Skip to content

Commit

Permalink
Merge pull request #32 from wri/streamlit-simple
Browse files Browse the repository at this point in the history
Fix streamlit app
  • Loading branch information
yellowcap authored Nov 29, 2024
2 parents 109788c + fa1526c commit a460761
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 141 deletions.
13 changes: 8 additions & 5 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,17 @@ def event_stream(query: str):
):
print(f"Namespace {namespace}")
for key, val in data.items():
print(f"Messager is {key}")
print(f"Messenger is {key}")
if key == "agent":
continue
for key2, val2 in val.items():
if key2 == "messages":
for msg in val.get("messages", []):
yield pack({"message": msg.content})
if hasattr(msg, "tool_calls"):
for msg in val2:
if msg.content:
yield pack({"message": msg.content})
if hasattr(msg, "tool_calls") and msg.tool_calls:
yield pack({"tool_calls": msg.tool_calls})
if hasattr(msg, "artifact"):
if hasattr(msg, "artifact") and msg.artifact:
yield pack({"artifact": msg.artifact})


Expand Down
156 changes: 22 additions & 134 deletions frontend/app.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,28 @@
import json
import os

import folium
import requests
import os

import streamlit as st
from streamlit_folium import st_folium
from dotenv import load_dotenv
from streamlit_folium import folium_static

_ = load_dotenv()
API_BASE_URL = os.environ.get("API_BASE_URL")

API_BASE_URL = os.environ["API_BASE_URL"]


st.set_page_config(page_icon="images/resource-racoon.jpg", layout="wide")
st.header("Resource Raccoon")
st.set_page_config(page_icon="images/zeno.jpg", layout="wide")
st.header("Zeno")
st.caption("Your intelligent EcoBot, saving the forest faster than a 🐼 eats bamboo")

# Sidebar content
with st.sidebar:
st.image("images/resource-racoon.jpg")
st.header("Meet Resource Raccoon!")
st.image("images/zeno.jpg")
st.header("Meet Zeno!")
st.write(
"""
**Resource Raccoon** is your AI sidekick at WRI, trained on all your blog posts! It is a concious consumer and is consuming a local produce only. It can help you with questions about your blog posts. Give it a try!
**Zeno** is your AI sidekick, trained on all your blog posts! It is a concious consumer and is consuming a local produce only. It can help you with questions about your blog posts. Give it a try!
"""
)

# st.subheader("Select a model:")
# available_models = requests.get(f"{API_BASE_URL}/models").json()["models"]

# model = st.selectbox(
# "Model", format_func=lambda x: x["model_name"], options=available_models
# )

st.subheader("🧐 Try asking:")
st.write(
"""
Expand All @@ -50,10 +39,7 @@
"""
)

# Note: the following section is commented to preseve the work that @DanielW has
# done to enable the streaming response of the chat messages.

# =========== BEGIN STREAMING RESPONSE ===============
if user_input := st.chat_input("Type your message here..."):
st.chat_message("user").write(user_input)
with requests.post(
Expand All @@ -63,115 +49,17 @@
) as stream:
for chunk in stream.iter_lines():
data = json.loads(chunk.decode("utf-8"))
st.write(data)
# =========== /END STREAMING RESPONSE ===============

# # Initialize session state for messages and selected dataset
# if "messages" not in st.session_state:
# st.session_state["messages"] = []
# if "selected_dataset" not in st.session_state:
# st.session_state["selected_dataset"] = None
# if "route" not in st.session_state:
# st.session_state["route"] = None

# col1, col2 = st.columns([4, 6])


# def display_in_streamlit(base64_string):
# image_html = f'<img src="data:image/png;base64,{base64_string}">'
# st.markdown(image_html, unsafe_allow_html=True)


