Final update including testing, final model, new data sources, updated notebook for mlp and logistic regression and final keyboard code
This commit is contained in:
parent
bd602a31ea
commit
f7bdc26953
4
.gitignore
vendored
4
.gitignore
vendored
@ -168,4 +168,6 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
data.npy
|
||||
*.npy
|
||||
*.pth
|
||||
*.onnx
|
@ -1 +1 @@
|
||||
3.12
|
||||
3.8
|
||||
|
13
TESTING.md
Normal file
13
TESTING.md
Normal file
@ -0,0 +1,13 @@
|
||||
# Fixed bugs
|
||||
1. Incorrect Hitbox Behavior After Key Scaling
|
||||
- Due to a forgotten reset of the z-index of keys, sometimes a key may appear behind another one, yet still the old key gets clicked
|
||||
|
||||
2. Click propagation
|
||||
- On click even would propagate further through the app, meaning 2 keys could get pressed at once
|
||||
|
||||
3. Invalid model crashes the app
|
||||
- If a file path to model doesn't point to a true model, the app would crash
|
||||
|
||||
# Tests
|
||||
1. Config from environment (or `.env`) gets parsed - Defaults get used, but are overwritten if .env is used
|
||||
2. Space properly clears buffers - Pressing space key clears both prediction and word box buffers
|
6
app/constants.py
Normal file
6
app/constants.py
Normal file
@ -0,0 +1,6 @@
|
||||
CONTEXT_SIZE = 10
|
||||
ALPHABET = list("abcdefghijklmnopqrstuvwxyz")
|
||||
VOCAB_SIZE = len(ALPHABET) + 1 # +1 for unknown
|
||||
UNKNOWN_IDX = VOCAB_SIZE - 1
|
||||
|
||||
SINGLE_KEY_PROBABILITY = 1 / 26
|
@ -14,7 +14,7 @@ class Settings(BaseSettings):
|
||||
|
||||
APP_NAME: str = "Omega"
|
||||
VERBOSITY: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
MODEL_PATH: FilePath = Path("./model/predictor.pth")
|
||||
MODEL_PATH: FilePath = Path("model.onnx")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
@ -7,7 +7,7 @@ from app.ui.view import KeyboardView
|
||||
|
||||
from app.utils import setup_logger
|
||||
|
||||
class Keyboard():
|
||||
class KeyboardApplication():
|
||||
def __init__(self):
|
||||
setup_logger()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
56
app/predictor.py
Normal file
56
app/predictor.py
Normal file
@ -0,0 +1,56 @@
|
||||
import onnxruntime
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
from app.core.config import settings
|
||||
from app.constants import ALPHABET, CONTEXT_SIZE, SINGLE_KEY_PROBABILITY
|
||||
from app.utils import tokenize
|
||||
|
||||
import logging
|
||||
|
||||
class KeyboardPredictor():
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.buffer = ""
|
||||
try:
|
||||
self.session = onnxruntime.InferenceSession(settings.MODEL_PATH)
|
||||
except onnxruntime.capi.onnxruntime_pybind11_state.InvalidProtobuf as e:
|
||||
self.logger.critical("Invalid model found. Exiting...")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def reset(self):
|
||||
self.logger.info("Clearing prediction buffer")
|
||||
self.buffer = ""
|
||||
|
||||
def update(self, char: str):
|
||||
self.logger.info("Adding new char '%s'", char)
|
||||
# Remove the oldest character
|
||||
# and append the current one
|
||||
if len(self.buffer) == CONTEXT_SIZE:
|
||||
self.buffer = self.buffer[1:]
|
||||
self.buffer += char.lower()
|
||||
|
||||
def __get_prediction(self) -> dict[str, float]:
|
||||
"""Returns all probabilities distribution over the next character."""
|
||||
tokenized = [tokenize("".join(self.buffer))] # shape: (1, context_size)
|
||||
input_array = np.array(tokenized, dtype=np.int64)
|
||||
|
||||
outputs = self.session.run(None, {"input": input_array})
|
||||
logits = outputs[0][0] # shape: (VOCAB_SIZE,)
|
||||
|
||||
exp_logits = np.exp(logits - np.max(logits)) # numerical stability
|
||||
probs = exp_logits / exp_logits.sum()
|
||||
|
||||
return {char: probs[i] for i, char in enumerate(ALPHABET + ["?"])}
|
||||
|
||||
def get_predictions(self) -> dict[str, float]:
|
||||
"""
|
||||
Returns keys that are most likely to be pressed next.
|
||||
"""
|
||||
self.logger.debug("Getting fresh predictions")
|
||||
DROPOUT_PROBABILITY = 2 * SINGLE_KEY_PROBABILITY
|
||||
|
||||
probs = self.__get_prediction()
|
||||
probs = {char: prob for char, prob in probs.items() if prob > DROPOUT_PROBABILITY}
|
||||
return probs
|
54
app/ui/current_word.py
Normal file
54
app/ui/current_word.py
Normal file
@ -0,0 +1,54 @@
|
||||
from PySide6.QtWidgets import QGraphicsRectItem, QGraphicsSimpleTextItem
|
||||
from PySide6.QtGui import QFont, QFontMetricsF, QColor, QBrush
|
||||
|
||||
|
||||
BACKGROUND_COLOR = QColor("white")
|
||||
TEXT_SCALE_FACTOR = 0.4 # 40% of the width or height
|
||||
Z_INDEX = 10
|
||||
|
||||
|
||||
class CurrentWordBox(QGraphicsRectItem):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.word = ""
|
||||
self.text = QGraphicsSimpleTextItem("", self)
|
||||
self.text.setFlag(QGraphicsSimpleTextItem.ItemIgnoresTransformations, True)
|
||||
|
||||
self.setBrush(QBrush(BACKGROUND_COLOR))
|
||||
self.setZValue(Z_INDEX)
|
||||
|
||||
self.width = 0
|
||||
self.height = 0
|
||||
|
||||
def set_geometry(self, x: float, y: float, width: float, height: float):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.setRect(0, 0, width, height)
|
||||
self.setPos(x, y)
|
||||
self.update_label_font()
|
||||
|
||||
def update_word(self, next_char: str):
|
||||
self.word += next_char
|
||||
self.text.setText(self.word)
|
||||
self.update_label_font()
|
||||
|
||||
def clear(self):
|
||||
self.word = ""
|
||||
self.text.setText("")
|
||||
self.update_label_font()
|
||||
|
||||
def update_label_font(self):
|
||||
min_dimension = min(self.width, self.height)
|
||||
font_size = min_dimension * TEXT_SCALE_FACTOR
|
||||
|
||||
font = QFont()
|
||||
font.setPointSizeF(font_size)
|
||||
self.text.setFont(font)
|
||||
|
||||
metrics = QFontMetricsF(font)
|
||||
text_rect = metrics.boundingRect(self.word)
|
||||
|
||||
text_x = (self.width - text_rect.width()) / 2
|
||||
text_y = (self.height - text_rect.height()) / 2
|
||||
self.text.setPos(text_x, text_y)
|
@ -1,23 +1,45 @@
|
||||
import logging
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QGraphicsRectItem, QGraphicsSimpleTextItem, QGraphicsItem
|
||||
)
|
||||
from PySide6.QtGui import QBrush, QColor, QFont, QFontMetricsF
|
||||
from PySide6.QtGui import QBrush, QColor, QFont, QFontMetricsF, QPen, Qt
|
||||
from PySide6.QtCore import Signal, QObject, QTimer
|
||||
|
||||
|
||||
NORMAL_COLOR = QColor("lightgray")
|
||||
CLICKED_COLOR = QColor("darkgray")
|
||||
TEXT_SCALE_FACTOR = 0.4
|
||||
CLICK_HIGHLIGHT_DURATION_MS = 100
|
||||
Z_INDEX_MULTIPLIER = 2.0
|
||||
|
||||
class KeyItem(QGraphicsRectItem, QObject):
|
||||
clicked = Signal(str)
|
||||
|
||||
def __init__(self, label: str):
|
||||
QObject.__init__(self)
|
||||
QGraphicsRectItem.__init__(self)
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyItem(QGraphicsRectItem):
|
||||
def __init__(self, label):
|
||||
super().__init__()
|
||||
self.label = label
|
||||
self.width = 0.0
|
||||
self.height = 0.0
|
||||
self.scale_factor = 1.0
|
||||
|
||||
self._reset_timer = QTimer()
|
||||
self._reset_timer.setSingleShot(True)
|
||||
self._reset_timer.timeout.connect(self.__reset_color)
|
||||
|
||||
self.text = QGraphicsSimpleTextItem(label, self)
|
||||
self.text.setFlag(QGraphicsItem.ItemIgnoresTransformations, True)
|
||||
|
||||
self.setBrush(QBrush(QColor("lightgray")))
|
||||
self.__normal_brush = QBrush(NORMAL_COLOR)
|
||||
self.__click_brush = QBrush(CLICKED_COLOR)
|
||||
self.setBrush(self.__normal_brush)
|
||||
|
||||
self.scale_factor = 1.0
|
||||
self.width = 0
|
||||
self.height = 0
|
||||
|
||||
def set_geometry(self, x, y, width, height):
|
||||
def set_geometry(self, x: float, y: float, width: float, height: float):
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
@ -28,25 +50,34 @@ class KeyItem(QGraphicsRectItem):
|
||||
self.update_label_font()
|
||||
|
||||
def update_label_font(self):
|
||||
# Dynamically size font
|
||||
font_size = min(self.width, self.height) * 0.4
|
||||
min_dimension = min(self.width, self.height)
|
||||
font_size = min_dimension * TEXT_SCALE_FACTOR
|
||||
|
||||
font = QFont()
|
||||
font.setPointSizeF(font_size)
|
||||
self.text.setFont(font)
|
||||
|
||||
# Use font metrics for accurate bounding
|
||||
metrics = QFontMetricsF(font)
|
||||
text_rect = metrics.boundingRect(self.label)
|
||||
|
||||
text_x = (self.width - text_rect.width()) / 2
|
||||
text_y = (self.height - text_rect.height()) / 2 # + metrics.ascent() - text_rect.height()
|
||||
|
||||
text_y = (self.height - text_rect.height()) / 2
|
||||
self.text.setPos(text_x, text_y)
|
||||
|
||||
def set_scale_factor(self, scale):
|
||||
self.scale_factor = scale
|
||||
self.setZValue(scale)
|
||||
def set_scale_factor(self, scale: float):
|
||||
scaled_z = scale * Z_INDEX_MULTIPLIER
|
||||
self.scale_factor = scaled_z
|
||||
self.setZValue(scaled_z)
|
||||
self.setScale(scale)
|
||||
|
||||
def mousePressEvent(self, q_mouse_event):
|
||||
print(self.label, "was clicked")
|
||||
self.logger.info("%s was clicked", self.label)
|
||||
self.clicked.emit(self.label)
|
||||
self.__onclick_color()
|
||||
|
||||
def __onclick_color(self):
|
||||
self.setBrush(self.__click_brush)
|
||||
self._reset_timer.start(CLICK_HIGHLIGHT_DURATION_MS)
|
||||
|
||||
def __reset_color(self):
|
||||
self.setBrush(self.__normal_brush)
|
||||
|
70
app/ui/keyboard.py
Normal file
70
app/ui/keyboard.py
Normal file
@ -0,0 +1,70 @@
|
||||
from PySide6.QtCore import QSize, QRectF
|
||||
from app.ui.key import KeyItem
|
||||
|
||||
|
||||
class KeyboardLayout():
|
||||
def __init__(self, scene):
|
||||
self.scene = scene
|
||||
self.keys = {}
|
||||
self.letter_rows = [
|
||||
"QWERTZUIOP",
|
||||
"ASDFGHJKL",
|
||||
"YXCVBNM",
|
||||
]
|
||||
self.space_key = KeyItem("Space")
|
||||
self.space_key.clicked.connect(lambda: self.on_key_clicked("Space"))
|
||||
self.scene.addItem(self.space_key)
|
||||
self.create_keys()
|
||||
|
||||
# callback to be set by scene
|
||||
self.key_pressed_callback = None
|
||||
|
||||
def create_keys(self):
|
||||
for row in self.letter_rows:
|
||||
for char in row:
|
||||
key = KeyItem(char)
|
||||
key.clicked.connect(lambda c=char: self.on_key_clicked(c))
|
||||
self.scene.addItem(key)
|
||||
self.keys[char] = key
|
||||
|
||||
def layout_keys(self, view_size: QSize, top_offset: float = 0.0):
|
||||
view_width = view_size.width()
|
||||
view_height = view_size.height()
|
||||
|
||||
padding = 10.0
|
||||
spacing = 6.0
|
||||
rows = self.letter_rows + [" "]
|
||||
row_count = len(rows)
|
||||
max_keys_in_row = max(len(row) for row in rows)
|
||||
|
||||
total_spacing_x = (max_keys_in_row - 1) * spacing
|
||||
total_spacing_y = (row_count - 1) * spacing
|
||||
|
||||
available_width = view_width - 2 * padding - total_spacing_x
|
||||
available_height = view_height - top_offset - 2 * padding - total_spacing_y
|
||||
|
||||
key_width = available_width / max_keys_in_row
|
||||
key_height = available_height / row_count
|
||||
|
||||
y = top_offset + padding
|
||||
for row in self.letter_rows:
|
||||
x = padding
|
||||
for char in row:
|
||||
key = self.keys[char]
|
||||
key.set_geometry(x, y, key_width, key_height)
|
||||
x += key_width + spacing
|
||||
y += key_height + spacing
|
||||
|
||||
space_width = key_width * 7 + spacing * 6
|
||||
self.space_key.set_geometry(padding, y, space_width, key_height)
|
||||
|
||||
def set_scale_factors(self, predictions: dict):
|
||||
for key in self.keys.values():
|
||||
key.set_scale_factor(1.0)
|
||||
for key, probability in predictions.items():
|
||||
if key.upper() in self.keys:
|
||||
self.keys[key.upper()].set_scale_factor(1 + probability)
|
||||
|
||||
def on_key_clicked(self, label: str):
|
||||
if self.key_pressed_callback:
|
||||
self.key_pressed_callback(label)
|
@ -1,79 +1,52 @@
|
||||
from PySide6.QtWidgets import (
|
||||
QGraphicsScene
|
||||
)
|
||||
from PySide6.QtGui import QBrush, QColor
|
||||
from PySide6.QtCore import QTimer, QRectF, QSize
|
||||
import random
|
||||
from PySide6.QtWidgets import QGraphicsScene
|
||||
from PySide6.QtCore import QRectF, QSize
|
||||
|
||||
from app.predictor import KeyboardPredictor
|
||||
from app.ui.current_word import CurrentWordBox
|
||||
from app.ui.keyboard import KeyboardLayout
|
||||
|
||||
|
||||
WORD_BOX_HEIGHT_MODIFIER = 0.15
|
||||
WORD_BOX_PADDING_ALL_SIDES = 10
|
||||
WORD_BOX_MARGIN_BOTTOM = 20
|
||||
|
||||
from app.ui.key import KeyItem
|
||||
class KeyboardScene(QGraphicsScene):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.keys = {}
|
||||
self.letter_rows = [
|
||||
"QWERTYUIOP",
|
||||
"ASDFGHJKL",
|
||||
"ZXCVBNM",
|
||||
]
|
||||
self.space_key = KeyItem("Space")
|
||||
self.addItem(self.space_key)
|
||||
self.timer = QTimer()
|
||||
self.timer.timeout.connect(self.simulate_prediction)
|
||||
self.timer.start(2000)
|
||||
|
||||
self.create_keys()
|
||||
self.word_box = CurrentWordBox()
|
||||
self.addItem(self.word_box)
|
||||
|
||||
def create_keys(self):
|
||||
for row in self.letter_rows:
|
||||
for char in row:
|
||||
key = KeyItem(char)
|
||||
self.addItem(key)
|
||||
self.keys[char] = key
|
||||
self.predictor = KeyboardPredictor()
|
||||
|
||||
self.keyboard = KeyboardLayout(self)
|
||||
self.keyboard.key_pressed_callback = self.on_key_clicked
|
||||
|
||||
def layout_keys(self, view_size: QSize):
|
||||
view_width = view_size.width()
|
||||
view_height = view_size.height()
|
||||
|
||||
padding = 10.0
|
||||
spacing = 6.0
|
||||
max_scale = 2.0
|
||||
word_box_height = view_height * WORD_BOX_HEIGHT_MODIFIER
|
||||
word_box_x = WORD_BOX_PADDING_ALL_SIDES
|
||||
word_box_y = WORD_BOX_PADDING_ALL_SIDES
|
||||
word_box_width = view_width - WORD_BOX_MARGIN_BOTTOM
|
||||
word_box_height_adjusted = word_box_height - WORD_BOX_MARGIN_BOTTOM
|
||||
|
||||
rows = self.letter_rows + [" "]
|
||||
row_count = len(rows)
|
||||
max_keys_in_row = max(len(row) for row in rows)
|
||||
|
||||
total_spacing_x = (max_keys_in_row - 1) * spacing
|
||||
total_spacing_y = (row_count - 1) * spacing
|
||||
|
||||
available_width = view_width - 2 * padding - total_spacing_x
|
||||
available_height = view_height - 2 * padding - total_spacing_y
|
||||
|
||||
key_width = available_width / max_keys_in_row
|
||||
key_height = available_height / row_count
|
||||
|
||||
y = padding
|
||||
for row in self.letter_rows:
|
||||
x = padding
|
||||
for char in row:
|
||||
key = self.keys[char]
|
||||
key.set_geometry(x, y, key_width, key_height)
|
||||
x += key_width + spacing
|
||||
y += key_height + spacing
|
||||
|
||||
# Space key layout
|
||||
space_width = key_width * 7 + spacing * 6
|
||||
self.space_key.set_geometry(padding, y, space_width, key_height)
|
||||
self.word_box.set_geometry(
|
||||
word_box_x, word_box_y, word_box_width, word_box_height_adjusted
|
||||
)
|
||||
|
||||
self.keyboard.layout_keys(view_size, top_offset=word_box_height)
|
||||
self.setSceneRect(QRectF(0, 0, view_width, view_height))
|
||||
|
||||
def simulate_prediction(self):
|
||||
most_likely = random.choice(list(self.keys.keys()))
|
||||
print(f"[Prediction] Most likely: {most_likely}")
|
||||
def on_key_clicked(self, label: str):
|
||||
if label == "Space":
|
||||
self.word_box.clear()
|
||||
self.predictor.reset()
|
||||
self.keyboard.set_scale_factors({})
|
||||
return
|
||||
|
||||
for char, key in self.keys.items():
|
||||
if char == most_likely:
|
||||
key.set_scale_factor(1.8)
|
||||
key.setBrush(QBrush(QColor("orange")))
|
||||
else:
|
||||
key.set_scale_factor(1.0)
|
||||
key.setBrush(QBrush(QColor("lightgray")))
|
||||
self.word_box.update_word(label)
|
||||
self.predictor.update(label)
|
||||
predictions = self.predictor.get_predictions()
|
||||
self.keyboard.set_scale_factors(predictions)
|
||||
|
@ -5,6 +5,7 @@ from PySide6.QtGui import QPainter
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
from .scene import KeyboardScene
|
||||
from app.core.config import settings
|
||||
|
||||
class KeyboardView(QGraphicsView):
|
||||
def __init__(self):
|
||||
@ -12,9 +13,9 @@ class KeyboardView(QGraphicsView):
|
||||
self.scene = KeyboardScene()
|
||||
self.setScene(self.scene)
|
||||
self.setRenderHint(QPainter.Antialiasing)
|
||||
self.setWindowTitle("Dynamic Keyboard")
|
||||
self.setWindowTitle(settings.APP_NAME)
|
||||
self.setAlignment(Qt.AlignLeft | Qt.AlignTop)
|
||||
self.setMinimumSize(600, 200) # Sensible default
|
||||
self.setMinimumSize(600, 300)
|
||||
|
||||
def resizeEvent(self, event):
|
||||
super().resizeEvent(event)
|
||||
|
15
app/utils.py
15
app/utils.py
@ -3,6 +3,11 @@ import logging
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
from app.constants import CONTEXT_SIZE, UNKNOWN_IDX, ALPHABET
|
||||
|
||||
char_to_index = {char: idx for idx, char in enumerate(ALPHABET)}
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger()
|
||||
|
||||
@ -51,3 +56,13 @@ if sys.platform == "win32":
|
||||
0, 0, 0, 0,
|
||||
win32con.SWP_NOMOVE | win32con.SWP_NOSIZE | win32con.SWP_NOACTIVATE
|
||||
)
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[int]:
|
||||
"""Convert last CONTEXT_SIZE chars to integer indices."""
|
||||
text = text.lower()[-CONTEXT_SIZE:] # trim to context length
|
||||
padded = [' '] * (CONTEXT_SIZE - len(text)) + list(text)
|
||||
return [
|
||||
char_to_index.get(c, UNKNOWN_IDX)
|
||||
for c in padded
|
||||
]
|
||||
|
4
main.py
4
main.py
@ -1,7 +1,7 @@
|
||||
import sys
|
||||
|
||||
from app.keyboard import Keyboard
|
||||
from app.keyboard import KeyboardApplication
|
||||
|
||||
if __name__ == "__main__":
|
||||
keyboard = Keyboard()
|
||||
keyboard = KeyboardApplication()
|
||||
sys.exit(keyboard.run())
|
||||
|
BIN
model.onnx
Normal file
BIN
model.onnx
Normal file
Binary file not shown.
@ -1,20 +1,34 @@
|
||||
# omega
|
||||
# Model and training
|
||||
|
||||
## Documentation
|
||||
First I gathered a couple long text source, like the GNU GPL license, Wikipedia articles, or even a book.
|
||||
This is the section of the code related to training and modifying the model used by this app
|
||||
|
||||
Those were transformed into a large text file [see all_words.txt](data/all_words.txt) using the following command
|
||||
## Training
|
||||
There are 2 notebooks related to training
|
||||
- [Multilayer Perceptron](./mlp.ipynb)
|
||||
- [Logistic Regression model](./logistic.ipynb)
|
||||
|
||||
```
|
||||
grep -o "[[:alpha:]]\{1,\}" "path_to_individual_source.txt" | tr '[:upper:]' '[:lower:]'
|
||||
```
|
||||
MLP proved to be far more accurate, therefore it is the one used
|
||||
|
||||
Which simply finds words at least 1 character long and unifies them by transforming them all to lowercase.
|
||||
## Data
|
||||
See [Sources](#Sources) for all data used.
|
||||
Data includes various data including news articles, scientific articles, couple of books and wikipedia articles.
|
||||
|
||||
The text was extracted by simply copying the text including some unwanted garbage like numbers, wikipedia links, etc. Simply put Ctrl+C - Ctrl+V.
|
||||
|
||||
Next step in data processing was using the included scripts, mainly [`words.sh`](./words.sh) which extracts only alphanumeric strings and places them on new lines.
|
||||
Second was cleaning the data completely by allowing only the 26 english alphabet characters to be present. This is done by [`clear.sh`](./clear.sh).
|
||||
Third and last was turning the data into a numpy array for space efficiency (instead of CSV). This is done by [`transform.py`](./transform.py). This is the last step for data processing and the model can now be trained using this data
|
||||
|
||||
## Structure
|
||||
The current model is a **character-level** predictor that uses the previous 10 characters to predict the next one.
|
||||
It was trained using the processed word list.
|
||||
Dataset can be found in [`data/all_cleaned_words.txt`](data/all_cleaned_words.txt)).
|
||||
|
||||
- **Model type**: **`Multilayer Perceptron`**
|
||||
- **Input shape**: **`10 * 16`** - 10 previous characters in form of embeddings of 16 dimensions
|
||||
- **Output shape**: **`26`** - 26 probabilities for each letter of the english alphabet
|
||||
- **Dataset**: **`220k`** words from various types of sources, mainly books though
|
||||
|
||||
For the model to have as much accuracy as possible, I calculated the average word length (5.819) and went with character history of 5 letters. This is for now the norm and can easily be omitted from the data if it becomes excessive
|
||||
```
|
||||
awk '{ total += length; count++ } END { if (count > 0) print total / count }' 1000_words.txt
|
||||
```
|
||||
|
||||
## Sources
|
||||
1. Generic news articles
|
||||
@ -36,5 +50,6 @@ awk '{ total += length; count++ } END { if (count > 0) print total / count }' 10
|
||||
- https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html
|
||||
|
||||
5. Books
|
||||
- https://ia902902.us.archive.org/19/items/diaryofawimpykidbookseriesbyjeffkinney_202004/Diary%20of%20a%20wimpy%20kid%20book02%20rodrick%20rules.pdf
|
||||
- https://drive.google.com/file/d/1b1Etdxb1cNU3PvDBQnYh0bCAAfssMi8b/view
|
||||
- https://dhspriory.org/kenny/PhilTexts/Camus/Myth%20of%20Sisyphus-.pdf
|
||||
- https://www.matermiddlehigh.org/ourpages/auto/2012/11/16/50246772/Beloved.pdf
|
3
model/clear.sh
Executable file
3
model/clear.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
grep -o "[[:alpha:]]\{1,\}" "$1" | tr '[:upper:]' '[:lower:]' | grep -Eo '^([a-z]{2,}|a|i)$'
|
216339
model/data/all_cleaned_words_shuffled.txt
Normal file
216339
model/data/all_cleaned_words_shuffled.txt
Normal file
File diff suppressed because it is too large
Load Diff
155467
model/data/all_words.txt
155467
model/data/all_words.txt
File diff suppressed because it is too large
Load Diff
63913
model/data/all_words.txt.bak
Normal file
63913
model/data/all_words.txt.bak
Normal file
File diff suppressed because it is too large
Load Diff
6173
model/data/beloved.txt
Normal file
6173
model/data/beloved.txt
Normal file
File diff suppressed because it is too large
Load Diff
4941
model/data/myth_of_sisyphus.txt
Normal file
4941
model/data/myth_of_sisyphus.txt
Normal file
File diff suppressed because it is too large
Load Diff
339
model/logistic.ipynb
Normal file
339
model/logistic.ipynb
Normal file
@ -0,0 +1,339 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 1. Import and data processing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Import all necessary libraries"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch.utils.data import DataLoader, TensorDataset, random_split"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Load the training data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = np.load(\"./data.npy\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define constants that describe the data and model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"CONTEXT_SIZE = 10\n",
|
||||
"ALPHABET = list(\"abcdefghijklmnopqrstuvwxyz\")\n",
|
||||
"ALPHABET_SIZE = len(ALPHABET)\n",
|
||||
"TRAINING_DATA_SIZE = 0.9\n",
|
||||
"\n",
|
||||
"# +1 is for unknown characters\n",
|
||||
"VOCAB_SIZE = ALPHABET_SIZE + 1\n",
|
||||
"\n",
|
||||
"EMBEDDING_DIM = 10\n",
|
||||
"\n",
|
||||
"INPUT_SEQ_LEN = CONTEXT_SIZE\n",
|
||||
"OUTPUT_SIZE = VOCAB_SIZE\n",
|
||||
"\n",
|
||||
"BATCH_SIZE = 2048\n",
|
||||
"\n",
|
||||
"EPOCHS = 30\n",
|
||||
"LEARNING_RATE = 1e-3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Process the data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Input: embeddings of the previous 10 letters\n",
|
||||
"# shape: (num_samples, CONTEXT_SIZE)\n",
|
||||
"X = data[:, :CONTEXT_SIZE]\n",
|
||||
"\n",
|
||||
"# Target: current letter index\n",
|
||||
"# shape: (num_samples,)\n",
|
||||
"y = data[:, CONTEXT_SIZE]\n",
|
||||
"\n",
|
||||
"# Torch dataset (important: use long/int64 for indices)\n",
|
||||
"X_tensor = torch.tensor(X, dtype=torch.long) # for nn.Embedding\n",
|
||||
"y_tensor = torch.tensor(y, dtype=torch.long) # for classification target\n",
|
||||
"\n",
|
||||
"dataset = TensorDataset(X_tensor, y_tensor)\n",
|
||||
"\n",
|
||||
"train_len = int(TRAINING_DATA_SIZE * len(dataset))\n",
|
||||
"train_set, test_set = random_split(dataset, [train_len, len(dataset) - train_len])\n",
|
||||
"\n",
|
||||
"train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)\n",
|
||||
"test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 2. Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Logistic(nn.Module):\n",
|
||||
" def __init__(self, *, embedding_count: int, embedding_dimension_size: int, context_size: int, output_shape: int):\n",
|
||||
" super().__init__()\n",
|
||||
" self.embedding = nn.Embedding(num_embeddings=embedding_count, embedding_dim=embedding_dimension_size)\n",
|
||||
" self.linear = nn.Linear(context_size * embedding_dimension_size, output_shape)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" embedded = self.embedding(x) # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_DIM)\n",
|
||||
" flattened = embedded.view(x.size(0), -1) # (BATCH_SIZE, CONTEXT_SIZE * EMBEDDING_DIM)\n",
|
||||
" return self.linear(flattened) # (BATCH_SIZE, OUTPUT_SIZE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using device: cpu\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/thastertyn/Code/Skola/4-rocnik/programove-vybaveni/omega/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:129: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)\n",
|
||||
" return torch._C._cuda_getDeviceCount() > 0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"print(f\"Using device: {device}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 3. Training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create fresh instance of the model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = Logistic(\n",
|
||||
" embedding_count=VOCAB_SIZE, # e.g., 27 for a–z + unknown\n",
|
||||
" embedding_dimension_size=EMBEDDING_DIM, # e.g., 10\n",
|
||||
" context_size=CONTEXT_SIZE, # e.g., 10\n",
|
||||
" output_shape=OUTPUT_SIZE # e.g., 27 (next character)\n",
|
||||
").to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Epoch 1] - Loss: 2.5968 | Accuracy: 23.06%\n",
|
||||
"[Epoch 2] - Loss: 2.3218 | Accuracy: 30.02%\n",
|
||||
"[Epoch 3] - Loss: 2.2600 | Accuracy: 31.25%\n",
|
||||
"[Epoch 4] - Loss: 2.2325 | Accuracy: 31.55%\n",
|
||||
"[Epoch 5] - Loss: 2.2171 | Accuracy: 31.75%\n",
|
||||
"[Epoch 6] - Loss: 2.2076 | Accuracy: 31.98%\n",
|
||||
"[Epoch 7] - Loss: 2.2006 | Accuracy: 32.22%\n",
|
||||
"[Epoch 8] - Loss: 2.1962 | Accuracy: 32.36%\n",
|
||||
"[Epoch 9] - Loss: 2.1925 | Accuracy: 32.42%\n",
|
||||
"[Epoch 10] - Loss: 2.1900 | Accuracy: 32.48%\n",
|
||||
"[Epoch 11] - Loss: 2.1876 | Accuracy: 32.54%\n",
|
||||
"[Epoch 12] - Loss: 2.1859 | Accuracy: 32.64%\n",
|
||||
"[Epoch 13] - Loss: 2.1847 | Accuracy: 32.65%\n",
|
||||
"[Epoch 14] - Loss: 2.1833 | Accuracy: 32.76%\n",
|
||||
"[Epoch 15] - Loss: 2.1821 | Accuracy: 32.75%\n",
|
||||
"[Epoch 16] - Loss: 2.1813 | Accuracy: 32.74%\n",
|
||||
"[Epoch 17] - Loss: 2.1806 | Accuracy: 32.84%\n",
|
||||
"[Epoch 18] - Loss: 2.1799 | Accuracy: 32.81%\n",
|
||||
"[Epoch 19] - Loss: 2.1792 | Accuracy: 32.80%\n",
|
||||
"[Epoch 20] - Loss: 2.1786 | Accuracy: 32.81%\n",
|
||||
"[Epoch 21] - Loss: 2.1780 | Accuracy: 32.77%\n",
|
||||
"[Epoch 22] - Loss: 2.1776 | Accuracy: 32.85%\n",
|
||||
"[Epoch 23] - Loss: 2.1770 | Accuracy: 32.81%\n",
|
||||
"[Epoch 24] - Loss: 2.1767 | Accuracy: 32.81%\n",
|
||||
"[Epoch 25] - Loss: 2.1764 | Accuracy: 32.81%\n",
|
||||
"[Epoch 26] - Loss: 2.1757 | Accuracy: 32.80%\n",
|
||||
"[Epoch 27] - Loss: 2.1755 | Accuracy: 32.81%\n",
|
||||
"[Epoch 28] - Loss: 2.1751 | Accuracy: 32.79%\n",
|
||||
"[Epoch 29] - Loss: 2.1748 | Accuracy: 32.82%\n",
|
||||
"[Epoch 30] - Loss: 2.1744 | Accuracy: 32.80%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for epoch in range(EPOCHS):\n",
|
||||
" model.train()\n",
|
||||
" total_loss = 0\n",
|
||||
" correct = 0\n",
|
||||
" total = 0\n",
|
||||
"\n",
|
||||
" for batch_X, batch_y in train_loader:\n",
|
||||
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" logits = model(batch_X) # shape: (BATCH_SIZE, OUTPUT_SIZE)\n",
|
||||
" loss = criterion(logits, batch_y)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" total_loss += loss.item() * batch_X.size(0)\n",
|
||||
"\n",
|
||||
" # Compute accuracy\n",
|
||||
" preds = torch.argmax(logits, dim=1)\n",
|
||||
" correct += (preds == batch_y).sum().item()\n",
|
||||
" total += batch_X.size(0)\n",
|
||||
"\n",
|
||||
" avg_loss = total_loss / total\n",
|
||||
" accuracy = correct / total * 100\n",
|
||||
" print(f\"[Epoch {epoch+1}] - Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Top 1 prediction accuracy: 32.45%\n",
|
||||
"Top 3 prediction accuracy: 58.55%\n",
|
||||
"Top 5 prediction accuracy: 72.66%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.eval()\n",
|
||||
"correct_top1 = 0\n",
|
||||
"correct_top3 = 0\n",
|
||||
"correct_top5 = 0\n",
|
||||
"total = 0\n",
|
||||
"\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for batch_X, batch_y in test_loader:\n",
|
||||
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
||||
" outputs = model(batch_X)\n",
|
||||
"\n",
|
||||
" _, top_preds = outputs.topk(5, dim=1)\n",
|
||||
"\n",
|
||||
" for true, top5 in zip(batch_y, top_preds):\n",
|
||||
" total += 1\n",
|
||||
" if true == top5[0]:\n",
|
||||
" correct_top1 += 1\n",
|
||||
" if true in top5[:3]:\n",
|
||||
" correct_top3 += 1\n",
|
||||
" if true in top5:\n",
|
||||
" correct_top5 += 1\n",
|
||||
"\n",
|
||||
"top1_acc = correct_top1 / total\n",
|
||||
"top3_acc = correct_top3 / total\n",
|
||||
"top5_acc = correct_top5 / total\n",
|
||||
"\n",
|
||||
"print(f\"Top 1 prediction accuracy: {(top1_acc * 100):.2f}%\")\n",
|
||||
"print(f\"Top 3 prediction accuracy: {(top3_acc * 100):.2f}%\")\n",
|
||||
"print(f\"Top 5 prediction accuracy: {(top5_acc * 100):.2f}%\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
601
model/mlp.ipynb
Normal file
601
model/mlp.ipynb
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
@ -1,475 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Omega\n",
|
||||
"Prediction of next key to be pressed using Multilayer Perceptron"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Import and load data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Import all required modules"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch.utils.data import DataLoader, TensorDataset, random_split"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = np.load(\"./data.npy\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define contstants describing the dataset and other useful information"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"CONTEXT_SIZE = 10\n",
|
||||
"ALPHABET = list(\"abcdefghijklmnopqrstuvwxyz\")\n",
|
||||
"ALPHABET_SIZE = len(ALPHABET)\n",
|
||||
"TRAINING_DATA_SIZE = 0.9\n",
|
||||
"\n",
|
||||
"VOCAB_SIZE = ALPHABET_SIZE + 1 # 26 letters + 1 for unknown\n",
|
||||
"EMBEDDING_DIM = 16\n",
|
||||
"\n",
|
||||
"INPUT_SEQ_LEN = CONTEXT_SIZE\n",
|
||||
"OUTPUT_SIZE = VOCAB_SIZE"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Define and split data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define input and output columns"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X = data[:, :CONTEXT_SIZE] # shape: (num_samples, CONTEXT_SIZE)\n",
|
||||
"\n",
|
||||
"# Target: current letter index\n",
|
||||
"y = data[:, CONTEXT_SIZE] # shape: (num_samples,)\n",
|
||||
"\n",
|
||||
"# Torch dataset (important: use long/int64 for indices)\n",
|
||||
"X_tensor = torch.tensor(X, dtype=torch.long) # for nn.Embedding\n",
|
||||
"y_tensor = torch.tensor(y, dtype=torch.long) # for classification target\n",
|
||||
"\n",
|
||||
"dataset = TensorDataset(X_tensor, y_tensor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_len = int(TRAINING_DATA_SIZE * len(dataset))\n",
|
||||
"train_set, test_set = random_split(dataset, [train_len, len(dataset) - train_len])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_loader = DataLoader(train_set, batch_size=1024, shuffle=True)\n",
|
||||
"test_loader = DataLoader(test_set, batch_size=1024)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"learning_rates = [1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2]\n",
|
||||
"activation_layers = [nn.ReLU, nn.GELU]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model and training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To find the best model for MLP, combinations of hyperparams are defined. \n",
|
||||
"This includes **activation layers** and **learning rates**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from itertools import product\n",
|
||||
"all_activation_combinations = list(product(activation_layers, repeat=len(activation_layers)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 66,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MLP(nn.Module):\n",
|
||||
" def __init__(self, activation_layers: list):\n",
|
||||
" super().__init__()\n",
|
||||
" self.net = nn.Sequential(\n",
|
||||
" nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=EMBEDDING_DIM),\n",
|
||||
" nn.Flatten(),\n",
|
||||
" nn.Linear(CONTEXT_SIZE * EMBEDDING_DIM, 256),\n",
|
||||
" activation_layers[0](),\n",
|
||||
" nn.Linear(256, 128),\n",
|
||||
" activation_layers[1](),\n",
|
||||
" nn.Linear(128, OUTPUT_SIZE)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return self.net(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using device: cuda\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"print(f\"Using device: {device}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 55,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# model = MLP().to(device)\n",
|
||||
"model = None"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Test all the activation_layer combinations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 65,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"criterion = nn.CrossEntropyLoss()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 71,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_model(model, optimizer):\n",
|
||||
" for epoch in range(30):\n",
|
||||
" model.train()\n",
|
||||
" total_loss = 0\n",
|
||||
" for batch_X, batch_y in train_loader:\n",
|
||||
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" output = model(batch_X)\n",
|
||||
" loss = criterion(output, batch_y)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" total_loss += loss.item()\n",
|
||||
" # print(f\"Epoch {epoch+1}, Loss: {total_loss:.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Testing model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 70,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def test_model(model) -> tuple[float]:\n",
|
||||
" model.eval()\n",
|
||||
" correct_top1 = 0\n",
|
||||
" correct_top3 = 0\n",
|
||||
" correct_top5 = 0\n",
|
||||
" total = 0\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for batch_X, batch_y in test_loader:\n",
|
||||
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
||||
" outputs = model(batch_X)\n",
|
||||
"\n",
|
||||
" _, top_preds = outputs.topk(5, dim=1)\n",
|
||||
"\n",
|
||||
" for true, top5 in zip(batch_y, top_preds):\n",
|
||||
" total += 1\n",
|
||||
" if true == top5[0]:\n",
|
||||
" correct_top1 += 1\n",
|
||||
" if true in top5[:3]:\n",
|
||||
" correct_top3 += 1\n",
|
||||
" if true in top5:\n",
|
||||
" correct_top5 += 1\n",
|
||||
"\n",
|
||||
" top1_acc = correct_top1 / total\n",
|
||||
" top3_acc = correct_top3 / total\n",
|
||||
" top5_acc = correct_top5 / total\n",
|
||||
"\n",
|
||||
" return (top1_acc, top3_acc, top5_acc)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 72,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.0001 had success of (0.44952931636286714, 0.6824383880407573, 0.788915135916511)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.0005 had success of (0.5080210132919649, 0.7299298381694461, 0.8241018227973064)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.001 had success of (0.5215950357860593, 0.7354299615696506, 0.826111483270458)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.005 had success of (0.5230758382399605, 0.7383563092761697, 0.8298840038077777)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.01 had success of (0.5206783485526919, 0.7364171632055847, 0.8278390861333428)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.05 had success of (0.12682015301625357, 0.29884003807777737, 0.45160949123858546)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.0001 had success of (0.44251313330747805, 0.6765504354264359, 0.7860240454112752)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.0005 had success of (0.5103127313753835, 0.7293304657476289, 0.8237492507844727)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.001 had success of (0.5211366921693756, 0.7379332228607693, 0.8288968021718436)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.005 had success of (0.5246271550964284, 0.739942883333921, 0.8305538906321617)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.01 had success of (0.5214892641822092, 0.7391319677044036, 0.8297077178013609)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.ReLU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.05 had success of (0.1655325600253852, 0.3544759017029228, 0.495469449635088)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.0001 had success of (0.44706131227303175, 0.6806755279765893, 0.7906427387793957)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.0005 had success of (0.5120050770369848, 0.7312343546169305, 0.8229735923562388)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.001 had success of (0.5179282868525896, 0.7381800232697528, 0.8289673165744104)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.005 had success of (0.5234636674540775, 0.7421640870147728, 0.8307654338398618)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.01 had success of (0.5197264041180412, 0.7384268236787364, 0.8286500017628601)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.ReLU'>) and learning rate 0.05 had success of (0.12551563656876918, 0.29757077883157634, 0.45034023199238443)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.0001 had success of (0.4493530303564503, 0.683284560871558, 0.7907837675845292)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.0005 had success of (0.5151077107499207, 0.733808130310616, 0.8255121108486408)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.001 had success of (0.5195148609103409, 0.7389204244967035, 0.8294961745936608)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.005 had success of (0.5214892641822092, 0.7401896837429045, 0.8302365758206114)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.01 had success of (0.5198674329231746, 0.7398371117300708, 0.8258294256601911)\n",
|
||||
"Model with activation layers (<class 'torch.nn.modules.activation.GELU'>, <class 'torch.nn.modules.activation.GELU'>) and learning rate 0.05 had success of (0.3762648520960406, 0.6283538412720798, 0.7500617001022459)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for activation_layer_combination in all_activation_combinations:\n",
|
||||
" for learning_rate in learning_rates:\n",
|
||||
" model = MLP(activation_layer_combination).to(device)\n",
|
||||
" optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
||||
" train_model(model, optimizer)\n",
|
||||
" results = test_model(model)\n",
|
||||
" print(\"Model with activation layers\", activation_layer_combination, \"and learning rate\", learning_rate, \"had success of\", results)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reuse same alphabet + mapping\n",
|
||||
"alphabet = list(\"abcdefghijklmnopqrstuvwxyz\")\n",
|
||||
"char_to_idx = {ch: idx for idx, ch in enumerate(alphabet)}\n",
|
||||
"PAD_IDX = len(alphabet) # index 26 for OOV/padding\n",
|
||||
"VOCAB_SIZE = len(alphabet) + 1 # 27 total (a–z + padding)\n",
|
||||
"CONTEXT_SIZE = 10\n",
|
||||
"\n",
|
||||
"idx_to_char = {idx: ch for ch, idx in char_to_idx.items()}\n",
|
||||
"idx_to_char[PAD_IDX] = \"_\" # for readability\n",
|
||||
"\n",
|
||||
"def preprocess_input(context: str) -> torch.Tensor:\n",
|
||||
" context = context.lower()\n",
|
||||
" padded = context.rjust(CONTEXT_SIZE, \"_\") # pad with underscores (or any 1-char symbol)\n",
|
||||
"\n",
|
||||
" indices = []\n",
|
||||
" for ch in padded[-CONTEXT_SIZE:]:\n",
|
||||
" idx = char_to_idx.get(ch, PAD_IDX) # if '_' or unknown → PAD_IDX (26)\n",
|
||||
" indices.append(idx)\n",
|
||||
"\n",
|
||||
" return torch.tensor(indices, dtype=torch.long).unsqueeze(0).to(device)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def predict_next_chars(model, context: str, top_k=5):\n",
|
||||
" model.eval()\n",
|
||||
" input_tensor = preprocess_input(context)\n",
|
||||
" with torch.no_grad():\n",
|
||||
" logits = model(input_tensor)\n",
|
||||
" probs = torch.softmax(logits, dim=-1)\n",
|
||||
" top_probs, top_indices = probs.topk(top_k, dim=-1)\n",
|
||||
"\n",
|
||||
" predictions = [(idx_to_char[idx.item()], top_probs[0, i].item()) for i, idx in enumerate(top_indices[0])]\n",
|
||||
" return predictions\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I: 89.74 %\n",
|
||||
"N: 4.42 %\n",
|
||||
"Y: 1.88 %\n",
|
||||
"M: 1.51 %\n",
|
||||
"B: 0.90 %\n",
|
||||
"E: 0.65 %\n",
|
||||
"G: 0.21 %\n",
|
||||
"R: 0.16 %\n",
|
||||
"L: 0.15 %\n",
|
||||
"O: 0.13 %\n",
|
||||
"C: 0.09 %\n",
|
||||
"U: 0.08 %\n",
|
||||
"A: 0.05 %\n",
|
||||
"V: 0.02 %\n",
|
||||
"S: 0.01 %\n",
|
||||
"F: 0.00 %\n",
|
||||
"H: 0.00 %\n",
|
||||
"T: 0.00 %\n",
|
||||
"W: 0.00 %\n",
|
||||
"P: 0.00 %\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"preds = predict_next_chars(model, \"susta\", top_k=20)\n",
|
||||
"for char, prob in preds:\n",
|
||||
" print(f\"{char.upper()}: {(prob * 100):.2f} %\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Model saving"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.save(model, \"mlp_full_model.pth\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.save(model.state_dict(), \"mlp_weights.pth\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
5817
model/out.txt
5817
model/out.txt
File diff suppressed because it is too large
Load Diff
Binary file not shown.
1024
model/training_results/raw_2.txt
Normal file
1024
model/training_results/raw_2.txt
Normal file
File diff suppressed because it is too large
Load Diff
3782
model/training_results/silu.json
Normal file
3782
model/training_results/silu.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,22 +3,13 @@
|
||||
from typing import Literal, List, Dict
|
||||
import numpy as np
|
||||
|
||||
INPUT_FILE: str = "./data/all_cleaned_words.txt"
|
||||
INPUT_FILE: str = "./data/all_cleaned_words_shuffled.txt"
|
||||
OUTPUT_FILE: str = "./data.npy"
|
||||
|
||||
alphabet: List[str] = list("abcdefghijklmnopqrstuvwxyz")
|
||||
vowels: set[str] = set("aeiouy")
|
||||
|
||||
char_to_index: Dict[str, int] = {ch: idx for idx, ch in enumerate(alphabet)}
|
||||
default_index: int = len(alphabet) # Out-of-vocabulary token (e.g., for "")
|
||||
|
||||
|
||||
def get_prev_type(c: str) -> Literal[0, 1, 2]:
|
||||
if c in vowels:
|
||||
return 1
|
||||
elif c in alphabet:
|
||||
return 2
|
||||
return 0
|
||||
default_index: int = len(alphabet) # 26 + 1 -> Unknown character
|
||||
|
||||
|
||||
def encode_letter(c: str) -> int:
|
||||
@ -43,12 +34,6 @@ def build_dataset(input_path: str) -> np.ndarray:
|
||||
# Append current char index (target for classification)
|
||||
features.append(encode_letter(curr_char))
|
||||
|
||||
# Word position features
|
||||
# is_start: int = 1 if i == 0 else 0
|
||||
prev1: str = prev_chars[-1]
|
||||
prev_type: int = get_prev_type(prev1)
|
||||
# word_length: int = i + 1
|
||||
|
||||
# features.extend([prev_type])
|
||||
all_features.append(features)
|
||||
|
||||
|
@ -3,14 +3,16 @@ name = "app"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"onnxruntime>=1.21.0",
|
||||
"pydantic-settings>=2.8.1",
|
||||
"pyside6>=6.8.3",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
train = [
|
||||
"onnx>=1.17.0",
|
||||
"torch>=2.6.0",
|
||||
"ipykernel>=6.29.5",
|
||||
"numpy>=2.2.4"
|
||||
|
102
uv.lock
generated
102
uv.lock
generated
@ -16,6 +16,8 @@ name = "app"
|
||||
version = "0.1.0"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "onnx" },
|
||||
{ name = "onnxruntime" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyside6" },
|
||||
]
|
||||
@ -38,6 +40,8 @@ requires-dist = [
|
||||
{ name = "matplotlib", marker = "extra == 'visualization'", specifier = ">=3.10.1" },
|
||||
{ name = "numpy", marker = "extra == 'train'", specifier = ">=2.2.4" },
|
||||
{ name = "numpy", marker = "extra == 'visualization'", specifier = ">=2.2.4" },
|
||||
{ name = "onnx", specifier = ">=1.17.0" },
|
||||
{ name = "onnxruntime", specifier = ">=1.21.0" },
|
||||
{ name = "pandas", marker = "extra == 'visualization'", specifier = ">=2.2.3" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.8.1" },
|
||||
{ name = "pyside6", specifier = ">=6.8.3" },
|
||||
@ -105,6 +109,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coloredlogs"
|
||||
version = "15.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "humanfriendly" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cc/c7/eed8f27100517e8c0e6b923d5f0845d0cb99763da6fdee00478f91db7325/coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0", size = 278520 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a7/06/3d6badcf13db419e25b07041d9c7b4a2c331d3f4e7134445ec5df57714cd/coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934", size = 46018 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "comm"
|
||||
version = "0.2.2"
|
||||
@ -211,6 +227,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de", size = 16215 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flatbuffers"
|
||||
version = "25.2.10"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e4/30/eb5dce7994fc71a2f685d98ec33cc660c0a5887db5610137e60d8cbc4489/flatbuffers-25.2.10.tar.gz", hash = "sha256:97e451377a41262f8d9bd4295cc836133415cc03d8cb966410a4af92eb00d26e", size = 22170 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/25/155f9f080d5e4bc0082edfda032ea2bc2b8fab3f4d25d46c1e9dd22a1a89/flatbuffers-25.2.10-py2.py3-none-any.whl", hash = "sha256:ebba5f4d5ea615af3f7fd70fc310636fbb2bbd1f566ac0a23d98dd412de50051", size = 30953 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fonttools"
|
||||
version = "4.56.0"
|
||||
@ -245,6 +270,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/56/53/eb690efa8513166adef3e0669afd31e95ffde69fb3c52ec2ac7223ed6018/fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3", size = 193615 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "humanfriendly"
|
||||
version = "10.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyreadline3", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", size = 360702 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipykernel"
|
||||
version = "6.29.5"
|
||||
@ -677,6 +714,48 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnx"
|
||||
version = "1.17.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
{ name = "protobuf" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9a/54/0e385c26bf230d223810a9c7d06628d954008a5e5e4b73ee26ef02327282/onnx-1.17.0.tar.gz", hash = "sha256:48ca1a91ff73c1d5e3ea2eef20ae5d0e709bb8a2355ed798ffc2169753013fd3", size = 12165120 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/dd/c416a11a28847fafb0db1bf43381979a0f522eb9107b831058fde012dd56/onnx-1.17.0-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:0e906e6a83437de05f8139ea7eaf366bf287f44ae5cc44b2850a30e296421f2f", size = 16651271 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/6c/f040652277f514ecd81b7251841f96caa5538365af7df07f86c6018cda2b/onnx-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d955ba2939878a520a97614bcf2e79c1df71b29203e8ced478fa78c9a9c63c2", size = 15907522 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/7c/67f4952d1b56b3f74a154b97d0dd0630d525923b354db117d04823b8b49b/onnx-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f3fb5cc4e2898ac5312a7dc03a65133dd2abf9a5e520e69afb880a7251ec97a", size = 16046307 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/20/6da11042d2ab870dfb4ce4a6b52354d7651b6b4112038b6d2229ab9904c4/onnx-1.17.0-cp312-cp312-win32.whl", hash = "sha256:317870fca3349d19325a4b7d1b5628f6de3811e9710b1e3665c68b073d0e68d7", size = 14424235 },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/55/c4d11bee1fdb0c4bd84b4e3562ff811a19b63266816870ae1f95567aa6e1/onnx-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:659b8232d627a5460d74fd3c96947ae83db6d03f035ac633e20cd69cfa029227", size = 14530453 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnxruntime"
|
||||
version = "1.21.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "coloredlogs" },
|
||||
{ name = "flatbuffers" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "sympy" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/21/593c9bc56002a6d1ea7c2236f4a648e081ec37c8d51db2383a9e83a63325/onnxruntime-1.21.0-cp312-cp312-macosx_13_0_universal2.whl", hash = "sha256:893d67c68ca9e7a58202fa8d96061ed86a5815b0925b5a97aef27b8ba246a20b", size = 33658780 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4a/b4/33ec675a8ac150478091262824413e5d4acc359e029af87f9152e7c1c092/onnxruntime-1.21.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:37b7445c920a96271a8dfa16855e258dc5599235b41c7bbde0d262d55bcc105f", size = 14159975 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/08/eead6895ed83b56711ca6c0d31d82f109401b9937558b425509e497d6fb4/onnxruntime-1.21.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9a04aafb802c1e5573ba4552f8babcb5021b041eb4cfa802c9b7644ca3510eca", size = 16019285 },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/39/e83d56e3c215713b5263cb4d4f0c69e3964bba11634233d8ae04fc7e6bf3/onnxruntime-1.21.0-cp312-cp312-win_amd64.whl", hash = "sha256:7f801318476cd7003d636a5b392f7a37c08b6c8d2f829773f3c3887029e03f32", size = 11760975 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/25/93f65617b06c741a58eeac9e373c99df443b02a774f4cb6511889757c0da/onnxruntime-1.21.0-cp313-cp313-macosx_13_0_universal2.whl", hash = "sha256:85718cbde1c2912d3a03e3b3dc181b1480258a229c32378408cace7c450f7f23", size = 33659581 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/03/6b6829ee8344490ab5197f39a6824499ed097d1fc8c85b1f91c0e6767819/onnxruntime-1.21.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:94dff3a61538f3b7b0ea9a06bc99e1410e90509c76e3a746f039e417802a12ae", size = 14160534 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/81/e280ddf05f83ad5e0d066ef08e31515b17bd50bb52ef2ea713d9e455e67a/onnxruntime-1.21.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c1e704b0eda5f2bbbe84182437315eaec89a450b08854b5a7762c85d04a28a0a", size = 16018947 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/ea/011dfc2536e46e2ea984d2c0256dc585ebb1352366dffdd98764f1f44ee4/onnxruntime-1.21.0-cp313-cp313-win_amd64.whl", hash = "sha256:19b630c6a8956ef97fb7c94948b17691167aa1aaf07b5f214fa66c3e4136c108", size = 11760731 },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/6b/a00f31322e91c610c7825377ef0cad884483c30d1370b896d57e7032e912/onnxruntime-1.21.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3995c4a2d81719623c58697b9510f8de9fa42a1da6b4474052797b0d712324fe", size = 14172215 },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/4b/98214f13ac1cd675dfc2713ba47b5722f55ce4fba526d2b2826f2682a42e/onnxruntime-1.21.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:36b18b8f39c0f84e783902112a0dd3c102466897f96d73bb83f6a6bff283a423", size = 15990612 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "24.2"
|
||||
@ -800,6 +879,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl", hash = "sha256:9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198", size = 387816 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "6.30.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c8/8c/cf2ac658216eebe49eaedf1e06bc06cbf6a143469236294a1171a51357c3/protobuf-6.30.2.tar.gz", hash = "sha256:35c859ae076d8c56054c25b59e5e59638d86545ed6e2b6efac6be0b6ea3ba048", size = 429315 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/be/85/cd53abe6a6cbf2e0029243d6ae5fb4335da2996f6c177bb2ce685068e43d/protobuf-6.30.2-cp310-abi3-win32.whl", hash = "sha256:b12ef7df7b9329886e66404bef5e9ce6a26b54069d7f7436a0853ccdeb91c103", size = 419148 },
|
||||
{ url = "https://files.pythonhosted.org/packages/97/e9/7b9f1b259d509aef2b833c29a1f3c39185e2bf21c9c1be1cd11c22cb2149/protobuf-6.30.2-cp310-abi3-win_amd64.whl", hash = "sha256:7653c99774f73fe6b9301b87da52af0e69783a2e371e8b599b3e9cb4da4b12b9", size = 431003 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/66/7f3b121f59097c93267e7f497f10e52ced7161b38295137a12a266b6c149/protobuf-6.30.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:0eb523c550a66a09a0c20f86dd554afbf4d32b02af34ae53d93268c1f73bc65b", size = 417579 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/89/bbb1bff09600e662ad5b384420ad92de61cab2ed0f12ace1fd081fd4c295/protobuf-6.30.2-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:50f32cc9fd9cb09c783ebc275611b4f19dfdfb68d1ee55d2f0c7fa040df96815", size = 317319 },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/50/1925de813499546bc8ab3ae857e3ec84efe7d2f19b34529d0c7c3d02d11d/protobuf-6.30.2-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:4f6c687ae8efae6cf6093389a596548214467778146b7245e886f35e1485315d", size = 316212 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/a1/93c2acf4ade3c5b557d02d500b06798f4ed2c176fa03e3c34973ca92df7f/protobuf-6.30.2-py3-none-any.whl", hash = "sha256:ae86b030e69a98e08c77beab574cbcb9fff6d031d57209f574a5aea1445f4b51", size = 167062 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psutil"
|
||||
version = "7.0.0"
|
||||
@ -930,6 +1023,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyreadline3"
|
||||
version = "3.5.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0f/49/4cea918a08f02817aabae639e3d0ac046fef9f9180518a3ad394e22da148/pyreadline3-3.5.4.tar.gz", hash = "sha256:8d57d53039a1c75adba8e50dd3d992b28143480816187ea5efbd5c78e6c885b7", size = 99839 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyside6"
|
||||
version = "6.8.3"
|
||||
|
Loading…
x
Reference in New Issue
Block a user