diff --git a/api.py b/api.py index cab3dbe..2bcc3d6 100644 --- a/api.py +++ b/api.py @@ -173,7 +173,6 @@ async def stream_docfinder( def event_stream_layerfinder( query: str, - ds_id: Optional[str] = None, thread_id: Optional[str] = None, ): if not thread_id: @@ -181,7 +180,7 @@ def event_stream_layerfinder( config = {"configurable": {"thread_id": thread_id}} stream = layerfinder.stream( - {"question": query, "messages": [HumanMessage(query)], "ds_id": ds_id}, + {"question": query, "messages": [HumanMessage(query)]}, stream_mode="updates", subgraphs=False, config=config, @@ -221,10 +220,9 @@ def event_stream_layerfinder( async def stream_layerfinder( query: Annotated[str, Body(embed=True)], thread_id: Optional[str] = Body(None), - ds_id: Optional[str] = Body(None), ): return StreamingResponse( - event_stream_layerfinder(query=query, thread_id=thread_id, ds_id=ds_id), + event_stream_layerfinder(query=query, thread_id=thread_id), media_type="application/x-ndjson", ) diff --git a/frontend/app.py b/frontend/app.py index b157b4c..57bcec5 100644 --- a/frontend/app.py +++ b/frontend/app.py @@ -49,15 +49,10 @@ # Agent data agents = [ - { - "name": "Docu Dodo 🐥", - "tagline": "A trusty agent that digs through documents to find the information you need.", - "description": "Specializes in finding and analyzing WRI & LCL documents. Can search through various document types, extract key information, and provide relevant summaries.", - }, { "name": "Owl Gorithm 🦉", - "tagline": "A wise, data-savvy agent for discovering relevant datasets.", - "description": "Expert at finding relevant datasets hosted by WRI & LCL. It tries its best to find the dataset & explain why it is relevant to your query.", + "tagline": "A wise, data-savvy agent for WRI content such as blog posts and datasets.", + "description": "Expert at finding relevant content datasets hosted by WRI & LCL. It tries its best to find the dataset & explain why it is relevant to your query.", }, { "name": "Earthy Eagle 🦅", diff --git "a/frontend/pages/1_\360\237\220\245_Docu_Dodo.py" "b/frontend/pages/1_\360\237\220\245_Docu_Dodo.py" deleted file mode 100644 index 156b2b4..0000000 --- "a/frontend/pages/1_\360\237\220\245_Docu_Dodo.py" +++ /dev/null @@ -1,76 +0,0 @@ -import json -import os -import uuid - -import requests -import streamlit as st -from dotenv import load_dotenv - -load_dotenv() - -API_BASE_URL = os.environ.get("API_BASE_URL") - -if "docfinder_session_id" not in st.session_state: - st.session_state.docfinder_session_id = str(uuid.uuid4()) -if "docfinder_messages" not in st.session_state: - st.session_state.docfinder_messages = [] - -st.header("Docu Dodo 🐥") -st.caption( - "Zeno's Docu Dodo, a trusty agent that digs through documents to find the information you need." -) - -with st.sidebar: - st.header("🐥") - st.write( - """ - Docu Dodo is expert at finding useful information from WRI & LCL documents. Give it a try! - """ - ) - - st.subheader("🧐 Try asking:") - st.write( - """ - - How many users are using GFW and how long did it take to get there? - """ - ) - - -def display_message(message): - if message["role"] == "user": - st.chat_message("user").write(message["content"]) - else: - st.chat_message("assistant").write(message["content"]) - - -def handle_stream_response(stream): - for chunk in stream.iter_lines(): - data = json.loads(chunk.decode("utf-8")) - message = { - "role": "assistant", - "type": "text", - "content": data["content"], - } - st.session_state.docfinder_messages.append(message) - display_message(message) - - -# Display chat history -for message in st.session_state.docfinder_messages: - display_message(message) - - -if user_input := st.chat_input("Type your message here..."): - message = {"role": "user", "content": user_input, "type": "text"} - st.session_state.docfinder_messages.append(message) - display_message(message) - - with requests.post( - f"{API_BASE_URL}/stream/docfinder", - json={ - "query": user_input, - "thread_id": st.session_state.docfinder_session_id, - }, - stream=True, - ) as stream: - handle_stream_response(stream) diff --git a/tests/test_layerfinder_graph.py b/tests/test_layerfinder_graph.py index e8fa1c8..9fe85f4 100644 --- a/tests/test_layerfinder_graph.py +++ b/tests/test_layerfinder_graph.py @@ -45,7 +45,8 @@ def test_layerfinder_agent_detail(): def test_layerfinder_agent_doc_route(): - query = "How many users are using GFW and how long did it take to get there?" + # query = "How many users are using GFW and how long did it take to get there?" + query = "What do you know about indonesia?" stream = layerfinder.invoke( {"question": query}, stream_mode="updates", diff --git a/zeno/agents/docfinder/graph.py b/zeno/agents/docfinder/graph.py index 0348ba5..db8cba3 100644 --- a/zeno/agents/docfinder/graph.py +++ b/zeno/agents/docfinder/graph.py @@ -7,7 +7,10 @@ from langgraph.graph import END, START, StateGraph from pydantic import BaseModel -from zeno.agents.docfinder.prompts import DOCUMENTS_FOR_DATASETS_PROMPT, GENERATE_PROMPT +from zeno.agents.docfinder.prompts import ( + DOCUMENTS_FOR_DATASETS_PROMPT, + GENERATE_PROMPT, +) from zeno.agents.docfinder.state import DocFinderState from zeno.agents.docfinder.tool_document_retrieve import vectorstore @@ -37,7 +40,6 @@ def generate_node(state: DocFinderState, config: RunnableConfig): for msg in state["messages"]: if isinstance(msg, HumanMessage): questions += ", " + msg.content - print("QUESTION", questions) prompt = GENERATE_PROMPT.format(questions=questions, context=context) diff --git a/zeno/agents/layerfinder/graph.py b/zeno/agents/layerfinder/graph.py index e02683f..54015f0 100644 --- a/zeno/agents/layerfinder/graph.py +++ b/zeno/agents/layerfinder/graph.py @@ -11,7 +11,6 @@ from zeno.agents.layerfinder.prompts import ( DATASETS_FOR_DOCS_PROMPT, LAYER_CAUTIONS_PROMPT, - LAYER_DETAILS_PROMPT, LAYER_FINDER_PROMPT, ROUTING_PROMPT, ) @@ -43,9 +42,9 @@ def retrieve_node(state: LayerFinderState): if isinstance(msg, HumanMessage): questions += ", " + msg.content context = [msg for msg in state["messages"] if isinstance(msg, AIMessage)][ - 0 - ].content - question = questions + context + -1 + ] + question = questions + context.content search_result = db.similarity_search_with_relevance_scores( question, k=10, score_threshold=0.3 @@ -110,32 +109,15 @@ def docfinder_node(state: LayerFinderState): return docfinder.invoke([HumanMessage(content=state["question"])]) -def explain_details_node(state: LayerFinderState): - print("---EXPLAIN DETAILS---") - ds_id = state["ds_id"] - dataset = [ds for ds in state["datasets"] if ds_id == ds.metadata["dataset"]] - if not dataset: - return {"messages": [AIMessage("No dataset found")]} - else: - dataset = dataset[0] - prompt = LAYER_DETAILS_PROMPT.format( - context=dataset.page_content, question=state["question"] - ) - response = haiku.invoke(prompt) - return {"messages": [response]} - - wf = StateGraph(LayerFinderState) wf.add_node("retrieve", retrieve_node) -wf.add_node("detail", explain_details_node) wf.add_node("cautions", cautions_node) wf.add_node("docfinder", docfinder) wf.add_conditional_edges(START, route_node) wf.add_edge("retrieve", "cautions") wf.add_edge("cautions", END) -wf.add_edge("detail", END) wf.add_edge("docfinder", END) memory = MemorySaver()