# # Left column (40% width) - Chat Interface
# with col1:
# # User input and API call - only happens on new input
# user_input = st.text_input("You:", key="user_input")
# if user_input and user_input not in [
# msg.get("user", "") for msg in st.session_state["messages"]
# ]:
# response = requests.post(
# f"{API_BASE_URL}/query",
# json={"query": user_input, "model_id": model["model_id"]},
# )
# data = response.json()
# st.session_state["route"] = data["route"]
# print(data)
# # datasets = json.loads(data["messages"][0]["content"])

# try:
# st.session_state["messages"] = []
# st.session_state["messages"].append({"user": user_input})
# st.session_state["messages"].append({"bot": data})
# except Exception as e:
# st.error(f"Error processing response: {str(e)}")

# # Display conversation and dataset buttons
# for msg_idx, message in enumerate(st.session_state["messages"]):
# if "user" in message:
# st.write(f"**You**: {message['user']}")
# else:
# st.write("**Assistant**:")
# data = message["bot"]
# try:
# match st.session_state["route"]:
# case "layerfinder":
# datasets = json.loads(data["messages"][0]["content"])
# for idx, dataset in enumerate(datasets):
# st.write(f"**Dataset {idx+1}:** {dataset['explanation']}")
# st.write(f"**URL**: {dataset['uri']}")

# # Generate a unique key for each button that includes both message and dataset index
# button_key = f"dataset_{msg_idx}_{idx}"
# if st.button(f"Show Dataset {idx+1}", key=button_key):
# st.session_state["selected_dataset"] = dataset[
# "tilelayer"
# ]
# print(f"changed state to: {dataset['tilelayer']}")
# case "firealert":
# for msg in data["messages"]:
# if (
# msg["name"] != "barchart-tool"
# ): # Only print non-chart messages
# st.write(msg["content"])
# case "docfinder":
# for msg in data["messages"]:
# st.write(msg["content"])
# # st.write(data["messages"][0]["content"])
# case _:
# st.write("Unable to find an agent for task")
# except Exception as e:
# st.error(f"Error processing response: {str(e)}")

# # Right column (60% width) - Map Visualization
# with col2:
# if st.session_state["route"] == "layerfinder":
# st.header("Map Visualization")
# m = folium.Map(location=[0, 0], zoom_start=2)

# if st.session_state["selected_dataset"]:
# print("yes")
# folium.TileLayer(
# tiles=st.session_state["selected_dataset"],
# attr="Global Forest Watch",
# name="Selected Dataset",
# overlay=True,
# control=True,
# ).add_to(m)

# folium.LayerControl().add_to(m)
# st_folium(m, width=700, height=500)
# elif st.session_state["route"] == "firealert":
# st.header("Fire Alert Statistics")
# # Display barchart from the most recent message
# if st.session_state["messages"]:
# for message in reversed(st.session_state["messages"]):
# if "bot" in message:
# data = message["bot"]
# for msg in data["messages"]:
# if msg["name"] == "barchart-tool":
# display_in_streamlit(msg["content"])
# break
# break
# else:
# st.header("Visualization")
# st.write("Select a dataset or query to view visualization")
if data.get("artifact", {}).get("type") == "FeatureCollection":
geom = data.get("artifact")["features"][0]["geometry"]
if geom["type"] == "Polygon":
pnt = geom["coordinates"][0][0]
else:
pnt = geom["coordinates"][0][0][0]

m = folium.Map(location=[pnt[1], pnt[0]], zoom_start=11)
g = folium.GeoJson(
data.get("artifact"),
).add_to(m)
folium_static(m, width=700, height=500)
else:
st.write(data)
File renamed without changes
1 change: 1 addition & 0 deletions frontend/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
streamlit==1.40.1
streamlit_folium==0.23.0
2 changes: 0 additions & 2 deletions zeno/tools/distalert/context_layer_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from zeno.agents.maingraph.models import ModelFactory

# init_gee()


class grade(BaseModel):
"""Binary score for relevance check."""
Expand Down

0 comments on commit a460761

Please sign in to comment.