Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai tutor #57

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
116 changes: 114 additions & 2 deletions ai_chatbots/chatbots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from operator import add
from typing import Annotated, Optional
from uuid import uuid4
from asgiref.sync import sync_to_async
from channels.db import database_sync_to_async

import posthog
from django.conf import settings
from django.utils.module_loading import import_string
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.tools.base import BaseTool
from langgraph.checkpoint.base import BaseCheckpointSaver
Expand All @@ -22,8 +24,13 @@
from langgraph.prebuilt import ToolNode, create_react_agent, tools_condition
from langgraph.prebuilt.chat_agent_executor import AgentState
from openai import BadRequestError
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict

from open_learning_ai_tutor.problems import get_pb_sol
from open_learning_ai_tutor.StratL import message_tutor, process_StratL_json_output
from open_learning_ai_tutor.tools import tutor_tools
from open_learning_ai_tutor.utils import messages_to_json, json_to_messages, intent_list_to_json
from ai_chatbots.models import TutorBotOutput
from ai_chatbots import tools
from ai_chatbots.api import get_search_tool_metadata
from ai_chatbots.tools import search_content_files
Expand Down Expand Up @@ -403,3 +410,108 @@ async def get_tool_metadata(self) -> str:
thread_id = self.config["configurable"]["thread_id"]
latest_state = await self.get_latest_history()
return get_search_tool_metadata(thread_id, latest_state)

@database_sync_to_async
def create_tutorbot_output(thread_id, chat_json):
return TutorBotOutput.objects.create(thread_id=thread_id, chat_json=chat_json)

@database_sync_to_async
def get_history(thread_id):
return TutorBotOutput.objects.filter(thread_id=thread_id).last()


class TutorBot(BaseChatbot):
"""
Chatbot that assists with problem sets
"""
def __init__( # noqa: PLR0913
self,
user_id: str,
checkpointer: Optional[BaseCheckpointSaver] = BaseCheckpointSaver,
*,
name: str = "MIT Open Learning Tutor Chatbot",
model: Optional[str] = None,
temperature: Optional[float] = None,
thread_id: Optional[str] = None,
problem_code: Optional[str] = None,
):
super().__init__(
user_id,
name=name,
checkpointer=checkpointer,
temperature=temperature,
thread_id=thread_id,
model=model or settings.AI_DEFAULT_TUTOR_MODEL,
)
self.problem, self.solution = get_pb_sol(problem_code)

def get_llm(self, **kwargs) -> BaseChatModel:
"""
Return the LLM instance for the chatbot.
Set it up to use a proxy, with required proxy kwargs, if applicable.
"""
llm = ChatOpenAI(
model=f"{self.proxy_prefix}{self.model}",
**(self.proxy.get_api_kwargs(base_url_key="base_url", api_key_key="openai_api_key") if self.proxy else {}),
**(self.proxy.get_additional_kwargs(self) if self.proxy else {}),
**kwargs,
)
# Set the temperature if it's supported by the model
if self.temperature and self.model not in settings.AI_UNSUPPORTED_TEMP_MODELS:
llm.temperature = self.temperature
return llm


async def get_tool_metadata(self) -> str:
"""Return the metadata for the tool"""
return None

async def get_completion(
self,
message: str,
*,
extra_state: Optional[TypedDict] = None,
debug: bool = settings.AI_DEBUG,
) -> AsyncGenerator[str, None]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

"""Call message_tutor with the user query and return the response"""

history = await get_history(self.thread_id)

if history:
json_history = json.loads(history.chat_json)
self.chat_history = json_to_messages(json_history.get('chat_history', []))+[HumanMessage(content=message)]
self.intent_history = json_history.get('intent_history', [])
self.assessment_history = json_history.get('assessment_history', [])
else:
self.chat_history = [HumanMessage(content=message)]
self.intent_history = '[]'
self.assessment_history = ''

response = ""

try:
json_output = message_tutor(
self.problem,
self.solution,
self.llm,
messages_to_json([HumanMessage(content=message)]),
messages_to_json(self.chat_history),
self.assessment_history,
self.intent_history,
'{}',
tools=tutor_tools
)

