Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit de7f0e0

Browse files
authored
Address doc comments in TrainingLoop Callbacks (#670)
1 parent 1ee08bd commit de7f0e0

File tree

6 files changed

+81
-34
lines changed

6 files changed

+81
-34
lines changed

Support/FoundationFileSystem.swift

+3-5
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,8 @@ public struct FoundationFile: File {
5959
try value.write(to: location)
6060
}
6161

62-
/// Append data to the file.
63-
///
64-
/// Parameter value: data to be appended at the end.
65-
public func append(_ value: Data) throws {
62+
/// Appends the bytes in `suffix` to the file.
63+
public func append(_ suffix: Data) throws {
6664
let fileHandler = try FileHandle(forUpdating: location)
6765
#if os(macOS)
6866
// The following are needed in order to build on macOS 10.15 (Catalina). They can be removed
@@ -72,7 +70,7 @@ public struct FoundationFile: File {
7270
fileHandler.closeFile()
7371
#else
7472
try fileHandler.seekToEnd()
75-
try fileHandler.write(contentsOf: value)
73+
try fileHandler.write(contentsOf: suffix)
7674
try fileHandler.close()
7775
#endif
7876
}

TrainingLoop/Callbacks/CSVLogger.swift

+10-4
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ public enum CSVLoggerError: Error {
77

88
/// A handler for logging training and validation statistics to a CSV file.
99
public class CSVLogger {
10-
/// The path of the file that statistics are logged to.
10+
/// The path of the file to which statistics are logged.
1111
public var path: String
1212

1313
// True iff the header of the CSV file has been written.
1414
fileprivate var headerWritten: Bool
1515

16-
/// Creates an instance that logs to a file with the given path.
16+
/// Creates an instance that logs to a file with the given `path`.
1717
///
1818
/// Throws: File system errors.
1919
public init(path: String = "run/log.csv") throws {
@@ -32,7 +32,7 @@ public class CSVLogger {
3232
self.headerWritten = false
3333
}
3434

35-
/// Logs the statistics for the 'loop' when 'batchEnd' event happens;
35+
/// Logs the statistics for `loop` when a `batchEnd` event happens;
3636
/// ignoring other events.
3737
///
3838
/// Throws: File system errors.
@@ -43,7 +43,6 @@ public class CSVLogger {
4343
let batchIndex = loop.batchIndex, let batchCount = loop.batchCount,
4444
let stats = loop.lastStatsLog
4545
else {
46-
// No-Op if trainingLoop doesn't set the required values for stats logging.
4746
return
4847
}
4948

@@ -61,11 +60,18 @@ public class CSVLogger {
6160
}
6261
}
6362

63+
/// Writes a row of column names to the file.
64+
///
65+
/// Column names are "epoch", "batch" and the `name` of each element of `stats`,
66+
/// in that order.
6467
func writeHeader(stats: [(name: String, value: Float)]) throws {
6568
let header = (["epoch", "batch"] + stats.lazy.map { $0.name }).joined(separator: ", ") + "\n"
6669
try FoundationFile(path: path).append(header.data(using: .utf8)!)
6770
}
6871

72+
/// Appends a row of statistics log to file with the given value `epoch` for
73+
/// "epoch" column, `batch` for "batch" column, and `value`s of `stats` for corresponding
74+
/// columns indicated by `stats` `name`s.
6975
func writeDataRow(epoch: String, batch: String, stats: [(name: String, value: Float)]) throws {
7076
let dataRow = ([epoch, batch] + stats.lazy.map { String($0.value) }).joined(separator: ", ")
7177
+ "\n"

TrainingLoop/Callbacks/ProgressPrinter.swift

+20-11
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,37 @@
1414

1515
import Foundation
1616

17-
let progressBarLength = 30
18-
1917
/// A handler for printing the training and validation progress.
18+
///
19+
/// The progress includes epoch and batch index the training is currently
20+
/// in, how many percentages of a full training/validation set has been done,
21+
/// and metric statistics.
2022
public class ProgressPrinter {
21-
/// Print training or validation progress in response of the 'event'.
23+
/// Length of the complete progress bar measured in count of `=` signs.
24+
public var progressBarLength: Int
25+
26+
/// Creates an instance that prints training progress with the complete
27+
/// progress bar to be `progressBarLength` characters long.
28+
public init(progressBarLength: Int = 30) {
29+
self.progressBarLength = progressBarLength
30+
}
31+
32+
/// Prints training or validation progress in response of the `event`.
2233
///
2334
/// An example of the progress would be:
2435
/// Epoch 1/12
2536
/// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513
26-
/// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521
27-
public func print<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
37+
/// 58/79 [======================>.......] - loss: 0.1520 - accuracy: 0.9521
38+
public func printProgress<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
2839
switch event {
2940
case .epochStart:
3041
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else {
31-
// No-Op if trainingLoop doesn't set the required values for progress printing.
3242
return
3343
}
3444

35-
Swift.print("Epoch \(epochIndex + 1)/\(epochCount)")
45+
print("Epoch \(epochIndex + 1)/\(epochCount)")
3646
case .batchEnd:
3747
guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else {
38-
// No-Op if trainingLoop doesn't set the required values for progress printing.
3948
return
4049
}
4150

@@ -46,15 +55,15 @@ public class ProgressPrinter {
4655
stats = formatStats(lastStatsLog)
4756
}
4857

49-
Swift.print(
58+
print(
5059
"\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)",
5160
terminator: ""
5261
)
5362
fflush(stdout)
5463
case .epochEnd:
55-
Swift.print("")
64+
print("")
5665
case .validationStart:
57-
Swift.print("")
66+
print("")
5867
default:
5968
return
6069
}

TrainingLoop/Callbacks/StatisticsRecorder.swift

+15-8
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,35 @@ import TensorFlow
1717
///
1818
/// Data produced by this handler can be used by ProgressPrinter, CVSLogger, etc.
1919
public class StatisticsRecorder {
20-
/// A Closure that returns if should call 'reset' on metricMeasurers.
20+
/// A function that returns `true` iff recorder should call `reset`
21+
/// on `metricMeasurers`.
2122
public var shouldReset:
2223
(
2324
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
2425
_ event: TrainingLoopEvent
2526
) -> Bool
2627

27-
/// A Closure that returns if should call 'accumulate' on metricMeasurers.
28+
/// A function that returns `true` iff recorder should call `accumulate`
29+
/// on `metricMeasurers`.
2830
public var shouldAccumulate:
2931
(
3032
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
3133
_ event: TrainingLoopEvent
3234
) -> Bool
3335

34-
/// A Closure that returns if should call 'compute' on metricMeasurers.
36+
/// A function that returns `true` iff recorder should call `measure`
37+
/// on `metricMeasurers`.
3538
public var shouldCompute:
3639
(
3740
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
3841
_ event: TrainingLoopEvent
3942
) -> Bool
4043

41-
/// Instances of MetricsMeasurers.
44+
/// Instances of MetricsMeasurers that you can reset accumulate and compute
45+
/// statistics periodically.
4246
fileprivate var metricMeasurers: [MetricsMeasurer]
4347

44-
/// Create an instance that records 'metrics' during the training loop.
48+
/// Creates an instance that records `metrics` during the training loop.
4549
public init(metrics: [TrainingMetrics]) {
4650
metricMeasurers = metrics.map { $0.measurer }
4751

@@ -70,9 +74,9 @@ public class StatisticsRecorder {
7074
}
7175
}
7276

73-
/// Recording statistics in response of the 'event'.
77+
/// Records statistics in response of the `event`.
7478
///
75-
/// It will record the statistics into 'lastStatsLog' in the loop where other
79+
/// It will record the statistics into lastStatsLog property in the `loop` where other
7680
/// callbacks can consume from.
7781
public func record<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
7882
guard let batchIndex = loop.batchIndex,
@@ -83,7 +87,6 @@ public class StatisticsRecorder {
8387
let output = loop.lastStepOutput,
8488
let target = loop.lastStepTarget
8589
else {
86-
// No-Op if trainingLoop doesn't set the required values for stats recording.
8790
return
8891
}
8992

@@ -101,18 +104,22 @@ public class StatisticsRecorder {
101104
}
102105
}
103106

107+
/// Resets each of the metricMeasurers.
104108
func resetMetricMeasurers() {
105109
for index in metricMeasurers.indices {
106110
metricMeasurers[index].reset()
107111
}
108112
}
109113

114+
/// Lets each of the metricMeasurers accumulate data from
115+
/// `loss`, `predictions`, `labels`.
110116
func accumulateMetrics<Output, Target>(loss: Tensor<Float>, predictions: Output, labels: Target) {
111117
for index in metricMeasurers.indices {
112118
metricMeasurers[index].accumulate(loss: loss, predictions: predictions, labels: labels)
113119
}
114120
}
115121

122+
/// Lets each of the metricMeasurers compute metrics on cumulated data.
116123
func computeMetrics() -> [(String, Float)] {
117124
var result: [(String, Float)] = []
118125
for measurer in metricMeasurers {

TrainingLoop/Metrics.swift

+28-2
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,46 @@ public enum TrainingMetrics {
2424
}
2525
}
2626

27-
/// A protocal defining functionalities of a metrics measurer.
27+
/// An accumulator of statistics.
2828
public protocol MetricsMeasurer {
29+
/// Name of the metrics.
2930
var name: String { get set }
31+
32+
/// Clears accumulated data up and resets measurer to initial state.
3033
mutating func reset()
34+
35+
/// Accumulates data from `loss`, `predictions`, `labels`.
3136
mutating func accumulate<Output, Target>(
32-
loss: Tensor<Float>?, predictions: Output?, labels: Target?)
37+
loss: Tensor<Float>?, predictions: Output?, labels: Target?
38+
)
39+
40+
/// Computes metrics from cumulated data.
3341
func measure() -> Float
3442
}
3543

3644
/// A measurer for measuring loss.
3745
public struct LossMeasurer: MetricsMeasurer {
46+
/// Name of the LossMeasurer.
3847
public var name: String
3948

49+
/// Sum of losses cumulated from batches.
4050
private var totalBatchLoss: Float = 0
51+
52+
/// Count of batchs cumulated so far.
4153
private var batchCount: Int32 = 0
4254

55+
/// Creates an instance with the LossMeasurer named `name`.
4356
public init(_ name: String = "loss") {
4457
self.name = name
4558
}
4659

60+
/// Resets totalBatchLoss and batchCount to zero.
4761
public mutating func reset() {
4862
totalBatchLoss = 0
4963
batchCount = 0
5064
}
5165

66+
/// Adds `loss` to totalBatchLoss and increases batchCount by one.
5267
public mutating func accumulate<Output, Target>(
5368
loss: Tensor<Float>?, predictions: Output?, labels: Target?
5469
) {
@@ -58,27 +73,37 @@ public struct LossMeasurer: MetricsMeasurer {
5873
}
5974
}
6075

76+
/// Computes averaged loss.
6177
public func measure() -> Float {
6278
return totalBatchLoss / Float(batchCount)
6379
}
6480
}
6581

6682
/// A measurer for measuring accuracy
6783
public struct AccuracyMeasurer: MetricsMeasurer {
84+
/// Name of the AccuracyMeasurer.
6885
public var name: String
6986

87+
/// Count of correct guesses.
7088
private var correctGuessCount: Int32 = 0
89+
90+
/// Count of total guesses.
7191
private var totalGuessCount: Int32 = 0
7292

93+
/// Creates an instance with the AccuracyMeasurer named `name`.
7394
public init(_ name: String = "accuracy") {
7495
self.name = name
7596
}
7697

98+
/// Resets correctGuessCount and totalGuessCount to zero.
7799
public mutating func reset() {
78100
correctGuessCount = 0
79101
totalGuessCount = 0
80102
}
81103

104+
/// Computes correct guess count from `loss`, `predictions` and `labels`
105+
/// and adds it to correctGuessCount; Computes total guess count from
106+
/// `labels` shape and adds it to totalGuessCount.
82107
public mutating func accumulate<Output, Target>(
83108
loss: Tensor<Float>?, predictions: Output?, labels: Target?
84109
) {
@@ -93,6 +118,7 @@ public struct AccuracyMeasurer: MetricsMeasurer {
93118
totalGuessCount += Int32(labels.shape.reduce(1, *))
94119
}
95120

121+
/// Computes accuracy as percentage of correct guesses.
96122
public func measure() -> Float {
97123
return Float(correctGuessCount) / Float(totalGuessCount)
98124
}

TrainingLoop/TrainingLoop.swift

+5-4
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public protocol TrainingLoopProtocol {
7878
/// The loss function.
7979
var lossFunction: LossFunction { get set }
8080

81-
/// The metrics
81+
/// The metrics on which training is measured.
8282
var metrics: [TrainingMetrics] { get set }
8383

8484
// Callbacks
@@ -220,14 +220,15 @@ where
220220

221221
/// Callbacks
222222

223-
// MARK: - The callbacks used to customize the training loop.
224-
223+
/// The callbacks used to customize the training loop.
225224
public var callbacks: [TrainingLoopCallback<Self>]
226225

227226
// MARK: - Default callback objects
228227

228+
/// The callback that records the training statistics.
229229
public var statisticsRecorder: StatisticsRecorder? = nil
230230

231+
/// The callback that prints the training progress.
231232
public var progressPrinter: ProgressPrinter? = nil
232233

233234
/// Temporary data
@@ -292,7 +293,7 @@ where
292293
self.progressPrinter = progressPrinter
293294
self.callbacks = [
294295
statisticsRecorder.record,
295-
progressPrinter.print,
296+
progressPrinter.printProgress,
296297
] + callbacks
297298
} else {
298299
self.callbacks = callbacks

0 commit comments

Comments
 (0)