-
Notifications
You must be signed in to change notification settings - Fork 623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] LLM Token-Level Generation Supervision #370
Labels
💡 feature request
New feature or request
Comments
Going through the linked workbook, implementing LM Format Enforcer for JSON schema inputs would probably look something like: # nexa/constants.py
NEXA_RUN_COMPLETION_TEMPLATE_MAP = {
"format_enforcer": "You MUST answer using the following JSON schema:",
"octopus-v2": "Below is the query from the users, please call the correct function and generate the parameters to call the function.\n\nQuery: {input} \n\nResponse:",
"octopus-v4": "<|system|>You are a router. Below is the query from the users, please call the correct function and generate the parameters to call the function.<|end|><|user|>{input}<|end|><|assistant|>",
}
# nexa/gguf/structure_utils.py
from nexa.gguf.llama import Llama
from nexa.gguf.llama import LogitsProcessorList
from pydantic import BaseModel, ValidationError
from typing import Optional, List
from lmformatenforcer import CharacterLevelParser, JsonSchemaParser
from lmformatenforcer.integrations.llamacpp import build_llamacpp_logits_processor, build_token_enforcer_tokenizer_data
from nexa.constants import (
NEXA_RUN_COMPLETION_TEMPLATE_MAP
)
class FormatEnforcer(downloaded_path: str = None):
"""
Character level parser for llama cpp
Source: https://github.com/noamgat/lm-format-enforcer
samples/colab_llamacpppython_integration.ipynb
"""
llm = Llama(model_path=downloaded_path)
def llamacpp_with_character_level_parser(prompt: str, character_level_parser: Optional[CharacterLevelParser]) -> str:
logits_processors: Optional[LogitsProcessorList] = None
if character_level_parser:
logits_processors = LogitsProcessorList([build_llamacpp_logits_processor(tokenizer_data, character_level_parser)])
output = llm(prompt, logits_processor=logits_processors, max_tokens=100)
text: str = output['choices'][0]['text']
return text
class PydanticToJson(pydantic_model):
"""
Validates a Pydantic model and converts it to JSON for processing. Not currently used.
Based on: https://github.com/noamgat/lm-format-enforcer
samples/colab_llamacpppython_integration.ipynb
"""
pydantic_model.model_validate(pydantic_model, strict=True)
return pydantic_model.schema_json() This code would then be called on to replace nexa_inference_text.py lines 361-400 with something like: # nexa/gguf/nexa_inference_text.py
# top of file
from nexa.gguf.structure_utils import FormatEnforcer
from nexa.gguf.constants import (
NEXA_RUN_COMPLETION_TEMPLATE_MAP
)
# from line 361
enforcer_instructions = NEXA_RUN_COMPLETION_TEMPLATE_MAP.get(format_enforcer, None)
structured_prompt = f"{prompt} {enforcer_instructions} {json_schema}"
params = {
"temperature": self.params.get("temperature", 0.7),
"max_tokens": self.params.get("max_new_tokens", 2048),
"top_k": self.params.get("top_k", 50),
"top_p": self.params.get("top_p", 1.0),
"stop": self.stop_words,
"logprobs": self.logprobs
}
params.update(kwargs)
# Perform structured inference
structured_data = FormatEnforcer(structured_prompt, JsonSchemaParser(AnswerFormat.schema())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Feature Description
Rescued from #368:
You may wish to consider implementing one of the token-level supervision options for LlamaCPP to deliver superior adherence during structured generation. It's the difference between asking "pretty please" and guaranteeing a correctly structured response.
As currently implemented by @xsxszab in nexa_inference_text.py, generation will fail if the model does not return a valid JSON response or doesn't follow the requested schema.
Options
LM Format Enforcer (Python)
LM Format Enforcer's llama-cpp-python integration code should be easy to adapt. This package is already being used in RedHat/IBM's enterprise-focused VLLM project (reference).
A demonstration workbook is available here. You may be able to run this workbook as-is by merely changing the imports. e.g.:
LLGuidance (upstream)
The LLGuidance Rust crate has recently been added to upstream llama.cpp.
Enabling this feature during compilation requires some fiddling with Rust, and there are still some bug fixes that need to be finalized (pull 11644). However, these are transitional problems and adopting this approach would probably make it easier for end-users to utilize structured generation using the SDK.
The text was updated successfully, but these errors were encountered: