Major changes, most notably structure for commands is being implemented and full cleanup on exit is performed
This commit is contained in:
		
							parent
							
								
									61399c555a
								
							
						
					
					
						commit
						a70d052cbf
					
				| @ -3,6 +3,11 @@ | ||||
| # invalid value is provided | ||||
| RESPONSE_TIMEOUT=5 | ||||
| 
 | ||||
| # In seconds | ||||
| # Default 60 if user doesn't interact | ||||
| # within this timeframe | ||||
| CLIENT_IDLE_TIMEOUT=60 | ||||
| 
 | ||||
| # A valid port number | ||||
| # If not provided or invalid, defaults to 65526 | ||||
| PORT=65526 | ||||
|  | ||||
| @ -6,7 +6,7 @@ authors = ["Thastertyn <thastertyn@thastertyn.xyz>"] | ||||
| readme = "README.md" | ||||
| 
 | ||||
| [tool.poetry.dependencies] | ||||
| python = "^3.12" | ||||
| python = "^3.6" | ||||
| sqlalchemy = "^2.0.37" | ||||
| python-dotenv = "^1.0.1" | ||||
| 
 | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| from bank_node.bank_node import BankNode | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     BankNode().start_server() | ||||
|     node = BankNode() | ||||
|     node.start() | ||||
|  | ||||
| @ -3,12 +3,14 @@ import signal | ||||
| import sys | ||||
| import logging | ||||
| 
 | ||||
| 
 | ||||
| from core.config import BankNodeConfig | ||||
| from core.exceptions import ConfigError | ||||
| from utils import setup_logger | ||||
| from database.database_manager import DatabaseManager | ||||
| 
 | ||||
| from bank_node.bank_worker import BankWorker | ||||
| 
 | ||||
| 
 | ||||
| class BankNode(): | ||||
|     def __init__(self): | ||||
|         try: | ||||
| @ -20,16 +22,24 @@ class BankNode(): | ||||
|             self.logger.info("Config is valid") | ||||
|             self.logger.debug("Starting Bank Node") | ||||
| 
 | ||||
|             self._setup_signals() | ||||
|             self.__setup_signals() | ||||
| 
 | ||||
|             self.database_manager = DatabaseManager() | ||||
|             self.socket_server = None | ||||
| 
 | ||||
|         except ConfigError as e: | ||||
|             print(e) | ||||
|             self.exit_with_error() | ||||
| 
 | ||||
|     def _setup_signals(self): | ||||
|     def __setup_signals(self): | ||||
|         self.logger.debug("Setting up exit signal hooks") | ||||
| 
 | ||||
|         # Not as clean as | ||||
|         # atexit.register(self.gracefully_exit) | ||||
|         # But it gives more control and is easier | ||||
|         # to understand in the output | ||||
|         # and looks better | ||||
| 
 | ||||
|         # Handle C standard signals | ||||
|         signal.signal(signal.SIGTERM, self.gracefully_exit) | ||||
|         signal.signal(signal.SIGINT, self.gracefully_exit) | ||||
| @ -40,27 +50,48 @@ class BankNode(): | ||||
|             signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit) | ||||
|             signal.signal(signal.SIGBREAK, self.gracefully_exit) | ||||
| 
 | ||||
|     def start_server(self): | ||||
|     def start(self): | ||||
|         for port in range(self.config.scan_port_start, self.config.scan_port_end + 1): | ||||
|             self.logger.debug("Trying port %d", port) | ||||
|             try: | ||||
|                 self.config.used_port = port | ||||
|                 self.__start_server(port) | ||||
|                 return | ||||
|             except socket.error as e: | ||||
|                 if e.errno == 98:  # Address is in use | ||||
|                     self.logger.info("Port %d in use, trying next port", port) | ||||
| 
 | ||||
|         self.logger.error("All ports are in use") | ||||
|         self.exit_with_error() | ||||
| 
 | ||||
|     def __start_server(self, port: int): | ||||
|         with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server: | ||||
|             socket_server.bind(("127.0.0.1", self.config.port)) | ||||
|             socket_server.bind((self.config.ip, port)) | ||||
|             socket_server.listen() | ||||
|             self.logger.info("Listening on %s:%s", "127.0.0.1", self.config.port) | ||||
|             self.socket_server = socket_server | ||||
|             self.logger.info("Listening on %s:%s", self.config.ip, port) | ||||
| 
 | ||||
