Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 83d98cd

Browse files
T2T TeamCopybara-Service
T2T Team
authored and
Copybara-Service
committed
TPU compatibility for data chunking
PiperOrigin-RevId: 237326592
1 parent 6671139 commit 83d98cd

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

tensor2tensor/utils/data_reader.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -490,17 +490,27 @@ def is_nonzero_chunk(example):
490490
def split_on_length(example):
491491
"""Split a batch of ditcs on length."""
492492
x = example["targets"]
493+
# TODO(kitaev): This code breaks if chunk_length * max_chunks < batch_size
493494
length_diff = chunk_length * max_chunks - tf.shape(x)[1]
494495
padded_x = tf.pad(x, [(0, 0), (0, length_diff), (0, 0), (0, 0)])
495496
chunks = [padded_x[:, i*chunk_length:(i+1)*chunk_length, :, :]
496497
for i in range(max_chunks - 1)]
497498
chunks.append(padded_x[:, (max_chunks - 1)*chunk_length:, :, :])
498499
new_example = {}
499-
new_example["chunk_number"] = tf.range(max_chunks)
500+
# Setting chunk_number to be tf.range(max_chunks) is incompatible with TPU
501+
new_example["chunk_number"] = tf.concat([
502+
tf.expand_dims(tf.ones_like(c) * n, axis=0)
503+
for n, c in enumerate(chunks)
504+
],
505+
axis=0)
500506
new_example["targets"] = tf.concat(
501507
[tf.expand_dims(c, axis=0) for c in chunks], axis=0)
502508
for k in example:
503509
if k != "targets":
510+
assert k != "chunk_number", (
511+
"Chunking code expects the chunk_number feature name to be "
512+
"available"
513+
)
504514
new_example[k] = tf.concat(
505515
[tf.expand_dims(example[k], axis=0) for _ in range(max_chunks)],
506516
axis=0)

0 commit comments

Comments
 (0)