Skip to content

Commit

Permalink
feat: Initial migration for Workspaces and pipeline step (#600)
Browse files Browse the repository at this point in the history
* feat: Initial migration for Workspaces and pipeline step

Related: #454

We noticed most of the incoming-requests which contain a code-snippet
only list a relative path with respect to where the code editor is
opened. This would make difficult to accurately distinguish between
repositories in Codegate. For example, a user could open 2 different
code Python repositorites in different session and both repositories
contain a `pyproject.toml`. It would be impossible for Codegate to
determine the real repository of the file only using the relative path.

Hence, the initial implementation of Workspaces will rely on a pipeline
step that is able to take commands a process them. Some commands could be:
- List workspaces
- Add workspace
- Switch active workspace
- Delete workspace

It would be up to the user to select the desired active workspace.

This PR introduces an initial migration for Workspaces and the pipeline
step with the `list` command.

* Reformatting changes

* Make unique workspaces name

* Introduced Sessions table and added add and activate commands

* Formatting changes and unit tests

* Classes separation into a different file
  • Loading branch information
aponcedeleonch authored Jan 16, 2025
1 parent b4d719f commit 147205a
Show file tree
Hide file tree
Showing 14 changed files with 589 additions and 26 deletions.
61 changes: 61 additions & 0 deletions migrations/versions/5c2f3eee5f90_introduce_workspaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""introduce workspaces
Revision ID: 5c2f3eee5f90
Revises: 30d0144e1a50
Create Date: 2025-01-15 19:27:08.230296
"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "5c2f3eee5f90"
down_revision: Union[str, None] = "30d0144e1a50"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Workspaces table
op.execute(
"""
CREATE TABLE workspaces (
id TEXT PRIMARY KEY, -- UUID stored as TEXT
name TEXT NOT NULL,
UNIQUE (name)
);
"""
)
op.execute("INSERT INTO workspaces (id, name) VALUES ('1', 'default');")
# Sessions table
op.execute(
"""
CREATE TABLE sessions (
id TEXT PRIMARY KEY, -- UUID stored as TEXT
active_workspace_id TEXT NOT NULL,
last_update DATETIME NOT NULL,
FOREIGN KEY (active_workspace_id) REFERENCES workspaces(id)
);
"""
)
# Alter table prompts
op.execute("ALTER TABLE prompts ADD COLUMN workspace_id TEXT REFERENCES workspaces(id);")
op.execute("UPDATE prompts SET workspace_id = '1';")
# Create index for workspace_id
op.execute("CREATE INDEX idx_prompts_workspace_id ON prompts (workspace_id);")
# Create index for session_id
op.execute("CREATE INDEX idx_sessions_workspace_id ON sessions (active_workspace_id);")


def downgrade() -> None:
# Drop the index for workspace_id
op.execute("DROP INDEX IF EXISTS idx_prompts_workspace_id;")
op.execute("DROP INDEX IF EXISTS idx_sessions_workspace_id;")
# Remove the workspace_id column from prompts table
op.execute("ALTER TABLE prompts DROP COLUMN workspace_id;")
# Drop the sessions table
op.execute("DROP TABLE IF EXISTS sessions;")
# Drop the workspaces table
op.execute("DROP TABLE IF EXISTS workspaces;")
2 changes: 2 additions & 0 deletions src/codegate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_VERSION = "dev"
_DESC = "CodeGate - A Generative AI security gateway."


def __get_version_and_description() -> tuple[str, str]:
try:
version = metadata.version("codegate")
Expand All @@ -19,6 +20,7 @@ def __get_version_and_description() -> tuple[str, str]:
description = _DESC
return version, description


__version__, __description__ = __get_version_and_description()

__all__ = ["Config", "ConfigurationError", "LogFormat", "LogLevel", "setup_logging"]
Expand Down
3 changes: 2 additions & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from codegate.ca.codegate_ca import CertificateAuthority
from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
from codegate.config import Config, ConfigurationError
from codegate.db.connection import init_db_sync
from codegate.db.connection import init_db_sync, init_session_if_not_exists
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.providers.copilot.provider import CopilotProvider
Expand Down Expand Up @@ -307,6 +307,7 @@ def serve(
logger = structlog.get_logger("codegate").bind(origin="cli")

init_db_sync(cfg.db_path)
init_session_if_not_exists(cfg.db_path)

# Check certificates and create CA if necessary
logger.info("Checking certificates and creating CA if needed")
Expand Down
145 changes: 136 additions & 9 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import asyncio
import json
import uuid
from pathlib import Path
from typing import List, Optional, Type

import structlog
from alembic import command as alembic_command
from alembic.config import Config as AlembicConfig
from pydantic import BaseModel
from sqlalchemy import TextClause, text
from pydantic import BaseModel, ValidationError
from sqlalchemy import CursorResult, TextClause, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.db.fim_cache import FimCache
from codegate.db.models import (
ActiveWorkspace,
Alert,
GetAlertsWithPromptAndOutputRow,
GetPromptWithOutputsRow,
Output,
Prompt,
Session,
Workspace,
WorkspaceActive,
)
from codegate.pipeline.base import PipelineContext

Expand Down Expand Up @@ -75,10 +80,14 @@ async def _execute_update_pydantic_model(
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
if prompt_params is None:
return None
# Get the active workspace to store the request
active_workspace = await DbReader().get_active_workspace()
workspace_id = active_workspace.id if active_workspace else "1"
prompt_params.workspace_id = workspace_id
sql = text(
"""
INSERT INTO prompts (id, timestamp, provider, request, type)
VALUES (:id, :timestamp, :provider, :request, :type)
INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id)
VALUES (:id, :timestamp, :provider, :request, :type, :workspace_id)
RETURNING *
"""
)
Expand Down Expand Up @@ -223,26 +232,78 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
except Exception as e:
logger.error(f"Failed to record context: {context}.", error=str(e))

async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
try:
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)
except ValidationError as e:
logger.error(f"Failed to create workspace with name: {workspace_name}: {str(e)}")
return None

sql = text(
"""
INSERT INTO workspaces (id, name)
VALUES (:id, :name)
RETURNING *
"""
)
added_workspace = await self._execute_update_pydantic_model(workspace, sql)
return added_workspace

async def update_session(self, session: Session) -> Optional[Session]:
sql = text(
"""
INSERT INTO sessions (id, active_workspace_id, last_update)
VALUES (:id, :active_workspace_id, :last_update)
ON CONFLICT (id) DO UPDATE SET
active_workspace_id = excluded.active_workspace_id, last_update = excluded.last_update
WHERE id = excluded.id
RETURNING *
"""
)
# We only pass an object to respect the signature of the function
active_session = await self._execute_update_pydantic_model(session, sql)
return active_session


class DbReader(DbCodeGate):

def __init__(self, sqlite_path: Optional[str] = None):
super().__init__(sqlite_path)

async def _dump_result_to_pydantic_model(
self, model_type: Type[BaseModel], result: CursorResult
) -> Optional[List[BaseModel]]:
try:
if not result:
return None
rows = [model_type(**row._asdict()) for row in result.fetchall() if row]
return rows
except Exception as e:
logger.error(f"Failed to dump to pydantic model: {model_type}.", error=str(e))
return None

async def _execute_select_pydantic_model(
self, model_type: Type[BaseModel], sql_command: TextClause
) -> Optional[BaseModel]:
) -> Optional[List[BaseModel]]:
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql_command)
if not result:
return None
rows = [model_type(**row._asdict()) for row in result.fetchall() if row]
return rows
return await self._dump_result_to_pydantic_model(model_type, result)
except Exception as e:
logger.error(f"Failed to select model: {model_type}.", error=str(e))
return None