|             while True: | ||||
|                 conn, addr = socket_server.accept() | ||||
|                 with conn: | ||||
|                     print(f"{addr} connected") | ||||
|                     request = conn.recv(1024).decode("utf-8") | ||||
|                     print(f"Request:\n{request}") | ||||
|                 client_socket, address = socket_server.accept() | ||||
|                 self.logger.info("%s connected", address[0]) | ||||
| 
 | ||||
|                 process = BankWorker(client_socket, address, self.config.to_dict()) | ||||
|                 process.start() | ||||
| 
 | ||||
|     def exit_with_error(self): | ||||
|         self.cleanup() | ||||
|         """Exit the application with status of 1""" | ||||
| 
 | ||||
|         sys.exit(1) | ||||
| 
 | ||||
|     def cleanup(self): | ||||
|         pass | ||||
|     def gracefully_exit(self, signum, _): | ||||
|         """Log the signal caught and exit with status 0""" | ||||
| 
 | ||||
|     def gracefully_exit(self, signum, frame): | ||||
|         signal_name = signal.Signals(signum).name | ||||
|         self.logger.warning("Caught %s. Cleaning up before exiting", signal_name) | ||||
|         self.logger.warning("Caught %s - cleaning up", signal_name) | ||||
|         self.cleanup() | ||||
|         self.logger.info("Exiting") | ||||
|         sys.exit(0) | ||||
| 
 | ||||
|     def cleanup(self): | ||||
|         self.logger.debug("Closing socket server") | ||||
|         self.socket_server.close() | ||||
|  | ||||
| @ -1,4 +0,0 @@ | ||||
| from multiprocessing import Process | ||||
| 
 | ||||
| class BankProcess(Process): | ||||
|     pass | ||||
							
								
								
									
										88
									
								
								src/bank_node/bank_worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								src/bank_node/bank_worker.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,88 @@ | ||||
| import socket | ||||
| import multiprocessing | ||||
| import logging | ||||
| from typing import Tuple, Dict | ||||
| import signal | ||||
| import sys | ||||
| 
 | ||||
| from bank_protocol.command_handler import CommandHandler | ||||
| from core import Request, Response | ||||
| from core.exceptions import BankNodeError | ||||
| 
 | ||||
| 
 | ||||
| class BankWorker(multiprocessing.Process): | ||||
|     def __init__(self, client_socket: socket.socket, client_address: Tuple, config: Dict): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|         self.client_socket = client_socket | ||||
|         self.client_socket.settimeout(config["client_idle_timeout"]) | ||||
|         self.client_address = client_address | ||||
| 
 | ||||
|         self.command_handler = CommandHandler(config) | ||||
|         self.config = config | ||||
| 
 | ||||
|     def __setup_signals(self): | ||||
|         self.logger.debug("Setting up exit signal hooks for worker") | ||||
| 
 | ||||
|         # Handle C standard signals | ||||
|         signal.signal(signal.SIGTERM, self.gracefully_exit_worker) | ||||
|         signal.signal(signal.SIGINT, self.gracefully_exit_worker) | ||||
| 
 | ||||
|         # Handle windows related signals | ||||
|         if sys.platform == "win32": | ||||
|             signal.signal(signal.CTRL_C_EVENT, self.gracefully_exit_worker) | ||||
|             signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit_worker) | ||||
|             signal.signal(signal.SIGBREAK, self.gracefully_exit_worker) | ||||
| 
 | ||||
|     def run(self): | ||||
|         self.__setup_signals() | ||||
|         with self.client_socket: | ||||
|             while True: | ||||
|                 try: | ||||
|                     raw_request = self.client_socket.recv(1024).decode("utf-8") | ||||
| 
 | ||||
|                     if not raw_request: | ||||
|                         self.logger.debug("%s disconnected", self.client_address[0]) | ||||
|                         break | ||||
| 
 | ||||
|                     request = Request(raw_request) | ||||
|                     self.logger.debug("Received request from %s - %s", self.client_address[0], request) | ||||
| 
 | ||||
|                     response: Response = self.command_handler.execute(request) + "\n\r" | ||||
| 
 | ||||
|                     self.client_socket.sendall(response.encode("utf-8")) | ||||
| 
 | ||||
