Skip to content

Commit

Permalink
Addressed review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
YiyanZhai committed Feb 17, 2025
1 parent 92e00b9 commit 88e9a41
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 15 deletions.
14 changes: 14 additions & 0 deletions examples/prefix-caching/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# WebLLM App for Prefix Caching Demo

This example demonstrates the use of `cachedPrefixes` in WebLLM.
To try it out, you can do the following steps under this folder

```bash
npm install
npm start
```

Note if you would like to hack WebLLM core package.
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
instruction in the project to build webllm locally. This option is only recommended
if you would like to hack WebLLM core package.
20 changes: 20 additions & 0 deletions examples/prefix-caching/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"name": "prefix-caching-example",
"version": "0.1.0",
"private": true,
"scripts": {
"start": "parcel src/prefix-caching.html --port 8888",
"build": "parcel build src/prefix-caching.html --dist-dir lib"
},
"devDependencies": {
"buffer": "^5.7.1",
"parcel": "^2.8.3",
"process": "^0.11.10",
"tslib": "^2.3.1",
"typescript": "^4.9.5",
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "../.."
}
}
23 changes: 23 additions & 0 deletions examples/prefix-caching/src/prefix-caching.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<!doctype html>
<html>
<script>
webLLMGlobal = {};
</script>
<body>
<h2>WebLLM Prefix Caching Test Page</h2>
Open console to see output
<br />
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./prefix-caching.ts"></script>
</body>
</html>
142 changes: 142 additions & 0 deletions examples/prefix-caching/src/prefix-caching.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import * as webllm from "@mlc-ai/web-llm";

const SYSTEM_PROMPT_PREFIX =
"You are a helpful assistant running in the user's browser, responsible for answering questions.";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

async function testPrefix() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
initProgressCallback: initProgressCallback,
logLevel: "INFO",
// Prefilling KV cache for efficiency
cachedPrefixes: [[{ role: "system", content: SYSTEM_PROMPT_PREFIX }]],
},
{
context_window_size: 2048,
},
);

const reply_using_prefix = await engine.chat.completions.create({
messages: [
{ role: "system", content: SYSTEM_PROMPT_PREFIX },
{ role: "user", content: "List three US states." },
],
// below configurations are all optional
n: 1,
temperature: 1.5,
max_tokens: 64,
logprobs: true,
top_logprobs: 2,
});
console.log(reply_using_prefix);
console.log(reply_using_prefix.usage);
}

async function testWithoutPrefix() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
// Engine Initialization without cachedPrefixes
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
initProgressCallback: initProgressCallback,
logLevel: "INFO",
},
{
context_window_size: 2048,
},
);

const reply_without_prefix = await engine.chat.completions.create({
messages: [
{ role: "system", content: SYSTEM_PROMPT_PREFIX },
{ role: "user", content: "List three US states." },
],
// below configurations are all optional
n: 1,
temperature: 1.5,
max_tokens: 64,
logprobs: true,
top_logprobs: 2,
});
console.log(reply_without_prefix);
console.log(reply_without_prefix.usage);
}

async function testMultiRound() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
initProgressCallback: initProgressCallback,
logLevel: "INFO",
cachedPrefixes: [[{ role: "system", content: SYSTEM_PROMPT_PREFIX }]], // Prefilling KV cache for efficiency
},
{
context_window_size: 2048,
},
);

// First Completion with cachedPrefixes
const reply0 = await engine.chat.completions.create({
messages: [
{ role: "system", content: SYSTEM_PROMPT_PREFIX },
{ role: "user", content: "List three US states." },
],
// below configurations are all optional
n: 1,
temperature: 1.5,
max_tokens: 64,
logprobs: true,
top_logprobs: 2,
});
console.log(reply0);
console.log(reply0.usage);

// Second Completion with cachedPrefixes
const reply1 = await engine.chat.completions.create({
messages: [
{ role: "system", content: SYSTEM_PROMPT_PREFIX },
{ role: "user", content: "Where is the US capital?" },
],
// below configurations are all optional
n: 1,
temperature: 1.5,
max_tokens: 64,
logprobs: true,
top_logprobs: 2,
});
console.log(reply1);
console.log(reply1.usage);
}

async function main() {
await testPrefix();

await testWithoutPrefix();

await testMultiRound();
}

