From a70d052cbfcbef386ebb49fe5751cb02c68c961a Mon Sep 17 00:00:00 2001 From: Thastertyn Date: Tue, 4 Feb 2025 16:46:21 +0100 Subject: [PATCH] Major changes, most notably structure for commands is being implemented and full cleanup on exit is performed --- .env.example | 5 ++ pyproject.toml | 2 +- src/app.py | 3 +- src/bank_node/bank_node.py | 65 ++++++++++---- src/bank_node/bank_process.py | 4 - src/bank_node/bank_worker.py | 88 +++++++++++++++++++ src/bank_protocol/command_handler.py | 42 +++++++++ src/bank_protocol/commands/__init__.py | 19 ++++ .../commands/account_balance_command.py | 8 ++ .../commands/account_create_command.py | 7 ++ .../commands/account_deposit_command.py | 7 ++ .../commands/account_remove_command.py | 7 ++ .../commands/account_withdrawal_command.py | 7 ++ .../commands/bank_code_command.py | 12 +++ .../bank_number_of_clients_command.py | 7 ++ .../commands/bank_total_amount_command.py | 7 ++ src/core/__init__.py | 7 ++ src/core/config.py | 42 +++++++-- src/core/request.py | 32 +++++-- src/core/response.py | 3 + src/utils/constants.py | 5 ++ src/utils/ip.py | 17 ++++ src/utils/logger.py | 4 +- 23 files changed, 360 insertions(+), 40 deletions(-) delete mode 100644 src/bank_node/bank_process.py create mode 100644 src/bank_node/bank_worker.py create mode 100644 src/utils/constants.py create mode 100644 src/utils/ip.py diff --git a/.env.example b/.env.example index c415451..2025791 100644 --- a/.env.example +++ b/.env.example @@ -3,6 +3,11 @@ # invalid value is provided RESPONSE_TIMEOUT=5 +# In seconds +# Default 60 if user doesn't interact +# within this timeframe +CLIENT_IDLE_TIMEOUT=60 + # A valid port number # If not provided or invalid, defaults to 65526 PORT=65526 diff --git a/pyproject.toml b/pyproject.toml index bcfd786..69e18b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Thastertyn "] readme = "README.md" [tool.poetry.dependencies] -python = "^3.12" +python = "^3.6" sqlalchemy = "^2.0.37" python-dotenv = "^1.0.1" diff --git a/src/app.py b/src/app.py index 9626a00..22a47cf 100644 --- a/src/app.py +++ b/src/app.py @@ -1,4 +1,5 @@ from bank_node.bank_node import BankNode if __name__ == "__main__": - BankNode().start_server() + node = BankNode() + node.start() diff --git a/src/bank_node/bank_node.py b/src/bank_node/bank_node.py index 24dc8b8..bf646ee 100644 --- a/src/bank_node/bank_node.py +++ b/src/bank_node/bank_node.py @@ -3,12 +3,14 @@ import signal import sys import logging - from core.config import BankNodeConfig from core.exceptions import ConfigError from utils import setup_logger from database.database_manager import DatabaseManager +from bank_node.bank_worker import BankWorker + + class BankNode(): def __init__(self): try: @@ -16,20 +18,28 @@ class BankNode(): setup_logger(self.config.verbosity) self.logger = logging.getLogger(__name__) - + self.logger.info("Config is valid") self.logger.debug("Starting Bank Node") - self._setup_signals() + self.__setup_signals() self.database_manager = DatabaseManager() + self.socket_server = None + except ConfigError as e: print(e) self.exit_with_error() - def _setup_signals(self): + def __setup_signals(self): self.logger.debug("Setting up exit signal hooks") + # Not as clean as + # atexit.register(self.gracefully_exit) + # But it gives more control and is easier + # to understand in the output + # and looks better + # Handle C standard signals signal.signal(signal.SIGTERM, self.gracefully_exit) signal.signal(signal.SIGINT, self.gracefully_exit) @@ -40,27 +50,48 @@ class BankNode(): signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit) signal.signal(signal.SIGBREAK, self.gracefully_exit) - def start_server(self): + def start(self): + for port in range(self.config.scan_port_start, self.config.scan_port_end + 1): + self.logger.debug("Trying port %d", port) + try: + self.config.used_port = port + self.__start_server(port) + return + except socket.error as e: + if e.errno == 98: # Address is in use + self.logger.info("Port %d in use, trying next port", port) + + self.logger.error("All ports are in use") + self.exit_with_error() + + def __start_server(self, port: int): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server: - socket_server.bind(("127.0.0.1", self.config.port)) + socket_server.bind((self.config.ip, port)) socket_server.listen() - self.logger.info("Listening on %s:%s", "127.0.0.1", self.config.port) + self.socket_server = socket_server + self.logger.info("Listening on %s:%s", self.config.ip, port) + while True: - conn, addr = socket_server.accept() - with conn: - print(f"{addr} connected") - request = conn.recv(1024).decode("utf-8") - print(f"Request:\n{request}") + client_socket, address = socket_server.accept() + self.logger.info("%s connected", address[0]) + + process = BankWorker(client_socket, address, self.config.to_dict()) + process.start() def exit_with_error(self): - self.cleanup() + """Exit the application with status of 1""" + sys.exit(1) - def cleanup(self): - pass + def gracefully_exit(self, signum, _): + """Log the signal caught and exit with status 0""" - def gracefully_exit(self, signum, frame): signal_name = signal.Signals(signum).name - self.logger.warning("Caught %s. Cleaning up before exiting", signal_name) + self.logger.warning("Caught %s - cleaning up", signal_name) self.cleanup() + self.logger.info("Exiting") sys.exit(0) + + def cleanup(self): + self.logger.debug("Closing socket server") + self.socket_server.close() diff --git a/src/bank_node/bank_process.py b/src/bank_node/bank_process.py deleted file mode 100644 index 46405fe..0000000 --- a/src/bank_node/bank_process.py +++ /dev/null @@ -1,4 +0,0 @@ -from multiprocessing import Process - -class BankProcess(Process): - pass \ No newline at end of file diff --git a/src/bank_node/bank_worker.py b/src/bank_node/bank_worker.py new file mode 100644 index 0000000..764a727 --- /dev/null +++ b/src/bank_node/bank_worker.py @@ -0,0 +1,88 @@ +import socket +import multiprocessing +import logging +from typing import Tuple, Dict +import signal +import sys + +from bank_protocol.command_handler import CommandHandler +from core import Request, Response +from core.exceptions import BankNodeError + + +class BankWorker(multiprocessing.Process): + def __init__(self, client_socket: socket.socket, client_address: Tuple, config: Dict): + super().__init__() + + self.logger = logging.getLogger(__name__) + + self.client_socket = client_socket + self.client_socket.settimeout(config["client_idle_timeout"]) + self.client_address = client_address + + self.command_handler = CommandHandler(config) + self.config = config + + def __setup_signals(self): + self.logger.debug("Setting up exit signal hooks for worker") + + # Handle C standard signals + signal.signal(signal.SIGTERM, self.gracefully_exit_worker) + signal.signal(signal.SIGINT, self.gracefully_exit_worker) + + # Handle windows related signals + if sys.platform == "win32": + signal.signal(signal.CTRL_C_EVENT, self.gracefully_exit_worker) + signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit_worker) + signal.signal(signal.SIGBREAK, self.gracefully_exit_worker) + + def run(self): + self.__setup_signals() + with self.client_socket: + while True: + try: + raw_request = self.client_socket.recv(1024).decode("utf-8") + + if not raw_request: + self.logger.debug("%s disconnected", self.client_address[0]) + break + + request = Request(raw_request) + self.logger.debug("Received request from %s - %s", self.client_address[0], request) + + response: Response = self.command_handler.execute(request) + "\n\r" + + self.client_socket.sendall(response.encode("utf-8")) + + except socket.timeout: + self.logger.debug("Client was idle for too long. Ending connection") + response = "ER Idle too long\n\r" + self.client_socket.sendall(response.encode("utf-8")) + self.client_socket.shutdown(socket.SHUT_RDWR) + self.client_socket.close() + break + except BankNodeError as e: + response = "ER " + e.message + "\n\r" + self.client_socket.sendall(response.encode("utf-8")) + except socket.error as e: + response = "ER Internal server error\n\r" + self.logger.error(e) + break + + self.logger.debug("Closing process for %s", self.client_address[0]) + + def gracefully_exit_worker(self, signum, _): + """Log the signal caught and exit with status 0""" + + signal_name = signal.Signals(signum).name + self.logger.warning("Worker caught %s - cleaning up", signal_name) + self.cleanup() + sys.exit(0) + + def cleanup(self): + self.logger.info("Closing connection with %s", self.client_address[0]) + if self.client_socket: + self.client_socket.close() + + +__all__ = ["BankWorker"] diff --git a/src/bank_protocol/command_handler.py b/src/bank_protocol/command_handler.py index e69de29..2187897 100644 --- a/src/bank_protocol/command_handler.py +++ b/src/bank_protocol/command_handler.py @@ -0,0 +1,42 @@ +import logging +from typing import Dict, Callable + +from core import Request, Response +from core.exceptions import BankNodeError + +from bank_protocol.commands import (account_balance, + account_create, + account_deposit, + account_remove, + account_withdrawal, + bank_code, + bank_number_of_clients, + bank_total_amount) + + +class CommandHandler: + def __init__(self, config: Dict): + self.logger = logging.getLogger(__name__) + self.config = config + + self.registered_commands: Dict[str, Callable] = { + "BC": bank_code, + "AC": account_create, + "AD": account_deposit, + "AW": account_withdrawal, + "AB": account_balance, + "AR": account_remove, + "BA": bank_total_amount, + "BN": bank_number_of_clients + } + + def execute(self, request: Request) -> Response: + if request.command_code not in self.registered_commands: + self.logger.warning("Unknown command %s", request.command_code) + raise BankNodeError(f"Unknown command {request.command_code}") + + command = self.registered_commands[request.command_code] + + response = command(request, self.config) + + return f"{request.command_code} {response}" diff --git a/src/bank_protocol/commands/__init__.py b/src/bank_protocol/commands/__init__.py index e69de29..113590d 100644 --- a/src/bank_protocol/commands/__init__.py +++ b/src/bank_protocol/commands/__init__.py @@ -0,0 +1,19 @@ +from .account_balance_command import * +from .account_create_command import * +from .account_deposit_command import * +from .account_remove_command import * +from .account_withdrawal_command import * +from .bank_code_command import * +from .bank_number_of_clients_command import * +from .bank_total_amount_command import * + +__all__ = [ + *account_balance_command.__all__, + *account_create_command.__all__, + *account_deposit_command.__all__, + *account_remove_command.__all__, + *account_withdrawal_command.__all__, + *bank_code_command.__all__, + *bank_number_of_clients_command.__all__, + *bank_total_amount_command.__all__ +] diff --git a/src/bank_protocol/commands/account_balance_command.py b/src/bank_protocol/commands/account_balance_command.py index e69de29..cc5f3e5 100644 --- a/src/bank_protocol/commands/account_balance_command.py +++ b/src/bank_protocol/commands/account_balance_command.py @@ -0,0 +1,8 @@ +from core.request import Request +import re + +def account_balance(request: Request): + pass + + +__all__ = ["account_balance"] diff --git a/src/bank_protocol/commands/account_create_command.py b/src/bank_protocol/commands/account_create_command.py index e69de29..acc007e 100644 --- a/src/bank_protocol/commands/account_create_command.py +++ b/src/bank_protocol/commands/account_create_command.py @@ -0,0 +1,7 @@ +from core.request import Request + +def account_create(request: Request): + pass + + +__all__ = ["account_create"] diff --git a/src/bank_protocol/commands/account_deposit_command.py b/src/bank_protocol/commands/account_deposit_command.py index e69de29..8bdea17 100644 --- a/src/bank_protocol/commands/account_deposit_command.py +++ b/src/bank_protocol/commands/account_deposit_command.py @@ -0,0 +1,7 @@ +from core.request import Request + +def account_deposit(request: Request): + pass + + +__all__ = ["account_deposit"] diff --git a/src/bank_protocol/commands/account_remove_command.py b/src/bank_protocol/commands/account_remove_command.py index e69de29..db8a873 100644 --- a/src/bank_protocol/commands/account_remove_command.py +++ b/src/bank_protocol/commands/account_remove_command.py @@ -0,0 +1,7 @@ +from core.request import Request + +def account_remove(request: Request): + pass + + +__all__ = ["account_remove"] diff --git a/src/bank_protocol/commands/account_withdrawal_command.py b/src/bank_protocol/commands/account_withdrawal_command.py index e69de29..2111852 100644 --- a/src/bank_protocol/commands/account_withdrawal_command.py +++ b/src/bank_protocol/commands/account_withdrawal_command.py @@ -0,0 +1,7 @@ +from core.request import Request + +def account_withdrawal(request: Request): + pass + + +__all__ = ["account_withdrawal"] diff --git a/src/bank_protocol/commands/bank_code_command.py b/src/bank_protocol/commands/bank_code_command.py index e69de29..97e85a4 100644 --- a/src/bank_protocol/commands/bank_code_command.py +++ b/src/bank_protocol/commands/bank_code_command.py @@ -0,0 +1,12 @@ +from typing import Dict + +from core import Request, Response +from bank_protocol.exceptions import InvalidRequest + +def bank_code(request: Request, config: Dict) -> Response: + if request.body is not None: + raise InvalidRequest("Incorrect usage") + + return config["ip"] + +__all__ = ["bank_code"] diff --git a/src/bank_protocol/commands/bank_number_of_clients_command.py b/src/bank_protocol/commands/bank_number_of_clients_command.py index e69de29..0f34fa9 100644 --- a/src/bank_protocol/commands/bank_number_of_clients_command.py +++ b/src/bank_protocol/commands/bank_number_of_clients_command.py @@ -0,0 +1,7 @@ +from core.request import Request + +def bank_number_of_clients(request: Request): + pass + + +__all__ = ["bank_number_of_clients"] diff --git a/src/bank_protocol/commands/bank_total_amount_command.py b/src/bank_protocol/commands/bank_total_amount_command.py index e69de29..7ae4bed 100644 --- a/src/bank_protocol/commands/bank_total_amount_command.py +++ b/src/bank_protocol/commands/bank_total_amount_command.py @@ -0,0 +1,7 @@ +from core.request import Request + +def bank_total_amount(request: Request): + pass + + +__all__ = ["bank_total_amount"] diff --git a/src/core/__init__.py b/src/core/__init__.py index e69de29..43b6058 100644 --- a/src/core/__init__.py +++ b/src/core/__init__.py @@ -0,0 +1,7 @@ +from .request import * +from .response import * + +__all__ = [ + *request.__all__, + *response.__all__ +] \ No newline at end of file diff --git a/src/core/config.py b/src/core/config.py index d3b685d..35fe4b5 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,9 +1,12 @@ import os import logging +import re import dotenv from core.exceptions import ConfigError +from utils.ip import get_ip +from utils.constants import IP_REGEX dotenv.load_dotenv() @@ -11,11 +14,17 @@ dotenv.load_dotenv() class BankNodeConfig: def __init__(self): self.logger = logging.getLogger(__name__) - port = os.getenv("PORT", "65526") + # Added for compatibility with python 3.9 instead of using + # logging.getLevelNamesMapping() + allowed_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + + ip = os.getenv("IP", get_ip()) timeout = os.getenv("RESPONSE_TIMEOUT", "5") + client_idle_timeout = os.getenv("CLIENT_IDLE_TIMEOUT", "60") verbosity = os.getenv("VERBOSITY", "INFO") scan_port_range = os.getenv("SCAN_PORT_RANGE") + # Port validation if not scan_port_range or scan_port_range == "": self.logger.error("Scan port range not defined") raise ConfigError("Scan port range not defined") @@ -26,20 +35,37 @@ class BankNodeConfig: self.logger.error("Scan port range is not in valid format") raise ConfigError("Scan port range is not in valid format") - - if not port.isdigit(): - self.logger.warning("Application port is invalid - Not a number. Falling back to default port 65526") - + # Timeout validation if not timeout.isdigit(): self.logger.warning("Request timeout is invalid - Not a number. Falling back to default timeout 5 seconds") - allowed_levels = logging.getLevelNamesMapping() + if not client_idle_timeout.isdigit(): + self.logger.warning("Client idle timeout is invalid - Not a number. Falling back to default idle timeout 60 seconds") + # Verbosity if verbosity not in allowed_levels: - self.logger.warning("Verbosity %s is invalid - Doesn't exist. Falling back to default verbosity INFO", verbosity) + self.logger.warning("Unknown verbosity %s. Falling back to default verbosity INFO", verbosity) - self.port = int(port) + # IP validation + if not re.match(IP_REGEX, ip): + self.logger.error("Invalid IP in configuration") + raise ConfigError("Invalid IP in configuration") + + self.used_port: int + self.ip = ip self.timeout = int(timeout) + self.client_idle_timeout = int(client_idle_timeout) self.verbosity = verbosity self.scan_port_start = int(range_split[0]) self.scan_port_end = int(range_split[1]) + + def to_dict(self): + return { + "used_port": self.used_port, + "ip": self.ip, + "timeout": self.timeout, + "client_idle_timeout": self.client_idle_timeout, + "verbosity": self.verbosity, + "scan_port_start": self.scan_port_start, + "scan_port_end": self.scan_port_end, + } diff --git a/src/core/request.py b/src/core/request.py index 61c0f25..2d2de02 100644 --- a/src/core/request.py +++ b/src/core/request.py @@ -1,10 +1,28 @@ -from core.exceptions import * +import re + +from bank_protocol.exceptions import InvalidRequest class Request(): - def __init__(self, raw_request): - try: - self.command_code = raw_request[0:1] - self.body = raw_request[2:-1] - except IndexError: - pass + + def __init__(self, raw_request: str): + if re.match(r"^[A-Z]{2}$", raw_request): + self.command_code = raw_request[0:2] # Still take the first 2 characters, because of lingering crlf + self.body = None + elif re.match(r"^[A-Z]{2} .+", raw_request): + command_code: str = raw_request[0:2] + body: str = raw_request[3:-1] or "" + + if len(body.split("\n")) > 1: + raise InvalidRequest("Multiline requests are not supported") + + self.command_code = command_code + self.body = body + else: + raise InvalidRequest("Invalid request") + + def __str__(self): + return f"{self.command_code} {self.body}" + + +__all__ = ["Request"] diff --git a/src/core/response.py b/src/core/response.py index e69de29..da10372 100644 --- a/src/core/response.py +++ b/src/core/response.py @@ -0,0 +1,3 @@ +Response = str + +__all__ = ["Response"] diff --git a/src/utils/constants.py b/src/utils/constants.py new file mode 100644 index 0000000..a13813c --- /dev/null +++ b/src/utils/constants.py @@ -0,0 +1,5 @@ +import re + +IP_REGEX = r"^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$" +ACCOUNT_NUMBER_REGEX = r"[0-9]{9}" +MONEY_AMOUNT_MAXIMUM = (2 ^ 63) - 1 diff --git a/src/utils/ip.py b/src/utils/ip.py new file mode 100644 index 0000000..cac71b6 --- /dev/null +++ b/src/utils/ip.py @@ -0,0 +1,17 @@ +import socket + + +def get_ip(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.settimeout(0) + try: + s.connect(('1.1.1.1', 1)) + ip = s.getsockname()[0] + except Exception: + ip = '127.0.0.1' + finally: + s.close() + return ip + + +__all__ = ["get_ip"] diff --git a/src/utils/logger.py b/src/utils/logger.py index 1afa0b0..c21ab50 100644 --- a/src/utils/logger.py +++ b/src/utils/logger.py @@ -4,9 +4,9 @@ import logging def setup_logger(verbosity: str): if verbosity == "DEBUG": - log_format = "[%(levelname)s] - %(name)s:%(lineno)d - %(message)s" + log_format = "[ %(levelname)s / %(processName)s ] - %(name)s:%(lineno)d - %(message)s" else: - log_format = "[%(levelname)s] - %(message)s" + log_format = "[ %(levelname)s ] - %(message)s" logging.basicConfig(level=verbosity, format=log_format, stream=sys.stdout)