Major changes, most notably structure for commands is being implemented and full cleanup on exit is performed

This commit is contained in:
Thastertyn 2025-02-04 16:46:21 +01:00
parent 61399c555a
commit a70d052cbf
23 changed files with 360 additions and 40 deletions

View File

@ -3,6 +3,11 @@
# invalid value is provided
RESPONSE_TIMEOUT=5
# In seconds
# Default 60 if user doesn't interact
# within this timeframe
CLIENT_IDLE_TIMEOUT=60
# A valid port number
# If not provided or invalid, defaults to 65526
PORT=65526

View File

@ -6,7 +6,7 @@ authors = ["Thastertyn <thastertyn@thastertyn.xyz>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.12"
python = "^3.6"
sqlalchemy = "^2.0.37"
python-dotenv = "^1.0.1"

View File

@ -1,4 +1,5 @@
from bank_node.bank_node import BankNode
if __name__ == "__main__":
BankNode().start_server()
node = BankNode()
node.start()

View File

@ -3,12 +3,14 @@ import signal
import sys
import logging
from core.config import BankNodeConfig
from core.exceptions import ConfigError
from utils import setup_logger
from database.database_manager import DatabaseManager
from bank_node.bank_worker import BankWorker
class BankNode():
def __init__(self):
try:
@ -20,16 +22,24 @@ class BankNode():
self.logger.info("Config is valid")
self.logger.debug("Starting Bank Node")
self._setup_signals()
self.__setup_signals()
self.database_manager = DatabaseManager()
self.socket_server = None
except ConfigError as e:
print(e)
self.exit_with_error()
def _setup_signals(self):
def __setup_signals(self):
self.logger.debug("Setting up exit signal hooks")
# Not as clean as
# atexit.register(self.gracefully_exit)
# But it gives more control and is easier
# to understand in the output
# and looks better
# Handle C standard signals
signal.signal(signal.SIGTERM, self.gracefully_exit)
signal.signal(signal.SIGINT, self.gracefully_exit)
@ -40,27 +50,48 @@ class BankNode():
signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit)
signal.signal(signal.SIGBREAK, self.gracefully_exit)
def start_server(self):
def start(self):
for port in range(self.config.scan_port_start, self.config.scan_port_end + 1):
self.logger.debug("Trying port %d", port)
try:
self.config.used_port = port
self.__start_server(port)
return
except socket.error as e:
if e.errno == 98: # Address is in use
self.logger.info("Port %d in use, trying next port", port)
self.logger.error("All ports are in use")
self.exit_with_error()
def __start_server(self, port: int):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server:
socket_server.bind(("127.0.0.1", self.config.port))
socket_server.bind((self.config.ip, port))
socket_server.listen()
self.logger.info("Listening on %s:%s", "127.0.0.1", self.config.port)
self.socket_server = socket_server
self.logger.info("Listening on %s:%s", self.config.ip, port)
while True:
conn, addr = socket_server.accept()
with conn:
print(f"{addr} connected")
request = conn.recv(1024).decode("utf-8")
print(f"Request:\n{request}")
client_socket, address = socket_server.accept()
self.logger.info("%s connected", address[0])
process = BankWorker(client_socket, address, self.config.to_dict())
process.start()
def exit_with_error(self):
self.cleanup()
"""Exit the application with status of 1"""
sys.exit(1)
def cleanup(self):
pass
def gracefully_exit(self, signum, _):
"""Log the signal caught and exit with status 0"""
def gracefully_exit(self, signum, frame):
signal_name = signal.Signals(signum).name
self.logger.warning("Caught %s. Cleaning up before exiting", signal_name)
self.logger.warning("Caught %s - cleaning up", signal_name)
self.cleanup()
self.logger.info("Exiting")
sys.exit(0)
def cleanup(self):
self.logger.debug("Closing socket server")
self.socket_server.close()

View File

@ -1,4 +0,0 @@
from multiprocessing import Process
class BankProcess(Process):
pass

View File

