Major changes, most notably structure for commands is being implemented and full cleanup on exit is performed

This commit is contained in:
Thastertyn 2025-02-04 16:46:21 +01:00
parent 61399c555a
commit a70d052cbf
23 changed files with 360 additions and 40 deletions

View File

@ -3,6 +3,11 @@
# invalid value is provided # invalid value is provided
RESPONSE_TIMEOUT=5 RESPONSE_TIMEOUT=5
# In seconds
# Default 60 if user doesn't interact
# within this timeframe
CLIENT_IDLE_TIMEOUT=60
# A valid port number # A valid port number
# If not provided or invalid, defaults to 65526 # If not provided or invalid, defaults to 65526
PORT=65526 PORT=65526

View File

@ -6,7 +6,7 @@ authors = ["Thastertyn <thastertyn@thastertyn.xyz>"]
readme = "README.md" readme = "README.md"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.12" python = "^3.6"
sqlalchemy = "^2.0.37" sqlalchemy = "^2.0.37"
python-dotenv = "^1.0.1" python-dotenv = "^1.0.1"

View File

@ -1,4 +1,5 @@
from bank_node.bank_node import BankNode from bank_node.bank_node import BankNode
if __name__ == "__main__": if __name__ == "__main__":
BankNode().start_server() node = BankNode()
node.start()

View File

@ -3,12 +3,14 @@ import signal
import sys import sys
import logging import logging
from core.config import BankNodeConfig from core.config import BankNodeConfig
from core.exceptions import ConfigError from core.exceptions import ConfigError
from utils import setup_logger from utils import setup_logger
from database.database_manager import DatabaseManager from database.database_manager import DatabaseManager
from bank_node.bank_worker import BankWorker
class BankNode(): class BankNode():
def __init__(self): def __init__(self):
try: try:
@ -20,16 +22,24 @@ class BankNode():
self.logger.info("Config is valid") self.logger.info("Config is valid")
self.logger.debug("Starting Bank Node") self.logger.debug("Starting Bank Node")
self._setup_signals() self.__setup_signals()
self.database_manager = DatabaseManager() self.database_manager = DatabaseManager()
self.socket_server = None
except ConfigError as e: except ConfigError as e:
print(e) print(e)
self.exit_with_error() self.exit_with_error()
def _setup_signals(self): def __setup_signals(self):
self.logger.debug("Setting up exit signal hooks") self.logger.debug("Setting up exit signal hooks")
# Not as clean as
# atexit.register(self.gracefully_exit)
# But it gives more control and is easier
# to understand in the output
# and looks better
# Handle C standard signals # Handle C standard signals
signal.signal(signal.SIGTERM, self.gracefully_exit) signal.signal(signal.SIGTERM, self.gracefully_exit)
signal.signal(signal.SIGINT, self.gracefully_exit) signal.signal(signal.SIGINT, self.gracefully_exit)
@ -40,27 +50,48 @@ class BankNode():
signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit) signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit)
signal.signal(signal.SIGBREAK, self.gracefully_exit) signal.signal(signal.SIGBREAK, self.gracefully_exit)
def start_server(self): def start(self):
for port in range(self.config.scan_port_start, self.config.scan_port_end + 1):
self.logger.debug("Trying port %d", port)
try:
self.config.used_port = port
self.__start_server(port)
return
except socket.error as e:
if e.errno == 98: # Address is in use
self.logger.info("Port %d in use, trying next port", port)
self.logger.error("All ports are in use")
self.exit_with_error()
def __start_server(self, port: int):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server:
socket_server.bind(("127.0.0.1", self.config.port)) socket_server.bind((self.config.ip, port))
socket_server.listen() socket_server.listen()
self.logger.info("Listening on %s:%s", "127.0.0.1", self.config.port) self.socket_server = socket_server
self.logger.info("Listening on %s:%s", self.config.ip, port)
while True: while True:
conn, addr = socket_server.accept() client_socket, address = socket_server.accept()
with conn: self.logger.info("%s connected", address[0])
print(f"{addr} connected")
request = conn.recv(1024).decode("utf-8") process = BankWorker(client_socket, address, self.config.to_dict())
print(f"Request:\n{request}") process.start()
def exit_with_error(self): def exit_with_error(self):
self.cleanup() """Exit the application with status of 1"""
sys.exit(1) sys.exit(1)
def cleanup(self): def gracefully_exit(self, signum, _):
pass """Log the signal caught and exit with status 0"""
def gracefully_exit(self, signum, frame):
signal_name = signal.Signals(signum).name signal_name = signal.Signals(signum).name
self.logger.warning("Caught %s. Cleaning up before exiting", signal_name) self.logger.warning("Caught %s - cleaning up", signal_name)
self.cleanup() self.cleanup()
self.logger.info("Exiting")
sys.exit(0) sys.exit(0)
def cleanup(self):
self.logger.debug("Closing socket server")
self.socket_server.close()

