Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit b6e5d01

Browse files
BradLarsonmarcrasi
and
marcrasi
authored
Growing Neural Cellular Automata (#669)
* Initial scaffolding for the example application. * Completed most of the steps, but doesn't yield non-zero gradients yet. * Initial model implementation, adding PNG and alpha channel support to image loading and saving. * Completed cell masking. * Formatting. * Minor experimentation. * Repaired training and masking. * Fixing the inference image generation, capturing representative samples only every 10th iteration during training. * Attempting gradient normaliztion, per the paper. * Intensifying Sobel kernels, testing different loss, premultiplying alpha on saving for RGB PNGs. * Fixing initial learning rate stepdown, extracting color component slicing into separate function, fixing one X10 issue. * Adding Colab demonstration notebook. * Made example and notebook compatible with X10. * Replaced sign()-based masking with real Boolean masking and an explicit broadcast that fixes X10-side corruption. * Reorganizing options. * Added a sample pool to start replicating Experiment 2. * Completed Experiment 2 and 3, added images to Readme. * Reformatting the Readme. * Reworked Image interface, now that it handles more than JPEGs. Started adding AnimatedImage interface. * Added functional animated GIF support. * Fixed color scaling, clamped colors. * Formatting. * Adding premultiplied alpha for image loading, to reduce artifacts at the edge of the image. * Updating PersonLab model for new interfaces. * Add ability to disable bias in the last convolution stage, which may lead to more stable results. * Updating notebook with latest implementation and animated GIF writing and display functionality. * Celll rule formatting. * Update Support/Image.swift Co-authored-by: marcrasi <[email protected]> * Update Support/Image.swift Co-authored-by: marcrasi <[email protected]> * Merge remote-tracking branch 'upstream/master' into CellularAutomata * Updating PDE solver image saving to match new interface. * Initial scaffolding for the example application. * Completed most of the steps, but doesn't yield non-zero gradients yet. * Initial model implementation, adding PNG and alpha channel support to image loading and saving. * Completed cell masking. * Formatting. * Minor experimentation. * Repaired training and masking. * Fixing the inference image generation, capturing representative samples only every 10th iteration during training. * Attempting gradient normaliztion, per the paper. * Intensifying Sobel kernels, testing different loss, premultiplying alpha on saving for RGB PNGs. * Fixing initial learning rate stepdown, extracting color component slicing into separate function, fixing one X10 issue. * Adding Colab demonstration notebook. * Made example and notebook compatible with X10. * Replaced sign()-based masking with real Boolean masking and an explicit broadcast that fixes X10-side corruption. * Reorganizing options. * Added a sample pool to start replicating Experiment 2. * Completed Experiment 2 and 3, added images to Readme. * Reformatting the Readme. * Reworked Image interface, now that it handles more than JPEGs. Started adding AnimatedImage interface. * Added functional animated GIF support. * Fixed color scaling, clamped colors. * Formatting. * Adding premultiplied alpha for image loading, to reduce artifacts at the edge of the image. * Updating PersonLab model for new interfaces. * Add ability to disable bias in the last convolution stage, which may lead to more stable results. * Updating notebook with latest implementation and animated GIF writing and display functionality. * Celll rule formatting. * Update Support/Image.swift Co-authored-by: marcrasi <[email protected]> * Update Support/Image.swift Co-authored-by: marcrasi <[email protected]> * Updating PDE solver image saving to match new interface. * Move the Node class into the only function where it is used. * Reworking image saving to use composable extensions, starting to add documentation comments. * Updating image saving across various models. * Repaired two examples and added documentation for the remaining Image methods. * Updating notebook for latest API. Co-authored-by: marcrasi <[email protected]>
1 parent de7f0e0 commit b6e5d01

32 files changed

+1693
-152
lines changed

Autoencoder/Autoencoder1D/main.swift

+10-8
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,16 @@ func saveImage<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEven
6969

7070
let outputFolder = "./output/"
7171
let selectedImageBatchLocalIndex = selectedImageGlobalIndex % batchSize
72-
try saveImage(
73-
(input as! Tensor<Float>)[selectedImageBatchLocalIndex..<selectedImageBatchLocalIndex+1],
74-
shape: (imageWidth, imageHeight), format: .grayscale,
75-
directory: outputFolder, name: "epoch-\(epochIndex + 1)-of-\(epochCount)-input")
76-
try saveImage(
77-
(output as! Tensor<Float>)[selectedImageBatchLocalIndex..<selectedImageBatchLocalIndex+1],
78-
shape: (imageWidth, imageHeight), format: .grayscale,
79-
directory: outputFolder, name: "epoch-\(epochIndex + 1)-of-\(epochCount)-output")
72+
let inputExample =
73+
(input as! Tensor<Float>)[selectedImageBatchLocalIndex..<selectedImageBatchLocalIndex+1]
74+
.normalizedToGrayscale().reshaped(to: [imageWidth, imageHeight, 1])
75+
try inputExample.saveImage(
76+
directory: outputFolder, name: "epoch-\(epochIndex + 1)-of-\(epochCount)-input", format: .png)
77+
let outputExample =
78+
(output as! Tensor<Float>)[selectedImageBatchLocalIndex..<selectedImageBatchLocalIndex+1]
79+
.normalizedToGrayscale().reshaped(to: [imageWidth, imageHeight, 1])
80+
try outputExample.saveImage(
81+
directory: outputFolder, name: "epoch-\(epochIndex + 1)-of-\(epochCount)-output", format: .png)
8082
}
8183

