-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
339 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
import streamlit as st | ||
from frontend.services.auth import AuthService | ||
from frontend.services.conversation import ConversationService | ||
from frontend.utils.logger import setup_logger | ||
|
||
logger = setup_logger(__name__) | ||
|
||
|
||
class ChatApp: | ||
def __init__(self): | ||
# Initialize session state if it is not yet initialized | ||
if "logged_in" not in st.session_state: | ||
st.session_state.logged_in = False | ||
if "token" not in st.session_state: | ||
st.session_state.token = None | ||
|
||
def run(self) -> None: | ||
if not st.session_state.logged_in: | ||
self.show_auth_page() | ||
else: | ||
self.show_conversation_page() | ||
|
||
def show_auth_page(self) -> None: | ||
st.title("Login/Register") | ||
choice = st.radio("Choose Action", ("Login", "Register")) | ||
email = st.text_input("Email") | ||
password = st.text_input("Password", type="password") | ||
if st.button(f"{choice}"): | ||
if choice == "Login": | ||
self.login(email, password) | ||
else: | ||
self.register(email, password) | ||
|
||
def login(self, email: str, password: str) -> None: | ||
token = AuthService.login(email, password) | ||
if token: | ||
st.session_state.token = token | ||
st.session_state.logged_in = True # Mark user as logged in | ||
st.success("Login successful!") | ||
st.rerun() | ||
else: | ||
st.error("Login failed!") | ||
|
||
def register(self, email: str, password: str) -> None: | ||
response = AuthService.register(email, password) | ||
if response: | ||
st.success("Registered successfully! Please log in.") | ||
else: | ||
st.error("Registration failed!") | ||
|
||
def show_conversation_page(self) -> None: | ||
try: | ||
# Fetch the list of conversations | ||
conversations = ConversationService.get_conversations( | ||
st.session_state.token | ||
) | ||
|
||
if conversations: | ||
# Store the conversations into the session state | ||
st.session_state.conversations = conversations | ||
|
||
# Function to truncate titles | ||
def truncate_title(title, max_length=40): | ||
return ( | ||
title | ||
if len(title) <= max_length | ||
else title[:max_length] + "..." | ||
) | ||
|
||
# Display Conversations in sidebar | ||
st.sidebar.title("Conversations") | ||
|
||
# Button to create a new conversation | ||
if st.sidebar.button("Create New Conversation"): | ||
self.create_new_conversation() | ||
|
||
# Iterate over all conversations to display them | ||
for conv in conversations: | ||
# Create a readable title with truncation if necessary | ||
title = f"Conversation {conv['id']} ({conv['created_at']})" | ||
truncated_title = truncate_title(title) | ||
|
||
# Use a button for each truncated conversation to handle selection | ||
if st.sidebar.button(truncated_title): | ||
# Set the selected conversation ID | ||
st.session_state.selected_conversation_id = conv["id"] | ||
|
||
# Load and display messages for the selected conversation | ||
if "selected_conversation_id" in st.session_state: | ||
st.write( | ||
f"Selected Conversation: {st.session_state.selected_conversation_id}" | ||
) | ||
self.load_messages() | ||
|
||
# Allow the user to send a message | ||
self.send_message() | ||
else: | ||
st.error("Failed to load conversations.") | ||
except Exception as ex: | ||
st.error("An error occurred while loading conversations.") | ||
logger.error(f"Exception: {ex}") | ||
|
||
def create_new_conversation(self) -> None: | ||
try: | ||
response = ConversationService.create_conversation(st.session_state.token) | ||
if response: | ||
st.success("New conversation created successfully!") | ||
st.rerun() # Refresh to update the conversation list | ||
else: | ||
st.error("Failed to create a new conversation.") | ||
except Exception as ex: | ||
st.error("An error occurred while creating a new conversation.") | ||
logger.error(f"Exception: {ex}") | ||
|
||
def load_messages(self) -> None: | ||
if "selected_conversation_id" in st.session_state: | ||
try: | ||
# Fetch messages using the ConversationService | ||
messages = ConversationService.get_messages( | ||
st.session_state.token, st.session_state.selected_conversation_id | ||
) | ||
if messages: | ||
# Implement CSS for fixed layout | ||
st.markdown( | ||
""" | ||
<style> | ||
.message-box { | ||
border-radius: 8px; | ||
padding: 10px; | ||
margin-bottom: 10px; | ||
box-shadow: 1px 1px 5px rgba(0, 0, 0, 0.1); | ||
} | ||
.message-container { | ||
max-height: 400px; /* Adjust as needed */ | ||
overflow-y: auto; | ||
margin-bottom: 20px; | ||
} | ||
</style> | ||
""", | ||
unsafe_allow_html=True, | ||
) | ||
st.write("Messages:") | ||
# Wrap messages in a scrollable container | ||
st.markdown( | ||
'<div class="message-container">', unsafe_allow_html=True | ||
) | ||
|
||
# Reverse the order of messages for display | ||
for message in reversed(messages["messages"]): | ||
# Set border color based on role | ||
if message["role"] == "user": | ||
border_color = "#0288d1" # Blue for user messages | ||
elif message["role"] == "assistant": | ||
border_color = "#7cb342" # Green for assistant messages | ||
else: | ||
border_color = "#c2185b" # Pink for other roles | ||
|
||
# Join content of each message for display | ||
content = " ".join(message["content"]) | ||
message_html = f""" | ||
<div class="message-box" style="border: 2px solid {border_color};"> | ||
<strong>{message['role']}:</strong> {content} | ||
</div> | ||
""" | ||
# Use st.markdown with unsafe_allow_html=True to render each message | ||
st.markdown(message_html, unsafe_allow_html=True) | ||
# Close the scrollable message container | ||
st.markdown("</div>", unsafe_allow_html=True) | ||
else: | ||
st.error("Failed to load messages.") | ||
except Exception as ex: | ||
st.error("An error occurred while loading messages.") | ||
logger.error(f"Exception: {ex}") | ||
|
||
def send_message(self) -> None: | ||
new_message = st.text_input( | ||
"Your message:", placeholder="Type your message here...", key="new_message" | ||
) | ||
|
||
if st.button("Send"): | ||
try: | ||
response = ConversationService.send_message( | ||
st.session_state.token, | ||
st.session_state.selected_conversation_id, | ||
new_message, # Use the session state value | ||
) | ||
if response: | ||
st.rerun() | ||
else: | ||
st.error("Failed to send message.") | ||
except Exception as ex: | ||
st.error("An error occurred while sending the message.") | ||
logger.error(f"Exception: {ex}") | ||
|
||
|
||
if __name__ == "__main__": | ||
app = ChatApp() | ||
app.run() |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import requests | ||
from typing import Dict, Any, Optional | ||
from frontend.utils.logger import setup_logger | ||
from dotenv import load_dotenv | ||
import os | ||
|
||
logger = setup_logger(__name__) | ||
load_dotenv() | ||
|
||
|
||
class AuthService: | ||
BASE_URL = os.getenv("BACKEND_URL") | ||
|
||
@classmethod | ||
def register(cls, email: str, password: str) -> Optional[Dict[str, Any]]: | ||
try: | ||
response = requests.post( | ||
f"{cls.BASE_URL}/register", json={"email": email, "password": password} | ||
) | ||
response.raise_for_status() | ||
return response.json() | ||
except requests.RequestException as e: | ||
logger.error(f"Registration failed: {e}") | ||
return None | ||
|
||
@classmethod | ||
def login(cls, email: str, password: str) -> Optional[str]: | ||
try: | ||
response = requests.post( | ||
f"{cls.BASE_URL}/login", json={"email": email, "password": password} | ||
) | ||
response.raise_for_status() | ||
token = response.cookies.get("token") | ||
return token | ||
except requests.RequestException as e: | ||
logger.error(f"Login failed: {e}") | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import requests | ||
from typing import Dict, Any, Optional | ||
from frontend.utils.logger import setup_logger | ||
from dotenv import load_dotenv | ||
import os | ||
|
||
logger = setup_logger(__name__) | ||
load_dotenv() | ||
|
||
|
||
class ConversationService: | ||
BASE_URL = os.getenv("BACKEND_URL") | ||
|
||
@classmethod | ||
def get_conversations(cls, token: str) -> Optional[Dict[str, Any]]: | ||
# Set token as 'token' in the cookies | ||
cookies = {"token": token} | ||
try: | ||
response = requests.get(f"{cls.BASE_URL}/conversations", cookies=cookies) | ||
response.raise_for_status() | ||
return response.json() | ||
except requests.RequestException as e: | ||
logger.error(f"Fetching conversations failed: {e}") | ||
return None | ||
|
||
@classmethod | ||
def create_conversation(cls, token: str) -> Optional[Dict[str, Any]]: | ||
# Set token as 'token' in the cookies | ||
cookies = {"token": token} | ||
try: | ||
response = requests.post(f"{cls.BASE_URL}/conversations", cookies=cookies) | ||
response.raise_for_status() | ||
return response.json() | ||
except requests.RequestException as e: | ||
logger.error(f"Creating conversation failed: {e}") | ||
return None | ||
|
||
@classmethod | ||
def get_messages(cls, token: str, conversation_id: str) -> Optional[Dict[str, Any]]: | ||
# Set token as 'token' in the cookies | ||
cookies = {"token": token} | ||
try: | ||
response = requests.get( | ||
f"{cls.BASE_URL}/conversations/{conversation_id}/messages", | ||
cookies=cookies, | ||
) | ||
response.raise_for_status() | ||
return response.json() | ||
except requests.RequestException as e: | ||
logger.error(f"Fetching messages failed: {e}") | ||
return None | ||
|
||
@classmethod | ||
def send_message( | ||
cls, token: str, conversation_id: str, message: str | ||
) -> Optional[Dict[str, Any]]: | ||
# Set token as 'token' in the cookies | ||
cookies = {"token": token} | ||
try: | ||
response = requests.post( | ||
f"{cls.BASE_URL}/conversations/{conversation_id}/messages", | ||
json={"message": message}, | ||
cookies=cookies, | ||
) | ||
response.raise_for_status() | ||
return response.json() | ||
except requests.RequestException as e: | ||
logger.error(f"Sending message failed: {e}") | ||
return None |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import logging | ||
|
||
|
||
def setup_logger(name: str) -> logging.Logger: | ||
logger = logging.getLogger(name) | ||
if not logger.hasHandlers(): | ||
logger.setLevel(logging.DEBUG) | ||
handler = logging.StreamHandler() | ||
formatter = logging.Formatter( | ||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s" | ||
) | ||
handler.setFormatter(formatter) | ||
logger.addHandler(handler) | ||
return logger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import streamlit as st | ||
|
||
|
||
class SessionState: | ||
def __init__(self): | ||
self.token: str = "" | ||
self.conversations = [] | ||
self.selected_conversation_id = None | ||
self.logged_in: bool = False | ||
|
||
@staticmethod | ||
def get() -> "SessionState": | ||
if not hasattr(st.session_state, "session"): | ||
st.session_state.session = SessionState() | ||
return st.session_state.session |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,4 @@ alembic | |
gunicorn | ||
flask-cors | ||
alembic | ||
flask_swagger_ui | ||
streamlit |