Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a298052

Browse files
T2T Teamcopybara-github
T2T Team
authored andcommitted
Splits out the transformer decoder layers so they may be called separately.
PiperOrigin-RevId: 247122386
1 parent ddb9665 commit a298052

File tree

1 file changed

+149
-106
lines changed

1 file changed

+149
-106
lines changed

tensor2tensor/models/transformer.py

+149-106
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,133 @@ def transformer_prepare_decoder(targets, hparams, features=None):
13361336
return (decoder_input, decoder_self_attention_bias)
13371337

13381338

1339+
def transformer_decoder_layer(decoder_input,
1340+
decoder_self_attention_bias,
1341+
layer_idx,
1342+
hparams,
1343+
encoder_output=None,
1344+
encoder_decoder_attention_bias=None,
1345+
cache=None,
1346+
decode_loop_step=None,
1347+
nonpadding=None,
1348+
save_weights_to=None,
1349+
make_image_summary=False,
1350+
losses=None,
1351+
layer_collection=None,
1352+
recurrent_memory_by_layer=None,
1353+
chunk_number=None):
1354+
"""A single transformer decoder layer."""
1355+
x = decoder_input
1356+
layer = layer_idx
1357+
layer_name = "layer_%d" % layer
1358+
layer_cache = cache[layer_name] if cache is not None else None
1359+
1360+
attention_dropout_broadcast_dims = (
1361+
common_layers.comma_separated_string_to_integer_list(
1362+
getattr(hparams, "attention_dropout_broadcast_dims", "")))
1363+
1364+
if recurrent_memory_by_layer is not None:
1365+
recurrent_memory = recurrent_memory_by_layer[layer_name]
1366+
else:
1367+
recurrent_memory = None
1368+
1369+
if layer < hparams.get("num_area_layers", 0):
1370+
max_area_width = hparams.get("max_area_width", 1)
1371+
max_area_height = hparams.get("max_area_height", 1)
1372+
memory_height = hparams.get("max_area_height", 1)
1373+
else:
1374+
max_area_width = 1
1375+
max_area_height = 1
1376+
memory_height = 1
1377+
with tf.variable_scope(layer_name):
1378+
with tf.variable_scope("self_attention"):
1379+
y = common_attention.multihead_attention(
1380+
common_layers.layer_preprocess(
1381+
x, hparams, layer_collection=layer_collection),
1382+
None,
1383+
decoder_self_attention_bias,
1384+
hparams.attention_key_channels or hparams.hidden_size,
1385+
hparams.attention_value_channels or hparams.hidden_size,
1386+
hparams.hidden_size,
1387+
hparams.num_heads,
1388+
hparams.attention_dropout,
1389+
attention_type=hparams.self_attention_type,
1390+
max_relative_position=hparams.max_relative_position,
1391+
heads_share_relative_embedding=(
1392+
hparams.heads_share_relative_embedding),
1393+
add_relative_to_values=hparams.add_relative_to_values,
1394+
save_weights_to=save_weights_to,
1395+
cache=layer_cache,
1396+
make_image_summary=make_image_summary,
1397+
dropout_broadcast_dims=attention_dropout_broadcast_dims,
1398+
max_length=hparams.get("max_length"),
1399+
decode_loop_step=decode_loop_step,
1400+
vars_3d=hparams.get("attention_variables_3d"),
1401+
activation_dtype=hparams.get("activation_dtype", "float32"),
1402+
weight_dtype=hparams.get("weight_dtype", "float32"),
1403+
layer_collection=layer_collection,
1404+
recurrent_memory=recurrent_memory,
1405+
chunk_number=chunk_number,
1406+
hard_attention_k=hparams.get("hard_attention_k", 0),
1407+
max_area_width=max_area_width,
1408+
max_area_height=max_area_height,
1409+
memory_height=memory_height,
1410+
area_key_mode=hparams.get("area_key_mode", "none"),
1411+
area_value_mode=hparams.get("area_value_mode", "none"),
1412+
training=(hparams.get(
1413+
"mode",
1414+
tf.estimator.ModeKeys.TRAIN) == tf.estimator.ModeKeys.TRAIN))
1415+
x = common_layers.layer_postprocess(x, y, hparams)
1416+
if encoder_output is not None:
1417+
with tf.variable_scope("encdec_attention"):
1418+
y = common_attention.multihead_attention(
1419+
common_layers.layer_preprocess(
1420+
x, hparams, layer_collection=layer_collection),
1421+
encoder_output,
1422+
encoder_decoder_attention_bias,
1423+
hparams.attention_key_channels or hparams.hidden_size,
1424+
hparams.attention_value_channels or hparams.hidden_size,
1425+
hparams.hidden_size,
1426+
hparams.num_heads,
1427+
hparams.attention_dropout,
1428+
max_relative_position=hparams.max_relative_position,
1429+
heads_share_relative_embedding=(
1430+
hparams.heads_share_relative_embedding),
1431+
add_relative_to_values=hparams.add_relative_to_values,
1432+
save_weights_to=save_weights_to,
1433+
cache=layer_cache,
1434+
make_image_summary=make_image_summary,
1435+
dropout_broadcast_dims=attention_dropout_broadcast_dims,
1436+
max_length=hparams.get("max_length"),
1437+
vars_3d=hparams.get("attention_variables_3d"),
1438+
activation_dtype=hparams.get("activation_dtype", "float32"),
1439+
weight_dtype=hparams.get("weight_dtype", "float32"),
1440+
layer_collection=layer_collection,
1441+
hard_attention_k=hparams.get("hard_attention_k", 0),
1442+
max_area_width=max_area_width,
1443+
max_area_height=max_area_height,
1444+
memory_height=memory_height,
1445+
area_key_mode=hparams.get("area_key_mode", "none"),
1446+
area_value_mode=hparams.get("area_value_mode", "none"),
1447+
training=(hparams.get(
1448+
"mode",
1449+
tf.estimator.ModeKeys.TRAIN) == tf.estimator.ModeKeys.TRAIN))
1450+
x = common_layers.layer_postprocess(x, y, hparams)
1451+
with tf.variable_scope("ffn"):
1452+
y = transformer_ffn_layer(
1453+
common_layers.layer_preprocess(
1454+
x, hparams, layer_collection=layer_collection),
1455+
hparams,
1456+
conv_padding="LEFT",
1457+
nonpadding_mask=nonpadding,
1458+
losses=losses,
1459+
cache=layer_cache,
1460+
decode_loop_step=decode_loop_step,
1461+
layer_collection=layer_collection)
1462+
x = common_layers.layer_postprocess(x, y, hparams)
1463+
return x
1464+
1465+
13391466
def transformer_decoder(decoder_input,
13401467
encoder_output,
13411468
decoder_self_attention_bias,
@@ -1350,8 +1477,7 @@ def transformer_decoder(decoder_input,
13501477
losses=None,
13511478
layer_collection=None,
13521479
recurrent_memory_by_layer=None,
1353-
chunk_number=None,
1354-
):
1480+
chunk_number=None):
13551481
"""A stack of transformer layers.
13561482
13571483
Args:
@@ -1377,8 +1503,8 @@ def transformer_decoder(decoder_input,
13771503
key created from the variable scope (including name).
13781504
make_image_summary: Whether to make an attention image summary.
13791505
losses: optional list onto which to append extra training losses
1380-
layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
1381-
KFAC optimizer. Default is None.
1506+
layer_collection: A tensorflow_kfac.LayerCollection. Only used by the KFAC
1507+
optimizer. Default is None.
13821508
recurrent_memory_by_layer: Optional dict, mapping layer names to instances
13831509
of transformer_memory.RecurrentMemory. Default is None.
13841510
chunk_number: an optional integer Tensor with shape [batch] used to operate
@@ -1388,9 +1514,6 @@ def transformer_decoder(decoder_input,
13881514
y: a Tensors
13891515
"""
13901516
x = decoder_input
1391-
attention_dropout_broadcast_dims = (
1392-
common_layers.comma_separated_string_to_integer_list(
1393-
getattr(hparams, "attention_dropout_broadcast_dims", "")))
13941517

