Skip to content

Commit 8774247

Browse files
authoredFeb 1, 2025··
feat: inference server rev0 (#233)
1 parent 8642e93 commit 8774247

16 files changed

+2237
-1567
lines changed
 

‎.gitattributes

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
* text=auto eol=lf
22
docs/content/** linguist-documentation
33
docs/ linguist-detectable
4-
docs/content/chicago-fullnote-bibliography.csl linguist-vendored
4+
*.md linguist-detectable=true
5+
*.md linguist-documentation=false
6+
*.md linguist-language=Markdown
7+
*.markdown linguist-detectable=true
8+
*.markdown linguist-documentation=false
9+
*.markdown linguist-language=Markdown

‎.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ ipython_config.py
8686
# pyenv
8787
# For a library or package, you might want to ignore these files since the code is
8888
# intended to run in multiple environments; otherwise, check them in:
89-
.python-version
9089

9190
# pipenv
9291
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.

‎.node-version

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v20.9.0
1+
v22.11.0
File renamed without changes.

‎packages/morph/src/styles/variables.scss

+1-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,4 @@ $normalWeight: 400;
88

99
$font-heading: 'Cardo', serif;
1010
$font-body: 'Bricolage Grotesque', sans-serif;
11-
$font-mono: 'JetBrains Mono', monospace;
12-
13-
14-
11+
$font-mono: 'JetBrains Mono', monospace;

‎python/asteraceae/asteraceae/__init__.py

Whitespace-only changes.

‎python/asteraceae/asteraceae/service.py

-63
This file was deleted.

‎python/asteraceae/pyproject.toml

+8-38
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,22 @@
11
[project]
22
name = "asteraceae"
3-
description = "a BentoML-based service to run SAE-intervention"
3+
description = "a BentoML service that run a SAEs with vLLM"
44
readme = "README.md"
55
requires-python = ">=3.11"
66
license = { text = "Apache-2.0" }
77
authors = [{ name = "Aaron Pham", email = "contact@aarnphm.xyz" }]
8-
dependencies = [
9-
"accelerate>=0.34.0",
10-
"bentoml @ git+https://github.com/bentoml/BentoML.git@main",
11-
"fastapi>=0.115.0",
12-
"openai>=1.47.0",
13-
"sae @ git+https://github.com/EleutherAI/sae.git@main",
14-
"vllm>=0.6.3",
15-
]
16-
dynamic = ["version"]
8+
dependencies = ["bentoml>=1.3.20", "kantoku>=0.18.1", "vllm>=0.7.0"]
9+
version = "0.0.0"
1710
[project.urls]
1811
Documentation = "https://tinymorph.aarnphm.xyz"
1912
GitHub = "https://github.com/aarnphm/tinymorph"
2013
Twitter = "https://twitter.com/aarnphm_"
2114
Tracker = "https://github.com/aarnphm/tinymorph/issues"
2215

23-
[build-system]
24-
requires = ["hatchling", "hatch-vcs"]
25-
build-backend = "hatchling.build"
26-
27-
[tool.hatch.version]
28-
source = "vcs"
29-
fallback-version = "0.0.0"
30-
[tool.hatch.build.hooks.vcs]
31-
version-file = "asteraceae/_version.py"
32-
[tool.hatch.version.raw-options]
33-
git_describe_command = [
34-
"git",
35-
"describe",
36-
"--dirty",
37-
"--tags",
38-
"--long",
39-
"--first-parent",
40-
]
41-
version_scheme = "post-release"
42-
fallback_version = "0.0.0"
43-
[tool.hatch.metadata]
44-
allow-direct-references = true
45-
[tool.hatch.build.targets.sdist]
46-
only-include = ["asteraceae"]
47-
[tool.hatch.build.targets.wheel]
48-
packages = ["asteraceae"]
4916

5017
[tool.bentoml.build]
51-
service = "asteraceae.service:Engine"
52-
include = ["asteraceae/service.py"]
18+
service = "service:Engine"
19+
include = ["service.py"]
5320
[tool.bentoml.build.python]
5421
lock_packages = false
5522
[tool.bentoml.build.docker]
@@ -59,3 +26,6 @@ name = "HF_TOKEN"
5926
[[tool.bentoml.build.envs]]
6027
name = "HF_HUB_ENABLE_HF_TRANSFER"
6128
value = "1"
29+
30+
[tool.uv.sources]
31+
exo = { workspace = true }

‎python/asteraceae/service.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# /// script
2+
# requires-python = ">=3.11"
3+
# dependencies = [
4+
# "bentoml",
5+
# "vllm>=0.7.0",
6+
# ]
7+
# ///
8+
from __future__ import annotations
9+
import uuid, io, logging, os, traceback, functools, typing
10+
import bentoml, fastapi, pydantic, yaml
11+
12+
from argparse import Namespace
13+
from typing import AsyncGenerator, Literal, Optional, Union, Sequence
14+
from annotated_types import Ge, Le
15+
from typing_extensions import Annotated
16+
17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.INFO)
19+
20+
openai_api_app = fastapi.FastAPI()
21+
22+
MAX_TOKENS = 4096
23+
MODEL_ID = 'meta-llama/Llama-3.1-8B-Instruct'
24+
25+
SYSTEM_PROMPT= """Your are a proficient writer. Your goal is to create note suggestions for any given text that share similar stylistic choices and tonality as Frank Kafka. YOU MUST RETURN VALID JSON, with schema '{{"suggestion": string, "relevance": float}}'. ONLY RETURN JSON and RETURN AT MOST {num_suggestion} SUGGESTIONS. Kept suggestion terse and authentic."""
26+
27+
PROMPT_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
28+
29+
{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>
30+
31+
{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
32+
33+
"""
34+
35+
36+
@bentoml.asgi_app(openai_api_app, path='/v1')
37+
@bentoml.service(
38+
name='asteraceae-service',
39+
traffic={'timeout': 300, 'concurrency': 256},
40+
resources={'gpu': 1, 'gpu_type': 'nvidia-a100-80gb'},
41+
)
42+
class Engine:
43+
def __init__(self):
44+
from transformers import AutoTokenizer
45+
from vllm import AsyncEngineArgs, AsyncLLMEngine
46+
from vllm.entrypoints.openai.api_server import init_app_state
47+
import vllm.entrypoints.openai.api_server as vllm_api_server
48+
49+
ENGINE_ARGS = AsyncEngineArgs(model=MODEL_ID, max_model_len=MAX_TOKENS, enable_prefix_caching=True)
50+
self.engine = AsyncLLMEngine.from_engine_args(ENGINE_ARGS)
51+
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
52+
53+
OPENAI_ENDPOINTS = [
54+
["/chat/completions", vllm_api_server.create_chat_completion, ["POST"]],
55+
["/completions", vllm_api_server.create_completion, ["POST"]],
56+
["/models", vllm_api_server.show_available_models, ["GET"]],
57+
]
58+
59+
for route, endpoint, methods in OPENAI_ENDPOINTS: openai_api_app.add_api_route( path=route, endpoint=endpoint, methods=methods,)
60+
61+
model_config = self.engine.engine.get_model_config()
62+
args = Namespace()
63+
args.model = MODEL_ID
64+
args.disable_log_requests = True
65+
args.max_log_len = 1000
66+
args.response_role = "assistant"
67+
args.served_model_name = None
68+
args.chat_template = None
69+
args.lora_modules = None
70+
args.prompt_adapters = None
71+
args.request_logger = None
72+
args.disable_log_stats = True
73+
args.return_tokens_as_token_ids = False
74+
args.enable_tool_call_parser = True
75+
args.enable_auto_tool_choice = True
76+
args.tool_call_parser = "llama3_json"
77+
args.enable_prompt_tokens_details = False
78+
79+
vllm_api_server.init_app_state( self.engine, model_config, openai_api_app.state, args)
80+
81+
@bentoml.api
82+
async def suggests(self, essay: str, num_suggestions: Annotated[int, Le(10)] = 5, max_tokens: Annotated[int, Ge(128), Le(MAX_TOKENS)] = MAX_TOKENS) -> AsyncGenerator[str, None]:
83+
from vllm import SamplingParams
84+
85+
SAMPLING_PARAM = SamplingParams(max_tokens=max_tokens, skip_special_tokens=True)
86+
messages = [
87+
{"role": "system", "content": SYSTEM_PROMPT.format(num_suggestion=num_suggestions)},
88+
{"role": "user", "content": essay}]
89+
90+
prompt = self.tokenizer.apply_chat_template(
91+
messages,
92+
tokenize=False,
93+
add_generation_prompt=True,
94+
)
95+
stream = await self.engine.add_request(uuid.uuid4().hex, prompt, SAMPLING_PARAM)
96+
97+
cursor = 0
98+
async for request_output in stream:
99+
text = request_output.outputs[0].text
100+
yield text[cursor:]
101+
cursor = len(text)
102+
103+
if __name__ == "__main__": Engine.serve_http(port=3000)

‎python/exo/README.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
## exo
22

3-
_a series of investigation into SAEs prevention for latent activation_
3+
a vLLM plugin to serve SAEs within the inference engine.
44

5-
To be used with `exo-service`
5+
## installation
6+
7+
```bash
8+
uv add vllm exo
9+
```
10+
11+
Note that this will override the default LlamaForCausalLM in vLLM.

‎python/exo/notebooks/sae.ipynb

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"\n",
3636
"import torch, pathlib, pandas as pd\n",
3737
"import huggingface_hub as hf_hub, safetensors as st\n",
38+
"from goodfire import Variant\n",
3839
"\n",
3940
"# device setup\n",
4041
"if torch.backends.mps.is_available():\n",

‎python/exo/pyproject.toml

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
[project]
22
name = "exo"
3-
description = "a series of investigation into SAEs prevention for latent activation"
3+
description = "a vLLM plugin to serve SAEs within the inference engine."
44
readme = "README.md"
55
requires-python = ">=3.11"
66
license = { text = "Apache-2.0" }
77
authors = [{ name = "Aaron Pham", email = "contact@aarnphm.xyz" }]
88
dependencies = [
9-
"nnsight>=0.3.5",
10-
"sae",
11-
"sae-lens>=0.5.0",
9+
"huggingface-hub>=0.25.0",
1210
"transformers>=4.44.2",
11+
"vllm>=0.7.0",
1312
]
1413
dynamic = ["version"]
1514
[project.urls]
@@ -18,6 +17,9 @@ GitHub = "https://github.com/aarnphm/tinymorph"
1817
Twitter = "https://twitter.com/aarnphm_"
1918
Tracker = "https://github.com/aarnphm/tinymorph/issues"
2019

20+
[project.entry-points."vllm.general_plugins"]
21+
llama_saes = "exo:register"
22+
2123
[build-system]
2224
requires = ["hatchling", "hatch-vcs"]
2325
build-backend = "hatchling.build"
@@ -48,10 +50,10 @@ packages = ["src/exo"]
4850

4951
[tool.uv]
5052
dev-dependencies = [
53+
"nnsight>=0.3.5",
54+
"sae-lens>=0.5.0",
55+
"goodfire>=0.3.4",
5156
"jupyter>=1.1.1",
5257
"jupyterlab-vim>=4.1.4",
5358
"notebook>=7.2.2",
5459
]
55-
56-
[tool.uv.sources]
57-
sae = { git = "https://github.com/EleutherAI/sae.git", rev = "main" }

‎python/exo/src/exo/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .variants import register
2+
3+
__all__ = ["register"]

‎python/exo/src/exo/llama_sae.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from __future__ import annotations
2+
3+
4+
from vllm.model_executor.models.llama import LlamaForCausalLM
5+
6+
class LlamaSAEForCausalLM(LlamaForCausalLM): ...

‎python/exo/src/exo/variants.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from __future__ import annotations
2+
3+
# NOTE: This file has to be extremely light and can be called multiple times.
4+
def register():
5+
"""out-of-tree registration for intervention with SAEs.
6+
7+
"""
8+
from vllm import ModelRegistry
9+
from exo.llama_sae import LlamaSAEForCausalLM
10+
11+
ModelRegistry.register_model("llama", LlamaSAEForCausalLM)

0 commit comments

Comments
 (0)
Please sign in to comment.