Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d9f807c

Browse files
authored
Merge pull request #456 from lukaszkaiser/push
1.3.2 (small changes for colab)
2 parents 970dac9 + f2fb96b commit d9f807c

File tree

5 files changed

+543
-405
lines changed

5 files changed

+543
-405
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.3.1',
8+
version='1.3.2',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/data_generators/translate_enfr.py

+8
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,14 @@ def use_small_dataset(self):
143143
return False
144144

145145

146+
@registry.register_problem
147+
class TranslateEnfrWmt32kPacked(TranslateEnfrWmt32k):
148+
149+
@property
150+
def packed_length(self):
151+
return 256
152+
153+
146154
@registry.register_problem
147155
class TranslateEnfrWmtSmallCharacters(translate.TranslateProblem):
148156
"""Problem spec for WMT En-Fr translation."""

tensor2tensor/layers/common_attention.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,8 @@ def dot_product_attention(q,
11821182
dropout_rate=0.0,
11831183
image_shapes=None,
11841184
name=None,
1185-
make_image_summary=True):
1185+
make_image_summary=True,
1186+
save_weights_to=None):
11861187
"""dot-product attention.
11871188
11881189
Args:
@@ -1195,17 +1196,22 @@ def dot_product_attention(q,
11951196
see comments for attention_image_summary()
11961197
name: an optional string
11971198
make_image_summary: True if you want an image summary.
1199+
save_weights_to: an optional dictionary to capture attention weights
1200+
for vizualization; the weights tensor will be appended there under
1201+
a string key created from the variable scope (including name).
11981202
11991203
Returns:
12001204
A Tensor.
12011205
"""
12021206
with tf.variable_scope(
1203-
name, default_name="dot_product_attention", values=[q, k, v]):
1207+
name, default_name="dot_product_attention", values=[q, k, v]) as scope:
12041208
# [batch, num_heads, query_length, memory_length]
12051209
logits = tf.matmul(q, k, transpose_b=True)
12061210
if bias is not None:
12071211
logits += bias
12081212
weights = tf.nn.softmax(logits, name="attention_weights")
1213+
if save_weights_to is not None:
1214+
save_weights_to[scope.name] = weights
12091215
# dropping out the attention links for each of the heads
12101216
weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
12111217
if (not tf.get_variable_scope().reuse and
@@ -2245,6 +2251,7 @@ def multihead_attention(query_antecedent,
22452251
gap_size=0,
22462252
num_memory_blocks=2,
22472253
name=None,
2254+
save_weights_to=None,
22482255
**kwargs):
22492256
"""Multihead scaled-dot-product attention with input/output transformations.
22502257
@@ -2284,7 +2291,10 @@ def multihead_attention(query_antecedent,
22842291
memory blocks.
22852292
num_memory_blocks: Integer option to indicate how many memory blocks to look
22862293
at.
2287-
name: an optional string
2294+
name: an optional string.
2295+
save_weights_to: an optional dictionary to capture attention weights
2296+
for vizualization; the weights tensor will be appended there under
2297+
a string key created from the variable scope (including name).
22882298
**kwargs (dict): Parameters for the attention function
22892299
22902300
Caching:
@@ -2345,7 +2355,8 @@ def multihead_attention(query_antecedent,
23452355
if isinstance(x, tuple):
23462356
x, additional_returned_value = x # Unpack
23472357
elif attention_type == "dot_product":
2348-
x = dot_product_attention(q, k, v, bias, dropout_rate, image_shapes)
2358+
x = dot_product_attention(q, k, v, bias, dropout_rate, image_shapes,
2359+
save_weights_to=save_weights_to)
23492360
elif attention_type == "dot_product_relative":
23502361
x = dot_product_attention_relative(q, k, v, bias, max_relative_position,
23512362
dropout_rate, image_shapes)

tensor2tensor/models/transformer.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
class Transformer(t2t_model.T2TModel):
4646
"""Attention net. See file docstring."""
4747

48+
def __init__(self, *args, **kwargs):
49+
super(Transformer, self).__init__(*args, **kwargs)
50+
self.attention_weights = dict() # For vizualizing attention heads.
51+
4852
def encode(self, inputs, target_space, hparams, features=None):
4953
"""Encode transformer inputs.
5054
@@ -73,7 +77,8 @@ def encode(self, inputs, target_space, hparams, features=None):
7377

7478
encoder_output = transformer_encoder(
7579
encoder_input, self_attention_bias,
76-
hparams, nonpadding=_features_to_nonpadding(features, "inputs"))
80+
hparams, nonpadding=_features_to_nonpadding(features, "inputs"),
81+
save_weights_to=self.attention_weights)
7782

