Final update including testing, final model, new data sources, updated notebook for mlp and logistic regression and final keyboard code

This commit is contained in:
Thastertyn 2025-04-06 23:28:30 +02:00
parent bd602a31ea
commit f7bdc26953
34 changed files with 453042 additions and 6426 deletions

4
.gitignore vendored
View File

@ -168,4 +168,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
data.npy
*.npy
*.pth
*.onnx

View File

@ -1 +1 @@
3.12
3.8

View File

@ -0,0 +1 @@
# Omega - Keyboard app

13
TESTING.md Normal file
View File

@ -0,0 +1,13 @@
# Fixed bugs
1. Incorrect Hitbox Behavior After Key Scaling
- Due to a forgotten reset of the z-index of keys, sometimes a key may appear behind another one, yet still the old key gets clicked
2. Click propagation
- On click even would propagate further through the app, meaning 2 keys could get pressed at once
3. Invalid model crashes the app
- If a file path to model doesn't point to a true model, the app would crash
# Tests
1. Config from environment (or `.env`) gets parsed - Defaults get used, but are overwritten if .env is used
2. Space properly clears buffers - Pressing space key clears both prediction and word box buffers

6
app/constants.py Normal file
View File

@ -0,0 +1,6 @@
CONTEXT_SIZE = 10
ALPHABET = list("abcdefghijklmnopqrstuvwxyz")
VOCAB_SIZE = len(ALPHABET) + 1 # +1 for unknown
UNKNOWN_IDX = VOCAB_SIZE - 1
SINGLE_KEY_PROBABILITY = 1 / 26

View File

@ -14,7 +14,7 @@ class Settings(BaseSettings):
APP_NAME: str = "Omega"
VERBOSITY: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
MODEL_PATH: FilePath = Path("./model/predictor.pth")
MODEL_PATH: FilePath = Path("model.onnx")
settings = Settings()

View File

@ -7,7 +7,7 @@ from app.ui.view import KeyboardView
from app.utils import setup_logger
class Keyboard():
class KeyboardApplication():
def __init__(self):
setup_logger()
self.logger = logging.getLogger(__name__)

56
app/predictor.py Normal file
View File

@ -0,0 +1,56 @@
import onnxruntime
import numpy as np
import sys
from app.core.config import settings
from app.constants import ALPHABET, CONTEXT_SIZE, SINGLE_KEY_PROBABILITY
from app.utils import tokenize
import logging
class KeyboardPredictor():
def __init__(self):
self.logger = logging.getLogger(__name__)
self.buffer = ""
try:
self.session = onnxruntime.InferenceSession(settings.MODEL_PATH)
except onnxruntime.capi.onnxruntime_pybind11_state.InvalidProtobuf as e:
self.logger.critical("Invalid model found. Exiting...")
sys.exit(1)
def reset(self):
self.logger.info("Clearing prediction buffer")
self.buffer = ""
def update(self, char: str):
self.logger.info("Adding new char '%s'", char)
# Remove the oldest character
# and append the current one
if len(self.buffer) == CONTEXT_SIZE:
self.buffer = self.buffer[1:]
self.buffer += char.lower()
def __get_prediction(self) -> dict[str, float]:
"""Returns all probabilities distribution over the next character."""
tokenized = [tokenize("".join(self.buffer))] # shape: (1, context_size)
input_array = np.array(tokenized, dtype=np.int64)
outputs = self.session.run(None, {"input": input_array})
logits = outputs[0][0] # shape: (VOCAB_SIZE,)
exp_logits = np.exp(logits - np.max(logits)) # numerical stability
probs = exp_logits / exp_logits.sum()
return {char: probs[i] for i, char in enumerate(ALPHABET + ["?"])}
def get_predictions(self) -> dict[str, float]:
"""
Returns keys that are most likely to be pressed next.
"""
self.logger.debug("Getting fresh predictions")
DROPOUT_PROBABILITY = 2 * SINGLE_KEY_PROBABILITY
probs = self.__get_prediction()
probs = {char: prob for char, prob in probs.items() if prob > DROPOUT_PROBABILITY}
return probs

54
app/ui/current_word.py Normal file
View File

@ -0,0 +1,54 @@
from PySide6.QtWidgets import QGraphicsRectItem, QGraphicsSimpleTextItem
from PySide6.QtGui import QFont, QFontMetricsF, QColor, QBrush
BACKGROUND_COLOR = QColor("white")
TEXT_SCALE_FACTOR = 0.4 # 40% of the width or height
Z_INDEX = 10
class CurrentWordBox(QGraphicsRectItem):
def __init__(self):
super().__init__()
self.word = ""
self.text = QGraphicsSimpleTextItem("", self)
self.text.setFlag(QGraphicsSimpleTextItem.ItemIgnoresTransformations, True)
self.setBrush(QBrush(BACKGROUND_COLOR))
self.setZValue(Z_INDEX)
self.width = 0
self.height = 0
def set_geometry(self, x: float, y: float, width: float, height: float):
self.width = width
self.height = height
self.setRect(0, 0, width, height)
self.setPos(x, y)
self.update_label_font()
def update_word(self, next_char: str):
self.word += next_char
self.text.setText(self.word)
self.update_label_font()
def clear(self):
self.word = ""
self.text.setText("")
self.update_label_font()
def update_label_font(self):
min_dimension = min(self.width, self.height)
font_size = min_dimension * TEXT_SCALE_FACTOR
font = QFont()
font.setPointSizeF(font_size)
self.text.setFont(font)
metrics = QFontMetricsF(font)
text_rect = metrics.boundingRect(self.word)
text_x = (self.width - text_rect.width()) / 2
text_y = (self.height - text_rect.height()) / 2
self.text.setPos(text_x, text_y)

View File

