Created
April 20, 2023 13:02
-
-
Save TadaoYamaoka/db819a7b8a7b1cbc5d459cd178d623fa to your computer and use it in GitHub Desktop.
lyra/cli_example/python_lib.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "lyra/cli_example/encoder_main_lib.h" | |
#include "lyra/cli_example/decoder_main_lib.h" | |
#include <algorithm> | |
#include <cstdint> | |
#include <fstream> | |
#include <iterator> | |
#include <memory> | |
#include <optional> | |
#include <sstream> | |
#include <string> | |
#include <type_traits> | |
#include <vector> | |
#include "absl/flags/marshalling.h" | |
#include "absl/random/bit_gen_ref.h" | |
#include "absl/random/random.h" | |
#include "absl/status/status.h" | |
#include "absl/strings/string_view.h" | |
#include "absl/time/clock.h" | |
#include "absl/time/time.h" | |
#include "absl/types/span.h" | |
#include "glog/logging.h" // IWYU pragma: keep | |
#include "include/ghc/filesystem.hpp" | |
#include "lyra/fixed_packet_loss_model.h" | |
#include "lyra/gilbert_model.h" | |
#include "lyra/lyra_config.h" | |
#include "lyra/lyra_decoder.h" | |
#include "lyra/wav_utils.h" | |
#include <pybind11/pybind11.h> | |
#include <pybind11/stl.h> | |
#include <pybind11/numpy.h> | |
namespace py = pybind11; | |
// Packets are appended to encoded_features. The oldest packet is encoded | |
// starting at index 0. | |
py::array_t<uint8_t> encode_wav(const std::vector<int16_t>& wav_data, int num_channels, | |
int sample_rate_hz, int bitrate, bool enable_preprocessing, | |
bool enable_dtx, const std::string& model_path) { | |
std::vector<uint8_t> encoded_features; | |
const bool ret = chromemedia::codec::EncodeWav(wav_data, num_channels, | |
sample_rate_hz, bitrate, enable_preprocessing, | |
enable_dtx, model_path, &encoded_features); | |
if (!ret) | |
return py::array_t<uint8_t>(); | |
// return ndarray | |
auto result = py::array_t<int8_t>(encoded_features.size()); | |
py::buffer_info buf = result.request(); | |
int8_t *ptr = static_cast<int8_t*>(buf.ptr); | |
std::copy(encoded_features.begin(), encoded_features.end(), ptr); | |
return result; | |
} | |
py::array_t<int16_t> decode_features(const std::vector<uint8_t>& encoded_features, | |
int sample_rate_hz, | |
int bitrate, bool randomize_num_samples_requested, | |
float packet_loss_rate, float average_burst_length, | |
const std::string& model_path) { | |
chromemedia::codec::PacketLossPattern fixed_packet_loss_pattern({}, {}); | |
auto decoder = chromemedia::codec::LyraDecoder::Create(sample_rate_hz, chromemedia::codec::kNumChannels, model_path); | |
if (decoder == nullptr) { | |
LOG(ERROR) << "Could not create lyra decoder."; | |
return py::array_t<int16_t>(); | |
} | |
std::unique_ptr<chromemedia::codec::PacketLossModelInterface> packet_loss_model; | |
if (fixed_packet_loss_pattern.starts_.empty()) { | |
packet_loss_model = | |
chromemedia::codec::GilbertModel::Create(packet_loss_rate, average_burst_length); | |
} else { | |
packet_loss_model = std::make_unique<chromemedia::codec::FixedPacketLossModel>( | |
sample_rate_hz, chromemedia::codec::GetNumSamplesPerHop(sample_rate_hz), | |
fixed_packet_loss_pattern.starts_, | |
fixed_packet_loss_pattern.durations_); | |
} | |
if (packet_loss_model == nullptr) { | |
LOG(ERROR) << "Could not create packet loss simulator model."; | |
return py::array_t<int16_t>(); | |
} | |
const int packet_size = chromemedia::codec::BitrateToPacketSize(bitrate); | |
std::vector<int16_t> decoded_audio; | |
absl::BitGen gen; | |
const bool ret = chromemedia::codec::DecodeFeatures(encoded_features, packet_size, | |
randomize_num_samples_requested, gen, decoder.get(), | |
packet_loss_model.get(), &decoded_audio); | |
if (!ret) | |
return py::array_t<int16_t>(); | |
// return ndarray | |
auto result = py::array_t<int16_t>(decoded_audio.size()); | |
py::buffer_info buf = result.request(); | |
int16_t *ptr = static_cast<int16_t*>(buf.ptr); | |
std::copy(decoded_audio.begin(), decoded_audio.end(), ptr); | |
return result; | |
} | |
void set_loglevel(const int level) { | |
google::InitGoogleLogging("lyra"); | |
FLAGS_minloglevel = level; | |
} | |
PYBIND11_MODULE(lyra, m) { | |
m.def("encode_wav", &encode_wav); | |
m.def("decode_features", &decode_features); | |
m.def("set_loglevel", &set_loglevel); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment