Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit adda08e

Browse files
authoredSep 24, 2020
TrainingLoop: refactor progress printer and add CSVLogger (#668)
1 parent fafa4f0 commit adda08e

File tree

15 files changed

+556
-274
lines changed

15 files changed

+556
-274
lines changed
 

‎Examples/LeNet-MNIST/main.swift

+19-10
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,33 @@ let dataset = MNIST(batchSize: batchSize, on: device)
3131

3232
// The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`.
3333
var classifier = Sequential {
34-
Conv2D<Float>(filterShape: (5, 5, 1, 6), padding: .same, activation: relu)
35-
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
36-
Conv2D<Float>(filterShape: (5, 5, 6, 16), activation: relu)
37-
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
38-
Flatten<Float>()
39-
Dense<Float>(inputSize: 400, outputSize: 120, activation: relu)
40-
Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
41-
Dense<Float>(inputSize: 84, outputSize: 10)
34+
Conv2D<Float>(filterShape: (5, 5, 1, 6), padding: .same, activation: relu)
35+
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
36+
Conv2D<Float>(filterShape: (5, 5, 6, 16), activation: relu)
37+
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
38+
Flatten<Float>()
39+
Dense<Float>(inputSize: 400, outputSize: 120, activation: relu)
40+
Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
41+
Dense<Float>(inputSize: 84, outputSize: 10)
4242
}
4343

4444
var optimizer = SGD(for: classifier, learningRate: 0.1)
4545

46-
let trainingProgress = TrainingProgress()
4746
var trainingLoop = TrainingLoop(
4847
training: dataset.training,
4948
validation: dataset.validation,
5049
optimizer: optimizer,
5150
lossFunction: softmaxCrossEntropy,
52-
callbacks: [trainingProgress.update])
51+
metrics: [.accuracy],
52+
callbacks: [try! CSVLogger().log])
53+
54+
// Compute statistics only when last batch ends.
55+
trainingLoop.statisticsRecorder!.shouldCompute = {
56+
(
57+
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
58+
_ event: TrainingLoopEvent
59+
) -> Bool in
60+
return event == .batchEnd && batchIndex + 1 == batchCount
61+
}
5362

5463
try! trainingLoop.fit(&classifier, epochs: epochCount, on: device)

‎Examples/MobileNetV1-Imagenette/main.swift

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224,
2929
var model = MobileNetV1(classCount: 10)
3030
let optimizer = SGD(for: model, learningRate: 0.02, momentum: 0.9)
3131

32-
let trainingProgress = TrainingProgress()
3332
var trainingLoop = TrainingLoop(
3433
training: dataset.training,
3534
validation: dataset.validation,
3635
optimizer: optimizer,
3736
lossFunction: softmaxCrossEntropy,
38-
callbacks: [trainingProgress.update])
37+
metrics: [.accuracy])
3938

4039
try! trainingLoop.fit(&model, epochs: 10, on: device)

‎Examples/MobileNetV2-Imagenette/main.swift

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224,
2929
var model = MobileNetV2(classCount: 10)
3030
let optimizer = SGD(for: model, learningRate: 0.002, momentum: 0.9)
3131

32-
let trainingProgress = TrainingProgress()
3332
var trainingLoop = TrainingLoop(
3433
training: dataset.training,
3534
validation: dataset.validation,
3635
optimizer: optimizer,
3736
lossFunction: softmaxCrossEntropy,
38-
callbacks: [trainingProgress.update])
37+
metrics: [.accuracy])
3938

4039
try! trainingLoop.fit(&model, epochs: 10, on: device)

‎Examples/ResNet-CIFAR10/main.swift

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ let dataset = CIFAR10(batchSize: 10, on: device)
2929
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
3030
var optimizer = SGD(for: model, learningRate: 0.001)
3131

