Skip to content

Commit fcb4c7e

Browse files
committed
cuda as CLI option
1 parent e2c3451 commit fcb4c7e

File tree

4 files changed

+167
-57
lines changed

4 files changed

+167
-57
lines changed

Diff for: iterative_machine_teaching/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,10 @@
1313
SurrogateTeacher,
1414
Teacher,
1515
)
16-
from .train import TeachingType, train
16+
from .train import (
17+
DatasetOptions,
18+
StudentOptions,
19+
TeacherOptions,
20+
TeachingType,
21+
train,
22+
)

Diff for: iterative_machine_teaching/__main__.py

+74-6
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,68 @@
22
import argparse
33

44
from .data import load_gaussian, load_mnist
5-
from .train import TeachingType, train
5+
from .train import (
6+
DatasetOptions,
7+
StudentOptions,
8+
TeacherOptions,
9+
TeachingType,
10+
train,
11+
)
612

713

814
def main() -> None:
915
parser = argparse.ArgumentParser("IterativeMachineTeaching")
1016

17+
parser.add_argument("--cuda", type=bool, action="store_true")
1118
parser.add_argument("kind", type=TeachingType, choices=list(TeachingType))
19+
parser.add_argument("--train-ratio", type=float, default=4.0 / 5.0)
1220

21+
# student / example options
1322
parser.add_argument(
14-
"-l",
15-
"--limit-train",
23+
"--student-examples",
1624
type=int,
1725
default=-1,
1826
help="Number of examples in student train dataset, "
1927
"negative value means max",
2028
)
29+
parser.add_argument(
30+
"--student-steps",
31+
type=int,
32+
default=1024,
33+
help="Number of forward / backward steps for student",
34+
)
35+
parser.add_argument(
36+
"--student-batch-size",
37+
type=int,
38+
default=8,
39+
help="Batch size for student and example",
40+
)
41+
parser.add_argument(
42+
"--student-lr",
43+
type=float,
44+
default=1e-3,
45+
help="Student and example learning rate",
46+
)
47+
48+
# teacher options
49+
parser.add_argument(
50+
"--teacher-lr", type=float, default=1e-3, help="Teacher learning rate"
51+
)
52+
parser.add_argument(
53+
"--teacher-batch-size", type=int, default=8, help="Teacher batch size"
54+
)
55+
parser.add_argument(
56+
"--research-batch-size",
57+
type=int,
58+
default=512,
59+
help="Batch size for example research",
60+
)
61+
parser.add_argument(
62+
"--teacher-epochs",
63+
type=int,
64+
default=16,
65+
help="Teacher training epochs",
66+
)
2167

2268
dataset_subparser = parser.add_subparsers(
2369
title="dataset", dest="dataset", required=True
@@ -35,14 +81,36 @@ def main() -> None:
3581
args = parser.parse_args()
3682

3783
if args.dataset == "mnist":
38-
dataset = load_mnist(args.input_pickle)
84+
x, y = load_mnist(args.input_pickle)
3985
elif args.dataset == "gaussian":
40-
dataset = load_gaussian(args.dim, args.per_class_example)
86+
x, y = load_gaussian(args.dim, args.per_class_example)
4187
else:
4288
parser.error("Unrecognized dataset")
4389
return
4490

45-
train(dataset, args.dataset, args.kind, args.limit_train)
91+
dataset_options = DatasetOptions(args.dataset, x, y, args.train_ratio)
92+
93+
student_options = StudentOptions(
94+
args.student_examples,
95+
args.student_steps,
96+
args.student_batch_size,
97+
args.student_lr,
98+
)
99+
100+
teacher_options = TeacherOptions(
101+
args.teacher_lr,
102+
args.teacher_batch_size,
103+
args.research_batch_size,
104+
args.teacher_epochs,
105+
)
106+
107+
train(
108+
dataset_options,
109+
args.kind,
110+
teacher_options,
111+
student_options,
112+
args.cuda,
113+
)
46114

47115

48116
if __name__ == "__main__":

Diff for: iterative_machine_teaching/train.py

+84-46
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
from enum import Enum
3-
from typing import Dict, Tuple, Type
3+
from typing import Dict, NamedTuple, Tuple, Type
44

55
import matplotlib.pyplot as plt
66
import torch as th
@@ -21,6 +21,36 @@
2121
Teacher,
2222
)
2323

