Switched to pytorch with multi layer perceptron

This commit is contained in:
Thastertyn 2025-03-23 22:09:28 +01:00
parent 13884a933b
commit e44a6cf86e
3 changed files with 247 additions and 465 deletions

1
.gitignore vendored
View File

@ -168,3 +168,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
data.npy

View File

@ -9,12 +9,46 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import pandas as pd\n", "import numpy as np\n",
"data = pd.read_csv(\"./out.txt\", sep=',')" "import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader, TensorDataset, random_split"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"data = np.load(\"./data.npy\")"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"CONTEXT_SIZE = 10\n",
"ALPHABET = list(\"abcdefghijklmnopqrstuvwxyz\")\n",
"ALPHABET_SIZE = len(ALPHABET)\n",
"TRAINING_DATA_SIZE = 0.9\n",
"\n",
"\n",
"# Derived values\n",
"PREV_LETTER_FEATURES = CONTEXT_SIZE * ALPHABET_SIZE\n",
"CURR_LETTER_FEATURES = ALPHABET_SIZE\n",
"OTHER_FEATURES = 3 # is_start, prev_type, word_length\n",
"\n",
"TOTAL_FEATURES = PREV_LETTER_FEATURES + CURR_LETTER_FEATURES + OTHER_FEATURES\n",
"\n",
"INPUT_SIZE = PREV_LETTER_FEATURES + OTHER_FEATURES\n",
"OUTPUT_SIZE = ALPHABET_SIZE"
] ]
}, },
{ {
@ -33,31 +67,44 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 52,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"input_features = ['previous_5','previous_4','previous_3','previous_2','previous_1','is_start','previous_type','word_length']\n", "X = np.hstack([\n",
"target_feature = 'current'" " data[:, :PREV_LETTER_FEATURES],\n",
" data[:, PREV_LETTER_FEATURES + CURR_LETTER_FEATURES:TOTAL_FEATURES]\n",
"])\n",
"\n",
"# Extract current letter (one-hot target)\n",
"y_onehot = data[:, PREV_LETTER_FEATURES:PREV_LETTER_FEATURES + CURR_LETTER_FEATURES]\n",
"y = np.argmax(y_onehot, axis=1)\n",
"\n",
"# Torch dataset\n",
"X_tensor = torch.tensor(X, dtype=torch.float32)\n",
"y_tensor = torch.tensor(y, dtype=torch.long)\n",
"\n",
"dataset = TensorDataset(X_tensor, y_tensor)\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 55,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn.model_selection import train_test_split" "train_len = int(TRAINING_DATA_SIZE * len(dataset))\n",
"train_set, test_set = random_split(dataset, [train_len, len(dataset) - train_len])"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 56,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"test_size = 0.1 # @param {\"type\":\"number\",\"placeholder\":\"0.1\"}\n", "train_loader = DataLoader(train_set, batch_size=128, shuffle=True)\n",
"X_train, X_test, y_train, y_test = train_test_split(data[input_features], data[target_feature], test_size=test_size)" "test_loader = DataLoader(test_set, batch_size=128)"
] ]
}, },
{ {
@ -69,458 +116,109 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 79,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn.linear_model import LogisticRegression" "class MLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.net = nn.Sequential(\n",
" nn.Linear(INPUT_SIZE, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, 128),\n",
" nn.ReLU(),\n",
" nn.Linear(128, OUTPUT_SIZE)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.net(x)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 80,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "name": "stdout",
"text/html": [ "output_type": "stream",
"<style>#sk-container-id-9 {\n", "text": [
" /* Definition of color scheme common for light and dark mode */\n", "Using device: cuda\n"
" --sklearn-color-text: black;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-9 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-9 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-9 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-9 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-9 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-9 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-9 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: block;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
"}\n",
"\n",
"#sk-container-id-9 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-9 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-9 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-9 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-9 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-9 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-9 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-9 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-9 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-9 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-9 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-9 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-9 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-9 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 1ex;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-9 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-9 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-9 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-9 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-9\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LogisticRegression(max_iter=10000, multi_class=&#x27;multinomial&#x27;, n_jobs=10,\n",
" solver=&#x27;saga&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-9\" type=\"checkbox\" checked><label for=\"sk-estimator-id-9\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;LogisticRegression<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.linear_model.LogisticRegression.html\">?<span>Documentation for LogisticRegression</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>LogisticRegression(max_iter=10000, multi_class=&#x27;multinomial&#x27;, n_jobs=10,\n",
" solver=&#x27;saga&#x27;)</pre></div> </div></div></div></div>"
],
"text/plain": [
"LogisticRegression(max_iter=10000, multi_class='multinomial', n_jobs=10,\n",
" solver='saga')"
] ]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
"model = LogisticRegression(multi_class=\"multinomial\", solver=\"saga\", max_iter=10_000, n_jobs=10)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.fit(X_train, y_train)" "print(f\"Using device: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create new model which predicts probability"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 81,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"y_pred = model.predict(X_test)" "model = MLP().to(device)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Loss: 4277.8506\n",
"Epoch 2, Loss: 3647.3064\n",
"Epoch 3, Loss: 3421.2898\n",
"Epoch 4, Loss: 3289.9248\n",
"Epoch 5, Loss: 3203.0331\n",
"Epoch 6, Loss: 3141.4064\n",
"Epoch 7, Loss: 3099.4711\n",
"Epoch 8, Loss: 3065.2254\n",
"Epoch 9, Loss: 3040.1093\n",
"Epoch 10, Loss: 3016.0812\n",
"Epoch 11, Loss: 2998.2589\n",
"Epoch 12, Loss: 2982.5763\n",
"Epoch 13, Loss: 2968.7752\n",
"Epoch 14, Loss: 2956.6091\n",
"Epoch 15, Loss: 2945.3793\n",
"Epoch 16, Loss: 2935.6520\n",
"Epoch 17, Loss: 2928.2420\n",
"Epoch 18, Loss: 2918.6128\n",
"Epoch 19, Loss: 2912.0454\n",
"Epoch 20, Loss: 2904.7236\n",
"Epoch 21, Loss: 2898.5873\n",
"Epoch 22, Loss: 2893.1154\n",
"Epoch 23, Loss: 2887.1008\n",
"Epoch 24, Loss: 2884.5473\n",
"Epoch 25, Loss: 2879.1589\n",
"Epoch 26, Loss: 2874.9795\n",
"Epoch 27, Loss: 2870.3030\n",
"Epoch 28, Loss: 2867.0953\n",
"Epoch 29, Loss: 2863.1449\n",
"Epoch 30, Loss: 2859.8749\n"
]
}
],
"source": [
"for epoch in range(30):\n",
" model.train()\n",
" total_loss = 0\n",
" for batch_X, batch_y in train_loader:\n",
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(batch_X)\n",
" loss = criterion(output, batch_y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" print(f\"Epoch {epoch+1}, Loss: {total_loss:.4f}\")"
] ]
}, },
{ {
@ -532,43 +230,56 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"acc = accuracy_score(y_test, y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Accuracy: 0.211\n" "Top-1 Accuracy: 51.27%\n",
"Top-3 Accuracy: 73.68%\n",
"Top-5 Accuracy: 82.94%\n"
] ]
} }
], ],
"source": [ "source": [
"print(f\"Accuracy: {acc:.3f}\")" "model.eval()\n",
"correct_top1 = 0\n",
"correct_top3 = 0\n",
"correct_top5 = 0\n",
"total = 0\n",
"\n",
"with torch.no_grad():\n",
" for batch_X, batch_y in test_loader:\n",
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
" outputs = model(batch_X) # shape: [batch_size, 26]\n",
"\n",
" # Get top-5 predictions\n",
" _, top_preds = outputs.topk(5, dim=1) # shape: [batch_size, 5]\n",
"\n",
" for true, top5 in zip(batch_y, top_preds):\n",
" total += 1\n",
" if true == top5[0]:\n",
" correct_top1 += 1\n",
" if true in top5[:3]:\n",
" correct_top3 += 1\n",
" if true in top5:\n",
" correct_top5 += 1\n",
"\n",
"top1_acc = correct_top1 / total\n",
"top3_acc = correct_top3 / total\n",
"top5_acc = correct_top5 / total\n",
"\n",
"print(f\"Top-1 Accuracy: {top1_acc * 100:.2f}%\")\n",
"print(f\"Top-3 Accuracy: {top3_acc * 100:.2f}%\")\n",
"print(f\"Top-5 Accuracy: {top5_acc * 100:.2f}%\")\n"
] ]
} }
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": ".venv",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