@ -1,23 +1,45 @@
import logging
from PySide6.QtWidgets import (
QGraphicsRectItem, QGraphicsSimpleTextItem, QGraphicsItem
)
from PySide6.QtGui import QBrush, QColor, QFont, QFontMetricsF
from PySide6.QtGui import QBrush, QColor, QFont, QFontMetricsF, QPen, Qt
from PySide6.QtCore import Signal, QObject, QTimer
NORMAL_COLOR = QColor("lightgray")
CLICKED_COLOR = QColor("darkgray")
TEXT_SCALE_FACTOR = 0.4
CLICK_HIGHLIGHT_DURATION_MS = 100
Z_INDEX_MULTIPLIER = 2.0
class KeyItem(QGraphicsRectItem, QObject):
clicked = Signal(str)
def __init__(self, label: str):
QObject.__init__(self)
QGraphicsRectItem.__init__(self)
self.logger = logging.getLogger(__name__)
class KeyItem(QGraphicsRectItem):
def __init__(self, label):
super().__init__()
self.label = label
self.width = 0.0
self.height = 0.0
self.scale_factor = 1.0
self._reset_timer = QTimer()
self._reset_timer.setSingleShot(True)
self._reset_timer.timeout.connect(self.__reset_color)
self.text = QGraphicsSimpleTextItem(label, self)
self.text.setFlag(QGraphicsItem.ItemIgnoresTransformations, True)
self.setBrush(QBrush(QColor("lightgray")))
self.__normal_brush = QBrush(NORMAL_COLOR)
self.__click_brush = QBrush(CLICKED_COLOR)
self.setBrush(self.__normal_brush)
self.scale_factor = 1.0
self.width = 0
self.height = 0
def set_geometry(self, x, y, width, height):
def set_geometry(self, x: float, y: float, width: float, height: float):
self.width = width
self.height = height
@ -28,25 +50,34 @@ class KeyItem(QGraphicsRectItem):
self.update_label_font()
def update_label_font(self):
# Dynamically size font
font_size = min(self.width, self.height) * 0.4
min_dimension = min(self.width, self.height)
font_size = min_dimension * TEXT_SCALE_FACTOR
font = QFont()
font.setPointSizeF(font_size)
self.text.setFont(font)
# Use font metrics for accurate bounding
metrics = QFontMetricsF(font)
text_rect = metrics.boundingRect(self.label)
text_x = (self.width - text_rect.width()) / 2
text_y = (self.height - text_rect.height()) / 2 # + metrics.ascent() - text_rect.height()
text_y = (self.height - text_rect.height()) / 2
self.text.setPos(text_x, text_y)
def set_scale_factor(self, scale):
self.scale_factor = scale
self.setZValue(scale)
def set_scale_factor(self, scale: float):
scaled_z = scale * Z_INDEX_MULTIPLIER
self.scale_factor = scaled_z
self.setZValue(scaled_z)
self.setScale(scale)
def mousePressEvent(self, q_mouse_event):
print(self.label, "was clicked")
self.logger.info("%s was clicked", self.label)
self.clicked.emit(self.label)
self.__onclick_color()
def __onclick_color(self):
self.setBrush(self.__click_brush)
self._reset_timer.start(CLICK_HIGHLIGHT_DURATION_MS)
def __reset_color(self):
self.setBrush(self.__normal_brush)

70
app/ui/keyboard.py Normal file
View File

@ -0,0 +1,70 @@
from PySide6.QtCore import QSize, QRectF
from app.ui.key import KeyItem
class KeyboardLayout():
def __init__(self, scene):
self.scene = scene
self.keys = {}
self.letter_rows = [
"QWERTZUIOP",
"ASDFGHJKL",
"YXCVBNM",
]
self.space_key = KeyItem("Space")
self.space_key.clicked.connect(lambda: self.on_key_clicked("Space"))
self.scene.addItem(self.space_key)
self.create_keys()
# callback to be set by scene
self.key_pressed_callback = None
def create_keys(self):
for row in self.letter_rows:
for char in row:
key = KeyItem(char)
key.clicked.connect(lambda c=char: self.on_key_clicked(c))
self.scene.addItem(key)
self.keys[char] = key
def layout_keys(self, view_size: QSize, top_offset: float = 0.0):
view_width = view_size.width()
view_height = view_size.height()
padding = 10.0
spacing = 6.0
rows = self.letter_rows + [" "]
row_count = len(rows)
max_keys_in_row = max(len(row) for row in rows)
total_spacing_x = (max_keys_in_row - 1) * spacing
total_spacing_y = (row_count - 1) * spacing
available_width = view_width - 2 * padding - total_spacing_x
available_height = view_height - top_offset - 2 * padding - total_spacing_y
key_width = available_width / max_keys_in_row
key_height = available_height / row_count
y = top_offset + padding
for row in self.letter_rows:
x = padding
for char in row:
key = self.keys[char]
key.set_geometry(x, y, key_width, key_height)
x += key_width + spacing
y += key_height + spacing
space_width = key_width * 7 + spacing * 6
self.space_key.set_geometry(padding, y, space_width, key_height)
def set_scale_factors(self, predictions: dict):
for key in self.keys.values():
key.set_scale_factor(1.0)
for key, probability in predictions.items():
if key.upper() in self.keys:
self.keys[key.upper()].set_scale_factor(1 + probability)
def on_key_clicked(self, label: str):
if self.key_pressed_callback:
self.key_pressed_callback(label)

View File

@ -1,79 +1,52 @@
from PySide6.QtWidgets import (
QGraphicsScene
)
from PySide6.QtGui import QBrush, QColor
from PySide6.QtCore import QTimer, QRectF, QSize
import random
from PySide6.QtWidgets import QGraphicsScene
from PySide6.QtCore import QRectF, QSize
from app.predictor import KeyboardPredictor
from app.ui.current_word import CurrentWordBox
from app.ui.keyboard import KeyboardLayout
WORD_BOX_HEIGHT_MODIFIER = 0.15
WORD_BOX_PADDING_ALL_SIDES = 10
WORD_BOX_MARGIN_BOTTOM = 20
from app.ui.key import KeyItem
class KeyboardScene(QGraphicsScene):
def __init__(self):
super().__init__()
self.keys = {}
self.letter_rows = [
"QWERTYUIOP",
"ASDFGHJKL",
"ZXCVBNM",
]
self.space_key = KeyItem("Space")
self.addItem(self.space_key)
self.timer = QTimer()
self.timer.timeout.connect(self.simulate_prediction)
self.timer.start(2000)
self.create_keys()
self.word_box = CurrentWordBox()
self.addItem(self.word_box)
def create_keys(self):
for row in self.letter_rows:
for char in row:
key = KeyItem(char)
self.addItem(key)
self.keys[char] = key
self.predictor = KeyboardPredictor()
self.keyboard = KeyboardLayout(self)
self.keyboard.key_pressed_callback = self.on_key_clicked
def layout_keys(self, view_size: QSize):
view_width = view_size.width()
view_height = view_size.height()
padding = 10.0
spacing = 6.0
max_scale = 2.0
word_box_height = view_height * WORD_BOX_HEIGHT_MODIFIER
word_box_x = WORD_BOX_PADDING_ALL_SIDES
word_box_y = WORD_BOX_PADDING_ALL_SIDES
word_box_width = view_width - WORD_BOX_MARGIN_BOTTOM
word_box_height_adjusted = word_box_height - WORD_BOX_MARGIN_BOTTOM
rows = self.letter_rows + [" "]
row_count = len(rows)
max_keys_in_row = max(len(row) for row in rows)
total_spacing_x = (max_keys_in_row - 1) * spacing
total_spacing_y = (row_count - 1) * spacing
available_width = view_width - 2 * padding - total_spacing_x
available_height = view_height - 2 * padding - total_spacing_y
key_width = available_width / max_keys_in_row
key_height = available_height / row_count
y = padding
for row in self.letter_rows:
x = padding
for char in row:
key = self.keys[char]
key.set_geometry(x, y, key_width, key_height)
x += key_width + spacing
y += key_height + spacing
# Space key layout
space_width = key_width * 7 + spacing * 6
self.space_key.set_geometry(padding, y, space_width, key_height)
self.word_box.set_geometry(
word_box_x, word_box_y, word_box_width, word_box_height_adjusted
)
self.keyboard.layout_keys(view_size, top_offset=word_box_height)
self.setSceneRect(QRectF(0, 0, view_width, view_height))
def simulate_prediction(self):
most_likely = random.choice(list(self.keys.keys()))
print(f"[Prediction] Most likely: {most_likely}")
def on_key_clicked(self, label: str):
if label == "Space":
self.word_box.clear()
self.predictor.reset()
self.keyboard.set_scale_factors({})
return
for char, key in self.keys.items():
if char == most_likely:
key.set_scale_factor(1.8)
key.setBrush(QBrush(QColor("orange")))
else:
key.set_scale_factor(1.0)
key.setBrush(QBrush(QColor("lightgray")))
self.word_box.update_word(label)
self.predictor.update(label)
predictions = self.predictor.get_predictions()
self.keyboard.set_scale_factors(predictions)

