Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit 5e5d7b4

Browse files
authored
skip augmentation for imagenet validation set (#735)
1 parent 339752a commit 5e5d7b4

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

Datasets/Imagenette/ImageNet.swift

+11-6
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public struct ImageNet<Entropy: RandomNumberGenerator> {
9898
return batches.lazy.map {
9999
makeImageNetBatch(
100100
samples: $0, outputSize: outputSize, mean: mean, standardDeviation: standardDeviation,
101-
device: device)
101+
device: device, applyAugmentation: true)
102102
}
103103
}
104104

@@ -107,7 +107,7 @@ public struct ImageNet<Entropy: RandomNumberGenerator> {
107107
validation = validationSamples.inBatches(of: batchSize).lazy.map {
108108
makeImageNetBatch(
109109
samples: $0, outputSize: outputSize, mean: mean, standardDeviation: standardDeviation,
110-
device: device)
110+
device: device, applyAugmentation: false)
111111
}
112112
} catch {
113113
fatalError("Could not load ImageNet dataset: \(error)")
@@ -205,7 +205,7 @@ func loadImageNetValidationDirectory(
205205
named: "val", in: localStorageDirectory, base: base, labelDict: labelDict)
206206
}
207207

208-
func applyImageNetDataAugmentation(image: Image) -> Tensor<Float> {
208+
func applyImageNetDataAugmentation(image: Image, applyAugmentation: Bool) -> Tensor<Float> {
209209
// using the tensorflow imagenet demo from mlperf as reference:
210210
// https://github.com/mlcommons/training/blob/4f97c909f3aeaa3351da473d12eba461ace0be76/image_classification/tensorflow/official/resnet/imagenet_preprocessing.py#L94
211211
let imageData = image.tensor
@@ -231,6 +231,11 @@ func applyImageNetDataAugmentation(image: Image) -> Tensor<Float> {
231231
cropped = Tensor<Float>([offsetY, targetX, targetY, offsetX])
232232
}
233233

234+
// we skip the above for validation images, but keep the resize operation
235+
if !applyAugmentation {
236+
cropped = Tensor<Float>([0.0, 0.0, 1.0, 1.0])
237+
}
238+
234239
let imageBroadcast = imageData.reshaped(to: [1, height, width, channels])
235240
let bboxBroadcast = cropped.reshaped(to: [1, 4])
236241

@@ -241,17 +246,17 @@ func applyImageNetDataAugmentation(image: Image) -> Tensor<Float> {
241246

242247
func makeImageNetBatch<BatchSamples: Collection>(
243248
samples: BatchSamples, outputSize: Int, mean: Tensor<Float>?, standardDeviation: Tensor<Float>?,
244-
device: Device
249+
device: Device, applyAugmentation: Bool
245250
) -> LabeledImage where BatchSamples.Element == (file: URL, label: Int32) {
246251
let images = samples.map(\.file).map { url -> Tensor<Float> in
247252
if url.absoluteString.range(of: "n02105855_2933.JPEG") != nil {
248253
// this is a png saved as a jpeg, we manually strip an extra alpha channel to start
249254
let image = Image(contentsOf: url).tensor.slice(lowerBounds: [0, 0, 0], sizes: [189, 213, 3])
250255
let colorOnlyImage = Image(image)
251-
return applyImageNetDataAugmentation(image: colorOnlyImage)
256+
return applyImageNetDataAugmentation(image: colorOnlyImage, applyAugmentation: applyAugmentation)
252257
} else {
253258
let image = Image(contentsOf: url)
254-
return applyImageNetDataAugmentation(image: image)
259+
return applyImageNetDataAugmentation(image: image, applyAugmentation: applyAugmentation)
255260
}
256261
}
257262

0 commit comments

Comments
 (0)