13951518
mlperf_log.transformer_print(
13961519
key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS,
@@ -1410,106 +1533,26 @@ def transformer_decoder(decoder_input,
14101533
hparams=hparams)
14111534

14121535
with tf.variable_scope(name):
1413-
for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers):
1414-
layer_name = "layer_%d" % layer
1415-
layer_cache = cache[layer_name] if cache is not None else None
1416-
if recurrent_memory_by_layer is not None:
1417-
recurrent_memory = recurrent_memory_by_layer[layer_name]
1418-
else:
1419-
recurrent_memory = None
1536+
for layer_idx in range(hparams.num_decoder_layers or
1537+
hparams.num_hidden_layers):
1538+
x = transformer_decoder_layer(
1539+
x,
1540+
decoder_self_attention_bias,
1541+
layer_idx,
1542+
hparams,
1543+
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
1544+
encoder_output=encoder_output,
1545+
cache=cache,
1546+
decode_loop_step=decode_loop_step,
1547+
nonpadding=nonpadding,
1548+
save_weights_to=save_weights_to,
1549+
make_image_summary=make_image_summary,
1550+
losses=losses,
1551+
layer_collection=layer_collection,
1552+
recurrent_memory_by_layer=recurrent_memory_by_layer,
1553+
chunk_number=chunk_number,
1554+
)
14201555

