diff --git a/backend/graphrag_app/api/query.py b/backend/graphrag_app/api/query.py index 79fe62a..cacf95f 100644 --- a/backend/graphrag_app/api/query.py +++ b/backend/graphrag_app/api/query.py @@ -1,29 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import json -import os import traceback from pathlib import Path -from typing import Any -import pandas as pd import yaml -from azure.identity import DefaultAzureCredential -from azure.search.documents import SearchClient -from azure.search.documents.models import VectorizedQuery from fastapi import ( APIRouter, HTTPException, ) from graphrag.api.query import global_search, local_search from graphrag.config.create_graphrag_config import create_graphrag_config -from graphrag.model.types import TextEmbedder -from graphrag.vector_stores.base import ( - BaseVectorStore, - VectorStoreDocument, - VectorStoreSearchResult, -) from graphrag_app.logger.load_logger import load_pipeline_logger from graphrag_app.typing.models import ( @@ -54,31 +41,23 @@ ) async def global_query(request: GraphRequest): # this is a slightly modified version of the graphrag.query.cli.run_global_search method - if isinstance(request.index_name, str): - index_names = [request.index_name] - else: - index_names = request.index_name - sanitized_index_names = [sanitize_name(name) for name in index_names] - sanitized_index_names_link = { - s: i for s, i in zip(sanitized_index_names, index_names) - } - - for index_name in sanitized_index_names: - if not _is_index_complete(index_name): - raise HTTPException( - status_code=500, - detail=f"{index_name} not ready for querying.", - ) + index_name = request.index_name + sanitized_index_name = sanitize_name(index_name) + + if not _is_index_complete(sanitized_index_name): + raise HTTPException( + status_code=500, + detail=f"{index_name} not ready for querying.", + ) COMMUNITY_REPORT_TABLE = "output/create_final_community_reports.parquet" COMMUNITIES_TABLE = "output/create_final_communities.parquet" ENTITIES_TABLE = "output/create_final_entities.parquet" NODES_TABLE = "output/create_final_nodes.parquet" - for index_name in sanitized_index_names: - validate_index_file_exist(index_name, COMMUNITY_REPORT_TABLE) - validate_index_file_exist(index_name, ENTITIES_TABLE) - validate_index_file_exist(index_name, NODES_TABLE) + validate_index_file_exist(sanitized_index_name, COMMUNITY_REPORT_TABLE) + validate_index_file_exist(sanitized_index_name, ENTITIES_TABLE) + validate_index_file_exist(sanitized_index_name, NODES_TABLE) if isinstance(request.community_level, int): COMMUNITY_LEVEL = request.community_level @@ -87,108 +66,19 @@ async def global_query(request: GraphRequest): COMMUNITY_LEVEL = 1 try: - links = { - "nodes": {}, - "community": {}, - "community_reports": {}, - "entities": {}, - } - max_vals = { - "nodes": -1, - "community": -1, - "community_reports": -1, - "entities": -1, - } - - communities_dfs = [] - community_reports_dfs = [] - entities_dfs = [] - nodes_dfs = [] - - for index_name in sanitized_index_names: - # read the parquet files into DataFrames and add provenance information - community_report_table_path = ( - f"abfs://{index_name}/{COMMUNITY_REPORT_TABLE}" - ) - communities_table_path = f"abfs://{index_name}/{COMMUNITIES_TABLE}" - entities_table_path = f"abfs://{index_name}/{ENTITIES_TABLE}" - nodes_table_path = f"abfs://{index_name}/{NODES_TABLE}" - - # Prepare each index's nodes dataframe for merging - nodes_df = get_df(nodes_table_path) - for i in nodes_df["human_readable_id"]: - links["nodes"][i + max_vals["nodes"] + 1] = { - "index_name": sanitized_index_names_link[index_name], - "id": i, - } - if max_vals["nodes"] != -1: - nodes_df["human_readable_id"] += max_vals["nodes"] + 1 - nodes_df["community"] = nodes_df["community"].apply( - lambda x: x + max_vals["community_reports"] + 1 if x != -1 else x - ) - nodes_df["title"] = nodes_df["title"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - max_vals["nodes"] = int(nodes_df["human_readable_id"].max()) - nodes_dfs.append(nodes_df) - - # Prepare each index's community reports dataframe for merging - community_reports_df = get_df(community_report_table_path) - for i in community_reports_df["community"]: - links["community_reports"][i + max_vals["community_reports"] + 1] = { - "index_name": sanitized_index_names_link[index_name], - "id": str(i), - } - community_reports_df["community"] += max_vals["community_reports"] + 1 - community_reports_df["human_readable_id"] += ( - max_vals["community_reports"] + 1 - ) - max_vals["community_reports"] = int(community_reports_df["community"].max()) - community_reports_dfs.append(community_reports_df) - - # Prepare each index's communities dataframe for merging - communities_df = get_df(communities_table_path) - for i in communities_df["community"]: - links["community"][i + max_vals["community"] + 1] = { - "index_name": sanitized_index_names_link[index_name], - "id": str(i), - } - communities_df["community"] += max_vals["community"] + 1 - communities_df["parent"] = communities_df["parent"].apply( - lambda x: x if x == -1 else x + max_vals["community"] + 1 - ) - communities_df["human_readable_id"] += max_vals["community"] + 1 - max_vals["community"] = int(communities_df["community"].max()) - communities_dfs.append(communities_df) - - # Prepare each index's entities dataframe for merging - entities_df = get_df(entities_table_path) - for i in entities_df["human_readable_id"]: - links["entities"][i + max_vals["entities"] + 1] = { - "index_name": sanitized_index_names_link[index_name], - "id": i, - } - entities_df["human_readable_id"] += max_vals["entities"] + 1 - entities_df["title"] = entities_df["title"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["entities"] = entities_df["human_readable_id"].max() - entities_dfs.append(entities_df) - - # merge the dataframes - nodes_combined = pd.concat(nodes_dfs, axis=0, ignore_index=True, sort=False) - community_reports_combined = pd.concat( - community_reports_dfs, axis=0, ignore_index=True, sort=False - ) - entities_combined = pd.concat( - entities_dfs, axis=0, ignore_index=True, sort=False - ) - communities_combined = pd.concat( - communities_dfs, axis=0, ignore_index=True, sort=False + # read the parquet files into DataFrames and add provenance information + community_report_table_path = ( + f"abfs://{sanitized_index_name}/{COMMUNITY_REPORT_TABLE}" ) + communities_table_path = f"abfs://{sanitized_index_name}/{COMMUNITIES_TABLE}" + entities_table_path = f"abfs://{sanitized_index_name}/{ENTITIES_TABLE}" + nodes_table_path = f"abfs://{sanitized_index_name}/{NODES_TABLE}" + + # load parquet tables associated with the index + nodes_df = get_df(nodes_table_path) + community_reports_df = get_df(community_report_table_path) + communities_df = get_df(communities_table_path) + entities_df = get_df(entities_table_path) # load custom pipeline settings ROOT_DIR = Path(__file__).resolve().parent.parent.parent @@ -201,20 +91,17 @@ async def global_query(request: GraphRequest): # perform async search result = await global_search( config=parameters, - nodes=nodes_combined, - entities=entities_combined, - communities=communities_combined, - community_reports=community_reports_combined, + nodes=nodes_df, + entities=entities_df, + communities=communities_df, + community_reports=community_reports_df, community_level=COMMUNITY_LEVEL, dynamic_community_selection=False, response_type="Multiple Paragraphs", query=request.query, ) - # link index provenance to the context data - context_data = _update_context(result[1], links) - - return GraphResponse(result=result[0], context_data=context_data) + return GraphResponse(result=result[0], context_data=result[1]) except Exception as e: logger = load_pipeline_logger() logger.error( @@ -233,49 +120,18 @@ async def global_query(request: GraphRequest): responses={200: {"model": GraphResponse}}, ) async def local_query(request: GraphRequest): - if isinstance(request.index_name, str): - index_names = [request.index_name] - else: - index_names = request.index_name - sanitized_index_names = [sanitize_name(name) for name in index_names] - sanitized_index_names_link = { - s: i for s, i in zip(sanitized_index_names, index_names) - } - - for index_name in sanitized_index_names: - if not _is_index_complete(index_name): - raise HTTPException( - status_code=500, - detail=f"{index_name} not ready for querying.", - ) + index_name = request.index_name + sanitized_index_name = sanitize_name(index_name) + + if not _is_index_complete(sanitized_index_name): + raise HTTPException( + status_code=500, + detail=f"{index_name} not ready for querying.", + ) azure_client_manager = AzureClientManager() blob_service_client = azure_client_manager.get_blob_service_client() - links = { - "nodes": {}, - "community_reports": {}, - "entities": {}, - "text_units": {}, - "relationships": {}, - "covariates": {}, - } - max_vals = { - "nodes": -1, - "community_reports": -1, - "entities": -1, - "text_units": 0, - "relationships": -1, - "covariates": 0, - } - - community_reports_dfs = [] - entities_dfs = [] - nodes_dfs = [] - relationships_dfs = [] - text_units_dfs = [] - covariates_dfs = [] - COMMUNITY_REPORT_TABLE = "output/create_final_community_reports.parquet" COVARIATES_TABLE = "output/create_final_covariates.parquet" ENTITIES_TABLE = "output/create_final_entities.parquet" @@ -289,161 +145,35 @@ async def local_query(request: GraphRequest): # Current investigations show that community level 2 is the most useful for local search. Set this as the default value COMMUNITY_LEVEL = 2 - # read the parquet files for each index into DataFrames and add provenance information - for index_name in sanitized_index_names: - # check for existence of files the query relies on to validate the index is complete - validate_index_file_exist(index_name, COMMUNITY_REPORT_TABLE) - validate_index_file_exist(index_name, ENTITIES_TABLE) - validate_index_file_exist(index_name, NODES_TABLE) - validate_index_file_exist(index_name, RELATIONSHIPS_TABLE) - validate_index_file_exist(index_name, TEXT_UNITS_TABLE) - - community_report_table_path = f"abfs://{index_name}/{COMMUNITY_REPORT_TABLE}" - covariates_table_path = f"abfs://{index_name}/{COVARIATES_TABLE}" - entities_table_path = f"abfs://{index_name}/{ENTITIES_TABLE}" - nodes_table_path = f"abfs://{index_name}/{NODES_TABLE}" - relationships_table_path = f"abfs://{index_name}/{RELATIONSHIPS_TABLE}" - text_units_table_path = f"abfs://{index_name}/{TEXT_UNITS_TABLE}" - - # Prepare each index's nodes dataframe for merging - nodes_df = get_df(nodes_table_path) - for i in nodes_df["human_readable_id"]: - links["nodes"][i + max_vals["nodes"] + 1] = { - "index_name": sanitized_index_names_link[index_name], - "id": i, - } - if max_vals["nodes"] != -1: - nodes_df["human_readable_id"] += max_vals["nodes"] + 1 - nodes_df["community"] = nodes_df["community"].apply( - lambda x: x + max_vals["community_reports"] + 1 if x != -1 else x - ) - nodes_df["title"] = nodes_df["title"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - nodes_df["id"] = nodes_df["id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - max_vals["nodes"] = int(nodes_df["human_readable_id"].max()) - nodes_dfs.append(nodes_df) + # check for existence of files the query relies on to validate the index is complete + validate_index_file_exist(sanitized_index_name, COMMUNITY_REPORT_TABLE) + validate_index_file_exist(sanitized_index_name, ENTITIES_TABLE) + validate_index_file_exist(sanitized_index_name, NODES_TABLE) + validate_index_file_exist(sanitized_index_name, RELATIONSHIPS_TABLE) + validate_index_file_exist(sanitized_index_name, TEXT_UNITS_TABLE) - # Prepare each index's community reports dataframe for merging - community_reports_df = get_df(community_report_table_path) - community_reports_df["community"] = community_reports_df["community"].astype( - int - ) - for i in community_reports_df["community"]: - links["community_reports"][i + max_vals["community_reports"] + 1] = { - "index_name": sanitized_index_names_link[index_name], - "id": str(i), - } - community_reports_df["community"] += max_vals["community_reports"] + 1 - community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1 - max_vals["community_reports"] = int(community_reports_df["community"].max()) - community_reports_dfs.append(community_reports_df) - - # Prepare each index's entities dataframe for merging - entities_df = get_df(entities_table_path) - for i in entities_df["human_readable_id"]: - links["entities"][i + max_vals["entities"] + 1] = { - "index_name": sanitized_index_names_link[index_name], - "id": i, - } - entities_df["human_readable_id"] += max_vals["entities"] + 1 - entities_df["title"] = entities_df["title"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["id"] = entities_df["id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["entities"] = int(entities_df["human_readable_id"].max()) - entities_dfs.append(entities_df) - - # Prepare each index's relationships dataframe for merging - relationships_df = get_df(relationships_table_path) - for i in relationships_df["human_readable_id"]: - links["relationships"][i + max_vals["relationships"] + 1] = { - "index_name": sanitized_index_names_link[index_name], - "id": i, - } - if max_vals["relationships"] != -1: - col = ( - relationships_df["human_readable_id"].astype(int) - + max_vals["relationships"] - + 1 - ) - relationships_df["human_readable_id"] = col.astype(str) - relationships_df["source"] = relationships_df["source"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - relationships_df["target"] = relationships_df["target"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["relationships"] = int(relationships_df["human_readable_id"].max()) - relationships_dfs.append(relationships_df) - - # Prepare each index's text units dataframe for merging - text_units_df = get_df(text_units_table_path) - for i in range(text_units_df.shape[0]): - links["text_units"][i + max_vals["text_units"]] = { - "index_name": sanitized_index_names_link[index_name], - "id": i, - } - text_units_df["id"] = text_units_df["id"].apply( - lambda x, index_name=index_name: f"{x}-{index_name}" - ) - text_units_df["human_readable_id"] = ( - text_units_df["human_readable_id"] + max_vals["text_units"] - ) - max_vals["text_units"] += text_units_df.shape[0] - text_units_dfs.append(text_units_df) - - # If present, prepare each index's covariates dataframe for merging - index_container_client = blob_service_client.get_container_client(index_name) - if index_container_client.get_blob_client(COVARIATES_TABLE).exists(): - covariates_df = get_df(covariates_table_path) - for i in covariates_df["human_readable_id"].astype(int): - links["covariates"][i + max_vals["covariates"]] = { - "index_name": sanitized_index_names_link[index_name], - "id": i, - } - covariates_df["id"] = covariates_df["id"].apply( - lambda x, index_name=index_name: f"{x}-{index_name}" - ) - covariates_df["human_readable_id"] = ( - covariates_df["human_readable_id"] + max_vals["covariates"] - ) - covariates_df["text_unit_id"] = covariates_df["text_unit_id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - covariates_df["subject_id"] = covariates_df["subject_id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - max_vals["covariates"] += covariates_df.shape[0] - covariates_dfs.append(covariates_df) - - # Merge the dataframes - nodes_combined = pd.concat(nodes_dfs, axis=0, ignore_index=True, sort=False) - community_reports_combined = pd.concat( - community_reports_dfs, axis=0, ignore_index=True, sort=False - ) - entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False) - relationships_combined = pd.concat( - relationships_dfs, axis=0, ignore_index=True, sort=False + community_report_table_path = ( + f"abfs://{sanitized_index_name}/{COMMUNITY_REPORT_TABLE}" ) - text_units_combined = pd.concat( - text_units_dfs, axis=0, ignore_index=True, sort=False + covariates_table_path = f"abfs://{sanitized_index_name}/{COVARIATES_TABLE}" + entities_table_path = f"abfs://{sanitized_index_name}/{ENTITIES_TABLE}" + nodes_table_path = f"abfs://{sanitized_index_name}/{NODES_TABLE}" + relationships_table_path = f"abfs://{sanitized_index_name}/{RELATIONSHIPS_TABLE}" + text_units_table_path = f"abfs://{sanitized_index_name}/{TEXT_UNITS_TABLE}" + + nodes_df = get_df(nodes_table_path) + community_reports_df = get_df(community_report_table_path) + entities_df = get_df(entities_table_path) + relationships_df = get_df(relationships_table_path) + text_units_df = get_df(text_units_table_path) + + # If present, prepare each index's covariates dataframe for merging + index_container_client = blob_service_client.get_container_client( + sanitized_index_name ) - covariates_combined = None - if len(covariates_dfs) > 0: - covariates_combined = pd.concat( - covariates_dfs, axis=0, ignore_index=True, sort=False - ) + covariates_df = None + if index_container_client.get_blob_client(COVARIATES_TABLE).exists(): + covariates_df = get_df(covariates_table_path) # load custom pipeline settings ROOT_DIR = Path(__file__).resolve().parent.parent.parent @@ -452,28 +182,24 @@ async def local_query(request: GraphRequest): # layer the custom settings on top of the default configuration settings of graphrag parameters = create_graphrag_config(data, ".") - # add index_names to vector_store args - parameters.embeddings.vector_store["index_names"] = sanitized_index_names + parameters.embeddings.vector_store["collection_name"] = sanitized_index_name # perform async search result = await local_search( config=parameters, - nodes=nodes_combined, - entities=entities_combined, - community_reports=community_reports_combined, - text_units=text_units_combined, - relationships=relationships_combined, - covariates=covariates_combined, + nodes=nodes_df, + entities=entities_df, + community_reports=community_reports_df, + text_units=text_units_df, + relationships=relationships_df, + covariates=covariates_df, community_level=COMMUNITY_LEVEL, response_type="Multiple Paragraphs", query=request.query, ) - # link index provenance to the context data - context_data = _update_context(result[1], links) - - return GraphResponse(result=result[0], context_data=context_data) + return GraphResponse(result=result[0], context_data=result[1]) def _is_index_complete(index_name: str) -> bool: @@ -496,199 +222,3 @@ def _is_index_complete(index_name: str) -> bool: if PipelineJobState(pipeline_job.status) == PipelineJobState.COMPLETE: return True return False - - -def _update_context(context, links): - """ - Update context data. - context_keys = ['reports', 'entities', 'relationships', 'claims', 'sources'] - """ - updated_context = {} - for key in context: - updated_entry = [] - if key == "reports": - updated_entry = [ - dict( - {k: entry[k] for k in entry}, - **{ - "index_name": links["community"][int(entry["id"])][ - "index_name" - ], - "index_id": links["community"][int(entry["id"])]["id"], - }, - ) - for entry in context[key] - ] - if key == "entities": - updated_entry = [ - dict( - {k: entry[k] for k in entry}, - **{ - "entity": entry["entity"].split("-")[0], - "index_name": links["entities"][int(entry["id"])]["index_name"], - "index_id": links["entities"][int(entry["id"])]["id"], - }, - ) - for entry in context[key] - ] - if key == "relationships": - updated_entry = [ - dict( - {k: entry[k] for k in entry}, - **{ - "source": entry["source"].split("-")[0], - "target": entry["target"].split("-")[0], - "index_name": links["relationships"][int(entry["id"])][ - "index_name" - ], - "index_id": links["relationships"][int(entry["id"])]["id"], - }, - ) - for entry in context[key] - ] - if key == "claims": - updated_entry = [ - dict( - {k: entry[k] for k in entry}, - **{ - "index_name": links["claims"][int(entry["id"])]["index_name"], - "index_id": links["claims"][int(entry["id"])]["id"], - }, - ) - for entry in context[key] - ] - if key == "sources": - updated_entry = context[key] - updated_context[key] = updated_entry - return updated_context - - -def _get_embedding_description_store( - entities: Any, - vector_store_type: str = Any, - config_args: dict | None = None, -): - collection_names = [ - f"{index_name}_description_embedding" - for index_name in config_args.get("index_names", []) - ] - ai_search_url = os.environ["AI_SEARCH_URL"] - description_embedding_store = MultiAzureAISearch( - collection_name="multi", - document_collection=None, - db_connection=None, - ) - description_embedding_store.connect(url=ai_search_url) - for collection_name in collection_names: - description_embedding_store.add_collection(collection_name) - return description_embedding_store - - -class MultiAzureAISearch(BaseVectorStore): - """The Azure AI Search vector storage implementation.""" - - def __init__( - self, - collection_name: str, - db_connection: Any, - document_collection: Any, - query_filter: Any | None = None, - **kwargs: Any, - ): - self.collection_name = collection_name - self.db_connection = db_connection - self.document_collection = document_collection - self.query_filter = query_filter - self.kwargs = kwargs - self.collections = [] - - def add_collection(self, collection_name: str): - self.collections.append(collection_name) - - def connect(self, **kwargs: Any) -> Any: - """Connect to the AzureAI vector store.""" - self.url = kwargs.get("url", None) - self.vector_size = kwargs.get("vector_size", 1536) - - self.vector_search_profile_name = kwargs.get( - "vector_search_profile_name", "vectorSearchProfile" - ) - - if self.url: - pass - else: - not_supported_error = ( - "Azure AI Search client is not supported on local host." - ) - raise ValueError(not_supported_error) - - def load_documents( - self, documents: list[VectorStoreDocument], overwrite: bool = True - ) -> None: - raise NotImplementedError("load_documents() method not implemented") - - def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: - """Build a query filter to filter documents by a list of ids.""" - if include_ids is None or len(include_ids) == 0: - self.query_filter = None - # returning to keep consistency with other methods, but not needed - return self.query_filter - - # more info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function - # search.in is faster that joined and/or conditions - id_filter = ",".join([f"{id!s}" for id in include_ids]) - self.query_filter = f"search.in(id, '{id_filter}', ',')" - - # returning to keep consistency with other methods, but not needed - # TODO: Refactor on a future PR - return self.query_filter - - def similarity_search_by_vector( - self, query_embedding: list[float], k: int = 10, **kwargs: Any - ) -> list[VectorStoreSearchResult]: - """Perform a vector-based similarity search.""" - vectorized_query = VectorizedQuery( - vector=query_embedding, k_nearest_neighbors=k, fields="vector" - ) - - docs = [] - for collection_name in self.collections: - add_on = "-" + str(collection_name.split("_")[0]) - audience = os.environ["AI_SEARCH_AUDIENCE"] - db_connection = SearchClient( - self.url, - collection_name, - DefaultAzureCredential(), - audience=audience, - ) - response = db_connection.search( - vector_queries=[vectorized_query], - ) - mod_response = [] - for r in response: - r["id"] = r.get("id", "") + add_on - mod_response += [r] - docs += mod_response - return [ - VectorStoreSearchResult( - document=VectorStoreDocument( - id=doc.get("id", ""), - text=doc.get("text", ""), - vector=doc.get("vector", []), - attributes=(json.loads(doc.get("attributes", "{}"))), - ), - score=abs(doc["@search.score"]), - ) - for doc in docs - ] - - def similarity_search_by_text( - self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any - ) -> list[VectorStoreSearchResult]: - """Perform a text-based similarity search.""" - query_embedding = text_embedder(text) - if query_embedding: - return self.similarity_search_by_vector( - query_embedding=query_embedding, k=k - ) - return [] diff --git a/backend/graphrag_app/main.py b/backend/graphrag_app/main.py index d8e681f..14e178d 100644 --- a/backend/graphrag_app/main.py +++ b/backend/graphrag_app/main.py @@ -26,7 +26,6 @@ from graphrag_app.api.index import index_route from graphrag_app.api.prompt_tuning import prompt_tuning_route from graphrag_app.api.query import query_route -from graphrag_app.api.query_streaming import query_streaming_route from graphrag_app.api.source import source_route from graphrag_app.logger.load_logger import load_pipeline_logger from graphrag_app.utils.azure_clients import AzureClientManager @@ -130,8 +129,11 @@ async def lifespan(app: FastAPI): root_path=os.getenv("API_ROOT_PATH", ""), title="GraphRAG", version=os.getenv("GRAPHRAG_VERSION", "undefined_version"), - lifespan=lifespan, + lifespan=lifespan + if os.getenv("KUBERNETES_SERVICE_HOST") + else None, # only set lifespan if running in AKS (by checking for a default k8s environment variable) ) + app.middleware("http")(catch_all_exceptions_middleware) app.add_middleware( CORSMiddleware, @@ -143,7 +145,7 @@ async def lifespan(app: FastAPI): app.include_router(data_route) app.include_router(index_route) app.include_router(query_route) -app.include_router(query_streaming_route) +# app.include_router(query_streaming_route) # temporarily disable streaming endpoints app.include_router(prompt_tuning_route) app.include_router(source_route) app.include_router(graph_route) diff --git a/backend/graphrag_app/typing/models.py b/backend/graphrag_app/typing/models.py index 702ef38..229356a 100644 --- a/backend/graphrag_app/typing/models.py +++ b/backend/graphrag_app/typing/models.py @@ -31,7 +31,7 @@ class EntityResponse(BaseModel): class GraphRequest(BaseModel): - index_name: str | List[str] + index_name: str query: str community_level: int | None = None diff --git a/notebooks/2-Advanced_Getting_Started.ipynb b/notebooks/2-Advanced_Getting_Started.ipynb index 21bb020..044d370 100644 --- a/notebooks/2-Advanced_Getting_Started.ipynb +++ b/notebooks/2-Advanced_Getting_Started.ipynb @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 3, "id": "9", "metadata": { "tags": [] @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 4, "id": "10", "metadata": {}, "outputs": [], @@ -181,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 5, "id": "12", "metadata": { "tags": [] @@ -190,7 +190,7 @@ "source": [ "def upload_files(\n", " file_directory: str,\n", - " storage_name: str,\n", + " container_name: str,\n", " batch_size: int = 100,\n", " overwrite: bool = True,\n", " max_retries: int = 5,\n", @@ -200,7 +200,7 @@ "\n", " Args:\n", " file_directory - a local directory of .txt files to upload. All files must be in utf-8 encoding.\n", - " storage_name - a unique name for the Azure storage container.\n", + " container_name - a unique name for the Azure storage container.\n", " batch_size - the number of files to upload in a single batch.\n", " overwrite - whether or not to overwrite files if they already exist in the storage container.\n", " max_retries - the maximum number of times to retry uploading a batch of files if the API is busy.\n", @@ -211,13 +211,13 @@ " url = endpoint + \"/data\"\n", "\n", " def upload_batch(\n", - " files: list, storage_name: str, overwrite: bool, max_retries: int\n", + " files: list, container_name: str, overwrite: bool, max_retries: int\n", " ) -> requests.Response:\n", " for _ in range(max_retries):\n", " response = requests.post(\n", " url=url,\n", " files=files,\n", - " params={\"storage_name\": storage_name, \"overwrite\": overwrite},\n", + " params={\"container_name\": container_name, \"overwrite\": overwrite},\n", " headers=headers,\n", " )\n", " # API may be busy, retry\n", @@ -246,25 +246,25 @@ " )\n", " # upload batch of files\n", " if len(batch_files) == batch_size:\n", - " response = upload_batch(batch_files, storage_name, overwrite, max_retries)\n", + " response = upload_batch(batch_files, container_name, overwrite, max_retries)\n", " # if response is not ok, return early\n", " if not response.ok:\n", " return response\n", " batch_files.clear()\n", " # upload remaining files\n", " if len(batch_files) > 0:\n", - " response = upload_batch(batch_files, storage_name, overwrite, max_retries)\n", + " response = upload_batch(batch_files, container_name, overwrite, max_retries)\n", " return response\n", "\n", "\n", - "def delete_files(storage_name: str) -> requests.Response:\n", - " \"\"\"Delete a blob storage container.\"\"\"\n", - " url = endpoint + f\"/data/{storage_name}\"\n", + "def delete_files(container_name: str) -> requests.Response:\n", + " \"\"\"Delete an azure storage container that holds raw data.\"\"\"\n", + " url = endpoint + f\"/data/{container_name}\"\n", " return requests.delete(url=url, headers=headers)\n", "\n", "\n", "def list_files() -> requests.Response:\n", - " \"\"\"List all data storage containers.\"\"\"\n", + " \"\"\"Get a list of all azure storage containers that hold raw data.\"\"\"\n", " url = endpoint + \"/data\"\n", " return requests.get(url=url, headers=headers)\n", "\n", @@ -290,19 +290,19 @@ " return requests.post(\n", " url,\n", " files=prompts if len(prompts) > 0 else None,\n", - " params={\"index_name\": index_name, \"storage_name\": storage_name},\n", + " params={\"index_container_name\": index_name, \"storage_container_name\": storage_name},\n", " headers=headers,\n", " )\n", "\n", "\n", - "def delete_index(index_name: str) -> requests.Response:\n", - " \"\"\"Delete a search index.\"\"\"\n", - " url = endpoint + f\"/index/{index_name}\"\n", + "def delete_index(container_name: str) -> requests.Response:\n", + " \"\"\"Delete an azure storage container that holds a search index.\"\"\"\n", + " url = endpoint + f\"/index/{container_name}\"\n", " return requests.delete(url, headers=headers)\n", "\n", "\n", "def list_indexes() -> list:\n", - " \"\"\"List all search indexes.\"\"\"\n", + " \"\"\"Get a list of all azure storage containers that hold search indexes.\"\"\"\n", " url = endpoint + \"/index\"\n", " response = requests.get(url, headers=headers)\n", " try:\n", @@ -313,8 +313,9 @@ " return response\n", "\n", "\n", - "def index_status(index_name: str) -> requests.Response:\n", - " url = endpoint + f\"/index/status/{index_name}\"\n", + "def index_status(container_name: str) -> requests.Response:\n", + " \"\"\"Get the status of a specific index.\"\"\"\n", + " url = endpoint + f\"/index/status/{container_name}\"\n", " return requests.get(url, headers=headers)\n", "\n", "\n", @@ -335,6 +336,7 @@ "def global_search_streaming(\n", " index_name: str | list[str], query: str, community_level: int\n", ") -> requests.Response:\n", + " raise NotImplementedError(\"this functionality has been temporarily removed\")\n", " \"\"\"Run a global query across one or more indexes and stream back the response\"\"\"\n", " url = endpoint + \"/query/streaming/global\"\n", " # optional parameter: community level to query the graph at (default for global query = 1)\n", @@ -379,6 +381,7 @@ "def local_search_streaming(\n", " index_name: str | list[str], query: str, community_level: int\n", ") -> requests.Response:\n", + " raise NotImplementedError(\"this functionality has been temporarily removed\")\n", " \"\"\"Run a global query across one or more indexes and stream back the response\"\"\"\n", " url = endpoint + \"/query/streaming/local\"\n", " # optional parameter: community level to query the graph at (default for local query = 2)\n", @@ -469,10 +472,10 @@ " return response\n", "\n", "\n", - "def generate_prompts(storage_name: str, limit: int = 1) -> None:\n", + "def generate_prompts(container_name: str, limit: int = 1) -> None:\n", " \"\"\"Generate graphrag prompts using data provided in a specific storage container.\"\"\"\n", " url = endpoint + \"/index/config/prompts\"\n", - " params = {\"storage_name\": storage_name, \"limit\": limit}\n", + " params = {\"container_name\": container_name, \"limit\": limit}\n", " return requests.get(url, params=params, headers=headers)" ] }, @@ -497,7 +500,7 @@ "source": [ "response = upload_files(\n", " file_directory=file_directory,\n", - " storage_name=storage_name,\n", + " container_name=storage_name,\n", " batch_size=100,\n", " overwrite=True,\n", ")\n", @@ -540,7 +543,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 35, "id": "18", "metadata": {}, "outputs": [], @@ -568,9 +571,11 @@ "metadata": {}, "outputs": [], "source": [ - "auto_template_response = generate_prompts(storage_name=storage_name, limit=1)\n", + "auto_template_response = generate_prompts(container_name=storage_name, limit=1)\n", "if auto_template_response.ok:\n", - " prompts = auto_template_response.json()" + " prompts = auto_template_response.json()\n", + "else:\n", + " print(auto_template_response.text)" ] }, { @@ -737,7 +742,6 @@ }, "outputs": [], "source": [ - "%%time\n", "# pass in a single index name as a string or to query across multiple indexes, set index_name=[myindex1, myindex2]\n", "global_response = global_search(\n", " index_name=index_name,\n", @@ -749,28 +753,6 @@ "global_response_data" ] }, - { - "cell_type": "markdown", - "id": "35", - "metadata": {}, - "source": [ - "An API endpoint has been designed to support streaming back the graphrag response while executing a global query (useful in applications like a chatbot)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36", - "metadata": {}, - "outputs": [], - "source": [ - "global_search_streaming(\n", - " index_name=index_name,\n", - " query=\"Summarize the main topics found in this data\",\n", - " community_level=1,\n", - ")" - ] - }, { "cell_type": "markdown", "id": "37", @@ -790,7 +772,6 @@ }, "outputs": [], "source": [ - "%%time\n", "# pass in a single index name as a string or to query across multiple indexes, set index_name=[myindex1, myindex2]\n", "local_response = local_search(\n", " index_name=index_name,\n", @@ -802,28 +783,6 @@ "local_response_data" ] }, - { - "cell_type": "markdown", - "id": "39", - "metadata": {}, - "source": [ - "An API endpoint has been designed to support streaming back the graphrag response while executing a local query (useful in applications like a chatbot)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "40", - "metadata": {}, - "outputs": [], - "source": [ - "local_search_streaming(\n", - " index_name=index_name,\n", - " query=\"Who are the primary actors in these communities?\",\n", - " community_level=2,\n", - ")" - ] - }, { "cell_type": "markdown", "id": "41", @@ -985,7 +944,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "id": "54", "metadata": {}, "outputs": [],