32-
let trainingProgress = TrainingProgress()
3332
var trainingLoop = TrainingLoop(
3433
training: dataset.training,
3534
validation: dataset.validation,
3635
optimizer: optimizer,
3736
lossFunction: softmaxCrossEntropy,
38-
callbacks: [trainingProgress.update])
37+
metrics: [.accuracy])
3938

4039
try! trainingLoop.fit(&model, epochs: 10, on: device)

‎Examples/VGG-Imagewoof/main.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ public func scheduleLearningRate<L: TrainingLoopProtocol>(
3939
}
4040
}
4141

42-
let trainingProgress = TrainingProgress()
4342
var trainingLoop = TrainingLoop(
4443
training: dataset.training,
4544
validation: dataset.validation,
4645
optimizer: optimizer,
4746
lossFunction: softmaxCrossEntropy,
48-
callbacks: [trainingProgress.update, scheduleLearningRate])
47+
metrics: [.accuracy],
48+
callbacks: [scheduleLearningRate])
4949

5050
try! trainingLoop.fit(&model, epochs: 90, on: device)

‎Support/FileSystem.swift

+1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ public protocol File {
3939
func read(position: Int, count: Int) throws -> Data
4040
func write(_ value: Data) throws
4141
func write(_ value: Data, position: Int) throws
42+
func append(_ value: Data) throws
4243
}

‎Support/FoundationFileSystem.swift

+10
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,14 @@ public struct FoundationFile: File {
5858
// TODO: Incorporate file offset.
5959
try value.write(to: location)
6060
}
61+
62+
/// Append data to the file.
63+
///
64+
/// Parameter value: data to be appended at the end.
65+
public func append(_ value: Data) throws {
66+
let fileHandler = try FileHandle(forUpdating: location)
67+
try fileHandler.seekToEnd()
68+
try fileHandler.write(contentsOf: value)
69+
try fileHandler.close()
70+
}
6171
}

