Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enh : make bark.h a C header #170

Merged
merged 6 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 170 additions & 18 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <random>
#include <regex>
#include <string>
#include <thread>
#include <vector>

#include "bark.h"
#include "encodec.h"
Expand All @@ -30,6 +32,121 @@

static const size_t MB = 1024 * 1024;

typedef int32_t bark_token;
typedef std::vector<int32_t> bark_sequence;
typedef std::vector<std::vector<int32_t>> bark_codes;

struct bark_vocab {
using id = int32_t;
using token = std::string;

std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
};

struct gpt_layer {
// normalization
struct ggml_tensor *ln_1_g;
struct ggml_tensor *ln_1_b;

struct ggml_tensor *ln_2_g;
struct ggml_tensor *ln_2_b;

// attention
struct ggml_tensor *c_attn_attn_w;
struct ggml_tensor *c_attn_attn_b;

struct ggml_tensor *c_attn_proj_w;
struct ggml_tensor *c_attn_proj_b;

// mlp
struct ggml_tensor *c_mlp_fc_w;
struct ggml_tensor *c_mlp_fc_b;

struct ggml_tensor *c_mlp_proj_w;
struct ggml_tensor *c_mlp_proj_b;
};

struct gpt_model {
gpt_hparams hparams;

// normalization
struct ggml_tensor *ln_f_g;
struct ggml_tensor *ln_f_b;

struct ggml_tensor *wpe; // position embedding
std::vector<struct ggml_tensor *> wtes; // token embedding
std::vector<struct ggml_tensor *> lm_heads; // language model head

std::vector<gpt_layer> layers;

// key + value memory
struct ggml_tensor *memory_k;
struct ggml_tensor *memory_v;

struct ggml_context *ctx;

ggml_backend_t backend = NULL;

ggml_backend_buffer_t buffer_w;
ggml_backend_buffer_t buffer_kv;

std::map<std::string, struct ggml_tensor *> tensors;

//
int64_t t_sample_us = 0;
int64_t t_predict_us = 0;
int64_t t_main_us = 0;

//
int64_t n_sample = 0;

//
int64_t memsize = 0;
};

struct bark_model {
// The token encoders
struct gpt_model semantic_model;
struct gpt_model coarse_model;
struct gpt_model fine_model;

// The vocabulary for the semantic encoder
struct bark_vocab vocab;
};

struct bark_context {
struct bark_model text_model;

struct encodec_context *encodec_ctx;

// buffer for model evaluation
ggml_backend_buffer_t buf_compute;

// custom allocator
struct ggml_allocr *allocr = NULL;
int n_gpu_layers = 0;

std::mt19937 rng;

bark_sequence tokens;
bark_sequence semantic_tokens;

bark_codes coarse_tokens;
bark_codes fine_tokens;

std::vector<float> audio_arr;

// hyperparameters
bark_context_params params;

// encodec parameters
std::string encodec_model_path;

// statistics
bark_statistics stats;
};

