@@ -98,7 +98,7 @@ public struct ImageNet<Entropy: RandomNumberGenerator> {
98
98
return batches. lazy. map {
99
99
makeImageNetBatch (
100
100
samples: $0, outputSize: outputSize, mean: mean, standardDeviation: standardDeviation,
101
- device: device)
101
+ device: device, applyAugmentation : true )
102
102
}
103
103
}
104
104
@@ -107,7 +107,7 @@ public struct ImageNet<Entropy: RandomNumberGenerator> {
107
107
validation = validationSamples. inBatches ( of: batchSize) . lazy. map {
108
108
makeImageNetBatch (
109
109
samples: $0, outputSize: outputSize, mean: mean, standardDeviation: standardDeviation,
110
- device: device)
110
+ device: device, applyAugmentation : false )
111
111
}
112
112
} catch {
113
113
fatalError ( " Could not load ImageNet dataset: \( error) " )
@@ -205,7 +205,7 @@ func loadImageNetValidationDirectory(
205
205
named: " val " , in: localStorageDirectory, base: base, labelDict: labelDict)
206
206
}
207
207
208
- func applyImageNetDataAugmentation( image: Image ) -> Tensor < Float > {
208
+ func applyImageNetDataAugmentation( image: Image , applyAugmentation : Bool ) -> Tensor < Float > {
209
209
// using the tensorflow imagenet demo from mlperf as reference:
210
210
// https://github.com/mlcommons/training/blob/4f97c909f3aeaa3351da473d12eba461ace0be76/image_classification/tensorflow/official/resnet/imagenet_preprocessing.py#L94
211
211
let imageData = image. tensor
@@ -231,6 +231,11 @@ func applyImageNetDataAugmentation(image: Image) -> Tensor<Float> {
231
231
cropped = Tensor < Float > ( [ offsetY, targetX, targetY, offsetX] )
232
232
}
233
233
234
+ // we skip the above for validation images, but keep the resize operation
235
+ if !applyAugmentation {
236
+ cropped = Tensor < Float > ( [ 0.0 , 0.0 , 1.0 , 1.0 ] )
237
+ }
238
+
234
239
let imageBroadcast = imageData. reshaped ( to: [ 1 , height, width, channels] )
235
240
let bboxBroadcast = cropped. reshaped ( to: [ 1 , 4 ] )
236
241
@@ -241,17 +246,17 @@ func applyImageNetDataAugmentation(image: Image) -> Tensor<Float> {
241
246
242
247
func makeImageNetBatch< BatchSamples: Collection > (
243
248
samples: BatchSamples , outputSize: Int , mean: Tensor < Float > ? , standardDeviation: Tensor < Float > ? ,
244
- device: Device
249
+ device: Device , applyAugmentation : Bool
245
250
) -> LabeledImage where BatchSamples. Element == ( file: URL , label: Int32 ) {
246
251
let images = samples. map ( \. file) . map { url -> Tensor < Float > in
247
252
if url. absoluteString. range ( of: " n02105855_2933.JPEG " ) != nil {
248
253
// this is a png saved as a jpeg, we manually strip an extra alpha channel to start
249
254
let image = Image ( contentsOf: url) . tensor. slice ( lowerBounds: [ 0 , 0 , 0 ] , sizes: [ 189 , 213 , 3 ] )
250
255
let colorOnlyImage = Image ( image)
251
- return applyImageNetDataAugmentation ( image: colorOnlyImage)
256
+ return applyImageNetDataAugmentation ( image: colorOnlyImage, applyAugmentation : applyAugmentation )
252
257
} else {
253
258
let image = Image ( contentsOf: url)
254
- return applyImageNetDataAugmentation ( image: image)
259
+ return applyImageNetDataAugmentation ( image: image, applyAugmentation : applyAugmentation )
255
260
}
256
261
}
257
262
0 commit comments