‎TrainingLoop/CMakeLists.txt

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
add_library(TrainingLoop
22
LossFunctions.swift
3+
Metrics.swift
34
TrainingLoop.swift
4-
TrainingProgress.swift
5-
TrainingStatistics.swift)
5+
Callbacks/StatisticsRecorder.swift
6+
Callbacks/ProgressPrinter.swift
7+
Callbacks/CSVLogger.swift)
68
target_link_libraries(TrainingLoop PUBLIC
79
ModelSupport)
810
set_target_properties(TrainingLoop PROPERTIES
+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import Foundation
2+
import ModelSupport
3+
4+
public enum CSVLoggerError: Error {
5+
case InvalidPath
6+
}
7+
8+
/// A handler for logging training and validation statistics to a CSV file.
9+
public class CSVLogger {
10+
/// The path of the file that statistics are logged to.
11+
public var path: String
12+
13+
// True iff the header of the CSV file has been written.
14+
fileprivate var headerWritten: Bool
15+
16+
/// Creates an instance that logs to a file with the given path.
17+
///
18+
/// Throws: File system errors.
19+
public init(path: String = "run/log.csv") throws {
20+
self.path = path
21+
22+
// Validate the path.
23+
let url = URL(fileURLWithPath: path)
24+
if url.pathExtension != "csv" {
25+
throw CSVLoggerError.InvalidPath
26+
}
27+
// Create the containing directory if it is missing.
28+
try FoundationFileSystem().createDirectoryIfMissing(at: url.deletingLastPathComponent().path)
29+
// Initialize the file with empty string.
30+
try FoundationFile(path: path).write(Data())
31+
32+
self.headerWritten = false
33+
}
34+
35+
/// Logs the statistics for the 'loop' when 'batchEnd' event happens;
36+
/// ignoring other events.
37+
///
38+
/// Throws: File system errors.
39+
public func log<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
40+
switch event {
41+
case .batchEnd:
42+
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount,
43+
let batchIndex = loop.batchIndex, let batchCount = loop.batchCount,
44+
let stats = loop.lastStatsLog
45+
else {
46+
// No-Op if trainingLoop doesn't set the required values for stats logging.
47+
return
48+
}
49+
50+
if !headerWritten {
51+
try writeHeader(stats: stats)
52+
headerWritten = true
53+
}
54+
55+
try writeDataRow(
56+
epoch: "\(epochIndex + 1)/\(epochCount)",
57+
batch: "\(batchIndex + 1)/\(batchCount)",
58+
stats: stats)
59+
default:
60+
return
61+
}
62+
}
63+
64+
func writeHeader(stats: [(name: String, value: Float)]) throws {
65+
let header = (["epoch", "batch"] + stats.lazy.map { $0.name }).joined(separator: ", ") + "\n"
66+
try FoundationFile(path: path).append(header.data(using: .utf8)!)
67+
}
68+
69+
func writeDataRow(epoch: String, batch: String, stats: [(name: String, value: Float)]) throws {
70+
let dataRow = ([epoch, batch] + stats.lazy.map { String($0.value) }).joined(separator: ", ")
71+
+ "\n"
72+
try FoundationFile(path: path).append(dataRow.data(using: .utf8)!)
73+
}
74+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
let progressBarLength = 30
18+
19+
/// A handler for printing the training and validation progress.
20+
public class ProgressPrinter {
21+
/// Print training or validation progress in response of the 'event'.
22+
///
23+
/// An example of the progress would be:
24+
/// Epoch 1/12
25+
/// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513
26+
/// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521
27+
public func print<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
28+
switch event {
29+
case .epochStart:
30+
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else {
31+
// No-Op if trainingLoop doesn't set the required values for progress printing.
32+
return
33+
}
34+
35+
Swift.print("Epoch \(epochIndex + 1)/\(epochCount)")
36+
case .batchEnd:
37+
guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else {
38+
// No-Op if trainingLoop doesn't set the required values for progress printing.
39+
return
40+
}
41+
42+
let progressBar = formatProgressBar(
43+
progress: Float(batchIndex + 1) / Float(batchCount), length: progressBarLength)
44+
var stats: String = ""
45+
if let lastStatsLog = loop.lastStatsLog {
46+
stats = formatStats(lastStatsLog)
47+
}
48+
49+
Swift.print(
50+
"\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)",
51+
terminator: ""
52+
)
53+
fflush(stdout)
54+
case .epochEnd:
55+
Swift.print("")
56+
case .validationStart:
57+
Swift.print("")
58+
default:
59+
return
60+
}
61+
}
62+
63+
func formatProgressBar(progress: Float, length: Int) -> String {
64+
let progressSteps = Int(round(Float(length) * progress))
65+
let leading = String(repeating: "=", count: progressSteps)
66+
let separator: String
67+
let trailing: String
68+
if progressSteps < progressBarLength {
69+
separator = ">"
70+
trailing = String(repeating: ".", count: progressBarLength - progressSteps - 1)
71+
} else {
72+
separator = ""
73+
trailing = ""
74+
}
75+
return "[\(leading)\(separator)\(trailing)]"
76+
}
77+
78+
func formatStats(_ stats: [(String, Float)]) -> String {
79+
var result = ""
80+
for stat in stats {
81+
result += " - \(stat.0): \(String(format: "%.4f", stat.1))"
82+
}
83+
return result
84+
}
85+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
import TensorFlow
15+
16+
/// A handler for recording training and validation statistics.
17+
///
18+
/// Data produced by this handler can be used by ProgressPrinter, CVSLogger, etc.
19+
public class StatisticsRecorder {
20+
/// A Closure that returns if should call 'reset' on metricMeasurers.
21+
public var shouldReset:
22+
(
23+
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
24+
_ event: TrainingLoopEvent
25+
) -> Bool
26+
27+
/// A Closure that returns if should call 'accumulate' on metricMeasurers.
28+
public var shouldAccumulate:
29+
(
30+
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
31+
_ event: TrainingLoopEvent
32+
) -> Bool
33+
34+
/// A Closure that returns if should call 'compute' on metricMeasurers.
35+
public var shouldCompute:
36+
(
37+
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
38+
_ event: TrainingLoopEvent
39+
) -> Bool
40+
41+
/// Instances of MetricsMeasurers.
42+
fileprivate var metricMeasurers: [MetricsMeasurer]
43+
44+
/// Create an instance that records 'metrics' during the training loop.
45+
public init(metrics: [TrainingMetrics]) {
46+
metricMeasurers = metrics.map { $0.measurer }
47+
48+
shouldReset = {
49+
(
50+
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
51+
_ event: TrainingLoopEvent
52+
) -> Bool in
53+
return event == .trainingStart || event == .validationStart
54+
}
55+
56+
shouldAccumulate = {
57+
(
58+
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
59+
_ event: TrainingLoopEvent
60+
) -> Bool in
61+
return event == .batchEnd
62+
}
63+
64+
shouldCompute = {
65+
(
66+
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
67+
_ event: TrainingLoopEvent
68+
) -> Bool in
69+
return event == .batchEnd
70+
}
71+
}
72+
73+
/// Recording statistics in response of the 'event'.
74+
///
75+
/// It will record the statistics into 'lastStatsLog' in the loop where other
76+
/// callbacks can consume from.
77+
public func record<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
78+
guard let batchIndex = loop.batchIndex,
79+
let batchCount = loop.batchCount,
80+
let epochIndex = loop.batchIndex,
81+
let epochCount = loop.epochCount,
82+
let loss = loop.lastStepLoss,
83+
let output = loop.lastStepOutput,
84+
let target = loop.lastStepTarget
85+
else {
86+
// No-Op if trainingLoop doesn't set the required values for stats recording.
87+
return
88+
}
89+
90+
if shouldReset(batchIndex, batchCount, epochIndex, epochCount, event) {
91+
resetMetricMeasurers()
92+
loop.lastStatsLog = nil
93+
}
94+
95+
if shouldAccumulate(batchIndex, batchCount, epochIndex, epochCount, event) {
96+
accumulateMetrics(loss: loss, predictions: output, labels: target)
97+
}
98+
99+
if shouldCompute(batchIndex, batchCount, epochIndex, epochCount, event) {
100+
loop.lastStatsLog = computeMetrics()
101+
}
102+
}
103+
104+
func resetMetricMeasurers() {
105+
for index in metricMeasurers.indices {
106+
metricMeasurers[index].reset()
107+
}
108+
}
109+
110+
func accumulateMetrics<Output, Target>(loss: Tensor<Float>, predictions: Output, labels: Target) {
111+
for index in metricMeasurers.indices {
112+
metricMeasurers[index].accumulate(loss: loss, predictions: predictions, labels: labels)
113+
}
114+
}
115+
116+
func computeMetrics() -> [(String, Float)] {
117+
var result: [(String, Float)] = []
118+
for measurer in metricMeasurers {
119+
result.append((name: measurer.name, value: measurer.measure()))
120+
}
121+
return result
122+
}
123+
}

