diff --git a/pyproject.toml b/pyproject.toml index 5f4076040..4ebbb575d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "jinja2>=3.1.4", "pydantic-ai>=0.0.26", "pydantic-settings>=2.6.1", + "pydantic[email]>=2.10.2", "rich>=13.9.4", "sqlalchemy[asyncio]>=2.0.36", "typer>=0.15.1", diff --git a/src/marvin/__init__.py b/src/marvin/__init__.py index 729af99a2..ebf3a1032 100644 --- a/src/marvin/__init__.py +++ b/src/marvin/__init__.py @@ -2,7 +2,7 @@ # necessary imports from marvin.settings import settings -from marvin.database import ensure_sqlite_memory_tables_exist +from marvin.database import init_database # core classes from marvin.thread import Thread @@ -31,7 +31,8 @@ from marvin.fns.summarize import summarize, summarize_async from marvin.fns.plan import plan, plan_async -ensure_sqlite_memory_tables_exist() +# Initialize the database on import +init_database() __version__ = _version("marvin") diff --git a/src/marvin/database.py b/src/marvin/database.py index 273360a61..a5e0b7427 100644 --- a/src/marvin/database.py +++ b/src/marvin/database.py @@ -254,16 +254,8 @@ def ensure_sqlite_memory_tables_exist(): created if they don't exist. """ - db_url = settings.database_url - if db_url is None: - raise ValueError("Database URL is not configured") - - if db_url == ":memory:" or db_url.endswith(":memory:"): - # We're using run_sync from another module, so keep it as is - asyncio.run(create_db_and_tables(force=False)) - else: - # For non-memory databases, ensure tables exist - asyncio.run(create_db_and_tables(force=False)) + # Call init_database which handles all database types + init_database() @asynccontextmanager @@ -303,3 +295,18 @@ async def create_db_and_tables(*, force: bool = False) -> None: await conn.run_sync(Base.metadata.create_all) logger.debug("Database tables created.") + + +def init_database(): + """Initialize the database. + + This function should be called during application startup to ensure + database tables exist before they are accessed. + """ + from marvin.utilities.logging import get_logger + + logger = get_logger(__name__) + + logger.debug("Initializing database...") + asyncio.run(create_db_and_tables(force=False)) + logger.debug("Database initialization complete.")