-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathrun-inference.py
52 lines (41 loc) · 1.82 KB
/
run-inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/usr/bin/env python3
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import logging
from typing import Optional, Tuple
from transformers import AutoTokenizer
from utils.setup import llama_config_setup
from config import LlamaConfig
from api import LlamaPipeline
def run_inference_popxl(config: LlamaConfig, tokenizer, hf_model, sequence_length: Optional[int] = None):
if sequence_length is not None:
config.model.sequence_length = sequence_length
pipe = LlamaPipeline(config, hf_llama_checkpoint=hf_model, tokenizer=tokenizer)
def get_input() -> Tuple[str, float, int, int]:
while True:
try:
logging.info("-- Enter prompt --")
prompt = input("> ")
logging.info("-- Enter Sampling Temperature (0 for greedy) --")
temperature = float(input("> "))
logging.info("-- Enter top-k parameter (0 for max) --")
k = int(input("> "))
logging.info("-- Enter number of tokens to generate --")
num_tokens = int(input("> "))
break
except ValueError:
logging.info("Invalid input!")
return prompt, temperature, k, num_tokens
while True:
prompt, temperature, k, output_length = get_input()
pipe(prompt, k=k, temperature=temperature, output_length=output_length)[0]
def main():
# --- Setup ---
config, _, hf_model = llama_config_setup("config/inference.yml", "release", "llama2_7b_pod4", hf_model_setup=True)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
run_inference_popxl(config, tokenizer, hf_model=hf_model, sequence_length=2048)
if __name__ == "__main__":
try:
main()
except Exception as e:
logging.exception(e) # Log time of exception
raise