From d8e1e05d4c0e0d267c05b08a34499e31cef7aa2e Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Fri, 17 Jan 2025 17:31:07 +0000 Subject: [PATCH] Buffer feature with test --- tests/test_dist_alerts.py | 29 +++ zeno/agents/distalert/tools.py | 31 ++- zeno/tools/distalert/dist_alerts_tool.py | 301 ----------------------- zeno/tools/distalert/drivers.py | 166 ------------- 4 files changed, 58 insertions(+), 469 deletions(-) delete mode 100644 zeno/tools/distalert/dist_alerts_tool.py delete mode 100644 zeno/tools/distalert/drivers.py diff --git a/tests/test_dist_alerts.py b/tests/test_dist_alerts.py index 955d1e2..6717657 100644 --- a/tests/test_dist_alerts.py +++ b/tests/test_dist_alerts.py @@ -114,3 +114,32 @@ def mockfunction(gadm_id, gadm_level): ) # Context layer type is as expected assert list(result.keys()) == [expected_natural_lands] + + +def test_dist_alert_tool_buffer(): + result = tools.dist_alerts_tool.invoke( + input={ + "name": "BRA.13.369_2", + "gadm_id": "BRA.13.369_2", + "gadm_level": 2, + "context_layer_name": "WRI/SBTN/naturalLands/v1/2020", + "threshold": 8, + "min_date": datetime.date(2021, 8, 12), + "max_date": datetime.date(2024, 8, 12), + } + ) + result_buffered = tools.dist_alerts_tool.invoke( + input={ + "name": "BRA.13.369_2", + "gadm_id": "BRA.13.369_2", + "gadm_level": 2, + "context_layer_name": "WRI/SBTN/naturalLands/v1/2020", + "threshold": 8, + "min_date": datetime.date(2021, 8, 12), + "max_date": datetime.date(2024, 8, 12), + "buffer_distance": 1000, + } + ) + assert ( + result["natural short vegetation"] < result_buffered["natural short vegetation"] + ) diff --git a/zeno/agents/distalert/tools.py b/zeno/agents/distalert/tools.py index 1c54201..948c5a0 100644 --- a/zeno/agents/distalert/tools.py +++ b/zeno/agents/distalert/tools.py @@ -9,6 +9,7 @@ from dotenv import load_dotenv from langchain_core.tools import tool from pydantic import BaseModel, Field +from pyproj import CRS from zeno.agents.contextfinder.tools import table as contextfinder_table from zeno.agents.distalert.drivers import DRIVER_VALUEMAP, get_drivers @@ -43,6 +44,9 @@ class DistAlertsInput(BaseModel): threshold: Optional[Literal[1, 2, 3, 4, 5, 6, 7, 8]] = Field( default=5, description="Threshold for disturbance alert scale" ) + buffer_distance: Optional[float] = Field( + default=None, description="Buffer distance in meters for buffering the features." + ) def get_date_mask(min_date: datetime.date, max_date: datetime.date) -> ee.image.Image: @@ -179,12 +183,32 @@ def get_distalerts_unfiltered( return zone_stats_result, vectorize -def get_features(gadm_id: str, gadm_level: int) -> ee.FeatureCollection: +def detect_utm_zone(lat, lon): + """ + Detect the UTM zone for a given latitude and longitude in WGS84. + """ + zone_number = int((lon + 180) // 6) + 1 + hemisphere = "north" if lat >= 0 else "south" + utm_crs = CRS.from_dict( + {"proj": "utm", "zone": zone_number, "south": hemisphere == "south"} + ) + return utm_crs + + +def get_features( + gadm_id: str, gadm_level: int, buffer_distance: float +) -> ee.FeatureCollection: aoi_df = gpd.read_file( Path("data") / f"gadm_410_level_{gadm_level}.gpkg", where=f"GID_{gadm_level} like '{gadm_id}'", ) aoi = aoi_df.geometry.iloc[0] + + if buffer_distance: + utm = detect_utm_zone(aoi.centroid.y, aoi.centroid.x) + aoi_df_utm = aoi_df.to_crs(utm) + aoi = aoi_df_utm.buffer(buffer_distance).to_crs(aoi_df.crs).iloc[0] + return ee.FeatureCollection([ee.Feature(aoi.__geo_interface__)]) @@ -202,6 +226,7 @@ def dist_alerts_tool( max_date: datetime.date, context_layer_name: Optional[str] = None, threshold: Optional[Literal[1, 2, 3, 4, 5, 6, 7, 8]] = 5, + buffer_distance: Optional[float] = None, ) -> dict: """ Dist alerts tool @@ -211,7 +236,9 @@ def dist_alerts_tool( """ print("---DIST ALERTS TOOL---") - gee_features = get_features(gadm_id=gadm_id, gadm_level=gadm_level) + gee_features = get_features( + gadm_id=gadm_id, gadm_level=gadm_level, buffer_distance=buffer_distance + ) distalerts = ee.ImageCollection(GEE_FOLDER + "VEG-DIST-STATUS").mosaic() date_mask = get_date_mask(min_date, max_date) diff --git a/zeno/tools/distalert/dist_alerts_tool.py b/zeno/tools/distalert/dist_alerts_tool.py deleted file mode 100644 index 7baf741..0000000 --- a/zeno/tools/distalert/dist_alerts_tool.py +++ /dev/null @@ -1,301 +0,0 @@ -import datetime -from typing import List, Literal, Optional, Tuple, Union - -import ee -import fiona -import googleapiclient -from dotenv import load_dotenv -from langchain_core.tools import tool -from pydantic import BaseModel, Field - -from zeno.tools.contextlayer.layers import layer_choices -from zeno.tools.distalert.drivers import DRIVER_VALUEMAP, get_drivers -from zeno.tools.distalert.gee import init_gee - -# Load environment variables -load_dotenv(".env") - -# Initialize gee -init_gee() - -gadm_1 = fiona.open("data/gadm_410_level_1.gpkg") -gadm_2 = fiona.open("data/gadm_410_level_2.gpkg") - -DIST_ALERT_REF_DATE = datetime.date(2020, 12, 31) -DIST_ALERT_SCALE = 30 -M2_TO_HA = 10000 -GEE_FOLDER = "projects/glad/HLSDIST/backend/" - - -class DistAlertsInput(BaseModel): - """Input schema for dist tool""" - - features: List[str] = Field( - description="List of GADM ids are used for zonal statistics" - ) - landcover: Optional[str] = Field( - default=None, - description="Landcover layer name to group zonal statistics by", - ) - threshold: Optional[Literal[1, 2, 3, 4, 5, 6, 7, 8]] = Field( - default=5, description="Threshold for disturbance alert scale" - ) - min_date: Optional[datetime.date] = Field( - default=None, - description="Cutoff date for alerts. Alerts before that date will be excluded.", - ) - max_date: Optional[datetime.date] = Field( - default=None, - description="Cutoff date for alerts. Alerts after that date will be excluded.", - ) - - -def print_meta( - layer: Union[ee.image.Image, ee.imagecollection.ImageCollection], -) -> None: - """Print layer metadata""" - # Get all metadata as a dictionary - metadata = layer.getInfo() - - # Print metadata - print("Image Metadata:") - for key, value in metadata.items(): - print(f"{key}: {value}") - - -def get_class_table( - band_name: str, - layer: Union[ee.image.Image, ee.imagecollection.ImageCollection], -) -> dict: - band_info = layer.select(band_name).getInfo() - - names = band_info["features"][0]["properties"][f"{band_name}_class_names"] - values = band_info["features"][0]["properties"][ - f"{band_name}_class_values" - ] - colors = band_info["features"][0]["properties"][ - f"{band_name}_class_palette" - ] - - pairs = [] - for name, color in zip(names, colors): - pairs.append({"name": name, "color": color}) - - return {val: pair for val, pair in zip(values, pairs)} - - -def get_date_mask( - min_date: datetime.date, max_date: datetime.date -) -> ee.image.Image: - today = datetime.date.today() - date_mask = None - if min_date and min_date > DIST_ALERT_REF_DATE and min_date < today: - days_passed = (today - min_date).days - days_since_start = (today - DIST_ALERT_REF_DATE).days - cutoff = days_since_start - days_passed - date_mask = ( - ee.ImageCollection(GEE_FOLDER + "VEG-DIST-DATE") - .mosaic() - .gte(cutoff) - .selfMask() - ) - - if max_date and max_date > DIST_ALERT_REF_DATE and max_date < today: - days_passed = (today - max_date).days - days_since_start = (today - DIST_ALERT_REF_DATE).days - cutoff = days_since_start - days_passed - date_mask_max = ( - ee.ImageCollection(GEE_FOLDER + "VEG-DIST-DATE") - .mosaic() - .lte(cutoff) - .selfMask() - ) - if date_mask: - date_mask = date_mask.And(date_mask_max) - else: - date_mask = date_mask_max - - return date_mask - - -def get_alerts_by_landcover( - distalerts: ee.Image, - landcover: str, - gee_features: ee.FeatureCollection, - date_mask: ee.Image, - threshold: int, -) -> Tuple[dict, ee.Image]: - lc_choice = [dat for dat in layer_choices if dat["dataset"] == landcover] - choice = {} - if lc_choice: - choice = lc_choice[0] - if choice["type"] == "ImageCollection": - landcover_layer = ee.ImageCollection(landcover) - else: - landcover_layer = ee.Image(landcover) - - if "class_table" in choice: - class_table = choice["class_table"] - else: - class_table = get_class_table(choice["band"], landcover_layer) - - if choice["type"] == "ImageCollection": - landcover_layer = landcover_layer.mosaic() - - landcover_layer = landcover_layer.select(choice["band"]) - else: - # TODO: replace this with a better selection. For now - # assumes if the choice did not exist that the drivers are requested. - landcover_layer = get_drivers() - class_table = { - val: {"name": key} for key, val in DRIVER_VALUEMAP.items() - } - - zone_stats_img = ( - distalerts.pixelArea() - .divide(M2_TO_HA) - .addBands(landcover_layer) - .updateMask(distalerts.gte(threshold)) - ) - if date_mask: - zone_stats_img = zone_stats_img.updateMask( - zone_stats_img.selfMask().And(date_mask) - ) - - zone_stats = zone_stats_img.reduceRegions( - collection=gee_features, - reducer=ee.Reducer.sum().group( - groupField=1, groupName=choice.get("band", "name") - ), - scale=choice.get("resolution", 30), - ).getInfo() - - zone_stats_result = {} - for feat in zone_stats["features"]: - if "GID_2" in feat["properties"]: - gadmid = feat["properties"]["GID_2"] - else: - gadmid = feat["properties"]["GID_1"] - - zone_stats_result[gadmid] = { - class_table[dat[choice.get("band", "name")]]["name"]: dat["sum"] - for dat in feat["properties"]["groups"] - } - vectorize = landcover_layer.updateMask(distalerts.gte(threshold)) - - return zone_stats_result, vectorize - - -def get_distalerts_unfiltered( - distalerts: ee.Image, - gee_features: ee.FeatureCollection, - date_mask: ee.Image, - threshold: int, -) -> Tuple[dict, ee.Image]: - zone_stats_img = ( - distalerts.pixelArea() - .divide(M2_TO_HA) - .updateMask(distalerts.gte(threshold)) - ) - if date_mask: - zone_stats_img = zone_stats_img.updateMask( - zone_stats_img.selfMask().And(date_mask) - ) - - zone_stats = zone_stats_img.reduceRegions( - collection=gee_features, - reducer=ee.Reducer.sum(), - scale=DIST_ALERT_SCALE, - ).getInfo() - - zone_stats_result = {} - for feat in zone_stats["features"]: - if "GID_2" in feat["properties"]: - gadmid = feat["properties"]["GID_2"] - else: - gadmid = feat["properties"]["GID_1"] - zone_stats_result[gadmid] = {"disturbances": feat["properties"]["sum"]} - - vectorize = ( - distalerts.gte(threshold) - .updateMask(distalerts.gte(threshold)) - .selfMask() - ) - return zone_stats_result, vectorize - - -def get_features(features: List[str]) -> ee.FeatureCollection: - if features[0].count(".") == 2: - gadm = gadm_2 - gadm_level = 2 - else: - gadm = gadm_1 - gadm_level = 1 - - matches = [ - next(gadm.filter(where=f"GID_{gadm_level} = '{id}'")) - for id in features - ] - - return ee.FeatureCollection( - [ee.Feature(dat.__geo_interface__) for dat in matches] - ) - - -@tool( - "dist-alerts-tool", - args_schema=DistAlertsInput, - return_direct=True, - response_format="content_and_artifact", -) -def dist_alerts_tool( - features: List[str], - landcover: Optional[str] = None, - threshold: Optional[Literal[1, 2, 3, 4, 5, 6, 7, 8]] = 5, - min_date: Optional[datetime.date] = None, - max_date: Optional[datetime.date] = None, -) -> dict: - """ - Dist alerts tool - - This tool quantifies vegetation disturbance alerts over an area of interest - and summarizes the alerts in statistics by landcover types. - """ - print("---DIST ALERTS TOOL---") - distalerts = ee.ImageCollection(GEE_FOLDER + "VEG-DIST-STATUS").mosaic() - - gee_features = get_features(features) - - date_mask = get_date_mask(min_date, max_date) - - if landcover: - zone_stats_result, vectorize = get_alerts_by_landcover( - distalerts=distalerts, - landcover=landcover, - gee_features=gee_features, - date_mask=date_mask, - threshold=threshold, - ) - else: - zone_stats_result, vectorize = get_distalerts_unfiltered( - distalerts=distalerts, - gee_features=gee_features, - date_mask=date_mask, - threshold=threshold, - ) - - # Vectorize the masked classification - vectors = vectorize.reduceToVectors( - geometryType="polygon", - scale=DIST_ALERT_SCALE, - maxPixels=1e8, - geometry=gee_features, - eightConnected=True, - ) - - try: - vectorized = vectors.getInfo() - except googleapiclient.errors.HttpError: - vectorized = {} - - return zone_stats_result, vectorized diff --git a/zeno/tools/distalert/drivers.py b/zeno/tools/distalert/drivers.py deleted file mode 100644 index 0a7be8e..0000000 --- a/zeno/tools/distalert/drivers.py +++ /dev/null @@ -1,166 +0,0 @@ -import ee - -# Example from https://code.earthengine.google.com/06f83955d12f278414f8c53655abdd6d - -DRIVER_VALUEMAP = { - "wildfire": 1, - "crop_cycle": 2, - "flooding": 3, - "conversion": 4, - "other_conversion": 5, -} - - -def get_drivers(): - folder = "projects/glad/HLSDIST/backend" - natural_lands = ee.Image("WRI/SBTN/naturalLands/v1/2020").select("natural") - vegdistcount = ee.ImageCollection(folder + "/VEG-DIST-COUNT").mosaic() - veganommax = ee.ImageCollection(folder + "/VEG-ANOM-MAX").mosaic() - confmask = vegdistcount.gte(2).And(veganommax.gt(50)) - - wf_collection = ee.ImageCollection.fromImages( - [ - ee.Image( - "projects/wri-dist-alert-drivers/assets/wildfire/dist-alert-wildfire-africa-nov2023-oct2024_v01" - ), - ee.Image( - "projects/wri-dist-alert-drivers/assets/wildfire/dist-alert-wildfire-europe-nov2023-oct2024_v01" - ), - ee.Image( - "projects/wri-dist-alert-drivers/assets/wildfire/dist-alert-wildfire-latam-nov2023-oct2024_v01" - ), - ee.Image( - "projects/wri-dist-alert-drivers/assets/wildfire/dist-alert-wildfire-ne-asia-nov2023-oct2024_v01" - ), - ee.Image( - "projects/wri-dist-alert-drivers/assets/wildfire/dist-alert-wildfire-north-am-nov2023-oct2024_v01" - ), - ee.Image( - "projects/wri-dist-alert-drivers/assets/wildfire/dist-alert-wildfire-se-asia-oceania-nov2023-oct2024_v01" - ), - ] - ) - - wildfire = ( - wf_collection.mosaic() - .neq(0) - .updateMask(confmask) - .updateMask(natural_lands) - ) - - cc_collection = ee.ImageCollection.fromImages( - [ - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-crop-cycle-africa-nov2023-oct2024-tiles-all" - ), - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-crop-cycle-europe-nov2023-oct2024-tiles-all" - ), - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-crop-cycle-latam-nov2023-oct2024-tiles-all" - ), - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-crop-cycle-northam-nov2023-oct2024-tiles-all" - ), - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-crop-cycle-seasia_oceania-nov2023-oct2024-tiles-all" - ), - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-crop-cycle-neasia-nov2023-oct2024-tiles-all" - ), - ] - ) - - crop_cycle = cc_collection.mosaic().updateMask(confmask) - - fl_collection = ee.ImageCollection.fromImages( - [ - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-flooding-africa-nov2023-oct2024-tiles-all" - ), - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-flooding-europe-nov2023-oct2024-tiles-all" - ), - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-flooding-latam-nov2023-oct2024-tiles-all" - ), - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-flooding-northam-nov2023-oct2024-tiles-all" - ), - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-flooding-seasia-oceania-nov2023-oct2024-tiles-all" - ), - ee.Image( - "projects/ee-jamesmaccarthy-wri/assets/dist-alert-flooding-neasia-nov2023-oct2024-tiles-all" - ), - ] - ) - - flooding = ( - fl_collection.mosaic() - .updateMask(confmask) - .updateMask(crop_cycle.unmask(2).neq(1)) - .updateMask(wildfire.eq(0).unmask(1)) - ) - - cv_collection = ee.ImageCollection.fromImages( - [ - ee.Image( - "projects/wri-dist-alert-drivers/assets/conversion/dist-alert-conversion-africa-nov2023-oct2024_v01" - ), - ee.Image( - "projects/wri-dist-alert-drivers/assets/conversion/dist-alert-conversion-europe-nov2023-oct2024_v01" - ), - ee.Image( - "projects/wri-dist-alert-drivers/assets/conversion/dist-alert-conversion-latam-nov2023-oct2024_v01" - ), - ee.Image( - "projects/wri-dist-alert-drivers/assets/conversion/dist-alert-conversion-ne-asia-nov2023-oct2024_v01" - ), - ee.Image( - "projects/wri-dist-alert-drivers/assets/conversion/dist-alert-conversion-north-am-nov2023-oct2024_v01" - ), - ee.Image( - "projects/wri-dist-alert-drivers/assets/conversion/dist-alert-conversion-se-asia-oceania-nov2023-oct2024_v01" - ), - ] - ) - - conversion = ( - cv_collection.mosaic() - .updateMask(confmask) - .updateMask(natural_lands) - .updateMask(crop_cycle.unmask(2).neq(1)) - .updateMask(wildfire.unmask(2).neq(1)) - ) - - other_conversion = ( - cv_collection.mosaic() - .updateMask(confmask) - .updateMask(natural_lands.eq(0)) - .updateMask(crop_cycle.unmask(2).neq(1)) - .updateMask(wildfire.unmask(2).neq(1)) - ) - - combo = ( - wildfire.multiply(DRIVER_VALUEMAP["wildfire"]) - .unmask() - .add(crop_cycle.multiply(DRIVER_VALUEMAP["crop_cycle"]).unmask()) - .add(flooding.multiply(DRIVER_VALUEMAP["flooding"]).unmask()) - .add(conversion.multiply(DRIVER_VALUEMAP["conversion"]).unmask()) - .add( - other_conversion.multiply( - DRIVER_VALUEMAP["other_conversion"] - ).unmask() - ) - ) - combo_mask = ( - wildfire.mask() - .Or(crop_cycle.mask()) - .Or(flooding.mask()) - .Or(conversion.mask()) - .Or(other_conversion.mask()) - ) - combo = combo.updateMask(combo_mask) - - return combo