|                 except socket.timeout: | ||||
|                     self.logger.debug("Client was idle for too long. Ending connection") | ||||
|                     response = "ER Idle too long\n\r" | ||||
|                     self.client_socket.sendall(response.encode("utf-8")) | ||||
|                     self.client_socket.shutdown(socket.SHUT_RDWR) | ||||
|                     self.client_socket.close() | ||||
|                     break | ||||
|                 except BankNodeError as e: | ||||
|                     response = "ER " + e.message + "\n\r" | ||||
|                     self.client_socket.sendall(response.encode("utf-8")) | ||||
|                 except socket.error as e: | ||||
|                     response = "ER Internal server error\n\r" | ||||
|                     self.logger.error(e) | ||||
|                     break | ||||
| 
 | ||||
|         self.logger.debug("Closing process for %s", self.client_address[0]) | ||||
| 
 | ||||
|     def gracefully_exit_worker(self, signum, _): | ||||
|         """Log the signal caught and exit with status 0""" | ||||
| 
 | ||||
|         signal_name = signal.Signals(signum).name | ||||
|         self.logger.warning("Worker caught %s - cleaning up", signal_name) | ||||
|         self.cleanup() | ||||
|         sys.exit(0) | ||||
| 
 | ||||
|     def cleanup(self): | ||||
|         self.logger.info("Closing connection with %s", self.client_address[0]) | ||||
|         if self.client_socket: | ||||
|             self.client_socket.close() | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["BankWorker"] | ||||
| @ -0,0 +1,42 @@ | ||||
| import logging | ||||
| from typing import Dict, Callable | ||||
| 
 | ||||
| from core import Request, Response | ||||
| from core.exceptions import BankNodeError | ||||
| 
 | ||||
| from bank_protocol.commands import (account_balance, | ||||
|                                     account_create, | ||||
|                                     account_deposit, | ||||
|                                     account_remove, | ||||
|                                     account_withdrawal, | ||||
|                                     bank_code, | ||||
|                                     bank_number_of_clients, | ||||
|                                     bank_total_amount) | ||||
| 
 | ||||
| 
 | ||||
| class CommandHandler: | ||||
|     def __init__(self, config: Dict): | ||||
|         self.logger = logging.getLogger(__name__) | ||||
|         self.config = config | ||||
| 
 | ||||
|         self.registered_commands: Dict[str, Callable] = { | ||||
|             "BC": bank_code, | ||||
|             "AC": account_create, | ||||
|             "AD": account_deposit, | ||||
|             "AW": account_withdrawal, | ||||
|             "AB": account_balance, | ||||
|             "AR": account_remove, | ||||
|             "BA": bank_total_amount, | ||||
|             "BN": bank_number_of_clients | ||||
|         } | ||||
| 
 | ||||
|     def execute(self, request: Request) -> Response: | ||||
|         if request.command_code not in self.registered_commands: | ||||
|             self.logger.warning("Unknown command %s", request.command_code) | ||||
|             raise BankNodeError(f"Unknown command {request.command_code}") | ||||
| 
 | ||||
|         command = self.registered_commands[request.command_code] | ||||
| 
 | ||||
|         response = command(request, self.config) | ||||
| 
 | ||||
|         return f"{request.command_code} {response}" | ||||
| @ -0,0 +1,19 @@ | ||||
| from .account_balance_command import * | ||||
| from .account_create_command import * | ||||
| from .account_deposit_command import * | ||||
| from .account_remove_command import * | ||||
| from .account_withdrawal_command import * | ||||
| from .bank_code_command import * | ||||
| from .bank_number_of_clients_command import * | ||||
| from .bank_total_amount_command import * | ||||
| 
 | ||||
| __all__ = [ | ||||
|     *account_balance_command.__all__, | ||||
|     *account_create_command.__all__, | ||||
|     *account_deposit_command.__all__, | ||||
|     *account_remove_command.__all__, | ||||
|     *account_withdrawal_command.__all__, | ||||
|     *bank_code_command.__all__, | ||||
|     *bank_number_of_clients_command.__all__, | ||||
|     *bank_total_amount_command.__all__ | ||||
| ] | ||||
| @ -0,0 +1,8 @@ | ||||
| from core.request import Request | ||||
| import re | ||||
| 
 | ||||
| def account_balance(request: Request): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["account_balance"] | ||||
| @ -0,0 +1,7 @@ | ||||
| from core.request import Request | ||||
| 
 | ||||
| def account_create(request: Request): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["account_create"] | ||||
| @ -0,0 +1,7 @@ | ||||
| from core.request import Request | ||||
| 
 | ||||