View File

@ -5,6 +5,7 @@ from PySide6.QtGui import QPainter
from PySide6.QtCore import Qt
from .scene import KeyboardScene
from app.core.config import settings
class KeyboardView(QGraphicsView):
def __init__(self):
@ -12,9 +13,9 @@ class KeyboardView(QGraphicsView):
self.scene = KeyboardScene()
self.setScene(self.scene)
self.setRenderHint(QPainter.Antialiasing)
self.setWindowTitle("Dynamic Keyboard")
self.setWindowTitle(settings.APP_NAME)
self.setAlignment(Qt.AlignLeft | Qt.AlignTop)
self.setMinimumSize(600, 200) # Sensible default
self.setMinimumSize(600, 300)
def resizeEvent(self, event):
super().resizeEvent(event)

View File

@ -3,6 +3,11 @@ import logging
from app.core.config import settings
from app.constants import CONTEXT_SIZE, UNKNOWN_IDX, ALPHABET
char_to_index = {char: idx for idx, char in enumerate(ALPHABET)}
def setup_logger():
logger = logging.getLogger()
@ -51,3 +56,13 @@ if sys.platform == "win32":
0, 0, 0, 0,
win32con.SWP_NOMOVE | win32con.SWP_NOSIZE | win32con.SWP_NOACTIVATE
)
def tokenize(text: str) -> list[int]:
"""Convert last CONTEXT_SIZE chars to integer indices."""
text = text.lower()[-CONTEXT_SIZE:] # trim to context length
padded = [' '] * (CONTEXT_SIZE - len(text)) + list(text)
return [
char_to_index.get(c, UNKNOWN_IDX)
for c in padded
]

View File

@ -1,7 +1,7 @@
import sys
from app.keyboard import Keyboard
from app.keyboard import KeyboardApplication
if __name__ == "__main__":
keyboard = Keyboard()
keyboard = KeyboardApplication()
sys.exit(keyboard.run())

BIN
model.onnx Normal file

Binary file not shown.

View File

