Skip to content

Commit 16dd3b8

Browse files
committed
init
0 parents  commit 16dd3b8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+11464
-0
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
*.pyc
2+
*.csv
3+
.*
4+
*.zip

checkpoints/.gitkeep

Whitespace-only changes.

config.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#coding:utf8
2+
import time
3+
import warnings
4+
5+
tfmt = '%m%d_%H%M%S'
6+
class Config(object):
7+
'''
8+
并不是所有的配置都生效,实际运行中只根据需求获取自己需要的参数
9+
'''
10+
11+
loss = 'multilabelloss'
12+
model='CNNText'
13+
title_dim = 100 # 标题的卷积核数
14+
content_dim = 200 #描述的卷积核数
15+
num_classes = 1999 # 类别
16+
embedding_dim = 256 # embedding大小
17+
linear_hidden_size = 2000 # 全连接层隐藏元数目
18+
kmax_pooling = 2# k
19+
hidden_size = 256 #LSTM hidden size
20+
num_layers=2 #LSTM layers
21+
inception_dim = 512 #inception的卷积核数
22+
23+
# vocab_size = 11973 # num of chars
24+
vocab_size = 411720 # num of words
25+
kernel_size = 3 #单尺度卷积核
26+
kernel_sizes = [2,3,4] #多尺度卷积核
27+
title_seq_len = 50 # 标题长度,word为30 char为50
28+
content_seq_len = 250 #描述长度 word为120 char为250
29+
type_='word' #word 和char
30+
all=False # 模型同时训练char和word
31+
32+
embedding_path = '/mnt/7/zhihu/ieee_zhihu_cup/data/char_embedding.npz' # Embedding
33+
train_data_path = '/mnt/7/zhihu/ieee_zhihu_cup/data/train.npz' # train
34+
labels_path = '/mnt/7/zhihu/ieee_zhihu_cup/data/labels.json' # labels
35+
test_data_path='/mnt/7/zhihu/ieee_zhihu_cup/data/test.npz' # test
36+
result_path='csv/'+time.strftime(tfmt)+'.csv'
37+
shuffle = True # 是否需要打乱数据
38+
num_workers = 4 # 多线程加载所需要的线程数目
39+
pin_memory = True # 数据从CPU->pin_memory—>GPU加速
40+
batch_size = 128
41+
42+
env = time.strftime(tfmt) # Visdom env
43+
plot_every = 10 # 每10个batch,更新visdom等
44+
45+
max_epoch=100
46+
lr = 5e-3 # 学习率
47+
lr2 = 1e-3 # embedding层的学习率
48+
min_lr = 1e-5 # 当学习率低于这个值,就退出训练
49+
lr_decay = 0.99 # 当一个epoch的损失开始上升lr = lr*lr_decay
50+
weight_decay = 0 #2e-5 # 权重衰减
51+
weight = 1 # 正负样本的weight
52+
decay_every = 3000 #每多少个batch 查看一下score,并随之修改学习率
53+
54+
model_path = None # 如果有 就加载
55+
optimizer_path='optimizer.pth' # 优化器的保存地址
56+
57+
debug_file = '/tmp/debug2' #若该文件存在则进入debug模式
58+
debug=False
59+
60+
gpu1 = False #如果在GPU1上运行代码,则需要修改数据存放的路径
61+
floyd=False # 服务如果在floyd上运行需要修改文件路径
62+
zhuge=False # 服务如果在zhuge上运行,修改文件路径
63+
64+
### multimode 用到的
65+
model_names=['MultiCNNTextBNDeep','CNNText_inception','RCNN','LSTMText','CNNText_inception']
66+
model_paths = ['checkpoints/MultiCNNTextBNDeep_0.37125473788','checkpoints/CNNText_tmp_0.380390420742','checkpoints/RCNN_word_0.373609030286','checkpoints/LSTMText_word_0.381833388089','checkpoints/CNNText_tmp_0.376364647145']#,'checkpoints/CNNText_tmp_0.402429167301']
67+
static=False # 是否训练embedding
68+
val=False # 跑测试集还是验证集?
69+
70+
fold = 1 # 数据集fold, 0或1 见 data/fold_dataset.py
71+
augument=True # 是否进行数据增强
72+
73+
###stack
74+
model_num=7
75+
data_root="/data/text/zhihu/result/"
76+
labels_file="/home/a/code/pytorch/zhihu/ddd/labels.json"
77+
val="/home/a/code/pytorch/zhihu/ddd/val.npz"
78+
79+
def parse(self,kwargs,print_=True):
80+
'''
81+
根据字典kwargs 更新 config参数
82+
'''
83+
for k,v in kwargs.iteritems():
84+
if not hasattr(self,k):
85+
raise Exception("opt has not attribute <%s>" %k)
86+
setattr(self,k,v)
87+
88+
###### 根据程序在哪台服务器运行,自动修正数据存放路径 ######
89+
if self.gpu1:
90+
self.train_data_path='/mnt/zhihu/data/train.npz'
91+
self.test_data_path='/mnt/zhihu/data/%s.npz' %('val' if self.val else 'test')
92+
self.labels_path='/mnt/zhihu/data/labels.json'
93+
self.embedding_path=self.embedding_path.replace('/mnt/7/zhihu/ieee_zhihu_cup/','/mnt/zhihu/')
94+
95+
if self.floyd:
96+
self.train_data_path='/data/train.npz'
97+
self.test_data_path='/data/%s.npz' %('val' if self.val else 'test')
98+
self.labels_path='/data/labels.json'
99+
self.embedding_path='/data/char_embedding.npz'
100+
if self.zhuge:
101+
self.train_data_path='./ddd/train.npz'
102+
self.test_data_path='./ddd/%s.npz' %('val' if self.val else 'test')
103+
self.labels_path='./ddd/labels.json'
104+
self.embedding_path='./ddd/char_embedding.npz'
105+
106+
### word和char的长度不一样 ##
107+
if self.type_=='word':
108+
self.vocab_size = 411720 # num of words
109+
self.title_seq_len = 30
110+
self.content_seq_len = 120
111+
self.embedding_path=self.embedding_path.replace('char','word') if self.embedding_path is not None else None
112+
113+
if self.type_=='char':
114+
self.vocab_size = 11973 # num of words
115+
self.title_seq_len = 50
116+
self.content_seq_len = 250
117+
118+
if self.model_path:
119+
self.embedding_path=None
120+
121+
if print_:
122+
print('user config:')
123+
print('#################################')
124+
for k in dir(self):
125+
if not k.startswith('_') and k!='parse' and k!='state_dict':
126+
print k,getattr(self,k)
127+
print('#################################')
128+
return self
129+
130+
def state_dict(self):
131+
return {k:getattr(self,k) for k in dir(self) if not k.startswith('_') and k!='parse' and k!='state_dict' }
132+
133+
134+
Config.parse = parse
135+
Config.state_dict = state_dict
136+
opt = Config()