8284
var trainingLoop = TrainingLoop(

Autoencoder/Autoencoder2D/main.swift

+8-6
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,14 @@ for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
8282
let testImages = model(sampleImages)
8383

8484
do {
85-
try saveImage(
86-
sampleImages[0..<1], shape: (imageWidth, imageHeight), format: .grayscale,
87-
directory: outputFolder, name: "epoch-\(epoch)-input")
88-
try saveImage(
89-
testImages[0..<1], shape: (imageWidth, imageHeight), format: .grayscale,
90-
directory: outputFolder, name: "epoch-\(epoch)-output")
85+
let inputExample = sampleImages[0..<1].normalizedToGrayscale()
86+
.reshaped(to: [imageWidth, imageHeight, 1])
87+
try inputExample.saveImage(
88+
directory: outputFolder, name: "epoch-\(epoch)-input", format: .png)
89+
let outputExample = testImages[0..<1].normalizedToGrayscale()
90+
.reshaped(to: [imageWidth, imageHeight, 1])
91+
try outputExample.saveImage(
92+
directory: outputFolder, name: "epoch-\(epoch)-output", format: .png)
9193
} catch {
9294
print("Could not save image with error: \(error)")
9395
}

Autoencoder/VAE1D/main.swift

+8-6
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,14 @@ for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
111111
let testLogVars = testOutputs[2]
112112
if epoch == 0 || (epoch + 1) % 10 == 0 {
113113
do {
114-
try saveImage(
115-
sampleImages[0..<1], shape: (imageWidth, imageHeight), format: .grayscale,
116-
directory: outputFolder, name: "epoch-\(epoch)-input")
117-
try saveImage(
118-
testImages[0..<1], shape: (imageWidth, imageHeight), format: .grayscale,
119-
directory: outputFolder, name: "epoch-\(epoch)-output")
114+
let inputExample = sampleImages[0..<1].normalizedToGrayscale()
115+
.reshaped(to: [imageWidth, imageHeight, 1])
116+
try inputExample.saveImage(
117+
directory: outputFolder, name: "epoch-\(epoch)-input", format: .png)
118+
let outputExample = testImages[0..<1].normalizedToGrayscale()
119+
.reshaped(to: [imageWidth, imageHeight, 1])
120+
try outputExample.saveImage(
121+
directory: outputFolder, name: "epoch-\(epoch)-output", format: .png)
120122
} catch {
121123
print("Could not save image with error: \(error)")
122124
}

CycleGAN/Data/Dataset.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ public struct CycleGANDataset<Entropy: RandomNumberGenerator> {
117117
options: [.skipsHiddenFiles])
118118
.filter { $0.pathExtension == "jpg" }
119119
.map {
120-
Image(jpeg: $0).tensor / 127.5 - 1.0
120+
Image(contentsOf: $0).tensor / 127.5 - 1.0
121121
}
122122
}
123123
}

CycleGAN/main.swift

+5-16
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ let _ones = Tensorf.one
4242
var step = 0
4343

