Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workspaces to OpenAPI spec #634

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
364 changes: 352 additions & 12 deletions api/openapi.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
codegate = "codegate.cli:main"
generate-openapi = "src.codegate.dashboard.dashboard:generate_openapi"
generate-openapi = "src.codegate.server:generate_openapi"

[tool.black]
line-length = 100
Expand Down
Original file line number Diff line number Diff line change
@@ -1,43 +1,49 @@
import asyncio
import json
from typing import AsyncGenerator, List, Optional

import requests
import structlog
from fastapi import APIRouter, Depends, FastAPI
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRoute

from codegate import __version__
from codegate.dashboard.post_processing import (
from codegate.api.dashboard.post_processing import (
parse_get_alert_conversation,
parse_messages_in_conversations,
)
from codegate.dashboard.request_models import AlertConversation, Conversation
from codegate.api.dashboard.request_models import AlertConversation, Conversation
from codegate.db.connection import DbReader, alert_queue

logger = structlog.get_logger("codegate")

dashboard_router = APIRouter(tags=["Dashboard"])
dashboard_router = APIRouter()
db_reader = None


def uniq_name(route: APIRoute):
return f"v1_{route.name}"


def get_db_reader():
global db_reader
if db_reader is None:
db_reader = DbReader()
return db_reader


def fetch_latest_version() -> str:
url = "https://api.github.com/repos/stacklok/codegate/releases/latest"
headers = {
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28"
}
headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
response = requests.get(url, headers=headers, timeout=5)
response.raise_for_status()
data = response.json()
return data.get("tag_name", "unknown")

@dashboard_router.get("/dashboard/messages")

@dashboard_router.get(
"/dashboard/messages", tags=["Dashboard"], generate_unique_id_function=uniq_name
)
def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversation]:
"""
Get all the messages from the database and return them as a list of conversations.
Expand All @@ -47,7 +53,9 @@ def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversat
return asyncio.run(parse_messages_in_conversations(prompts_outputs))


@dashboard_router.get("/dashboard/alerts")
@dashboard_router.get(
"/dashboard/alerts", tags=["Dashboard"], generate_unique_id_function=uniq_name
)
def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[Optional[AlertConversation]]:
"""
Get all the messages from the database and return them as a list of conversations.
Expand All @@ -65,21 +73,26 @@ async def generate_sse_events() -> AsyncGenerator[str, None]:
yield f"data: {message}\n\n"


@dashboard_router.get("/dashboard/alerts_notification")
@dashboard_router.get(
"/dashboard/alerts_notification", tags=["Dashboard"], generate_unique_id_function=uniq_name
)
async def stream_sse():
"""
Send alerts event
"""
return StreamingResponse(generate_sse_events(), media_type="text/event-stream")

@dashboard_router.get("/dashboard/version")

@dashboard_router.get(
"/dashboard/version", tags=["Dashboard"], generate_unique_id_function=uniq_name
)
def version_check():
try:
latest_version = fetch_latest_version()

# normalize the versions as github will return them with a 'v' prefix
current_version = __version__.lstrip('v')
latest_version_stripped = latest_version.lstrip('v')
current_version = __version__.lstrip("v")
latest_version_stripped = latest_version.lstrip("v")

is_latest: bool = latest_version_stripped == current_version

Expand All @@ -95,28 +108,13 @@ def version_check():
"current_version": __version__,
"latest_version": "unknown",
"is_latest": None,
"error": "An error occurred while fetching the latest version"
"error": "An error occurred while fetching the latest version",
}
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return {
"current_version": __version__,
"latest_version": "unknown",
"is_latest": None,
"error": "An unexpected error occurred"
"error": "An unexpected error occurred",
}


def generate_openapi():
# Create a temporary FastAPI app instance
app = FastAPI()

# Include your defined router
app.include_router(dashboard_router)

# Generate OpenAPI JSON
openapi_schema = app.openapi()

# Convert the schema to JSON string for easier handling or storage
openapi_json = json.dumps(openapi_schema, indent=2)
print(openapi_json)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import structlog

