Skip to content

Commit

Permalink
add prompt autogeneration (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgbradley1 authored Jun 19, 2024
1 parent d18c2fb commit fc1dc34
Show file tree
Hide file tree
Showing 21 changed files with 604 additions and 407 deletions.
Binary file modified backend/graphrag-wheel/graphrag-0.0.1-py3-none-any.whl
Binary file not shown.
3 changes: 1 addition & 2 deletions backend/graphrag-wheel/note.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@ This graphrag wheel file was built from the following repo

https://github.com/microsoft/graphrag

on commit hash a389514b4723189803097dc0d602b50d139ee92c

on commit hash 8cb189635e90d49231f3f09b54e69d4daae1371d
4 changes: 0 additions & 4 deletions backend/run-indexing-job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@

parser = argparse.ArgumentParser(description="Kickoff indexing job.")
parser.add_argument("-i", "--index-name", required=True)
parser.add_argument("-s", "--storage-name", required=True)
parser.add_argument("-e", "--entity-config", required=False)
args = parser.parse_args()

asyncio.run(
_start_indexing_pipeline(
index_name=args.index_name,
storage_name=args.storage_name,
entity_config_name=args.entity_config,
***REMOVED***
)
4 changes: 2 additions & 2 deletions backend/src/aks-batch-job-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ kind: Job
metadata:
name: PLACEHOLDER
spec:
ttlSecondsAfterFinished: 0
backoffLimit: 30
ttlSecondsAfterFinished: 120
backoffLimit: 6
template:
metadata:
labels:
Expand Down
2 changes: 1 addition & 1 deletion backend/src/api/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def stream_response(report_df, query, end_callback=(lambda x: x), timeout=300):
this_directory = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
***REMOVED***
data = yaml.safe_load(open(f"{this_directory***REMOVED***/pipeline_settings.yaml"))
data = yaml.safe_load(open(f"{this_directory***REMOVED***/pipeline-settings.yaml"))
# layer the custom settings on top of the default configuration settings of graphrag
parameters = create_graphrag_config(data, ".")

Expand Down
244 changes: 126 additions & 118 deletions backend/src/api/index.py

Large diffs are not rendered by default.

94 changes: 87 additions & 7 deletions backend/src/api/index_configuration.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
***REMOVED***
***REMOVED***

import inspect
***REMOVED***
import shutil
from typing import Union

import yaml
from fastapi import (
APIRouter,
Depends,
HTTPException,
)
from fastapi.responses import StreamingResponse
from graphrag.prompt_tune.cli import fine_tune as generate_fine_tune_prompts

from src.api.azure_clients import AzureStorageClientManager
from src.api.azure_clients import (
AzureStorageClientManager,
BlobServiceClientSingleton,
)
from src.api.common import (
sanitize_name,
verify_subscription_key_exist,
Expand All @@ -23,7 +31,6 @@
from src.reporting import ReporterSingleton

azure_storage_client_manager = AzureStorageClientManager()

index_configuration_route = APIRouter(
prefix="/index/config", tags=["Index Configuration"]
)
Expand All @@ -33,14 +40,17 @@
Depends(verify_subscription_key_exist)
***REMOVED***

# NOTE: currently disable all /entity endpoints - to be replaced by the auto-generation of prompts


@index_configuration_route.get(
"/entity",
summary="Get all entity configurations",
response_model=EntityNameList,
responses={200: {"model": EntityNameList***REMOVED***, 400: {"model": EntityNameList***REMOVED******REMOVED***,
include_in_schema=False,
)
def get_all_entitys():
async def get_all_entitys():
***REMOVED***
Retrieve a list of all entity configuration names.
***REMOVED***
Expand All @@ -62,8 +72,9 @@ def get_all_entitys():
summary="Create an entity configuration",
response_model=BaseResponse,
responses={200: {"model": BaseResponse***REMOVED******REMOVED***,
include_in_schema=False,
)
def create_entity(request: EntityConfiguration):
async def create_entity(request: EntityConfiguration):
# check for entity configuration existence
entity_container = azure_storage_client_manager.get_cosmos_container_client(
database_name="graphrag", container_name="entities"
Expand Down Expand Up @@ -121,8 +132,9 @@ def create_entity(request: EntityConfiguration):
summary="Update an existing entity configuration",
response_model=BaseResponse,
responses={200: {"model": BaseResponse***REMOVED******REMOVED***,
include_in_schema=False,
)
def update_entity(request: EntityConfiguration):
async def update_entity(request: EntityConfiguration):
# check for entity configuration existence
reporter = ReporterSingleton.get_instance()
existing_item = None
Expand Down Expand Up @@ -179,8 +191,9 @@ def update_entity(request: EntityConfiguration):
summary="Get a specified entity configuration",
response_model=Union[EntityConfiguration, BaseResponse],
responses={200: {"model": EntityConfiguration***REMOVED***, 400: {"model": BaseResponse***REMOVED******REMOVED***,
include_in_schema=False,
)
def get_entity(entity_configuration_name: str):
async def get_entity(entity_configuration_name: str):
reporter = ReporterSingleton.get_instance()
***REMOVED***
existing_item = None
Expand Down Expand Up @@ -213,8 +226,9 @@ def get_entity(entity_configuration_name: str):
summary="Delete a specified entity configuration",
response_model=BaseResponse,
responses={200: {"model": BaseResponse***REMOVED******REMOVED***,
include_in_schema=False,
)
def delete_entity(entity_configuration_name: str):
async def delete_entity(entity_configuration_name: str):
reporter = ReporterSingleton.get_instance()
***REMOVED***
entity_container = azure_storage_client_manager.get_cosmos_container_client(
Expand All @@ -235,3 +249,69 @@ def delete_entity(entity_configuration_name: str):
status_code=500,
detail=f"Entity configuration '{entity_configuration_name***REMOVED***' not found.",
***REMOVED***


@index_configuration_route.get(
"/prompts",
summary="Generate graphrag prompts from user-provided data.",
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
)
async def generate_prompts(storage_name: str, limit: int = 5):
***REMOVED***
Automatically generate custom prompts for entity entraction,
community reports, and summarize descriptions based on a sample of provided data.
***REMOVED***
# check for storage container existence
blob_service_client = BlobServiceClientSingleton().get_instance()
sanitized_storage_name = sanitize_name(storage_name)
if not blob_service_client.get_container_client(sanitized_storage_name).exists():
raise HTTPException(
status_code=500,
detail=f"Data container '{storage_name***REMOVED***' does not exist.",
***REMOVED***
this_directory = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
***REMOVED***
print("THIS DIRECTORY: ", this_directory)
print("CWD: ", os.getcwd())

# write custom settings.yaml to a file and store in a temporary directory
data = yaml.safe_load(open(f"{this_directory***REMOVED***/pipeline-settings.yaml"))
data["input"]["container_name"] = sanitized_storage_name
temp_dir = f"/tmp/{sanitized_storage_name***REMOVED***_prompt_tuning"
shutil.rmtree(temp_dir, ignore_errors=True)
os.makedirs(temp_dir, exist_ok=True)
print(f"TEMP SETTINGS DIR: {temp_dir***REMOVED***")
with open(f"{temp_dir***REMOVED***/settings.yaml", "w") as f:
yaml.dump(data, f, default_flow_style=False)

# generate prompts
***REMOVED***
await generate_fine_tune_prompts(
root=temp_dir,
domain="",
select="random",
limit=limit,
skip_entity_types=True,
output="prompts",
***REMOVED***
except Exception:
raise HTTPException(
status_code=500,
detail=f"Error generating prompts for data in '{storage_name***REMOVED***'. Please try a lower limit.",
***REMOVED***

# zip up the generated prompt files and return the zip file
temp_archive = (
f"{temp_dir***REMOVED***/prompts" # will become a zip file with the name prompts.zip
***REMOVED***
shutil.make_archive(temp_archive, "zip", root_dir=temp_dir, base_dir="prompts")
print(f"ARCHIVE: {temp_archive***REMOVED***.zip")
for f in os.listdir(temp_dir):
print(f"FILE: {f***REMOVED***")

def iterfile(file_path: str):
with open(file_path, mode="rb") as file_like:
yield from file_like

return StreamingResponse(iterfile(f"{temp_archive***REMOVED***.zip"))
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,14 @@ embeddings:
overwrite: True
url: $AI_SEARCH_URL

entity_extraction:
prompt: PLACEHOLDER
entity_types: PLACEHOLDER
# entity_extraction:
# prompt: PLACEHOLDER

# community_reports:
# prompt: PLACEHOLDER

# summarize_descriptions:
# prompt: PLACEHOLDER

snapshots:
graphml: True
4 changes: 2 additions & 2 deletions backend/src/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def global_query(request: GraphRequest):
this_directory = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
***REMOVED***
data = yaml.safe_load(open(f"{this_directory***REMOVED***/pipeline_settings.yaml"))
data = yaml.safe_load(open(f"{this_directory***REMOVED***/pipeline-settings.yaml"))
# layer the custom settings on top of the default configuration settings of graphrag
parameters = create_graphrag_config(data, ".")

Expand Down Expand Up @@ -340,7 +340,7 @@ async def local_query(request: GraphRequest):
this_directory = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
***REMOVED***
data = yaml.safe_load(open(f"{this_directory***REMOVED***/pipeline_settings.yaml"))
data = yaml.safe_load(open(f"{this_directory***REMOVED***/pipeline-settings.yaml"))
# layer the custom settings on top of the default configuration settings of graphrag
parameters = create_graphrag_config(data, ".")

Expand Down
Loading

0 comments on commit fc1dc34

Please sign in to comment.