await create_tutorbot_output(self.thread_id, json_output)
prossessed = process_StratL_json_output(json_output)
response = "An error has occurred, please try again"
for index, msg in enumerate(prossessed[0]):
if isinstance(msg, ToolMessage) and msg.name == 'text_student':
response = prossessed[0][index-1].tool_calls[0]['args']['message_to_student']

yield response

except Exception:
yield '<!-- {"error":{"message":"An error occurred, please try again"}} -->'
log.exception("Error running AI agent")

109 changes: 109 additions & 0 deletions ai_chatbots/chatbots_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableBinding
from open_learning_ai_tutor.problems import get_pb_sol

from ai_chatbots.chatbots import (
ResourceRecommendationBot,
SyllabusAgentState,
SyllabusBot,
TutorBot,
get_history
)
from ai_chatbots.checkpointers import AsyncDjangoSaver
from ai_chatbots.conftest import MockAsyncIterator
Expand Down Expand Up @@ -459,3 +462,109 @@ async def test_proxy_settings(settings, mocker, mock_checkpointer, use_proxy):
**{},
**{},
)

@pytest.mark.parametrize(
("model", "temperature"),
[
("gpt-3.5-turbo", 0.1),
("gpt-4", None),
(None, None),
],
)
async def test_tutor_bot_intitiation(
mocker, model, temperature
):
"""Test the tutor class instantiation."""
name = "My tutor bot"
problem_code = "A1P1"


chatbot = TutorBot(
"user",
name=name,
model=model,
temperature=temperature,
problem_code=problem_code
)
assert chatbot.model == (
model if model else settings.AI_DEFAULT_TUTOR_MODEL
)
assert chatbot.temperature == (
temperature if temperature else settings.AI_DEFAULT_TEMPERATURE
)
problem, solution = get_pb_sol(problem_code)
assert chatbot.problem == problem
assert chatbot.solution == solution
assert chatbot.model == model if model else settings.AI_DEFAULT_TUTOR_MODEL


async def test_tutor_get_completion(
mocker, mock_checkpointer
):
"""Test that the tutor bot get_completion method returns expected values."""

json_output = {
"chat_history": [
{
"type": "HumanMessage",
"content": "what should i try next?"
},
{
"type": "AIMessage",
"content": "",
"tool_calls": [
{
"id": "call_2YfyQtpoDAaSfJo0XiYEVEI3",
"function": {
"arguments": "{\"message_to_student\":\"Let's start with Problem 1.1. Have you tried plotting the states' centers using latitude and longitude? What do you think should be the first variable in the plot command? Share your thoughts or any code you've tried so far.\"}",
"name": "text_student"
},
"type": "function"
}
],
"refusal": None
},
{
"type": "ToolMessage",
"content": "Message sent",
"name": "text_student",
"tool_call_id": "call_2YfyQtpoDAaSfJo0XiYEVEI3"
}
],
"intent_history": "[[\"P_HYPOTHESIS\"]]",
"assessment_history": [
{
"type": "HumanMessage",
"content": "Student: \"what should i try next?\""
},
{
"type": "AIMessage",
"content": "{\"justification\": \"The student is explicitly asking about how to solve the problem, indicating they are seeking guidance on the next steps to take.\", \"selection\": \"g\"}",
"refusal": None
}
],
"metadata": {
"docs": None,
"rag_queries": None,
"A_B_test": False,
"tutor_model": "gpt-4o"
}
}

mocker.patch(
"ai_chatbots.chatbots.message_tutor",
return_value=json.dumps(json_output)
)
user_msg = "what should i try next?"
thread_id='TEST'

chatbot = TutorBot("anonymous", mock_checkpointer, problem_code="A1P1", thread_id=thread_id)

results = ""
async for chunk in chatbot.get_completion(user_msg):
results += str(chunk)
assert results == "Let's start with Problem 1.1. Have you tried plotting the states' centers using latitude and longitude? What do you think should be the first variable in the plot command? Share your thoughts or any code you've tried so far."