70
transform.py Executable file
View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
from typing import Literal, List, Dict
import numpy as np
INPUT_FILE: str = "./data/all_cleaned_words.txt"
OUTPUT_FILE: str = "./data.npy"
alphabet: List[str] = list("abcdefghijklmnopqrstuvwxyz")
vowels: set[str] = set("aeiouy")
char_to_onehot: Dict[str, List[int]] = {
ch: [1 if i == idx else 0 for i in range(26)]
for idx, ch in enumerate(alphabet)
}
empty_vec: List[int] = [0] * 26
def get_prev_type(c: str) -> Literal[0, 1, 2]:
if c in vowels:
return 1
elif c in alphabet:
return 2
return 0
def encode_letter(c: str) -> List[int]:
return char_to_onehot.get(c, empty_vec)
def build_dataset(input_path: str) -> np.ndarray:
all_features: List[List[int]] = []
with open(input_path, 'r') as input_file:
for line in input_file:
word: str = line.strip().lower()
prev_chars: List[str] = [""] * 10 # Updated: now 10-character context
for i, curr_char in enumerate(word):
features: List[int] = []
# One-hot encode 10 previous characters
for prev in prev_chars:
features.extend(encode_letter(prev))
# One-hot encode current character
features.extend(encode_letter(curr_char))
# Word position features
is_start: int = 1 if i == 0 else 0
features.append(is_start)
prev1: str = prev_chars[-1]
features.append(get_prev_type(prev1))
word_length: int = i + 1
features.append(word_length)
all_features.append(features)
# Shift history
prev_chars = prev_chars[1:] + [curr_char]
return np.array(all_features, dtype=np.int32)
if __name__ == "__main__":
dataset: np.ndarray = build_dataset(INPUT_FILE)
np.save(OUTPUT_FILE, dataset)
print(f"Saved dataset shape: {dataset.shape}{OUTPUT_FILE}")