Skip to content

Commit

Permalink
Distalerts drivers (#101)
Browse files Browse the repository at this point in the history
* Inital implementation of using drivers for agregation

* Fix test and improve prompt for drivers
  • Loading branch information
yellowcap authored Jan 13, 2025
1 parent 5d539fb commit e1564e8
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 83 deletions.
9 changes: 8 additions & 1 deletion tests/test_dist_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import uuid

from zeno.agents.distalert.agent import graph
Expand All @@ -9,7 +10,7 @@ def test_distalert_agent_level_2():
"configurable": {"thread_id": uuid.uuid4()},
}
initial_state = GraphState(
question="Provide data about disturbance alerts in Aveiro summarized by natural lands in 2023"
question="Provide data about disturbance alerts in Aveiro in 2023 summarized by natural lands"
)
for _, chunk in graph.stream(
initial_state,
Expand All @@ -20,8 +21,13 @@ def test_distalert_agent_level_2():
if "assistant" in chunk:
for msg in chunk["assistant"]["messages"]:
for call in msg.tool_calls:
print(call["name"])
if call["name"] == "location-tool":
assert call["args"]["gadm_level"] == 2
assert call["args"]["query"] == "Aveiro"
if call["name"] == "dist-alerts-tool":
assert call["args"]["min_date"] == "2023-01-01"
assert call["args"]["max_date"] == "2023-12-31"

if "tools" in chunk:
for msg in chunk["tools"]["messages"]:
Expand All @@ -47,3 +53,4 @@ def test_distalert_agent_level_1():
call = chunk["assistant"]["messages"][0].tool_calls[0]
if call["name"] == "location-tool":
assert call["args"]["gadm_level"] == 1
assert call["args"]["query"] == "Florida"
2 changes: 1 addition & 1 deletion tests/test_location_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def test_location_tool_name():
fids = location_tool.invoke(input={"query": "Puri India", "gadm_level": 2})
assert len(fids) == 3
assert len(fids) == 4
assert fids[0] == "IND.26.26_1"


Expand Down
11 changes: 7 additions & 4 deletions zeno/agents/distalert/utils/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

tools = [dist_alerts_tool, context_layer_tool, location_tool]

# model = ModelFactory().get("claude-3-5-sonnet-latest").bind_tools(tools)
model = ModelFactory().get("claude-3-5-sonnet-latest").bind_tools(tools)
# model = ModelFactory().get("qwen2.5:7b").bind_tools(tools)
# model = ModelFactory().get("gpt-3.5-turbo").bind_tools(tools)
model = ModelFactory().get("gpt-4o-mini").bind_tools(tools)
# model = ModelFactory().get("gpt-4o-mini").bind_tools(tools)


def assistant(state):
Expand All @@ -29,13 +29,16 @@ def assistant(state):
Think through the solution step-by-step first and then execute.
A context layer can be used to summarize vegetation disturbances by things like landcover or tree height categories.
If such a context layer analysis is is requested, obtain the context layer using the `context-layer-tool`.
If such a context layer analysis is is requested, obtain the context layer using the `context-layer-tool`
Use the `location-tool` to get polygons of any region or place by name. There are two levels, 1 is for
state/province/regional analysis, and 2 is for smaller areas like municiaplities and counties. Use 2 level by default,
and level 1 if someone asks for state/province/regional analysis.
Use the `dist-alerts-tool` to get vegetation disturbance information, pass the context layer and the location as input.
If the user asks for summarizing disturbance alerts by underlying cause or driver, do not use the `context-layer-tool`
and pass `distalert-drivers` in the `landcover` argument like so `landcover="distalert-drivers"`. This will summarize
the disturbance alerts by driver.
"""
)

Expand Down Expand Up @@ -69,7 +72,7 @@ def human_review_location(state):
if action == "continue":
return Command(goto="assistant")
elif action == "update":
last_msg.content = json.dumps(options[option])
last_msg.content = json.dumps([options[option]])
last_msg.artifact = artifact
return Command(goto="assistant")
elif action == "feedback":
Expand Down
203 changes: 126 additions & 77 deletions zeno/tools/distalert/dist_alerts_tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import List, Literal, Optional, Union
from typing import List, Literal, Optional, Tuple, Union

import ee
import fiona
Expand All @@ -9,6 +9,7 @@
from pydantic import BaseModel, Field

from zeno.tools.contextlayer.layers import layer_choices
from zeno.tools.distalert.drivers import DRIVER_VALUEMAP, get_drivers
from zeno.tools.distalert.gee import init_gee

# Load environment variables
Expand All @@ -21,6 +22,7 @@

DIST_ALERT_REF_DATE = datetime.date(2020, 12, 31)
DIST_ALERT_SCALE = 30
M2_TO_HA = 10000


class DistAlertsInput(BaseModel):
Expand All @@ -44,6 +46,7 @@ class DistAlertsInput(BaseModel):
description="Cutoff date for alerts. Alerts after that date will be excluded.",
)


def print_meta(
layer: Union[ee.image.Image, ee.imagecollection.ImageCollection]
) -> None:
Expand Down Expand Up @@ -73,34 +76,7 @@ def get_class_table(
return {val: pair for val, pair in zip(values, pairs)}


@tool(
"dist-alerts-tool",
args_schema=DistAlertsInput,
return_direct=True,
response_format="content_and_artifact",
)
def dist_alerts_tool(
features: List[str],
landcover: Optional[str] = None,
threshold: Optional[Literal[1, 2, 3, 4, 5, 6, 7, 8]] = 5,
min_date: Optional[datetime.date] = None,
max_date: Optional[datetime.date] = None,
) -> dict:
"""
Dist alerts tool
This tool quantifies vegetation disturbance alerts over an area of interest
and summarizes the alerts in statistics by landcover types.
"""
print("---DIST ALERTS TOOL---")
distalerts = ee.ImageCollection(
"projects/glad/HLSDIST/current/VEG-DIST-STATUS"
).mosaic()

gee_features = ee.FeatureCollection(
[ee.Feature(gadm[int(id)].__geo_interface__) for id in features]
)

def get_date_mask(min_date: datetime.date, max_date: datetime.date) -> ee.image.Image:
today = datetime.date.today()
date_mask = None
if min_date and min_date > DIST_ALERT_REF_DATE and min_date < today:
Expand Down Expand Up @@ -129,8 +105,19 @@ def dist_alerts_tool(
else:
date_mask = date_mask_max

if landcover:
choice = [dat for dat in layer_choices if dat["dataset"] == landcover][0]
return date_mask


def get_alerts_by_landcover(
distalerts: ee.Image,
landcover: ee.Image,
gee_features: ee.FeatureCollection,
date_mask: ee.Image,
threshold: int,
) -> Tuple[dict, ee.Image]:
choice = [dat for dat in layer_choices if dat["dataset"] == landcover]
if choice:
choice = choice[0]
if choice["type"] == "ImageCollection":
landcover_layer = ee.ImageCollection(landcover)
else:
Expand All @@ -140,57 +127,119 @@ def dist_alerts_tool(
class_table = choice["class_table"]
else:
class_table = get_class_table(choice["band"], landcover_layer)
else:
# TODO: replace this with a better selection. For now
# assumes if the choice did not exist that the drivers are requested.
landcover_layer = get_drivers()
class_table = {val: key for key, val in DRIVER_VALUEMAP.items()}

if choice["type"] == "ImageCollection":
landcover_layer = landcover_layer.mosaic()
if choice["type"] == "ImageCollection":
landcover_layer = landcover_layer.mosaic()

landcover_layer = landcover_layer.select(choice["band"])
landcover_layer = landcover_layer.select(choice["band"])

zone_stats_img = (
distalerts.pixelArea()
.divide(10000)
.addBands(landcover_layer)
.updateMask(distalerts.gte(threshold))
)
if date_mask:
zone_stats_img = zone_stats_img.updateMask(
zone_stats_img.selfMask().And(date_mask)
)

zone_stats = zone_stats_img.reduceRegions(
collection=gee_features,
reducer=ee.Reducer.sum().group(groupField=1, groupName=choice["band"]),
scale=choice["resolution"],
).getInfo()

zone_stats_result = {}
for feat in zone_stats["features"]:
zone_stats_result[feat["properties"]["gadmid"]] = {
class_table[dat[choice["band"]]]["name"]: dat["sum"]
for dat in feat["properties"]["groups"]
}
vectorize = landcover_layer.updateMask(distalerts.gte(threshold))
else:
zone_stats_img = (
distalerts.pixelArea().divide(10000).updateMask(distalerts.gte(threshold))
zone_stats_img = (
distalerts.pixelArea()
.divide(M2_TO_HA)
.addBands(landcover_layer)
.updateMask(distalerts.gte(threshold))
)
if date_mask:
zone_stats_img = zone_stats_img.updateMask(
zone_stats_img.selfMask().And(date_mask)
)
if date_mask:
zone_stats_img = zone_stats_img.updateMask(
zone_stats_img.selfMask().And(date_mask)
)

zone_stats = zone_stats_img.reduceRegions(
collection=gee_features,
reducer=ee.Reducer.sum(),
scale=DIST_ALERT_SCALE,
).getInfo()

zone_stats_result = {
feat["properties"]["gadmid"]: {"disturbances": feat["properties"]["sum"]}
for feat in zone_stats["features"]

zone_stats = zone_stats_img.reduceRegions(
collection=gee_features,
reducer=ee.Reducer.sum().group(groupField=1, groupName=choice["band"]),
scale=choice["resolution"],
).getInfo()

zone_stats_result = {}
for feat in zone_stats["features"]:
zone_stats_result[feat["properties"]["gadmid"]] = {
class_table[dat[choice["band"]]]["name"]: dat["sum"]
for dat in feat["properties"]["groups"]
}
vectorize = (
distalerts.gte(threshold).updateMask(distalerts.gte(threshold)).selfMask()
vectorize = landcover_layer.updateMask(distalerts.gte(threshold))

return zone_stats_result, vectorize


def get_distalerts_unfiltered(
distalerts: ee.Image,
gee_features: ee.FeatureCollection,
date_mask: ee.Image,
threshold: int,
) -> Tuple[dict, ee.Image]:
zone_stats_img = (
distalerts.pixelArea().divide(M2_TO_HA).updateMask(distalerts.gte(threshold))
)
if date_mask:
zone_stats_img = zone_stats_img.updateMask(
zone_stats_img.selfMask().And(date_mask)
)

zone_stats = zone_stats_img.reduceRegions(
collection=gee_features,
reducer=ee.Reducer.sum(),
scale=DIST_ALERT_SCALE,
).getInfo()

zone_stats_result = {
feat["properties"]["gadmid"]: {"disturbances": feat["properties"]["sum"]}
for feat in zone_stats["features"]
}
vectorize = (
distalerts.gte(threshold).updateMask(distalerts.gte(threshold)).selfMask()
)
return zone_stats_result, vectorize


@tool(
"dist-alerts-tool",
args_schema=DistAlertsInput,
return_direct=True,
response_format="content_and_artifact",
)
def dist_alerts_tool(
features: List[str],
landcover: Optional[str] = None,
threshold: Optional[Literal[1, 2, 3, 4, 5, 6, 7, 8]] = 5,
min_date: Optional[datetime.date] = None,
max_date: Optional[datetime.date] = None,
) -> dict:
"""
Dist alerts tool
This tool quantifies vegetation disturbance alerts over an area of interest
and summarizes the alerts in statistics by landcover types.
"""
print("---DIST ALERTS TOOL---")
distalerts = ee.ImageCollection(
"projects/glad/HLSDIST/current/VEG-DIST-STATUS"
).mosaic()

gee_features = ee.FeatureCollection(
[ee.Feature(gadm[int(id)].__geo_interface__) for id in features]
)

date_mask = get_date_mask(min_date, max_date)

if landcover:
zone_stats_result, vectorize = get_alerts_by_landcover(
distalerts=distalerts,
landcover=landcover,
gee_features=gee_features,
date_mask=date_mask,
threshold=threshold,
)
else:
zone_stats_result, vectorize = get_distalerts_unfiltered(
distalerts=distalerts,
gee_features=gee_features,
date_mask=date_mask,
threshold=threshold,
)

# Vectorize the masked classification
Expand Down
Loading

0 comments on commit e1564e8

Please sign in to comment.