Skip to content

Commit

Permalink
Merge pull request #113 from wri/stac-tool
Browse files Browse the repository at this point in the history
Add support for stac-tool to query for satellite imagery
  • Loading branch information
yellowcap authored Jan 20, 2025
2 parents 746e9f2 + 10f1d7b commit befa3d2
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 25 deletions.
34 changes: 32 additions & 2 deletions frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ def display_message(message):
stats = data.get("content", {})
stats = json.loads(stats)
print(stats)
df = pd.DataFrame(list(stats.items()), columns=['Category', 'Value'])
st.bar_chart(df, x='Category', y='Value')
df = pd.DataFrame(
list(stats.items()), columns=["Category", "Value"]
)
st.bar_chart(df, x="Category", y="Value")

# plot the artifact which is a geojson featurecollection
artifact = data.get("artifact", {})
Expand All @@ -104,6 +106,28 @@ def display_message(message):
st.chat_message("assistant").write(
f"Adding context layer {message['content']}"
)
elif message["type"] == "stac":
st.chat_message("assistant").write(
"Found satellite images for your area of interest, here are the stac ids: "
)
data = message["content"]
artifact = data.get("artifact", {})
# create a grid of 2 x 5 images
cols = st.columns(5)
for idx, stac_item in enumerate(artifact["features"]):
stac_id = stac_item["id"]
stac_href = next(
(
link["href"]
for link in stac_item["links"]
if link["rel"] == "thumbnail"
),
None,
)
with cols[idx % 5]:
st.chat_message("assistant").image(
stac_href, caption=stac_id, width=100
)
else:
st.chat_message("assistant").write(message["content"])

Expand Down Expand Up @@ -142,6 +166,12 @@ def handle_stream_response(stream):
"type": "context",
"content": data["content"],
}
elif data.get("tool_name") == "stac-tool":
message = {
"role": "assistant",
"type": "stac",
"content": data,
}
else:
message = {
"role": "assistant",
Expand Down
3 changes: 2 additions & 1 deletion zeno/agents/zeno/agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from zeno.agents.contextfinder.tools import context_layer_tool
from zeno.agents.distalert.tools import dist_alerts_tool
from zeno.agents.location.multitools import location_tool
from zeno.tools.stac.stac_tool import stac_tool
from zeno.agents.zeno.models import ModelFactory
from langchain_anthropic import ChatAnthropic

haiku = ChatAnthropic(model="claude-3-5-haiku-latest", temperature=0)

tools_with_hil = [location_tool]
tools_with_hil_names = {t.name for t in tools_with_hil}
tools = [dist_alerts_tool, context_layer_tool]
tools = [dist_alerts_tool, context_layer_tool, stac_tool]

zeno_agent = haiku.bind_tools(tools + tools_with_hil)
6 changes: 5 additions & 1 deletion zeno/agents/zeno/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@
ask follow up questions without picking a default.
Current date: {current_date}.
"""
""".format(
current_date=current_date
)

print(ZENO_PROMPT)
46 changes: 25 additions & 21 deletions zeno/tools/stac/stac_tool.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import datetime
from pathlib import Path
from typing import Tuple

from langchain_core.tools import tool
from pydantic import BaseModel, Field
from pystac_client import Client
import geopandas as gpd

# Defaults to E84 AWS STAC catalog & Sentinel-2 L2A collection
CATALOG = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"
data_dir = Path("data")


class StacInput(BaseModel):
"""Input schema for STAC search tool"""

catalog: str = Field(
description="STAC catalog to use for search",
default="https://earth-search.aws.element84.com/v1",
)
collection: str = Field(
description="STAC Clollection to use", default="sentinel-2-l2a"
)
bbox: Tuple[float, float, float, float] = Field(
description="Bounding box for STAC search."
)
name: str = Field(description="Name of the area of interest")
gadm_id: str = Field(description="GADM ID of the area of interest")
gadm_level: int = Field(description="GADM level of the area of interest")
min_date: datetime.datetime = Field(
description="Earliest date for retrieving STAC items.",
)
Expand All @@ -33,28 +33,32 @@ class StacInput(BaseModel):
response_format="content_and_artifact",
)
def stac_tool(
bbox: Tuple[float, float, float, float],
name: str,
gadm_id: str,
gadm_level: int,
min_date: datetime.datetime,
max_date: datetime.datetime,
catalog: str = "https://earth-search.aws.element84.com/v1",
collection: str = "sentinel-2-l2a",
) -> dict:
"""Find locations and their administrative hierarchies given a place name.
Returns a list of IDs with matches at different administrative levels
"""
print("---SENTINEL-TOOL---")
"""Returns satellite images for a given area of interest."""
print("---STAC-TOOL---")

aoi_df = gpd.read_file(
data_dir / f"gadm_410_level_{gadm_level}.gpkg",
where=f"GID_{gadm_level} like '{gadm_id}'",
)
aoi = aoi_df.iloc[0]

catalog = Client.open(catalog)
catalog = Client.open(CATALOG)

query = catalog.search(
collections=[collection],
collections=[COLLECTION],
datetime=[min_date, max_date],
max_items=10,
bbox=bbox,
intersects=aoi.geometry,
)

items = list(query.items())
print(f"Found: {len(items):d} datasets")
print(f"Found: {len(items):d} recent STAC items")

# Convert STAC items into a GeoJSON FeatureCollection
stac_json = query.item_collection_as_dict()
Expand Down

0 comments on commit befa3d2

Please sign in to comment.