4444
var validationImage = dataset.trainSamples[0].domainA.expandingShape(at: 0)
45-
let validationImageURL = URL(string: FileManager.default.currentDirectoryPath)!.appendingPathComponent("sample.jpg")
4645

4746
// MARK: Train
4847

@@ -151,9 +150,7 @@ for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
151150
Context.local.learningPhase = .inference
152151

153152
let fakeSample = generatorG(validationImage) * 0.5 + 0.5
154-
155-
let fakeSampleImage = Image(tensor: fakeSample[0] * 255)
156-
fakeSampleImage.save(to: validationImageURL, format: .rgb)
153+
try fakeSample[0].scaled(by: 255).saveImage(directory: "output", name: "sample")
157154

158155
print("GeneratorG loss: \(gLoss.scalars[0])")
159156
print("GeneratorF loss: \(fLoss.scalars[0])")
@@ -167,11 +164,6 @@ for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
167164

168165
// MARK: Final test
169166

170-
let aResultsFolder = try createDirectoryIfNeeded(path: FileManager.default
171-
.currentDirectoryPath + "/testA_results")
172-
let bResultsFolder = try createDirectoryIfNeeded(path: FileManager.default
173-
.currentDirectoryPath + "/testB_results")
174-
175167
var testStep = 0
176168
for testBatch in dataset.testing {
177169
let realX = testBatch.domainA
@@ -183,13 +175,10 @@ for testBatch in dataset.testing {
183175
let resultX = realX.concatenated(with: fakeY, alongAxis: 2) * 0.5 + 0.5
184176
let resultY = realY.concatenated(with: fakeX, alongAxis: 2) * 0.5 + 0.5
185177

186-
let imageX = Image(tensor: resultX[0] * 255)
187-
let imageY = Image(tensor: resultY[0] * 255)
188-
189-
imageX.save(to: aResultsFolder.appendingPathComponent("\(String(testStep)).jpg", isDirectory: false),
190-
format: .rgb)
191-
imageY.save(to: bResultsFolder.appendingPathComponent("\(String(testStep)).jpg", isDirectory: false),
192-
format: .rgb)
178+
try resultX[0].scaled(by: 255)
179+
.saveImage(directory: "output/testA_results", name: "\(String(testStep))")
180+
try resultY[0].scaled(by: 255)
181+
.saveImage(directory: "output/testA_results", name: "\(String(testStep))")
193182

194183
testStep += 1
195184
}

DCGAN/main.swift

+4-8
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,8 @@ for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
151151
Context.local.learningPhase = .inference
152152

153153
// Render images.
154-
let generatedImage = generator(noise)
155-
try saveImage(
156-
generatedImage, shape: (28, 28), format: .grayscale, directory: outputFolder,
157-
name: "\(epoch)")
154+
let generatedImage = generator(noise).normalizedToGrayscale().reshaped(to: [28, 28, 1])
155+
try generatedImage.saveImage(directory: outputFolder, name: "\(epoch)", format: .png)
158156

159157
// Print loss.
160158
let generatorLoss_ = generatorLoss(fakeLabels: generatedImage)
@@ -163,7 +161,5 @@ for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
163161

164162
// Generate another image.
165163
let noise1 = Tensor<Float>(randomNormal: TensorShape(1, 100))
166-
let generatedImage = generator(noise1)
167-
try saveImage(
168-
generatedImage, shape: (28, 28), format: .grayscale, directory: outputFolder,
169-
name: "final")
164+
let generatedImage = generator(noise1).normalizedToGrayscale().reshaped(to: [28, 28, 1])
165+
try generatedImage.saveImage(directory: outputFolder, name: "final", format: .png)

Datasets/Imagenette/Imagenette.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ func makeImagenetteBatch<BatchSamples: Collection>(
215215
device: Device
216216
) -> LabeledImage where BatchSamples.Element == (file: URL, label: Int32) {
217217
let images = samples.map(\.file).map { url -> Tensor<Float> in
218-
Image(jpeg: url).resized(to: (outputSize, outputSize)).tensor
218+
Image(contentsOf: url).resized(to: (outputSize, outputSize)).tensor
219219
}
220220

221221
var imageTensor = Tensor(stacking: images)

Datasets/ObjectDetectionDataset.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public struct LazyImage {
2929

3030
public func tensor() -> Tensor<Float>? {
3131
if url != nil {
32-
return Image(jpeg: url!).tensor
32+
return Image(contentsOf: url!).tensor
3333
} else {
3434
return nil
3535
}

Datasets/OxfordIIITPets/OxfordIIITPets.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ fileprivate func makeBatch<BatchSamples: Collection>(
173173
samples: BatchSamples, imageSize: Int, device: Device
174174
) -> SegmentedImage where BatchSamples.Element == (file: URL, annotation: URL) {
175175
let images = samples.map(\.file).map { url -> Tensor<Float> in
176-
Image(jpeg: url).resized(to: (imageSize, imageSize)).tensor[0..., 0..., 0..<3]
176+
Image(contentsOf: url).resized(to: (imageSize, imageSize)).tensor[0..., 0..., 0..<3]
177177
}
178178

179179
var imageTensor = Tensor(stacking: images)
@@ -182,7 +182,7 @@ fileprivate func makeBatch<BatchSamples: Collection>(
182182

183183
let annotations = samples.map(\.annotation).map { url -> Tensor<Int32> in
184184
Tensor<Int32>(
185-
Image(jpeg: url).resized(to: (imageSize, imageSize)).tensor[0..., 0..., 0...0] - 1)
185+
Image(contentsOf: url).resized(to: (imageSize, imageSize)).tensor[0..., 0..., 0...0] - 1)
186186
}
187187
var annotationTensor = Tensor(stacking: annotations)
188188
annotationTensor = Tensor(copying: annotationTensor, to: device)

Examples/Fractals/ImageUtilities.swift

+5-7
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,16 @@ extension ImageSize: ExpressibleByArgument {
3636
}
3737

3838
fileprivate func prismColor(_ value: Float, iterations: Int) -> [Float] {
39-
guard value < Float(iterations) else { return [0.0, 0.0, 0.0] }
39+
guard value < Float(iterations) else { return [0.0, 0.0, 0.0, 1.0] }
4040

4141
let normalizedValue = value / Float(iterations)
4242

4343
// Values drawn from Matplotlib: https://github.com/matplotlib/matplotlib/blob/master/lib/matplotlib/_cm.py
4444
let red = (0.75 * sinf((normalizedValue * 20.9 + 0.25) * Float.pi) + 0.67) * 255
4545
let green = (0.75 * sinf((normalizedValue * 20.9 - 0.25) * Float.pi) + 0.33) * 255
4646
let blue = (-1.1 * sinf((normalizedValue * 20.9) * Float.pi)) * 255
47-
return [red, green, blue]
47+
let alpha: Float = 255.0
48+
return [red, green, blue, alpha]
4849
}
4950

5051
func saveFractalImage(_ divergenceGrid: Tensor<Float>, iterations: Int, fileName: String) throws {
@@ -54,10 +55,7 @@ func saveFractalImage(_ divergenceGrid: Tensor<Float>, iterations: Int, fileName
5455
$0 += prismColor($1, iterations: iterations)
5556
}
5657
let colorImage = Tensor<Float>(
57-
shape: [gridShape[0], gridShape[1], 3], scalars: colorValues, on: divergenceGrid.device)
58+
shape: [gridShape[0], gridShape[1], 4], scalars: colorValues, on: divergenceGrid.device)
5859

59-
try saveImage(
60-
colorImage, shape: (gridShape[0], gridShape[1]),
61-
format: .rgb, directory: "./", name: fileName,
62-
quality: 95)
60+
try colorImage.saveImage(directory: "./", name: fileName, format: .png)
6361
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
struct CellRule: Layer {
18+
@noDerivative var perceptionFilter: Tensor<Float>
19+
@noDerivative let fireRate: Float
20+
21+
var conv1: Conv2D<Float>
22+
var conv2: Conv2D<Float>
23+
24+
init(stateChannels: Int, fireRate: Float, useBias: Bool) {
25+
self.fireRate = fireRate
26+
27+
let horizontalSobelKernel =
28+
Tensor<Float>(
29+
shape: [3, 3, 1, 1], scalars: [-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0]) / 8.0
30+
let horizontalSobelFilter = horizontalSobelKernel.broadcasted(to: [3, 3, stateChannels, 1])
31+
let verticalSobelKernel =
32+
Tensor<Float>(
33+
shape: [3, 3, 1, 1], scalars: [-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0]) / 8.0
34+
let verticalSobelFilter = verticalSobelKernel.broadcasted(to: [3, 3, stateChannels, 1])
35+
let identityKernel = Tensor<Float>(
36+
shape: [3, 3, 1, 1], scalars: [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0])
37+
let identityFilter = identityKernel.broadcasted(to: [3, 3, stateChannels, 1])
38+
perceptionFilter = Tensor(
39+
concatenating: [horizontalSobelFilter, verticalSobelFilter, identityFilter], alongAxis: 3)
40+
41+
conv1 = Conv2D<Float>(filterShape: (1, 1, stateChannels * 3, 128))
42+
conv2 = Conv2D<Float>(
43+
filterShape: (1, 1, 128, stateChannels), useBias: useBias, filterInitializer: zeros())
44+
}
45+
46+
@differentiable
47+
func livingMask(_ input: Tensor<Float>) -> Tensor<Float> {
48+
let alphaChannel = input.slice(
49+
lowerBounds: [0, 0, 0, 3], sizes: [input.shape[0], input.shape[1], input.shape[2], 1])
50+
let localMaximum =
51+
maxPool2D(alphaChannel, filterSize: (1, 3, 3, 1), strides: (1, 1, 1, 1), padding: .same)
52+
return withoutDerivative(at: input) { _ in localMaximum.mask { $0 .> 0.1 } }
53+
}
54+
55+
@differentiable
56+
func perceive(_ input: Tensor<Float>) -> Tensor<Float> {
57+
return depthwiseConv2D(
58+
input, filter: perceptionFilter, strides: (1, 1, 1, 1), padding: .same)
59+
}
60+
61+
@differentiable
62+
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
63+
let perception = perceive(input)
64+
let dx = conv2(relu(conv1(perception)))
65+
66+
let updateDistribution = Tensor<Float>(
67+
randomUniform: [input.shape[0], input.shape[1], input.shape[2], 1], on: input.device)
68+
let updateMask = withoutDerivative(at: input) { _ in
69+
updateDistribution.mask { $0 .< fireRate }
70+
}
71+
72+
let updatedState = input + (dx * updateMask)
73+
let combinedLivingMask = livingMask(input) * livingMask(updatedState)
74+
return updatedState * combinedLivingMask
75+
}
76+
}
77+
78+
func normalizeGradient(_ gradient: CellRule.TangentVector) -> CellRule.TangentVector {
79+
var outputGradient = gradient
80+
for kp in gradient.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
81+
let norm = sqrt(gradient[keyPath: kp].squared().sum())
82+
outputGradient[keyPath: kp] = gradient[keyPath: kp] / (norm + 1e-8)
83+
}
84+
return outputGradient
85+
}
86+
87+
extension Tensor where Scalar: Numeric {
88+
@differentiable(where Scalar: TensorFlowFloatingPoint)
89+
var colorComponents: Tensor {
90+
precondition(self.rank == 3 || self.rank == 4)
91+
if self.rank == 3 {
92+
return self.slice(
93+
lowerBounds: [0, 0, 0], sizes: [self.shape[0], self.shape[1], 4])
94+
} else {
95+
return self.slice(
96+
lowerBounds: [0, 0, 0, 0], sizes: [self.shape[0], self.shape[1], self.shape[2], 4])
97+
}
98+
}
99+
100+
func mask(condition: (Tensor) -> Tensor<Bool>) -> Tensor {
101+
let satisfied = condition(self)
102+
return Tensor(zerosLike: self)
103+
.replacing(with: Tensor(onesLike: self), where: satisfied)
104+
}
105+
}
106+
107+
// Note: the following is an identity function that serves to cut the backward trace into
108+
// smaller identical traces, to improve X10 performance.
109+
@inlinable
110+
@differentiable
111+
func clipBackwardsTrace(_ input: Tensor<Float>) -> Tensor<Float> {
112+
return input
113+
}
114+
115+
@inlinable
116+
@derivative(of: clipBackwardsTrace)
117+
func _vjpClipBackwardsTrace(
118+
_ input: Tensor<Float>
119+
) -> (value: Tensor<Float>, pullback: (Tensor<Float>) -> Tensor<Float>) {
120+
return (
121+
input,
122+
{
123+
LazyTensorBarrier()
124+
return $0
125+
}
126+
)
127+
}

0 commit comments

Comments
 (0)