import onnxruntime import numpy as np import sys from app.core.config import settings from app.constants import ALPHABET, CONTEXT_SIZE, SINGLE_KEY_PROBABILITY from app.utils import tokenize import logging class KeyboardPredictor(): def __init__(self): self.logger = logging.getLogger(__name__) self.buffer = "" try: self.session = onnxruntime.InferenceSession(settings.MODEL_PATH) except onnxruntime.capi.onnxruntime_pybind11_state.InvalidProtobuf as e: self.logger.critical("Invalid model found. Exiting...") sys.exit(1) def reset(self): self.logger.info("Clearing prediction buffer") self.buffer = "" def update(self, char: str): self.logger.info("Adding new char '%s'", char) # Remove the oldest character # and append the current one if len(self.buffer) == CONTEXT_SIZE: self.buffer = self.buffer[1:] self.buffer += char.lower() def __get_prediction(self) -> dict[str, float]: """Returns all probabilities distribution over the next character.""" tokenized = [tokenize("".join(self.buffer))] # shape: (1, context_size) input_array = np.array(tokenized, dtype=np.int64) outputs = self.session.run(None, {"input": input_array}) logits = outputs[0][0] # shape: (VOCAB_SIZE,) exp_logits = np.exp(logits - np.max(logits)) # numerical stability probs = exp_logits / exp_logits.sum() return {char: probs[i] for i, char in enumerate(ALPHABET + ["?"])} def get_predictions(self) -> dict[str, float]: """ Returns keys that are most likely to be pressed next. """ self.logger.debug("Getting fresh predictions") DROPOUT_PROBABILITY = 2 * SINGLE_KEY_PROBABILITY probs = self.__get_prediction() probs = {char: prob for char, prob in probs.items() if prob > DROPOUT_PROBABILITY} return probs