24+
DatasetOptions = NamedTuple(
25+
"DatasetOptions",
26+
[
27+
("name", str),
28+
("x", th.Tensor),
29+
("y", th.Tensor),
30+
("train_ratio", float),
31+
],
32+
)
33+
34+
StudentOptions = NamedTuple(
35+
"StudentOptions",
36+
[
37+
("examples", int),
38+
("steps", int),
39+
("batch_size", int),
40+
("learning_rate", float),
41+
],
42+
)
43+
44+
TeacherOptions = NamedTuple(
45+
"TeacherOptions",
46+
[
47+
("learning_rate", float),
48+
("batch_size", int),
49+
("research_batch_size", int),
50+
("nb_epoch", int),
51+
],
52+
)
53+
2454

2555
class TeachingType(Enum):
2656
OMNISCIENT = "OMNISCIENT"
@@ -49,56 +79,68 @@ def get_student(self, clf: Classifier, learning_rate: float) -> Student:
4979

5080

5181
def train(
52-
dataset: Tuple[th.Tensor, th.Tensor],
53-
dataset_name: str,
82+
dataset_options: DatasetOptions,
5483
kind: TeachingType,
55-
example_nb_student: int,
84+
teacher_options: TeacherOptions,
85+
student_options: StudentOptions,
86+
cuda: bool,
5687
) -> None:
5788

58-
x, y = dataset
89+
assert 0.0 < dataset_options.train_ratio < 1.0
5990

60-
num_features = x.size()[1] # 784
61-
num_classes = th.unique(y).size()[0] # 10
91+
x, y = dataset_options.x, dataset_options.y
92+
93+
num_features = x.size()[1]
94+
num_classes = th.unique(y).size()[0]
6295

6396
print(
64-
f'Dataset "{dataset_name}" of {x.size()[0]} '
97+
f'Dataset "{dataset_options.name}" of {x.size()[0]} '
6598
f"examples with {kind.value} teacher."
6699
)
67100

68-
ratio_train = 4.0 / 5.0
69-
limit_train = int(x.size()[0] * ratio_train)
101+
limit_train = int(x.size()[0] * dataset_options.train_ratio)
70102

71-
x_train = x[:limit_train, :].cuda()
72-
y_train = y[:limit_train].cuda()
103+
x_train = x[:limit_train, :]
104+
y_train = y[:limit_train]
73105

74-
x_test = x[limit_train:, :].cuda()
75-
y_test = y[limit_train:].cuda()
106+
x_test = x[limit_train:, :]
107+
y_test = y[limit_train:]
76108

77109
# create models
78-
student_model = LinearClassifier(num_features, num_classes).cuda()
79-
teacher_model = LinearClassifier(num_features, num_classes).cuda()
110+
student_model = LinearClassifier(num_features, num_classes)
111+
example_model = LinearClassifier(num_features, num_classes)
112+
teacher_model = LinearClassifier(num_features, num_classes)
113+
114+
# cuda or not
115+
if cuda:
116+
x_train = x_train.cuda()
117+
y_train = y_train.cuda()
118+
119+
x_test = x_test.cuda()
120+
y_test = y_test.cuda()
80121

81-
# create student and teacher
82-
learning_rate = 1e-3
83-
research_batch_size = 512
122+
student_model = student_model.cuda()
123+
example_model = example_model.cuda()
124+
teacher_model = teacher_model.cuda()
84125

85-
student = kind.get_student(student_model, learning_rate)
126+
# create student, example and teacher
127+
student = kind.get_student(student_model, student_options.learning_rate)
128+
example = ModelWrapper(example_model, student_options.learning_rate)
86129
teacher = kind.get_teacher(
87-
teacher_model, learning_rate, research_batch_size
130+
teacher_model,
131+
teacher_options.learning_rate,
132+
teacher_options.research_batch_size,
88133
)
89134