from codegate.dashboard.request_models import (
from codegate.api.dashboard.request_models import (
AlertConversation,
ChatMessage,
Conversation,
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from codegate.api import v1_models
from codegate.db.connection import AlreadyExistsError
from codegate.workspaces.crud import WorkspaceCrud
from codegate.api.dashboard.dashboard import dashboard_router

v1 = APIRouter()
v1.include_router(dashboard_router)

wscrud = WorkspaceCrud()


Expand Down
2 changes: 1 addition & 1 deletion src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, sqlite_path: Optional[str] = None):
)
self._db_path = Path(sqlite_path).absolute()
self._db_path.parent.mkdir(parents=True, exist_ok=True)
logger.debug(f"Connecting to DB from path: {self._db_path}")
# logger.debug(f"Connecting to DB from path: {self._db_path}")
engine_dict = {
"url": f"sqlite+aiosqlite:///{self._db_path}",
"echo": False, # Set to False in production
Expand Down
15 changes: 13 additions & 2 deletions src/codegate/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import traceback
from unittest.mock import Mock

import structlog
from fastapi import APIRouter, FastAPI, Request
Expand All @@ -8,7 +10,6 @@

from codegate import __description__, __version__
from codegate.api.v1 import v1
from codegate.dashboard.dashboard import dashboard_router
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.anthropic.provider import AnthropicProvider
from codegate.providers.llamacpp.provider import LlamaCppProvider
Expand Down Expand Up @@ -96,9 +97,19 @@ async def health_check():
return {"status": "healthy"}

app.include_router(system_router)
app.include_router(dashboard_router)

# CodeGate API
app.include_router(v1, prefix="/api/v1", tags=["CodeGate API"])

return app


def generate_openapi():
app = init_app(Mock(spec=PipelineFactory))

# Generate OpenAPI JSON
openapi_schema = app.openapi()

# Convert the schema to JSON string for easier handling or storage
openapi_json = json.dumps(openapi_schema, indent=2)
print(openapi_json)
8 changes: 4 additions & 4 deletions tests/dashboard/test_post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import pytest

from codegate.dashboard.post_processing import (
from codegate.api.dashboard.post_processing import (
_get_question_answer,
_group_partial_messages,
_is_system_prompt,
parse_output,
parse_request,
)
from codegate.dashboard.request_models import (
from codegate.api.dashboard.request_models import (
PartialQuestions,
)
from codegate.db.models import GetPromptWithOutputsRow
Expand Down Expand Up @@ -162,10 +162,10 @@ async def test_parse_output(output_dict, expected_str):
)
async def test_get_question_answer(request_msg_list, output_msg_str, row):
with patch(
"codegate.dashboard.post_processing.parse_request", new_callable=AsyncMock
"codegate.api.dashboard.post_processing.parse_request", new_callable=AsyncMock
) as mock_parse_request:
with patch(
"codegate.dashboard.post_processing.parse_output", new_callable=AsyncMock
"codegate.api.dashboard.post_processing.parse_output", new_callable=AsyncMock
) as mock_parse_output:
# Set return values for the mocks
mock_parse_request.return_value = request_msg_list
Expand Down
6 changes: 3 additions & 3 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def test_health_check(test_client: TestClient) -> None:
assert response.status_code == 200
assert response.json() == {"status": "healthy"}

@patch("codegate.dashboard.dashboard.fetch_latest_version", return_value="foo")
@patch("codegate.api.dashboard.dashboard.fetch_latest_version", return_value="foo")
def test_version_endpoint(mock_fetch_latest_version, test_client: TestClient) -> None:
"""Test the version endpoint."""
response = test_client.get("/dashboard/version")
response = test_client.get("/api/v1/dashboard/version")
assert response.status_code == 200

response_data = response.json()
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_dashboard_routes(mock_pipeline_factory) -> None:
routes = [route.path for route in app.routes]

# Verify dashboard endpoints are included
dashboard_routes = [route for route in routes if route.startswith("/dashboard")]
dashboard_routes = [route for route in routes if route.startswith("/api/v1/dashboard")]
assert len(dashboard_routes) > 0


Expand Down
Loading