71 lines
1.9 KiB
Python
Executable File
71 lines
1.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
from typing import Literal, List, Dict
|
|
import numpy as np
|
|
|
|
INPUT_FILE: str = "./data/all_cleaned_words.txt"
|
|
OUTPUT_FILE: str = "./data.npy"
|
|
|
|
alphabet: List[str] = list("abcdefghijklmnopqrstuvwxyz")
|
|
vowels: set[str] = set("aeiouy")
|
|
|
|
char_to_onehot: Dict[str, List[int]] = {
|
|
ch: [1 if i == idx else 0 for i in range(26)]
|
|
for idx, ch in enumerate(alphabet)
|
|
}
|
|
empty_vec: List[int] = [0] * 26
|
|
|
|
|
|
def get_prev_type(c: str) -> Literal[0, 1, 2]:
|
|
if c in vowels:
|
|
return 1
|
|
elif c in alphabet:
|
|
return 2
|
|
return 0
|
|
|
|
|
|
def encode_letter(c: str) -> List[int]:
|
|
return char_to_onehot.get(c, empty_vec)
|
|
|
|
|
|
def build_dataset(input_path: str) -> np.ndarray:
|
|
all_features: List[List[int]] = []
|
|
|
|
with open(input_path, 'r') as input_file:
|
|
for line in input_file:
|
|
word: str = line.strip().lower()
|
|
prev_chars: List[str] = [""] * 10 # Updated: now 10-character context
|
|
|
|
for i, curr_char in enumerate(word):
|
|
features: List[int] = []
|
|
|
|
# One-hot encode 10 previous characters
|
|
for prev in prev_chars:
|
|
features.extend(encode_letter(prev))
|
|
|
|
# One-hot encode current character
|
|
features.extend(encode_letter(curr_char))
|
|
|
|
# Word position features
|
|
is_start: int = 1 if i == 0 else 0
|
|
features.append(is_start)
|
|
|
|
prev1: str = prev_chars[-1]
|
|
features.append(get_prev_type(prev1))
|
|
|
|
word_length: int = i + 1
|
|
features.append(word_length)
|
|
|
|
all_features.append(features)
|
|
|
|
# Shift history
|
|
prev_chars = prev_chars[1:] + [curr_char]
|
|
|
|
return np.array(all_features, dtype=np.int32)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dataset: np.ndarray = build_dataset(INPUT_FILE)
|
|
np.save(OUTPUT_FILE, dataset)
|
|
print(f"Saved dataset shape: {dataset.shape} → {OUTPUT_FILE}")
|