1421-
if layer < hparams.get("num_area_layers", 0):
1422-
max_area_width = hparams.get("max_area_width", 1)
1423-
max_area_height = hparams.get("max_area_height", 1)
1424-
memory_height = hparams.get("max_area_height", 1)
1425-
else:
1426-
max_area_width = 1
1427-
max_area_height = 1
1428-
memory_height = 1
1429-
with tf.variable_scope(layer_name):
1430-
with tf.variable_scope("self_attention"):
1431-
y = common_attention.multihead_attention(
1432-
common_layers.layer_preprocess(
1433-
x, hparams, layer_collection=layer_collection),
1434-
None,
1435-
decoder_self_attention_bias,
1436-
hparams.attention_key_channels or hparams.hidden_size,
1437-
hparams.attention_value_channels or hparams.hidden_size,
1438-
hparams.hidden_size,
1439-
hparams.num_heads,
1440-
hparams.attention_dropout,
1441-
attention_type=hparams.self_attention_type,
1442-
max_relative_position=hparams.max_relative_position,
1443-
heads_share_relative_embedding=(
1444-
hparams.heads_share_relative_embedding),
1445-
add_relative_to_values=hparams.add_relative_to_values,
1446-
save_weights_to=save_weights_to,
1447-
cache=layer_cache,
1448-
make_image_summary=make_image_summary,
1449-
dropout_broadcast_dims=attention_dropout_broadcast_dims,
1450-
max_length=hparams.get("max_length"),
1451-
decode_loop_step=decode_loop_step,
1452-
vars_3d=hparams.get("attention_variables_3d"),
1453-
activation_dtype=hparams.get("activation_dtype", "float32"),
1454-
weight_dtype=hparams.get("weight_dtype", "float32"),
1455-
layer_collection=layer_collection,
1456-
recurrent_memory=recurrent_memory,
1457-
chunk_number=chunk_number,
1458-
hard_attention_k=hparams.get("hard_attention_k", 0),
1459-
max_area_width=max_area_width,
1460-
max_area_height=max_area_height,
1461-
memory_height=memory_height,
1462-
area_key_mode=hparams.get("area_key_mode", "none"),
1463-
area_value_mode=hparams.get("area_value_mode", "none"),
1464-
training=(hparams.get("mode", tf.estimator.ModeKeys.TRAIN)
1465-
== tf.estimator.ModeKeys.TRAIN))
1466-
x = common_layers.layer_postprocess(x, y, hparams)
1467-
if encoder_output is not None:
1468-
with tf.variable_scope("encdec_attention"):
1469-
y = common_attention.multihead_attention(
1470-
common_layers.layer_preprocess(
1471-
x, hparams, layer_collection=layer_collection),
1472-
encoder_output,
1473-
encoder_decoder_attention_bias,
1474-
hparams.attention_key_channels or hparams.hidden_size,
1475-
hparams.attention_value_channels or hparams.hidden_size,
1476-
hparams.hidden_size,
1477-
hparams.num_heads,
1478-
hparams.attention_dropout,
1479-
max_relative_position=hparams.max_relative_position,
1480-
heads_share_relative_embedding=(
1481-
hparams.heads_share_relative_embedding),
1482-
add_relative_to_values=hparams.add_relative_to_values,
1483-
save_weights_to=save_weights_to,
1484-
cache=layer_cache,
1485-
make_image_summary=make_image_summary,
1486-
dropout_broadcast_dims=attention_dropout_broadcast_dims,
1487-
max_length=hparams.get("max_length"),
1488-
vars_3d=hparams.get("attention_variables_3d"),
1489-
activation_dtype=hparams.get("activation_dtype", "float32"),
1490-
weight_dtype=hparams.get("weight_dtype", "float32"),
1491-
layer_collection=layer_collection,
1492-
hard_attention_k=hparams.get("hard_attention_k", 0),
1493-
max_area_width=max_area_width,
1494-
max_area_height=max_area_height,
1495-
memory_height=memory_height,
1496-
area_key_mode=hparams.get("area_key_mode", "none"),
1497-
area_value_mode=hparams.get("area_value_mode", "none"),
1498-
training=(hparams.get("mode", tf.estimator.ModeKeys.TRAIN)
1499-
== tf.estimator.ModeKeys.TRAIN))
1500-
x = common_layers.layer_postprocess(x, y, hparams)
1501-
with tf.variable_scope("ffn"):
1502-
y = transformer_ffn_layer(
1503-
common_layers.layer_preprocess(
1504-
x, hparams, layer_collection=layer_collection),
1505-
hparams,
1506-
conv_padding="LEFT",
1507-
nonpadding_mask=nonpadding,
1508-
losses=losses,
1509-
cache=layer_cache,
1510-
decode_loop_step=decode_loop_step,
1511-
layer_collection=layer_collection)
1512-
x = common_layers.layer_postprocess(x, y, hparams)
15131556
# if normalization is done in layer_preprocess, then it should also be done
15141557
# on the output, since the output can grow very large, being the sum of
15151558
# a whole stack of unnormalized layer outputs.

0 commit comments

Comments
 (0)