Skip to content

Commit

Permalink
perf: preallocate tensor in semantic text generation to reduce alloca…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
no2chem committed Jun 21, 2023
1 parent 6cd7f0c commit 4f36747
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions bark/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,16 @@ def generate_text_semantic(
with _inference_mode():
x = x.to(device)
n_tot_steps = 768
# preallocate tensor
x_initial = x.shape[1]
x = torch.hstack([x , torch.empty([1, n_tot_steps], dtype=torch.int32, device=device)])
# custom tqdm updates since we don't know when eos will occur
pbar = tqdm.tqdm(disable=silent, total=n_tot_steps)
pbar_state = 0
tot_generated_duration_s = 0
kv_cache = None
for n in range(n_tot_steps):
if use_kv_caching and kv_cache is not None:
x_input = x[:, [-1]]
else:
x_input = x
x_input = x[:, [x_initial + n - 1]] if use_kv_caching and kv_cache is not None else x[:,:x_initial + n]
logits, kv_cache = model(
x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
)
Expand Down Expand Up @@ -485,18 +485,18 @@ def generate_text_semantic(
item_next == SEMANTIC_VOCAB_SIZE
or (min_eos_p is not None and probs[-1] >= min_eos_p)
):
n -= 1 # backtrack 1
# eos found, so break
pbar.update(n - pbar_state)
break
x = torch.cat((x, item_next[None]), dim=1)
x[0][x_initial + n] = item_next
tot_generated_duration_s += 1 / SEMANTIC_RATE_HZ
if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s:
pbar.update(n - pbar_state)
break
if n == n_tot_steps - 1:
pbar.update(n - pbar_state)
break
del logits, relevant_logits, probs, item_next

if n > pbar_state:
if n > pbar.total:
Expand All @@ -506,7 +506,7 @@ def generate_text_semantic(
pbar.total = n
pbar.refresh()
pbar.close()
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
out = x.detach().cpu().numpy().squeeze()[x_initial : x_initial + n + 1]
if OFFLOAD_CPU:
model.to("cpu")
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
Expand Down

0 comments on commit 4f36747

Please sign in to comment.