Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit 339752a

Browse files
authored
A higher-level interface to model checkpointing (#631)
* Creating generalized tensor extraction for models, reworking GPT-2 to use this. * Adding unit tests for model checkpoint writing. * Added checkpoint reading to Checkpointable. * Minor access level adjustment. * Adding checkpointing overview documentation.
1 parent eda41af commit 339752a

10 files changed

+525
-127
lines changed

Checkpoints/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_library(Checkpoints
2+
Checkpointable.swift
23
CheckpointIndexReader.swift
34
CheckpointIndexWriter.swift
45
CheckpointReader.swift

Checkpoints/CheckpointReader.swift

+18
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,24 @@ open class CheckpointReader {
290290
}
291291
}
292292

293+
extension CheckpointReader {
294+
static func recursivelyObtainTensorNames(
295+
_ current: Any, scope: String? = nil, tensors: inout [String],
296+
separator: String, ignoredTensorPaths: Set<String> = []
297+
) {
298+
CheckpointWriter.recursivelyVisitTensors(
299+
current, scope: scope, separator: separator, ignoredTensorPaths: ignoredTensorPaths
300+
) { child, path in
301+
if child.value is Tensor<Float> {
302+
tensors.append(path)
303+
return false
304+
} else {
305+
return true
306+
}
307+
}
308+
}
309+
}
310+
293311
extension Tensorflow_TensorShapeProto {
294312
var shapeArray: [Int] {
295313
return self.dim.map { Int($0.size) }

Checkpoints/CheckpointWriter.swift

+89-10
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,19 @@ import Foundation
1616
import ModelSupport
1717
import TensorFlow
1818

19+
// TODO: Extend handling to different tensor types.
20+
1921
/// A Swift-native TensorFlow v2 checkpoint writer. This writer has no dependencies
2022
/// on the TensorFlow runtime or libraries.
2123
open class CheckpointWriter {
22-
// TODO: Extend handling to different tensor types.
23-
let tensors: [String: Tensor<Float>]
2424
let fileSystem: FileSystem
2525

2626
/// Initializes the checkpoint reader from a dictionary of tensors, keyed on their string names.
2727
///
2828
/// - Parameters:
29-
/// - tensors: A dictionary containing the tensors to be written, with the keys being the
30-
/// names of those tensors to write in the checkpoint.
3129
/// - fileSystem: The filesystem used for writing the checkpoint.
32-
public init(
33-
tensors: [String: Tensor<Float>], fileSystem: FileSystem = FoundationFileSystem()
30+
public init(fileSystem: FileSystem = FoundationFileSystem()
3431
) {
35-
self.tensors = tensors
3632
self.fileSystem = fileSystem
3733
}
3834

@@ -41,11 +37,13 @@ open class CheckpointWriter {
4137
/// [name].data-0000X-of-0000Y binary shard files with the tensor bytes within them.
4238
///
4339
/// - Parameters:
40+
/// - tensors: The tensors to be written, keyed by the names of those tensors to write in the
41+
/// checkpoint.
4442
/// - directory: The directory to write the checkpoint into. If it doesn't exist, it will be
4543
/// created.
4644
/// - name: The base name of the checkpoint, which is what will have the .index and
4745
/// .data-0000X-of-0000Y extensions appended to it for files in the checkpoint directory.
48-
public func write(to directory: URL, name: String) throws {
46+
public func write(tensors: [String: Tensor<Float>], to directory: URL, name: String) throws {
4947
try fileSystem.createDirectoryIfMissing(at: directory.path)
5048
let indexWriter = CheckpointIndexWriter(tensors: tensors)
5149
let indexHeader = indexWriter.serializedHeader()
@@ -56,10 +54,13 @@ open class CheckpointWriter {
5654
// TODO: Handle splitting into multiple shards.
5755
try writeShard(
5856
to: directory.appendingPathComponent("\(name)"), shard: 0, numShards: 1,
59-
tensorList: indexWriter.orderedTensors)
57+
tensors: tensors, tensorList: indexWriter.orderedTensors)
6058
}
6159

62-
func writeShard(to location: URL, shard: Int, numShards: Int, tensorList: [String]) throws {
60+
func writeShard(
61+
to location: URL, shard: Int, numShards: Int, tensors: [String: Tensor<Float>],
62+
tensorList: [String]
63+
) throws {
6364
let shardFile = CheckpointReader.shardFile(
6465
location: location, shard: shard, totalShards: numShards)
6566

@@ -83,3 +84,81 @@ open class CheckpointWriter {
8384
try outputFile.write(outputBuffer)
8485
}
8586
}
87+
88+
extension CheckpointWriter {
89+
static func recursivelyObtainTensors(
90+
_ current: Any, scope: String? = nil, tensors: inout [String: Tensor<Float>],
91+
separator: String, ignoredTensorPaths: Set<String> = []
92+
) {
93+
recursivelyVisitTensors(
94+
current, scope: scope, separator: separator, ignoredTensorPaths: ignoredTensorPaths
95+
) { child, path in
96+
if let tensor = child.value as? Tensor<Float> {
97+
if tensors[path] != nil {
98+
print(
99+
"Warning: Saved two different tensors with the same name: \(path). This is most likely undesired behavior.")
100+
}
101+
tensors[path] = tensor
102+
return false
103+
} else {
104+
return true
105+
}
106+
}
107+
}
108+
109+
static func recursivelyVisitTensors(
110+
_ current: Any, scope: String? = nil, separator: String,
111+
ignoredTensorPaths: Set<String> = [], visitor: (Mirror.Child, String) -> Bool
112+
) {
113+
let currentType = String(describing: type(of: current.self))
114+
let m = Mirror(reflecting: current)
115+
116+
var previousNames: [String: Int] = [:]
117+
var emptyCount = 0
118+
for child in m.children {
119+
let uniqueLabel: String
120+
if let label = child.label {
121+
if let nameCount = previousNames[label] {
122+
uniqueLabel = "\(label)_\(nameCount)"
123+
previousNames[label] = nameCount + 1
124+
} else {
125+
uniqueLabel = label
126+
previousNames[label] = 1
127+
}
128+
} else {
129+
uniqueLabel = "[\(emptyCount)]"
130+
emptyCount += 1
131+
}
132+
let path = (scope != nil ? scope! + separator : "") + uniqueLabel
133+
let compoundTypeDescription = "\(currentType).\(uniqueLabel)"
134+
if ignoredTensorPaths.contains(compoundTypeDescription) {
135+
continue
136+
}
137+
if visitor(child, path) {
138+
recursivelyVisitTensors(
139+
child.value, scope: path, separator: separator,
140+
ignoredTensorPaths: ignoredTensorPaths, visitor: visitor)
141+
}
142+
}
143+
}
144+
145+
static func remapTensorNames(
146+
tensors: [String: Tensor<Float>], nameMap: (String) -> String
147+
) -> [String: Tensor<Float>] {
148+
var remappedTensors: [String: Tensor<Float>] = [:]
149+
for (key, value) in tensors {
150+
remappedTensors[nameMap(key)] = value
151+
}
152+
return remappedTensors
153+
}
154+
155+
static func lookupMap(table: [String: String]) -> (String) -> String {
156+
return {name in
157+
return table[name] ?? name
158+
}
159+
}
160+
161+
static func identityMap(_ name: String) -> String {
162+
return name
163+
}
164+
}

Checkpoints/Checkpointable.swift

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import ModelSupport
16+
import Foundation
17+
import TensorFlow
18+
19+
/// Models that comply to Checkpointable can have their Tensors be written to and read from disk
20+
/// using the `writeCheckpoint(to:...)` and `readCheckpoint(from:...)` interfaces.
21+
public protocol Checkpointable: KeyPathIterable {
22+
/// Any Tensor that should be ignored for checkpoint reading or writing, specified in
23+
/// `Type.property` syntax. For example, `["Attention.scale"]`.
24+
var ignoredTensorPaths: Set<String> { get }
25+
26+
/// The string separator between descending levels of the model. For example, a separator of `"/"`
27+
/// will yield tensor path names like `conv2/filter`.
28+
var checkpointSeparator: String { get }
29+
30+
/// A mapping function between the internally generated tensor path names and how those names
31+
/// will or do appear in the on-disk checkpoint.
32+
var tensorNameMap: ((String) -> String) { get }
33+
}
34+
35+
public extension Checkpointable {
36+
var ignoredTensorPaths: Set<String> {
37+
return []
38+
}
39+
40+
var checkpointSeparator: String {
41+
return "/"
42+
}
43+
44+
var tensorNameMap: (String) -> String {
45+
return CheckpointWriter.identityMap
46+
}
47+
48+
/// Writes a checkpoint of this model's tensors to disk.
49+
///
50+
/// - Parameters:
51+
/// - location: The directory to write the checkpoint into. If it doesn't exist, it will be
52+
/// created.
53+
/// - name: The base name of the checkpoint, which is what will have the .index and
54+
/// .data-0000X-of-0000Y extensions appended to it for files in the checkpoint directory.
55+
/// - fileSystem: The filesystem used for writing the checkpoint. Defaults to
56+
/// FoundationFileSystem.
57+
/// - nameTable: A lookup table of generated tensor path names to their corresponding tensor
58+
/// name in the checkpoint file. If an internal tensor path name is not represented, the
59+
/// internal path name is used for the on-disk checkpoint.
60+
func writeCheckpoint(
61+
to location: URL, name: String, fileSystem: FileSystem = FoundationFileSystem(),
62+
nameTable: [String: String]
63+
) throws {
64+
try writeCheckpoint(
65+
to: location, name: name, fileSystem: fileSystem,
66+
nameMap: CheckpointWriter.lookupMap(table: nameTable))
67+
}
68+
69+
/// Writes a checkpoint of this model's tensors to disk.
70+
///
71+
/// - Parameters:
72+
/// - location: The directory to write the checkpoint into. If it doesn't exist, it will be
73+
/// created.
74+
/// - name: The base name of the checkpoint, which is what will have the .index and
75+
/// .data-0000X-of-0000Y extensions appended to it for files in the checkpoint directory.
76+
/// - fileSystem: The filesystem used for writing the checkpoint. Defaults to
77+
/// FoundationFileSystem.
78+
/// - nameMap: A mapping function that converts generated tensor path names to their
79+
/// corresponding tensor name in the checkpoint file.
80+
func writeCheckpoint(
81+
to location: URL, name: String, fileSystem: FileSystem = FoundationFileSystem(),
82+
nameMap: ((String) -> String)? = nil
83+
) throws {
84+
var rawTensors: [String: Tensor<Float>] = [:]
85+
CheckpointWriter.recursivelyObtainTensors(
86+
self, tensors: &rawTensors, separator: self.checkpointSeparator,
87+
ignoredTensorPaths: self.ignoredTensorPaths)
88+
89+
let tensors = CheckpointWriter.remapTensorNames(tensors: rawTensors,
90+
nameMap: nameMap ?? self.tensorNameMap)
91+
92+
let writer = CheckpointWriter(fileSystem: fileSystem)
93+
try writer.write(tensors: tensors, to: location, name: name)
94+
}
95+
96+
/// Reads a checkpoint of this model's tensors from disk.
97+
///
98+
/// - Parameters:
99+
/// - location: Either a URL to the checkpoint files, where the last component is the file
100+
/// base of the checkpoint files, or a URL to an archive containing the checkpoint files.
101+
/// - name: The base name of the checkpoint, which is what will have the .index and
102+
/// .data-0000X-of-0000Y extensions appended to it for files in the checkpoint directory.
103+
/// - fileSystem: The filesystem used for reading the checkpoint. Defaults to
104+
/// FoundationFileSystem.
105+
/// - nameMap: A mapping function that converts generated tensor path names to their
106+
/// corresponding tensor name in the checkpoint file.
107+
mutating func readCheckpoint(
108+
from location: URL, name: String, fileSystem: FileSystem = FoundationFileSystem(),
109+
nameMap: ((String) -> String)? = nil
110+
) throws {
111+
var rawTensorNames: [String] = []
112+
CheckpointReader.recursivelyObtainTensorNames(
113+
self, tensors: &rawTensorNames, separator: self.checkpointSeparator,
114+
ignoredTensorPaths: self.ignoredTensorPaths)
115+
116+
let concreteNameMap = nameMap ?? self.tensorNameMap
117+
let tensorNames = rawTensorNames.map{ concreteNameMap($0) }
118+
119+
let keypaths = self.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self)
120+
121+
guard keypaths.count == tensorNames.count else {
122+
fatalError(
123+
"The number of writable key paths: \(keypaths.count) did not match the number of tensor names: \(tensorNames.count)")
124+
}
125+
126+
let reader: CheckpointReader = try CheckpointReader(checkpointLocation: location,
127+
modelName: name)
128+
129+
for (index, keypath) in keypaths.enumerated() {
130+
self[keyPath: keypath] = Tensor<Float>(reader.loadTensor(named: tensorNames[index]))
131+
}
132+
}
133+
134+
/// Reads a checkpoint of this model's tensors from disk.
135+
///
136+
/// - Parameters:
137+
/// - location: Either a URL to the checkpoint files, where the last component is the file
138+
/// base of the checkpoint files, or a URL to an archive containing the checkpoint files.
139+
/// - name: The base name of the checkpoint, which is what will have the .index and
140+
/// .data-0000X-of-0000Y extensions appended to it for files in the checkpoint directory.
141+
/// - fileSystem: The filesystem used for reading the checkpoint. Defaults to
142+
/// FoundationFileSystem.
143+
/// - nameTable: A lookup table of generated tensor path names to their corresponding tensor
144+
/// name in the checkpoint file. If an internal tensor path name is not represented, the
145+
/// internal path name is used for the on-disk checkpoint.
146+
mutating func readCheckpoint(
147+
from location: URL, name: String, fileSystem: FileSystem = FoundationFileSystem(),
148+
nameTable: [String: String]
149+
) throws {
150+
try readCheckpoint(
151+
from: location, name: name, fileSystem: fileSystem,
152+
nameMap: CheckpointWriter.lookupMap(table: nameTable))
153+
}
154+
}

0 commit comments

Comments
 (0)