‎TrainingLoop/Metrics.swift

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import TensorFlow
2+
3+
/// Metrics that can be registered into TrainingLoop.
4+
public enum TrainingMetrics {
5+
case loss
6+
case accuracy
7+
8+
public var name: String {
9+
switch self {
10+
case .loss:
11+
return "loss"
12+
case .accuracy:
13+
return "accuracy"
14+
}
15+
}
16+
17+
public var measurer: MetricsMeasurer {
18+
switch self {
19+
case .loss:
20+
return LossMeasurer(self.name)
21+
case .accuracy:
22+
return AccuracyMeasurer(self.name)
23+
}
24+
}
25+
}
26+
27+
/// A protocal defining functionalities of a metrics measurer.
28+
public protocol MetricsMeasurer {
29+
var name: String { get set }
30+
mutating func reset()
31+
mutating func accumulate<Output, Target>(
32+
loss: Tensor<Float>?, predictions: Output?, labels: Target?
33+
)
34+
func measure() -> Float
35+
}
36+
37+
/// A measurer for measuring loss.
38+
public struct LossMeasurer: MetricsMeasurer {
39+
public var name: String
40+
41+
private var totalBatchLoss: Float = 0
42+
private var batchCount: Int32 = 0
43+
44+
public init(_ name: String = "loss") {
45+
self.name = name
46+
}
47+
48+
public mutating func reset() {
49+
totalBatchLoss = 0
50+
batchCount = 0
51+
}
52+
53+
public mutating func accumulate<Output, Target>(
54+
loss: Tensor<Float>?, predictions: Output?, labels: Target?
55+
) {
56+
if let newBatchLoss = loss {
57+
totalBatchLoss += newBatchLoss.scalarized()
58+
batchCount += 1
59+
}
60+
}
61+
62+
public func measure() -> Float {
63+
return totalBatchLoss / Float(batchCount)
64+
}
65+
}
66+
67+
/// A measurer for measuring accuracy
68+
public struct AccuracyMeasurer: MetricsMeasurer {
69+
public var name: String
70+
71+
private var correctGuessCount: Int32 = 0
72+
private var totalGuessCount: Int32 = 0
73+
74+
public init(_ name: String = "accuracy") {
75+
self.name = name
76+
}
77+
78+
public mutating func reset() {
79+
correctGuessCount = 0
80+
totalGuessCount = 0
81+
}
82+
83+
public mutating func accumulate<Output, Target>(
84+
loss: Tensor<Float>?, predictions: Output?, labels: Target?
85+
) {
86+
guard let predictions = predictions as? Tensor<Float>, let labels = labels as? Tensor<Int32>
87+
else {
88+
fatalError(
89+
"For accuracy measurements, the model output must be Tensor<Float>, and the labels must be Tensor<Int>."
90+
)
91+
}
92+
correctGuessCount += Tensor<Int32>(predictions.argmax(squeezingAxis: 1) .== labels).sum()
93+
.scalarized()
94+
totalGuessCount += Int32(labels.shape[0])
95+
}
96+
97+
public func measure() -> Float {
98+
return Float(correctGuessCount) / Float(totalGuessCount)
99+
}
100+
}

