@@ -18,7 +18,7 @@ import TensorFlow
18
18
// Workaround https://bugs.swift.org/browse/TF-1122 that prevents us from registering a
19
19
// loss function inside our TrainingLoop struct
20
20
public final class LossFunctionWrapper < Output: Differentiable , Target> {
21
- public typealias F = @differentiable ( Output , @noDerivative Target ) -> Tensor < Float >
21
+ public typealias F = @differentiable ( Output , @noDerivative Target ) -> Tensor < Float >
22
22
public var f : F
23
23
init ( _ f: @escaping F ) { self . f = f }
24
24
}
@@ -34,64 +34,94 @@ public protocol TrainingLoopProtocol {
34
34
where
35
35
Training: Sequence , Training. Element: Collection ,
36
36
Training. Element. Element == LabeledData < Opt . Model . Input , Target >
37
+
37
38
/// The type of the collection of batches for the validation data.
38
39
associatedtype Validation
39
40
where
40
41
Validation: Collection ,
41
42
Validation. Element == LabeledData < Opt . Model . Input , Target >
43
+
42
44
/// The type of the target of our model.
43
45
associatedtype Target
46
+
44
47
/// The type of the optimizer used.
45
48
associatedtype Opt : Optimizer where Opt. Model: Module
46
49
47
50
// Typealiases
48
51
/// The type of the model.
49
52
typealias Model = Opt . Model
53
+
50
54
/// The type of the input of the model.
51
55
typealias Input = Opt . Model . Input
56
+
52
57
/// The type of the output of the model.
53
58
typealias Output = Opt . Model . Output
59
+
54
60
/// The type of a batch.
55
61
typealias Batch = LabeledData < Input , Target >
62
+
56
63
// In a wrapper for now because of TF-1122.
57
64
/// The type of the loss function.
58
65
typealias LossFunction = LossFunctionWrapper < Output , Target >
59
66
60
67
// Data
61
68
/// The training epochs.
62
69
var training : Training { get }
70
+
63
71
/// The validation batches.
64
72
var validation : Validation { get }
65
73
66
74
// Optimizer and loss function
67
75
/// The optimizer.
68
76
var optimizer : Opt { get set }
77
+
69
78
/// The loss function.
70
79
var lossFunction : LossFunction { get set }
71
80
81
+ /// The metrics
82
+ var metrics : [ TrainingMetrics ] { get set }
83
+
72
84
// Callbacks
73
85
/// The callbacks used to customize the training loop.
74
86
var callbacks : [ TrainingLoopCallback < Self > ] { get set }
75
87
76
88
// Temporary data
89
+
90
+ // MARK: - Step-level data
91
+
77
92
/// The last input fed to the model.
78
- var lastInput : Input ? { get set }
93
+ var lastStepInput : Input ? { get set }
94
+
79
95
/// The last target.
80
- var lastTarget : Target ? { get set }
96
+ var lastStepTarget : Target ? { get set }
97
+
81
98
/// The last predictions of the model.
82
- var lastOutput : Output ? { get set }
99
+ var lastStepOutput : Output ? { get set }
100
+
83
101
/// The last gradients computed.
84
- var lastGradient : Model . TangentVector ? { get set }
102
+ var lastStepGradient : Model . TangentVector ? { get set }
103
+
85
104
/// The last loss.
86
- var lastLoss : Tensor < Float > ? { get set }
87
- /// The number of epochs we are currently fitting for.
88
- var epochCount : Int ? { get set }
89
- /// The index of the current epoch.
90
- var epochIndex : Int ? { get set }
105
+ var lastStepLoss : Tensor < Float > ? { get set }
106
+
91
107
/// The number of batches in the current collection of batches.
92
108
var batchCount : Int ? { get set }
109
+
93
110
/// The index of the current batch.
94
111
var batchIndex : Int ? { get set }
112
+
113
+ // MARK: - Epoch-level data
114
+
115
+ /// The number of epochs we are currently fitting for.
116
+ var epochCount : Int ? { get set }
117
+
118
+ /// The index of the current epoch.
119
+ var epochIndex : Int ? { get set }
120
+
121
+ // MARK: - Others
122
+
123
+ /// The log for last statistics
124
+ var lastStatsLog : [ ( name: String , value: Float ) ] ? { get set }
95
125
}
96
126
97
127
/// The events that occur during a call to `fit` in the `TrainingLoop`
@@ -101,26 +131,37 @@ public protocol TrainingLoopProtocol {
101
131
public enum TrainingLoopEvent {
102
132
/// The start of a fit.
103
133
case fitStart
134
+
104
135
/// The end of a fit.
105
136
case fitEnd
137
+
106
138
/// The start of one epoch (training + validation).
107
139
case epochStart
140
+
108
141
/// The start of one epoch (training + validation).
109
142
case epochEnd
143
+
110
144
/// The start of a training phase.
111
145
case trainingStart
146
+
112
147
/// The end of a training phase.
113
148
case trainingEnd
149
+
114
150
/// The start of a validation phase.
115
151
case validationStart
152
+
116
153
/// The end of a validation phase.
117
154
case validationEnd
155
+
118
156
/// The start of a training or inference step on a batch.
119
157
case batchStart
158
+
120
159
/// The end of a training or inference step on a batch.
121
160
case batchEnd
161
+
122
162
/// At the start of the optimizer update, just after the differentiable step.
123
163
case updateStart
164
+
124
165
/// Just after the model prediction at inference, before computing the loss.
125
166
case inferencePredictionEnd
126
167
}
@@ -146,87 +187,139 @@ where
146
187
// Typealiases
147
188
/// The type of the model.
148
189
public typealias Model = Opt . Model
190
+
149
191
/// The type of the input of the model.
150
192
public typealias Input = Opt . Model . Input
193
+
151
194
/// The type of the output of the model.
152
195
public typealias Output = Opt . Model . Output
196
+
153
197
/// The type of a batch.
154
198
public typealias Batch = LabeledData < Input , Target >
199
+
155
200
// In a wrapper for now because of TF-1122.
156
201
/// The type of the loss function.
157
202
public typealias LossFunction = LossFunctionWrapper < Output , Target >
158
203
159
204
// Data
160
205
/// The training epochs.
161
206
public let training : Training
207
+
162
208
/// The validation batches.
163
209
public let validation : Validation
164
210
165
211
// Optimizer and loss function
166
212
/// The optimizer.
167
213
public var optimizer : Opt
214
+
168
215
/// The loss function
169
216
public var lossFunction : LossFunction
170
217
171
- // Callbacks
172
- /// The callbacks used to customize the training loop.
173
- public var callbacks : [ TrainingLoopCallback < Self > ] = [ ]
218
+ /// The metrics
219
+ public var metrics : [ TrainingMetrics ]
220
+
221
+ /// Callbacks
222
+
223
+ // MARK: - The callbacks used to customize the training loop.
224
+
225
+ public var callbacks : [ TrainingLoopCallback < Self > ]
226
+
227
+ // MARK: - Default callback objects
228
+
229
+ public var statisticsRecorder : StatisticsRecorder ? = nil
230
+
231
+ public var progressPrinter : ProgressPrinter ? = nil
232
+
233
+ /// Temporary data
234
+
235
+ // MARK: - Step-level data
174
236
175
- // Temporary data
176
237
/// The last input fed to the model.
177
- public var lastInput : Input ? = nil
238
+ public var lastStepInput : Input ? = nil
239
+
178
240
/// The last target.
179
- public var lastTarget : Target ? = nil
241
+ public var lastStepTarget : Target ? = nil
242
+
180
243
/// The last predictions of the model.
181
- public var lastOutput : Output ? = nil
244
+ public var lastStepOutput : Output ? = nil
245
+
182
246
/// The last gradients computed.
183
- public var lastGradient : Model . TangentVector ? = nil
247
+ public var lastStepGradient : Model . TangentVector ? = nil
248
+
184
249
/// The last loss.
185
- public var lastLoss : Tensor < Float > ? = nil
186
- /// The number of epochs we are currently fitting for.
187
- public var epochCount : Int ? = nil
188
- /// The index of the current epoch.
189
- public var epochIndex : Int ? = nil
250
+ public var lastStepLoss : Tensor < Float > ? = nil
251
+
190
252
/// The number of batches in the current collection of batches.
191
253
public var batchCount : Int ? = nil
254
+
192
255
/// The index of the current batch.
193
256
public var batchIndex : Int ? = nil
194
257
258
+ // MARK: - Epoch-level data
259
+
260
+ /// The number of epochs we are currently fitting for.
261
+ public var epochCount : Int ? = nil
262
+
263
+ /// The index of the current epoch.
264
+ public var epochIndex : Int ? = nil
265
+
266
+ // MARK: - Others
267
+
268
+ /// The log for last statistics
269
+ public var lastStatsLog : [ ( name: String , value: Float ) ] ? = nil
270
+
195
271
/// Creates an instance from `training` and `validation` data, a `model`, an `optimizer` and a
196
272
/// `lossFunction`.
197
273
///
198
274
/// Parameter callbacks: Callbacks that the `TrainingLoop` will use in every call to fit.
199
275
public init (
200
276
training: Training , validation: Validation , optimizer: Opt ,
201
- lossFunction: @escaping LossFunction . F , callbacks: [ TrainingLoopCallback < Self > ] = [ ]
277
+ lossFunction: @escaping LossFunction . F ,
278
+ metrics: [ TrainingMetrics ] = [ ] ,
279
+ callbacks: [ TrainingLoopCallback < Self > ] = [ ] ,
280
+ includeDefaultCallbacks: Bool = true
202
281
) {
203
282
self . training = training
204
283
self . validation = validation
205
284
self . optimizer = optimizer
206
285
self . lossFunction = LossFunction ( lossFunction)
207
- self . callbacks = callbacks
286
+ self . metrics = metrics
287
+
288
+ if includeDefaultCallbacks {
289
+ let statisticsRecorder = StatisticsRecorder ( metrics: [ . loss] + metrics)
290
+ let progressPrinter = ProgressPrinter ( )
291
+ self . statisticsRecorder = statisticsRecorder
292
+ self . progressPrinter = progressPrinter
293
+ self . callbacks = [
294
+ statisticsRecorder. record,
295
+ progressPrinter. print,
296
+ ] + callbacks
297
+ } else {
298
+ self . callbacks = callbacks
299
+ }
208
300
}
209
301
}
210
302
211
303
extension TrainingLoop {
212
304
/// The default differentiable step.
213
305
public mutating func differentiableStep( model: Model ) throws {
214
- guard let data = lastInput else { return }
215
- guard let target = lastTarget else { return }
216
- ( lastLoss, lastGradient) = valueWithGradient ( at: model) { ( model: Model ) -> Tensor < Float > in
306
+ guard let data = lastStepInput else { return }
307
+ guard let target = lastStepTarget else { return }
308
+ ( lastStepLoss, lastStepGradient) = valueWithGradient ( at: model) {
309
+ ( model: Model ) -> Tensor < Float > in
217
310
let predictions = model ( data)
218
- lastOutput = predictions
311
+ lastStepOutput = predictions
219
312
return lossFunction. f ( predictions, target)
220
313
}
221
314
}
222
315
223
316
/// The step used for inference.
224
317
public mutating func inferenceStep( model: Model ) throws {
225
- guard let data = lastInput else { return }
226
- lastOutput = model ( data)
227
- guard let target = lastTarget else { return }
318
+ guard let data = lastStepInput else { return }
319
+ lastStepOutput = model ( data)
320
+ guard let target = lastStepTarget else { return }
228
321
try handleEvent ( . inferencePredictionEnd)
229
- lastLoss = lossFunction. f ( lastOutput !, target)
322
+ lastStepLoss = lossFunction. f ( lastStepOutput !, target)
230
323
}
231
324
232
325
/// The step used for training.
@@ -235,7 +328,7 @@ extension TrainingLoop {
235
328
) throws {
236
329
try differentiableStep ( model, & self )
237
330
try handleEvent ( . updateStart)
238
- optimizer. update ( & model, along: lastGradient !)
331
+ optimizer. update ( & model, along: lastStepGradient !)
239
332
}
240
333
}
241
334
@@ -245,12 +338,16 @@ extension TrainingLoop {
245
338
public enum TrainingLoopAction : Error {
246
339
/// Abort actions in the current training/inference step and goes to the next batch.
247
340
case cancelBatch
341
+
248
342
/// Abort actions in the current training phase and goes to the validation phase.
249
343
case cancelTraining
344
+
250
345
/// Abort actions in the current validation phase and goes to the next epoch.
251
346
case cancelValidation
347
+
252
348
/// Abort actions in the current epoch and goes to the next epoch.
253
349
case cancelEpoch
350
+
254
351
/// Abort actions in the current fit and ends fitting.
255
352
case cancelFit
256
353
}
@@ -272,7 +369,7 @@ extension TrainingLoop {
272
369
batchCount = batches. count
273
370
for (i, batch) in batches. enumerated ( ) {
274
371
batchIndex = i
275
- ( lastInput , lastTarget ) = ( batch. data, batch. label)
372
+ ( lastStepInput , lastStepTarget ) = ( batch. data, batch. label)
276
373
do {
277
374
try handleEvent ( . batchStart)
278
375
try step ( & self )
@@ -294,7 +391,9 @@ extension TrainingLoop {
294
391
public mutating func fit(
295
392
_ model: inout Model , epochs: Int , callbacks: [ TrainingLoopCallback < Self > ] = [ ] ,
296
393
on device: Device = Device . default,
297
- differentiableStep: ( Model , inout Self ) throws -> Void = { try $1. differentiableStep ( model: $0) }
394
+ differentiableStep: ( Model , inout Self ) throws -> Void = {
395
+ try $1. differentiableStep ( model: $0)
396
+ }
298
397
) throws {
299
398
let callbacksCount = self . callbacks. count
300
399
self . callbacks += callbacks
0 commit comments