async def _exec_select_conditions_to_pydantic(
self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict
) -> Optional[List[BaseModel]]:
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql_command, conditions)
return await self._dump_result_to_pydantic_model(model_type, result)
except Exception as e:
logger.error(f"Failed to select model with conditions: {model_type}.", error=str(e))
return None

async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]:
sql = text(
"""
Expand Down Expand Up @@ -286,6 +347,54 @@ async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAnd
prompts = await self._execute_select_pydantic_model(GetAlertsWithPromptAndOutputRow, sql)
return prompts

async def get_workspaces(self) -> List[WorkspaceActive]:
sql = text(
"""
SELECT
w.id, w.name, s.active_workspace_id
FROM workspaces w
LEFT JOIN sessions s ON w.id = s.active_workspace_id
"""
)
workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql)
return workspaces

async def get_workspace_by_name(self, name: str) -> List[Workspace]:
sql = text(
"""
SELECT
id, name
FROM workspaces
WHERE name = :name
"""
)
conditions = {"name": name}
workspaces = await self._exec_select_conditions_to_pydantic(Workspace, sql, conditions)
return workspaces[0] if workspaces else None

async def get_sessions(self) -> List[Session]:
sql = text(
"""
SELECT
id, active_workspace_id, last_update
FROM sessions
"""
)
sessions = await self._execute_select_pydantic_model(Session, sql)
return sessions

async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
sql = text(
"""
SELECT
w.id, w.name, s.id as session_id, s.last_update
FROM sessions s
INNER JOIN workspaces w ON w.id = s.active_workspace_id
"""
)
active_workspace = await self._execute_select_pydantic_model(ActiveWorkspace, sql)
return active_workspace[0] if active_workspace else None


def init_db_sync(db_path: Optional[str] = None):
"""DB will be initialized in the constructor in case it doesn't exist."""
Expand All @@ -307,5 +416,23 @@ def init_db_sync(db_path: Optional[str] = None):
logger.info("DB initialized successfully.")


def init_session_if_not_exists(db_path: Optional[str] = None):
import datetime

db_reader = DbReader(db_path)
sessions = asyncio.run(db_reader.get_sessions())
# If there are no sessions, create a new one
# TODO: For the moment there's a single session. If it already exists, we don't create a new one
if not sessions:
session = Session(
id=str(uuid.uuid4()),
active_workspace_id="1",
last_update=datetime.datetime.now(datetime.timezone.utc),
)
db_recorder = DbRecorder(db_path)
asyncio.run(db_recorder.update_session(session))
logger.info("Session in DB initialized successfully.")


if __name__ == "__main__":
init_db_sync()
Loading

0 comments on commit 147205a

Please sign in to comment.