@ -0,0 +1,88 @@
import socket
import multiprocessing
import logging
from typing import Tuple, Dict
import signal
import sys
from bank_protocol.command_handler import CommandHandler
from core import Request, Response
from core.exceptions import BankNodeError
class BankWorker(multiprocessing.Process):
def __init__(self, client_socket: socket.socket, client_address: Tuple, config: Dict):
super().__init__()
self.logger = logging.getLogger(__name__)
self.client_socket = client_socket
self.client_socket.settimeout(config["client_idle_timeout"])
self.client_address = client_address
self.command_handler = CommandHandler(config)
self.config = config
def __setup_signals(self):
self.logger.debug("Setting up exit signal hooks for worker")
# Handle C standard signals
signal.signal(signal.SIGTERM, self.gracefully_exit_worker)
signal.signal(signal.SIGINT, self.gracefully_exit_worker)
# Handle windows related signals
if sys.platform == "win32":
signal.signal(signal.CTRL_C_EVENT, self.gracefully_exit_worker)
signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit_worker)
signal.signal(signal.SIGBREAK, self.gracefully_exit_worker)
def run(self):
self.__setup_signals()
with self.client_socket:
while True:
try:
raw_request = self.client_socket.recv(1024).decode("utf-8")
if not raw_request:
self.logger.debug("%s disconnected", self.client_address[0])
break
request = Request(raw_request)
self.logger.debug("Received request from %s - %s", self.client_address[0], request)
response: Response = self.command_handler.execute(request) + "\n\r"
self.client_socket.sendall(response.encode("utf-8"))
except socket.timeout:
self.logger.debug("Client was idle for too long. Ending connection")
response = "ER Idle too long\n\r"
self.client_socket.sendall(response.encode("utf-8"))
self.client_socket.shutdown(socket.SHUT_RDWR)
self.client_socket.close()
break
except BankNodeError as e:
response = "ER " + e.message + "\n\r"
self.client_socket.sendall(response.encode("utf-8"))
except socket.error as e:
response = "ER Internal server error\n\r"
self.logger.error(e)
break
self.logger.debug("Closing process for %s", self.client_address[0])
def gracefully_exit_worker(self, signum, _):
"""Log the signal caught and exit with status 0"""
signal_name = signal.Signals(signum).name
self.logger.warning("Worker caught %s - cleaning up", signal_name)
self.cleanup()
sys.exit(0)
def cleanup(self):
self.logger.info("Closing connection with %s", self.client_address[0])
if self.client_socket:
self.client_socket.close()
__all__ = ["BankWorker"]

View File

@ -0,0 +1,42 @@
import logging
from typing import Dict, Callable
from core import Request, Response
from core.exceptions import BankNodeError
from bank_protocol.commands import (account_balance,
account_create,
account_deposit,
account_remove,
account_withdrawal,
bank_code,
bank_number_of_clients,
bank_total_amount)
class CommandHandler:
def __init__(self, config: Dict):
self.logger = logging.getLogger(__name__)
self.config = config
self.registered_commands: Dict[str, Callable] = {
"BC": bank_code,
"AC": account_create,
"AD": account_deposit,
"AW": account_withdrawal,
"AB": account_balance,
"AR": account_remove,
"BA": bank_total_amount,
"BN": bank_number_of_clients
}
def execute(self, request: Request) -> Response:
if request.command_code not in self.registered_commands:
self.logger.warning("Unknown command %s", request.command_code)
raise BankNodeError(f"Unknown command {request.command_code}")
command = self.registered_commands[request.command_code]
response = command(request, self.config)
return f"{request.command_code} {response}"

View File

@ -0,0 +1,19 @@
from .account_balance_command import *
from .account_create_command import *
from .account_deposit_command import *
from .account_remove_command import *
from .account_withdrawal_command import *
from .bank_code_command import *
from .bank_number_of_clients_command import *
from .bank_total_amount_command import *
__all__ = [
*account_balance_command.__all__,
*account_create_command.__all__,
*account_deposit_command.__all__,
*account_remove_command.__all__,
*account_withdrawal_command.__all__,
*bank_code_command.__all__,
*bank_number_of_clients_command.__all__,
*bank_total_amount_command.__all__
]

View File