@ -1,40 +1,55 @@
# omega
# Model and training
## Documentation
First I gathered a couple long text source, like the GNU GPL license, Wikipedia articles, or even a book.
This is the section of the code related to training and modifying the model used by this app
Those were transformed into a large text file [see all_words.txt](data/all_words.txt) using the following command
## Training
There are 2 notebooks related to training
- [Multilayer Perceptron](./mlp.ipynb)
- [Logistic Regression model](./logistic.ipynb)
```
grep -o "[[:alpha:]]\{1,\}" "path_to_individual_source.txt" | tr '[:upper:]' '[:lower:]'
```
MLP proved to be far more accurate, therefore it is the one used
Which simply finds words at least 1 character long and unifies them by transforming them all to lowercase.
## Data
See [Sources](#Sources) for all data used.
Data includes various data including news articles, scientific articles, couple of books and wikipedia articles.
The text was extracted by simply copying the text including some unwanted garbage like numbers, wikipedia links, etc. Simply put Ctrl+C - Ctrl+V.
Next step in data processing was using the included scripts, mainly [`words.sh`](./words.sh) which extracts only alphanumeric strings and places them on new lines.
Second was cleaning the data completely by allowing only the 26 english alphabet characters to be present. This is done by [`clear.sh`](./clear.sh).
Third and last was turning the data into a numpy array for space efficiency (instead of CSV). This is done by [`transform.py`](./transform.py). This is the last step for data processing and the model can now be trained using this data
## Structure
The current model is a **character-level** predictor that uses the previous 10 characters to predict the next one.
It was trained using the processed word list.
Dataset can be found in [`data/all_cleaned_words.txt`](data/all_cleaned_words.txt)).
- **Model type**: **`Multilayer Perceptron`**
- **Input shape**: **`10 * 16`** - 10 previous characters in form of embeddings of 16 dimensions
- **Output shape**: **`26`** - 26 probabilities for each letter of the english alphabet
- **Dataset**: **`220k`** words from various types of sources, mainly books though
For the model to have as much accuracy as possible, I calculated the average word length (5.819) and went with character history of 5 letters. This is for now the norm and can easily be omitted from the data if it becomes excessive
```
awk '{ total += length; count++ } END { if (count > 0) print total / count }' 1000_words.txt
```
## Sources
1. Generic news articles
- https://edition.cnn.com/2025/03/20/middleeast/ronen-bar-shin-bet-israel-vote-dismiss-intl-latam/index.html
- https://edition.cnn.com/2025/03/21/europe/conor-mcgregor-ireland-president-election-intl-hnk/index.html
- https://edition.cnn.com/2025/03/20/middleeast/ronen-bar-shin-bet-israel-vote-dismiss-intl-latam/index.html
- https://edition.cnn.com/2025/03/21/europe/conor-mcgregor-ireland-president-election-intl-hnk/index.html
2. Wikipedia articles
- https://simple.wikipedia.org/wiki/Dog
- https://en.wikipedia.org/wiki/Car
- https://simple.wikipedia.org/wiki/Dog
- https://en.wikipedia.org/wiki/Car
3. Scientific articles ([Kurzgesagt](https://www.youtube.com/@kurzgesagt/videos))
- https://www.youtube.com/watch?v=dCiMUWw1BBc&t=766s
- https://news.umich.edu/astronomers-find-surprising-ice-world-in-the-habitable-zone-with-jwst-data/
- https://www.youtube.com/watch?v=VD6xJq8NguY
- https://www.pnas.org/doi/10.1073/pnas.1711842115
- https://www.youtube.com/watch?v=dCiMUWw1BBc&t=766s
- https://news.umich.edu/astronomers-find-surprising-ice-world-in-the-habitable-zone-with-jwst-data/
- https://www.youtube.com/watch?v=VD6xJq8NguY
- https://www.pnas.org/doi/10.1073/pnas.1711842115
4. License text
- https://www.gnu.org/licenses/gpl-3.0.en.html
- https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html
- https://www.gnu.org/licenses/gpl-3.0.en.html
- https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html
5. Books
- https://ia902902.us.archive.org/19/items/diaryofawimpykidbookseriesbyjeffkinney_202004/Diary%20of%20a%20wimpy%20kid%20book02%20rodrick%20rules.pdf
- https://drive.google.com/file/d/1b1Etdxb1cNU3PvDBQnYh0bCAAfssMi8b/view
- https://drive.google.com/file/d/1b1Etdxb1cNU3PvDBQnYh0bCAAfssMi8b/view
- https://dhspriory.org/kenny/PhilTexts/Camus/Myth%20of%20Sisyphus-.pdf
- https://www.matermiddlehigh.org/ourpages/auto/2012/11/16/50246772/Beloved.pdf

3
model/clear.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/bash
grep -o "[[:alpha:]]\{1,\}" "$1" | tr '[:upper:]' '[:lower:]' | grep -Eo '^([a-z]{2,}|a|i)$'

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

63913
model/data/all_words.txt.bak Normal file

File diff suppressed because it is too large Load Diff

6173
model/data/beloved.txt Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

339
model/logistic.ipynb Normal file
View File

@ -0,0 +1,339 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Import and data processing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import all necessary libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader, TensorDataset, random_split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load the training data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"data = np.load(\"./data.npy\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define constants that describe the data and model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"CONTEXT_SIZE = 10\n",
"ALPHABET = list(\"abcdefghijklmnopqrstuvwxyz\")\n",
"ALPHABET_SIZE = len(ALPHABET)\n",
"TRAINING_DATA_SIZE = 0.9\n",
"\n",
"# +1 is for unknown characters\n",
"VOCAB_SIZE = ALPHABET_SIZE + 1\n",
"\n",
"EMBEDDING_DIM = 10\n",
"\n",
"INPUT_SEQ_LEN = CONTEXT_SIZE\n",
"OUTPUT_SIZE = VOCAB_SIZE\n",
"\n",
"BATCH_SIZE = 2048\n",
"\n",
"EPOCHS = 30\n",
"LEARNING_RATE = 1e-3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Process the data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Input: embeddings of the previous 10 letters\n",
"# shape: (num_samples, CONTEXT_SIZE)\n",
"X = data[:, :CONTEXT_SIZE]\n",
"\n",
"# Target: current letter index\n",
"# shape: (num_samples,)\n",
"y = data[:, CONTEXT_SIZE]\n",
"\n",
"# Torch dataset (important: use long/int64 for indices)\n",
"X_tensor = torch.tensor(X, dtype=torch.long) # for nn.Embedding\n",
"y_tensor = torch.tensor(y, dtype=torch.long) # for classification target\n",
"\n",
"dataset = TensorDataset(X_tensor, y_tensor)\n",
"\n",
"train_len = int(TRAINING_DATA_SIZE * len(dataset))\n",
"train_set, test_set = random_split(dataset, [train_len, len(dataset) - train_len])\n",
"\n",
"train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)\n",
"test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class Logistic(nn.Module):\n",
" def __init__(self, *, embedding_count: int, embedding_dimension_size: int, context_size: int, output_shape: int):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(num_embeddings=embedding_count, embedding_dim=embedding_dimension_size)\n",
" self.linear = nn.Linear(context_size * embedding_dimension_size, output_shape)\n",
"\n",
" def forward(self, x):\n",
" embedded = self.embedding(x) # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_DIM)\n",
" flattened = embedded.view(x.size(0), -1) # (BATCH_SIZE, CONTEXT_SIZE * EMBEDDING_DIM)\n",
" return self.linear(flattened) # (BATCH_SIZE, OUTPUT_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cpu\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/thastertyn/Code/Skola/4-rocnik/programove-vybaveni/omega/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:129: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)\n",
" return torch._C._cuda_getDeviceCount() > 0\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create fresh instance of the model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model = Logistic(\n",
" embedding_count=VOCAB_SIZE, # e.g., 27 for az + unknown\n",
" embedding_dimension_size=EMBEDDING_DIM, # e.g., 10\n",
" context_size=CONTEXT_SIZE, # e.g., 10\n",
" output_shape=OUTPUT_SIZE # e.g., 27 (next character)\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Epoch 1] - Loss: 2.5968 | Accuracy: 23.06%\n",
"[Epoch 2] - Loss: 2.3218 | Accuracy: 30.02%\n",
"[Epoch 3] - Loss: 2.2600 | Accuracy: 31.25%\n",
"[Epoch 4] - Loss: 2.2325 | Accuracy: 31.55%\n",
"[Epoch 5] - Loss: 2.2171 | Accuracy: 31.75%\n",
"[Epoch 6] - Loss: 2.2076 | Accuracy: 31.98%\n",
"[Epoch 7] - Loss: 2.2006 | Accuracy: 32.22%\n",
"[Epoch 8] - Loss: 2.1962 | Accuracy: 32.36%\n",
"[Epoch 9] - Loss: 2.1925 | Accuracy: 32.42%\n",
"[Epoch 10] - Loss: 2.1900 | Accuracy: 32.48%\n",
"[Epoch 11] - Loss: 2.1876 | Accuracy: 32.54%\n",
"[Epoch 12] - Loss: 2.1859 | Accuracy: 32.64%\n",
"[Epoch 13] - Loss: 2.1847 | Accuracy: 32.65%\n",
"[Epoch 14] - Loss: 2.1833 | Accuracy: 32.76%\n",
"[Epoch 15] - Loss: 2.1821 | Accuracy: 32.75%\n",
"[Epoch 16] - Loss: 2.1813 | Accuracy: 32.74%\n",
"[Epoch 17] - Loss: 2.1806 | Accuracy: 32.84%\n",
"[Epoch 18] - Loss: 2.1799 | Accuracy: 32.81%\n",
"[Epoch 19] - Loss: 2.1792 | Accuracy: 32.80%\n",
"[Epoch 20] - Loss: 2.1786 | Accuracy: 32.81%\n",
"[Epoch 21] - Loss: 2.1780 | Accuracy: 32.77%\n",
"[Epoch 22] - Loss: 2.1776 | Accuracy: 32.85%\n",
"[Epoch 23] - Loss: 2.1770 | Accuracy: 32.81%\n",
"[Epoch 24] - Loss: 2.1767 | Accuracy: 32.81%\n",
"[Epoch 25] - Loss: 2.1764 | Accuracy: 32.81%\n",
"[Epoch 26] - Loss: 2.1757 | Accuracy: 32.80%\n",
"[Epoch 27] - Loss: 2.1755 | Accuracy: 32.81%\n",
"[Epoch 28] - Loss: 2.1751 | Accuracy: 32.79%\n",
"[Epoch 29] - Loss: 2.1748 | Accuracy: 32.82%\n",
"[Epoch 30] - Loss: 2.1744 | Accuracy: 32.80%\n"
]
}
],
"source": [
"for epoch in range(EPOCHS):\n",
" model.train()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"\n",
" for batch_X, batch_y in train_loader:\n",
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" logits = model(batch_X) # shape: (BATCH_SIZE, OUTPUT_SIZE)\n",
" loss = criterion(logits, batch_y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item() * batch_X.size(0)\n",
"\n",
" # Compute accuracy\n",
" preds = torch.argmax(logits, dim=1)\n",
" correct += (preds == batch_y).sum().item()\n",
" total += batch_X.size(0)\n",
"\n",
" avg_loss = total_loss / total\n",
" accuracy = correct / total * 100\n",
" print(f\"[Epoch {epoch+1}] - Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%\")\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Top 1 prediction accuracy: 32.45%\n",
"Top 3 prediction accuracy: 58.55%\n",
"Top 5 prediction accuracy: 72.66%\n"
]
}
],
"source": [
"model.eval()\n",
"correct_top1 = 0\n",
"correct_top3 = 0\n",
"correct_top5 = 0\n",
"total = 0\n",
"\n",
"with torch.no_grad():\n",
" for batch_X, batch_y in test_loader:\n",
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
" outputs = model(batch_X)\n",
"\n",
" _, top_preds = outputs.topk(5, dim=1)\n",
"\n",
" for true, top5 in zip(batch_y, top_preds):\n",
" total += 1\n",
" if true == top5[0]:\n",
" correct_top1 += 1\n",
" if true in top5[:3]:\n",
" correct_top3 += 1\n",
" if true in top5:\n",
" correct_top5 += 1\n",
"\n",
"top1_acc = correct_top1 / total\n",
"top3_acc = correct_top3 / total\n",
"top5_acc = correct_top5 / total\n",
"\n",
"print(f\"Top 1 prediction accuracy: {(top1_acc * 100):.2f}%\")\n",
"print(f\"Top 3 prediction accuracy: {(top3_acc * 100):.2f}%\")\n",
"print(f\"Top 5 prediction accuracy: {(top5_acc * 100):.2f}%\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

601
model/mlp.ipynb Normal file

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@ -1,475 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Omega\n",
"Prediction of next key to be pressed using Multilayer Perceptron"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Import and load data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import all required modules"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader, TensorDataset, random_split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load data"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"data = np.load(\"./data.npy\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define contstants describing the dataset and other useful information"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"CONTEXT_SIZE = 10\n",
"ALPHABET = list(\"abcdefghijklmnopqrstuvwxyz\")\n",
"ALPHABET_SIZE = len(ALPHABET)\n",
"TRAINING_DATA_SIZE = 0.9\n",
"\n",
"VOCAB_SIZE = ALPHABET_SIZE + 1 # 26 letters + 1 for unknown\n",
"EMBEDDING_DIM = 16\n",
"\n",
"INPUT_SEQ_LEN = CONTEXT_SIZE\n",
"OUTPUT_SIZE = VOCAB_SIZE"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define and split data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define input and output columns"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"X = data[:, :CONTEXT_SIZE] # shape: (num_samples, CONTEXT_SIZE)\n",
"\n",
"# Target: current letter index\n",
"y = data[:, CONTEXT_SIZE] # shape: (num_samples,)\n",
"\n",
"# Torch dataset (important: use long/int64 for indices)\n",
"X_tensor = torch.tensor(X, dtype=torch.long) # for nn.Embedding\n",
"y_tensor = torch.tensor(y, dtype=torch.long) # for classification target\n",
"\n",
"dataset = TensorDataset(X_tensor, y_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"train_len = int(TRAINING_DATA_SIZE * len(dataset))\n",
"train_set, test_set = random_split(dataset, [train_len, len(dataset) - train_len])"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"train_loader = DataLoader(train_set, batch_size=1024, shuffle=True)\n",
"test_loader = DataLoader(test_set, batch_size=1024)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"learning_rates = [1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2]\n",
"activation_layers = [nn.ReLU, nn.GELU]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model and training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To find the best model for MLP, combinations of hyperparams are defined. \n",
"This includes **activation layers** and **learning rates**"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"from itertools import product\n",
"all_activation_combinations = list(product(activation_layers, repeat=len(activation_layers)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" def __init__(self, activation_layers: list):\n",
" super().__init__()\n",
" self.net = nn.Sequential(\n",
" nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=EMBEDDING_DIM),\n",
" nn.Flatten(),\n",
" nn.Linear(CONTEXT_SIZE * EMBEDDING_DIM, 256),\n",
" activation_layers[0](),\n",
" nn.Linear(256, 128),\n",
" activation_layers[1](),\n",
" nn.Linear(128, OUTPUT_SIZE)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.net(x)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"# model = MLP().to(device)\n",
"model = None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Test all the activation_layer combinations"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"\n",
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, optimizer):\n",
" for epoch in range(30):\n",
" model.train()\n",
" total_loss = 0\n",
" for batch_X, batch_y in train_loader:\n",
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(batch_X)\n",
" loss = criterion(output, batch_y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" # print(f\"Epoch {epoch+1}, Loss: {total_loss:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Testing model"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"def test_model(model) -> tuple[float]:\n",
" model.eval()\n",
" correct_top1 = 0\n",
" correct_top3 = 0\n",
" correct_top5 = 0\n",
" total = 0\n",
"\n",
" with torch.no_grad():\n",
" for batch_X, batch_y in test_loader:\n",
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
" outputs = model(batch_X)\n",
"\n",
" _, top_preds = outputs.topk(5, dim=1)\n",
"\n",
" for true, top5 in zip(batch_y, top_preds):\n",
" total += 1\n",
" if true == top5[0]:\n",
" correct_top1 += 1\n",
" if true in top5[:3]:\n",
" correct_top3 += 1\n",
" if true in top5:\n",
" correct_top5 += 1\n",
"\n",
" top1_acc = correct_top1 / total\n",
" top3_acc = correct_top3 / total\n",
" top5_acc = correct_top5 / total\n",
"\n",
" return (top1_acc, top3_acc, top5_acc)\n"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.0001 had success of (0.44952931636286714, 0.6824383880407573, 0.788915135916511)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.0005 had success of (0.5080210132919649, 0.7299298381694461, 0.8241018227973064)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.001 had success of (0.5215950357860593, 0.7354299615696506, 0.826111483270458)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.005 had success of (0.5230758382399605, 0.7383563092761697, 0.8298840038077777)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.01 had success of (0.5206783485526919, 0.7364171632055847, 0.8278390861333428)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.05 had success of (0.12682015301625357, 0.29884003807777737, 0.45160949123858546)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.0001 had success of (0.44251313330747805, 0.6765504354264359, 0.7860240454112752)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.0005 had success of (0.5103127313753835, 0.7293304657476289, 0.8237492507844727)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.001 had success of (0.5211366921693756, 0.7379332228607693, 0.8288968021718436)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.005 had success of (0.5246271550964284, 0.739942883333921, 0.8305538906321617)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.01 had success of (0.5214892641822092, 0.7391319677044036, 0.8297077178013609)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.05 had success of (0.1655325600253852, 0.3544759017029228, 0.495469449635088)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.0001 had success of (0.44706131227303175, 0.6806755279765893, 0.7906427387793957)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.0005 had success of (0.5120050770369848, 0.7312343546169305, 0.8229735923562388)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.001 had success of (0.5179282868525896, 0.7381800232697528, 0.8289673165744104)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.005 had success of (0.5234636674540775, 0.7421640870147728, 0.8307654338398618)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.01 had success of (0.5197264041180412, 0.7384268236787364, 0.8286500017628601)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.05 had success of (0.12551563656876918, 0.29757077883157634, 0.45034023199238443)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.0001 had success of (0.4493530303564503, 0.683284560871558, 0.7907837675845292)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.0005 had success of (0.5151077107499207, 0.733808130310616, 0.8255121108486408)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.001 had success of (0.5195148609103409, 0.7389204244967035, 0.8294961745936608)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.005 had success of (0.5214892641822092, 0.7401896837429045, 0.8302365758206114)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.01 had success of (0.5198674329231746, 0.7398371117300708, 0.8258294256601911)\n",
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.05 had success of (0.3762648520960406, 0.6283538412720798, 0.7500617001022459)\n"
]
}
],
"source": [
"for activation_layer_combination in all_activation_combinations:\n",
" for learning_rate in learning_rates:\n",
" model = MLP(activation_layer_combination).to(device)\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
" train_model(model, optimizer)\n",
" results = test_model(model)\n",
" print(\"Model with activation layers\", activation_layer_combination, \"and learning rate\", learning_rate, \"had success of\", results)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# Reuse same alphabet + mapping\n",
"alphabet = list(\"abcdefghijklmnopqrstuvwxyz\")\n",
"char_to_idx = {ch: idx for idx, ch in enumerate(alphabet)}\n",
"PAD_IDX = len(alphabet) # index 26 for OOV/padding\n",
"VOCAB_SIZE = len(alphabet) + 1 # 27 total (az + padding)\n",
"CONTEXT_SIZE = 10\n",
"\n",
"idx_to_char = {idx: ch for ch, idx in char_to_idx.items()}\n",
"idx_to_char[PAD_IDX] = \"_\" # for readability\n",
"\n",
"def preprocess_input(context: str) -> torch.Tensor:\n",
" context = context.lower()\n",
" padded = context.rjust(CONTEXT_SIZE, \"_\") # pad with underscores (or any 1-char symbol)\n",
"\n",
" indices = []\n",
" for ch in padded[-CONTEXT_SIZE:]:\n",
" idx = char_to_idx.get(ch, PAD_IDX) # if '_' or unknown → PAD_IDX (26)\n",
" indices.append(idx)\n",
"\n",
" return torch.tensor(indices, dtype=torch.long).unsqueeze(0).to(device)\n",
"\n",
"\n",
"def predict_next_chars(model, context: str, top_k=5):\n",
" model.eval()\n",
" input_tensor = preprocess_input(context)\n",
" with torch.no_grad():\n",
" logits = model(input_tensor)\n",
" probs = torch.softmax(logits, dim=-1)\n",
" top_probs, top_indices = probs.topk(top_k, dim=-1)\n",
"\n",
" predictions = [(idx_to_char[idx.item()], top_probs[0, i].item()) for i, idx in enumerate(top_indices[0])]\n",
" return predictions\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I: 89.74 %\n",
"N: 4.42 %\n",
"Y: 1.88 %\n",
"M: 1.51 %\n",
"B: 0.90 %\n",
"E: 0.65 %\n",
"G: 0.21 %\n",
"R: 0.16 %\n",
"L: 0.15 %\n",
"O: 0.13 %\n",
"C: 0.09 %\n",
"U: 0.08 %\n",
"A: 0.05 %\n",
"V: 0.02 %\n",
"S: 0.01 %\n",
"F: 0.00 %\n",
"H: 0.00 %\n",
"T: 0.00 %\n",
"W: 0.00 %\n",
"P: 0.00 %\n"
]
}
],
"source": [
"preds = predict_next_chars(model, \"susta\", top_k=20)\n",
"for char, prob in preds:\n",
" print(f\"{char.upper()}: {(prob * 100):.2f} %\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Model saving"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model, \"mlp_full_model.pth\")"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), \"mlp_weights.pth\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because it is too large Load Diff