‎TrainingLoop/TrainingLoop.swift

+135-36
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import TensorFlow
1818
// Workaround https://bugs.swift.org/browse/TF-1122 that prevents us from registering a
1919
// loss function inside our TrainingLoop struct
2020
public final class LossFunctionWrapper<Output: Differentiable, Target> {
21-
public typealias F = @differentiable (Output, @noDerivative Target) -> Tensor<Float>
21+
public typealias F = @differentiable(Output, @noDerivative Target) -> Tensor<Float>
2222
public var f: F
2323
init(_ f: @escaping F) { self.f = f }
2424
}
@@ -34,64 +34,94 @@ public protocol TrainingLoopProtocol {
3434
where
3535
Training: Sequence, Training.Element: Collection,
3636
Training.Element.Element == LabeledData<Opt.Model.Input, Target>
37+
3738
/// The type of the collection of batches for the validation data.
3839
associatedtype Validation
3940
where
4041
Validation: Collection,
4142
Validation.Element == LabeledData<Opt.Model.Input, Target>
43+
4244
/// The type of the target of our model.
4345
associatedtype Target
46+
4447
/// The type of the optimizer used.
4548
associatedtype Opt: Optimizer where Opt.Model: Module
4649

4750
// Typealiases
4851
/// The type of the model.
4952
typealias Model = Opt.Model
53+
5054
/// The type of the input of the model.
5155
typealias Input = Opt.Model.Input
56+
5257
/// The type of the output of the model.
5358
typealias Output = Opt.Model.Output
59+
5460
/// The type of a batch.
5561
typealias Batch = LabeledData<Input, Target>
62+
5663
// In a wrapper for now because of TF-1122.
5764
/// The type of the loss function.
5865
typealias LossFunction = LossFunctionWrapper<Output, Target>
5966

6067
// Data
6168
/// The training epochs.
6269
var training: Training { get }
70+
6371
/// The validation batches.
6472
var validation: Validation { get }
6573

6674
// Optimizer and loss function
6775
/// The optimizer.
6876
var optimizer: Opt { get set }
77+
6978
/// The loss function.
7079
var lossFunction: LossFunction { get set }
7180

81+
/// The metrics
82+
var metrics: [TrainingMetrics] { get set }
83+
7284
// Callbacks
7385
/// The callbacks used to customize the training loop.
7486
var callbacks: [TrainingLoopCallback<Self>] { get set }
7587

7688
// Temporary data
89+
90+
// MARK: - Step-level data
91+
7792
/// The last input fed to the model.
78-
var lastInput: Input? { get set }
93+
var lastStepInput: Input? { get set }
94+
7995
/// The last target.
80-
var lastTarget: Target? { get set }
96+
var lastStepTarget: Target? { get set }
97+
8198
/// The last predictions of the model.
82-
var lastOutput: Output? { get set }
99+
var lastStepOutput: Output? { get set }
100+
83101
/// The last gradients computed.
84-
var lastGradient: Model.TangentVector? { get set }
102+
var lastStepGradient: Model.TangentVector? { get set }
103+
85104
/// The last loss.
86-
var lastLoss: Tensor<Float>? { get set }
87-
/// The number of epochs we are currently fitting for.
88-
var epochCount: Int? { get set }
89-
/// The index of the current epoch.
90-
var epochIndex: Int? { get set }
105+
var lastStepLoss: Tensor<Float>? { get set }
106+
91107
/// The number of batches in the current collection of batches.
92108
var batchCount: Int? { get set }
109+
93110
/// The index of the current batch.
94111
var batchIndex: Int? { get set }
112+
113+
// MARK: - Epoch-level data
114+
115+
/// The number of epochs we are currently fitting for.
116+
var epochCount: Int? { get set }
117+
118+
/// The index of the current epoch.
119+
var epochIndex: Int? { get set }
120+
121+
// MARK: - Others
122+
123+
/// The log for last statistics
124+
var lastStatsLog: [(name: String, value: Float)]? { get set }
95125
}
96126