7883
return encoder_output, encoder_decoder_attention_bias
7984

@@ -114,7 +119,8 @@ def decode(self,
114119
encoder_decoder_attention_bias,
115120
hparams,
116121
cache=cache,
117-
nonpadding=nonpadding)
122+
nonpadding=nonpadding,
123+
save_weights_to=self.attention_weights)
118124

119125
if hparams.use_tpu and hparams.mode == tf.estimator.ModeKeys.TRAIN:
120126
# TPU does not react kindly to extra dimensions.
@@ -507,7 +513,8 @@ def transformer_encoder(encoder_input,
507513
encoder_self_attention_bias,
508514
hparams,
509515
name="encoder",
510-
nonpadding=None):
516+
nonpadding=None,
517+
save_weights_to=None):
511518
"""A stack of transformer layers.
512519
513520
Args:
@@ -522,6 +529,9 @@ def transformer_encoder(encoder_input,
522529
encoder_self_attention_bias. The knowledge about padding is used
523530
for pad_remover(efficiency) and to mask out padding in convoltutional
524531
layers.
532+
save_weights_to: an optional dictionary to capture attention weights
533+
for vizualization; the weights tensor will be appended there under
534+
a string key created from the variable scope (including name).
525535
526536
Returns:
527537
y: a Tensors
@@ -551,6 +561,7 @@ def transformer_encoder(encoder_input,
551561
hparams.num_heads,
552562
hparams.attention_dropout,
553563
attention_type=hparams.self_attention_type,
564+
save_weights_to=save_weights_to,
554565
max_relative_position=hparams.max_relative_position)
555566
x = common_layers.layer_postprocess(x, y, hparams)
556567
with tf.variable_scope("ffn"):
@@ -571,7 +582,8 @@ def transformer_decoder(decoder_input,
571582
hparams,
572583
cache=None,
573584
name="decoder",
574-
nonpadding=None):
585+
nonpadding=None,
586+
save_weights_to=None):
575587
"""A stack of transformer layers.
576588
577589
Args:
@@ -590,6 +602,9 @@ def transformer_decoder(decoder_input,
590602
to mask out padding in convoltutional layers. We generally only
591603
need this mask for "packed" datasets, because for ordinary datasets,
592604
no padding is ever followed by nonpadding.
605+
save_weights_to: an optional dictionary to capture attention weights
606+
for vizualization; the weights tensor will be appended there under
607+
a string key created from the variable scope (including name).
593608
594609
Returns:
595610
y: a Tensors
@@ -612,6 +627,7 @@ def transformer_decoder(decoder_input,
612627
hparams.num_heads,
613628
hparams.attention_dropout,
614629
attention_type=hparams.self_attention_type,
630+
save_weights_to=save_weights_to,
615631
max_relative_position=hparams.max_relative_position,
616632
cache=layer_cache)
617633
x = common_layers.layer_postprocess(x, y, hparams)
@@ -624,7 +640,8 @@ def transformer_decoder(decoder_input,
624640
hparams.attention_key_channels or hparams.hidden_size,
625641
hparams.attention_value_channels or hparams.hidden_size,
626642
hparams.hidden_size, hparams.num_heads,
627-
hparams.attention_dropout)
643+
hparams.attention_dropout,
644+
save_weights_to=save_weights_to)
628645
x = common_layers.layer_postprocess(x, y, hparams)
629646
with tf.variable_scope("ffn"):
630647
y = transformer_ffn_layer(

0 commit comments

Comments
 (0)