omega/model/mlp.ipynb

602 lines
153 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Import and data processing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import all necessary libraries"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader, TensorDataset, random_split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load the training data"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"data = np.load(\"./data.npy\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define constants that describe the data and model"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"CONTEXT_SIZE = 10\n",
"ALPHABET = list(\"abcdefghijklmnopqrstuvwxyz\")\n",
"ALPHABET_SIZE = len(ALPHABET)\n",
"TRAINING_DATA_SIZE = 0.9\n",
"\n",
"# +1 is for unknown characters\n",
"VOCAB_SIZE = ALPHABET_SIZE + 1\n",
"\n",
"EMBEDDING_DIM = 16\n",
"\n",
"INPUT_SEQ_LEN = CONTEXT_SIZE\n",
"OUTPUT_SIZE = VOCAB_SIZE\n",
"\n",
"BATCH_SIZE = 2048 * 2 * 2\n",
"\n",
"EPOCHS = 50\n",
"LEARNING_RATE = 1e-3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Process the data"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# Input: embeddings of the previous 10 letters\n",
"# shape: (num_samples, CONTEXT_SIZE)\n",
"X = data[:, :CONTEXT_SIZE]\n",
"\n",
"# Target: current letter index\n",
"# shape: (num_samples,)\n",
"y = data[:, CONTEXT_SIZE]\n",
"\n",
"# Torch dataset (important: use long/int64 for indices)\n",
"X_tensor = torch.tensor(X, dtype=torch.long) # for nn.Embedding\n",
"y_tensor = torch.tensor(y, dtype=torch.long) # for classification target\n",
"\n",
"dataset = TensorDataset(X_tensor, y_tensor)\n",
"\n",
"train_len = int(TRAINING_DATA_SIZE * len(dataset))\n",
"train_set, test_set = random_split(dataset, [train_len, len(dataset) - train_len])\n",
"\n",
"train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)\n",
"test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Model"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"@dataclass\n",
"class MlpHiddenLayer():\n",
" size: int\n",
" activation_function: nn.Module"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" def __init__(self, *,\n",
" embedding_count: int,\n",
" embedding_dimension_size: int,\n",
" input_shape_size: int,\n",
" output_shape: int,\n",
" hidden_layers: list[MlpHiddenLayer]\n",
" ):\n",
"\n",
" super().__init__()\n",
"\n",
" self.embedding_count = embedding_count\n",
" self.embedding_dimension_size = embedding_dimension_size\n",
" self.input_shape_size = input_shape_size\n",
" self.output_shape = output_shape\n",
" self.hidden_layers = hidden_layers\n",
"\n",
" layers = [\n",
" nn.Embedding(num_embeddings=embedding_count, embedding_dim=embedding_dimension_size),\n",
" nn.Flatten(),\n",
" ]\n",
"\n",
" input_dimensions = input_shape_size\n",
"\n",
" for layer in hidden_layers:\n",
" layers.append(nn.Linear(input_dimensions, layer.size))\n",
" layers.append(layer.activation_function())\n",
" input_dimensions = layer.size\n",
"\n",
" layers.append(nn.Linear(input_dimensions, output_shape))\n",
"\n",
" self.net = nn.Sequential(*layers)\n",
"\n",
" def forward(self, x):\n",
" return self.net(x)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create fresh instance of the model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mass testing all hyperparams"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 73\u001b[39m\n\u001b[32m 64\u001b[39m model = MLP(\n\u001b[32m 65\u001b[39m hidden_layers=hidden_layers,\n\u001b[32m 66\u001b[39m embedding_count=VOCAB_SIZE,\n\u001b[32m (...)\u001b[39m\u001b[32m 69\u001b[39m output_shape=OUTPUT_SIZE,\n\u001b[32m 70\u001b[39m ).to(device)\n\u001b[32m 72\u001b[39m optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n\u001b[32m---> \u001b[39m\u001b[32m73\u001b[39m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 74\u001b[39m top1, top3, top5 = test_model(model)\n\u001b[32m 76\u001b[39m results.append({\n\u001b[32m 77\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mconfig_id\u001b[39m\u001b[33m\"\u001b[39m: config_id,\n\u001b[32m 78\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mactivation\u001b[39m\u001b[33m\"\u001b[39m: act_fn.\u001b[34m__name__\u001b[39m,\n\u001b[32m (...)\u001b[39m\u001b[32m 83\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mtop5_acc\u001b[39m\u001b[33m\"\u001b[39m: top5\n\u001b[32m 84\u001b[39m })\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 16\u001b[39m, in \u001b[36mtrain_model\u001b[39m\u001b[34m(model, optimizer)\u001b[39m\n\u001b[32m 14\u001b[39m model.train()\n\u001b[32m 15\u001b[39m total_loss = \u001b[32m0\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mbatch_X\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_y\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 17\u001b[39m \u001b[43m \u001b[49m\u001b[43mbatch_X\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_y\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_X\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_y\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 18\u001b[39m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mzero_grad\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Code/Skola/4-rocnik/programove-vybaveni/omega/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py:708\u001b[39m, in \u001b[36m_BaseDataLoaderIter.__next__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 705\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 706\u001b[39m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[32m 707\u001b[39m \u001b[38;5;28mself\u001b[39m._reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m708\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 709\u001b[39m \u001b[38;5;28mself\u001b[39m._num_yielded += \u001b[32m1\u001b[39m\n\u001b[32m 710\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m 711\u001b[39m \u001b[38;5;28mself\u001b[39m._dataset_kind == _DatasetKind.Iterable\n\u001b[32m 712\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 713\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._num_yielded > \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called\n\u001b[32m 714\u001b[39m ):\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Code/Skola/4-rocnik/programove-vybaveni/omega/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py:764\u001b[39m, in \u001b[36m_SingleProcessDataLoaderIter._next_data\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 762\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 763\u001b[39m index = \u001b[38;5;28mself\u001b[39m._next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m764\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[32m 765\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._pin_memory:\n\u001b[32m 766\u001b[39m data = _utils.pin_memory.pin_memory(data, \u001b[38;5;28mself\u001b[39m._pin_memory_device)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Code/Skola/4-rocnik/programove-vybaveni/omega/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py:50\u001b[39m, in \u001b[36m_MapDatasetFetcher.fetch\u001b[39m\u001b[34m(self, possibly_batched_index)\u001b[39m\n\u001b[32m 48\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.auto_collation:\n\u001b[32m 49\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m.dataset, \u001b[33m\"\u001b[39m\u001b[33m__getitems__\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m.dataset.__getitems__:\n\u001b[32m---> \u001b[39m\u001b[32m50\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m.\u001b[49m\u001b[43m__getitems__\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpossibly_batched_index\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 51\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 52\u001b[39m data = [\u001b[38;5;28mself\u001b[39m.dataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Code/Skola/4-rocnik/programove-vybaveni/omega/.venv/lib/python3.12/site-packages/torch/utils/data/dataset.py:420\u001b[39m, in \u001b[36mSubset.__getitems__\u001b[39m\u001b[34m(self, indices)\u001b[39m\n\u001b[32m 418\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.dataset.__getitems__([\u001b[38;5;28mself\u001b[39m.indices[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m indices]) \u001b[38;5;66;03m# type: ignore[attr-defined]\u001b[39;00m\n\u001b[32m 419\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m420\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mindices\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m indices]\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Code/Skola/4-rocnik/programove-vybaveni/omega/.venv/lib/python3.12/site-packages/torch/utils/data/dataset.py:211\u001b[39m, in \u001b[36mTensorDataset.__getitem__\u001b[39m\u001b[34m(self, index)\u001b[39m\n\u001b[32m 210\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, index):\n\u001b[32m--> \u001b[39m\u001b[32m211\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m[\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: "
]
}
],
"source": [
"from itertools import product\n",
"\n",
"MHL = MlpHiddenLayer\n",
"\n",
"learning_rates = [1e-2, 5e-3, 1e-3, 5e-4, 1e-4]\n",
"layer_sizes = [32, 64, 128, 256]\n",
"depths = [1, 2, 3]\n",
"activation_functions = [nn.ReLU]\n",
"\n",
"all_models = []\n",
"\n",
"def train_model(model, optimizer):\n",
" for epoch in range(EPOCHS):\n",
" model.train()\n",
" total_loss = 0\n",
" for batch_X, batch_y in train_loader:\n",
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(batch_X)\n",
" loss = criterion(output, batch_y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item() / batch_X.size(0)\n",
"\n",
"def test_model(model):\n",
" model.eval()\n",
" correct_top1 = 0\n",
" correct_top3 = 0\n",
" correct_top5 = 0\n",
" total = 0\n",
"\n",
" with torch.no_grad():\n",
" for batch_X, batch_y in test_loader:\n",
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
" outputs = model(batch_X)\n",
"\n",
" _, top_preds = outputs.topk(5, dim=1)\n",
"\n",
" for true, top5 in zip(batch_y, top_preds):\n",
" total += 1\n",
" if true == top5[0]:\n",
" correct_top1 += 1\n",
" if true in top5[:3]:\n",
" correct_top3 += 1\n",
" if true in top5:\n",
" correct_top5 += 1\n",
"\n",
" top1_acc = correct_top1 / total\n",
" top3_acc = correct_top3 / total\n",
" top5_acc = correct_top5 / total\n",
"\n",
" return top1_acc, top3_acc, top5_acc\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"results = []\n",
"\n",
"config_id = 0\n",
"for act_fn in activation_functions:\n",
" for depth in depths:\n",
" for size_combo in product(layer_sizes, repeat=depth):\n",
" for learning_rate in learning_rates:\n",
" hidden_layers = [MlpHiddenLayer(size=s, activation_function=act_fn) for s in size_combo]\n",
" model = MLP(\n",
" hidden_layers=hidden_layers,\n",
" embedding_count=VOCAB_SIZE,\n",
" embedding_dimension_size=EMBEDDING_DIM,\n",
" input_shape_size=CONTEXT_SIZE * EMBEDDING_DIM,\n",
" output_shape=OUTPUT_SIZE,\n",
" ).to(device)\n",
"\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
" train_model(model, optimizer)\n",
" top1, top3, top5 = test_model(model)\n",
"\n",
" results.append({\n",
" \"config_id\": config_id,\n",
" \"activation\": act_fn.__name__,\n",
" \"layer_sizes\": size_combo,\n",
" \"learning_rate\": learning_rate,\n",
" \"top1_acc\": top1,\n",
" \"top3_acc\": top3,\n",
" \"top5_acc\": top5\n",
" })\n",
"\n",
" print(f\"[#{config_id}] {act_fn.__name__} {size_combo} lr={learning_rate:.0e} → top1={top1:.2f}, top3={top3:.2f}, top5={top5:.2f}\")\n",
" config_id += 1\n",
"\n",
" del model\n",
" torch.cuda.empty_cache()\n",
"\n",
"\n",
"print(results)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model training"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"MHL = MlpHiddenLayer\n",
"Relu = nn.ReLU\n",
"Gelu = nn.GELU\n",
"Silu = nn.SiLU\n",
"\n",
"sizes = [256, 128]\n",
"\n",
"\n",
"model = MLP(\n",
" hidden_layers=[MHL(size=size, activation_function=Relu) for size in sizes],\n",
" embedding_count=VOCAB_SIZE,\n",
" embedding_dimension_size=EMBEDDING_DIM,\n",
" input_shape_size=CONTEXT_SIZE * EMBEDDING_DIM,\n",
" output_shape=OUTPUT_SIZE,\n",
" ).to(device)\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Epoch #01] - Loss: 1.48195 | Accuracy: 51.54%\n",
"[Epoch #02] - Loss: 1.48260 | Accuracy: 51.51%\n",
"# Small loss difference detected (1/5)\n",
"[Epoch #03] - Loss: 1.48176 | Accuracy: 51.56%\n",
"# Small loss difference detected (2/5)\n",
"[Epoch #04] - Loss: 1.48150 | Accuracy: 51.54%\n",
"# Small loss difference detected (3/5)\n",
"[Epoch #05] - Loss: 1.48147 | Accuracy: 51.56%\n",
"# Small loss difference detected (4/5)\n",
"[Epoch #06] - Loss: 1.48097 | Accuracy: 51.60%\n",
"# Small loss difference detected (5/5)\n",
"# Loss has been too stagnant for 5 epochs.\n",
"## Ending now\n"
]
}
],
"source": [
"prev_loss = float(\"inf\")\n",
"small_change_count = 0\n",
"SMALL_CHANGE_COUNT_TRIGGER = 5\n",
"\n",
"TOO_SMALL_CHANGE = 1e-3\n",
"\n",
"for epoch in range(EPOCHS):\n",
" if small_change_count >= SMALL_CHANGE_COUNT_TRIGGER:\n",
" print(f\"# Loss has been too stagnant for {SMALL_CHANGE_COUNT_TRIGGER} epochs.\\n## Ending now\")\n",
" break\n",
"\n",
" model.train()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"\n",
" for batch_X, batch_y in train_loader:\n",
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" output = model(batch_X)\n",
" loss = criterion(output, batch_y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item() * batch_X.size(0) # Multiply by batch size to sum loss\n",
" preds = torch.argmax(output, dim=1)\n",
" correct += (preds == batch_y).sum().item()\n",
" total += batch_X.size(0)\n",
"\n",
" avg_loss = total_loss / total\n",
" accuracy = correct / total * 100\n",
" print(f\"[Epoch #{(epoch+1):02}] - Loss: {avg_loss:.5f} | Accuracy: {accuracy:.2f}%\")\n",
"\n",
" if prev_loss - avg_loss < TOO_SMALL_CHANGE:\n",
" small_change_count += 1\n",
" print(f\"# Small loss difference detected ({small_change_count}/{SMALL_CHANGE_COUNT_TRIGGER})\")\n",
" else:\n",
" if small_change_count > 0:\n",
" print(\"# Loss difference increased again. Resetting counter\")\n",
" small_change_count = 0\n",
"\n",
" prev_loss = avg_loss\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Top 1 prediction accuracy: 49.92%\n",
"Top 3 prediction accuracy: 73.00%\n",
"Top 5 prediction accuracy: 82.72%\n"
]
}
],
"source": [
"model.eval()\n",
"correct_top1 = 0\n",
"correct_top3 = 0\n",
"correct_top5 = 0\n",
"total = 0\n",
"\n",
"with torch.no_grad():\n",
" for batch_X, batch_y in test_loader:\n",
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
" outputs = model(batch_X)\n",
"\n",
" _, top_preds = outputs.topk(5, dim=1)\n",
"\n",
" for true, top5 in zip(batch_y, top_preds):\n",
" total += 1\n",
" if true == top5[0]:\n",
" correct_top1 += 1\n",
" if true in top5[:3]:\n",
" correct_top3 += 1\n",
" if true in top5:\n",
" correct_top5 += 1\n",
"\n",
"top1_acc = correct_top1 / total\n",
"top3_acc = correct_top3 / total\n",
"top5_acc = correct_top5 / total\n",
"\n",
"print(f\"Top 1 prediction accuracy: {(top1_acc * 100):.2f}%\")\n",
"print(f\"Top 3 prediction accuracy: {(top3_acc * 100):.2f}%\")\n",
"print(f\"Top 5 prediction accuracy: {(top5_acc * 100):.2f}%\")"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"embeddings = model.net[0].weight.detach().cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 800x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def pca_torch(X: torch.Tensor, n_components=3):\n",
" # Center the data\n",
" X = X - X.mean(dim=0)\n",
" # Compute SVD\n",
" U, S, V = torch.pca_lowrank(X, q=n_components)\n",
" return torch.matmul(X, V[:, :n_components])\n",
"\n",
"# Extract and reduce embeddings\n",
"embeddings = model.net[0].weight.detach().cpu()\n",
"reduced = pca_torch(embeddings, n_components=3)\n",
"\n",
"# Plot using matplotlib 3D\n",
"import matplotlib.pyplot as plt\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"\n",
"fig = plt.figure(figsize=(8, 6))\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"\n",
"for i, label in enumerate(ALPHABET + ['<UNK>']):\n",
" x, y, z = reduced[i]\n",
" ax.scatter(x.item(), y.item(), z.item(), s=50)\n",
" ax.text(x.item(), y.item(), z.item(), label, fontsize=9)\n",
"\n",
"ax.set_title(\"3D PCA Projection of Character Embeddings\")\n",
"ax.set_xlabel(\"PC 1\")\n",
"ax.set_ylabel(\"PC 2\")\n",
"ax.set_zlabel(\"PC 3\")\n",
"plt.tight_layout()\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Export successful -> 'model.onnx'\n"
]
}
],
"source": [
"import torch.onnx\n",
"\n",
"FILENAME = \"model.onnx\"\n",
"model.eval()\n",
"\n",
"dummy_input = torch.randint(0, VOCAB_SIZE, (1, CONTEXT_SIZE), dtype=torch.long).to(device)\n",
"\n",
"torch.onnx.export(\n",
" model,\n",
" dummy_input,\n",
" FILENAME,\n",
" input_names=[\"input\"],\n",
" output_names=[\"output\"],\n",
" dynamic_axes={\n",
" \"input\": {0: \"batch_size\"},\n",
" \"output\": {0: \"batch_size\"},\n",
" },\n",
" opset_version=13\n",
")\n",
"\n",
"print(f\"Export successful -> '{FILENAME}'\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}