Skip to content

Commit

Permalink
Merge pull request #117 from wri/reactivate-tms-output
Browse files Browse the repository at this point in the history
Reactivate tms output using viz parameters
  • Loading branch information
yellowcap authored Jan 21, 2025
2 parents 8fef055 + 58d62ed commit cf7d904
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 64 deletions.
1 change: 0 additions & 1 deletion api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def event_stream_alerts(

for update in stream:
node = next(iter(update.keys()))
# node = list(update.keys())[0]

if node == "__interrupt__":
print("INTERRUPTED")
Expand Down
13 changes: 11 additions & 2 deletions frontend/pages/3_🦅_Earthy_Eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,18 @@ def display_message(message):
g = folium.GeoJson(artifact).add_to(m) # noqa: F841
folium_static(m, width=700, height=500)
elif message["type"] == "context":
data = message["content"]
st.chat_message("assistant").write(
f"Adding context layer {message['content']}"
f"Adding context layer {data['content']}"
)
m = folium.Map(location=[0, 0], zoom_start=3)
g = folium.TileLayer(
data['artifact']['tms_url'],
name=data['content'],
attr=data['content'],
).add_to(m) # noqa: F841
folium_static(m, width=700, height=500)

elif message["type"] == "stac":
st.chat_message("assistant").write(
"Found satellite images for your area of interest, here are the stac ids: "
Expand Down Expand Up @@ -158,7 +167,7 @@ def handle_stream_response(stream):
message = {
"role": "assistant",
"type": "context",
"content": data["content"],
"content": data,
}
elif data.get("tool_name") == "stac-tool":
message = {
Expand Down
6 changes: 3 additions & 3 deletions tests/test_context_layer_tool.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from zeno.tools.contextlayer.context_layer_retriever_tool import context_layer_tool
from zeno.agents.distalert.tool_context_layer import context_layer_tool


def test_context_layer_tool_cereal():
msg = context_layer_tool.invoke(
{
"name": "context-layer-tool",
"args": {"question": "Summarize disturbance alerts by type of cereal"},
"args": {"question": "Summarize disturbance alerts by natural lands"},
"id": "42",
"type": "tool_call",
}
)
assert msg.content == "ESA/WorldCereal/2021/MODELS/v100"
assert msg.content == "WRI/SBTN/naturalLands/v1/2020"
assert "{z}/{x}/{y}" in msg.artifact["tms_url"]
71 changes: 27 additions & 44 deletions tests/test_dist_agent.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,39 @@
import datetime
import uuid

from zeno.agents.distalert.agent import graph
from zeno.agents.maingraph.utils.state import GraphState
from langgraph.types import Command

from zeno.agents.distalert.graph import graph as dist_alert

def test_distalert_agent_level_2():

def test_distalert_agent():
"""
This test just runs the agent without checking any output, it is intended
to be used for debugging
"""
config = {
"configurable": {"thread_id": uuid.uuid4()},
}
initial_state = GraphState(
question="Provide data about disturbance alerts in Aveiro in 2023 summarized by natural lands"
)
for _, chunk in graph.stream(
initial_state,
query = "Provide data about disturbance alerts in Aveiro summarized by natural lands in 2023"
stream = dist_alert.stream(
{"messages": [query]},
stream_mode="updates",
subgraphs=True,
subgraphs=False,
config=config,
):
if "assistant" in chunk:
for msg in chunk["assistant"]["messages"]:
for call in msg.tool_calls:
print(call["name"])
if call["name"] == "location-tool":
assert call["args"]["gadm_level"] == 2
assert call["args"]["query"] == "Aveiro"
if call["name"] == "dist-alerts-tool":
assert call["args"]["min_date"] == "2023-01-01"
assert call["args"]["max_date"] == "2023-12-31"

if "tools" in chunk:
for msg in chunk["tools"]["messages"]:
if msg.name == "context-layer-tool":
assert "WRI/SBTN/naturalLands/v1" in msg.content


def test_distalert_agent_level_1():
config = {
"configurable": {"thread_id": uuid.uuid4()},
}
initial_state = GraphState(
question="Provide data about disturbance alerts in Florida summarized by natural lands in 2023"
)
for _, chunk in graph.stream(
initial_state,
for chunk in stream:
print(str(chunk)[:300], "\n")

query = "Averio"
stream = dist_alert.stream(
Command(
goto="dist_alert",
update={
"messages": [query],
},
),
stream_mode="updates",
subgraphs=True,
subgraphs=False,
config=config,
):
if "assistant" in chunk:
if chunk["assistant"]["messages"][0].tool_calls:
call = chunk["assistant"]["messages"][0].tool_calls[0]
if call["name"] == "location-tool":
assert call["args"]["gadm_level"] == 1
assert call["args"]["query"] == "Florida"
)
for chunk in stream:
print(str(chunk)[:300], "\n")
2 changes: 1 addition & 1 deletion tests/test_dist_alerts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime

from zeno.agents.distalert import tools
from zeno.agents.distalert import tool_dist_alerts as tools


def test_dist_alert_tool():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_location_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random

from zeno.agents.location.agent import location_agent
from zeno.agents.distalert.agent import dist_alert_agent


# Test data for level 1 locations
Expand Down Expand Up @@ -36,10 +36,10 @@ def test_location_agent():
# pick a random query each from LEVEL_1_TEST_DATA & LEVEL_2_TEST_DATA
query_1 = random.choice(LEVEL_1_TEST_DATA)
query_2 = random.choice(LEVEL_2_TEST_DATA)
result_1 = location_agent.invoke(
result_1 = dist_alert_agent.invoke(
{"messages": [("user", "Find the location of " + query_1[0])]}
)
result_2 = location_agent.invoke(
result_2 = dist_alert_agent.invoke(
{"messages": [("user", "Find the location of " + query_2[0])]}
)
print(result_1["messages"][-1].content[0]["text"])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_location_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from zeno.agents.location.tools import location_tool
from zeno.agents.distalert.tool_location import location_tool

# Test data for level 1 locations
LEVEL_1_TEST_DATA = [
Expand Down
2 changes: 1 addition & 1 deletion tests/test_stac_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime

from zeno.tools.stac.stac_tool import stac_tool
from zeno.agents.distalert.tool_stac import stac_tool


def test_stac_tool():
Expand Down
16 changes: 8 additions & 8 deletions zeno/agents/distalert/tool_context_layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from pathlib import Path
import json

import ee
import lancedb
Expand Down Expand Up @@ -28,10 +29,11 @@ def get_tms_url(result: Series):
else:
image = ee.Image(result.dataset)

# TODO: add dynamic viz parameters
map_id = image.select(result.band).getMapId(
visParams=result.visualization_parameters
)
if result.visualization_parameters:
viz_params = json.loads(result.visualization_parameters)
map_id = image.select(result.band).getMapId(viz_params)
else:
map_id = image.select(result.band).getMapId()

return map_id["tile_fetcher"].url_format

Expand Down Expand Up @@ -67,11 +69,9 @@ def context_layer_tool(question: str) -> dict:
.sort_values(by="year", ascending=False)
.iloc[0]
)

# tms_url = get_tms_url(result)

tms_url = get_tms_url(result)
result = result.to_dict()
# result["tms_url"] = tms_url
result["tms_url"] = tms_url

# Delete the dataset key vector as ndarray is not serializable
del result["vector"]
Expand Down

0 comments on commit cf7d904

Please sign in to comment.