1
+ #encoding:utf-8
2
+ from torch .utils import data
3
+ import torch as t
4
+ import numpy as np
5
+ import random
6
+ from glob import glob
7
+ class StackData (data .Dataset ):
8
+ def __init__ (self ,data_root ,labels_file ):
9
+ self .data_files_path = glob (data_root + "*val.pth" )
10
+ self .model_num = len (self .data_files_path )
11
+ self .label_file_path = labels_file
12
+ self .data = t .zeros (100 ,1999 * self .model_num )
13
+ for i in range (self .model_num ):
14
+ self .data [:,i * 1999 :i * 1999 + 1999 ]= t .sigmoid (t .load (self .data_files_path [i ]).float ()[:100 ])
15
+ print self .data .size ()
16
+
17
+ class ZhihuData (data .Dataset ):
18
+
19
+ def __init__ (self ,train_root ,labels_file ,type_ = 'char' ):
20
+ '''
21
+ Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
22
+ '''
23
+ import json
24
+ with open (labels_file ) as f :
25
+ labels_ = json .load (f )
26
+
27
+ # embedding_d = np.load(embedding_root)['vector']
28
+ question_d = np .load (train_root )
29
+ self .type_ = type_
30
+ if type_ == 'char' :
31
+ all_data_title ,all_data_content = \
32
+ question_d ['title_char' ],question_d ['content_char' ]
33
+
34
+ elif type_ == 'word' :
35
+ all_data_title ,all_data_content = \
36
+ question_d ['title_word' ],question_d ['content_word' ]
37
+
38
+ self .train_data = all_data_title [:- 20000 ],all_data_content [:- 20000 ]
39
+ self .val_data = all_data_title [- 20000 :],all_data_content [- 20000 :]
40
+
41
+ self .all_num = len (all_data_content )
42
+ # del all_data_title,all_data_content
43
+
44
+ self .data_title ,self .data_content = self .train_data
45
+ self .len_ = len (self .data_title )
46
+
47
+ self .index2qid = question_d ['index2qid' ].item ()
48
+ self .l_end = 0
49
+ self .labels = labels_ ['d' ]
50
+
51
+ # def augument(self,d):
52
+ # '''
53
+ # 数据增强之: 随机偏移
54
+ # '''
55
+ # if self.type_=='char':
56
+ # _index = (-8,8)
57
+ # else :_index =(-5,5)
58
+ # r = d.new(d.size()).fill_(0)
59
+ # index = random.randint(-3,4)
60
+ # if _index >0:
61
+ # r[index:] = d[:-index]
62
+ # else:
63
+ # r[:-index] = d[index:]
64
+ # return r
65
+
66
+ # def augument(self,d,type_=1):
67
+ # if type_==1:
68
+ # return self.shuffle(d)
69
+ # else :
70
+ # if self.type_=='char':
71
+ # return self.dropout(d,p=0.6)
72
+
73
+ def shuffle (self ,d ):
74
+ return np .random .permutation (d .tolist ())
75
+
76
+ def dropout (self ,d ,p = 0.5 ):
77
+ len_ = len (d )
78
+ index = np .random .choice (len_ ,int (len_ * p ))
79
+ d [index ]= 0
80
+ return d
81
+
82
+ def train (self , train = True ):
83
+ if train :
84
+ self .data_title ,self .data_content = self .train_data
85
+ self .l_end = 0
86
+ else :
87
+ self .data_title ,self .data_content = self .val_data
88
+ self .l_end = self .all_num - 200000
89
+ self .len_ = len (self .data_content )
90
+ return self
91
+
92
+ def __getitem__ (self ,index ):
93
+ '''
94
+ for (title,content),label in dataloader:
95
+ train
96
+ `
97
+ 当使用char时
98
+ title: (50,)
99
+ content: (250,)
100
+ labels:(1999,)
101
+ '''
102
+ title ,content = self .data_title [index ],self .data_content [index ]
103
+ qid = self .index2qid [index + self .l_end ]
104
+ labels = self .labels [qid ]
105
+ data = (t .from_numpy (title ).long (),t .from_numpy (content ).long ())
106
+ label_tensor = t .zeros (1999 ).scatter_ (0 ,t .LongTensor (labels ),1 ).long ()
107
+ return data ,label_tensor
108
+
109
+ def __len__ (self ):
110
+ return self .len_
111
+
112
+ class ZhihuALLData (data .Dataset ):
113
+
114
+ def __init__ (self ,train_root ,labels_file ,type_ = 'char' ):
115
+ '''
116
+ Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
117
+ '''
118
+ import json
119
+ with open (labels_file ) as f :
120
+ labels_ = json .load (f )
121
+
122
+ # embedding_d = np.load(embedding_root)['vector']
123
+ question_d = np .load (train_root )
124
+
125
+ # all_data_title,all_data_content =\
126
+ all_char_title ,all_char_content = question_d ['title_char' ],question_d ['content_char' ]
127
+ # all_data_title,all_data_content =\
128
+ all_word_title ,all_word_content = question_d ['title_word' ],question_d ['content_word' ]
129
+
130
+ self .train_data = (all_char_title [:- 20000 ],all_char_content [:- 20000 ]),( all_word_title [:- 20000 ],all_word_content [:- 20000 ])
131
+ self .val_data = (all_char_title [- 20000 :],all_char_content [- 20000 :]), (all_word_title [- 20000 :],all_word_content [- 20000 :])
132
+ self .all_num = len (all_char_title )
133
+ # del all_data_title,all_data_content
134
+
135
+ self .data_title ,self .data_content = self .train_data
136
+ self .len_ = len (self .data_title [0 ])
137
+
138
+ self .index2qid = question_d ['index2qid' ].item ()
139
+ self .l_end = 0
140
+ self .labels = labels_ ['d' ]
141
+
142
+
143
+ def train (self , train = True ):
144
+ if train :
145
+ self .data_title ,self .data_content = self .train_data
146
+ self .l_end = 0
147
+ else :
148
+ self .data_title ,self .data_content = self .val_data
149
+ self .l_end = self .all_num - 20000
150
+ self .len_ = len (self .data_content [0 ])
151
+ return self
152
+
153
+ def __getitem__ (self ,index ):
154
+ '''
155
+ for (title,content),label in dataloader:
156
+ train
157
+ `
158
+ 当使用char时
159
+ title: (50,)
160
+ content: (250,)
161
+ labels:(1999,)
162
+ '''
163
+ char ,word = (self .data_titl
0 commit comments