Skip to content

Commit

Permalink
Merge branch 'feature/neural_net' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Tickloop committed May 17, 2022
2 parents 6f928bc + 2f7a2d3 commit 17a4fb6
Show file tree
Hide file tree
Showing 12 changed files with 933 additions and 3 deletions.
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__pycache__
data/games.txt
100epoch_loss.png
data/100epoch_naive
data/dummy.txt
data/interaction_history.json
models/
plots/
interactions/
39 changes: 39 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from torch import empty, long
from torch.utils.data import Dataset

class WordleDataset(Dataset):
"""
A dataset that contains the words from the chosen word list (official/unofficial.txt) under data/ subdirectory.
Can be iterated through by simply calling `for words, labels in dataset:`
"""
def __init__(self, root_dir):
def get_wordlist(filename):
"""
To read all the lines from the given file, strip the enidng '\n' and lower case all the characters
"""
words = []
with open(filename, 'r') as f:
words = f.readlines()
words = [word.strip() for word in words]
words = [word.lower() for word in words]
return words

def get_labels(words):
labels = empty((len(words), 5), dtype=long)
for i, word in enumerate(words):
for j, k in enumerate(word):
labels[i, j] = ord(k) - ord('a')
return labels

self.root_dir = root_dir
self.words = get_wordlist(root_dir)
self.labels = get_labels(self.words)

def __len__(self):
return len(self.words)

def __getitem__(self, idx):
word = self.words[idx]
label = self.labels[idx]
return word, label
29 changes: 29 additions & 0 deletions game_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from torch import load
from utils import get_default_features, get_feedback, get_mask_tree, get_word_beam_search, get_updated_features, get_wordset
from visualize import get_colored_word

if __name__ == "__main__":
word_set = get_wordset("data/official.txt")
model = load("models/100epoch_bigger_train_beam_3")
mask_tree = get_mask_tree("data/official.txt")

while True:
correct_word = input("Choose word:")
print("word in vocab: ", correct_word in word_set)
if correct_word == "quit":
break

features = get_default_features()

for attempt in range(6):
outputs = model(features)
guessed_word = get_word_beam_search(outputs, mask_tree)

feedback = get_feedback(guessed_word, correct_word)
features = get_updated_features(features, feedback, guessed_word)

colored_word = get_colored_word(guessed_word, feedback)
print(colored_word)

if guessed_word == correct_word:
break
80 changes: 80 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import numpy as np
from metrics import accuracy, avg_loss
from utils import *
from models import BaseModel

device = "cuda:0" if torch.cuda.is_available else "cpu"
print(f"Training models on {device}")

torch.manual_seed(2002)
torch.autograd.set_detect_anomaly(True)

def train(model, datasets, mask_tree, max_epochs, eta):
losses = np.zeros(max_epochs)
val_acc = np.zeros(max_epochs)
val_loss = np.zeros(max_epochs)
word_count = len(datasets['train'])
max_val_acc = float('-inf')

optimizer = Adam(model.parameters(), lr=eta)
loss_criterion = CrossEntropyLoss()
interactions = { ep : { word : {} for word, label in datasets['train'] } for ep in range(max_epochs) }

for epoch in range(max_epochs):
i = 0
model.train()

for correct_word, correct_word_labels in datasets['train']:
features = get_default_features()
i += 1
print(f"Word: {correct_word} {i}/{word_count}", end='\r')

for attempt in range(6):
optimizer.zero_grad()

outputs = model(features)
guessed_word = get_word_beam_search(outputs, mask_tree, k=3)

word_loss = loss_criterion(outputs, correct_word_labels)

word_loss.backward()
optimizer.step()

losses[epoch] += word_loss.item()

feedback = get_feedback(guessed_word, correct_word)
features = get_updated_features(features, feedback, guessed_word)

interactions[epoch][correct_word][attempt] = {
'feedback': feedback,
'guessed_word': guessed_word
}

if guessed_word == correct_word:
break

model.eval()
val_acc[epoch], _ = accuracy(model, datasets['train'], mask_tree)
val_loss[epoch] = avg_loss(model, datasets['train'], mask_tree)
print(f"Epoch {epoch} / {max_epochs}, loss => {losses[epoch]}, val_acc => {val_acc[epoch]}, val_loss => {val_loss[epoch]}")

if val_acc[epoch] > max_val_acc:
save_model(model, "100epoch_bigger_full")
max_val_acc = val_acc[epoch]

return losses, interactions

if __name__ == "__main__":
splits = [1.0, 0, 0]
mask_tree = get_mask_tree("data/official.txt")
dataset = get_dataset("data/official.txt")
datasets = get_split_dataset(dataset, splits)

