Skip to content

Commit

Permalink
add openai assistant
Browse files Browse the repository at this point in the history
  • Loading branch information
eyenpi committed Oct 11, 2024
1 parent db5dc91 commit 480d7f9
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 22 deletions.
219 changes: 197 additions & 22 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
from typing import Iterator, Dict, Any
from http import HTTPStatus
from app.assistants.openai import OpenAIAssistant

app = Flask(__name__)

Expand All @@ -20,6 +21,8 @@
)
logger = logging.getLogger(__name__)

openai_client = OpenAIAssistant()


def get_database_session() -> Iterator[Session]:
db = next(get_db())
Expand All @@ -40,9 +43,9 @@ def register() -> Dict[str, Any]:
password = data.get("password")

if not email or not password:
logger.warning("Missing email or password")
logger.warning("Missing email or password in JSON data")
return (
jsonify({"error": "Email and password are required."}),
jsonify({"error": "Email and password are required"}),
HTTPStatus.BAD_REQUEST,
)

Expand Down Expand Up @@ -74,11 +77,16 @@ def register() -> Dict[str, Any]:

@app.route("/login", methods=["POST"])
def login() -> Dict[str, Any]:
email = request.form.get("email")
password = request.form.get("password")
data = request.get_json()
if not data:
logger.warning("Empty JSON request body")
return jsonify({"error": "Request body must be JSON"}), HTTPStatus.BAD_REQUEST

email = data.get("email")
password = data.get("password")