90135
# Train teacher
91136
print("Train teacher...")
137+
nb_batch_teacher = x_train.size()[0] // teacher_options.batch_size
92138

93-
nb_epoch_teacher = 25
94-
batch_size_teacher = 32
95-
nb_batch_teacher = x_train.size()[0] // batch_size_teacher
96-
97-
tqdm_bar = tqdm(range(nb_epoch_teacher))
139+
tqdm_bar = tqdm(range(teacher_options.nb_epoch))
98140
for e in tqdm_bar:
99141
for b_idx in range(nb_batch_teacher):
100-
i_min = b_idx * batch_size_teacher
101-
i_max = (b_idx + 1) * batch_size_teacher
142+
i_min = b_idx * teacher_options.batch_size
143+
i_max = (b_idx + 1) * teacher_options.batch_size
102144

103145
_ = teacher.train(x_train[i_min:i_max], y_train[i_min:i_max])
104146

@@ -109,35 +151,31 @@ def train(
109151

110152
tqdm_bar.set_description(f"Epoch {e} : F1-Score = {f1_score_value}")
111153

112-
# For comparison
154+
# For benchmark
113155

114156
# to avoid a lot of compute...
115157
# if negative -> all train examples
116-
example_nb_student = (
117-
example_nb_student if example_nb_student >= 0 else x_train.size()[0]
158+
student_examples = (
159+
student_options.examples
160+
if student_options.examples >= 0
161+
else x_train.size()[0]
118162
)
119-
x_train = x_train[:example_nb_student]
120-
y_train = y_train[:example_nb_student]
163+
x_train = x_train[:student_examples]
164+
y_train = y_train[:student_examples]
121165

122-
rounds = 1024
123-
batch_size = 16
124-
nb_batch = x_train.size()[0] // batch_size
166+
nb_batch = x_train.size()[0] // student_options.batch_size
125167

126168
# train example
127169
print("Train example...")
128170

129-
example = ModelWrapper(
130-
LinearClassifier(num_features, num_classes).cuda(), learning_rate
131-
)
132-
133171
batch_index_example = 0
134172
loss_values_example = []
135173
metrics_example = []
136174

137-
for _ in tqdm(range(rounds)):
175+
for _ in tqdm(range(student_options.steps)):
138176
b_idx = batch_index_example % nb_batch
139-
i_min = b_idx * batch_size
140-
i_max = (b_idx + 1) * batch_size
177+
i_min = b_idx * student_options.batch_size
178+
i_max = (b_idx + 1) * student_options.batch_size
141179

142180
loss = example.train(x_train[i_min:i_max], y_train[i_min:i_max])
143181

@@ -158,9 +196,9 @@ def train(
158196
loss_values_student = []
159197
metrics_student = []
160198

161-
for _ in tqdm(range(rounds)):
199+
for _ in tqdm(range(student_options.steps)):
162200
selected_x, selected_y = teacher.select_n_examples(
163-
student, x_train, y_train, batch_size
201+
student, x_train, y_train, student_options.batch_size
164202
)
165203

166204
loss = student.train(selected_x, selected_y)
@@ -180,7 +218,7 @@ def train(
180218
plt.plot(metrics_example, c="blue", label="example - f1 score")
181219
plt.plot(metrics_student, c="red", label="student - f1 score")
182220

183-
plt.title(f"{dataset_name} Linear - {kind.value}")
221+
plt.title(f"{dataset_options.name} Linear - {kind.value}")
184222
plt.xlabel("mini-batch optim steps")
185223
plt.legend()
186224
plt.show()

Diff for: setup.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33

44
setup(
55
name="iterative_machine_teaching",
6-
version="1.0.0",
6+
version="1.1.0",
77
packages=["iterative_machine_teaching"],
88
url="https://github.com/Ipsedo/IterativeMachineTeaching",
9-
license="",
9+
license="GPL-3.0",
1010
author="Ipsedo",
11-
author_email="",
1211
description="Iterative Machine Teaching implementation",
13-
test_suite="tests",
1412
)

0 commit comments

Comments
 (0)