Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit a6ad2bb

Browse files
authored
Restructuring checkpoint files, adding an initial filesystem abstraction (#618)
* Initial move of the checkpoint reader / writer to their own module. * Updating BERT for new checkpoint organization. * Update MiniGo for checkpoint migration. * Initial experiments with a filesystem abstraction. * Using a default filesystem for the index reader. * Minor formatting. * Remove unnecessary default. * Remove CMake references to old Checkpoint files. * Fixing more CMake imports. * Actually add the Checkpoints CMakelists. * Enable testing and proper dependency for CMake. * Adding the last few CMake files.
1 parent a275252 commit a6ad2bb

30 files changed

+182
-43
lines changed

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ add_dependencies(ArgumentParser swift-argument-parser-install)
124124
add_subdirectory(Autoencoder)
125125
add_subdirectory(Support)
126126
add_subdirectory(Batcher)
127+
add_subdirectory(Checkpoints)
127128
add_subdirectory(Datasets)
128129
add_subdirectory(Models)
129130
add_subdirectory(MiniGo)

Checkpoints/CMakeLists.txt

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
add_library(Checkpoints
2+
CheckpointIndexReader.swift
3+
CheckpointIndexWriter.swift
4+
CheckpointReader.swift
5+
CheckpointWriter.swift
6+
Protobufs/tensor_bundle.pb.swift
7+
Protobufs/tensor_shape.pb.swift
8+
Protobufs/tensor_slice.pb.swift
9+
Protobufs/types.pb.swift
10+
Protobufs/versions.pb.swift
11+
SnappyDecompression.swift)
12+
set_target_properties(Checkpoints PROPERTIES
13+
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
14+
target_compile_options(Checkpoints PRIVATE
15+
$<$<BOOL:${BUILD_TESTING}>:-enable-testing>)
16+
target_link_libraries(Checkpoints PUBLIC
17+
ModelSupport)
18+
19+
install(TARGETS Checkpoints
20+
ARCHIVE DESTINATION lib/swift/$<LOWER_CASE:${CMAKE_SYSTEM_NAME}>
21+
LIBRARY DESTINATION lib/swift/$<LOWER_CASE:${CMAKE_SYSTEM_NAME}>
22+
RUNTIME DESTINATION bin)
23+
get_swift_host_arch(swift_arch)
24+
install(FILES
25+
$<TARGET_PROPERTY:Checkpoints,Swift_MODULE_DIRECTORY>/Checkpoints.swiftdoc
26+
$<TARGET_PROPERTY:Checkpoints,Swift_MODULE_DIRECTORY>/Checkpoints.swiftmodule
27+
DESTINATION lib/swift$<$<NOT:$<BOOL:${BUILD_SHARED_LIBS}>>:_static>/$<LOWER_CASE:${CMAKE_SYSTEM_NAME}>/${swift_arch})

Support/Checkpoints/CheckpointIndexReader.swift Checkpoints/CheckpointIndexReader.swift

+4-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
// locations determined by this metadata.
3131

3232
import Foundation
33+
import ModelSupport
3334

3435
// The block footer size is constant, and is obtained from the following:
3536
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/io/format.h
@@ -43,8 +44,9 @@ class CheckpointIndexReader {
4344

4445
var atEndOfFile: Bool { return index >= (binaryData.count - footerSize - 1) }
4546

46-
init(file: URL) throws {
47-
let fileData = try Data(contentsOf: file)
47+
init(file: URL, fileSystem: FileSystem = FoundationFileSystem()) throws {
48+
let indexFile = fileSystem.open(file.path)
49+
let fileData = try indexFile.read()
4850
if fileData[0] == 0 {
4951
binaryData = fileData
5052
} else {

Support/Checkpoints/CheckpointReader.swift Checkpoints/CheckpointReader.swift

+16-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
// shards to obtain their raw bytes.
2121

2222
import Foundation
23+
import ModelSupport
2324
import TensorFlow
2425

2526
/// A Swift-native TensorFlow v2 checkpoint reader that can download all checkpoint files from
@@ -29,6 +30,7 @@ open class CheckpointReader {
2930
let header: Tensorflow_BundleHeaderProto
3031
let metadata: [String: Tensorflow_BundleEntryProto]
3132
var shardCache: [URL: Data] = [:]
33+
let fileSystem: FileSystem
3234

3335
/// The local checkpoint location.
3436
public let localCheckpointLocation: URL
@@ -51,7 +53,11 @@ open class CheckpointReader {
5153
/// base of the checkpoint files, or a URL to an archive containing the checkpoint files.
5254
/// - modelName: A distinct name for the model, to ensure that checkpoints with the same base
5355
/// name but for different models don't collide when downloaded.
54-
public init(checkpointLocation: URL, modelName: String, additionalFiles: [String] = []) throws {
56+
public init(
57+
checkpointLocation: URL, modelName: String, additionalFiles: [String] = [],
58+
fileSystem: FileSystem = FoundationFileSystem()
59+
) throws {
60+
self.fileSystem = fileSystem
5561
let temporaryDirectory = FileManager.default.temporaryDirectory.appendingPathComponent(
5662
modelName, isDirectory: true)
5763

@@ -72,20 +78,23 @@ open class CheckpointReader {
7278
if finalCheckpointLocation.isFileURL {
7379
self.localCheckpointLocation = finalCheckpointLocation
7480
indexReader = try CheckpointIndexReader(
75-
file: finalCheckpointLocation.appendingPathExtension("index"))
81+
file: finalCheckpointLocation.appendingPathExtension("index"),
82+
fileSystem: fileSystem)
7683
self.header = try indexReader.readHeader()
7784
} else {
7885
let temporaryCheckpointBase = temporaryDirectory.appendingPathComponent(checkpointBase)
7986
self.localCheckpointLocation = temporaryCheckpointBase
8087
let localIndexFileLocation = temporaryCheckpointBase.appendingPathExtension("index")
8188
if FileManager.default.fileExists(atPath: localIndexFileLocation.path) {
82-
indexReader = try CheckpointIndexReader(file: localIndexFileLocation)
89+
indexReader = try CheckpointIndexReader(file: localIndexFileLocation,
90+
fileSystem: fileSystem)
8391
self.header = try indexReader.readHeader()
8492
} else {
8593
// The index file contains the number of shards, so obtain that first.
8694
try CheckpointReader.downloadIndexFile(
8795
from: finalCheckpointLocation, to: temporaryDirectory)
88-
indexReader = try CheckpointIndexReader(file: localIndexFileLocation)
96+
indexReader = try CheckpointIndexReader(file: localIndexFileLocation,
97+
fileSystem: fileSystem)
8998
self.header = try indexReader.readHeader()
9099

91100
try CheckpointReader.downloadCheckpointFiles(
@@ -269,10 +278,9 @@ open class CheckpointReader {
269278
} else {
270279
do {
271280
// It is far too slow to read the shards in each time a tensor is accessed, so we
272-
// read the entire shard into an in-memory cache on first access. A better approach
273-
// to mapping these files may be needed, because .alwaysMapped doesn't seem to help
274-
// as much as it should.
275-
let shardBytes = try Data(contentsOf: file, options: .alwaysMapped)
281+
// read the entire shard into an in-memory cache on first access.
282+
let shardFile = fileSystem.open(file.path)
283+
let shardBytes = try shardFile.read()
276284
shardCache[file] = shardBytes
277285
return shardBytes
278286
} catch {

Support/Checkpoints/CheckpointWriter.swift Checkpoints/CheckpointWriter.swift

+12-4
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,27 @@
1313
// limitations under the License.
1414

1515
import Foundation
16+
import ModelSupport
1617
import TensorFlow
1718

1819
/// A Swift-native TensorFlow v2 checkpoint writer. This writer has no dependencies
1920
/// on the TensorFlow runtime or libraries.
2021
open class CheckpointWriter {
2122
// TODO: Extend handling to different tensor types.
2223
let tensors: [String: Tensor<Float>]
24+
let fileSystem: FileSystem
2325

2426
/// Initializes the checkpoint reader from a dictionary of tensors, keyed on their string names.
2527
///
2628
/// - Parameters:
2729
/// - tensors: A dictionary containing the tensors to be written, with the keys being the
2830
/// names of those tensors to write in the checkpoint.
29-
public init(tensors: [String: Tensor<Float>]) {
31+
/// - fileSystem: The filesystem used for writing the checkpoint.
32+
public init(
33+
tensors: [String: Tensor<Float>], fileSystem: FileSystem = FoundationFileSystem()
34+
) {
3035
self.tensors = tensors
36+
self.fileSystem = fileSystem
3137
}
3238

3339
/// Writes the checkpoint to disk, in a specified directory. A TensorFlow v2 checkpoint consists
@@ -40,11 +46,12 @@ open class CheckpointWriter {
4046
/// - name: The base name of the checkpoint, which is what will have the .index and
4147
/// .data-0000X-of-0000Y extensions appended to it for files in the checkpoint directory.
4248
public func write(to directory: URL, name: String) throws {
43-
try createDirectoryIfMissing(at: directory.path)
49+
try fileSystem.createDirectoryIfMissing(at: directory.path)
4450
let indexWriter = CheckpointIndexWriter(tensors: tensors)
4551
let indexHeader = indexWriter.serializedHeader()
4652
let headerLocation = directory.appendingPathComponent("\(name).index")
47-
try indexHeader.write(to: headerLocation)
53+
let headerFile = fileSystem.open(headerLocation.path)
54+
try headerFile.write(indexHeader)
4855

4956
// TODO: Handle splitting into multiple shards.
5057
try writeShard(
@@ -72,6 +79,7 @@ open class CheckpointWriter {
7279
}
7380
}
7481

75-
try outputBuffer.write(to: shardFile)
82+
let outputFile = fileSystem.open(shardFile.path)
83+
try outputFile.write(outputBuffer)
7684
}
7785
}

Support/Checkpoints/SnappyDecompression.swift Checkpoints/SnappyDecompression.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import Foundation
2323
public enum SnappyDecompressionError: Error {
2424
case illegalLiteralLength(upperBits: UInt8)
2525
case impossibleTagType(tagType: UInt8)
26+
case uncompressedDataLengthMismatch(target: Int, actual: Int)
2627
}
2728

2829
// The following extension to Data provides methods that read variable-length byte sequences
@@ -150,9 +151,8 @@ public extension Data {
150151
}
151152
}
152153
if uncompressedData.count != uncompressedLength {
153-
// TODO: Determine if this should be elevated to a thrown error.
154-
printError(
155-
"Warning: uncompressed data length of \(uncompressedData.count) did not match desired length of \(uncompressedLength).")
154+
throw SnappyDecompressionError.uncompressedDataLengthMismatch(
155+
target: uncompressedLength, actual: uncompressedData.count)
156156
}
157157

158158
return uncompressedData

FastStyleTransfer/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ set_target_properties(FastStyleTransfer PROPERTIES
1010
target_compile_options(FastStyleTransfer PRIVATE
1111
$<$<BOOL:${BUILD_TESTING}>:-enable-testing>)
1212
target_link_libraries(FastStyleTransfer PUBLIC
13-
ModelSupport)
13+
Checkpoints)
1414

1515
add_executable(FastStyleTransferDemo
1616
Demo/ColabDemo.ipynb

FastStyleTransfer/Demo/Helpers.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import FastStyleTransfer
22
import Foundation
3-
import ModelSupport
3+
import Checkpoints
44
import TensorFlow
55

66
extension TransformerNet: ImportableLayer {}

FastStyleTransfer/Utility/ImportableLayer.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import TensorFlow
2-
import ModelSupport
2+
import Checkpoints
33

44
public protocol ImportableLayer: KeyPathIterable {}
55

MiniGo/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ set_target_properties(MiniGo PROPERTIES
2727
target_compile_options(MiniGo PRIVATE
2828
$<$<BOOL:${BUILD_TESTING}>:-enable-testing>)
2929
target_link_libraries(MiniGo PUBLIC
30-
ModelSupport)
30+
Checkpoints)
3131

3232
add_executable(MiniGoDemo
3333
main.swift)

MiniGo/Models/MiniGoCheckpointReader.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import Checkpoints
1516
import TensorFlow
16-
import ModelSupport
1717

1818
public class MiniGoCheckpointReader: CheckpointReader {
1919
private var layerCounts: [String: Int] = [:]

Models/Text/BERT.swift

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import Foundation
16-
import TensorFlow
15+
import Checkpoints
1716
import Datasets
17+
import Foundation
1818
import ModelSupport
19+
import TensorFlow
1920

2021
/// Represents a type that can contribute to the regularization term when training models.
2122
public protocol Regularizable: Differentiable {

Models/Text/BERT/BERTCheckpointReader.swift

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import Checkpoints
1516
import Datasets
1617
import Foundation
1718
import ModelSupport

Models/Text/GPT2/GPT2.swift

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import Checkpoints
1516
import Foundation
1617
import ModelSupport
1718
import TensorFlow

Models/Text/GPT2/PythonCheckpointReader.swift

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import Checkpoints
1516
import ModelSupport
1617
import TensorFlow
1718

Package.swift

+8-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ let package = Package(
1010
],
1111
products: [
1212
.library(name: "Batcher", targets: ["Batcher"]),
13+
.library(name: "Checkpoints", targets: ["Checkpoints"]),
1314
.library(name: "Datasets", targets: ["Datasets"]),
1415
.library(name: "ModelSupport", targets: ["ModelSupport"]),
1516
.library(name: "ImageClassificationModels", targets: ["ImageClassificationModels"]),
@@ -26,14 +27,17 @@ let package = Package(
2627
],
2728
targets: [
2829
.target(name: "Batcher", path: "Batcher"),
30+
.target(
31+
name: "Checkpoints", dependencies: ["SwiftProtobuf", "ModelSupport"],
32+
path: "Checkpoints"),
2933
.target(name: "Datasets", dependencies: ["ModelSupport"], path: "Datasets"),
3034
.target(name: "STBImage", path: "Support/STBImage"),
3135
.target(
3236
name: "ModelSupport", dependencies: ["SwiftProtobuf", "STBImage"], path: "Support",
3337
exclude: ["STBImage"]),
3438
.target(name: "ImageClassificationModels", path: "Models/ImageClassification"),
3539
.target(name: "VideoClassificationModels", path: "Models/Spatiotemporal"),
36-
.target(name: "TextModels", dependencies: ["Datasets"], path: "Models/Text"),
40+
.target(name: "TextModels", dependencies: ["Checkpoints", "Datasets"], path: "Models/Text"),
3741
.target(name: "RecommendationModels", path: "Models/Recommendation"),
3842
.target(name: "TrainingLoop", dependencies: ["ModelSupport"], path: "TrainingLoop"),
3943
.target(
@@ -76,7 +80,7 @@ let package = Package(
7680
dependencies: ["Datasets", "ImageClassificationModels", "TrainingLoop"],
7781
path: "Examples/MobileNetV2-Imagenette"),
7882
.target(
79-
name: "MiniGo", dependencies: ["ModelSupport"], path: "MiniGo", exclude: ["main.swift"]),
83+
name: "MiniGo", dependencies: ["Checkpoints"], path: "MiniGo", exclude: ["main.swift"]),
8084
.target(
8185
name: "MiniGoDemo", dependencies: ["MiniGo"], path: "MiniGo", sources: ["main.swift"]),
8286
.target(
@@ -100,7 +104,7 @@ let package = Package(
100104
.target(name: "GAN", dependencies: ["Datasets", "ModelSupport"], path: "GAN"),
101105
.target(name: "DCGAN", dependencies: ["Datasets", "ModelSupport"], path: "DCGAN"),
102106
.target(
103-
name: "FastStyleTransfer", dependencies: ["ModelSupport"], path: "FastStyleTransfer",
107+
name: "FastStyleTransfer", dependencies: ["Checkpoints"], path: "FastStyleTransfer",
104108
exclude: ["Demo"]),
105109
.target(
106110
name: "FastStyleTransferDemo", dependencies: ["FastStyleTransfer"],
@@ -113,7 +117,7 @@ let package = Package(
113117
"TextModels"
114118
],
115119
path: "Benchmarks"),
116-
.testTarget(name: "CheckpointTests", dependencies: ["ModelSupport"]),
120+
.testTarget(name: "CheckpointTests", dependencies: ["Checkpoints"]),
117121
.target(
118122
name: "BERT-CoLA", dependencies: ["TextModels", "Datasets"], path: "Examples/BERT-CoLA"),
119123
.testTarget(name: "SupportTests", dependencies: ["ModelSupport"]),

Support/CMakeLists.txt

+2-10
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,9 @@ add_subdirectory(STBImage)
22

33
add_library(ModelSupport
44
BijectiveDictionary.swift
5-
Checkpoints/CheckpointIndexReader.swift
6-
Checkpoints/CheckpointIndexWriter.swift
7-
Checkpoints/CheckpointReader.swift
8-
Checkpoints/CheckpointWriter.swift
9-
Checkpoints/Protobufs/tensor_bundle.pb.swift
10-
Checkpoints/Protobufs/tensor_shape.pb.swift
11-
Checkpoints/Protobufs/tensor_slice.pb.swift
12-
Checkpoints/Protobufs/types.pb.swift
13-
Checkpoints/Protobufs/versions.pb.swift
14-
Checkpoints/SnappyDecompression.swift
155
FileManagement.swift
6+
FileSystem.swift
7+
FoundationFileSystem.swift
168
LabeledData.swift
179
Image.swift
1810
Stderr.swift

Support/FileSystem.swift

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
public protocol FileSystem {
18+
/// Creates a directory at a path, if missing. If the directory exists, this does nothing.
19+
///
20+
/// - Parameters:
21+
/// - path: The path of the desired directory.
22+
func createDirectoryIfMissing(at path: String) throws
23+
24+
/// Opens a file at the specified location for reading or writing.
25+
///
26+
/// - Parameters:
27+
/// - path: The path of the file to be opened.
28+
func open(_ path: String) -> File
29+
}
30+
31+
public protocol File {
32+
func read() throws -> Data
33+
func read(position: Int, count: Int) throws -> Data
34+
func write(_ value: Data) throws
35+
func write(_ value: Data, position: Int) throws
36+
}

0 commit comments

Comments
 (0)