View File

@ -1,4 +0,0 @@
from multiprocessing import Process
class BankProcess(Process):
pass

View File

@ -0,0 +1,88 @@
import socket
import multiprocessing
import logging
from typing import Tuple, Dict
import signal
import sys
from bank_protocol.command_handler import CommandHandler
from core import Request, Response
from core.exceptions import BankNodeError
class BankWorker(multiprocessing.Process):
def __init__(self, client_socket: socket.socket, client_address: Tuple, config: Dict):
super().__init__()
self.logger = logging.getLogger(__name__)
self.client_socket = client_socket
self.client_socket.settimeout(config["client_idle_timeout"])
self.client_address = client_address
self.command_handler = CommandHandler(config)
self.config = config
def __setup_signals(self):
self.logger.debug("Setting up exit signal hooks for worker")
# Handle C standard signals
signal.signal(signal.SIGTERM, self.gracefully_exit_worker)
signal.signal(signal.SIGINT, self.gracefully_exit_worker)
# Handle windows related signals
if sys.platform == "win32":
signal.signal(signal.CTRL_C_EVENT, self.gracefully_exit_worker)
signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit_worker)
signal.signal(signal.SIGBREAK, self.gracefully_exit_worker)
def run(self):
self.__setup_signals()
with self.client_socket:
while True:
try:
raw_request = self.client_socket.recv(1024).decode("utf-8")
if not raw_request:
self.logger.debug("%s disconnected", self.client_address[0])
break
request = Request(raw_request)
self.logger.debug("Received request from %s - %s", self.client_address[0], request)
response: Response = self.command_handler.execute(request) + "\n\r"
self.client_socket.sendall(response.encode("utf-8"))
except socket.timeout:
self.logger.debug("Client was idle for too long. Ending connection")
response = "ER Idle too long\n\r"
self.client_socket.sendall(response.encode("utf-8"))
self.client_socket.shutdown(socket.SHUT_RDWR)
self.client_socket.close()
break
except BankNodeError as e:
response = "ER " + e.message + "\n\r"
self.client_socket.sendall(response.encode("utf-8"))
except socket.error as e:
response = "ER Internal server error\n\r"
self.logger.error(e)
break
self.logger.debug("Closing process for %s", self.client_address[0])
def gracefully_exit_worker(self, signum, _):
"""Log the signal caught and exit with status 0"""
signal_name = signal.Signals(signum).name
self.logger.warning("Worker caught %s - cleaning up", signal_name)
self.cleanup()
sys.exit(0)
def cleanup(self):
self.logger.info("Closing connection with %s", self.client_address[0])
if self.client_socket:
self.client_socket.close()
__all__ = ["BankWorker"]

View File

