Skip to content

Commit

Permalink
Add agent to answer queries about KBAs (#118)
Browse files Browse the repository at this point in the history
* Add agent to answer queries about KBAs

* Improve the prompts
  • Loading branch information
srmsoumya authored Jan 23, 2025
1 parent 677b1c1 commit ff5281c
Show file tree
Hide file tree
Showing 11 changed files with 529 additions and 15 deletions.
81 changes: 81 additions & 0 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from zeno.agents.distalert.graph import graph as dist_alert
from zeno.agents.docfinder.graph import graph as docfinder
from zeno.agents.layerfinder.graph import graph as layerfinder
from zeno.agents.kba.graph import graph as kba

app = FastAPI()
# # langfuse_handler = CallbackHandler()
Expand Down Expand Up @@ -220,3 +221,83 @@ async def stream_layerfinder(
event_stream_layerfinder(query, thread_id),
media_type="application/x-ndjson",
)


def event_stream_kba(
query: str,
user_persona: Optional[str] = None,
thread_id: Optional[str] = None,
):
if not thread_id:
thread_id = str(uuid.uuid4())

config = {"configurable": {"thread_id": thread_id}}
query = HumanMessage(content=query, name="human")
stream = kba.stream(
{"messages": [query], "user_persona": user_persona},
stream_mode="updates",
subgraphs=False,
config=config,
)

for update in stream:
node = next(iter(update.keys()))

if node == "kba_response_node":
report = update[node]["report"].to_dict()
summary = report["summary"]
metrics = report["metrics"]
regional_breakdown = report["regional_breakdown"]
actions = report["actions"]
data_gaps = report["data_gaps"]
yield pack(
{
"node": node,
"type": "report",
"summary": summary,
"metrics": metrics,
"regional_breakdown": regional_breakdown,
"actions": actions,
"data_gaps": data_gaps,
}
)
else:
messages = update[node]["messages"]
if node == "tools":
for message in messages:
message.pretty_print()
yield pack(
{
"node": node,
"type": "tool_call",
"tool_name": message.name,
"content": message.content,
"artifact": (
message.artifact
if hasattr(message, "artifact")
else None
),
}
)
else:
for message in messages:
message.pretty_print()
yield pack(
{
"node": node,
"type": "update",
"content": message.content,
}
)


@app.post("/stream/kba")
async def stream_kba(
query: Annotated[str, Body(embed=True)],
user_persona: Optional[str] = Body(None),
thread_id: Optional[str] = Body(None),
):
return StreamingResponse(
event_stream_kba(query, user_persona, thread_id),
media_type="application/x-ndjson",
)
9 changes: 7 additions & 2 deletions frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@
},
{
"name": "Earthy Eagle 🦅",
"tagline": "An eagle-eyed agent focused on detecting distribution or deforestation alerts.",
"description": "Specializes in detecting distribution alerts. It assists in finding alerts for specific locations and timeframes. Additionally, it helps in understanding the distribution of alerts within a location and provides satellite images for validation.",
"tagline": "An eagle-eyed agent focused on detecting disturbances or deforestation alerts.",
"description": "Specializes in detecting disturbances or deforestation alerts. It assists in finding alerts for specific locations and timeframes. Additionally, it helps in understanding the distribution of alerts within a location and provides satellite images for validation.",
},
{
"name": "Keeper Kaola 🐨",
"tagline": "Keeping a watch over the worlds Key Biodiversity Areas (KBAs).",
"description": "Specializing in planning interventions and answering queries about KBAs - from habitat analysis to species protection strategies. Keeper Koala helps ensure critical ecosystems get the attention they need.",
},
]

Expand Down
16 changes: 5 additions & 11 deletions frontend/pages/3_🦅_Earthy_Eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@

st.header("Earthy Eagle 🦅")
st.caption(
"Zeno's Earthy Eagle is an eagle-eyed agent focused on detecting distribution alerts."
"Zeno's Earthy Eagle is an eagle-eyed agent focused on detecting disturbances or deforestation alerts."
)

# Sidebar content
with st.sidebar:
st.header("🦅")
st.write(
"""
Earthy Eagle specializes in detecting distribution alerts. It assists in finding alerts for specific locations and timeframes.
Earthy Eagle specializes in detecting disturbances or deforestation alerts. It assists in finding alerts for specific locations and timeframes.
Additionally, it helps in understanding the distribution of alerts within a location and provides satellite images for validation.
"""
)
Expand All @@ -52,9 +52,7 @@ def display_message(message):
st.chat_message("user").write(message["content"])
elif message["role"] == "assistant":
if message["type"] == "location":
st.chat_message("assistant").write(
"Found location you searched for..."
)
st.chat_message("assistant").write("Found location you searched for...")
data = message["content"]
artifact = data.get("artifact", {})
artifact = artifact[0]
Expand All @@ -79,9 +77,7 @@ def display_message(message):
stats = data.get("content", {})
stats = json.loads(stats)
print(stats)
df = pd.DataFrame(
list(stats.items()), columns=["Category", "Value"]
)
df = pd.DataFrame(list(stats.items()), columns=["Category", "Value"])
st.bar_chart(df, x="Category", y="Value")

