Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for XTC and DRY samplers #1843

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,22 @@ def add_mirostat_v2(self, seed: int, tau: float, eta: float):
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
self._add_sampler(sampler)

def add_xtc(self, probability: float, threshold: float, min_keep: int, seed: int):
sampler = llama_cpp.llama_sampler_init_xtc(probability, threshold, min_keep, seed)
self._add_sampler(sampler)

def add_dry(self, model: LlamaModel, ctx: LlamaContext, multiplier: float, base: float,
allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []):

# Convert Python strings to bytes
seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers]
# Create array of char*
arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes)
sampler = llama_cpp.llama_sampler_init_dry(model.vocab, ctx.n_ctx(), multiplier, base,
allowed_length, penalty_last_n,
arr, len(seq_breakers))
self._add_sampler(sampler)

def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
sampler = llama_cpp.llama_sampler_init_grammar(
model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
Expand Down
100 changes: 100 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,13 @@ def _init_sampler(
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
Expand Down Expand Up @@ -747,11 +754,13 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
else:
n_probs = 0
min_keep = max(1, n_probs)
sampler.add_dry(self._model, self._ctx, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers)
sampler.add_top_k(top_k)
sampler.add_typical(typical_p, min_keep)
sampler.add_top_p(top_p, min_keep)
sampler.add_min_p(min_p, min_keep)
sampler.add_temp(temp)
sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed)
sampler.add_dist(self._seed)
return sampler

Expand All @@ -769,6 +778,13 @@ def sample(
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
Expand Down Expand Up @@ -804,6 +820,13 @@ def sample(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
Expand Down Expand Up @@ -833,6 +856,13 @@ def generate(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
Expand Down Expand Up @@ -872,6 +902,13 @@ def generate(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
Expand Down Expand Up @@ -924,6 +961,13 @@ def generate(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
logits_processor=logits_processor,
grammar=grammar,
penalize_nl=penalize_nl,
Expand Down Expand Up @@ -1140,6 +1184,13 @@ def _create_completion(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
Expand Down Expand Up @@ -1328,6 +1379,13 @@ def logit_bias_processor(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
Expand Down Expand Up @@ -1760,6 +1818,13 @@ def create_completion(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
Expand Down Expand Up @@ -1823,6 +1888,13 @@ def create_completion(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
model=model,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
Expand Down Expand Up @@ -1857,6 +1929,13 @@ def __call__(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
Expand Down Expand Up @@ -1920,6 +1999,13 @@ def __call__(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
model=model,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
Expand Down Expand Up @@ -1951,6 +2037,13 @@ def create_chat_completion(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
Expand Down Expand Up @@ -2024,6 +2117,13 @@ def create_chat_completion(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
model=model,
logits_processor=logits_processor,
grammar=grammar,
Expand Down
35 changes: 35 additions & 0 deletions llama_cpp/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3646,6 +3646,41 @@ def llama_sampler_init_xtc(
) -> llama_sampler_p:
...

# LLAMA_API struct llama_sampler * llama_sampler_init_dry(
# const struct llama_vocab * vocab,
# int32_t context_size,
# float dry_multiplier,
# float dry_base,
# int32_t dry_allowed_length,
# int32_t dry_penalty_last_n,
# const char ** seq_breakers,
# size_t num_breakers);
@ctypes_function(
"llama_sampler_init_dry",
[
llama_vocab_p_ctypes,
ctypes.c_int32,
ctypes.c_float,
ctypes.c_float,
ctypes.c_int32,
ctypes.c_int32,
ctypes.POINTER(ctypes.c_char_p),
ctypes.c_size_t
],
llama_sampler_p_ctypes,
)
def llama_sampler_init_dry(
vocab: llama_vocab_p,
context_size: int,
dry_multiplier: float,
dry_base: float,
dry_allowed_length: int,
dry_penalty_last_n: int,
seq_breakers: list[str],
num_breakers: int,
) -> llama_sampler_p:
...


# /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
# LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);
Expand Down