From 480d7f99b8e82b963291142b6283977ec423b834 Mon Sep 17 00:00:00 2001 From: Ali Nabipour Date: Sat, 12 Oct 2024 01:43:55 +0200 Subject: [PATCH] add openai assistant --- app.py | 219 +++++++++++++++++++++++++++++++++++---- app/assistants/openai.py | 140 +++++++++++++++++++++++++ app/core/__init__.py | 0 3 files changed, 337 insertions(+), 22 deletions(-) create mode 100644 app/assistants/openai.py delete mode 100644 app/core/__init__.py diff --git a/app.py b/app.py index 1d93b03..e46efba 100644 --- a/app.py +++ b/app.py @@ -7,6 +7,7 @@ import logging from typing import Iterator, Dict, Any from http import HTTPStatus +from app.assistants.openai import OpenAIAssistant app = Flask(__name__) @@ -20,6 +21,8 @@ ) logger = logging.getLogger(__name__) +openai_client = OpenAIAssistant() + def get_database_session() -> Iterator[Session]: db = next(get_db()) @@ -40,9 +43,9 @@ def register() -> Dict[str, Any]: password = data.get("password") if not email or not password: - logger.warning("Missing email or password") + logger.warning("Missing email or password in JSON data") return ( - jsonify({"error": "Email and password are required."}), + jsonify({"error": "Email and password are required"}), HTTPStatus.BAD_REQUEST, ) @@ -74,11 +77,16 @@ def register() -> Dict[str, Any]: @app.route("/login", methods=["POST"]) def login() -> Dict[str, Any]: - email = request.form.get("email") - password = request.form.get("password") + data = request.get_json() + if not data: + logger.warning("Empty JSON request body") + return jsonify({"error": "Request body must be JSON"}), HTTPStatus.BAD_REQUEST + + email = data.get("email") + password = data.get("password") if not email or not password: - logger.warning("Missing email or password in form data") + logger.warning("Missing email or password in JSON data") return ( jsonify({"error": "Email and password are required"}), HTTPStatus.BAD_REQUEST, @@ -92,18 +100,17 @@ def login() -> Dict[str, Any]: response = make_response( jsonify({"message": "Login successful"}), HTTPStatus.OK ) - response.set_cookie( key="token", value=token, httponly=True, - secure=False, # TODO: Set to True in production + secure=False, # TODO: Set to True in production for HTTPS samesite="Lax", ) return response - logger.warning("Invalid login attempt for email: {email}") + logger.warning(f"Invalid login attempt for email: {email}") return jsonify({"error": "Invalid credentials"}), HTTPStatus.UNAUTHORIZED @@ -117,24 +124,37 @@ def get_user_info(current_user: User) -> Dict[str, Any]: @app.route("/conversations", methods=["POST"]) @token_required def create_conversation(current_user: User) -> Dict[str, Any]: - data = request.get_json() - if not data: - logger.warning("Empty JSON request body") - return jsonify({"error": "Request body must be JSON"}), HTTPStatus.BAD_REQUEST - title = data.get("title") - if not title: - logger.warning("Title missing in request body") - return jsonify({"error": "Title is required"}), HTTPStatus.BAD_REQUEST + thread_id = openai_client.create_thread() + if not thread_id: + logger.error("Error creating conversation thread") + return ( + jsonify({"error": "An error occurred creating the conversation thread"}), + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + assistant_id = openai_client.get_assistant_id() + if not assistant_id: + logger.error("Error getting assistant ID") + return ( + jsonify({"error": "An error occurred getting the assistant ID"}), + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + new_conversation = ConversationThread( + user_id=current_user.id, + thread_id=thread_id, + assistant_id=assistant_id, + ) db_session = next(get_database_session()) - new_thread = ConversationThread(title=title, user_id=current_user.id) - db_session.add(new_thread) + db_session.add(new_conversation) db_session.commit() - logger.info(f"Conversation thread created with ID: {new_thread.id}") + logger.info(f"Conversation thread created with ID: {new_conversation.id}") return ( - jsonify({"message": "Conversation created", "thread_id": new_thread.id}), + jsonify( + {"message": "Conversation created", "conversation_id": new_conversation.id} + ), HTTPStatus.CREATED, ) @@ -143,16 +163,171 @@ def create_conversation(current_user: User) -> Dict[str, Any]: @token_required def list_conversations(current_user: User) -> Dict[str, Any]: db_session = next(get_database_session()) - threads = ( + conversations = ( db_session.query(ConversationThread).filter_by(user_id=current_user.id).all() ) logger.info(f"Conversations listed for user ID: {current_user.id}") return ( - jsonify([{"id": thread.id, "title": thread.title} for thread in threads]), + jsonify( + [ + { + "id": conversation.id, + "thread_id": conversation.thread_id, + "assistant_id": conversation.assistant_id, + "created_at": conversation.created_at, + "status": conversation.status, + } + for conversation in conversations + ] + ), HTTPStatus.OK, ) +@app.route("/conversations//messages", methods=["POST"]) +@token_required +def send_message(current_user: User, conversation_id: int) -> Dict[str, Any]: + data = request.get_json() + if not data: + logger.warning("Empty JSON request body") + return jsonify({"error": "Request body must be JSON"}), HTTPStatus.BAD_REQUEST + + message = data.get("message") + if not message: + logger.warning("Message missing in request body") + return jsonify({"error": "Message is required"}), HTTPStatus.BAD_REQUEST + + db_session = next(get_database_session()) + conversation = ( + db_session.query(ConversationThread) + .filter_by(id=conversation_id, user_id=current_user.id) + .first() + ) + + if not conversation: + logger.warning( + f"Conversation {conversation_id} not found for user {current_user.id}" + ) + return jsonify({"error": "Conversation not found"}), HTTPStatus.NOT_FOUND + + openai_response = openai_client.send_message( + thread_id=conversation.thread_id, + assistant_id=conversation.assistant_id, + message=message, + ) + + if not openai_response: + logger.error("Error sending message to OpenAI") + return ( + jsonify({"error": "An error occurred sending the message"}), + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + logger.info(f"Message sent to conversation {conversation_id}") + return jsonify(openai_response), HTTPStatus.OK + + +@app.route("/conversations//messages", methods=["GET"]) +@token_required +def get_conversation_messages( + conversation_id: str, current_user: User +) -> Dict[str, Any]: + try: + db_session = next(get_database_session()) + conversation = ( + db_session.query(ConversationThread) + .filter_by(id=conversation_id, user_id=current_user.id) + .first() + ) + + if not conversation: + logger.warning( + f"Conversation with ID {conversation_id} not found for user {current_user.id}" + ) + return jsonify({"error": "Conversation not found"}), HTTPStatus.NOT_FOUND + + # Fetch the thread messages from OpenAI + messages = openai_client.get_thread_messages(conversation.thread_id) + + if not messages: + logger.error("Error fetching conversation messages") + return ( + jsonify( + {"error": "An error occurred fetching the conversation messages"} + ), + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + return ( + jsonify( + { + "conversation_id": conversation_id, + "messages": messages, + } + ), + HTTPStatus.OK, + ) + + except Exception as e: + logger.error( + f"Error fetching conversation messages for {conversation_id}: {str(e)}" + ) + return ( + jsonify({"error": "Failed to retrieve conversation messages"}), + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + +@app.route("/conversations//thread", methods=["GET"]) +@token_required +def get_conversation_thread(conversation_id: int, current_user: User) -> Dict[str, Any]: + try: + db_session = next(get_database_session()) + conversation = ( + db_session.query(ConversationThread) + .filter_by(id=conversation_id, user_id=current_user.id) + .first() + ) + + if not conversation: + logger.warning( + f"Conversation with ID {conversation_id} not found for user {current_user.id}" + ) + return jsonify({"error": "Conversation not found"}), HTTPStatus.NOT_FOUND + + # Fetch the thread messages from OpenAI + thread = openai_client.get_thread(conversation.thread_id).to_json() + + if not thread: + logger.error("Error fetching conversation thread messages") + return ( + jsonify( + {"error": "An error occurred fetching the conversation thread"} + ), + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + return ( + jsonify( + { + "conversation_id": conversation_id, + "thread_id": conversation.thread_id, + "content": thread, + } + ), + HTTPStatus.OK, + ) + + except Exception as e: + logger.error( + f"Error fetching conversation thread for {conversation_id}: {str(e)}" + ) + return ( + jsonify({"error": "Failed to retrieve conversation thread"}), + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + @app.route("/") def index() -> str: logger.info("Index page accessed") diff --git a/app/assistants/openai.py b/app/assistants/openai.py new file mode 100644 index 0000000..38144b2 --- /dev/null +++ b/app/assistants/openai.py @@ -0,0 +1,140 @@ +import os +import logging +from openai import OpenAI +from dotenv import load_dotenv +from typing import Optional, Dict, Any + +# Load environment variables from .env +load_dotenv() + +# Set up logging configuration +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler("app.log"), # Logs to a file + logging.StreamHandler(), # Also logs to the console + ], +) +logger = logging.getLogger(__name__) + + +class OpenAIAssistant: + def __init__(self) -> None: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + logger.error("OpenAI API key not set in environment variables.") + raise ValueError("The OpenAI API key is not set.") + + logger.info("OpenAI API key successfully loaded.") + self.client = OpenAI(api_key=self.api_key) + + def get_assistant_id(self) -> Optional[str]: + """Return a fixed assistant ID (hardcoded).""" + assistant_id = "asst_Bd6nPv9qhFDPR74IuN2jOokL" + logger.info(f"Returning assistant ID: {assistant_id}") + return assistant_id + + def create_thread(self) -> Optional[str]: + """Creates a conversation thread in OpenAI.""" + try: + logger.info("Creating a new conversation thread.") + thread = self.client.beta.threads.create() + logger.info(f"Conversation thread created with ID: {thread.id}") + return thread.id + except Exception as e: + logger.error(f"Failed to create conversation thread: {e}") + return None + + def send_message( + self, thread_id: str, assistant_id: str, message: str + ) -> Optional[Dict[str, str]]: + """Sends a message to the assistant in a specific conversation thread.""" + try: + logger.info( + f"Sending message to thread {thread_id} with assistant {assistant_id}." + ) + message_response = self.client.beta.threads.messages.create( + thread_id=thread_id, + role="user", + content=message, + ) + logger.info(f"Message sent: {message}") + + # Wait for the assistant to respond + run_response = self.client.beta.threads.runs.create_and_poll( + thread_id=thread_id, + assistant_id=assistant_id, + ) + + if run_response.status == "completed": + logger.info(f"Assistant response completed for thread {thread_id}.") + messages = self.client.beta.threads.messages.list(thread_id=thread_id) + # assistant_reply = ( + # messages[-1].content if messages else "No response from assistant." + # ) + assistant_reply = messages + logger.info(f"Assistant response: {assistant_reply}") + return self._extract_last_message(assistant_reply) + else: + logger.error( + f"Assistant did not complete the response for thread {thread_id}. Status: {run_response.status}" + ) + return None + + except Exception as e: + logger.error( + f"Error sending message to thread {thread_id} with assistant {assistant_id}: {e}" + ) + return None + + def _extract_last_message(self, response: Dict[str, Any]) -> Optional[str]: + try: + messages = response.data + if not messages: + logger.warning("No messages found.") + return None + + for message in reversed(messages): + if message.role == "assistant": + content = message.content[0].text.value + return content + + logger.info("No assistant message found.") + return None + + except Exception as e: + logger.error(f"Error extracting last assistant message: {e}") + return None + + def get_thread_messages(self, thread_id: str) -> Optional[Dict[str, Any]]: + """Fetches all messages in a conversation thread.""" + try: + messages = self.client.beta.threads.messages.list(thread_id=thread_id) + thread_data = [] + for message in messages.data: + message_data = { + "id": message.id, + "role": message.role, + "created_at": message.created_at, + "content": [ + content_block.text.value for content_block in message.content + ], + } + thread_data.append(message_data) + + logger.info(f"Fetched {len(thread_data)} messages for thread {thread_id}.") + return thread_data + except Exception as e: + logger.error(f"Error fetching messages for thread {thread_id}: {e}") + return None + + def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]: + """Fetches a conversation thread by ID.""" + try: + thread = self.client.beta.threads.retrieve(thread_id) + logger.info(f"Fetched conversation thread with ID: {thread_id}") + return thread + except Exception as e: + logger.error(f"Error fetching conversation thread {thread_id}: {e}") + return None diff --git a/app/core/__init__.py b/app/core/__init__.py deleted file mode 100644 index e69de29..0000000