# plot the artifact which is a geojson featurecollection
Expand Down Expand Up @@ -216,9 +212,7 @@ def handle_stream_response(stream):
display_message(message)

# If we were waiting for input, this is a response to an interrupt
query_type = (
"human_input" if st.session_state.waiting_for_input else "query"
)
query_type = "human_input" if st.session_state.waiting_for_input else "query"

# Reset the waiting_for_input state
if st.session_state.waiting_for_input:
Expand Down
151 changes: 151 additions & 0 deletions frontend/pages/4_🐨_Keeper_Kaola.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import json
import os
import uuid

import folium
import requests
import streamlit as st
from dotenv import load_dotenv
from streamlit_folium import folium_static

load_dotenv()

API_BASE_URL = os.environ.get("API_BASE_URL")


if "kba_session_id" not in st.session_state:
st.session_state.kba_session_id = str(uuid.uuid4())
if "kba_messages" not in st.session_state:
st.session_state.kba_messages = []


# Add a callback function to reset the session state
def reset_state():
st.session_state.kba_session_id = str(uuid.uuid4())
st.session_state.kba_messages = []
st.session_state.custom_persona = ""


st.header("Keeper Kaola 🐨")
st.caption(
"Zeno's Keeper Kaola, keeping a watch over the worlds Key Biodiversity Areas (KBAs)."
)

with st.sidebar:
st.header("🐥")
st.write(
"""
Keeper Kaola is an expert at planning interventions and answering queries about KBAs - from habitat analysis to species protection strategies.
"""
)

# Add user persona selection
st.subheader("Select or Enter User Persona")
user_personas = [
"I am a conservation manager responsible for overseeing a network of Key Biodiversity Areas. I have basic GIS skills, I am comfortable visualising data but not conducting advanced analysis. I need to identify and understand threats, such as illegal logging or habitat degradation, and monitor changes in ecosystem health over time to allocate resources effectively and plan conservation interventions.",
"I am a program manager implementing nature-based solutions projects focused on agroforestry and land restoration. I am comfortable using tools like QGIS for mapping and visualisation. I need to track project outcomes, such as tree cover gain and carbon sequestration, and prioritise areas for intervention based on risks like soil erosion or forest loss.",
"I am an investment analyst for an impact fund supporting reforestation and agroforestry projects. I have limited GIS skills and rely on intuitive dashboards or visualisations to understand geospatial insights. I need independent geospatial insights to monitor portfolio performance, assess project risks, and ensure investments align with our net-zero commitments.",
"I am a sustainability manager responsible for ensuring our company’s agricultural supply chains meet conversion-free commitments. I have limited GIS skills and can only use simple web-based tools or dashboards. I need to monitor and address risks such as land conversion to maintain compliance and support sustainable sourcing decisions.",
"I am an advocacy program manager for an NGO working on Indigenous Peoples’ land rights. I have basic GIS skills, enabling me to visualise data but not perform advanced analysis. I need to use data to highlight land use changes, advocate for stronger tenure policies, and empower local communities to monitor their territories.",
"I am a journalist covering environmental issues and corporate accountability, with basic GIS skills that enable me to interpret geospatial data by eye but not produce charts or insights myself. I need reliable, accessible data to track whether companies are meeting their EU Deforestation Regulation (EUDR) commitments, identify instances of non-compliance, and write compelling, data-driven stories that hold businesses accountable for their environmental impact.",
]

selected_persona = st.selectbox(
"Choose a persona", user_personas, on_change=reset_state
)
custom_persona = st.text_input("Or enter a custom persona", on_change=reset_state)

# Determine active persona
active_persona = custom_persona if custom_persona else selected_persona
if st.session_state.get("active_persona") != active_persona:
st.session_state.active_persona = active_persona
reset_state()
st.rerun()

if st.session_state.get("active_persona"):
st.success(f"**{st.session_state.active_persona}**", icon="🕵️‍♂️")