class BarkProgressBar {
public:
BarkProgressBar(std::string func_name, double needed_progress) {
Expand Down Expand Up @@ -1070,22 +1187,23 @@ static bool bark_load_model_from_file(
return true;
}

struct bark_context* bark_load_model(const std::string& model_path, bark_verbosity_level verbosity, uint32_t seed) {
struct bark_context* bark_load_model(const char *model_path, bark_verbosity_level verbosity, uint32_t seed) {
int64_t t_load_start_us = ggml_time_us();

struct bark_context* bctx = new bark_context();

bctx->text_model = bark_model();
if (!bark_load_model_from_file(model_path, bctx, verbosity)) {
fprintf(stderr, "%s: failed to load model weights from '%s'\n", __func__, model_path.c_str());
std::string model_path_str(model_path);
if (!bark_load_model_from_file(model_path_str, bctx, verbosity)) {
fprintf(stderr, "%s: failed to load model weights from '%s'\n", __func__, model_path);
return nullptr;
}

bark_context_params params = bark_context_default_params();
params.verbosity = verbosity;
bctx->rng = std::mt19937(seed);
bctx->params = params;
bctx->t_load_us = ggml_time_us() - t_load_start_us;
bctx->stats.t_load_us = ggml_time_us() - t_load_start_us;

return bctx;
}
Expand Down Expand Up @@ -1629,6 +1747,7 @@ static bool bark_eval_text_encoder(struct bark_context* bctx, int n_threads) {
}

bctx->semantic_tokens = output;
bctx->stats.n_sample_semantic = model.n_sample;

return true;
}
Expand Down Expand Up @@ -1672,6 +1791,7 @@ bool bark_forward_text_encoder(struct bark_context* bctx, int n_threads) {
}

model.t_main_us = ggml_time_us() - t_main_start_us;
bctx->stats.t_semantic_us = model.t_main_us;

bark_print_statistics(&model);

Expand Down Expand Up @@ -1797,6 +1917,7 @@ static bool bark_eval_coarse_encoder(struct bark_context* bctx, int n_threads) {
}

bctx->coarse_tokens = out_coarse;
bctx->stats.n_sample_coarse = model.n_sample;

return true;
}
Expand Down Expand Up @@ -1840,6 +1961,7 @@ bool bark_forward_coarse_encoder(struct bark_context* bctx, int n_threads) {
}

model.t_main_us = ggml_time_us() - t_main_start_us;
bctx->stats.t_coarse_us = model.t_main_us;

bark_print_statistics(&model);

Expand Down Expand Up @@ -1989,6 +2111,7 @@ static bool bark_eval_fine_encoder(struct bark_context* bctx, int n_threads) {
assert(bctx->coarse_tokens.size() == in_arr.size());

bctx->fine_tokens = in_arr;
bctx->stats.n_sample_fine = model.n_sample;

return true;
}
Expand Down Expand Up @@ -2034,6 +2157,7 @@ bool bark_forward_fine_encoder(struct bark_context* bctx, int n_threads) {
}

model.t_main_us = ggml_time_us() - t_main_start_us;
bctx->stats.t_fine_us = model.t_main_us;

bark_print_statistics(&model);

Expand Down Expand Up @@ -2062,15 +2186,16 @@ static bool bark_forward_eval(struct bark_context* bctx, int n_threads) {
return true;
}

bool bark_generate_audio(struct bark_context* bctx, const std::string& text, int n_threads) {
bool bark_generate_audio(struct bark_context* bctx, const char * text, int n_threads) {
if (!bctx) {
fprintf(stderr, "%s: invalid bark context\n", __func__);
return false;
}

int64_t t_start_eval_us = ggml_time_us();

bark_tokenize_input(bctx, text);
std::string text_str(text);
bark_tokenize_input(bctx, text_str);

if (!bark_forward_eval(bctx, n_threads)) {
fprintf(stderr, "%s: failed to forward eval\n", __func__);
Expand Down Expand Up @@ -2101,8 +2226,7 @@ bool bark_generate_audio(struct bark_context* bctx, const std::string& text, int
}

bctx->audio_arr = bctx->encodec_ctx->out_audio;

bctx->t_eval_us = ggml_time_us() - t_start_eval_us;
bctx->stats.t_eval_us = ggml_time_us() - t_start_eval_us;

return true;
}
Expand Down Expand Up @@ -2230,21 +2354,21 @@ bool bark_model_weights_quantize(std::ifstream& fin, std::ofstream& fout, ggml_f
return true;
}

bool bark_model_quantize(
const std::string& fname_inp,
const std::string& fname_out,
ggml_ftype ftype) {
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
bool bark_model_quantize(const char * fname_inp, const char * fname_out, ggml_ftype ftype) {
printf("%s: loading model from '%s'\n", __func__, fname_inp);

auto fin = std::ifstream(fname_inp, std::ios::binary);
std::string fname_inp_str(fname_inp);
std::string fname_out_str(fname_out);

auto fin = std::ifstream(fname_inp_str, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp);
return false;
}

auto fout = std::ofstream(fname_out, std::ios::binary);
auto fout = std::ofstream(fname_out_str, std::ios::binary);
if (!fout) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out);
return false;
}

Expand All @@ -2253,7 +2377,7 @@ bool bark_model_quantize(
uint32_t magic;
fin.read((char*)&magic, sizeof(magic));
if (magic != GGML_FILE_MAGIC) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp);
return false;
}

Expand Down Expand Up @@ -2301,3 +2425,31 @@ bool bark_model_quantize(

return true;
}

float * bark_get_audio_data(struct bark_context *bctx) {
if (!bctx || bctx->audio_arr.empty()) {
return nullptr;
}
return bctx->audio_arr.data();
}

int bark_get_audio_data_size(struct bark_context *bctx) {
if (!bctx || bctx->audio_arr.empty()) {
return 0;
}
return bctx->audio_arr.size();
}

const bark_statistics * bark_get_statistics(struct bark_context *bctx) {
if (!bctx) {
return nullptr;
}
return &bctx->stats;
}

void bark_reset_statistics(struct bark_context *bctx) {
if (!bctx) {
return;
}
memset(&bctx->stats, 0, sizeof(bark_statistics));
}
Loading
Loading