history = await get_history(thread_id)
assert history.thread_id == thread_id
assert history.chat_json == json.dumps(json_output)
33 changes: 30 additions & 3 deletions ai_chatbots/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from langgraph.checkpoint.base import BaseCheckpointSaver
from rest_framework.exceptions import ValidationError
from rest_framework.status import HTTP_200_OK
from ai_chatbots.chatbots import ResourceRecommendationBot, SyllabusBot
from ai_chatbots.chatbots import ResourceRecommendationBot, SyllabusBot, TutorBot
from ai_chatbots.checkpointers import AsyncDjangoSaver
from ai_chatbots.constants import AI_THREAD_COOKIE_KEY, AI_THREADS_ANONYMOUS_COOKIE_KEY
from ai_chatbots.models import UserChatSession
from ai_chatbots.serializers import ChatRequestSerializer, SyllabusChatRequestSerializer
from ai_chatbots.serializers import ChatRequestSerializer, SyllabusChatRequestSerializer, TutorChatRequestSerializer
from users.models import User

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,8 +62,8 @@ async def assign_thread_cookies(
"""
latest_cookie_key = f"{self.ROOM_NAME}_{AI_THREAD_COOKIE_KEY}"
anon_cookie_key = f"{self.ROOM_NAME}_{AI_THREADS_ANONYMOUS_COOKIE_KEY}"

threads_ids_str = self.scope["cookies"].get(anon_cookie_key) or ""

thread_ids = [tid for tid in (threads_ids_str).split(",") if tid]

if thread_ids:
Expand Down Expand Up @@ -307,3 +307,30 @@ def process_extra_state(self, data: dict) -> dict:
"course_id": [data.get("course_id")],
"collection_name": [data.get("collection_name")],
}


class TutorBotHttpConsumer(BaseBotHttpConsumer):
"""
Async HTTP consumer for the tutor bot.
"""

serializer_class = TutorChatRequestSerializer
ROOM_NAME = TutorBot.__name__


def create_chatbot(self, serializer: TutorChatRequestSerializer, checkpointer: BaseCheckpointSaver,):
"""Return a TutorBot instance"""
temperature = serializer.validated_data.pop("temperature", None)
model = serializer.validated_data.pop("model", None)
problem_code = serializer.validated_data.pop("problem_code", None)


return TutorBot(
self.user_id,
checkpointer,
temperature=temperature,
model=model,
thread_id=self.thread_id,
problem_code = problem_code
)

45 changes: 45 additions & 0 deletions ai_chatbots/consumers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def syllabus_consumer(async_user):
consumer.channel_name = "test_syllabus_channel"
return consumer

@pytest.fixture
def tutor_consumer(async_user):
"""Return a tutor consumer."""
consumer = consumers.TutorBotHttpConsumer()
consumer.scope = {"user": async_user, "cookies": {}, "session": None}
consumer.channel_name = "test_tutor_channel"
return consumer

@pytest.mark.parametrize(
("message", "temperature", "instructions", "model"),
Expand Down Expand Up @@ -417,3 +424,41 @@ async def test_handle_errors(
"more_body": True,
}
)

async def test_tutor_agent_handle(
mocker,
mock_http_consumer_send,
tutor_consumer,
):
"""Test the receive function of the recommendation agent."""
response = SystemMessageFactory.create()
user = tutor_consumer.scope["user"]
user.is_superuser = True
mock_completion = mocker.patch(
"ai_chatbots.chatbots.TutorBot.get_completion",
return_value=mocker.Mock(
__aiter__=mocker.Mock(
return_value=MockAsyncIterator(list(response.content.split(" ")))
)
),
)
message = "What should i try next?"
data = {
"message": message,
"problem_code": "A1P1"
}

await tutor_consumer.handle(json.dumps(data))

mock_http_consumer_send.send_headers.assert_called_once()

mock_completion.assert_called_once_with(message, extra_state=None)
assert (
mock_http_consumer_send.send_body.call_count
== len(response.content.split(" ")) + 2
)
mock_http_consumer_send.send_body.assert_any_call(
body=response.content.split(" ")[0].encode("utf-8"),
more_body=True,
)
assert mock_http_consumer_send.send_headers.call_count == 1
Loading
Loading