Major changes, most notably structure for commands is being implemented and full cleanup on exit is performed
This commit is contained in:
parent
61399c555a
commit
a70d052cbf
@ -3,6 +3,11 @@
|
||||
# invalid value is provided
|
||||
RESPONSE_TIMEOUT=5
|
||||
|
||||
# In seconds
|
||||
# Default 60 if user doesn't interact
|
||||
# within this timeframe
|
||||
CLIENT_IDLE_TIMEOUT=60
|
||||
|
||||
# A valid port number
|
||||
# If not provided or invalid, defaults to 65526
|
||||
PORT=65526
|
||||
|
@ -6,7 +6,7 @@ authors = ["Thastertyn <thastertyn@thastertyn.xyz>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.12"
|
||||
python = "^3.6"
|
||||
sqlalchemy = "^2.0.37"
|
||||
python-dotenv = "^1.0.1"
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from bank_node.bank_node import BankNode
|
||||
|
||||
if __name__ == "__main__":
|
||||
BankNode().start_server()
|
||||
node = BankNode()
|
||||
node.start()
|
||||
|
@ -3,12 +3,14 @@ import signal
|
||||
import sys
|
||||
import logging
|
||||
|
||||
|
||||
from core.config import BankNodeConfig
|
||||
from core.exceptions import ConfigError
|
||||
from utils import setup_logger
|
||||
from database.database_manager import DatabaseManager
|
||||
|
||||
from bank_node.bank_worker import BankWorker
|
||||
|
||||
|
||||
class BankNode():
|
||||
def __init__(self):
|
||||
try:
|
||||
@ -20,16 +22,24 @@ class BankNode():
|
||||
self.logger.info("Config is valid")
|
||||
self.logger.debug("Starting Bank Node")
|
||||
|
||||
self._setup_signals()
|
||||
self.__setup_signals()
|
||||
|
||||
self.database_manager = DatabaseManager()
|
||||
self.socket_server = None
|
||||
|
||||
except ConfigError as e:
|
||||
print(e)
|
||||
self.exit_with_error()
|
||||
|
||||
def _setup_signals(self):
|
||||
def __setup_signals(self):
|
||||
self.logger.debug("Setting up exit signal hooks")
|
||||
|
||||
# Not as clean as
|
||||
# atexit.register(self.gracefully_exit)
|
||||
# But it gives more control and is easier
|
||||
# to understand in the output
|
||||
# and looks better
|
||||
|
||||
# Handle C standard signals
|
||||
signal.signal(signal.SIGTERM, self.gracefully_exit)
|
||||
signal.signal(signal.SIGINT, self.gracefully_exit)
|
||||
@ -40,27 +50,48 @@ class BankNode():
|
||||
signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit)
|
||||
signal.signal(signal.SIGBREAK, self.gracefully_exit)
|
||||
|
||||
def start_server(self):
|
||||
def start(self):
|
||||
for port in range(self.config.scan_port_start, self.config.scan_port_end + 1):
|
||||
self.logger.debug("Trying port %d", port)
|
||||
try:
|
||||
self.config.used_port = port
|
||||
self.__start_server(port)
|
||||
return
|
||||
except socket.error as e:
|
||||
if e.errno == 98: # Address is in use
|
||||
self.logger.info("Port %d in use, trying next port", port)
|
||||
|
||||
self.logger.error("All ports are in use")
|
||||
self.exit_with_error()
|
||||
|
||||
def __start_server(self, port: int):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_server:
|
||||
socket_server.bind(("127.0.0.1", self.config.port))
|
||||
socket_server.bind((self.config.ip, port))
|
||||
socket_server.listen()
|
||||
self.logger.info("Listening on %s:%s", "127.0.0.1", self.config.port)
|
||||
self.socket_server = socket_server
|
||||
self.logger.info("Listening on %s:%s", self.config.ip, port)
|
||||
|
||||
while True:
|
||||
conn, addr = socket_server.accept()
|
||||
with conn:
|
||||
print(f"{addr} connected")
|
||||
request = conn.recv(1024).decode("utf-8")
|
||||
print(f"Request:\n{request}")
|
||||
client_socket, address = socket_server.accept()
|
||||
self.logger.info("%s connected", address[0])
|
||||
|
||||
process = BankWorker(client_socket, address, self.config.to_dict())
|
||||
process.start()
|
||||
|
||||
def exit_with_error(self):
|
||||
self.cleanup()
|
||||
"""Exit the application with status of 1"""
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
def gracefully_exit(self, signum, _):
|
||||
"""Log the signal caught and exit with status 0"""
|
||||
|
||||
def gracefully_exit(self, signum, frame):
|
||||
signal_name = signal.Signals(signum).name
|
||||
self.logger.warning("Caught %s. Cleaning up before exiting", signal_name)
|
||||
self.logger.warning("Caught %s - cleaning up", signal_name)
|
||||
self.cleanup()
|
||||
self.logger.info("Exiting")
|
||||
sys.exit(0)
|
||||
|
||||
def cleanup(self):
|
||||
self.logger.debug("Closing socket server")
|
||||
self.socket_server.close()
|
||||
|
@ -1,4 +0,0 @@
|
||||
from multiprocessing import Process
|
||||
|
||||
class BankProcess(Process):
|
||||
pass
|
88
src/bank_node/bank_worker.py
Normal file
88
src/bank_node/bank_worker.py
Normal file
@ -0,0 +1,88 @@
|
||||
import socket
|
||||
import multiprocessing
|
||||
import logging
|
||||
from typing import Tuple, Dict
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from bank_protocol.command_handler import CommandHandler
|
||||
from core import Request, Response
|
||||
from core.exceptions import BankNodeError
|
||||
|
||||
|
||||
class BankWorker(multiprocessing.Process):
|
||||
def __init__(self, client_socket: socket.socket, client_address: Tuple, config: Dict):
|
||||
super().__init__()
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
self.client_socket = client_socket
|
||||
self.client_socket.settimeout(config["client_idle_timeout"])
|
||||
self.client_address = client_address
|
||||
|
||||
self.command_handler = CommandHandler(config)
|
||||
self.config = config
|
||||
|
||||
def __setup_signals(self):
|
||||
self.logger.debug("Setting up exit signal hooks for worker")
|
||||
|
||||
# Handle C standard signals
|
||||
signal.signal(signal.SIGTERM, self.gracefully_exit_worker)
|
||||
signal.signal(signal.SIGINT, self.gracefully_exit_worker)
|
||||
|
||||
# Handle windows related signals
|
||||
if sys.platform == "win32":
|
||||
signal.signal(signal.CTRL_C_EVENT, self.gracefully_exit_worker)
|
||||
signal.signal(signal.CTRL_BREAK_EVENT, self.gracefully_exit_worker)
|
||||
signal.signal(signal.SIGBREAK, self.gracefully_exit_worker)
|
||||
|
||||
def run(self):
|
||||
self.__setup_signals()
|
||||
with self.client_socket:
|
||||
while True:
|
||||
try:
|
||||
raw_request = self.client_socket.recv(1024).decode("utf-8")
|
||||
|
||||
if not raw_request:
|
||||
self.logger.debug("%s disconnected", self.client_address[0])
|
||||
break
|
||||
|
||||
request = Request(raw_request)
|
||||
self.logger.debug("Received request from %s - %s", self.client_address[0], request)
|
||||
|
||||
response: Response = self.command_handler.execute(request) + "\n\r"
|
||||
|
||||
self.client_socket.sendall(response.encode("utf-8"))
|
||||
|
||||
except socket.timeout:
|
||||
self.logger.debug("Client was idle for too long. Ending connection")
|
||||
response = "ER Idle too long\n\r"
|
||||
self.client_socket.sendall(response.encode("utf-8"))
|
||||
self.client_socket.shutdown(socket.SHUT_RDWR)
|
||||
self.client_socket.close()
|
||||
break
|
||||
except BankNodeError as e:
|
||||
response = "ER " + e.message + "\n\r"
|
||||
self.client_socket.sendall(response.encode("utf-8"))
|
||||
except socket.error as e:
|
||||
response = "ER Internal server error\n\r"
|
||||
self.logger.error(e)
|
||||
break
|
||||
|
||||
self.logger.debug("Closing process for %s", self.client_address[0])
|
||||
|
||||
def gracefully_exit_worker(self, signum, _):
|
||||
"""Log the signal caught and exit with status 0"""
|
||||
|
||||
signal_name = signal.Signals(signum).name
|
||||
self.logger.warning("Worker caught %s - cleaning up", signal_name)
|
||||
self.cleanup()
|
||||
sys.exit(0)
|
||||
|
||||
def cleanup(self):
|
||||
self.logger.info("Closing connection with %s", self.client_address[0])
|
||||
if self.client_socket:
|
||||
self.client_socket.close()
|
||||
|
||||
|
||||
__all__ = ["BankWorker"]
|
@ -0,0 +1,42 @@
|
||||
import logging
|
||||
from typing import Dict, Callable
|
||||
|
||||
from core import Request, Response
|
||||
from core.exceptions import BankNodeError
|
||||
|
||||
from bank_protocol.commands import (account_balance,
|
||||
account_create,
|
||||
account_deposit,
|
||||
account_remove,
|
||||
account_withdrawal,
|
||||
bank_code,
|
||||
bank_number_of_clients,
|
||||
bank_total_amount)
|
||||
|
||||
|
||||
class CommandHandler:
|
||||
def __init__(self, config: Dict):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.config = config
|
||||
|
||||
self.registered_commands: Dict[str, Callable] = {
|
||||
"BC": bank_code,
|
||||
"AC": account_create,
|
||||
"AD": account_deposit,
|
||||
"AW": account_withdrawal,
|
||||
"AB": account_balance,
|
||||
"AR": account_remove,
|
||||
"BA": bank_total_amount,
|
||||
"BN": bank_number_of_clients
|
||||
}
|
||||
|
||||
def execute(self, request: Request) -> Response:
|
||||
if request.command_code not in self.registered_commands:
|
||||
self.logger.warning("Unknown command %s", request.command_code)
|
||||
raise BankNodeError(f"Unknown command {request.command_code}")
|
||||
|
||||
command = self.registered_commands[request.command_code]
|
||||
|
||||
response = command(request, self.config)
|
||||
|
||||
return f"{request.command_code} {response}"
|
@ -0,0 +1,19 @@
|
||||
from .account_balance_command import *
|
||||
from .account_create_command import *
|
||||
from .account_deposit_command import *
|
||||
from .account_remove_command import *
|
||||
from .account_withdrawal_command import *
|
||||
from .bank_code_command import *
|
||||
from .bank_number_of_clients_command import *
|
||||
from .bank_total_amount_command import *
|
||||
|
||||
__all__ = [
|
||||
*account_balance_command.__all__,
|
||||
*account_create_command.__all__,
|
||||
*account_deposit_command.__all__,
|
||||
*account_remove_command.__all__,
|
||||
*account_withdrawal_command.__all__,
|
||||
*bank_code_command.__all__,
|
||||
*bank_number_of_clients_command.__all__,
|
||||
*bank_total_amount_command.__all__
|
||||
]
|
@ -0,0 +1,8 @@
|
||||
from core.request import Request
|
||||
import re
|
||||
|
||||
def account_balance(request: Request):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["account_balance"]
|
@ -0,0 +1,7 @@
|
||||
from core.request import Request
|
||||
|
||||
def account_create(request: Request):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["account_create"]
|
@ -0,0 +1,7 @@
|
||||
from core.request import Request
|
||||
|
||||
def account_deposit(request: Request):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["account_deposit"]
|
@ -0,0 +1,7 @@
|
||||
from core.request import Request
|
||||
|
||||
def account_remove(request: Request):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["account_remove"]
|
@ -0,0 +1,7 @@
|
||||
from core.request import Request
|
||||
|
||||
def account_withdrawal(request: Request):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["account_withdrawal"]
|
@ -0,0 +1,12 @@
|
||||
from typing import Dict
|
||||
|
||||
from core import Request, Response
|
||||
from bank_protocol.exceptions import InvalidRequest
|
||||
|
||||
def bank_code(request: Request, config: Dict) -> Response:
|
||||
if request.body is not None:
|
||||
raise InvalidRequest("Incorrect usage")
|
||||
|
||||
return config["ip"]
|
||||
|
||||
__all__ = ["bank_code"]
|
@ -0,0 +1,7 @@
|
||||
from core.request import Request
|
||||
|
||||
def bank_number_of_clients(request: Request):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["bank_number_of_clients"]
|
@ -0,0 +1,7 @@
|
||||
from core.request import Request
|
||||
|
||||
def bank_total_amount(request: Request):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["bank_total_amount"]
|
@ -0,0 +1,7 @@
|
||||
from .request import *
|
||||
from .response import *
|
||||
|
||||
__all__ = [
|
||||
*request.__all__,
|
||||
*response.__all__
|
||||
]
|
@ -1,9 +1,12 @@
|
||||
import os
|
||||
import logging
|
||||
import re
|
||||
|
||||
import dotenv
|
||||
|
||||
from core.exceptions import ConfigError
|
||||
from utils.ip import get_ip
|
||||
from utils.constants import IP_REGEX
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@ -11,11 +14,17 @@ dotenv.load_dotenv()
|
||||
class BankNodeConfig:
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
port = os.getenv("PORT", "65526")
|
||||
# Added for compatibility with python 3.9 instead of using
|
||||
# logging.getLevelNamesMapping()
|
||||
allowed_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||
|
||||
ip = os.getenv("IP", get_ip())
|
||||
timeout = os.getenv("RESPONSE_TIMEOUT", "5")
|
||||
client_idle_timeout = os.getenv("CLIENT_IDLE_TIMEOUT", "60")
|
||||
verbosity = os.getenv("VERBOSITY", "INFO")
|
||||
scan_port_range = os.getenv("SCAN_PORT_RANGE")
|
||||
|
||||
# Port validation
|
||||
if not scan_port_range or scan_port_range == "":
|
||||
self.logger.error("Scan port range not defined")
|
||||
raise ConfigError("Scan port range not defined")
|
||||
@ -26,20 +35,37 @@ class BankNodeConfig:
|
||||
self.logger.error("Scan port range is not in valid format")
|
||||
raise ConfigError("Scan port range is not in valid format")
|
||||
|
||||
|
||||
if not port.isdigit():
|
||||
self.logger.warning("Application port is invalid - Not a number. Falling back to default port 65526")
|
||||
|
||||
# Timeout validation
|
||||
if not timeout.isdigit():
|
||||
self.logger.warning("Request timeout is invalid - Not a number. Falling back to default timeout 5 seconds")
|
||||
|
||||
allowed_levels = logging.getLevelNamesMapping()
|
||||
if not client_idle_timeout.isdigit():
|
||||
self.logger.warning("Client idle timeout is invalid - Not a number. Falling back to default idle timeout 60 seconds")
|
||||
|
||||
# Verbosity
|
||||
if verbosity not in allowed_levels:
|
||||
self.logger.warning("Verbosity %s is invalid - Doesn't exist. Falling back to default verbosity INFO", verbosity)
|
||||
self.logger.warning("Unknown verbosity %s. Falling back to default verbosity INFO", verbosity)
|
||||
|
||||
self.port = int(port)
|
||||
# IP validation
|
||||
if not re.match(IP_REGEX, ip):
|
||||
self.logger.error("Invalid IP in configuration")
|
||||
raise ConfigError("Invalid IP in configuration")
|
||||
|
||||
self.used_port: int
|
||||
self.ip = ip
|
||||
self.timeout = int(timeout)
|
||||
self.client_idle_timeout = int(client_idle_timeout)
|
||||
self.verbosity = verbosity
|
||||
self.scan_port_start = int(range_split[0])
|
||||
self.scan_port_end = int(range_split[1])
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"used_port": self.used_port,
|
||||
"ip": self.ip,
|
||||
"timeout": self.timeout,
|
||||
"client_idle_timeout": self.client_idle_timeout,
|
||||
"verbosity": self.verbosity,
|
||||
"scan_port_start": self.scan_port_start,
|
||||
"scan_port_end": self.scan_port_end,
|
||||
}
|
||||
|
@ -1,10 +1,28 @@
|
||||
from core.exceptions import *
|
||||
import re
|
||||
|
||||
from bank_protocol.exceptions import InvalidRequest
|
||||
|
||||
|
||||
class Request():
|
||||
def __init__(self, raw_request):
|
||||
try:
|
||||
self.command_code = raw_request[0:1]
|
||||
self.body = raw_request[2:-1]
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def __init__(self, raw_request: str):
|
||||
if re.match(r"^[A-Z]{2}$", raw_request):
|
||||
self.command_code = raw_request[0:2] # Still take the first 2 characters, because of lingering crlf
|
||||
self.body = None
|
||||
elif re.match(r"^[A-Z]{2} .+", raw_request):
|
||||
command_code: str = raw_request[0:2]
|
||||
body: str = raw_request[3:-1] or ""
|
||||
|
||||
if len(body.split("\n")) > 1:
|
||||
raise InvalidRequest("Multiline requests are not supported")
|
||||
|
||||
self.command_code = command_code
|
||||
self.body = body
|
||||
else:
|
||||
raise InvalidRequest("Invalid request")
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.command_code} {self.body}"
|
||||
|
||||
|
||||
__all__ = ["Request"]
|
||||
|
@ -0,0 +1,3 @@
|
||||
Response = str
|
||||
|
||||
__all__ = ["Response"]
|
5
src/utils/constants.py
Normal file
5
src/utils/constants.py
Normal file
@ -0,0 +1,5 @@
|
||||
import re
|
||||
|
||||
IP_REGEX = r"^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$"
|
||||
ACCOUNT_NUMBER_REGEX = r"[0-9]{9}"
|
||||
MONEY_AMOUNT_MAXIMUM = (2 ^ 63) - 1
|
17
src/utils/ip.py
Normal file
17
src/utils/ip.py
Normal file
@ -0,0 +1,17 @@
|
||||
import socket
|
||||
|
||||
|
||||
def get_ip():
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.settimeout(0)
|
||||
try:
|
||||
s.connect(('1.1.1.1', 1))
|
||||
ip = s.getsockname()[0]
|
||||
except Exception:
|
||||
ip = '127.0.0.1'
|
||||
finally:
|
||||
s.close()
|
||||
return ip
|
||||
|
||||
|
||||
__all__ = ["get_ip"]
|
@ -4,9 +4,9 @@ import logging
|
||||
|
||||
def setup_logger(verbosity: str):
|
||||
if verbosity == "DEBUG":
|
||||
log_format = "[%(levelname)s] - %(name)s:%(lineno)d - %(message)s"
|
||||
log_format = "[ %(levelname)s / %(processName)s ] - %(name)s:%(lineno)d - %(message)s"
|
||||
else:
|
||||
log_format = "[%(levelname)s] - %(message)s"
|
||||
log_format = "[ %(levelname)s ] - %(message)s"
|
||||
|
||||
logging.basicConfig(level=verbosity, format=log_format, stream=sys.stdout)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user