-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'feature/neural_net' into main
- Loading branch information
Showing
12 changed files
with
933 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
__pycache__ | ||
data/games.txt | ||
100epoch_loss.png | ||
data/100epoch_naive | ||
data/dummy.txt | ||
data/interaction_history.json | ||
models/ | ||
plots/ | ||
interactions/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from torch import empty, long | ||
from torch.utils.data import Dataset | ||
|
||
class WordleDataset(Dataset): | ||
""" | ||
A dataset that contains the words from the chosen word list (official/unofficial.txt) under data/ subdirectory. | ||
Can be iterated through by simply calling `for words, labels in dataset:` | ||
""" | ||
def __init__(self, root_dir): | ||
def get_wordlist(filename): | ||
""" | ||
To read all the lines from the given file, strip the enidng '\n' and lower case all the characters | ||
""" | ||
words = [] | ||
with open(filename, 'r') as f: | ||
words = f.readlines() | ||
words = [word.strip() for word in words] | ||
words = [word.lower() for word in words] | ||
return words | ||
|
||
def get_labels(words): | ||
labels = empty((len(words), 5), dtype=long) | ||
for i, word in enumerate(words): | ||
for j, k in enumerate(word): | ||
labels[i, j] = ord(k) - ord('a') | ||
return labels | ||
|
||
self.root_dir = root_dir | ||
self.words = get_wordlist(root_dir) | ||
self.labels = get_labels(self.words) | ||
|
||
def __len__(self): | ||
return len(self.words) | ||
|
||
def __getitem__(self, idx): | ||
word = self.words[idx] | ||
label = self.labels[idx] | ||
return word, label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from torch import load | ||
from utils import get_default_features, get_feedback, get_mask_tree, get_word_beam_search, get_updated_features, get_wordset | ||
from visualize import get_colored_word | ||
|
||
if __name__ == "__main__": | ||
word_set = get_wordset("data/official.txt") | ||
model = load("models/100epoch_bigger_train_beam_3") | ||
mask_tree = get_mask_tree("data/official.txt") | ||
|
||
while True: | ||
correct_word = input("Choose word:") | ||
print("word in vocab: ", correct_word in word_set) | ||
if correct_word == "quit": | ||
break | ||
|
||
features = get_default_features() | ||
|
||
for attempt in range(6): | ||
outputs = model(features) | ||
guessed_word = get_word_beam_search(outputs, mask_tree) | ||
|
||
feedback = get_feedback(guessed_word, correct_word) | ||
features = get_updated_features(features, feedback, guessed_word) | ||
|
||
colored_word = get_colored_word(guessed_word, feedback) | ||
print(colored_word) | ||
|
||
if guessed_word == correct_word: | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
from torch.optim import Adam | ||
from torch.nn import CrossEntropyLoss | ||
import numpy as np | ||
from metrics import accuracy, avg_loss | ||
from utils import * | ||
from models import BaseModel | ||
|
||
device = "cuda:0" if torch.cuda.is_available else "cpu" | ||
print(f"Training models on {device}") | ||
|
||
torch.manual_seed(2002) | ||
torch.autograd.set_detect_anomaly(True) | ||
|
||
def train(model, datasets, mask_tree, max_epochs, eta): | ||
losses = np.zeros(max_epochs) | ||
val_acc = np.zeros(max_epochs) | ||
val_loss = np.zeros(max_epochs) | ||
word_count = len(datasets['train']) | ||
max_val_acc = float('-inf') | ||
|
||
optimizer = Adam(model.parameters(), lr=eta) | ||
loss_criterion = CrossEntropyLoss() | ||
interactions = { ep : { word : {} for word, label in datasets['train'] } for ep in range(max_epochs) } | ||
|
||
for epoch in range(max_epochs): | ||
i = 0 | ||
model.train() | ||
|
||
for correct_word, correct_word_labels in datasets['train']: | ||
features = get_default_features() | ||
i += 1 | ||
print(f"Word: {correct_word} {i}/{word_count}", end='\r') | ||
|
||
for attempt in range(6): | ||
optimizer.zero_grad() | ||
|
||
outputs = model(features) | ||
guessed_word = get_word_beam_search(outputs, mask_tree, k=3) | ||
|
||
word_loss = loss_criterion(outputs, correct_word_labels) | ||
|
||
word_loss.backward() | ||
optimizer.step() | ||
|
||
losses[epoch] += word_loss.item() | ||
|
||
feedback = get_feedback(guessed_word, correct_word) | ||
features = get_updated_features(features, feedback, guessed_word) | ||
|
||
interactions[epoch][correct_word][attempt] = { | ||
'feedback': feedback, | ||
'guessed_word': guessed_word | ||
} | ||
|
||
if guessed_word == correct_word: | ||
break | ||
|
||
model.eval() | ||
val_acc[epoch], _ = accuracy(model, datasets['train'], mask_tree) | ||
val_loss[epoch] = avg_loss(model, datasets['train'], mask_tree) | ||
print(f"Epoch {epoch} / {max_epochs}, loss => {losses[epoch]}, val_acc => {val_acc[epoch]}, val_loss => {val_loss[epoch]}") | ||
|
||
if val_acc[epoch] > max_val_acc: | ||
save_model(model, "100epoch_bigger_full") | ||
max_val_acc = val_acc[epoch] | ||
|
||
return losses, interactions | ||
|
||
if __name__ == "__main__": | ||
splits = [1.0, 0, 0] | ||
mask_tree = get_mask_tree("data/official.txt") | ||
dataset = get_dataset("data/official.txt") | ||
datasets = get_split_dataset(dataset, splits) | ||
|
||
b1 = BaseModel(in_features=26 * 12) | ||
b1_loss, interaction_history = train(b1, datasets, mask_tree, max_epochs=100, eta=0.00005) | ||
|
||
save_history(interaction_history, "final_interaction_history.json") | ||
save_loss(b1_loss, "100epoch_bigger_full.npy") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from torch.nn import CrossEntropyLoss | ||
from utils import get_default_features, get_feedback, get_mask_tree, get_updated_features, get_word_beam_search | ||
|
||
|
||
def accuracy(model, dataset, mask_tree): | ||
acc = 0. | ||
count = 0. | ||
attempt_count = {} | ||
for correct_word, label in dataset: | ||
features = get_default_features() | ||
|
||
for attempt in range(6): | ||
output = model(features) | ||
|
||
guessed_word = get_word_beam_search(output, mask_tree) | ||
|
||
if guessed_word == correct_word: | ||
acc += 1 | ||
attempt_count[correct_word] = 1 + attempt | ||
break | ||
|
||
feedback = get_feedback(guessed_word, correct_word) | ||
features = get_updated_features(features, feedback, guessed_word) | ||
count += 1 | ||
|
||
acc = 100 * acc / count | ||
acc = round(acc, 4) | ||
return acc, attempt_count | ||
|
||
def avg_loss(model, dataset, mask_tree): | ||
loss_fn = CrossEntropyLoss() | ||
loss = 0. | ||
for correct_word, label in dataset: | ||
features = get_default_features() | ||
|
||
for attempt in range(6): | ||
outputs = model(features) | ||
loss += loss_fn(outputs, label) | ||
|
||
guessed_word = get_word_beam_search(outputs, mask_tree) | ||
feedback = get_feedback(guessed_word, correct_word) | ||
features = get_updated_features(features, feedback, guessed_word) | ||
|
||
if guessed_word == correct_word: | ||
break | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
class BaseModel(nn.Module): | ||
def __init__(self, in_features): | ||
super(BaseModel, self).__init__() | ||
|
||
self.linear_layers = nn.Sequential( | ||
nn.Linear(in_features=in_features, out_features=512), | ||
nn.ReLU(), | ||
nn.Linear(in_features=512, out_features=512), | ||
nn.ReLU(), | ||
nn.Linear(in_features=512, out_features=512), | ||
nn.ReLU(), | ||
) | ||
|
||
self.output_char_layers = [ | ||
nn.Linear(in_features=512, out_features=26), | ||
nn.Linear(in_features=512, out_features=26), | ||
nn.Linear(in_features=512, out_features=26), | ||
nn.Linear(in_features=512, out_features=26), | ||
nn.Linear(in_features=512, out_features=26) | ||
] | ||
|
||
self.flatten = nn.Flatten(start_dim=0) | ||
self.activation = nn.ReLU() | ||
|
||
def forward(self, x): | ||
output = self.flatten(x) | ||
output = self.linear_layers(output) | ||
|
||
outputs = torch.empty((5, 26)) | ||
for i, layer in enumerate(self.output_char_layers): | ||
outputs[i] = layer(output) | ||
|
||
return outputs | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/bash | ||
|
||
# first run the trainer | ||
python main.py | ||
|
||
# then the visualizer | ||
python visualize.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
from torch import empty, long | ||
from torch.utils.data import Dataset | ||
class WordleDataset(Dataset): | ||
def __init__(self, root_dir): | ||
|
||
""" To read all the lines from the given file, strip the enidng '\n' and lower case all the characters """ | ||
def get_wordlist(filename): | ||
words = [] | ||
with open(filename, 'r') as f: | ||
words = f.readlines() | ||
words = [word.strip() for word in words] | ||
words = [word.lower() for word in words] | ||
return words | ||
|
||
def get_labels(words): | ||
labels = empty((len(words), 5), dtype=long) | ||
for i, word in enumerate(words): | ||
for j, k in enumerate(word): | ||
labels[i, j] = ord(k) - ord('a') | ||
return labels | ||
|
||
self.root_dir = root_dir | ||
self.words = get_wordlist(root_dir) | ||
self.labels = get_labels(self.words) | ||
|
||
def __len__(self): | ||
return len(self.words) | ||
|
||
def __getitem__(self, idx): | ||
word = self.words[idx] | ||
label = self.labels[idx] | ||
return word, label | ||
|
||
def split_and_print(dataset, splits): | ||
total_count = len(dataset) | ||
splits = [int(total_count * ratio) for ratio in splits] | ||
splits[-1] = total_count - sum(splits) | ||
|
||
train_set, val_set, test_set = torch.utils.data.random_split(dataset, splits, generator=torch.Generator().manual_seed(42)) | ||
datasets = { | ||
'train': train_set, | ||
'val': val_set, | ||
'test': test_set, | ||
} | ||
|
||
# print the first 10 in each | ||
for name, dataset in datasets.items(): | ||
print(f"name: {len(dataset)} samples") | ||
for idx in range(10): | ||
print(dataset[idx]) | ||
|
||
if __name__ == "__main__": | ||
# splits | ||
splits = [0.8, 0.05, 0] | ||
|
||
official_dataset = WordleDataset("data/official.txt") | ||
total_count = len(official_dataset) | ||
print(f"\ntotal samples in Official Dataset: {total_count}") | ||
split_and_print(official_dataset, splits) | ||
|
||
unofficial_dataset = WordleDataset("data/words.txt") | ||
total_count = len(unofficial_dataset) | ||
print(f"\ntotal samples in Official Dataset: {total_count}") | ||
split_and_print(unofficial_dataset, splits) |
File renamed without changes.
Oops, something went wrong.