def display_message(message):
if message["role"] == "user":
st.chat_message("user").write(message["content"])
else:
if message["type"] == "kba_location":
st.chat_message("assistant").write(
"Found Key Biodiversity Areas in your area of interest..."
)
data = message["content"]
artifact = data.get("artifact", {})
artifact = json.loads(artifact)
print(artifact)
# plot the artifact which is a geojson featurecollection using folium
geometry = artifact["features"][0]["geometry"]
if geometry["type"] == "Polygon":
pnt = geometry["coordinates"][0][0]
else:
pnt = geometry["coordinates"][0][0][0]
m = folium.Map(location=[pnt[1], pnt[0]], zoom_start=11)
g = folium.GeoJson(artifact).add_to(m) # noqa: F841
folium_static(m, width=700, height=500)
elif message["type"] == "report":
st.chat_message("assistant").write(message["summary"])
st.chat_message("assistant").write(message["metrics"])
st.chat_message("assistant").write(message["regional_breakdown"])
st.chat_message("assistant").write(message["actions"])
st.chat_message("assistant").write(message["data_gaps"])
elif message["type"] == "update":
st.chat_message("assistant").write(message["content"])


def handle_stream_response(stream):
for chunk in stream.iter_lines():
data = json.loads(chunk.decode("utf-8"))

if data.get("type") == "report":
message = {
"role": "assistant",
"type": "report",
"summary": data["summary"],
"metrics": data["metrics"],
"regional_breakdown": data["regional_breakdown"],
"actions": data["actions"],
"data_gaps": data["data_gaps"],
}
elif data.get("type") == "update":
message = {
"role": "assistant",
"type": "update",
"content": data["content"],
}
elif data.get("type") == "tool_call":
message = {
"role": "assistant",
"type": "kba_location",
"content": data,
}
st.session_state.kba_messages.append(message)
display_message(message)


# Display chat history
if st.session_state.active_persona:
for message in st.session_state.kba_messages:
display_message(message)

if user_input := st.chat_input("Type your message here..."):
message = {"role": "user", "content": user_input, "type": "text"}
st.session_state.kba_messages.append(message)
display_message(message)

with requests.post(
f"{API_BASE_URL}/stream/kba",
json={
"query": user_input,
"user_persona": st.session_state.active_persona, # Include persona in the request
"thread_id": st.session_state.kba_session_id,
},
stream=True,
) as stream:
handle_stream_response(stream)
else:
st.write("Please select or enter a user persona to start the chat.")
7 changes: 5 additions & 2 deletions zeno/agents/distalert/tool_dist_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class DistAlertsInput(BaseModel):
default=5, description="Threshold for disturbance alert scale"
)
buffer_distance: Optional[float] = Field(
default=None, description="Buffer distance in meters for buffering the features."
default=None,
description="Buffer distance in meters for buffering the features.",
)


Expand Down Expand Up @@ -187,7 +188,9 @@ def get_distalerts_unfiltered(
scale=DIST_ALERT_STATS_SCALE,
).getInfo()

zone_stats_result= {"disturbances": zone_stats["features"][0]["properties"]["sum"]}
zone_stats_result = {
"disturbances": zone_stats["features"][0]["properties"]["sum"]
}

vectorize = (
distalerts.gte(threshold)
Expand Down
Empty file added zeno/agents/kba/__init__.py
Empty file.
62 changes: 62 additions & 0 deletions zeno/agents/kba/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import json

from langchain_anthropic import ChatAnthropic
from pydantic import BaseModel, Field

from zeno.agents.kba.tool_kba_info import kba_info_tool

from pydantic import BaseModel, Field
from typing import List, Dict, Optional


# Add this mixin to all models
class JSONSerializable(BaseModel):
def to_dict(self):
return json.loads(json.dumps(self, default=lambda o: o.dict()))


class KBAMetrics(JSONSerializable):
total_kbas: int
threatened_kbas: int
protected_coverage: float
key_species: int
habitat_types: List[Dict[str, float]]
threat_categories: List[Dict[str, float]]


class RegionalStats(JSONSerializable):
region_name: str
kba_count: int
protection_status: float
primary_threats: List[str]
trend: float # positive/negative change


class KBAActionItem(JSONSerializable):
priority: str # "High", "Medium", "Low"
area: str
issue: str
recommended_action: str
expected_impact: str
timeframe: str


class KBAResponse(JSONSerializable):
summary: str = Field(
description="Concise summary highlighting key patterns and critical insights"
)
metrics: KBAMetrics = Field(description="Core KBA statistics for visualization")
regional_breakdown: List[RegionalStats] = Field(
description="Geographic distribution and trends"
)
actions: List[KBAActionItem] = Field(description="Prioritized conservation actions")
data_gaps: List[str] = Field(description="Missing or incomplete data areas")


# haiku = ChatAnthropic(model="claude-3-5-haiku-latest", temperature=0)
sonnet = ChatAnthropic(model="claude-3-5-sonnet-latest", temperature=0)


tools = [kba_info_tool]
kba_info_agent = sonnet.bind_tools(tools)
kba_response_agent = sonnet.with_structured_output(KBAResponse)
Loading

0 comments on commit ff5281c

Please sign in to comment.