-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathlocal_llm.py
79 lines (64 loc) · 2.47 KB
/
local_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from pathlib import Path
from typing import Iterator
from ragna.core import Assistant, PackageRequirement, Source
class Llama38BInstruct(Assistant):
@classmethod
def display_name(cls):
return "Llama-3-8B-Instruct-exl2"
@classmethod
def requirements(cls):
return [
PackageRequirement("torch"),
PackageRequirement("exllamav2"),
]
@classmethod
def is_available(cls):
requirements_available = super().is_available()
if not requirements_available:
return False
import torch
return torch.cuda.is_available()
def __init__(self):
super().__init__()
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Config,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
config = ExLlamaV2Config()
config.model_dir = str(Path.home() / "shared/scipy/rags-to-riches" / self.display_name())
config.prepare()
self.tokenizer = ExLlamaV2Tokenizer(config)
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy=True)
model.load_autosplit(cache)
self.generator = ExLlamaV2StreamingGenerator(model, cache, self.tokenizer)
self.generator.set_stop_conditions({self.tokenizer.eos_token_id, 78191})
self.settings = ExLlamaV2Sampler.Settings()
self.settings.temperature = 0.0
def _make_prompt(self, prompt: str, sources: list[Source]) -> str:
return "\n".join(
[
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>",
f"",
f"Answer the question based only on the following context:",
*[source.content for source in sources],
f"<|eot_id|><|start_header_id|>user<|end_header_id|>",
f"",
f"{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
]
)
def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> Iterator[str]:
input_ids = self.tokenizer.encode(
self._make_prompt(prompt, sources), add_bos=False
)
self.generator.begin_stream_ex(input_ids, self.settings)
for _ in range(max_new_tokens):
result = self.generator.stream_ex()
if result["eos"]:
break
yield result["chunk"]