| def account_deposit(request: Request): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["account_deposit"] | ||||
| @ -0,0 +1,7 @@ | ||||
| from core.request import Request | ||||
| 
 | ||||
| def account_remove(request: Request): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["account_remove"] | ||||
| @ -0,0 +1,7 @@ | ||||
| from core.request import Request | ||||
| 
 | ||||
| def account_withdrawal(request: Request): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["account_withdrawal"] | ||||
| @ -0,0 +1,12 @@ | ||||
| from typing import Dict | ||||
| 
 | ||||
| from core import Request, Response | ||||
| from bank_protocol.exceptions import InvalidRequest | ||||
| 
 | ||||
| def bank_code(request: Request, config: Dict) -> Response: | ||||
|     if request.body is not None: | ||||
|         raise InvalidRequest("Incorrect usage") | ||||
| 
 | ||||
|     return config["ip"] | ||||
| 
 | ||||
| __all__ = ["bank_code"] | ||||
| @ -0,0 +1,7 @@ | ||||
| from core.request import Request | ||||
| 
 | ||||
| def bank_number_of_clients(request: Request): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["bank_number_of_clients"] | ||||
| @ -0,0 +1,7 @@ | ||||
| from core.request import Request | ||||
| 
 | ||||
| def bank_total_amount(request: Request): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["bank_total_amount"] | ||||
| @ -0,0 +1,7 @@ | ||||
| from .request import * | ||||
| from .response import * | ||||
| 
 | ||||
| __all__ = [ | ||||
|     *request.__all__, | ||||
|     *response.__all__ | ||||
| ] | ||||
| @ -1,9 +1,12 @@ | ||||
| import os | ||||
| import logging | ||||
| import re | ||||
| 
 | ||||
| import dotenv | ||||
| 
 | ||||
| from core.exceptions import ConfigError | ||||
| from utils.ip import get_ip | ||||
| from utils.constants import IP_REGEX | ||||
| 
 | ||||
| dotenv.load_dotenv() | ||||
| 
 | ||||
| @ -11,11 +14,17 @@ dotenv.load_dotenv() | ||||
| class BankNodeConfig: | ||||
|     def __init__(self): | ||||
|         self.logger = logging.getLogger(__name__) | ||||
|         port = os.getenv("PORT", "65526") | ||||
|         # Added for compatibility with python 3.9 instead of using | ||||
|         # logging.getLevelNamesMapping() | ||||
|         allowed_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | ||||
| 
 | ||||
|         ip = os.getenv("IP", get_ip()) | ||||
|         timeout = os.getenv("RESPONSE_TIMEOUT", "5") | ||||
|         client_idle_timeout = os.getenv("CLIENT_IDLE_TIMEOUT", "60") | ||||
|         verbosity = os.getenv("VERBOSITY", "INFO") | ||||
|         scan_port_range = os.getenv("SCAN_PORT_RANGE") | ||||
| 
 | ||||
|         # Port validation | ||||
|         if not scan_port_range or scan_port_range == "": | ||||
|             self.logger.error("Scan port range not defined") | ||||
|             raise ConfigError("Scan port range not defined") | ||||
| @ -26,20 +35,37 @@ class BankNodeConfig: | ||||
|             self.logger.error("Scan port range is not in valid format") | ||||
|             raise ConfigError("Scan port range is not in valid format") | ||||
| 
 | ||||
| 
 | ||||
|         if not port.isdigit(): | ||||
|             self.logger.warning("Application port is invalid - Not a number. Falling back to default port 65526") | ||||
| 
 | ||||
|         # Timeout validation | ||||
|         if not timeout.isdigit(): | ||||
|             self.logger.warning("Request timeout is invalid - Not a number. Falling back to default timeout 5 seconds") | ||||
| 
 | ||||
|         allowed_levels = logging.getLevelNamesMapping() | ||||
|         if not client_idle_timeout.isdigit(): | ||||
|             self.logger.warning("Client idle timeout is invalid - Not a number. Falling back to default idle timeout 60 seconds") | ||||
| 
 | ||||
|         # Verbosity | ||||
|         if verbosity not in allowed_levels: | ||||
|             self.logger.warning("Verbosity %s is invalid - Doesn't exist. Falling back to default verbosity INFO", verbosity) | ||||
|             self.logger.warning("Unknown verbosity %s. Falling back to default verbosity INFO", verbosity) | ||||
| 
 | ||||