data/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .fold_dataset import FoldData

data/dataset.1.py

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#encoding:utf-8
2+
from torch.utils import data
3+
import torch as t
4+
import numpy as np
5+
import random
6+
from glob import glob
7+
class StackData(data.Dataset):
8+
def __init__(self,data_root,labels_file):
9+
self.data_files_path=glob(data_root+"*val.pth")
10+
self.model_num=len(self.data_files_path)
11+
self.label_file_path=labels_file
12+
self.data=t.zeros(100,1999*self.model_num)
13+
for i in range(self.model_num):
14+
self.data[:,i*1999:i*1999+1999]=t.sigmoid(t.load(self.data_files_path[i]).float()[:100])
15+
print self.data.size()
16+
17+
class ZhihuData(data.Dataset):
18+
19+
def __init__(self,train_root,labels_file,type_='char'):
20+
'''
21+
Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
22+
'''
23+
import json
24+
with open(labels_file) as f:
25+
labels_ = json.load(f)
26+
27+
# embedding_d = np.load(embedding_root)['vector']
28+
question_d = np.load(train_root)
29+
self.type_=type_
30+
if type_ == 'char':
31+
all_data_title,all_data_content =\
32+
question_d['title_char'],question_d['content_char']
33+
34+
elif type_ == 'word':
35+
all_data_title,all_data_content =\
36+
question_d['title_word'],question_d['content_word']
37+
38+
self.train_data = all_data_title[:-20000],all_data_content[:-20000]
39+
self.val_data = all_data_title[-20000:],all_data_content[-20000:]
40+
41+
self.all_num = len(all_data_content)
42+
# del all_data_title,all_data_content
43+
44+
self.data_title,self.data_content = self.train_data
45+
self.len_ = len(self.data_title)
46+
47+
self.index2qid = question_d['index2qid'].item()
48+
self.l_end=0
49+
self.labels = labels_['d']
50+
51+
# def augument(self,d):
52+
# '''
53+
# 数据增强之: 随机偏移
54+
# '''
55+
# if self.type_=='char':
56+
# _index = (-8,8)
57+
# else :_index =(-5,5)
58+
# r = d.new(d.size()).fill_(0)
59+
# index = random.randint(-3,4)
60+
# if _index >0:
61+
# r[index:] = d[:-index]
62+
# else:
63+
# r[:-index] = d[index:]
64+
# return r
65+
66+
# def augument(self,d,type_=1):
67+
# if type_==1:
68+
# return self.shuffle(d)
69+
# else :
70+
# if self.type_=='char':
71+
# return self.dropout(d,p=0.6)
72+
73+
def shuffle(self,d):
74+
return np.random.permutation(d.tolist())
75+
76+
def dropout(self,d,p=0.5):
77+
len_ = len(d)
78+
index = np.random.choice(len_,int(len_*p))
79+
d[index]=0
80+
return d
81+
82+
def train(self, train=True):
83+
if train:
84+
self.data_title,self.data_content = self.train_data
85+
self.l_end = 0
86+
else:
87+
self.data_title,self.data_content = self.val_data
88+
self.l_end = self.all_num-200000
89+
self.len_ = len(self.data_content)
90+
return self
91+
92+
def __getitem__(self,index):
93+
'''
94+
for (title,content),label in dataloader:
95+
train
96+
`
97+
当使用char时
98+
title: (50,)
99+
content: (250,)
100+
labels:(1999,)
101+
'''
102+
title,content = self.data_title[index],self.data_content[index]
103+
qid = self.index2qid[index+self.l_end]
104+
labels = self.labels[qid]
105+
data = (t.from_numpy(title).long(),t.from_numpy(content).long())
106+
label_tensor = t.zeros(1999).scatter_(0,t.LongTensor(labels),1).long()
107+
return data,label_tensor
108+
109+
def __len__(self):
110+
return self.len_
111+
112+
class ZhihuALLData(data.Dataset):
113+
114+
def __init__(self,train_root,labels_file,type_='char'):
115+
'''
116+
Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
117+
'''
118+
import json
119+
with open(labels_file) as f:
120+
labels_ = json.load(f)
121+
122+
# embedding_d = np.load(embedding_root)['vector']
123+
question_d = np.load(train_root)
124+
125+
# all_data_title,all_data_content =\
126+
all_char_title,all_char_content= question_d['title_char'],question_d['content_char']
127+
# all_data_title,all_data_content =\
128+
all_word_title,all_word_content= question_d['title_word'],question_d['content_word']
129+
130+
self.train_data = (all_char_title[:-20000],all_char_content[:-20000]),( all_word_title[:-20000],all_word_content[:-20000])
131+
self.val_data = (all_char_title[-20000:],all_char_content[-20000:]), (all_word_title[-20000:],all_word_content[-20000:])
132+
self.all_num = len(all_char_title)
133+
# del all_data_title,all_data_content
134+
135+
self.data_title,self.data_content = self.train_data
136+
self.len_ = len(self.data_title[0])
137+
138+
self.index2qid = question_d['index2qid'].item()
139+
self.l_end=0
140+
self.labels = labels_['d']
141+
142+
143+
def train(self, train=True):
144+
if train:
145+
self.data_title,self.data_content = self.train_data
146+
self.l_end = 0
147+
else:
148+
self.data_title,self.data_content = self.val_data
149+
self.l_end = self.all_num-20000
150+
self.len_ = len(self.data_content[0])
151+
return self
152+
153+
def __getitem__(self,index):
154+
'''
155+
for (title,content),label in dataloader:
156+
train
157+
`
158+
当使用char时
159+
title: (50,)
160+
content: (250,)
161+
labels:(1999,)
162+
'''
163+
char,word = (self.data_titl

0 commit comments

Comments
 (0)