More tweaks to the model, most likely final iteration at 52% top 1 accuracy
This commit is contained in:
parent
e44a6cf86e
commit
09b1baf485
BIN
mlp_full_model.pth
Normal file
BIN
mlp_full_model.pth
Normal file
Binary file not shown.
BIN
mlp_weights.pth
Normal file
BIN
mlp_weights.pth
Normal file
Binary file not shown.
215
notebook.ipynb
215
notebook.ipynb
@ -9,7 +9,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -21,7 +21,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -30,7 +30,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -39,16 +39,11 @@
|
||||
"ALPHABET_SIZE = len(ALPHABET)\n",
|
||||
"TRAINING_DATA_SIZE = 0.9\n",
|
||||
"\n",
|
||||
"VOCAB_SIZE = ALPHABET_SIZE + 1 # 26 letters + 1 for unknown\n",
|
||||
"EMBEDDING_DIM = 16\n",
|
||||
"\n",
|
||||
"# Derived values\n",
|
||||
"PREV_LETTER_FEATURES = CONTEXT_SIZE * ALPHABET_SIZE\n",
|
||||
"CURR_LETTER_FEATURES = ALPHABET_SIZE\n",
|
||||
"OTHER_FEATURES = 3 # is_start, prev_type, word_length\n",
|
||||
"\n",
|
||||
"TOTAL_FEATURES = PREV_LETTER_FEATURES + CURR_LETTER_FEATURES + OTHER_FEATURES\n",
|
||||
"\n",
|
||||
"INPUT_SIZE = PREV_LETTER_FEATURES + OTHER_FEATURES\n",
|
||||
"OUTPUT_SIZE = ALPHABET_SIZE"
|
||||
"INPUT_SEQ_LEN = CONTEXT_SIZE\n",
|
||||
"OUTPUT_SIZE = VOCAB_SIZE"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -67,29 +62,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X = np.hstack([\n",
|
||||
" data[:, :PREV_LETTER_FEATURES],\n",
|
||||
" data[:, PREV_LETTER_FEATURES + CURR_LETTER_FEATURES:TOTAL_FEATURES]\n",
|
||||
"])\n",
|
||||
"X = data[:, :CONTEXT_SIZE] # shape: (num_samples, CONTEXT_SIZE)\n",
|
||||
"\n",
|
||||
"# Extract current letter (one-hot target)\n",
|
||||
"y_onehot = data[:, PREV_LETTER_FEATURES:PREV_LETTER_FEATURES + CURR_LETTER_FEATURES]\n",
|
||||
"y = np.argmax(y_onehot, axis=1)\n",
|
||||
"# Target: current letter index\n",
|
||||
"y = data[:, CONTEXT_SIZE] # shape: (num_samples,)\n",
|
||||
"\n",
|
||||
"# Torch dataset\n",
|
||||
"X_tensor = torch.tensor(X, dtype=torch.float32)\n",
|
||||
"y_tensor = torch.tensor(y, dtype=torch.long)\n",
|
||||
"# Torch dataset (important: use long/int64 for indices)\n",
|
||||
"X_tensor = torch.tensor(X, dtype=torch.long) # for nn.Embedding\n",
|
||||
"y_tensor = torch.tensor(y, dtype=torch.long) # for classification target\n",
|
||||
"\n",
|
||||
"dataset = TensorDataset(X_tensor, y_tensor)\n"
|
||||
"dataset = TensorDataset(X_tensor, y_tensor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 55,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -99,7 +90,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 56,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -116,7 +107,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 79,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -124,7 +115,9 @@
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" self.net = nn.Sequential(\n",
|
||||
" nn.Linear(INPUT_SIZE, 256),\n",
|
||||
" nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=EMBEDDING_DIM),\n",
|
||||
" nn.Flatten(),\n",
|
||||
" nn.Linear(CONTEXT_SIZE * EMBEDDING_DIM, 256),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Linear(256, 128),\n",
|
||||
" nn.ReLU(),\n",
|
||||
@ -137,14 +130,22 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 80,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using device: cuda\n"
|
||||
"Using device: cpu\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/thastertyn/Code/Skola/4-rocnik/programove-vybaveni/omega/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:129: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)\n",
|
||||
" return torch._C._cuda_getDeviceCount() > 0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -155,54 +156,55 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 81,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = MLP().to(device)\n",
|
||||
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
|
||||
"\n",
|
||||
"criterion = nn.CrossEntropyLoss()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 82,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1, Loss: 4277.8506\n",
|
||||
"Epoch 2, Loss: 3647.3064\n",
|
||||
"Epoch 3, Loss: 3421.2898\n",
|
||||
"Epoch 4, Loss: 3289.9248\n",
|
||||
"Epoch 5, Loss: 3203.0331\n",
|
||||
"Epoch 6, Loss: 3141.4064\n",
|
||||
"Epoch 7, Loss: 3099.4711\n",
|
||||
"Epoch 8, Loss: 3065.2254\n",
|
||||
"Epoch 9, Loss: 3040.1093\n",
|
||||
"Epoch 10, Loss: 3016.0812\n",
|
||||
"Epoch 11, Loss: 2998.2589\n",
|
||||
"Epoch 12, Loss: 2982.5763\n",
|
||||
"Epoch 13, Loss: 2968.7752\n",
|
||||
"Epoch 14, Loss: 2956.6091\n",
|
||||
"Epoch 15, Loss: 2945.3793\n",
|
||||
"Epoch 16, Loss: 2935.6520\n",
|
||||
"Epoch 17, Loss: 2928.2420\n",
|
||||
"Epoch 18, Loss: 2918.6128\n",
|
||||
"Epoch 19, Loss: 2912.0454\n",
|
||||
"Epoch 20, Loss: 2904.7236\n",
|
||||
"Epoch 21, Loss: 2898.5873\n",
|
||||
"Epoch 22, Loss: 2893.1154\n",
|
||||
"Epoch 23, Loss: 2887.1008\n",
|
||||
"Epoch 24, Loss: 2884.5473\n",
|
||||
"Epoch 25, Loss: 2879.1589\n",
|
||||
"Epoch 26, Loss: 2874.9795\n",
|
||||
"Epoch 27, Loss: 2870.3030\n",
|
||||
"Epoch 28, Loss: 2867.0953\n",
|
||||
"Epoch 29, Loss: 2863.1449\n",
|
||||
"Epoch 30, Loss: 2859.8749\n"
|
||||
"Epoch 1, Loss: 4068.5562\n",
|
||||
"Epoch 2, Loss: 3446.1109\n",
|
||||
"Epoch 3, Loss: 3260.1651\n",
|
||||
"Epoch 4, Loss: 3165.0248\n",
|
||||
"Epoch 5, Loss: 3101.6501\n",
|
||||
"Epoch 6, Loss: 3054.4113\n",
|
||||
"Epoch 7, Loss: 3021.7103\n",
|
||||
"Epoch 8, Loss: 2994.6145\n",
|
||||
"Epoch 9, Loss: 2973.1683\n",
|
||||
"Epoch 10, Loss: 2955.0090\n",
|
||||
"Epoch 11, Loss: 2940.0807\n",
|
||||
"Epoch 12, Loss: 2928.2814\n",
|
||||
"Epoch 13, Loss: 2916.9362\n",
|
||||
"Epoch 14, Loss: 2905.9567\n",
|
||||
"Epoch 15, Loss: 2897.3687\n",
|
||||
"Epoch 16, Loss: 2890.6869\n",
|
||||
"Epoch 17, Loss: 2882.7104\n",
|
||||
"Epoch 18, Loss: 2876.6815\n",
|
||||
"Epoch 19, Loss: 2870.7298\n",
|
||||
"Epoch 20, Loss: 2865.6343\n",
|
||||
"Epoch 21, Loss: 2860.5506\n",
|
||||
"Epoch 22, Loss: 2856.7977\n",
|
||||
"Epoch 23, Loss: 2852.8814\n",
|
||||
"Epoch 24, Loss: 2847.7687\n",
|
||||
"Epoch 25, Loss: 2846.0855\n",
|
||||
"Epoch 26, Loss: 2842.2640\n",
|
||||
"Epoch 27, Loss: 2838.4780\n",
|
||||
"Epoch 28, Loss: 2836.9773\n",
|
||||
"Epoch 29, Loss: 2833.8416\n",
|
||||
"Epoch 30, Loss: 2830.5508\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -230,16 +232,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 83,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Top-1 Accuracy: 51.27%\n",
|
||||
"Top-3 Accuracy: 73.68%\n",
|
||||
"Top-5 Accuracy: 82.94%\n"
|
||||
"Top-1 Accuracy: 52.77%\n",
|
||||
"Top-3 Accuracy: 74.39%\n",
|
||||
"Top-5 Accuracy: 83.37%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -275,6 +277,87 @@
|
||||
"print(f\"Top-3 Accuracy: {top3_acc * 100:.2f}%\")\n",
|
||||
"print(f\"Top-5 Accuracy: {top5_acc * 100:.2f}%\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.save(model.state_dict(), \"mlp_weights.pth\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.save(model, \"mlp_full_model.pth\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reuse same alphabet + mapping\n",
|
||||
"alphabet = list(\"abcdefghijklmnopqrstuvwxyz\")\n",
|
||||
"char_to_idx = {ch: idx for idx, ch in enumerate(alphabet)}\n",
|
||||
"PAD_IDX = len(alphabet) # index 26 for OOV/padding\n",
|
||||
"VOCAB_SIZE = len(alphabet) + 1 # 27 total (a–z + padding)\n",
|
||||
"CONTEXT_SIZE = 10\n",
|
||||
"\n",
|
||||
"idx_to_char = {idx: ch for ch, idx in char_to_idx.items()}\n",
|
||||
"idx_to_char[PAD_IDX] = \"_\" # for readability\n",
|
||||
"\n",
|
||||
"def preprocess_input(context: str) -> torch.Tensor:\n",
|
||||
" context = context.lower()\n",
|
||||
" padded = context.rjust(CONTEXT_SIZE, \"_\") # pad with underscores (or any 1-char symbol)\n",
|
||||
"\n",
|
||||
" indices = []\n",
|
||||
" for ch in padded[-CONTEXT_SIZE:]:\n",
|
||||
" idx = char_to_idx.get(ch, PAD_IDX) # if '_' or unknown → PAD_IDX (26)\n",
|
||||
" indices.append(idx)\n",
|
||||
"\n",
|
||||
" return torch.tensor(indices, dtype=torch.long).unsqueeze(0).to(device)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def predict_next_chars(model, context: str, top_k=5):\n",
|
||||
" model.eval()\n",
|
||||
" input_tensor = preprocess_input(context)\n",
|
||||
" with torch.no_grad():\n",
|
||||
" logits = model(input_tensor)\n",
|
||||
" probs = torch.softmax(logits, dim=-1)\n",
|
||||
" top_probs, top_indices = probs.topk(top_k, dim=-1)\n",
|
||||
"\n",
|
||||
" predictions = [(idx_to_char[idx.item()], top_probs[0, i].item()) for i, idx in enumerate(top_indices[0])]\n",
|
||||
" return predictions\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"A: 0.4302\n",
|
||||
"T: 0.2897\n",
|
||||
"E: 0.1538\n",
|
||||
"I: 0.0905\n",
|
||||
"C: 0.0159\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"preds = predict_next_chars(model, \"doors\")\n",
|
||||
"for char, prob in preds:\n",
|
||||
" print(f\"{char.upper()}: {prob:.4f}\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
33
transform.py
33
transform.py
@ -9,11 +9,8 @@ OUTPUT_FILE: str = "./data.npy"
|
||||
alphabet: List[str] = list("abcdefghijklmnopqrstuvwxyz")
|
||||
vowels: set[str] = set("aeiouy")
|
||||
|
||||
char_to_onehot: Dict[str, List[int]] = {
|
||||
ch: [1 if i == idx else 0 for i in range(26)]
|
||||
for idx, ch in enumerate(alphabet)
|
||||
}
|
||||
empty_vec: List[int] = [0] * 26
|
||||
char_to_index: Dict[str, int] = {ch: idx for idx, ch in enumerate(alphabet)}
|
||||
default_index: int = len(alphabet) # Out-of-vocabulary token (e.g., for "")
|
||||
|
||||
|
||||
def get_prev_type(c: str) -> Literal[0, 1, 2]:
|
||||
@ -24,8 +21,8 @@ def get_prev_type(c: str) -> Literal[0, 1, 2]:
|
||||
return 0
|
||||
|
||||
|
||||
def encode_letter(c: str) -> List[int]:
|
||||
return char_to_onehot.get(c, empty_vec)
|
||||
def encode_letter(c: str) -> int:
|
||||
return char_to_index.get(c, default_index)
|
||||
|
||||
|
||||
def build_dataset(input_path: str) -> np.ndarray:
|
||||
@ -34,31 +31,27 @@ def build_dataset(input_path: str) -> np.ndarray:
|
||||
with open(input_path, 'r') as input_file:
|
||||
for line in input_file:
|
||||
word: str = line.strip().lower()
|
||||
prev_chars: List[str] = [""] * 10 # Updated: now 10-character context
|
||||
prev_chars: List[str] = [""] * 10
|
||||
|
||||
for i, curr_char in enumerate(word):
|
||||
features: List[int] = []
|
||||
|
||||
# One-hot encode 10 previous characters
|
||||
# Use indices instead of one-hot for previous 10 characters
|
||||
for prev in prev_chars:
|
||||
features.extend(encode_letter(prev))
|
||||
features.append(encode_letter(prev))
|
||||
|
||||
# One-hot encode current character
|
||||
features.extend(encode_letter(curr_char))
|
||||
# Append current char index (target for classification)
|
||||
features.append(encode_letter(curr_char))
|
||||
|
||||
# Word position features
|
||||
is_start: int = 1 if i == 0 else 0
|
||||
features.append(is_start)
|
||||
|
||||
# is_start: int = 1 if i == 0 else 0
|
||||
prev1: str = prev_chars[-1]
|
||||
features.append(get_prev_type(prev1))
|
||||
|
||||
word_length: int = i + 1
|
||||
features.append(word_length)
|
||||
prev_type: int = get_prev_type(prev1)
|
||||
# word_length: int = i + 1
|
||||
|
||||
# features.extend([prev_type])
|
||||
all_features.append(features)
|
||||
|
||||
# Shift history
|
||||
prev_chars = prev_chars[1:] + [curr_char]
|
||||
|
||||
return np.array(all_features, dtype=np.int32)
|
||||
|
65
transform.sh
65
transform.sh
@ -1,65 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
FILE_PATH="./1000_words.txt"
|
||||
OUT_FILE_PATH="./out.txt"
|
||||
vowels="aeiouy"
|
||||
|
||||
|
||||
printf "previous_5,previous_4,previous_3,previous_2,previous_1,current,is_start,previous_type,word_length\n" > "$OUT_FILE_PATH"
|
||||
|
||||
while read -r line; do
|
||||
prev_5=""
|
||||
prev_4=""
|
||||
prev_3=""
|
||||
prev_2=""
|
||||
prev_1=""
|
||||
|
||||
for (( i=0; i<"${#line}"; i++ )); do
|
||||
word_length="$((i + 1))"
|
||||
curr="${line:$i:1}"
|
||||
|
||||
# Convert all to lowercase
|
||||
curr_lower=$(echo "$curr" | tr 'A-Z' 'a-z')
|
||||
p1_lower=$(echo "$prev_1" | tr 'A-Z' 'a-z')
|
||||
p2_lower=$(echo "$prev_2" | tr 'A-Z' 'a-z')
|
||||
p3_lower=$(echo "$prev_3" | tr 'A-Z' 'a-z')
|
||||
p4_lower=$(echo "$prev_4" | tr 'A-Z' 'a-z')
|
||||
p5_lower=$(echo "$prev_5" | tr 'A-Z' 'a-z')
|
||||
|
||||
# Convert to ASCII values (default to 0 if empty)
|
||||
val_p5=0; [ -n "$p5_lower" ] && val_p5=$(printf "%d" "'$p5_lower")
|
||||
val_p4=0; [ -n "$p4_lower" ] && val_p4=$(printf "%d" "'$p4_lower")
|
||||
val_p3=0; [ -n "$p3_lower" ] && val_p3=$(printf "%d" "'$p3_lower")
|
||||
val_p2=0; [ -n "$p2_lower" ] && val_p2=$(printf "%d" "'$p2_lower")
|
||||
val_p1=0; [ -n "$p1_lower" ] && val_p1=$(printf "%d" "'$p1_lower")
|
||||
val_curr=$(printf "%d" "'$curr_lower")
|
||||
|
||||
# Determine if this is the start of the word
|
||||
is_start=0
|
||||
[ "$i" -eq 0 ] && is_start=1
|
||||
|
||||
# Determine if prev_1 is vowel or consonant
|
||||
if [[ "$p1_lower" =~ ^[a-z]$ ]]; then
|
||||
if [[ "$vowels" == *"$p1_lower"* ]]; then
|
||||
prev_type="1"
|
||||
else
|
||||
prev_type="2"
|
||||
fi
|
||||
else
|
||||
prev_type="0"
|
||||
fi
|
||||
|
||||
# Output CSV line
|
||||
printf "%d,%d,%d,%d,%d,%d,%d,%d,%d\n" \
|
||||
"$val_p5" "$val_p4" "$val_p3" "$val_p2" "$val_p1" "$val_curr" \
|
||||
"$is_start" "$prev_type" "$word_length" \
|
||||
>> "$OUT_FILE_PATH"
|
||||
|
||||
# Shift history
|
||||
prev_5="$prev_4"
|
||||
prev_4="$prev_3"
|
||||
prev_3="$prev_2"
|
||||
prev_2="$prev_1"
|
||||
prev_1="$curr"
|
||||
done
|
||||
done < "$FILE_PATH"
|
Loading…
x
Reference in New Issue
Block a user