diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a194fece..0c7f3990c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- feat: Add hot-swapping for LoRA adapters + ## [0.3.2] - feat: Update llama.cpp to ggerganov/llama.cpp@74d73dc85cc2057446bf63cc37ff649ae7cebd80 diff --git a/docs/api-reference.md b/docs/api-reference.md index ab51ef754..6aba6b863 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -22,6 +22,7 @@ High-level Python bindings for llama.cpp. - __call__ - create_chat_completion - create_chat_completion_openai_v1 + - set_lora_adapter_scale - set_cache - save_state - load_state diff --git a/examples/low_level_api/common.py b/examples/low_level_api/common.py index a0212ff0d..7e397969d 100644 --- a/examples/low_level_api/common.py +++ b/examples/low_level_api/common.py @@ -3,10 +3,11 @@ import re from dataclasses import dataclass, field -from typing import List - -# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp +from typing import List, Sequence, Tuple +import typing +# Based on https://github.com/ggerganov/llama.cpp/blob/master/common/common.cpp +# and https://github.com/ggerganov/llama.cpp/blob/master/common/arg.cpp @dataclass class GptParams: @@ -40,8 +41,8 @@ class GptParams: input_suffix: str = "" antiprompt: List[str] = field(default_factory=list) - lora_adapter: str = "" - lora_base: str = "" + lora: List[str] = None + lora_scaled: List[Tuple[str, float]] = None memory_f16: bool = True random_prompt: bool = False @@ -257,16 +258,56 @@ def gpt_params_parse(argv=None): parser.add_argument( "--lora", type=str, - default="", - help="apply LoRA adapter (implies --no-mmap)", - dest="lora_adapter", - ) - parser.add_argument( - "--lora-base", - type=str, - default="", - help="optional model to use as a base for the layers modified by the LoRA adapter", - dest="lora_base", + action="append", + default=[], + help="path to LoRA adapter (can be repeated to use multiple adapters)", + metavar="FNAME", + dest="lora", + ) + + class MultiTupleAction(argparse.Action): + """Action for handling multiple arguments as tuples with type conversion""" + def __init__(self, + option_strings: Sequence[str], + dest: str, + nargs: int = None, + type: Tuple = None, + metavar: Tuple = None, + **kwargs): + self.tuple_type = type + super().__init__( + option_strings=option_strings, + dest=dest, + type=str, # We will fix + nargs=nargs, + metavar=metavar, + **kwargs + ) + + def __call__(self, parser, namespace, values, option_string=None): + if len(values) != self.nargs: + parser.error( + f'{option_string} requires {len(self.metavar)} arguments: ' + f'{" ".join(self.metavar)}' + ) + + converted_values = tuple(value_type(value) for value_type, value in zip(typing.get_args(self.tuple_type), values)) + # Initialize list if needed + if not hasattr(namespace, self.dest): + setattr(namespace, self.dest, []) + + # Add the converted tuple to the list + getattr(namespace, self.dest).append(converted_values) + + parser.add_argument( + "--lora-scaled", + action=MultiTupleAction, + nargs=2, + type=Tuple[str, float], + help="path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)", + metavar=('FNAME', 'SCALE'), + dest='lora_scaled', + default=[], ) parser.add_argument( @@ -375,9 +416,6 @@ def gpt_params_parse(argv=None): delattr(args, "logit_bias_str") params = GptParams(**vars(args)) - if params.lora_adapter: - params.use_mmap = False - if logit_bias_str != None: for i in logit_bias_str: if m := re.match(r"(\d+)([-+]\d+)", i): diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index 39081be17..52264d76d 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -93,22 +93,14 @@ def __init__(self, params: GptParams) -> None: if self.params.ignore_eos: self.params.logit_bias[llama_cpp.llama_token_eos()] = -float("inf") - if len(self.params.lora_adapter) > 0: - if ( - llama_cpp.llama_apply_lora_from_file( - self.ctx, - self.params.lora_adapter.encode("utf8"), - ( - self.params.lora_base.encode("utf8") - if len(self.params.lora_base) > 0 - else None - ), - self.params.n_threads, - ) - != 0 - ): - print("error: failed to apply lora adapter") - return + for lora_path, scale in [(pth, 1.0) for pth in self.params.lora] + self.params.lora_scaled: + lora_adapter = llama_cpp.llama_lora_adapter_init( + self.model, + lora_path.encode("utf8")) + if lora_adapter is None: + raise RuntimeError(f"error: failed to load lora adapter '{lora_path}'") + if scale != 0.0: + llama_cpp.llama_lora_adapter_set(self.ctx, lora_adapter, scale) print(file=sys.stderr) print( diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 994d5f149..e43dcfe49 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -285,6 +285,18 @@ def kv_cache_seq_keep(self, seq_id: int): def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift) + def lora_adapter_set(self, adapter: LlamaLoraAdapter, scale: float): + return_code = llama_cpp.llama_lora_adapter_set(self.ctx, adapter.lora_adapter, scale) + if return_code != 0: + raise RuntimeError(f"lora_adapter_set returned {return_code}") + + def lora_adapter_remove(self, adapter: LlamaLoraAdapter) -> bool: + return_code = llama_cpp.llama_lora_adapter_remove(self.ctx, adapter.lora_adapter) + return return_code != 0 + + def lora_adapter_clear(self): + llama_cpp.llama_lora_adapter_clear(self.ctx) + def get_state_size(self) -> int: return llama_cpp.llama_get_state_size(self.ctx) @@ -861,3 +873,45 @@ def close(self): def __del__(self): self.close() + +class LlamaLoraAdapter: + """Intermediate Python wrapper for a llama.cpp llama_lora_adapter. + NOTE: For stability it's recommended you use the Llama class instead.""" + + def __init__( + self, + model: LlamaModel, + lora_path: str, + *, + verbose: bool = True, + ): + self.model = model + self.lora_path = lora_path + + lora_adapter = None + + if not os.path.exists(lora_path): + raise ValueError(f"LoRA adapter path does not exist: {lora_path}") + + with suppress_stdout_stderr(disable=verbose): + lora_adapter = llama_cpp.llama_lora_adapter_init( + self.model.model, + self.lora_path.encode("utf-8"), + ) + + if lora_adapter is None: + raise RuntimeError( + f"Failed to initialize LoRA adapter from lora path: {self.lora_path}" + ) + + # The llama_lora_adapter will be freed by the llama_model as part of its + # lifecycle. The llama_model destructor destroys each llama_lora_adapter, + # and the destructor for llama_lora_adapter calls llama_lora_adapter_free. + # All we do here is clear the wrapped reference when the LlamaModel wrapper + # is closed, so that the LlamaLoraAdapter wrapper reference is cleared to + # when the llama_lora_adapters are freed. + def clear_lora_adapter(): + self.lora_adapter = None + self.model._exit_stack.callback(clear_lora_adapter) + + self.lora_adapter = lora_adapter \ No newline at end of file diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d15a88b00..fe70c806f 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -18,6 +18,7 @@ List, Literal, Optional, + Tuple, Union, Generator, Sequence, @@ -33,6 +34,7 @@ from .llama_types import * from .llama_grammar import LlamaGrammar from .llama_cache import ( + LlamaCacheKey, BaseLlamaCache, LlamaCache, # type: ignore LlamaDiskCache, # type: ignore @@ -96,9 +98,7 @@ def __init__( # Sampling Params last_n_tokens_size: int = 64, # LoRA Params - lora_base: Optional[str] = None, - lora_scale: float = 1.0, - lora_path: Optional[str] = None, + lora_adapters: Optional[Dict[str, float]] = None, # Backend Params numa: Union[bool, int] = False, # Chat Format Params @@ -174,8 +174,7 @@ def __init__( offload_kqv: Offload K, Q, V to GPU. flash_attn: Use flash attention. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. - lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. - lora_path: Path to a LoRA file to apply to the model. + lora_adapters: Paths to LoRA adapter files and the scale to apply to them at (scale of 0.0 will not be used during inference). numa: numa policy chat_format: String specifying the chat format to use when calling create_chat_completion. chat_handler: Optional chat handler to use when calling create_chat_completion. @@ -243,7 +242,7 @@ def __init__( ) # keep a reference to the array so it is not gc'd self.model_params.tensor_split = self._c_tensor_split self.model_params.vocab_only = vocab_only - self.model_params.use_mmap = use_mmap if lora_path is None else False + self.model_params.use_mmap = use_mmap self.model_params.use_mlock = use_mlock # kv_overrides is the original python dict @@ -355,9 +354,9 @@ def __init__( self.cache: Optional[BaseLlamaCache] = None - self.lora_base = lora_base - self.lora_scale = lora_scale - self.lora_path = lora_path + self.lora_adapters = ( + lora_adapters if lora_adapters is None else {} + ) self.spm_infill = spm_infill @@ -406,32 +405,14 @@ def __init__( ) ) - self._lora_adapter: Optional[llama_cpp.llama_lora_adapter_p] = None - - if self.lora_path: - self._lora_adapter = llama_cpp.llama_lora_adapter_init( - self._model.model, - self.lora_path.encode("utf-8"), - ) - if self._lora_adapter is None: - raise RuntimeError( - f"Failed to initialize LoRA adapter from lora path: {self.lora_path}" - ) - - def free_lora_adapter(): - if self._lora_adapter is None: - return - llama_cpp.llama_lora_adapter_free(self._lora_adapter) - self._lora_adapter = None - - self._stack.callback(free_lora_adapter) + # Dict from LoRA path to wrapper + self._lora_adapters_paths: Dict[str, internals.LlamaLoraAdapter] = {} + # Immutable value representing active adapters for use as a key + self._lora_adapters_active: Tuple[Tuple[str, float], ...] = () - if llama_cpp.llama_lora_adapter_set( - self._ctx.ctx, self._lora_adapter, self.lora_scale - ): - raise RuntimeError( - f"Failed to set LoRA adapter from lora path: {self.lora_path}" - ) + if self.lora_adapters: + for lora_path, scale in self.lora_adapters.copy().items(): + self.set_lora_adapter_scale(lora_path, scale, load_if_needed=True) if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) @@ -453,6 +434,7 @@ def free_lora_adapter(): self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab) self.n_tokens = 0 + self.tokens_lora_adapters: Tuple[Tuple[str, float]] = () # Adapters that processed tokens self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) self.scores: npt.NDArray[np.single] = np.ndarray( (n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single @@ -621,10 +603,52 @@ def set_seed(self, seed: int): seed: The random seed. """ self._seed = seed + + def set_lora_adapter_scale(self, lora_path: str, scale: float, *, load_if_needed=False): + """ + Set the scale for a LoRA adapter or 0.0 to disable it for inference. If the LoRA adapter file + has previously been loaded then this method will set its scale. If the LoRA adapter file has + not been previously loaded, this method will raise an exception, unless load_if_needed is set. + + Args: + lora_path: The path to the LoRA adapter. This path must have been loaded when the `Llama` object was created. + scale: The scaling factor to apply to the LoRA adapter. If 0.0, the LoRA adapter will be disabled so it won't be used during inference. + load_if_needed: Whether or not to load the adapter if it has not been previously been loaded. If True, this + method will attempt to load the adapter from the lora_path if needed. If False, loading an adapter that + hasn't already been loaded will raise an exception. + """ + # Load adapter if needed (even if scale 0.0) + lora_adapter = self._lora_adapters_paths.get(lora_path) + if lora_adapter is None: + lora_adapter = internals.LlamaLoraAdapter( + self._model, + lora_path, + verbose=self.verbose, + ) + if lora_adapter is None: + raise RuntimeError( + f"Failed to initialize LoRA adapter from lora path: {lora_path}" + ) + self._lora_adapters_paths[lora_path] = lora_adapter + + if scale == 0.0: + # Remove from context; safe to call even if not in context + self._ctx.lora_adapter_remove(lora_adapter) + else: + # Set scale in context + self._ctx.lora_adapter_set(lora_adapter, scale) + + if self.lora_adapters is None: + self.lora_adapters = {} + self.lora_adapters[lora_path] = scale + self._lora_adapters_active = tuple(sorted( + filter(lambda path_scale: path_scale[1] != 0.0, self.lora_adapters.items()) + )) def reset(self): """Reset the model state.""" self.n_tokens = 0 + self.tokens_lora_adapters = self._lora_adapters_active def eval(self, tokens: Sequence[int]): """Evaluate a list of tokens. @@ -875,7 +899,7 @@ def generate( ) # Check for kv cache prefix match - if reset and self.n_tokens > 0: + if reset and self.n_tokens > 0 and self.tokens_lora_adapters == self._lora_adapters_active: longest_prefix = 0 for a, b in zip(self._input_ids, tokens[:-1]): if a == b: @@ -1292,7 +1316,8 @@ def logit_bias_processor( if self.cache: try: - cache_item = self.cache[prompt_tokens] + cache_key = LlamaCacheKey(active_lora_adapters=self._lora_adapters_active, tokens=tuple(prompt_tokens)) + cache_item = self.cache[cache_key] cache_prefix_len = Llama.longest_token_prefix( cache_item.input_ids.tolist(), prompt_tokens ) @@ -1630,7 +1655,8 @@ def logit_bias_processor( if self.cache: if self.verbose: print("Llama._create_completion: cache save", file=sys.stderr) - self.cache[prompt_tokens + completion_tokens] = self.save_state() + cache_key = LlamaCacheKey(active_lora_adapters=self._lora_adapters_active, tokens=tuple(prompt_tokens + completion_tokens)) + self.cache[cache_key] = self.save_state() if self.verbose: print("Llama._create_completion: cache saved", file=sys.stderr) return @@ -1638,7 +1664,8 @@ def logit_bias_processor( if self.cache: if self.verbose: print("Llama._create_completion: cache save", file=sys.stderr) - self.cache[prompt_tokens + completion_tokens] = self.save_state() + cache_key = LlamaCacheKey(active_lora_adapters=self._lora_adapters_active, tokens=tuple(prompt_tokens + completion_tokens)) + self.cache[cache_key] = self.save_state() text_str = text.decode("utf-8", errors="ignore") @@ -2095,9 +2122,7 @@ def __getstate__(self): # Sampling Params last_n_tokens_size=self.last_n_tokens_size, # LoRA Params - lora_base=self.lora_base, - lora_scale=self.lora_scale, - lora_path=self.lora_path, + lora_adapters=self.lora_adapters, # Backend Params numa=self.numa, # Chat Format Params diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py index e059e98e1..ce7d2ac7d 100644 --- a/llama_cpp/llama_cache.py +++ b/llama_cpp/llama_cache.py @@ -1,5 +1,6 @@ import sys from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import ( Optional, Sequence, @@ -13,6 +14,17 @@ from .llama_types import * +@dataclass(eq=True, frozen=True) +class LlamaCacheKey: + """A key in a LlamaCache. Stores tokens to key by. Also stores + information about active LoRA adapters, because we need different + cached values for different active adapters, even for the same tokens.""" + active_lora_adapters: Tuple[Tuple[str, float], ...] + tokens: Tuple[int, ...] + + def __post_init__(self): + if not isinstance(self.tokens, tuple): + raise ValueError("tokens must be a tuple") class BaseLlamaCache(ABC): """Base cache class for a llama.cpp model.""" @@ -20,6 +32,13 @@ class BaseLlamaCache(ABC): def __init__(self, capacity_bytes: int = (2 << 30)): self.capacity_bytes = capacity_bytes + def _convert_to_cache_key(self, key: Union[Sequence[int], LlamaCacheKey]) -> LlamaCacheKey: + """Convert raw tokens to a key if needed""" + if type(key) == LlamaCacheKey: + return key + else: + return LlamaCacheKey(active_lora_adapters=(), tokens=tuple(key)) + @property @abstractmethod def cache_size(self) -> int: @@ -27,24 +46,61 @@ def cache_size(self) -> int: def _find_longest_prefix_key( self, - key: Tuple[int, ...], - ) -> Optional[Tuple[int, ...]]: + key: LlamaCacheKey, + ) -> Optional[LlamaCacheKey]: + """Find the cached key with the longest matching token prefix. A match also requires that the active + LoRA adapters match exactly. + + Args: + key (LlamaCacheKey): The key to find a prefix match for. + + Returns: + Optional[LlamaCacheKey]: The key with the longest matching prefix, or None if no match found. + """ pass @abstractmethod - def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": + def __getitem__(self, key: Union[Sequence[int], LlamaCacheKey]) -> "llama_cpp.llama.LlamaState": + """Retrieve a cached state by key, matching on the longest common token prefix. A match also requires + that the active LoRA adapters match exactly. + + Args: + key: Key to look up. Raw token sequences are supported for backwards compatibility + and assume no active LoRA adapters. + + Returns: + llama_cpp.llama.LlamaState: The cached state for the entry sharing the longest token prefix. + + Raises: + KeyError: If no prefix match is found. + """ raise NotImplementedError @abstractmethod - def __contains__(self, key: Sequence[int]) -> bool: + def __contains__(self, key: Union[Sequence[int], LlamaCacheKey]) -> bool: + """Check if any cached key shares a token prefix with the given key. + + Args: + key: Key to look up. Raw token sequences are supported for backwards compatibility + and assume no active LoRA adapters. + + Returns: + bool: True if any cached key shares a token prefix with this key. + """ raise NotImplementedError @abstractmethod def __setitem__( - self, key: Sequence[int], value: "llama_cpp.llama.LlamaState" + self, key: Union[Sequence[int], LlamaCacheKey], value: "llama_cpp.llama.LlamaState" ) -> None: - raise NotImplementedError + """Store a state keyed on its tokens and information about active LoRA adapters. + Args: + key: Key to store. Raw token sequences are supported for backwards compatibility + and assume no active LoRA adapters + value: The state to cache + """ + raise NotImplementedError class LlamaRAMCache(BaseLlamaCache): """Cache for a llama.cpp model using RAM.""" @@ -53,7 +109,7 @@ def __init__(self, capacity_bytes: int = (2 << 30)): super().__init__(capacity_bytes) self.capacity_bytes = capacity_bytes self.cache_state: OrderedDict[ - Tuple[int, ...], "llama_cpp.llama.LlamaState" + LlamaCacheKey, "llama_cpp.llama.LlamaState" ] = OrderedDict() @property @@ -62,22 +118,21 @@ def cache_size(self): def _find_longest_prefix_key( self, - key: Tuple[int, ...], - ) -> Optional[Tuple[int, ...]]: + key: LlamaCacheKey, + ) -> Optional[LlamaCacheKey]: min_len = 0 - min_key = None - keys = ( - (k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) - for k in self.cache_state.keys() - ) - for k, prefix_len in keys: + min_key: Optional[LlamaCacheKey] = None + for k in self.cache_state.keys(): + if k.active_lora_adapters != key.active_lora_adapters: continue + if len(k.tokens) < min_len: continue # Optimization + prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k.tokens, key.tokens) if prefix_len > min_len: min_len = prefix_len min_key = k return min_key - def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": - key = tuple(key) + def __getitem__(self, key: Union[Sequence[int], LlamaCacheKey]) -> "llama_cpp.llama.LlamaState": + key = self._convert_to_cache_key(key) _key = self._find_longest_prefix_key(key) if _key is None: raise KeyError("Key not found") @@ -85,11 +140,11 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": self.cache_state.move_to_end(_key) return value - def __contains__(self, key: Sequence[int]) -> bool: + def __contains__(self, key: Union[Sequence[int], LlamaCacheKey]) -> bool: return self._find_longest_prefix_key(tuple(key)) is not None - def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): - key = tuple(key) + def __setitem__(self, key: Union[Sequence[int], LlamaCacheKey], value: "llama_cpp.llama.LlamaState"): + key = self._convert_to_cache_key(key) if key in self.cache_state: del self.cache_state[key] self.cache_state[key] = value @@ -116,19 +171,24 @@ def cache_size(self): def _find_longest_prefix_key( self, - key: Tuple[int, ...], - ) -> Optional[Tuple[int, ...]]: + key: LlamaCacheKey, + ) -> Optional[LlamaCacheKey]: min_len = 0 min_key: Optional[Tuple[int, ...]] = None for k in self.cache.iterkeys(): # type: ignore - prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key) + if not isinstance(k, LlamaCacheKey): + print("LlamaDiskCache: Disk cache keys must be LlamaCacheKey objects: skipping") + continue + if k.active_lora_adapters != key.active_lora_adapters: continue + if len(k.tokens) < min_len: continue # Optimization + prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k.tokens, key.tokens) if prefix_len > min_len: min_len = prefix_len - min_key = k # type: ignore + min_key = k return min_key - def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": - key = tuple(key) + def __getitem__(self, key: Union[Sequence[int], LlamaCacheKey]) -> "llama_cpp.llama.LlamaState": + key = self._convert_to_cache_key(key) _key = self._find_longest_prefix_key(key) if _key is None: raise KeyError("Key not found") @@ -138,12 +198,12 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": # self.cache.push(_key, side="front") # type: ignore return value - def __contains__(self, key: Sequence[int]) -> bool: - return self._find_longest_prefix_key(tuple(key)) is not None + def __contains__(self, key: Union[Sequence[int], LlamaCacheKey]) -> bool: + return self._find_longest_prefix_key(self._convert_to_cache_key(key)) is not None - def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): + def __setitem__(self, key: Union[Sequence[int], LlamaCacheKey], value: "llama_cpp.llama.LlamaState"): print("LlamaDiskCache.__setitem__: called", file=sys.stderr) - key = tuple(key) + key = self._convert_to_cache_key(key) if key in self.cache: print("LlamaDiskCache.__setitem__: delete", file=sys.stderr) del self.cache[key] diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index c6716f919..cb2765c5a 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -2,7 +2,7 @@ import json -from typing import Dict, Optional, Union, List +from typing import Any, Dict, Optional, Union, List import llama_cpp import llama_cpp.llama_speculative as llama_speculative @@ -19,6 +19,12 @@ def __init__(self, models: List[ModelSettings]) -> None: for model in models: if not model.model_alias: model.model_alias = model.model + if model.model_alias in self._model_settings_dict: + raise ValueError( + f"Please specify a unique model alias for each model: {model.model_alias}" + ) + if model.verbose: + print(f"Registering model: {model.model_alias}") self._model_settings_dict[model.model_alias] = model self._current_model: Optional[llama_cpp.Llama] = None @@ -28,12 +34,19 @@ def __init__(self, models: List[ModelSettings]) -> None: self._default_model_alias: str = self._default_model_settings.model_alias # type: ignore # Load default model + + if self._default_model_settings.verbose: + print(f"Loading default model {self._default_model_alias}") self._current_model = self.load_llama_from_model_settings( self._default_model_settings ) self._current_model_alias = self._default_model_alias def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama: + """Get the Llama model for the given alias, or the default model otherwise. + This may result in model loading, or in hot-swapping if a compatible model + is already loaded and only LoRA adapters need to be changed. + """ if model is None: model = self._default_model_alias @@ -44,12 +57,49 @@ def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama: if self._current_model is not None: return self._current_model + new_settings = self._model_settings_dict[model] + + if self._current_model is not None and self._current_model_alias is not None: + current_settings = self._model_settings_dict[self._current_model_alias] + + def hot_swappable_settings(settings: ModelSettings) -> Dict[str, Any]: + """Subset of settings used to check if models can be hot-swapped""" + values = settings.model_dump() + values.pop('model_alias', None) # The model alias doesn't matter + values.pop('lora_adapters', None) # Different LoRA adapters can be hot-swapped + return values + + if hot_swappable_settings(new_settings) == hot_swappable_settings(current_settings): + # We can hot-swap! First, zero out existing LoRAs + if current_settings.verbose: + print(f"Hot-swapping model, setting existing LoRA adapter scales to 0.0.") + if self._current_model.lora_adapters is not None: + for lora_path in self._current_model.lora_adapters: + self._current_model.set_lora_adapter_scale(lora_path, 0.0) + + # Now enable new LoRAs + if new_settings.lora_adapters is not None: + if new_settings.verbose: + print(f"Hot-swapping model, setting LoRA adapter scales for {model}.") + for lora_path, scale in new_settings.lora_adapters.items(): + self._current_model.set_lora_adapter_scale( + lora_path, + scale, + load_if_needed=True + ) + + self._current_model_alias = model + return self._current_model + if self._current_model: + if current_settings.verbose: + print(f"Switching model, unloading current model {self._current_model}") self._current_model.close() self._current_model = None - settings = self._model_settings_dict[model] - self._current_model = self.load_llama_from_model_settings(settings) + if new_settings.verbose: + print(f"Switching model, loading new model {model}") + self._current_model = self.load_llama_from_model_settings(new_settings) self._current_model_alias = model return self._current_model @@ -268,8 +318,7 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: # Sampling Params last_n_tokens_size=settings.last_n_tokens_size, # LoRA Params - lora_base=settings.lora_base, - lora_path=settings.lora_path, + lora_adapters=settings.lora_adapters, # Backend Params numa=settings.numa, # Chat Format Params diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 13c951241..e0868f266 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -113,13 +113,9 @@ class ModelSettings(BaseSettings): description="Last n tokens to keep for repeat penalty calculation.", ) # LoRA Params - lora_base: Optional[str] = Field( + lora_adapters: Optional[Dict[str, float]]= Field( default=None, - description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.", - ) - lora_path: Optional[str] = Field( - default=None, - description="Path to a LoRA file to apply to the model.", + description="Paths to LoRA adapter files and the scale to apply to them at (scale of 0.0 will not be used during inference).", ) # Backend Params numa: Union[bool, int] = Field(