Added new signals to catch, models, updated database manager, exceptions and config

This commit is contained in:
Thastertyn 2025-01-29 10:30:28 +01:00
parent a105984aa3
commit eb7f15528e
11 changed files with 206 additions and 23 deletions

View File

@ -5,26 +5,46 @@ import logging
from core.config import BankNodeConfig from core.config import BankNodeConfig
from core.exceptions import ConfigError
from utils import setup_logger from utils import setup_logger
from database.database_manager import DatabaseManager
class BankNode(): class BankNode():
def __init__(self): def __init__(self):
setup_logger() try:
self.logger = logging.getLogger(__name__) self.config = BankNodeConfig()
self._setup_signals()
self.config = BankNodeConfig() setup_logger(self.config.verbosity)
self.logger = logging.getLogger(__name__)
self.logger.info("Config is valid")
self.logger.debug("Starting Bank Node")
self._setup_signals()
self.database_manager = DatabaseManager()
except ConfigError as e:
print(e)
self.exit_with_error()
def _setup_signals(self): def _setup_signals(self):
self.logger.debug("Setting up exit signal hooks") self.logger.debug("Setting up exit signal hooks")
# Handle C standard signals
signal.signal(signal.SIGTERM, self.gracefully_exit) signal.signal(signal.SIGTERM, self.gracefully_exit)
signal.signal(signal.SIGINT, self.gracefully_exit) signal.signal(signal.SIGINT, self.gracefully_exit)
# Handle windows related signals
if sys.platform == "win32":
signal.signal(signal.CTRL_C_EVENT, self.gracefully_exit)
signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit)
signal.signal(signal.SIGBREAK, self.gracefully_exit)
def start_server(self): def start_server(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server:
socket_server.bind(("127.0.0.1", 6969)) socket_server.bind(("127.0.0.1", self.config.port))
socket_server.listen() socket_server.listen()
self.logger.info("Listening on %s:%s", "127.0.0.1", self.config.port)
while True: while True:
conn, addr = socket_server.accept() conn, addr = socket_server.accept()
with conn: with conn:
@ -32,6 +52,10 @@ class BankNode():
request = conn.recv(1024).decode("utf-8") request = conn.recv(1024).decode("utf-8")
print(f"Request:\n{request}") print(f"Request:\n{request}")
def exit_with_error(self):
self.cleanup()
sys.exit(1)
def cleanup(self): def cleanup(self):
pass pass

View File

@ -0,0 +1,8 @@
from core.exceptions import BankNodeError
class InvalidRequest(BankNodeError):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@ -3,15 +3,43 @@ import logging
import dotenv import dotenv
from utils import setup_logger from core.exceptions import ConfigError
dotenv.load_dotenv() dotenv.load_dotenv()
class BankNodeConfig: class BankNodeConfig:
def __init__(self): def __init__(self):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.port = os.getenv("PORT", "6969") port = os.getenv("PORT", "65526")
timeout = os.getenv("RESPONSE_TIMEOUT", "5")
verbosity = os.getenv("VERBOSITY", "INFO")
scan_port_range = os.getenv("SCAN_PORT_RANGE")
self.timeout = os.getenv("RESPONSE_TIMEOUT", "5") if not scan_port_range or scan_port_range == "":
self.logger.error("Scan port range not defined")
raise ConfigError("Scan port range not defined")
self.verbosity = os.getenv("VERBOSITY", "DEBUG") range_split = scan_port_range.split(":")
if len(range_split) != 2 or not range_split[0].isdigit() or not range_split[1].isdigit():
self.logger.error("Scan port range is not in valid format")
raise ConfigError("Scan port range is not in valid format")
if not port.isdigit():
self.logger.warning("Application port is invalid - Not a number. Falling back to default port 65526")
if not timeout.isdigit():
self.logger.warning("Request timeout is invalid - Not a number. Falling back to default timeout 5 seconds")
allowed_levels = logging.getLevelNamesMapping()
if verbosity not in allowed_levels:
self.logger.warning("Verbosity %s is invalid - Doesn't exist. Falling back to default verbosity INFO", verbosity)
self.port = int(port)
self.timeout = int(timeout)
self.verbosity = verbosity
self.scan_port_start = int(range_split[0])
self.scan_port_end = int(range_split[1])

View File

@ -1,13 +1,12 @@
class BankProtocolError(Exception): class BankNodeError(Exception):
def __init__(self, message): def __init__(self, message):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
class InvalidRequest(BankProtocolError): class ConfigError(BankNodeError):
def __init__(self, message): def __init__(self, message):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
__all__ = ["BankNodeError", "ConfigError"]
__all__ = ["BankProtocolError", "InvalidRequest"]

View File

@ -0,0 +1,5 @@
from .database_manager import *
__all__ = [
*database_manager.__all__
]

View File

@ -0,0 +1,67 @@
import logging
from typing import Generator
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine, text
from sqlalchemy.exc import DatabaseError
from database.exceptions import DatabaseConnectionError
from models.base_model import Base
class DatabaseManager():
_instance: 'DatabaseManager' = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if hasattr(self, "engine"):
return
self.logger = logging.getLogger(__name__)
self.logger.info("Initializing Database")
self.engine = create_engine('sqlite:///bank.db')
self.Session = sessionmaker(bind=self.engine)
def create_tables(self):
self.logger.debug("Creating tables")
Base.metadata.create_all(self.engine)
def cleanup(self) -> None:
self.logger.debug("Closing connection")
self.engine.dispose()
def test_connection(self) -> bool:
self.logger.debug("Testing database connection")
try:
with self.engine.connect() as connection:
connection.execute(text("select 1"))
self.logger.debug("Database connection successful")
return True
except DatabaseError as e:
self.logger.critical("Database connection failed: %s", e)
raise DatabaseConnectionError("Database connection failed") from e
return False
@classmethod
def get_session(cls) -> Generator:
session = cls._instance.Session()
try:
yield session
except Exception as e:
session.rollback()
cls._instance.logger.error("Transaction failed: %s", e)
raise
finally:
session.close()
__all__ = ["DatabaseManager"]

View File

@ -0,0 +1,28 @@
class DatabaseError(Exception):
def __init__(self, message: str):
super().__init__(message)
self.message = message
class DatabaseConnectionError(DatabaseError):
def __init__(self, message: str):
super().__init__(message)
self.message = message
class EmptyDatabaseConfigError(Exception):
def __init__(self, message: str, config_name: str):
super().__init__(message)
self.message = message
self.config_name = config_name
class DuplicateEntryError(DatabaseError):
def __init__(self, duplicate_entry_name: str, message: str):
super().__init__(message)
self.duplicate_entry_name = duplicate_entry_name
self.message = message
__all__ = ["DatabaseError", "DatabaseConnectionError", "DuplicateEntryError"]

View File

@ -0,0 +1,7 @@
from .account_model import *
from .base_model import *
__all__ = [
*base_model.__all__,
*account_model.__all__
]

View File

@ -0,0 +1,14 @@
from sqlalchemy import Column, Integer, CheckConstraint
from .base_model import Base
class Account(Base):
__tablename__ = 'account'
__table_args__ = (CheckConstraint('account_number > 10000 and account_number <= 99999'),)
account_number = Column(Integer, nullable=False, primary_key=True)
balance = Column(Integer)
__all__ = ["Account"]

5
src/models/base_model.py Normal file
View File

@ -0,0 +1,5 @@
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
__all__ = ["Base"]

View File

@ -2,15 +2,13 @@ import sys
import logging import logging
def setup_logger(): def setup_logger(verbosity: str):
logger = logging.getLogger() if verbosity == "DEBUG":
log_format = "[%(levelname)s] - %(name)s:%(lineno)d - %(message)s"
else:
log_format = "[%(levelname)s] - %(message)s"
handler = logging.StreamHandler(sys.stdout) logging.basicConfig(level=verbosity, format=log_format, stream=sys.stdout)
formatter = logging.Formatter("[%(levelname)s] - %(name)s:%(lineno)d - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
__all__ = ["setup_logger"] __all__ = ["setup_logger"]