@@ -17,31 +17,35 @@ import TensorFlow
17
17
///
18
18
/// Data produced by this handler can be used by ProgressPrinter, CVSLogger, etc.
19
19
public class StatisticsRecorder {
20
- /// A Closure that returns if should call 'reset' on metricMeasurers.
20
+ /// A function that returns `true` iff recorder should call `reset`
21
+ /// on `metricMeasurers`.
21
22
public var shouldReset :
22
23
(
23
24
_ batchIndex: Int , _ batchCount: Int , _ epochIndex: Int , _ epochCount: Int ,
24
25
_ event: TrainingLoopEvent
25
26
) -> Bool
26
27
27
- /// A Closure that returns if should call 'accumulate' on metricMeasurers.
28
+ /// A function that returns `true` iff recorder should call `accumulate`
29
+ /// on `metricMeasurers`.
28
30
public var shouldAccumulate :
29
31
(
30
32
_ batchIndex: Int , _ batchCount: Int , _ epochIndex: Int , _ epochCount: Int ,
31
33
_ event: TrainingLoopEvent
32
34
) -> Bool
33
35
34
- /// A Closure that returns if should call 'compute' on metricMeasurers.
36
+ /// A function that returns `true` iff recorder should call `measure`
37
+ /// on `metricMeasurers`.
35
38
public var shouldCompute :
36
39
(
37
40
_ batchIndex: Int , _ batchCount: Int , _ epochIndex: Int , _ epochCount: Int ,
38
41
_ event: TrainingLoopEvent
39
42
) -> Bool
40
43
41
- /// Instances of MetricsMeasurers.
44
+ /// Instances of MetricsMeasurers that you can reset accumulate and compute
45
+ /// statistics periodically.
42
46
fileprivate var metricMeasurers : [ MetricsMeasurer ]
43
47
44
- /// Create an instance that records ' metrics' during the training loop.
48
+ /// Creates an instance that records ` metrics` during the training loop.
45
49
public init ( metrics: [ TrainingMetrics ] ) {
46
50
metricMeasurers = metrics. map { $0. measurer }
47
51
@@ -70,9 +74,9 @@ public class StatisticsRecorder {
70
74
}
71
75
}
72
76
73
- /// Recording statistics in response of the ' event' .
77
+ /// Records statistics in response of the ` event` .
74
78
///
75
- /// It will record the statistics into ' lastStatsLog' in the loop where other
79
+ /// It will record the statistics into lastStatsLog property in the ` loop` where other
76
80
/// callbacks can consume from.
77
81
public func record< L: TrainingLoopProtocol > ( _ loop: inout L , event: TrainingLoopEvent ) throws {
78
82
guard let batchIndex = loop. batchIndex,
@@ -83,7 +87,6 @@ public class StatisticsRecorder {
83
87
let output = loop. lastStepOutput,
84
88
let target = loop. lastStepTarget
85
89
else {
86
- // No-Op if trainingLoop doesn't set the required values for stats recording.
87
90
return
88
91
}
89
92
@@ -101,18 +104,22 @@ public class StatisticsRecorder {
101
104
}
102
105
}
103
106
107
+ /// Resets each of the metricMeasurers.
104
108
func resetMetricMeasurers( ) {
105
109
for index in metricMeasurers. indices {
106
110
metricMeasurers [ index] . reset ( )
107
111
}
108
112
}
109
113
114
+ /// Lets each of the metricMeasurers accumulate data from
115
+ /// `loss`, `predictions`, `labels`.
110
116
func accumulateMetrics< Output, Target> ( loss: Tensor < Float > , predictions: Output , labels: Target ) {
111
117
for index in metricMeasurers. indices {
112
118
metricMeasurers [ index] . accumulate ( loss: loss, predictions: predictions, labels: labels)
113
119
}
114
120
}
115
121
122
+ /// Lets each of the metricMeasurers compute metrics on cumulated data.
116
123
func computeMetrics( ) -> [ ( String , Float ) ] {
117
124
var result : [ ( String , Float ) ] = [ ]
118
125
for measurer in metricMeasurers {
0 commit comments