44 lines
1.1 KiB
Python
44 lines
1.1 KiB
Python
import logging
|
|
from contextlib import contextmanager
|
|
from typing import Generator
|
|
|
|
from sqlalchemy.exc import DatabaseError as SqlAlchemyError
|
|
from sqlmodel import Session, create_engine, select, SQLModel
|
|
|
|
from app.core.config import settings
|
|
|
|
from app.database.exceptions import DatabaseError
|
|
|
|
import app.database.models
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.info("Creating engine")
|
|
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
|
|
SQLModel.metadata.create_all(engine)
|
|
|
|
|
|
def test_connection():
|
|
logger.debug("Testing database connection")
|
|
try:
|
|
with Session(engine) as session:
|
|
session.exec(select(1))
|
|
logger.debug("Database connection successful")
|
|
except SqlAlchemyError as e:
|
|
logger.critical("Database connection failed: %s", e)
|
|
raise DatabaseError("Database connection failed", DatabaseError.CONNECTION_ERROR) from e
|
|
|
|
|
|
def cleanup() -> None:
|
|
logger.debug("Closing connection")
|
|
engine.dispose()
|
|
|
|
|
|
@contextmanager
|
|
def get_session() -> Generator[Session, None, None]:
|
|
with Session(engine) as session:
|
|
yield session
|
|
|
|
|
|
test_connection()
|