diff --git a/src/bank_node/bank_node.py b/src/bank_node/bank_node.py index 0fd6108..24dc8b8 100644 --- a/src/bank_node/bank_node.py +++ b/src/bank_node/bank_node.py @@ -5,26 +5,46 @@ import logging from core.config import BankNodeConfig +from core.exceptions import ConfigError from utils import setup_logger - +from database.database_manager import DatabaseManager class BankNode(): def __init__(self): - setup_logger() - self.logger = logging.getLogger(__name__) - self._setup_signals() - self.config = BankNodeConfig() + try: + self.config = BankNodeConfig() + + setup_logger(self.config.verbosity) + self.logger = logging.getLogger(__name__) + + self.logger.info("Config is valid") + self.logger.debug("Starting Bank Node") + + self._setup_signals() + + self.database_manager = DatabaseManager() + except ConfigError as e: + print(e) + self.exit_with_error() def _setup_signals(self): self.logger.debug("Setting up exit signal hooks") + + # Handle C standard signals signal.signal(signal.SIGTERM, self.gracefully_exit) signal.signal(signal.SIGINT, self.gracefully_exit) + # Handle windows related signals + if sys.platform == "win32": + signal.signal(signal.CTRL_C_EVENT, self.gracefully_exit) + signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit) + signal.signal(signal.SIGBREAK, self.gracefully_exit) + def start_server(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server: - socket_server.bind(("127.0.0.1", 6969)) + socket_server.bind(("127.0.0.1", self.config.port)) socket_server.listen() - + self.logger.info("Listening on %s:%s", "127.0.0.1", self.config.port) while True: conn, addr = socket_server.accept() with conn: @@ -32,6 +52,10 @@ class BankNode(): request = conn.recv(1024).decode("utf-8") print(f"Request:\n{request}") + def exit_with_error(self): + self.cleanup() + sys.exit(1) + def cleanup(self): pass diff --git a/src/bank_protocol/exceptions.py b/src/bank_protocol/exceptions.py index e69de29..a4e1a02 100644 --- a/src/bank_protocol/exceptions.py +++ b/src/bank_protocol/exceptions.py @@ -0,0 +1,8 @@ + +from core.exceptions import BankNodeError + + +class InvalidRequest(BankNodeError): + def __init__(self, message): + super().__init__(message) + self.message = message diff --git a/src/core/config.py b/src/core/config.py index 2033f85..d3b685d 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -3,15 +3,43 @@ import logging import dotenv -from utils import setup_logger +from core.exceptions import ConfigError dotenv.load_dotenv() + class BankNodeConfig: def __init__(self): self.logger = logging.getLogger(__name__) - self.port = os.getenv("PORT", "6969") + port = os.getenv("PORT", "65526") + timeout = os.getenv("RESPONSE_TIMEOUT", "5") + verbosity = os.getenv("VERBOSITY", "INFO") + scan_port_range = os.getenv("SCAN_PORT_RANGE") - self.timeout = os.getenv("RESPONSE_TIMEOUT", "5") + if not scan_port_range or scan_port_range == "": + self.logger.error("Scan port range not defined") + raise ConfigError("Scan port range not defined") - self.verbosity = os.getenv("VERBOSITY", "DEBUG") \ No newline at end of file + range_split = scan_port_range.split(":") + + if len(range_split) != 2 or not range_split[0].isdigit() or not range_split[1].isdigit(): + self.logger.error("Scan port range is not in valid format") + raise ConfigError("Scan port range is not in valid format") + + + if not port.isdigit(): + self.logger.warning("Application port is invalid - Not a number. Falling back to default port 65526") + + if not timeout.isdigit(): + self.logger.warning("Request timeout is invalid - Not a number. Falling back to default timeout 5 seconds") + + allowed_levels = logging.getLevelNamesMapping() + + if verbosity not in allowed_levels: + self.logger.warning("Verbosity %s is invalid - Doesn't exist. Falling back to default verbosity INFO", verbosity) + + self.port = int(port) + self.timeout = int(timeout) + self.verbosity = verbosity + self.scan_port_start = int(range_split[0]) + self.scan_port_end = int(range_split[1]) diff --git a/src/core/exceptions.py b/src/core/exceptions.py index add9c37..e3e6aab 100644 --- a/src/core/exceptions.py +++ b/src/core/exceptions.py @@ -1,13 +1,12 @@ -class BankProtocolError(Exception): +class BankNodeError(Exception): def __init__(self, message): super().__init__(message) self.message = message -class InvalidRequest(BankProtocolError): +class ConfigError(BankNodeError): def __init__(self, message): super().__init__(message) self.message = message - -__all__ = ["BankProtocolError", "InvalidRequest"] +__all__ = ["BankNodeError", "ConfigError"] diff --git a/src/database/__init__.py b/src/database/__init__.py index e69de29..1023f20 100644 --- a/src/database/__init__.py +++ b/src/database/__init__.py @@ -0,0 +1,5 @@ +from .database_manager import * + +__all__ = [ + *database_manager.__all__ +] diff --git a/src/database/database_manager.py b/src/database/database_manager.py index e69de29..dbc0aba 100644 --- a/src/database/database_manager.py +++ b/src/database/database_manager.py @@ -0,0 +1,67 @@ +import logging +from typing import Generator + +from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine, text + +from sqlalchemy.exc import DatabaseError + +from database.exceptions import DatabaseConnectionError +from models.base_model import Base + + +class DatabaseManager(): + + _instance: 'DatabaseManager' = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + if hasattr(self, "engine"): + return + + self.logger = logging.getLogger(__name__) + self.logger.info("Initializing Database") + + self.engine = create_engine('sqlite:///bank.db') + + self.Session = sessionmaker(bind=self.engine) + + def create_tables(self): + self.logger.debug("Creating tables") + Base.metadata.create_all(self.engine) + + def cleanup(self) -> None: + self.logger.debug("Closing connection") + self.engine.dispose() + + def test_connection(self) -> bool: + self.logger.debug("Testing database connection") + try: + with self.engine.connect() as connection: + connection.execute(text("select 1")) + self.logger.debug("Database connection successful") + return True + except DatabaseError as e: + self.logger.critical("Database connection failed: %s", e) + raise DatabaseConnectionError("Database connection failed") from e + + return False + + @classmethod + def get_session(cls) -> Generator: + session = cls._instance.Session() + try: + yield session + except Exception as e: + session.rollback() + cls._instance.logger.error("Transaction failed: %s", e) + raise + finally: + session.close() + + +__all__ = ["DatabaseManager"] diff --git a/src/database/exceptions.py b/src/database/exceptions.py new file mode 100644 index 0000000..5e7c802 --- /dev/null +++ b/src/database/exceptions.py @@ -0,0 +1,28 @@ +class DatabaseError(Exception): + def __init__(self, message: str): + super().__init__(message) + self.message = message + + +class DatabaseConnectionError(DatabaseError): + def __init__(self, message: str): + super().__init__(message) + self.message = message + + +class EmptyDatabaseConfigError(Exception): + def __init__(self, message: str, config_name: str): + super().__init__(message) + + self.message = message + self.config_name = config_name + + +class DuplicateEntryError(DatabaseError): + def __init__(self, duplicate_entry_name: str, message: str): + super().__init__(message) + self.duplicate_entry_name = duplicate_entry_name + self.message = message + + +__all__ = ["DatabaseError", "DatabaseConnectionError", "DuplicateEntryError"] diff --git a/src/models/__init__.py b/src/models/__init__.py index e69de29..4158f4f 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -0,0 +1,7 @@ +from .account_model import * +from .base_model import * + +__all__ = [ + *base_model.__all__, + *account_model.__all__ +] diff --git a/src/models/account_model.py b/src/models/account_model.py new file mode 100644 index 0000000..0366cab --- /dev/null +++ b/src/models/account_model.py @@ -0,0 +1,14 @@ +from sqlalchemy import Column, Integer, CheckConstraint + +from .base_model import Base + + +class Account(Base): + __tablename__ = 'account' + __table_args__ = (CheckConstraint('account_number > 10000 and account_number <= 99999'),) + + account_number = Column(Integer, nullable=False, primary_key=True) + balance = Column(Integer) + + +__all__ = ["Account"] diff --git a/src/models/base_model.py b/src/models/base_model.py new file mode 100644 index 0000000..c1ba9fd --- /dev/null +++ b/src/models/base_model.py @@ -0,0 +1,5 @@ +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +__all__ = ["Base"] diff --git a/src/utils/logger.py b/src/utils/logger.py index 00aa200..1afa0b0 100644 --- a/src/utils/logger.py +++ b/src/utils/logger.py @@ -2,15 +2,13 @@ import sys import logging -def setup_logger(): - logger = logging.getLogger() +def setup_logger(verbosity: str): + if verbosity == "DEBUG": + log_format = "[%(levelname)s] - %(name)s:%(lineno)d - %(message)s" + else: + log_format = "[%(levelname)s] - %(message)s" - handler = logging.StreamHandler(sys.stdout) - - formatter = logging.Formatter("[%(levelname)s] - %(name)s:%(lineno)d - %(message)s") - handler.setFormatter(formatter) - - logger.addHandler(handler) + logging.basicConfig(level=verbosity, format=log_format, stream=sys.stdout) __all__ = ["setup_logger"]