@ -0,0 +1,42 @@
import logging
from typing import Dict, Callable
from core import Request, Response
from core.exceptions import BankNodeError
from bank_protocol.commands import (account_balance,
account_create,
account_deposit,
account_remove,
account_withdrawal,
bank_code,
bank_number_of_clients,
bank_total_amount)
class CommandHandler:
def __init__(self, config: Dict):
self.logger = logging.getLogger(__name__)
self.config = config
self.registered_commands: Dict[str, Callable] = {
"BC": bank_code,
"AC": account_create,
"AD": account_deposit,
"AW": account_withdrawal,
"AB": account_balance,
"AR": account_remove,
"BA": bank_total_amount,
"BN": bank_number_of_clients
}
def execute(self, request: Request) -> Response:
if request.command_code not in self.registered_commands:
self.logger.warning("Unknown command %s", request.command_code)
raise BankNodeError(f"Unknown command {request.command_code}")
command = self.registered_commands[request.command_code]
response = command(request, self.config)
return f"{request.command_code} {response}"

View File

@ -0,0 +1,19 @@
from .account_balance_command import *
from .account_create_command import *
from .account_deposit_command import *
from .account_remove_command import *
from .account_withdrawal_command import *
from .bank_code_command import *
from .bank_number_of_clients_command import *
from .bank_total_amount_command import *
__all__ = [
*account_balance_command.__all__,
*account_create_command.__all__,
*account_deposit_command.__all__,
*account_remove_command.__all__,
*account_withdrawal_command.__all__,
*bank_code_command.__all__,
*bank_number_of_clients_command.__all__,
*bank_total_amount_command.__all__
]

View File