97127
/// The events that occur during a call to `fit` in the `TrainingLoop`
@@ -101,26 +131,37 @@ public protocol TrainingLoopProtocol {
101131
public enum TrainingLoopEvent {
102132
/// The start of a fit.
103133
case fitStart
134+
104135
/// The end of a fit.
105136
case fitEnd
137+
106138
/// The start of one epoch (training + validation).
107139
case epochStart
140+
108141
/// The start of one epoch (training + validation).
109142
case epochEnd
143+
110144
/// The start of a training phase.
111145
case trainingStart
146+
112147
/// The end of a training phase.
113148
case trainingEnd
149+
114150
/// The start of a validation phase.
115151
case validationStart
152+
116153
/// The end of a validation phase.
117154
case validationEnd
155+
118156
/// The start of a training or inference step on a batch.
119157
case batchStart
158+
120159
/// The end of a training or inference step on a batch.
121160
case batchEnd
161+
122162
/// At the start of the optimizer update, just after the differentiable step.
123163
case updateStart
164+
124165
/// Just after the model prediction at inference, before computing the loss.
125166
case inferencePredictionEnd
126167
}
@@ -146,87 +187,139 @@ where
146187
// Typealiases
147188
/// The type of the model.
148189
public typealias Model = Opt.Model
190+
149191
/// The type of the input of the model.
150192
public typealias Input = Opt.Model.Input
193+
151194
/// The type of the output of the model.
152195
public typealias Output = Opt.Model.Output
196+
153197
/// The type of a batch.
154198
public typealias Batch = LabeledData<Input, Target>
199+
155200
// In a wrapper for now because of TF-1122.
156201
/// The type of the loss function.
157202
public typealias LossFunction = LossFunctionWrapper<Output, Target>
158203

159204
// Data
160205
/// The training epochs.
161206
public let training: Training
207+
162208
/// The validation batches.
163209
public let validation: Validation
164210

165211
// Optimizer and loss function
166212
/// The optimizer.
167213
public var optimizer: Opt
214+
168215
/// The loss function
169216
public var lossFunction: LossFunction
170217

171-
// Callbacks
172-
/// The callbacks used to customize the training loop.
173-
public var callbacks: [TrainingLoopCallback<Self>] = []
218+
/// The metrics
219+
public var metrics: [TrainingMetrics]
220+
221+
/// Callbacks
222+
223+
// MARK: - The callbacks used to customize the training loop.
224+
225+
public var callbacks: [TrainingLoopCallback<Self>]
226+
227+
// MARK: - Default callback objects
228+
229+
public var statisticsRecorder: StatisticsRecorder? = nil
230+
231+
public var progressPrinter: ProgressPrinter? = nil
232+
233+
/// Temporary data
234+
235+
// MARK: - Step-level data
174236

175-
// Temporary data
176237
/// The last input fed to the model.
177-
public var lastInput: Input? = nil
238+
public var lastStepInput: Input? = nil
239+
178240
/// The last target.
179-
public var lastTarget: Target? = nil
241+
public var lastStepTarget: Target? = nil
242+
180243
/// The last predictions of the model.
181-
public var lastOutput: Output? = nil
244+
public var lastStepOutput: Output? = nil
245+
182246
/// The last gradients computed.
183-
public var lastGradient: Model.TangentVector? = nil
247+
public var lastStepGradient: Model.TangentVector? = nil
248+
184249
/// The last loss.
185-
public var lastLoss: Tensor<Float>? = nil
186-
/// The number of epochs we are currently fitting for.
187-
public var epochCount: Int? = nil
188-
/// The index of the current epoch.
189-
public var epochIndex: Int? = nil
250+
public var lastStepLoss: Tensor<Float>? = nil
251+
190252
/// The number of batches in the current collection of batches.
191253
public var batchCount: Int? = nil
254+
192255
/// The index of the current batch.
193256
public var batchIndex: Int? = nil
194257

258+
// MARK: - Epoch-level data
259+
260+
/// The number of epochs we are currently fitting for.
261+
public var epochCount: Int? = nil
262+
263+
/// The index of the current epoch.
264+
public var epochIndex: Int? = nil
265+
266+
// MARK: - Others
267+
268+
/// The log for last statistics
269+
public var lastStatsLog: [(name: String, value: Float)]? = nil
270+
195271
/// Creates an instance from `training` and `validation` data, a `model`, an `optimizer` and a
196272
/// `lossFunction`.
197273
///
198274
/// Parameter callbacks: Callbacks that the `TrainingLoop` will use in every call to fit.
199275
public init(
200276
training: Training, validation: Validation, optimizer: Opt,
201-
lossFunction: @escaping LossFunction.F, callbacks: [TrainingLoopCallback<Self>] = []
277+
lossFunction: @escaping LossFunction.F,
278+
metrics: [TrainingMetrics] = [],
279+
callbacks: [TrainingLoopCallback<Self>] = [],
280+
includeDefaultCallbacks: Bool = true
202281
) {
203282
self.training = training
204283
self.validation = validation
205284
self.optimizer = optimizer
206285
self.lossFunction = LossFunction(lossFunction)
207-
self.callbacks = callbacks
286+
self.metrics = metrics
287+
288+
if includeDefaultCallbacks {
289+
let statisticsRecorder = StatisticsRecorder(metrics: [.loss] + metrics)
290+
let progressPrinter = ProgressPrinter()
291+
self.statisticsRecorder = statisticsRecorder
292+
self.progressPrinter = progressPrinter
293+
self.callbacks = [
294+
statisticsRecorder.record,
295+
progressPrinter.print,
296+
] + callbacks
297+
} else {
298+
self.callbacks = callbacks
299+
}
208300
}
209301
}
210302