Binary file not shown.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -3,22 +3,13 @@
from typing import Literal, List, Dict
import numpy as np
INPUT_FILE: str = "./data/all_cleaned_words.txt"
INPUT_FILE: str = "./data/all_cleaned_words_shuffled.txt"
OUTPUT_FILE: str = "./data.npy"
alphabet: List[str] = list("abcdefghijklmnopqrstuvwxyz")
vowels: set[str] = set("aeiouy")
char_to_index: Dict[str, int] = {ch: idx for idx, ch in enumerate(alphabet)}
default_index: int = len(alphabet) # Out-of-vocabulary token (e.g., for "")
def get_prev_type(c: str) -> Literal[0, 1, 2]:
if c in vowels:
return 1
elif c in alphabet:
return 2
return 0
default_index: int = len(alphabet) # 26 + 1 -> Unknown character
def encode_letter(c: str) -> int:
@ -43,12 +34,6 @@ def build_dataset(input_path: str) -> np.ndarray:
# Append current char index (target for classification)
features.append(encode_letter(curr_char))
# Word position features
# is_start: int = 1 if i == 0 else 0
prev1: str = prev_chars[-1]
prev_type: int = get_prev_type(prev1)
# word_length: int = i + 1
# features.extend([prev_type])
all_features.append(features)