|         self.port = int(port) | ||||
|         # IP validation | ||||
|         if not re.match(IP_REGEX, ip): | ||||
|             self.logger.error("Invalid IP in configuration") | ||||
|             raise ConfigError("Invalid IP in configuration") | ||||
| 
 | ||||
|         self.used_port: int | ||||
|         self.ip = ip | ||||
|         self.timeout = int(timeout) | ||||
|         self.client_idle_timeout = int(client_idle_timeout) | ||||
|         self.verbosity = verbosity | ||||
|         self.scan_port_start = int(range_split[0]) | ||||
|         self.scan_port_end = int(range_split[1]) | ||||
| 
 | ||||
|     def to_dict(self): | ||||
|         return { | ||||
|             "used_port": self.used_port, | ||||
|             "ip": self.ip, | ||||
|             "timeout": self.timeout, | ||||
|             "client_idle_timeout": self.client_idle_timeout, | ||||
|             "verbosity": self.verbosity, | ||||
|             "scan_port_start": self.scan_port_start, | ||||
|             "scan_port_end": self.scan_port_end, | ||||
|         } | ||||
|  | ||||
| @ -1,10 +1,28 @@ | ||||
| from core.exceptions import * | ||||
| import re | ||||
| 
 | ||||
| from bank_protocol.exceptions import InvalidRequest | ||||
| 
 | ||||
| 
 | ||||
| class Request(): | ||||
|     def __init__(self, raw_request): | ||||
|         try: | ||||
|             self.command_code = raw_request[0:1] | ||||
|             self.body = raw_request[2:-1] | ||||
|         except IndexError: | ||||
|             pass | ||||
| 
 | ||||
|     def __init__(self, raw_request: str): | ||||
|         if re.match(r"^[A-Z]{2}$", raw_request): | ||||
|             self.command_code = raw_request[0:2] # Still take the first 2 characters, because of lingering crlf | ||||
|             self.body = None | ||||
|         elif re.match(r"^[A-Z]{2} .+", raw_request): | ||||
|             command_code: str = raw_request[0:2] | ||||
|             body: str = raw_request[3:-1] or "" | ||||
| 
 | ||||
|             if len(body.split("\n")) > 1: | ||||
|                 raise InvalidRequest("Multiline requests are not supported") | ||||
| 
 | ||||
|             self.command_code = command_code | ||||
|             self.body = body | ||||
|         else: | ||||
|             raise InvalidRequest("Invalid request") | ||||
| 
 | ||||
|     def __str__(self): | ||||
|         return f"{self.command_code} {self.body}" | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["Request"] | ||||
|  | ||||
| @ -0,0 +1,3 @@ | ||||
| Response = str | ||||
| 
 | ||||
| __all__ = ["Response"] | ||||
							
								
								
									
										5
									
								
								src/utils/constants.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								src/utils/constants.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,5 @@ | ||||
| import re | ||||
| 
 | ||||
| IP_REGEX = r"^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$" | ||||
| ACCOUNT_NUMBER_REGEX = r"[0-9]{9}" | ||||
| MONEY_AMOUNT_MAXIMUM = (2 ^ 63) - 1 | ||||
							
								
								
									
										17
									
								
								src/utils/ip.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								src/utils/ip.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,17 @@ | ||||
| import socket | ||||
| 
 | ||||
| 
 | ||||
| def get_ip(): | ||||
|     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||||
|     s.settimeout(0) | ||||
|     try: | ||||
|         s.connect(('1.1.1.1', 1)) | ||||
|         ip = s.getsockname()[0] | ||||
|     except Exception: | ||||
|         ip = '127.0.0.1' | ||||
|     finally: | ||||
|         s.close() | ||||
|     return ip | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["get_ip"] | ||||
| @ -4,9 +4,9 @@ import logging | ||||
| 
 | ||||
| def setup_logger(verbosity: str): | ||||
|     if verbosity == "DEBUG": | ||||
|         log_format = "[%(levelname)s] - %(name)s:%(lineno)d - %(message)s" | ||||
|         log_format = "[ %(levelname)s / %(processName)s ] - %(name)s:%(lineno)d - %(message)s" | ||||
|     else: | ||||
|         log_format = "[%(levelname)s] - %(message)s" | ||||
|         log_format = "[ %(levelname)s ] - %(message)s" | ||||
| 
 | ||||
|     logging.basicConfig(level=verbosity, format=log_format, stream=sys.stdout) | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user