@ -0,0 +1,8 @@
from core.request import Request
import re
def account_balance(request: Request):
pass
__all__ = ["account_balance"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def account_create(request: Request):
pass
__all__ = ["account_create"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def account_deposit(request: Request):
pass
__all__ = ["account_deposit"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def account_remove(request: Request):
pass
__all__ = ["account_remove"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def account_withdrawal(request: Request):
pass
__all__ = ["account_withdrawal"]

View File

@ -0,0 +1,12 @@
from typing import Dict
from core import Request, Response
from bank_protocol.exceptions import InvalidRequest
def bank_code(request: Request, config: Dict) -> Response:
if request.body is not None:
raise InvalidRequest("Incorrect usage")
return config["ip"]
__all__ = ["bank_code"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def bank_number_of_clients(request: Request):
pass
__all__ = ["bank_number_of_clients"]

View File

@ -0,0 +1,7 @@
from core.request import Request
def bank_total_amount(request: Request):
pass
__all__ = ["bank_total_amount"]

View File

@ -0,0 +1,7 @@
from .request import *
from .response import *
__all__ = [
*request.__all__,
*response.__all__
]

View File

@ -1,9 +1,12 @@
import os
import logging
import re
import dotenv
from core.exceptions import ConfigError
from utils.ip import get_ip
from utils.constants import IP_REGEX
dotenv.load_dotenv()
@ -11,11 +14,17 @@ dotenv.load_dotenv()
class BankNodeConfig:
def __init__(self):
self.logger = logging.getLogger(__name__)
port = os.getenv("PORT", "65526")
# Added for compatibility with python 3.9 instead of using
# logging.getLevelNamesMapping()
allowed_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
ip = os.getenv("IP", get_ip())
timeout = os.getenv("RESPONSE_TIMEOUT", "5")
client_idle_timeout = os.getenv("CLIENT_IDLE_TIMEOUT", "60")
verbosity = os.getenv("VERBOSITY", "INFO")
scan_port_range = os.getenv("SCAN_PORT_RANGE")
# Port validation
if not scan_port_range or scan_port_range == "":
self.logger.error("Scan port range not defined")
raise ConfigError("Scan port range not defined")
@ -26,20 +35,37 @@ class BankNodeConfig:
self.logger.error("Scan port range is not in valid format")
raise ConfigError("Scan port range is not in valid format")
if not port.isdigit():
self.logger.warning("Application port is invalid - Not a number. Falling back to default port 65526")
# Timeout validation
if not timeout.isdigit():
self.logger.warning("Request timeout is invalid - Not a number. Falling back to default timeout 5 seconds")
allowed_levels = logging.getLevelNamesMapping()
if not client_idle_timeout.isdigit():
self.logger.warning("Client idle timeout is invalid - Not a number. Falling back to default idle timeout 60 seconds")
# Verbosity
if verbosity not in allowed_levels:
self.logger.warning("Verbosity %s is invalid - Doesn't exist. Falling back to default verbosity INFO", verbosity)
self.logger.warning("Unknown verbosity %s. Falling back to default verbosity INFO", verbosity)
self.port = int(port)
# IP validation
if not re.match(IP_REGEX, ip):
self.logger.error("Invalid IP in configuration")
raise ConfigError("Invalid IP in configuration")
self.used_port: int
self.ip = ip
self.timeout = int(timeout)
self.client_idle_timeout = int(client_idle_timeout)
self.verbosity = verbosity
self.scan_port_start = int(range_split[0])
self.scan_port_end = int(range_split[1])
def to_dict(self):
return {
"used_port": self.used_port,
"ip": self.ip,
"timeout": self.timeout,
"client_idle_timeout": self.client_idle_timeout,
"verbosity": self.verbosity,
"scan_port_start": self.scan_port_start,
"scan_port_end": self.scan_port_end,
}

View File

@ -1,10 +1,28 @@
from core.exceptions import *
import re
from bank_protocol.exceptions import InvalidRequest
class Request():
def __init__(self, raw_request):
try:
self.command_code = raw_request[0:1]
self.body = raw_request[2:-1]
except IndexError:
pass
def __init__(self, raw_request: str):
if re.match(r"^[A-Z]{2}$", raw_request):
self.command_code = raw_request[0:2] # Still take the first 2 characters, because of lingering crlf
self.body = None
elif re.match(r"^[A-Z]{2} .+", raw_request):
command_code: str = raw_request[0:2]
body: str = raw_request[3:-1] or ""
if len(body.split("\n")) > 1:
raise InvalidRequest("Multiline requests are not supported")
self.command_code = command_code
self.body = body
else:
raise InvalidRequest("Invalid request")
def __str__(self):
return f"{self.command_code} {self.body}"
__all__ = ["Request"]

View File

@ -0,0 +1,3 @@
Response = str
__all__ = ["Response"]

5
src/utils/constants.py Normal file
View File

@ -0,0 +1,5 @@
import re
IP_REGEX = r"^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$"
ACCOUNT_NUMBER_REGEX = r"[0-9]{9}"
MONEY_AMOUNT_MAXIMUM = (2 ^ 63) - 1

17
src/utils/ip.py Normal file
View File

@ -0,0 +1,17 @@
import socket
def get_ip():
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
try:
s.connect(('1.1.1.1', 1))
ip = s.getsockname()[0]
except Exception:
ip = '127.0.0.1'
finally:
s.close()
return ip
__all__ = ["get_ip"]

View File

@ -4,7 +4,7 @@ import logging
def setup_logger(verbosity: str):
if verbosity == "DEBUG":
log_format = "[%(levelname)s] - %(name)s:%(lineno)d - %(message)s"
log_format = "[ %(levelname)s / %(processName)s ] - %(name)s:%(lineno)d - %(message)s"
else:
log_format = "[ %(levelname)s ] - %(message)s"