b1 = BaseModel(in_features=26 * 12)
b1_loss, interaction_history = train(b1, datasets, mask_tree, max_epochs=100, eta=0.00005)

save_history(interaction_history, "final_interaction_history.json")
save_loss(b1_loss, "100epoch_bigger_full.npy")
47 changes: 47 additions & 0 deletions metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from torch.nn import CrossEntropyLoss
from utils import get_default_features, get_feedback, get_mask_tree, get_updated_features, get_word_beam_search


def accuracy(model, dataset, mask_tree):
acc = 0.
count = 0.
attempt_count = {}
for correct_word, label in dataset:
features = get_default_features()

for attempt in range(6):
output = model(features)

guessed_word = get_word_beam_search(output, mask_tree)

if guessed_word == correct_word:
acc += 1
attempt_count[correct_word] = 1 + attempt
break

feedback = get_feedback(guessed_word, correct_word)
features = get_updated_features(features, feedback, guessed_word)
count += 1

acc = 100 * acc / count
acc = round(acc, 4)
return acc, attempt_count

def avg_loss(model, dataset, mask_tree):
loss_fn = CrossEntropyLoss()
loss = 0.
for correct_word, label in dataset:
features = get_default_features()

for attempt in range(6):
outputs = model(features)
loss += loss_fn(outputs, label)

guessed_word = get_word_beam_search(outputs, mask_tree)
feedback = get_feedback(guessed_word, correct_word)
features = get_updated_features(features, feedback, guessed_word)

if guessed_word == correct_word:
break

return loss
37 changes: 37 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn as nn

class BaseModel(nn.Module):
def __init__(self, in_features):
super(BaseModel, self).__init__()

self.linear_layers = nn.Sequential(
nn.Linear(in_features=in_features, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=512),
nn.ReLU(),
)

self.output_char_layers = [
nn.Linear(in_features=512, out_features=26),
nn.Linear(in_features=512, out_features=26),
nn.Linear(in_features=512, out_features=26),
nn.Linear(in_features=512, out_features=26),
nn.Linear(in_features=512, out_features=26)
]

self.flatten = nn.Flatten(start_dim=0)
self.activation = nn.ReLU()

def forward(self, x):
output = self.flatten(x)
output = self.linear_layers(output)

outputs = torch.empty((5, 26))
for i, layer in enumerate(self.output_char_layers):
outputs[i] = layer(output)

return outputs

7 changes: 7 additions & 0 deletions pipeline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

# first run the trainer
python main.py

# then the visualizer
python visualize.py
65 changes: 65 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from torch import empty, long
from torch.utils.data import Dataset
class WordleDataset(Dataset):
def __init__(self, root_dir):

""" To read all the lines from the given file, strip the enidng '\n' and lower case all the characters """
def get_wordlist(filename):
words = []
with open(filename, 'r') as f:
words = f.readlines()
words = [word.strip() for word in words]
words = [word.lower() for word in words]
return words

def get_labels(words):
labels = empty((len(words), 5), dtype=long)
for i, word in enumerate(words):
for j, k in enumerate(word):
labels[i, j] = ord(k) - ord('a')
return labels

self.root_dir = root_dir
self.words = get_wordlist(root_dir)
self.labels = get_labels(self.words)

def __len__(self):
return len(self.words)

def __getitem__(self, idx):
word = self.words[idx]
label = self.labels[idx]
return word, label

def split_and_print(dataset, splits):
total_count = len(dataset)
splits = [int(total_count * ratio) for ratio in splits]
splits[-1] = total_count - sum(splits)

train_set, val_set, test_set = torch.utils.data.random_split(dataset, splits, generator=torch.Generator().manual_seed(42))
datasets = {
'train': train_set,
'val': val_set,
'test': test_set,
}

# print the first 10 in each
for name, dataset in datasets.items():
print(f"name: {len(dataset)} samples")
for idx in range(10):
print(dataset[idx])

if __name__ == "__main__":
# splits
splits = [0.8, 0.05, 0]

official_dataset = WordleDataset("data/official.txt")
total_count = len(official_dataset)
print(f"\ntotal samples in Official Dataset: {total_count}")
split_and_print(official_dataset, splits)

unofficial_dataset = WordleDataset("data/words.txt")
total_count = len(unofficial_dataset)
print(f"\ntotal samples in Official Dataset: {total_count}")
split_and_print(unofficial_dataset, splits)
File renamed without changes.
Loading

0 comments on commit 17a4fb6

Please sign in to comment.