211303
extension TrainingLoop {
212304
/// The default differentiable step.
213305
public mutating func differentiableStep(model: Model) throws {
214-
guard let data = lastInput else { return }
215-
guard let target = lastTarget else { return }
216-
(lastLoss, lastGradient) = valueWithGradient(at: model) { (model: Model) -> Tensor<Float> in
306+
guard let data = lastStepInput else { return }
307+
guard let target = lastStepTarget else { return }
308+
(lastStepLoss, lastStepGradient) = valueWithGradient(at: model) {
309+
(model: Model) -> Tensor<Float> in
217310
let predictions = model(data)
218-
lastOutput = predictions
311+
lastStepOutput = predictions
219312
return lossFunction.f(predictions, target)
220313
}
221314
}
222315

223316
/// The step used for inference.
224317
public mutating func inferenceStep(model: Model) throws {
225-
guard let data = lastInput else { return }
226-
lastOutput = model(data)
227-
guard let target = lastTarget else { return }
318+
guard let data = lastStepInput else { return }
319+
lastStepOutput = model(data)
320+
guard let target = lastStepTarget else { return }
228321
try handleEvent(.inferencePredictionEnd)
229-
lastLoss = lossFunction.f(lastOutput!, target)
322+
lastStepLoss = lossFunction.f(lastStepOutput!, target)
230323
}
231324

232325
/// The step used for training.
@@ -235,7 +328,7 @@ extension TrainingLoop {
235328
) throws {
236329
try differentiableStep(model, &self)
237330
try handleEvent(.updateStart)
238-
optimizer.update(&model, along: lastGradient!)
331+
optimizer.update(&model, along: lastStepGradient!)
239332
}
240333
}
241334

@@ -245,12 +338,16 @@ extension TrainingLoop {
245338
public enum TrainingLoopAction: Error {
246339
/// Abort actions in the current training/inference step and goes to the next batch.
247340
case cancelBatch
341+
248342
/// Abort actions in the current training phase and goes to the validation phase.
249343
case cancelTraining
344+
250345
/// Abort actions in the current validation phase and goes to the next epoch.
251346
case cancelValidation
347+
252348
/// Abort actions in the current epoch and goes to the next epoch.
253349
case cancelEpoch
350+
254351
/// Abort actions in the current fit and ends fitting.
255352
case cancelFit
256353
}
@@ -272,7 +369,7 @@ extension TrainingLoop {
272369
batchCount = batches.count
273370
for (i, batch) in batches.enumerated() {
274371
batchIndex = i
275-
(lastInput, lastTarget) = (batch.data, batch.label)
372+
(lastStepInput, lastStepTarget) = (batch.data, batch.label)
276373
do {
277374
try handleEvent(.batchStart)
278375
try step(&self)
@@ -294,7 +391,9 @@ extension TrainingLoop {
294391
public mutating func fit(
295392
_ model: inout Model, epochs: Int, callbacks: [TrainingLoopCallback<Self>] = [],
296393
on device: Device = Device.default,
297-
differentiableStep: (Model, inout Self) throws -> Void = { try $1.differentiableStep(model: $0) }
394+
differentiableStep: (Model, inout Self) throws -> Void = {
395+
try $1.differentiableStep(model: $0)
396+
}
298397
) throws {
299398
let callbacksCount = self.callbacks.count
300399
self.callbacks += callbacks

‎TrainingLoop/TrainingProgress.swift

-110
This file was deleted.

‎TrainingLoop/TrainingStatistics.swift

-108
This file was deleted.

0 commit comments

Comments
 (0)
This repository has been archived.