124 lines
4.3 KiB
Python
124 lines
4.3 KiB
Python
import socket
|
|
import multiprocessing
|
|
import logging
|
|
from typing import Tuple
|
|
import signal
|
|
import sys
|
|
|
|
from bank_protocol.command_handler import CommandHandler
|
|
from core import Request, Response, BankNodeConfig
|
|
from core.exceptions import BankNodeError
|
|
from utils.logger import setup_logger
|
|
|
|
|
|
class BankWorker(multiprocessing.Process):
|
|
def __init__(self, client_socket: socket.socket, client_address: Tuple, config: BankNodeConfig):
|
|
super().__init__()
|
|
|
|
self.client_socket = client_socket
|
|
self.client_address = client_address
|
|
self.config = config
|
|
|
|
self.logger = None
|
|
self.command_handler = None
|
|
|
|
def __setup_signals(self):
|
|
self.logger.debug("Setting up exit signal hooks for worker")
|
|
|
|
# Handle C standard signals
|
|
signal.signal(signal.SIGTERM, self.gracefully_exit_worker)
|
|
signal.signal(signal.SIGINT, self.gracefully_exit_worker)
|
|
|
|
# Handle windows related signals
|
|
if sys.platform == "win32":
|
|
signal.signal(signal.SIGBREAK, self.gracefully_exit_worker)
|
|
|
|
def run(self):
|
|
# Logging behaves weirdly with processes on windows
|
|
# and loses its configuration by default
|
|
# -> Set it up again in the fresh process
|
|
if sys.platform == "win32":
|
|
setup_logger(self.config.verbosity)
|
|
|
|
self.logger = logging.getLogger(__name__)
|
|
self.command_handler = CommandHandler(self.config)
|
|
|
|
self.client_socket.settimeout(self.config.client_idle_timeout)
|
|
self.client_socket.setblocking(True)
|
|
|
|
self.__setup_signals()
|
|
|
|
with self.client_socket:
|
|
self.serve_client()
|
|
|
|
self.logger.debug("Closing process for %s", self.client_address[0])
|
|
|
|
def serve_client(self):
|
|
buffer = ""
|
|
|
|
while True:
|
|
try:
|
|
data = self.client_socket.recv(1024).decode("utf-8")
|
|
|
|
if not data:
|
|
self.logger.debug("%s disconnected", self.client_address[0])
|
|
break
|
|
|
|
buffer += data
|
|
self.logger.debug("Buffer updated: %r", buffer)
|
|
|
|
if "\r\n" in buffer:
|
|
self.logger.debug("CRLF detected")
|
|
request_data, buffer = buffer.split("\r\n", 1)
|
|
elif "\n" in buffer:
|
|
self.logger.debug("LF detected")
|
|
request_data, buffer = buffer.split("\n", 1)
|
|
elif "\r" in buffer:
|
|
self.logger.debug("CR detected")
|
|
request_data, buffer = buffer.split("\r", 1)
|
|
else:
|
|
continue
|
|
|
|
self.logger.debug("Processing request: %r", request_data)
|
|
|
|
request = Request(request_data)
|
|
response: Response = self.command_handler.execute(request) + "\r\n"
|
|
self.client_socket.sendall(response.encode("utf-8"))
|
|
self.logger.debug("Response sent to %s", self.client_address[0])
|
|
|
|
except socket.timeout:
|
|
self.logger.debug("Client was idle for too long. Ending connection")
|
|
response = "ER Idle too long\n\r"
|
|
self.client_socket.sendall(response.encode("utf-8"))
|
|
self.client_socket.shutdown(socket.SHUT_RDWR)
|
|
self.client_socket.close()
|
|
break
|
|
except UnicodeDecodeError:
|
|
self.logger.warning("Received a non utf-8 message")
|
|
response = "ER Not utf-8 message"
|
|
self.client_socket.sendall(response.encode("utf-8"))
|
|
break
|
|
except BankNodeError as e:
|
|
response = "ER " + e.message + "\n\r"
|
|
self.client_socket.sendall(response.encode("utf-8"))
|
|
except socket.error as e:
|
|
self.logger.error(e)
|
|
response = "ER Internal server error\n\r"
|
|
break
|
|
|
|
def gracefully_exit_worker(self, signum, _):
|
|
"""Log the signal caught and exit with status 0"""
|
|
|
|
signal_name = signal.Signals(signum).name
|
|
self.logger.warning("Worker caught %s - cleaning up", signal_name)
|
|
self.cleanup()
|
|
sys.exit(0)
|
|
|
|
def cleanup(self):
|
|
self.logger.info("Closing connection with %s", self.client_address[0])
|
|
if self.client_socket:
|
|
self.client_socket.close()
|
|
|
|
|
|
__all__ = ["BankWorker"]
|