Skip to content

Commit

Permalink
Add Repetitive penalty to LLM inference - mlx-lm (ml-explore#399)
Browse files Browse the repository at this point in the history
* feat: add repetition penalty

* fix: generate function argument fix

* typo fixes

* update repetitive penalty

* update generate_step and generate

* resolve conflicts in generate

* merge latest oull origin master

* update generate

* update generate and generate_step

* update repetition list - rename variable

* refactor token count

* update generate step and generate

* move repetition_context in generate_step

* update generate step

* update generate_step
  • Loading branch information
vishal-14069 authored Feb 17, 2024
1 parent 0ba4663 commit 21e19b5
Showing 1 changed file with 73 additions and 6 deletions.
79 changes: 73 additions & 6 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import time
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
Expand Down Expand Up @@ -80,10 +80,36 @@ def get_model_path(path_or_hf_repo: str) -> Path:
return model_path


def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.
Paper: https://arxiv.org/abs/1909.05858
Args:
logits (mx.array): The logits produced by the language model.
generated_tokens (any): A list of N previous tokens.
penalty (float): The repetition penalty factor to be applied.
Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(generated_tokens) > 0:
indices = mx.array([token for token in generated_tokens])
selected_logits = logits[:, indices]
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[:, indices] = selected_logits
return logits


def generate_step(
prompt: mx.array,
model: nn.Module,
temp: float = 0.0,
temp: 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
Expand All @@ -92,6 +118,9 @@ def generate_step(
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing
one token and probability per call.
Expand All @@ -108,12 +137,37 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]:
prob = softmax_logits[0, token]
return token, prob

if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
):
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)

y = prompt
cache = None

repetition_context = prompt.tolist()

if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]

while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y, prob = sample(logits)

if repetition_penalty:
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty
)
y, prob = sample(logits)
repetition_context.append(y.item())
else:
y, prob = sample(logits)

if repetition_context_size:
if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
yield y, prob


Expand All @@ -125,6 +179,8 @@ def generate(
max_tokens: int = 100,
verbose: bool = False,
formatter: Callable = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
) -> str:
"""
Generate text from the model.
Expand All @@ -139,20 +195,31 @@ def generate(
(default ``False``).
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
"""

if verbose:
print("=" * 10)
print("Prompt:", prompt)

prompt = mx.array(tokenizer.encode(prompt))
prompt_tokens = mx.array(tokenizer.encode(prompt))

tic = time.perf_counter()
tokens = []
skip = 0
REPLACEMENT_CHAR = "\ufffd"

for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)):
for (token, prob), n in zip(
generate_step(
prompt_tokens,
model,
temp,
repetition_penalty,
repetition_context_size,
),
range(max_tokens),
):
if token == tokenizer.eos_token_id:
break
if n == 0:
Expand All @@ -179,7 +246,7 @@ def generate(
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
Expand Down

0 comments on commit 21e19b5

Please sign in to comment.