1
1
# -*- coding: utf-8 -*-
2
2
from enum import Enum
3
- from typing import Dict , Tuple , Type
3
+ from typing import Dict , NamedTuple , Tuple , Type
4
4
5
5
import matplotlib .pyplot as plt
6
6
import torch as th
21
21
Teacher ,
22
22
)
23
23
24
+ DatasetOptions = NamedTuple (
25
+ "DatasetOptions" ,
26
+ [
27
+ ("name" , str ),
28
+ ("x" , th .Tensor ),
29
+ ("y" , th .Tensor ),
30
+ ("train_ratio" , float ),
31
+ ],
32
+ )
33
+
34
+ StudentOptions = NamedTuple (
35
+ "StudentOptions" ,
36
+ [
37
+ ("examples" , int ),
38
+ ("steps" , int ),
39
+ ("batch_size" , int ),
40
+ ("learning_rate" , float ),
41
+ ],
42
+ )
43
+
44
+ TeacherOptions = NamedTuple (
45
+ "TeacherOptions" ,
46
+ [
47
+ ("learning_rate" , float ),
48
+ ("batch_size" , int ),
49
+ ("research_batch_size" , int ),
50
+ ("nb_epoch" , int ),
51
+ ],
52
+ )
53
+
24
54
25
55
class TeachingType (Enum ):
26
56
OMNISCIENT = "OMNISCIENT"
@@ -49,56 +79,68 @@ def get_student(self, clf: Classifier, learning_rate: float) -> Student:
49
79
50
80
51
81
def train (
52
- dataset : Tuple [th .Tensor , th .Tensor ],
53
- dataset_name : str ,
82
+ dataset_options : DatasetOptions ,
54
83
kind : TeachingType ,
55
- example_nb_student : int ,
84
+ teacher_options : TeacherOptions ,
85
+ student_options : StudentOptions ,
86
+ cuda : bool ,
56
87
) -> None :
57
88
58
- x , y = dataset
89
+ assert 0.0 < dataset_options . train_ratio < 1.0
59
90
60
- num_features = x .size ()[1 ] # 784
61
- num_classes = th .unique (y ).size ()[0 ] # 10
91
+ x , y = dataset_options .x , dataset_options .y
92
+
93
+ num_features = x .size ()[1 ]
94
+ num_classes = th .unique (y ).size ()[0 ]
62
95
63
96
print (
64
- f'Dataset "{ dataset_name } " of { x .size ()[0 ]} '
97
+ f'Dataset "{ dataset_options . name } " of { x .size ()[0 ]} '
65
98
f"examples with { kind .value } teacher."
66
99
)
67
100
68
- ratio_train = 4.0 / 5.0
69
- limit_train = int (x .size ()[0 ] * ratio_train )
101
+ limit_train = int (x .size ()[0 ] * dataset_options .train_ratio )
70
102
71
- x_train = x [:limit_train , :]. cuda ()
72
- y_train = y [:limit_train ]. cuda ()
103
+ x_train = x [:limit_train , :]
104
+ y_train = y [:limit_train ]
73
105
74
- x_test = x [limit_train :, :]. cuda ()
75
- y_test = y [limit_train :]. cuda ()
106
+ x_test = x [limit_train :, :]
107
+ y_test = y [limit_train :]
76
108
77
109
# create models
78
- student_model = LinearClassifier (num_features , num_classes ).cuda ()
79
- teacher_model = LinearClassifier (num_features , num_classes ).cuda ()
110
+ student_model = LinearClassifier (num_features , num_classes )
111
+ example_model = LinearClassifier (num_features , num_classes )
112
+ teacher_model = LinearClassifier (num_features , num_classes )
113
+
114
+ # cuda or not
115
+ if cuda :
116
+ x_train = x_train .cuda ()
117
+ y_train = y_train .cuda ()
118
+
119
+ x_test = x_test .cuda ()
120
+ y_test = y_test .cuda ()
80
121
81
- # create student and teacher
82
- learning_rate = 1e-3
83
- research_batch_size = 512
122
+ student_model = student_model . cuda ()
123
+ example_model = example_model . cuda ()
124
+ teacher_model = teacher_model . cuda ()
84
125
85
- student = kind .get_student (student_model , learning_rate )
126
+ # create student, example and teacher
127
+ student = kind .get_student (student_model , student_options .learning_rate )
128
+ example = ModelWrapper (example_model , student_options .learning_rate )
86
129
teacher = kind .get_teacher (
87
- teacher_model , learning_rate , research_batch_size
130
+ teacher_model ,
131
+ teacher_options .learning_rate ,
132
+ teacher_options .research_batch_size ,
88
133
)
89
134
90
135
# Train teacher
91
136
print ("Train teacher..." )
137
+ nb_batch_teacher = x_train .size ()[0 ] // teacher_options .batch_size
92
138
93
- nb_epoch_teacher = 25
94
- batch_size_teacher = 32
95
- nb_batch_teacher = x_train .size ()[0 ] // batch_size_teacher
96
-
97
- tqdm_bar = tqdm (range (nb_epoch_teacher ))
139
+ tqdm_bar = tqdm (range (teacher_options .nb_epoch ))
98
140
for e in tqdm_bar :
99
141
for b_idx in range (nb_batch_teacher ):
100
- i_min = b_idx * batch_size_teacher
101
- i_max = (b_idx + 1 ) * batch_size_teacher
142
+ i_min = b_idx * teacher_options . batch_size
143
+ i_max = (b_idx + 1 ) * teacher_options . batch_size
102
144
103
145
_ = teacher .train (x_train [i_min :i_max ], y_train [i_min :i_max ])
104
146
@@ -109,35 +151,31 @@ def train(
109
151
110
152
tqdm_bar .set_description (f"Epoch { e } : F1-Score = { f1_score_value } " )
111
153
112
- # For comparison
154
+ # For benchmark
113
155
114
156
# to avoid a lot of compute...
115
157
# if negative -> all train examples
116
- example_nb_student = (
117
- example_nb_student if example_nb_student >= 0 else x_train .size ()[0 ]
158
+ student_examples = (
159
+ student_options .examples
160
+ if student_options .examples >= 0
161
+ else x_train .size ()[0 ]
118
162
)
119
- x_train = x_train [:example_nb_student ]
120
- y_train = y_train [:example_nb_student ]
163
+ x_train = x_train [:student_examples ]
164
+ y_train = y_train [:student_examples ]
121
165
122
- rounds = 1024
123
- batch_size = 16
124
- nb_batch = x_train .size ()[0 ] // batch_size
166
+ nb_batch = x_train .size ()[0 ] // student_options .batch_size
125
167
126
168
# train example
127
169
print ("Train example..." )
128
170
129
- example = ModelWrapper (
130
- LinearClassifier (num_features , num_classes ).cuda (), learning_rate
131
- )
132
-
133
171
batch_index_example = 0
134
172
loss_values_example = []
135
173
metrics_example = []
136
174
137
- for _ in tqdm (range (rounds )):
175
+ for _ in tqdm (range (student_options . steps )):
138
176
b_idx = batch_index_example % nb_batch
139
- i_min = b_idx * batch_size
140
- i_max = (b_idx + 1 ) * batch_size
177
+ i_min = b_idx * student_options . batch_size
178
+ i_max = (b_idx + 1 ) * student_options . batch_size
141
179
142
180
loss = example .train (x_train [i_min :i_max ], y_train [i_min :i_max ])
143
181
@@ -158,9 +196,9 @@ def train(
158
196
loss_values_student = []
159
197
metrics_student = []
160
198
161
- for _ in tqdm (range (rounds )):
199
+ for _ in tqdm (range (student_options . steps )):
162
200
selected_x , selected_y = teacher .select_n_examples (
163
- student , x_train , y_train , batch_size
201
+ student , x_train , y_train , student_options . batch_size
164
202
)
165
203
166
204
loss = student .train (selected_x , selected_y )
@@ -180,7 +218,7 @@ def train(
180
218
plt .plot (metrics_example , c = "blue" , label = "example - f1 score" )
181
219
plt .plot (metrics_student , c = "red" , label = "student - f1 score" )
182
220
183
- plt .title (f"{ dataset_name } Linear - { kind .value } " )
221
+ plt .title (f"{ dataset_options . name } Linear - { kind .value } " )
184
222
plt .xlabel ("mini-batch optim steps" )
185
223
plt .legend ()
186
224
plt .show ()
0 commit comments