if not email or not password:
logger.warning("Missing email or password in form data")
logger.warning("Missing email or password in JSON data")
return (
jsonify({"error": "Email and password are required"}),
HTTPStatus.BAD_REQUEST,
Expand All @@ -92,18 +100,17 @@ def login() -> Dict[str, Any]:
response = make_response(
jsonify({"message": "Login successful"}), HTTPStatus.OK
)

response.set_cookie(
key="token",
value=token,
httponly=True,
secure=False, # TODO: Set to True in production
secure=False, # TODO: Set to True in production for HTTPS
samesite="Lax",
)

return response

logger.warning("Invalid login attempt for email: {email}")
logger.warning(f"Invalid login attempt for email: {email}")
return jsonify({"error": "Invalid credentials"}), HTTPStatus.UNAUTHORIZED


Expand All @@ -117,24 +124,37 @@ def get_user_info(current_user: User) -> Dict[str, Any]:
@app.route("/conversations", methods=["POST"])
@token_required
def create_conversation(current_user: User) -> Dict[str, Any]:
data = request.get_json()
if not data:
logger.warning("Empty JSON request body")
return jsonify({"error": "Request body must be JSON"}), HTTPStatus.BAD_REQUEST

title = data.get("title")
if not title:
logger.warning("Title missing in request body")
return jsonify({"error": "Title is required"}), HTTPStatus.BAD_REQUEST
thread_id = openai_client.create_thread()
if not thread_id:
logger.error("Error creating conversation thread")
return (
jsonify({"error": "An error occurred creating the conversation thread"}),
HTTPStatus.INTERNAL_SERVER_ERROR,
)

assistant_id = openai_client.get_assistant_id()
if not assistant_id:
logger.error("Error getting assistant ID")
return (
jsonify({"error": "An error occurred getting the assistant ID"}),
HTTPStatus.INTERNAL_SERVER_ERROR,
)

new_conversation = ConversationThread(
user_id=current_user.id,
thread_id=thread_id,
assistant_id=assistant_id,
)
db_session = next(get_database_session())
new_thread = ConversationThread(title=title, user_id=current_user.id)
db_session.add(new_thread)
db_session.add(new_conversation)
db_session.commit()
logger.info(f"Conversation thread created with ID: {new_thread.id}")
logger.info(f"Conversation thread created with ID: {new_conversation.id}")

return (
jsonify({"message": "Conversation created", "thread_id": new_thread.id}),
jsonify(
{"message": "Conversation created", "conversation_id": new_conversation.id}
),
HTTPStatus.CREATED,
)

Expand All @@ -143,16 +163,171 @@ def create_conversation(current_user: User) -> Dict[str, Any]:
@token_required
def list_conversations(current_user: User) -> Dict[str, Any]:
db_session = next(get_database_session())
threads = (
conversations = (
db_session.query(ConversationThread).filter_by(user_id=current_user.id).all()
)
logger.info(f"Conversations listed for user ID: {current_user.id}")
return (
jsonify([{"id": thread.id, "title": thread.title} for thread in threads]),
jsonify(
[
{
"id": conversation.id,
"thread_id": conversation.thread_id,
"assistant_id": conversation.assistant_id,
"created_at": conversation.created_at,
"status": conversation.status,
}
for conversation in conversations
]
),
HTTPStatus.OK,
)


@app.route("/conversations/<int:conversation_id>/messages", methods=["POST"])
@token_required
def send_message(current_user: User, conversation_id: int) -> Dict[str, Any]:
data = request.get_json()
if not data:
logger.warning("Empty JSON request body")
return jsonify({"error": "Request body must be JSON"}), HTTPStatus.BAD_REQUEST

message = data.get("message")
if not message:
logger.warning("Message missing in request body")
return jsonify({"error": "Message is required"}), HTTPStatus.BAD_REQUEST

db_session = next(get_database_session())
conversation = (
db_session.query(ConversationThread)
.filter_by(id=conversation_id, user_id=current_user.id)
.first()
)

if not conversation:
logger.warning(
f"Conversation {conversation_id} not found for user {current_user.id}"
)
return jsonify({"error": "Conversation not found"}), HTTPStatus.NOT_FOUND

openai_response = openai_client.send_message(
thread_id=conversation.thread_id,
assistant_id=conversation.assistant_id,
message=message,
)

if not openai_response:
logger.error("Error sending message to OpenAI")
return (
jsonify({"error": "An error occurred sending the message"}),
HTTPStatus.INTERNAL_SERVER_ERROR,
)

logger.info(f"Message sent to conversation {conversation_id}")
return jsonify(openai_response), HTTPStatus.OK


@app.route("/conversations/<int:conversation_id>/messages", methods=["GET"])
@token_required
def get_conversation_messages(
conversation_id: str, current_user: User
) -> Dict[str, Any]:
try:
db_session = next(get_database_session())
conversation = (
db_session.query(ConversationThread)
.filter_by(id=conversation_id, user_id=current_user.id)
.first()
)

if not conversation:
logger.warning(
f"Conversation with ID {conversation_id} not found for user {current_user.id}"
)
return jsonify({"error": "Conversation not found"}), HTTPStatus.NOT_FOUND

# Fetch the thread messages from OpenAI
messages = openai_client.get_thread_messages(conversation.thread_id)

if not messages:
logger.error("Error fetching conversation messages")
return (
jsonify(
{"error": "An error occurred fetching the conversation messages"}
),
HTTPStatus.INTERNAL_SERVER_ERROR,
)

return (
jsonify(
{
"conversation_id": conversation_id,
"messages": messages,
}
),
HTTPStatus.OK,
)

except Exception as e:
logger.error(
f"Error fetching conversation messages for {conversation_id}: {str(e)}"
)
return (
jsonify({"error": "Failed to retrieve conversation messages"}),
HTTPStatus.INTERNAL_SERVER_ERROR,
)


@app.route("/conversations/<int:conversation_id>/thread", methods=["GET"])
@token_required
def get_conversation_thread(conversation_id: int, current_user: User) -> Dict[str, Any]:
try:
db_session = next(get_database_session())
conversation = (
db_session.query(ConversationThread)
.filter_by(id=conversation_id, user_id=current_user.id)
.first()
)

if not conversation:
logger.warning(
f"Conversation with ID {conversation_id} not found for user {current_user.id}"
)
return jsonify({"error": "Conversation not found"}), HTTPStatus.NOT_FOUND

# Fetch the thread messages from OpenAI
thread = openai_client.get_thread(conversation.thread_id).to_json()

if not thread:
logger.error("Error fetching conversation thread messages")
return (
jsonify(
{"error": "An error occurred fetching the conversation thread"}
),
HTTPStatus.INTERNAL_SERVER_ERROR,
)

return (
jsonify(
{
"conversation_id": conversation_id,
"thread_id": conversation.thread_id,
"content": thread,
}
),
HTTPStatus.OK,
)

except Exception as e:
logger.error(
f"Error fetching conversation thread for {conversation_id}: {str(e)}"
)
return (
jsonify({"error": "Failed to retrieve conversation thread"}),
HTTPStatus.INTERNAL_SERVER_ERROR,
)


@app.route("/")
def index() -> str:
logger.info("Index page accessed")
Expand Down
Loading

0 comments on commit 480d7f9

Please sign in to comment.