# Import data

In [10]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, random_split

In [50]:
data = np.load("./data.npy")

In [54]:
CONTEXT_SIZE = 10
ALPHABET = list("abcdefghijklmnopqrstuvwxyz")
ALPHABET_SIZE = len(ALPHABET)
TRAINING_DATA_SIZE = 0.9


# Derived values
PREV_LETTER_FEATURES = CONTEXT_SIZE * ALPHABET_SIZE
CURR_LETTER_FEATURES = ALPHABET_SIZE
OTHER_FEATURES = 3  # is_start, prev_type, word_length

TOTAL_FEATURES = PREV_LETTER_FEATURES + CURR_LETTER_FEATURES + OTHER_FEATURES

INPUT_SIZE = PREV_LETTER_FEATURES + OTHER_FEATURES
OUTPUT_SIZE = ALPHABET_SIZE

# Define and split data

## Define input and output columns

In [52]:
X = np.hstack([
    data[:, :PREV_LETTER_FEATURES],
    data[:, PREV_LETTER_FEATURES + CURR_LETTER_FEATURES:TOTAL_FEATURES]
])

# Extract current letter (one-hot target)
y_onehot = data[:, PREV_LETTER_FEATURES:PREV_LETTER_FEATURES + CURR_LETTER_FEATURES]
y = np.argmax(y_onehot, axis=1)

# Torch dataset
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)

dataset = TensorDataset(X_tensor, y_tensor)


In [55]:
train_len = int(TRAINING_DATA_SIZE * len(dataset))
train_set, test_set = random_split(dataset, [train_len, len(dataset) - train_len])

In [56]:
train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = DataLoader(test_set, batch_size=128)

# Train on data

In [79]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(INPUT_SIZE, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, OUTPUT_SIZE)
        )

    def forward(self, x):
        return self.net(x)

In [80]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [81]:
model = MLP().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [82]:
for epoch in range(30):
    model.train()
    total_loss = 0
    for batch_X, batch_y in train_loader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        optimizer.zero_grad()
        output = model(batch_X)
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

Epoch 1, Loss: 4277.8506
Epoch 2, Loss: 3647.3064
Epoch 3, Loss: 3421.2898
Epoch 4, Loss: 3289.9248
Epoch 5, Loss: 3203.0331
Epoch 6, Loss: 3141.4064
Epoch 7, Loss: 3099.4711
Epoch 8, Loss: 3065.2254
Epoch 9, Loss: 3040.1093
Epoch 10, Loss: 3016.0812
Epoch 11, Loss: 2998.2589
Epoch 12, Loss: 2982.5763
Epoch 13, Loss: 2968.7752
Epoch 14, Loss: 2956.6091
Epoch 15, Loss: 2945.3793
Epoch 16, Loss: 2935.6520
Epoch 17, Loss: 2928.2420
Epoch 18, Loss: 2918.6128
Epoch 19, Loss: 2912.0454
Epoch 20, Loss: 2904.7236
Epoch 21, Loss: 2898.5873
Epoch 22, Loss: 2893.1154
Epoch 23, Loss: 2887.1008
Epoch 24, Loss: 2884.5473
Epoch 25, Loss: 2879.1589
Epoch 26, Loss: 2874.9795
Epoch 27, Loss: 2870.3030
Epoch 28, Loss: 2867.0953
Epoch 29, Loss: 2863.1449
Epoch 30, Loss: 2859.8749


# Testing model

In [83]:
model.eval()
correct_top1 = 0
correct_top3 = 0
correct_top5 = 0
total = 0

with torch.no_grad():
    for batch_X, batch_y in test_loader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        outputs = model(batch_X)  # shape: [batch_size, 26]

        # Get top-5 predictions
        _, top_preds = outputs.topk(5, dim=1)  # shape: [batch_size, 5]

        for true, top5 in zip(batch_y, top_preds):
            total += 1
            if true == top5[0]:
                correct_top1 += 1
            if true in top5[:3]:
                correct_top3 += 1
            if true in top5:
                correct_top5 += 1

top1_acc = correct_top1 / total
top3_acc = correct_top3 / total
top5_acc = correct_top5 / total

print(f"Top-1 Accuracy: {top1_acc * 100:.2f}%")
print(f"Top-3 Accuracy: {top3_acc * 100:.2f}%")
print(f"Top-5 Accuracy: {top5_acc * 100:.2f}%")


Top-1 Accuracy: 51.27%
Top-3 Accuracy: 73.68%
Top-5 Accuracy: 82.94%
