From 147205abaa9c5572e2cae62ffabcac073292e791 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce de Leon Date: Thu, 16 Jan 2025 18:49:59 +0200 Subject: [PATCH] feat: Initial migration for Workspaces and pipeline step (#600) * feat: Initial migration for Workspaces and pipeline step Related: #454 We noticed most of the incoming-requests which contain a code-snippet only list a relative path with respect to where the code editor is opened. This would make difficult to accurately distinguish between repositories in Codegate. For example, a user could open 2 different code Python repositorites in different session and both repositories contain a `pyproject.toml`. It would be impossible for Codegate to determine the real repository of the file only using the relative path. Hence, the initial implementation of Workspaces will rely on a pipeline step that is able to take commands a process them. Some commands could be: - List workspaces - Add workspace - Switch active workspace - Delete workspace It would be up to the user to select the desired active workspace. This PR introduces an initial migration for Workspaces and the pipeline step with the `list` command. * Reformatting changes * Make unique workspaces name * Introduced Sessions table and added add and activate commands * Formatting changes and unit tests * Classes separation into a different file --- .../5c2f3eee5f90_introduce_workspaces.py | 61 +++++++ src/codegate/__init__.py | 2 + src/codegate/cli.py | 3 +- src/codegate/db/connection.py | 145 +++++++++++++++- src/codegate/db/models.py | 48 +++++- src/codegate/pipeline/base.py | 1 + .../extract_snippets/extract_snippets.py | 6 +- src/codegate/pipeline/factory.py | 2 + src/codegate/pipeline/secrets/secrets.py | 2 +- src/codegate/pipeline/workspace/__init__.py | 0 src/codegate/pipeline/workspace/commands.py | 157 ++++++++++++++++++ src/codegate/pipeline/workspace/workspace.py | 58 +++++++ .../providers/ollama/completion_handler.py | 5 +- tests/pipeline/workspace/test_workspace.py | 125 ++++++++++++++ 14 files changed, 589 insertions(+), 26 deletions(-) create mode 100644 migrations/versions/5c2f3eee5f90_introduce_workspaces.py create mode 100644 src/codegate/pipeline/workspace/__init__.py create mode 100644 src/codegate/pipeline/workspace/commands.py create mode 100644 src/codegate/pipeline/workspace/workspace.py create mode 100644 tests/pipeline/workspace/test_workspace.py diff --git a/migrations/versions/5c2f3eee5f90_introduce_workspaces.py b/migrations/versions/5c2f3eee5f90_introduce_workspaces.py new file mode 100644 index 00000000..9f1bba0a --- /dev/null +++ b/migrations/versions/5c2f3eee5f90_introduce_workspaces.py @@ -0,0 +1,61 @@ +"""introduce workspaces + +Revision ID: 5c2f3eee5f90 +Revises: 30d0144e1a50 +Create Date: 2025-01-15 19:27:08.230296 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "5c2f3eee5f90" +down_revision: Union[str, None] = "30d0144e1a50" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Workspaces table + op.execute( + """ + CREATE TABLE workspaces ( + id TEXT PRIMARY KEY, -- UUID stored as TEXT + name TEXT NOT NULL, + UNIQUE (name) + ); + """ + ) + op.execute("INSERT INTO workspaces (id, name) VALUES ('1', 'default');") + # Sessions table + op.execute( + """ + CREATE TABLE sessions ( + id TEXT PRIMARY KEY, -- UUID stored as TEXT + active_workspace_id TEXT NOT NULL, + last_update DATETIME NOT NULL, + FOREIGN KEY (active_workspace_id) REFERENCES workspaces(id) + ); + """ + ) + # Alter table prompts + op.execute("ALTER TABLE prompts ADD COLUMN workspace_id TEXT REFERENCES workspaces(id);") + op.execute("UPDATE prompts SET workspace_id = '1';") + # Create index for workspace_id + op.execute("CREATE INDEX idx_prompts_workspace_id ON prompts (workspace_id);") + # Create index for session_id + op.execute("CREATE INDEX idx_sessions_workspace_id ON sessions (active_workspace_id);") + + +def downgrade() -> None: + # Drop the index for workspace_id + op.execute("DROP INDEX IF EXISTS idx_prompts_workspace_id;") + op.execute("DROP INDEX IF EXISTS idx_sessions_workspace_id;") + # Remove the workspace_id column from prompts table + op.execute("ALTER TABLE prompts DROP COLUMN workspace_id;") + # Drop the sessions table + op.execute("DROP TABLE IF EXISTS sessions;") + # Drop the workspaces table + op.execute("DROP TABLE IF EXISTS workspaces;") diff --git a/src/codegate/__init__.py b/src/codegate/__init__.py index 24f8ce3c..62a7bb40 100644 --- a/src/codegate/__init__.py +++ b/src/codegate/__init__.py @@ -10,6 +10,7 @@ _VERSION = "dev" _DESC = "CodeGate - A Generative AI security gateway." + def __get_version_and_description() -> tuple[str, str]: try: version = metadata.version("codegate") @@ -19,6 +20,7 @@ def __get_version_and_description() -> tuple[str, str]: description = _DESC return version, description + __version__, __description__ = __get_version_and_description() __all__ = ["Config", "ConfigurationError", "LogFormat", "LogLevel", "setup_logging"] diff --git a/src/codegate/cli.py b/src/codegate/cli.py index 06456a00..7bcd035d 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -14,7 +14,7 @@ from codegate.ca.codegate_ca import CertificateAuthority from codegate.codegate_logging import LogFormat, LogLevel, setup_logging from codegate.config import Config, ConfigurationError -from codegate.db.connection import init_db_sync +from codegate.db.connection import init_db_sync, init_session_if_not_exists from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.copilot.provider import CopilotProvider @@ -307,6 +307,7 @@ def serve( logger = structlog.get_logger("codegate").bind(origin="cli") init_db_sync(cfg.db_path) + init_session_if_not_exists(cfg.db_path) # Check certificates and create CA if necessary logger.info("Checking certificates and creating CA if needed") diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 4894ad2a..d04dfcf9 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -1,23 +1,28 @@ import asyncio import json +import uuid from pathlib import Path from typing import List, Optional, Type import structlog from alembic import command as alembic_command from alembic.config import Config as AlembicConfig -from pydantic import BaseModel -from sqlalchemy import TextClause, text +from pydantic import BaseModel, ValidationError +from sqlalchemy import CursorResult, TextClause, text from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import create_async_engine from codegate.db.fim_cache import FimCache from codegate.db.models import ( + ActiveWorkspace, Alert, GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow, Output, Prompt, + Session, + Workspace, + WorkspaceActive, ) from codegate.pipeline.base import PipelineContext @@ -75,10 +80,14 @@ async def _execute_update_pydantic_model( async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]: if prompt_params is None: return None + # Get the active workspace to store the request + active_workspace = await DbReader().get_active_workspace() + workspace_id = active_workspace.id if active_workspace else "1" + prompt_params.workspace_id = workspace_id sql = text( """ - INSERT INTO prompts (id, timestamp, provider, request, type) - VALUES (:id, :timestamp, :provider, :request, :type) + INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id) + VALUES (:id, :timestamp, :provider, :request, :type, :workspace_id) RETURNING * """ ) @@ -223,26 +232,78 @@ async def record_context(self, context: Optional[PipelineContext]) -> None: except Exception as e: logger.error(f"Failed to record context: {context}.", error=str(e)) + async def add_workspace(self, workspace_name: str) -> Optional[Workspace]: + try: + workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name) + except ValidationError as e: + logger.error(f"Failed to create workspace with name: {workspace_name}: {str(e)}") + return None + + sql = text( + """ + INSERT INTO workspaces (id, name) + VALUES (:id, :name) + RETURNING * + """ + ) + added_workspace = await self._execute_update_pydantic_model(workspace, sql) + return added_workspace + + async def update_session(self, session: Session) -> Optional[Session]: + sql = text( + """ + INSERT INTO sessions (id, active_workspace_id, last_update) + VALUES (:id, :active_workspace_id, :last_update) + ON CONFLICT (id) DO UPDATE SET + active_workspace_id = excluded.active_workspace_id, last_update = excluded.last_update + WHERE id = excluded.id + RETURNING * + """ + ) + # We only pass an object to respect the signature of the function + active_session = await self._execute_update_pydantic_model(session, sql) + return active_session + class DbReader(DbCodeGate): def __init__(self, sqlite_path: Optional[str] = None): super().__init__(sqlite_path) + async def _dump_result_to_pydantic_model( + self, model_type: Type[BaseModel], result: CursorResult + ) -> Optional[List[BaseModel]]: + try: + if not result: + return None + rows = [model_type(**row._asdict()) for row in result.fetchall() if row] + return rows + except Exception as e: + logger.error(f"Failed to dump to pydantic model: {model_type}.", error=str(e)) + return None + async def _execute_select_pydantic_model( self, model_type: Type[BaseModel], sql_command: TextClause - ) -> Optional[BaseModel]: + ) -> Optional[List[BaseModel]]: async with self._async_db_engine.begin() as conn: try: result = await conn.execute(sql_command) - if not result: - return None - rows = [model_type(**row._asdict()) for row in result.fetchall() if row] - return rows + return await self._dump_result_to_pydantic_model(model_type, result) except Exception as e: logger.error(f"Failed to select model: {model_type}.", error=str(e)) return None + async def _exec_select_conditions_to_pydantic( + self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict + ) -> Optional[List[BaseModel]]: + async with self._async_db_engine.begin() as conn: + try: + result = await conn.execute(sql_command, conditions) + return await self._dump_result_to_pydantic_model(model_type, result) + except Exception as e: + logger.error(f"Failed to select model with conditions: {model_type}.", error=str(e)) + return None + async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]: sql = text( """ @@ -286,6 +347,54 @@ async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAnd prompts = await self._execute_select_pydantic_model(GetAlertsWithPromptAndOutputRow, sql) return prompts + async def get_workspaces(self) -> List[WorkspaceActive]: + sql = text( + """ + SELECT + w.id, w.name, s.active_workspace_id + FROM workspaces w + LEFT JOIN sessions s ON w.id = s.active_workspace_id + """ + ) + workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql) + return workspaces + + async def get_workspace_by_name(self, name: str) -> List[Workspace]: + sql = text( + """ + SELECT + id, name + FROM workspaces + WHERE name = :name + """ + ) + conditions = {"name": name} + workspaces = await self._exec_select_conditions_to_pydantic(Workspace, sql, conditions) + return workspaces[0] if workspaces else None + + async def get_sessions(self) -> List[Session]: + sql = text( + """ + SELECT + id, active_workspace_id, last_update + FROM sessions + """ + ) + sessions = await self._execute_select_pydantic_model(Session, sql) + return sessions + + async def get_active_workspace(self) -> Optional[ActiveWorkspace]: + sql = text( + """ + SELECT + w.id, w.name, s.id as session_id, s.last_update + FROM sessions s + INNER JOIN workspaces w ON w.id = s.active_workspace_id + """ + ) + active_workspace = await self._execute_select_pydantic_model(ActiveWorkspace, sql) + return active_workspace[0] if active_workspace else None + def init_db_sync(db_path: Optional[str] = None): """DB will be initialized in the constructor in case it doesn't exist.""" @@ -307,5 +416,23 @@ def init_db_sync(db_path: Optional[str] = None): logger.info("DB initialized successfully.") +def init_session_if_not_exists(db_path: Optional[str] = None): + import datetime + + db_reader = DbReader(db_path) + sessions = asyncio.run(db_reader.get_sessions()) + # If there are no sessions, create a new one + # TODO: For the moment there's a single session. If it already exists, we don't create a new one + if not sessions: + session = Session( + id=str(uuid.uuid4()), + active_workspace_id="1", + last_update=datetime.datetime.now(datetime.timezone.utc), + ) + db_recorder = DbRecorder(db_path) + asyncio.run(db_recorder.update_session(session)) + logger.info("Session in DB initialized successfully.") + + if __name__ == "__main__": init_db_sync() diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 22859573..fe5dbb68 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -1,9 +1,11 @@ +import datetime +import re from typing import Any, Optional -import pydantic +from pydantic import BaseModel, field_validator -class Alert(pydantic.BaseModel): +class Alert(BaseModel): id: Any prompt_id: Any code_snippet: Optional[Any] @@ -13,22 +15,23 @@ class Alert(pydantic.BaseModel): timestamp: Any -class Output(pydantic.BaseModel): +class Output(BaseModel): id: Any prompt_id: Any timestamp: Any output: Any -class Prompt(pydantic.BaseModel): +class Prompt(BaseModel): id: Any timestamp: Any provider: Optional[Any] request: Any type: Any + workspace_id: Optional[str] -class Setting(pydantic.BaseModel): +class Setting(BaseModel): id: Any ip: Optional[Any] port: Optional[Any] @@ -37,10 +40,28 @@ class Setting(pydantic.BaseModel): other_settings: Optional[Any] +class Workspace(BaseModel): + id: str + name: str + + @field_validator("name", mode="plain") + @classmethod + def name_must_be_alphanumeric(cls, value): + if not re.match(r"^[a-zA-Z0-9_-]+$", value): + raise ValueError("name must be alphanumeric and can only contain _ and -") + return value + + +class Session(BaseModel): + id: str + active_workspace_id: str + last_update: datetime.datetime + + # Models for select queries -class GetAlertsWithPromptAndOutputRow(pydantic.BaseModel): +class GetAlertsWithPromptAndOutputRow(BaseModel): id: Any prompt_id: Any code_snippet: Optional[Any] @@ -57,7 +78,7 @@ class GetAlertsWithPromptAndOutputRow(pydantic.BaseModel): output_timestamp: Optional[Any] -class GetPromptWithOutputsRow(pydantic.BaseModel): +class GetPromptWithOutputsRow(BaseModel): id: Any timestamp: Any provider: Optional[Any] @@ -66,3 +87,16 @@ class GetPromptWithOutputsRow(pydantic.BaseModel): output_id: Optional[Any] output: Optional[Any] output_timestamp: Optional[Any] + + +class WorkspaceActive(BaseModel): + id: str + name: str + active_workspace_id: Optional[str] + + +class ActiveWorkspace(BaseModel): + id: str + name: str + session_id: str + last_update: datetime.datetime diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index e22b2915..f0e13196 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -135,6 +135,7 @@ def add_input_request( provider=provider, type="fim" if is_fim_request else "chat", request=request_str, + workspace_id=None, ) # Uncomment the below to debug the input # logger.debug(f"Added input request to context: {self.input_request}") diff --git a/src/codegate/pipeline/extract_snippets/extract_snippets.py b/src/codegate/pipeline/extract_snippets/extract_snippets.py index 06e06bea..8f7ebbd7 100644 --- a/src/codegate/pipeline/extract_snippets/extract_snippets.py +++ b/src/codegate/pipeline/extract_snippets/extract_snippets.py @@ -124,10 +124,8 @@ def extract_snippets(message: str) -> List[CodeSnippet]: lang = None #  just correct the typescript exception - lang_map = { - "typescript": "javascript" - } - lang = lang_map.get(lang, lang) + lang_map = {"typescript": "javascript"} + lang = lang_map.get(lang, lang) snippets.append(CodeSnippet(filepath=filename, code=content, language=lang)) return snippets diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index 7a713332..cd14df8e 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -14,6 +14,7 @@ ) from codegate.pipeline.system_prompt.codegate import SystemPrompt from codegate.pipeline.version.version import CodegateVersion +from codegate.pipeline.workspace.workspace import CodegateWorkspace class PipelineFactory: @@ -28,6 +29,7 @@ def create_input_pipeline(self) -> SequentialPipelineProcessor: # later steps CodegateSecrets(), CodegateVersion(), + CodegateWorkspace(), CodeSnippetExtractor(), CodegateContextRetriever(), SystemPrompt(Config.get_config().prompts.default_chat), diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 49f0627a..fce826c1 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -366,7 +366,7 @@ async def process_chunk( if match: # Found a complete marker, process it encrypted_value = match.group(1) - if encrypted_value.startswith('$'): + if encrypted_value.startswith("$"): encrypted_value = encrypted_value[1:] original_value = input_context.sensitive.manager.get_original_value( encrypted_value, diff --git a/src/codegate/pipeline/workspace/__init__.py b/src/codegate/pipeline/workspace/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/codegate/pipeline/workspace/commands.py b/src/codegate/pipeline/workspace/commands.py new file mode 100644 index 00000000..9651db8f --- /dev/null +++ b/src/codegate/pipeline/workspace/commands.py @@ -0,0 +1,157 @@ +import datetime +from typing import Optional, Tuple + +from codegate.db.connection import DbReader, DbRecorder +from codegate.db.models import Session, Workspace + + +class WorkspaceCrud: + + def __init__(self): + self._db_reader = DbReader() + + async def add_workspace(self, new_workspace_name: str) -> bool: + """ + Add a workspace + + Args: + name (str): The name of the workspace + """ + db_recorder = DbRecorder() + workspace_created = await db_recorder.add_workspace(new_workspace_name) + return bool(workspace_created) + + async def get_workspaces(self): + """ + Get all workspaces + """ + return await self._db_reader.get_workspaces() + + async def _is_workspace_active_or_not_exist( + self, workspace_name: str + ) -> Tuple[bool, Optional[Session], Optional[Workspace]]: + """ + Check if the workspace is active + + Will return: + - True if the workspace was activated + - False if the workspace is already active or does not exist + """ + selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name) + if not selected_workspace: + return True, None, None + + sessions = await self._db_reader.get_sessions() + # The current implementation expects only one active session + if len(sessions) != 1: + raise RuntimeError("Something went wrong. No active session found.") + + session = sessions[0] + if session.active_workspace_id == selected_workspace.id: + return True, None, None + return False, session, selected_workspace + + async def activate_workspace(self, workspace_name: str) -> bool: + """ + Activate a workspace + + Will return: + - True if the workspace was activated + - False if the workspace is already active or does not exist + """ + is_active, session, workspace = await self._is_workspace_active_or_not_exist(workspace_name) + if is_active: + return False + + session.active_workspace_id = workspace.id + session.last_update = datetime.datetime.now(datetime.timezone.utc) + db_recorder = DbRecorder() + await db_recorder.update_session(session) + return True + + +class WorkspaceCommands: + + def __init__(self): + self.workspace_crud = WorkspaceCrud() + self.commands = { + "list": self._list_workspaces, + "add": self._add_workspace, + "activate": self._activate_workspace, + } + + async def _list_workspaces(self, *args) -> str: + """ + List all workspaces + """ + workspaces = await self.workspace_crud.get_workspaces() + respond_str = "" + for workspace in workspaces: + respond_str += f"- {workspace.name}" + if workspace.active_workspace_id: + respond_str += " **(active)**" + respond_str += "\n" + return respond_str + + async def _add_workspace(self, *args) -> str: + """ + Add a workspace + """ + if args is None or len(args) == 0: + return "Please provide a name. Use `codegate-workspace add your_workspace_name`" + + new_workspace_name = args[0] + if not new_workspace_name: + return "Please provide a name. Use `codegate-workspace add your_workspace_name`" + + workspace_created = await self.workspace_crud.add_workspace(new_workspace_name) + if not workspace_created: + return ( + "Something went wrong. Workspace could not be added.\n" + "1. Check if the name is alphanumeric and only contains dashes, and underscores.\n" + "2. Check if the workspace already exists." + ) + return f"Workspace **{new_workspace_name}** has been added" + + async def _activate_workspace(self, *args) -> str: + """ + Activate a workspace + """ + if args is None or len(args) == 0: + return "Please provide a name. Use `codegate-workspace activate workspace_name`" + + workspace_name = args[0] + if not workspace_name: + return "Please provide a name. Use `codegate-workspace activate workspace_name`" + + was_activated = await self.workspace_crud.activate_workspace(workspace_name) + if not was_activated: + return ( + f"Workspace **{workspace_name}** does not exist or was already active. " + f"Use `codegate-workspace add {workspace_name}` to add it" + ) + return f"Workspace **{workspace_name}** has been activated" + + async def execute(self, command: str, *args) -> str: + """ + Execute the given command + + Args: + command (str): The command to execute + """ + command_to_execute = self.commands.get(command) + if command_to_execute is not None: + return await command_to_execute(*args) + else: + return "Command not found" + + async def parse_execute_cmd(self, last_user_message: str) -> str: + """ + Parse the last user message and execute the command + + Args: + last_user_message (str): The last user message + """ + command_and_args = last_user_message.lower().split("codegate-workspace ")[1] + command, *args = command_and_args.split(" ") + return await self.execute(command, *args) diff --git a/src/codegate/pipeline/workspace/workspace.py b/src/codegate/pipeline/workspace/workspace.py new file mode 100644 index 00000000..5cde5177 --- /dev/null +++ b/src/codegate/pipeline/workspace/workspace.py @@ -0,0 +1,58 @@ +from litellm import ChatCompletionRequest + +from codegate.pipeline.base import ( + PipelineContext, + PipelineResponse, + PipelineResult, + PipelineStep, +) +from codegate.pipeline.workspace.commands import WorkspaceCommands + + +class CodegateWorkspace(PipelineStep): + """Pipeline step that handles workspace information requests.""" + + @property + def name(self) -> str: + """ + Returns the name of this pipeline step. + + Returns: + str: The identifier 'codegate-workspace' + """ + return "codegate-workspace" + + async def process( + self, request: ChatCompletionRequest, context: PipelineContext + ) -> PipelineResult: + """ + Checks if the last user message contains "codegate-workspace" and + responds with command specified. + This short-circuits the pipeline if the message is found. + + Args: + request (ChatCompletionRequest): The chat completion request to process + context (PipelineContext): The current pipeline context + + Returns: + PipelineResult: Contains workspace response if triggered, otherwise continues + pipeline + """ + last_user_message = self.get_last_user_message(request) + + if last_user_message is not None: + last_user_message_str, _ = last_user_message + if "codegate-workspace" in last_user_message_str.lower(): + context.shortcut_response = True + command_output = await WorkspaceCommands().parse_execute_cmd(last_user_message_str) + return PipelineResult( + response=PipelineResponse( + step_name=self.name, + content=command_output, + model=request["model"], + ), + context=context, + ) + + # Fall through + return PipelineResult(request=request, context=context) diff --git a/src/codegate/providers/ollama/completion_handler.py b/src/codegate/providers/ollama/completion_handler.py index bccd4992..4c48f614 100644 --- a/src/codegate/providers/ollama/completion_handler.py +++ b/src/codegate/providers/ollama/completion_handler.py @@ -1,4 +1,3 @@ -import json from typing import AsyncIterator, Optional, Union import structlog @@ -11,9 +10,7 @@ logger = structlog.get_logger("codegate") -async def ollama_stream_generator( - stream: AsyncIterator[ChatResponse] -) -> AsyncIterator[str]: +async def ollama_stream_generator(stream: AsyncIterator[ChatResponse]) -> AsyncIterator[str]: """OpenAI-style SSE format""" try: async for chunk in stream: diff --git a/tests/pipeline/workspace/test_workspace.py b/tests/pipeline/workspace/test_workspace.py new file mode 100644 index 00000000..85f10edc --- /dev/null +++ b/tests/pipeline/workspace/test_workspace.py @@ -0,0 +1,125 @@ +import datetime +from unittest.mock import AsyncMock, patch + +import pytest + +from codegate.db.models import Session, Workspace, WorkspaceActive +from codegate.pipeline.workspace.commands import WorkspaceCommands, WorkspaceCrud + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_workspaces, expected_output", + [ + # Case 1: No workspaces + ([], ""), + # Case 2: One workspace active + ( + [ + # We'll make a MagicMock that simulates a workspace + # with 'name' attribute and 'active_workspace_id' set + WorkspaceActive(id="1", name="Workspace1", active_workspace_id="100") + ], + "- Workspace1 **(active)**\n", + ), + # Case 3: Multiple workspaces, second one active + ( + [ + WorkspaceActive(id="1", name="Workspace1", active_workspace_id=None), + WorkspaceActive(id="2", name="Workspace2", active_workspace_id="200"), + ], + "- Workspace1\n- Workspace2 **(active)**\n", + ), + ], +) +async def test_list_workspaces(mock_workspaces, expected_output): + """ + Test _list_workspaces with different sets of returned workspaces. + """ + workspace_commands = WorkspaceCommands() + + # Mock DbReader inside workspace_commands + mock_get_workspaces = AsyncMock(return_value=mock_workspaces) + workspace_commands.workspace_crud.get_workspaces = mock_get_workspaces + + # Call the method + result = await workspace_commands._list_workspaces() + + # Check the result + assert result == expected_output + mock_get_workspaces.assert_awaited_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "args, existing_workspaces, expected_message", + [ + # Case 1: No workspace name provided + ([], [], "Please provide a name. Use `codegate-workspace add your_workspace_name`"), + # Case 2: Workspace name is empty string + ([""], [], "Please provide a name. Use `codegate-workspace add your_workspace_name`"), + # Case 3: Successful add + (["myworkspace"], [], "Workspace **myworkspace** has been added"), + ], +) +async def test_add_workspaces(args, existing_workspaces, expected_message): + """ + Test _add_workspace under different scenarios: + - no args + - empty string arg + - workspace already exists + - workspace successfully added + """ + workspace_commands = WorkspaceCommands() + + # Mock the DbReader to return existing_workspaces + mock_db_reader = AsyncMock() + mock_db_reader.get_workspace_by_name.return_value = existing_workspaces + workspace_commands._db_reader = mock_db_reader + + # We'll also patch DbRecorder to ensure no real DB operations happen + with patch( + "codegate.pipeline.workspace.commands.DbRecorder", autospec=True + ) as mock_recorder_cls: + mock_recorder = mock_recorder_cls.return_value + mock_recorder.add_workspace = AsyncMock() + + # Call the method + result = await workspace_commands._add_workspace(*args) + + # Assertions + assert result == expected_message + + # If expected_message indicates "added", we expect add_workspace to be called once + if "has been added" in expected_message: + mock_recorder.add_workspace.assert_awaited_once_with(args[0]) + else: + mock_recorder.add_workspace.assert_not_awaited() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "user_message, expected_command, expected_args, mocked_execute_response", + [ + ("codegate-workspace list", "list", [], "List workspaces output"), + ("codegate-workspace add myws", "add", ["myws"], "Added workspace"), + ("codegate-workspace activate myws", "activate", ["myws"], "Activated workspace"), + ], +) +async def test_parse_execute_cmd( + user_message, expected_command, expected_args, mocked_execute_response +): + """ + Test parse_execute_cmd to ensure it parses the user message + and calls the correct command with the correct args. + """ + workspace_commands = WorkspaceCommands() + + with patch.object( + workspace_commands, "execute", return_value=mocked_execute_response + ) as mock_execute: + result = await workspace_commands.parse_execute_cmd(user_message) + assert result == mocked_execute_response + + # Verify 'execute' was called with the expected command and args + mock_execute.assert_awaited_once_with(expected_command, *expected_args)