View File

@ -3,14 +3,16 @@ name = "app"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.8"
dependencies = [
"onnxruntime>=1.21.0",
"pydantic-settings>=2.8.1",
"pyside6>=6.8.3",
]
[project.optional-dependencies]
train = [
"onnx>=1.17.0",
"torch>=2.6.0",
"ipykernel>=6.29.5",
"numpy>=2.2.4"

102
uv.lock generated
View File

@ -16,6 +16,8 @@ name = "app"
version = "0.1.0"
source = { virtual = "." }
dependencies = [
{ name = "onnx" },
{ name = "onnxruntime" },
{ name = "pydantic-settings" },
{ name = "pyside6" },
]
@ -38,6 +40,8 @@ requires-dist = [
{ name = "matplotlib", marker = "extra == 'visualization'", specifier = ">=3.10.1" },
{ name = "numpy", marker = "extra == 'train'", specifier = ">=2.2.4" },
{ name = "numpy", marker = "extra == 'visualization'", specifier = ">=2.2.4" },
{ name = "onnx", specifier = ">=1.17.0" },
{ name = "onnxruntime", specifier = ">=1.21.0" },
{ name = "pandas", marker = "extra == 'visualization'", specifier = ">=2.2.3" },
{ name = "pydantic-settings", specifier = ">=2.8.1" },
{ name = "pyside6", specifier = ">=6.8.3" },
@ -105,6 +109,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 },
]
[[package]]
name = "coloredlogs"
version = "15.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "humanfriendly" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cc/c7/eed8f27100517e8c0e6b923d5f0845d0cb99763da6fdee00478f91db7325/coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0", size = 278520 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a7/06/3d6badcf13db419e25b07041d9c7b4a2c331d3f4e7134445ec5df57714cd/coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934", size = 46018 },
]
[[package]]
name = "comm"
version = "0.2.2"
@ -211,6 +227,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de", size = 16215 },
]
[[package]]
name = "flatbuffers"
version = "25.2.10"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e4/30/eb5dce7994fc71a2f685d98ec33cc660c0a5887db5610137e60d8cbc4489/flatbuffers-25.2.10.tar.gz", hash = "sha256:97e451377a41262f8d9bd4295cc836133415cc03d8cb966410a4af92eb00d26e", size = 22170 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b8/25/155f9f080d5e4bc0082edfda032ea2bc2b8fab3f4d25d46c1e9dd22a1a89/flatbuffers-25.2.10-py2.py3-none-any.whl", hash = "sha256:ebba5f4d5ea615af3f7fd70fc310636fbb2bbd1f566ac0a23d98dd412de50051", size = 30953 },
]
[[package]]
name = "fonttools"
version = "4.56.0"
@ -245,6 +270,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/56/53/eb690efa8513166adef3e0669afd31e95ffde69fb3c52ec2ac7223ed6018/fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3", size = 193615 },
]
[[package]]
name = "humanfriendly"
version = "10.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyreadline3", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", size = 360702 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794 },
]
[[package]]
name = "ipykernel"
version = "6.29.5"
@ -677,6 +714,48 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
]
[[package]]
name = "onnx"
version = "1.17.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
{ name = "protobuf" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9a/54/0e385c26bf230d223810a9c7d06628d954008a5e5e4b73ee26ef02327282/onnx-1.17.0.tar.gz", hash = "sha256:48ca1a91ff73c1d5e3ea2eef20ae5d0e709bb8a2355ed798ffc2169753013fd3", size = 12165120 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b4/dd/c416a11a28847fafb0db1bf43381979a0f522eb9107b831058fde012dd56/onnx-1.17.0-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:0e906e6a83437de05f8139ea7eaf366bf287f44ae5cc44b2850a30e296421f2f", size = 16651271 },
{ url = "https://files.pythonhosted.org/packages/f0/6c/f040652277f514ecd81b7251841f96caa5538365af7df07f86c6018cda2b/onnx-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d955ba2939878a520a97614bcf2e79c1df71b29203e8ced478fa78c9a9c63c2", size = 15907522 },
{ url = "https://files.pythonhosted.org/packages/3d/7c/67f4952d1b56b3f74a154b97d0dd0630d525923b354db117d04823b8b49b/onnx-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f3fb5cc4e2898ac5312a7dc03a65133dd2abf9a5e520e69afb880a7251ec97a", size = 16046307 },
{ url = "https://files.pythonhosted.org/packages/ae/20/6da11042d2ab870dfb4ce4a6b52354d7651b6b4112038b6d2229ab9904c4/onnx-1.17.0-cp312-cp312-win32.whl", hash = "sha256:317870fca3349d19325a4b7d1b5628f6de3811e9710b1e3665c68b073d0e68d7", size = 14424235 },
{ url = "https://files.pythonhosted.org/packages/35/55/c4d11bee1fdb0c4bd84b4e3562ff811a19b63266816870ae1f95567aa6e1/onnx-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:659b8232d627a5460d74fd3c96947ae83db6d03f035ac633e20cd69cfa029227", size = 14530453 },
]
[[package]]
name = "onnxruntime"
version = "1.21.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "coloredlogs" },
{ name = "flatbuffers" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "protobuf" },
{ name = "sympy" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/ff/21/593c9bc56002a6d1ea7c2236f4a648e081ec37c8d51db2383a9e83a63325/onnxruntime-1.21.0-cp312-cp312-macosx_13_0_universal2.whl", hash = "sha256:893d67c68ca9e7a58202fa8d96061ed86a5815b0925b5a97aef27b8ba246a20b", size = 33658780 },
{ url = "https://files.pythonhosted.org/packages/4a/b4/33ec675a8ac150478091262824413e5d4acc359e029af87f9152e7c1c092/onnxruntime-1.21.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:37b7445c920a96271a8dfa16855e258dc5599235b41c7bbde0d262d55bcc105f", size = 14159975 },
{ url = "https://files.pythonhosted.org/packages/8b/08/eead6895ed83b56711ca6c0d31d82f109401b9937558b425509e497d6fb4/onnxruntime-1.21.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9a04aafb802c1e5573ba4552f8babcb5021b041eb4cfa802c9b7644ca3510eca", size = 16019285 },
{ url = "https://files.pythonhosted.org/packages/77/39/e83d56e3c215713b5263cb4d4f0c69e3964bba11634233d8ae04fc7e6bf3/onnxruntime-1.21.0-cp312-cp312-win_amd64.whl", hash = "sha256:7f801318476cd7003d636a5b392f7a37c08b6c8d2f829773f3c3887029e03f32", size = 11760975 },
{ url = "https://files.pythonhosted.org/packages/f2/25/93f65617b06c741a58eeac9e373c99df443b02a774f4cb6511889757c0da/onnxruntime-1.21.0-cp313-cp313-macosx_13_0_universal2.whl", hash = "sha256:85718cbde1c2912d3a03e3b3dc181b1480258a229c32378408cace7c450f7f23", size = 33659581 },
{ url = "https://files.pythonhosted.org/packages/f9/03/6b6829ee8344490ab5197f39a6824499ed097d1fc8c85b1f91c0e6767819/onnxruntime-1.21.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:94dff3a61538f3b7b0ea9a06bc99e1410e90509c76e3a746f039e417802a12ae", size = 14160534 },
{ url = "https://files.pythonhosted.org/packages/a6/81/e280ddf05f83ad5e0d066ef08e31515b17bd50bb52ef2ea713d9e455e67a/onnxruntime-1.21.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c1e704b0eda5f2bbbe84182437315eaec89a450b08854b5a7762c85d04a28a0a", size = 16018947 },
{ url = "https://files.pythonhosted.org/packages/d3/ea/011dfc2536e46e2ea984d2c0256dc585ebb1352366dffdd98764f1f44ee4/onnxruntime-1.21.0-cp313-cp313-win_amd64.whl", hash = "sha256:19b630c6a8956ef97fb7c94948b17691167aa1aaf07b5f214fa66c3e4136c108", size = 11760731 },
{ url = "https://files.pythonhosted.org/packages/47/6b/a00f31322e91c610c7825377ef0cad884483c30d1370b896d57e7032e912/onnxruntime-1.21.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3995c4a2d81719623c58697b9510f8de9fa42a1da6b4474052797b0d712324fe", size = 14172215 },
{ url = "https://files.pythonhosted.org/packages/58/4b/98214f13ac1cd675dfc2713ba47b5722f55ce4fba526d2b2826f2682a42e/onnxruntime-1.21.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:36b18b8f39c0f84e783902112a0dd3c102466897f96d73bb83f6a6bff283a423", size = 15990612 },
]
[[package]]
name = "packaging"
version = "24.2"
@ -800,6 +879,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl", hash = "sha256:9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198", size = 387816 },
]
[[package]]
name = "protobuf"
version = "6.30.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c8/8c/cf2ac658216eebe49eaedf1e06bc06cbf6a143469236294a1171a51357c3/protobuf-6.30.2.tar.gz", hash = "sha256:35c859ae076d8c56054c25b59e5e59638d86545ed6e2b6efac6be0b6ea3ba048", size = 429315 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/be/85/cd53abe6a6cbf2e0029243d6ae5fb4335da2996f6c177bb2ce685068e43d/protobuf-6.30.2-cp310-abi3-win32.whl", hash = "sha256:b12ef7df7b9329886e66404bef5e9ce6a26b54069d7f7436a0853ccdeb91c103", size = 419148 },
{ url = "https://files.pythonhosted.org/packages/97/e9/7b9f1b259d509aef2b833c29a1f3c39185e2bf21c9c1be1cd11c22cb2149/protobuf-6.30.2-cp310-abi3-win_amd64.whl", hash = "sha256:7653c99774f73fe6b9301b87da52af0e69783a2e371e8b599b3e9cb4da4b12b9", size = 431003 },
{ url = "https://files.pythonhosted.org/packages/8e/66/7f3b121f59097c93267e7f497f10e52ced7161b38295137a12a266b6c149/protobuf-6.30.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:0eb523c550a66a09a0c20f86dd554afbf4d32b02af34ae53d93268c1f73bc65b", size = 417579 },
{ url = "https://files.pythonhosted.org/packages/d0/89/bbb1bff09600e662ad5b384420ad92de61cab2ed0f12ace1fd081fd4c295/protobuf-6.30.2-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:50f32cc9fd9cb09c783ebc275611b4f19dfdfb68d1ee55d2f0c7fa040df96815", size = 317319 },
{ url = "https://files.pythonhosted.org/packages/28/50/1925de813499546bc8ab3ae857e3ec84efe7d2f19b34529d0c7c3d02d11d/protobuf-6.30.2-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:4f6c687ae8efae6cf6093389a596548214467778146b7245e886f35e1485315d", size = 316212 },
{ url = "https://files.pythonhosted.org/packages/e5/a1/93c2acf4ade3c5b557d02d500b06798f4ed2c176fa03e3c34973ca92df7f/protobuf-6.30.2-py3-none-any.whl", hash = "sha256:ae86b030e69a98e08c77beab574cbcb9fff6d031d57209f574a5aea1445f4b51", size = 167062 },
]
[[package]]
name = "psutil"
version = "7.0.0"
@ -930,6 +1023,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120 },
]
[[package]]
name = "pyreadline3"
version = "3.5.4"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/0f/49/4cea918a08f02817aabae639e3d0ac046fef9f9180518a3ad394e22da148/pyreadline3-3.5.4.tar.gz", hash = "sha256:8d57d53039a1c75adba8e50dd3d992b28143480816187ea5efbd5c78e6c885b7", size = 99839 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178 },
]
[[package]]
name = "pyside6"
version = "6.8.3"