main();
4 changes: 4 additions & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,13 @@ export interface ChatOptions extends Partial<ChatConfig> {}
* appConfig: Configure the app, including the list of models and whether to use IndexedDB cache.
* initProgressCallback: A callback for showing the progress of loading the model.
* logitProcessorRegistry: A register for stateful logit processors, see `webllm.LogitProcessor`.
* cachedPrefixes: Specifies a system prompt (prefix) that will be prefilled when loading the engine
* to create their corresponding KV cache and store them for reuse. These cached kv pairs persist
* until the engine is reloaded.
*
* @note All fields are optional, and `logitProcessorRegistry` is only used for `MLCEngine` and not
* other `MLCEngine`s.
* @note cachedPrefixes is experimental. It may change in future versions.
*/
export interface MLCEngineConfig {
appConfig?: AppConfig;
Expand Down
28 changes: 13 additions & 15 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completi

type ImageURL = ChatCompletionContentPartImage.ImageURL;

// Default sequence ID for chat completion
const CHAT_SEQUENCE_ID = 0;

export class LLMChatPipeline {
private config: ChatConfig;
private tokenizer: Tokenizer;
Expand Down Expand Up @@ -177,7 +180,7 @@ export class LLMChatPipeline {
log.info("prepend_space_in_encode: ", this.prepend_space_in_encode);

this.seqIdToPrefix = new Map<number, number[]>();
this.nextSequenceId = 0;
this.nextSequenceId = CHAT_SEQUENCE_ID;
this.device = this.tvm.webgpu();

// 1. Create VM and get the core functions
Expand Down Expand Up @@ -510,9 +513,9 @@ export class LLMChatPipeline {
msgRole: Role, // either user or tool
inp_role_str?: string,
genConfig?: GenerationConfig,
seqID = 0,
seqID = CHAT_SEQUENCE_ID,
): Promise<void> {
if (seqID === 0) {
if (seqID === CHAT_SEQUENCE_ID) {
if (msgRole !== Role.user && msgRole !== Role.tool) {
throw new MessageOrderError(
"The last message should be from `user` or `tool`.",
Expand Down Expand Up @@ -608,7 +611,7 @@ export class LLMChatPipeline {
}

// 0. Get inputData from conversation
if (seqID === 0) {
if (seqID === CHAT_SEQUENCE_ID) {
if (conversation.isTextCompletion) {
conversation.prompt = inp;
} else {
Expand Down Expand Up @@ -652,13 +655,8 @@ export class LLMChatPipeline {

// If a match is found, fork the sequence
if (matchedSeqId !== -1 && maxMatchedLen > 0) {
console.log(
"Forking sequence",
matchedSeqId,
"at position",
maxMatchedLen,
);
if (seqID === 0) {
log.info("Forking sequence", matchedSeqId, "at position", maxMatchedLen);
if (seqID === CHAT_SEQUENCE_ID) {
this.fKVCacheRemoveSequence!(
this.kvCache,
new tvmjs.Scalar(seqID, "int64"),
Expand All @@ -672,14 +670,14 @@ export class LLMChatPipeline {
new tvmjs.Scalar(maxMatchedLen, "int64"), // fork_position
);
this.tvm.endScope();
} else if (seqID !== 0) {
} else if (seqID !== CHAT_SEQUENCE_ID) {
// If no match is found, add the new sequence to the KV cache
console.log("Adding prefix to KV cache: ", seqID);
log.info("Adding prefix to KV cache: ", seqID);
this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(seqID, "int64"));
}

// Add the new sequence to the seqIdToPrefix map (if it is a prefix)
if (seqID !== 0) {
if (seqID !== CHAT_SEQUENCE_ID) {
this.seqIdToPrefix.set(seqID, inputTokens);
}

Expand Down Expand Up @@ -996,7 +994,7 @@ export class LLMChatPipeline {
private async embedAndForward(
inputData: Array<Array<number> | ImageURL>,
inputDataLen: number,
seqID = 0,
seqID = CHAT_SEQUENCE_ID,
): Promise<tvmjs.NDArray> {
if (inputDataLen > this.prefillChunkSize) {
throw new Error(
Expand Down

0 comments on commit 88e9a41

Please sign in to comment.