Skip to content

Commit b1ea9e0

Browse files
committed
upload
0 parents  commit b1ea9e0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1474
-0
lines changed

README.md

+169
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# [PYTORCH] Hierarchical Attention Networks for Document Classification
2+
3+
## Introduction
4+
5+
Here is my pytorch implementation of the model described in the paper **Hierarchical Attention Networks for Document Classification** [paper](https://www.cs.cmu.edu/%7Ediyiy/docs/naacl16.pdf).
6+
7+
<p align="center">
8+
<img src="demo/video.gif"><br/>
9+
<i>An example of app demo for my model's output for Dbpedia dataset.</i>
10+
</p>
11+
12+
<p align="center">
13+
<img src="demo/output.gif"><br/>
14+
<i>An example of my model's performance for Dbpedia dataset.</i>
15+
</p>
16+
17+
## How to use my code
18+
19+
With my code, you can:
20+
* **Train your model with any dataset**
21+
* **Given either my trained model or yours, you could evaluate any test dataset whose have the same set of classes**
22+
* **Run a simple web app for testing purpose**
23+
24+
## Requirements:
25+
26+
* **python 3.6**
27+
* **pytorch 0.4**
28+
* **tensorboard**
29+
* **tensorboardX** (This library could be skipped if you do not use SummaryWriter)
30+
* **numpy**
31+
32+
## Datasets:
33+
34+
Statistics of datasets I used for experiments. These datasets could be download from [link](https://drive.google.com/drive/u/0/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M)
35+
36+
| Dataset | Classes | Train samples | Test samples |
37+
|------------------------|:---------:|:---------------:|:--------------:|
38+
| AG’s News | 4 | 120 000 | 7 600 |
39+
| Sogou News | 5 | 450 000 | 60 000 |
40+
| DBPedia | 14 | 560 000 | 70 000 |
41+
| Yelp Review Polarity | 2 | 560 000 | 38 000 |
42+
| Yelp Review Full | 5 | 650 000 | 50 000 |
43+
| Yahoo! Answers | 10 | 1 400 000 | 60 000 |
44+
| Amazon Review Full | 5 | 3 000 000 | 650 000 |
45+
| Amazon Review Polarity | 2 | 3 600 000 | 400 000 |
46+
47+
Additionally, I also use word2vec pre-trained models, taken from GLOVE, which you could download from [link](https://nlp.stanford.edu/projects/glove/). I run experiments with all 4 word2vec files (50d, 100d, 200d and 300d). You could easily switch to other common word2vec models, like the one provided in FastText [link](https://fasttext.cc/docs/en/crawl-vectors.html)
48+
In the paper, it is said that a pre-trained word2vec is used. However, to the best of my knowledge, at least in pytorch, there is no implementation on github using it. In all HAN github repositories I have seen so far, a default embedding layer
49+
was used, without loading pre-trained word2vec model. I admit that we could still train HAN model without any pre-trained word2vec model. However, to serve the purpose of re-implementing origin model, in all experiments, as mentioned above, I used 1 out of 4 pre-trained word2vec models as initilization for embedding layer.
50+
51+
## Setting:
52+
53+
During my experiments, I found out that given different datasets and different embedding layer's dimension, some combinations of batch size and learning rate yield better performance (faster convergence and higher accuracy) than others. Particularly in some cases, if you set wrong values for these 2 very important parameters, your model will never converge. Detail setting for each experiments will be shown in **Experiments** part.
54+
I have not set a fixed number of epoches for each experiment. Instead, I apply early stopping technique, to stop training phase after test loss has not been improved for **n** epoches.
55+
56+
## Training
57+
58+
If you want to train a model with default parameters, you could run:
59+
- **python train.py**
60+
61+
If you want to train a model with your preference parameters, like optimizer and learning rate, you could run:
62+
- **python train.py --batch_size batch_size --lr learning_rate**: For example, python train.py --batch_size 512 --lr 0.01
63+
64+
If you want to train a model with your preference word2vec model, you could run:
65+
- **python train.py --word2vec_path path/to/your/word2vec**
66+
67+
## Test
68+
69+
For testing a trained model with your test file, please run the following command:
70+
- **python test.py --word2vec_path path/to/your/word2vec**, with the word2vec file is the same as the one you use in training phase.
71+
72+
You could find some trained models I have trained in [link](https://drive.google.com/open?id=1A50PDQMm0THnU6QDxOEsvKqH-ZTxmGpT)
73+
74+
## Experiments:
75+
76+
Results for test set are presented as follows: A(B/C):
77+
- **A** is accuracy.
78+
- **B** is learning rate used.
79+
- **C** is batch size.
80+
81+
Each experiment is run over 10 epochs.
82+
83+
| GLOVE word2vec| 50 | 100 | 200 | 300 |
84+
|:---------------:|:------------------:|:------------------:|:------------------:|:------------------:|
85+
| ag_news | updated soon | updated soon | updated soon | updated soon |
86+
| sogu_news | updated soon | updated soon | updated soon | updated soon |
87+
| db_pedia | updated soon | updated soon | updated soon | updated soon |
88+
| yelp_polarity | updated soon | updated soon | updated soon | updated soon |
89+
| yelp_review | updated soon | updated soon | updated soon | updated soon |
90+
| yahoo_answer | updated soon | updated soon | updated soon | updated soon |
91+
| amazon_review | updated soon | updated soon | updated soon | updated soon |
92+
|amazon_polarity| updated soon | updated soon | updated soon | updated soon |
93+
94+
The training/test loss/accuracy curves for each dataset's experiments (with the order from left to right, top to bottom is 50d, 100d, 200d and 300d word2vec) are shown below:
95+
96+
- **ag_news**
97+
98+
<img src="demo/agnews_50.png" width="420"> <img src="demo/agnews_100.png" width="420">
99+
<img src="demo/agnews_200.png" width="420"> <img src="demo/agnews_300.png" width="420">
100+
101+
- **db_pedia**
102+
103+
<img src="demo/dbpedia_50.png" width="420"> <img src="demo/dbpedia_100.png" width="420">
104+
<img src="demo/dbpedia_200.png" width="420"> <img src="demo/dbpedia_300.png" width="420">
105+
106+
- **yelp_polarity**
107+
108+
<img src="demo/yelpreviewpolarity_50.png" width="420"> <img src="demo/yelpreviewpolarity_100.png" width="420">
109+
<img src="demo/yelpreviewpolarity_200.png" width="420"> <img src="demo/empty.png" width="420">
110+
111+
- **yelp_review**
112+
113+
<img src="demo/yelpreviewfull_50.png" width="420"> <img src="demo/empty.png" width="420">
114+
<img src="demo/empty.png" width="420"> <img src="demo/yelpreviewfull_300.png" width="420">
115+
116+
- **Yahoo! Answers**
117+
118+
<img src="demo/yahoo_50.png" width="420"> <img src="demo/yahoo_100.png" width="420">
119+
<img src="demo/yahoo_200.png" width="420"> <img src="demo/yahoo_300.png" width="420">
120+
121+
- **amazon_review**
122+
123+
<img src="demo/amazonreviewfull_50.png" width="420"> <img src="demo/empty.png" width="420">
124+
<img src="demo/amazonreviewfull_200.png" width="420"> <img src="demo/empty.png" width="420">
125+
126+
- **amazon_polarity**
127+
128+
<img src="demo/amazonreviewpolarity_50.png" width="420"> <img src="demo/amazonreviewpolarity_100.png" width="420">
129+
<img src="demo/empty.png" width="420"> <img src="demo/amazonreviewpolarity_50.png" width="420">
130+
131+
There are some experiments I have not had time to train. For such experiments, statistics as well as loss/accuracy visualization are empty. Additionally, there are some other experiments, I can not wait until they are finished, hence I stopped training phase before it should be . You could see whether a model was stopped by early stopping technique or by me by looking at the test loss curve, if the loss is not improved for 5 consecutive epoches, it is the former case. Othewise, if the loss is still going down, it is the latter case. When I have time, I will complete the incomplete experiments, and update results here.
132+
133+
After completing training phase, you could see model's parameters you have set, accuracy, loss and confusion matrix for test set at the end of each epoch at **root_folder/trained_models/logs.txt**. One example is shown below:
134+
135+
<p align="center">
136+
<img src="demo/output.png"><br/>
137+
<i>An example of logs.txt for Dbpedia dataset.</i>
138+
</p>
139+
140+
## Demo:
141+
142+
I wrote a simple web which is suitable for quick test with any input text. In order to use the app, you could follow the following steps:
143+
144+
- **Step 1**: Run the script app.py
145+
<img src="demo/1.png" width="800">
146+
147+
- **Step 2**: Web interface
148+
<img src="demo/2.png" width="800">
149+
150+
- **Step 3**: Select trained model
151+
<img src="demo/3.png" width="800">
152+
153+
- **Step 4**: Select word2vec model
154+
<img src="demo/4.png" width="800">
155+
156+
- **Step 5 (Optional)**: Select file containing classes (one class per line)
157+
<img src="demo/5.png" width="800">
158+
159+
- **Step 6**: After all necessary files are selected, press submit button
160+
<img src="demo/6.png" width="800">
161+
162+
- **Step 7**: You could paste any text to the textbox
163+
<img src="demo/7.png" width="800">
164+
165+
- **Step 8**: A sample text
166+
<img src="demo/8.png" width="800">
167+
168+
- **Step 9**: After submit button pressed, predicted category and probability are shown
169+
<img src="demo/9.png" width="800">

app.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
@author: Viet Nguyen <[email protected]>
3+
"""
4+
import os
5+
import random
6+
import string
7+
from flask import Flask, request, render_template
8+
import torch
9+
import torch.nn.functional as F
10+
import csv
11+
import pandas as pd
12+
from nltk.tokenize import sent_tokenize, word_tokenize
13+
import numpy as np
14+
15+
app = Flask(__name__)
16+
APP_ROOT = os.path.dirname(os.path.abspath(__file__))
17+
IMAGES_FOLDER = "flask_images"
18+
rand_str = lambda n: "".join([random.choice(string.ascii_letters + string.digits) for _ in range(n)])
19+
20+
model = None
21+
word2vec = None
22+
max_length_sentences = 0
23+
max_length_word = 0
24+
num_classes = 0
25+
categories = None
26+
27+
28+
@app.route("/")
29+
def home():
30+
return render_template("main.html")
31+
32+
@app.route("/input")
33+
def new_input():
34+
return render_template("input.html")
35+
36+
@app.route("/show", methods=["POST"])
37+
def show():
38+
global model, dictionary, max_length_word, max_length_sentences, num_classes, categories
39+
trained_model = request.files["model"]
40+
if torch.cuda.is_available():
41+
model = torch.load(trained_model)
42+
else:
43+
model = torch.load(trained_model, map_location=lambda storage, loc: storage)
44+
dictionary = pd.read_csv(filepath_or_buffer=request.files["word2vec"], header=None, sep=" ", quoting=csv.QUOTE_NONE,
45+
usecols=[0]).values
46+
dictionary = [word[0] for word in dictionary]
47+
max_length_sentences = model.max_sent_length
48+
max_length_word = model.max_word_length
49+
num_classes = list(model.modules())[-1].out_features
50+
if "classes" in request.files:
51+
df = pd.read_csv(request.files["classes"], header=None)
52+
categories = [item[0] for item in df.values]
53+
return render_template("input.html")
54+
55+
56+
@app.route("/result", methods=["POST"])
57+
def result():
58+
global dictionary, model, max_length_sentences, max_length_word, categories
59+
text = request.form["message"]
60+
document_encode = [
61+
[dictionary.index(word) if word in dictionary else -1 for word in word_tokenize(text=sentences)] for sentences
62+
in sent_tokenize(text=text)]
63+
64+
for sentences in document_encode:
65+
if len(sentences) < max_length_word:
66+
extended_words = [-1 for _ in range(max_length_word - len(sentences))]
67+
sentences.extend(extended_words)
68+
69+
if len(document_encode) < max_length_sentences:
70+
extended_sentences = [[-1 for _ in range(max_length_word)] for _ in
71+
range(max_length_sentences - len(document_encode))]
72+
document_encode.extend(extended_sentences)
73+
74+
document_encode = [sentences[:max_length_word] for sentences in document_encode][
75+
:max_length_sentences]
76+
77+
document_encode = np.stack(arrays=document_encode, axis=0)
78+
document_encode += 1
79+
empty_array = np.zeros_like(document_encode, dtype=np.int64)
80+
input_array = np.stack([document_encode, empty_array], axis=0)
81+
feature = torch.from_numpy(input_array)
82+
if torch.cuda.is_available():
83+
feature = feature.cuda()
84+
model.eval()
85+
with torch.no_grad():
86+
model._init_hidden_state(2)
87+
prediction = model(feature)
88+
prediction = F.softmax(prediction)
89+
max_prob, max_prob_index = torch.max(prediction, dim=-1)
90+
prob = "{:.2f} %".format(float(max_prob[0])*100)
91+
if categories != None:
92+
category = categories[int(max_prob_index[0])]
93+
else:
94+
category = int(max_prob_index[0]) + 1
95+
return render_template("result.html", text=text, value=prob, index=category)
96+
97+
98+
if __name__ == "__main__":
99+
app.secret_key = os.urandom(12)
100+
app.run(host="0.0.0.0", port=4555, debug=True)

demo/1.png

252 KB
Loading

demo/2.png

129 KB
Loading

demo/3.png

182 KB
Loading

demo/4.png

182 KB
Loading

demo/5.png

182 KB
Loading

demo/6.png

132 KB
Loading

demo/7.png

109 KB
Loading

demo/8.png

135 KB
Loading

demo/9.png

179 KB
Loading

demo/agnews_100.png

69.9 KB
Loading

demo/agnews_200.png

65.7 KB
Loading

demo/agnews_300.png

72.8 KB
Loading

demo/agnews_50.png

67.4 KB
Loading

demo/amazonreviewfull_200.png

63.3 KB
Loading

demo/amazonreviewfull_50.png

74.2 KB
Loading

demo/amazonreviewpolarity_100.png

64.7 KB
Loading

demo/amazonreviewpolarity_300.png

68.5 KB
Loading

demo/amazonreviewpolarity_50.png

75.2 KB
Loading

demo/dbpedia_100.png

70.5 KB
Loading

demo/dbpedia_200.png

66.5 KB
Loading

demo/dbpedia_300.png

77.2 KB
Loading

demo/dbpedia_50.png

77.1 KB
Loading

demo/empty.png

3.23 KB
Loading

demo/output.gif

37.1 MB
Loading

demo/output.png

87.4 KB
Loading

demo/video.gif

9.24 MB
Loading

demo/video.mp4

1.42 MB
Binary file not shown.

demo/yahoo_100.png

62.2 KB
Loading

demo/yahoo_200.png

61.4 KB
Loading

demo/yahoo_300.png

69.3 KB
Loading

demo/yahoo_50.png

69.8 KB
Loading

demo/yelpreviewfull_300.png

63.9 KB
Loading

demo/yelpreviewfull_50.png

62.4 KB
Loading

demo/yelpreviewpolarity_100.png

66.6 KB
Loading

demo/yelpreviewpolarity_200.png

68 KB
Loading

demo/yelpreviewpolarity_50.png

68.4 KB
Loading

src/dataset.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
@author: Viet Nguyen <[email protected]>
3+
"""
4+
import pandas as pd
5+
from torch.utils.data.dataset import Dataset
6+
import csv
7+
from nltk.tokenize import sent_tokenize, word_tokenize
8+
import numpy as np
9+
10+
11+
class MyDataset(Dataset):
12+
13+
def __init__(self, data_path, dict_path, max_length_sentences=30, max_length_word=35):
14+
super(MyDataset, self).__init__()
15+
16+
texts, labels = [], []
17+
with open(data_path) as csv_file:
18+
reader = csv.reader(csv_file, quotechar='"')
19+
for idx, line in enumerate(reader):
20+
text = ""
21+
for tx in line[1:]:
22+
text += tx.lower()
23+
text += " "
24+
label = int(line[0]) - 1
25+
texts.append(text)
26+
labels.append(label)
27+
28+
self.texts = texts
29+
self.labels = labels
30+
self.dict = pd.read_csv(filepath_or_buffer=dict_path, header=None, sep=" ", quoting=csv.QUOTE_NONE,
31+
usecols=[0]).values
32+
self.dict = [word[0] for word in self.dict]
33+
self.max_length_sentences = max_length_sentences
34+
self.max_length_word = max_length_word
35+
self.num_classes = len(set(self.labels))
36+
37+
def __len__(self):
38+
return len(self.labels)
39+
40+
def __getitem__(self, index):
41+
label = self.labels[index]
42+
text = self.texts[index]
43+
document_encode = [
44+
[self.dict.index(word) if word in self.dict else -1 for word in word_tokenize(text=sentences)] for sentences
45+
in
46+
sent_tokenize(text=text)]
47+
48+
for sentences in document_encode:
49+
if len(sentences) < self.max_length_word:
50+
extended_words = [-1 for _ in range(self.max_length_word - len(sentences))]
51+
sentences.extend(extended_words)
52+
53+
if len(document_encode) < self.max_length_sentences:
54+
extended_sentences = [[-1 for _ in range(self.max_length_word)] for _ in
55+
range(self.max_length_sentences - len(document_encode))]
56+
document_encode.extend(extended_sentences)
57+
58+
document_encode = [sentences[:self.max_length_word] for sentences in document_encode][
59+
:self.max_length_sentences]
60+
61+
document_encode = np.stack(arrays=document_encode, axis=0)
62+
document_encode += 1
63+
64+
return document_encode.astype(np.int64), label
65+
66+
67+
if __name__ == '__main__':
68+
test = MyDataset(data_path="../data/test.csv", dict_path="../data/glove.6B.50d.txt")
69+
print (test.__getitem__(index=1)[0].shape)

src/hierarchical_att_model.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
@author: Viet Nguyen <[email protected]>
3+
"""
4+
import torch
5+
import torch.nn as nn
6+
from src.sent_att_model import SentAttNet
7+
from src.word_att_model import WordAttNet
8+
9+
10+
class HierAttNet(nn.Module):
11+
def __init__(self, word_hidden_size, sent_hidden_size, batch_size, num_classes, pretrained_word2vec_path,
12+
max_sent_length, max_word_length):
13+
super(HierAttNet, self).__init__()
14+
self.batch_size = batch_size
15+
self.word_hidden_size = word_hidden_size
16+
self.sent_hidden_size = sent_hidden_size
17+
self.max_sent_length = max_sent_length
18+
self.max_word_length = max_word_length
19+
self.word_att_net = WordAttNet(pretrained_word2vec_path, word_hidden_size)
20+
self.sent_att_net = SentAttNet(sent_hidden_size, word_hidden_size, num_classes)
21+
self._init_hidden_state()
22+
23+
def _init_hidden_state(self, last_batch_size=None):
24+
if last_batch_size:
25+
batch_size = last_batch_size
26+
else:
27+
batch_size = self.batch_size
28+
self.word_hidden_state = torch.zeros(2, batch_size, self.word_hidden_size)
29+
self.sent_hidden_state = torch.zeros(2, batch_size, self.sent_hidden_size)
30+
if torch.cuda.is_available():
31+
self.word_hidden_state = self.word_hidden_state.cuda()
32+
self.sent_hidden_state = self.sent_hidden_state.cuda()
33+
34+
def forward(self, input):
35+
36+
output_list = []
37+
input = input.permute(1, 0, 2)
38+
for i in input:
39+
output, self.word_hidden_state = self.word_att_net(i.permute(1, 0), self.word_hidden_state)
40+
output_list.append(output)
41+
output = torch.cat(output_list, 0)
42+
output, self.sent_hidden_state = self.sent_att_net(output, self.sent_hidden_state)
43+
44+
return output

0 commit comments

Comments
 (0)