Added new signals to catch, models, updated database manager, exceptions and config
This commit is contained in:
parent
a105984aa3
commit
eb7f15528e
@ -5,26 +5,46 @@ import logging
|
|||||||
|
|
||||||
|
|
||||||
from core.config import BankNodeConfig
|
from core.config import BankNodeConfig
|
||||||
|
from core.exceptions import ConfigError
|
||||||
from utils import setup_logger
|
from utils import setup_logger
|
||||||
|
from database.database_manager import DatabaseManager
|
||||||
|
|
||||||
class BankNode():
|
class BankNode():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
setup_logger()
|
try:
|
||||||
self.logger = logging.getLogger(__name__)
|
self.config = BankNodeConfig()
|
||||||
self._setup_signals()
|
|
||||||
self.config = BankNodeConfig()
|
setup_logger(self.config.verbosity)
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
self.logger.info("Config is valid")
|
||||||
|
self.logger.debug("Starting Bank Node")
|
||||||
|
|
||||||
|
self._setup_signals()
|
||||||
|
|
||||||
|
self.database_manager = DatabaseManager()
|
||||||
|
except ConfigError as e:
|
||||||
|
print(e)
|
||||||
|
self.exit_with_error()
|
||||||
|
|
||||||
def _setup_signals(self):
|
def _setup_signals(self):
|
||||||
self.logger.debug("Setting up exit signal hooks")
|
self.logger.debug("Setting up exit signal hooks")
|
||||||
|
|
||||||
|
# Handle C standard signals
|
||||||
signal.signal(signal.SIGTERM, self.gracefully_exit)
|
signal.signal(signal.SIGTERM, self.gracefully_exit)
|
||||||
signal.signal(signal.SIGINT, self.gracefully_exit)
|
signal.signal(signal.SIGINT, self.gracefully_exit)
|
||||||
|
|
||||||
|
# Handle windows related signals
|
||||||
|
if sys.platform == "win32":
|
||||||
|
signal.signal(signal.CTRL_C_EVENT, self.gracefully_exit)
|
||||||
|
signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit)
|
||||||
|
signal.signal(signal.SIGBREAK, self.gracefully_exit)
|
||||||
|
|
||||||
def start_server(self):
|
def start_server(self):
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server:
|
||||||
socket_server.bind(("127.0.0.1", 6969))
|
socket_server.bind(("127.0.0.1", self.config.port))
|
||||||
socket_server.listen()
|
socket_server.listen()
|
||||||
|
self.logger.info("Listening on %s:%s", "127.0.0.1", self.config.port)
|
||||||
while True:
|
while True:
|
||||||
conn, addr = socket_server.accept()
|
conn, addr = socket_server.accept()
|
||||||
with conn:
|
with conn:
|
||||||
@ -32,6 +52,10 @@ class BankNode():
|
|||||||
request = conn.recv(1024).decode("utf-8")
|
request = conn.recv(1024).decode("utf-8")
|
||||||
print(f"Request:\n{request}")
|
print(f"Request:\n{request}")
|
||||||
|
|
||||||
|
def exit_with_error(self):
|
||||||
|
self.cleanup()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
from core.exceptions import BankNodeError
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRequest(BankNodeError):
|
||||||
|
def __init__(self, message):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
@ -3,15 +3,43 @@ import logging
|
|||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
from utils import setup_logger
|
from core.exceptions import ConfigError
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
class BankNodeConfig:
|
class BankNodeConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.port = os.getenv("PORT", "6969")
|
port = os.getenv("PORT", "65526")
|
||||||
|
timeout = os.getenv("RESPONSE_TIMEOUT", "5")
|
||||||
|
verbosity = os.getenv("VERBOSITY", "INFO")
|
||||||
|
scan_port_range = os.getenv("SCAN_PORT_RANGE")
|
||||||
|
|
||||||
self.timeout = os.getenv("RESPONSE_TIMEOUT", "5")
|
if not scan_port_range or scan_port_range == "":
|
||||||
|
self.logger.error("Scan port range not defined")
|
||||||
|
raise ConfigError("Scan port range not defined")
|
||||||
|
|
||||||
self.verbosity = os.getenv("VERBOSITY", "DEBUG")
|
range_split = scan_port_range.split(":")
|
||||||
|
|
||||||
|
if len(range_split) != 2 or not range_split[0].isdigit() or not range_split[1].isdigit():
|
||||||
|
self.logger.error("Scan port range is not in valid format")
|
||||||
|
raise ConfigError("Scan port range is not in valid format")
|
||||||
|
|
||||||
|
|
||||||
|
if not port.isdigit():
|
||||||
|
self.logger.warning("Application port is invalid - Not a number. Falling back to default port 65526")
|
||||||
|
|
||||||
|
if not timeout.isdigit():
|
||||||
|
self.logger.warning("Request timeout is invalid - Not a number. Falling back to default timeout 5 seconds")
|
||||||
|
|
||||||
|
allowed_levels = logging.getLevelNamesMapping()
|
||||||
|
|
||||||
|
if verbosity not in allowed_levels:
|
||||||
|
self.logger.warning("Verbosity %s is invalid - Doesn't exist. Falling back to default verbosity INFO", verbosity)
|
||||||
|
|
||||||
|
self.port = int(port)
|
||||||
|
self.timeout = int(timeout)
|
||||||
|
self.verbosity = verbosity
|
||||||
|
self.scan_port_start = int(range_split[0])
|
||||||
|
self.scan_port_end = int(range_split[1])
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
class BankProtocolError(Exception):
|
class BankNodeError(Exception):
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
class InvalidRequest(BankProtocolError):
|
class ConfigError(BankNodeError):
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
|
__all__ = ["BankNodeError", "ConfigError"]
|
||||||
__all__ = ["BankProtocolError", "InvalidRequest"]
|
|
||||||
|
@ -0,0 +1,5 @@
|
|||||||
|
from .database_manager import *
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
*database_manager.__all__
|
||||||
|
]
|
@ -0,0 +1,67 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
|
from sqlalchemy.exc import DatabaseError
|
||||||
|
|
||||||
|
from database.exceptions import DatabaseConnectionError
|
||||||
|
from models.base_model import Base
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseManager():
|
||||||
|
|
||||||
|
_instance: 'DatabaseManager' = None
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
if hasattr(self, "engine"):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self.logger.info("Initializing Database")
|
||||||
|
|
||||||
|
self.engine = create_engine('sqlite:///bank.db')
|
||||||
|
|
||||||
|
self.Session = sessionmaker(bind=self.engine)
|
||||||
|
|
||||||
|
def create_tables(self):
|
||||||
|
self.logger.debug("Creating tables")
|
||||||
|
Base.metadata.create_all(self.engine)
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
self.logger.debug("Closing connection")
|
||||||
|
self.engine.dispose()
|
||||||
|
|
||||||
|
def test_connection(self) -> bool:
|
||||||
|
self.logger.debug("Testing database connection")
|
||||||
|
try:
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
connection.execute(text("select 1"))
|
||||||
|
self.logger.debug("Database connection successful")
|
||||||
|
return True
|
||||||
|
except DatabaseError as e:
|
||||||
|
self.logger.critical("Database connection failed: %s", e)
|
||||||
|
raise DatabaseConnectionError("Database connection failed") from e
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_session(cls) -> Generator:
|
||||||
|
session = cls._instance.Session()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
cls._instance.logger.error("Transaction failed: %s", e)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DatabaseManager"]
|
28
src/database/exceptions.py
Normal file
28
src/database/exceptions.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
class DatabaseError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnectionError(DatabaseError):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyDatabaseConfigError(Exception):
|
||||||
|
def __init__(self, message: str, config_name: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
self.message = message
|
||||||
|
self.config_name = config_name
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateEntryError(DatabaseError):
|
||||||
|
def __init__(self, duplicate_entry_name: str, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
self.duplicate_entry_name = duplicate_entry_name
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DatabaseError", "DatabaseConnectionError", "DuplicateEntryError"]
|
@ -0,0 +1,7 @@
|
|||||||
|
from .account_model import *
|
||||||
|
from .base_model import *
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
*base_model.__all__,
|
||||||
|
*account_model.__all__
|
||||||
|
]
|
14
src/models/account_model.py
Normal file
14
src/models/account_model.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from sqlalchemy import Column, Integer, CheckConstraint
|
||||||
|
|
||||||
|
from .base_model import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Account(Base):
|
||||||
|
__tablename__ = 'account'
|
||||||
|
__table_args__ = (CheckConstraint('account_number > 10000 and account_number <= 99999'),)
|
||||||
|
|
||||||
|
account_number = Column(Integer, nullable=False, primary_key=True)
|
||||||
|
balance = Column(Integer)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Account"]
|
5
src/models/base_model.py
Normal file
5
src/models/base_model.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
__all__ = ["Base"]
|
@ -2,15 +2,13 @@ import sys
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
def setup_logger():
|
def setup_logger(verbosity: str):
|
||||||
logger = logging.getLogger()
|
if verbosity == "DEBUG":
|
||||||
|
log_format = "[%(levelname)s] - %(name)s:%(lineno)d - %(message)s"
|
||||||
|
else:
|
||||||
|
log_format = "[%(levelname)s] - %(message)s"
|
||||||
|
|
||||||
handler = logging.StreamHandler(sys.stdout)
|
logging.basicConfig(level=verbosity, format=log_format, stream=sys.stdout)
|
||||||
|
|
||||||
formatter = logging.Formatter("[%(levelname)s] - %(name)s:%(lineno)d - %(message)s")
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["setup_logger"]
|
__all__ = ["setup_logger"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user