@ -0,0 +1,8 @@
from core.request import Request
import re
def account_balance(request: Request):
pass
__all__ = ["account_balance"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def account_create(request: Request):
pass
__all__ = ["account_create"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def account_deposit(request: Request):
pass
__all__ = ["account_deposit"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def account_remove(request: Request):
pass
__all__ = ["account_remove"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def account_withdrawal(request: Request):
pass
__all__ = ["account_withdrawal"]

View File

@ -0,0 +1,12 @@
from typing import Dict
from core import Request, Response
from bank_protocol.exceptions import InvalidRequest
def bank_code(request: Request, config: Dict) -> Response:
if request.body is not None:
raise InvalidRequest("Incorrect usage")
return config["ip"]
__all__ = ["bank_code"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def bank_number_of_clients(request: Request):
pass
__all__ = ["bank_number_of_clients"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def bank_total_amount(request: Request):
pass
__all__ = ["bank_total_amount"]

View File

@ -0,0 +1,7 @@
from .request import *
from .response import *
__all__ = [
*request.__all__,
*response.__all__
]

View File

@ -1,9 +1,12 @@
import os import os
import logging import logging
import re
import dotenv import dotenv
from core.exceptions import ConfigError from core.exceptions import ConfigError
from utils.ip import get_ip
from utils.constants import IP_REGEX
dotenv.load_dotenv() dotenv.load_dotenv()
@ -11,11 +14,17 @@ dotenv.load_dotenv()
class BankNodeConfig: class BankNodeConfig:
def __init__(self): def __init__(self):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
port = os.getenv("PORT", "65526") # Added for compatibility with python 3.9 instead of using
# logging.getLevelNamesMapping()
allowed_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
ip = os.getenv("IP", get_ip())
timeout = os.getenv("RESPONSE_TIMEOUT", "5") timeout = os.getenv("RESPONSE_TIMEOUT", "5")
client_idle_timeout = os.getenv("CLIENT_IDLE_TIMEOUT", "60")
verbosity = os.getenv("VERBOSITY", "INFO") verbosity = os.getenv("VERBOSITY", "INFO")
scan_port_range = os.getenv("SCAN_PORT_RANGE") scan_port_range = os.getenv("SCAN_PORT_RANGE")
# Port validation
if not scan_port_range or scan_port_range == "": if not scan_port_range or scan_port_range == "":
self.logger.error("Scan port range not defined") self.logger.error("Scan port range not defined")
raise ConfigError("Scan port range not defined") raise ConfigError("Scan port range not defined")
@ -26,20 +35,37 @@ class BankNodeConfig:
self.logger.error("Scan port range is not in valid format") self.logger.error("Scan port range is not in valid format")
raise ConfigError("Scan port range is not in valid format") raise ConfigError("Scan port range is not in valid format")
# Timeout validation
if not port.isdigit():
self.logger.warning("Application port is invalid - Not a number. Falling back to default port 65526")
if not timeout.isdigit(): if not timeout.isdigit():
self.logger.warning("Request timeout is invalid - Not a number. Falling back to default timeout 5 seconds") self.logger.warning("Request timeout is invalid - Not a number. Falling back to default timeout 5 seconds")
allowed_levels = logging.getLevelNamesMapping() if not client_idle_timeout.isdigit():
self.logger.warning("Client idle timeout is invalid - Not a number. Falling back to default idle timeout 60 seconds")
# Verbosity
if verbosity not in allowed_levels: if verbosity not in allowed_levels:
self.logger.warning("Verbosity %s is invalid - Doesn't exist. Falling back to default verbosity INFO", verbosity) self.logger.warning("Unknown verbosity %s. Falling back to default verbosity INFO", verbosity)
self.port = int(port) # IP validation
if not re.match(IP_REGEX, ip):
self.logger.error("Invalid IP in configuration")
raise ConfigError("Invalid IP in configuration")
self.used_port: int
self.ip = ip
self.timeout = int(timeout) self.timeout = int(timeout)
self.client_idle_timeout = int(client_idle_timeout)
self.verbosity = verbosity self.verbosity = verbosity
self.scan_port_start = int(range_split[0]) self.scan_port_start = int(range_split[0])
self.scan_port_end = int(range_split[1]) self.scan_port_end = int(range_split[1])
def to_dict(self):
return {
"used_port": self.used_port,
"ip": self.ip,
"timeout": self.timeout,
"client_idle_timeout": self.client_idle_timeout,
"verbosity": self.verbosity,
"scan_port_start": self.scan_port_start,
"scan_port_end": self.scan_port_end,
}

View File

@ -1,10 +1,28 @@
from core.exceptions import * import re
from bank_protocol.exceptions import InvalidRequest
class Request(): class Request():
def __init__(self, raw_request):
try: def __init__(self, raw_request: str):
self.command_code = raw_request[0:1] if re.match(r"^[A-Z]{2}$", raw_request):
self.body = raw_request[2:-1] self.command_code = raw_request[0:2] # Still take the first 2 characters, because of lingering crlf
except IndexError: self.body = None
pass elif re.match(r"^[A-Z]{2} .+", raw_request):
command_code: str = raw_request[0:2]
body: str = raw_request[3:-1] or ""
if len(body.split("\n")) > 1:
raise InvalidRequest("Multiline requests are not supported")
self.command_code = command_code
self.body = body
else:
raise InvalidRequest("Invalid request")
def __str__(self):
return f"{self.command_code} {self.body}"
__all__ = ["Request"]

View File

@ -0,0 +1,3 @@
Response = str
__all__ = ["Response"]

5
src/utils/constants.py Normal file
View File

@ -0,0 +1,5 @@
import re
IP_REGEX = r"^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$"
ACCOUNT_NUMBER_REGEX = r"[0-9]{9}"
MONEY_AMOUNT_MAXIMUM = (2 ^ 63) - 1

17
src/utils/ip.py Normal file
View File

@ -0,0 +1,17 @@
import socket
def get_ip():
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
try:
s.connect(('1.1.1.1', 1))
ip = s.getsockname()[0]
except Exception:
ip = '127.0.0.1'
finally:
s.close()
return ip
__all__ = ["get_ip"]

View File

@ -4,7 +4,7 @@ import logging
def setup_logger(verbosity: str): def setup_logger(verbosity: str):
if verbosity == "DEBUG": if verbosity == "DEBUG":
log_format = "[%(levelname)s] - %(name)s:%(lineno)d - %(message)s" log_format = "[ %(levelname)s / %(processName)s ] - %(name)s:%(lineno)d - %(message)s"
else: else:
log_format = "[ %(levelname)s ] - %(message)s" log_format = "[ %(levelname)s ] - %(message)s"