57 lines
1.9 KiB
Python
57 lines
1.9 KiB
Python
import onnxruntime
|
|
import numpy as np
|
|
import sys
|
|
|
|
from app.core.config import settings
|
|
from app.constants import ALPHABET, CONTEXT_SIZE, SINGLE_KEY_PROBABILITY
|
|
from app.utils import tokenize
|
|
|
|
import logging
|
|
|
|
class KeyboardPredictor():
|
|
def __init__(self):
|
|
self.logger = logging.getLogger(__name__)
|
|
self.buffer = ""
|
|
try:
|
|
self.session = onnxruntime.InferenceSession(settings.MODEL_PATH)
|
|
except onnxruntime.capi.onnxruntime_pybind11_state.InvalidProtobuf as e:
|
|
self.logger.critical("Invalid model found. Exiting...")
|
|
sys.exit(1)
|
|
|
|
|
|
def reset(self):
|
|
self.logger.info("Clearing prediction buffer")
|
|
self.buffer = ""
|
|
|
|
def update(self, char: str):
|
|
self.logger.info("Adding new char '%s'", char)
|
|
# Remove the oldest character
|
|
# and append the current one
|
|
if len(self.buffer) == CONTEXT_SIZE:
|
|
self.buffer = self.buffer[1:]
|
|
self.buffer += char.lower()
|
|
|
|
def __get_prediction(self) -> dict[str, float]:
|
|
"""Returns all probabilities distribution over the next character."""
|
|
tokenized = [tokenize("".join(self.buffer))] # shape: (1, context_size)
|
|
input_array = np.array(tokenized, dtype=np.int64)
|
|
|
|
outputs = self.session.run(None, {"input": input_array})
|
|
logits = outputs[0][0] # shape: (VOCAB_SIZE,)
|
|
|
|
exp_logits = np.exp(logits - np.max(logits)) # numerical stability
|
|
probs = exp_logits / exp_logits.sum()
|
|
|
|
return {char: probs[i] for i, char in enumerate(ALPHABET + ["?"])}
|
|
|
|
def get_predictions(self) -> dict[str, float]:
|
|
"""
|
|
Returns keys that are most likely to be pressed next.
|
|
"""
|
|
self.logger.debug("Getting fresh predictions")
|
|
DROPOUT_PROBABILITY = 2 * SINGLE_KEY_PROBABILITY
|
|
|
|
probs = self.__get_prediction()
|
|
probs = {char: prob for char, prob in probs.items() if prob > DROPOUT_PROBABILITY}
|
|
return probs
|