nu/src/bank_node/bank_worker.py

124 lines
4.3 KiB
Python

import socket
import multiprocessing
import logging
from typing import Tuple
import signal
import sys
from bank_protocol.command_handler import CommandHandler
from core import Request, Response, BankNodeConfig
from core.exceptions import BankNodeError
from utils.logger import setup_logger
class BankWorker(multiprocessing.Process):
def __init__(self, client_socket: socket.socket, client_address: Tuple, config: BankNodeConfig):
super().__init__()
self.client_socket = client_socket
self.client_address = client_address
self.config = config
self.logger = None
self.command_handler = None
def __setup_signals(self):
self.logger.debug("Setting up exit signal hooks for worker")
# Handle C standard signals
signal.signal(signal.SIGTERM, self.gracefully_exit_worker)
signal.signal(signal.SIGINT, self.gracefully_exit_worker)
# Handle windows related signals
if sys.platform == "win32":
signal.signal(signal.SIGBREAK, self.gracefully_exit_worker)
def run(self):
# Logging behaves weirdly with processes on windows
# and loses its configuration by default
# -> Set it up again in the fresh process
if sys.platform == "win32":
setup_logger(self.config.verbosity)
self.logger = logging.getLogger(__name__)
self.command_handler = CommandHandler(self.config)
self.client_socket.settimeout(self.config.client_idle_timeout)
self.client_socket.setblocking(True)
self.__setup_signals()
with self.client_socket:
self.serve_client()
self.logger.debug("Closing process for %s", self.client_address[0])
def serve_client(self):
buffer = ""
while True:
try:
data = self.client_socket.recv(1024).decode("utf-8")
if not data:
self.logger.debug("%s disconnected", self.client_address[0])
break
buffer += data
self.logger.debug("Buffer updated: %r", buffer)
if "\r\n" in buffer:
self.logger.debug("CRLF detected")
request_data, buffer = buffer.split("\r\n", 1)
elif "\n" in buffer:
self.logger.debug("LF detected")
request_data, buffer = buffer.split("\n", 1)
elif "\r" in buffer:
self.logger.debug("CR detected")
request_data, buffer = buffer.split("\r", 1)
else:
continue
self.logger.debug("Processing request: %r", request_data)
request = Request(request_data)
response: Response = self.command_handler.execute(request) + "\r\n"
self.client_socket.sendall(response.encode("utf-8"))
self.logger.debug("Response sent to %s", self.client_address[0])
except socket.timeout:
self.logger.debug("Client was idle for too long. Ending connection")
response = "ER Idle too long\n\r"
self.client_socket.sendall(response.encode("utf-8"))
self.client_socket.shutdown(socket.SHUT_RDWR)
self.client_socket.close()
break
except UnicodeDecodeError:
self.logger.warning("Received a non utf-8 message")
response = "ER Not utf-8 message"
self.client_socket.sendall(response.encode("utf-8"))
break
except BankNodeError as e:
response = "ER " + e.message + "\n\r"
self.client_socket.sendall(response.encode("utf-8"))
except socket.error as e:
self.logger.error(e)
response = "ER Internal server error\n\r"
break
def gracefully_exit_worker(self, signum, _):
"""Log the signal caught and exit with status 0"""
signal_name = signal.Signals(signum).name
self.logger.warning("Worker caught %s - cleaning up", signal_name)
self.cleanup()
sys.exit(0)
def cleanup(self):
self.logger.info("Closing connection with %s", self.client_address[0])
if self.client_socket:
self.client_socket.close()
__all__ = ["BankWorker"]