@@ -1336,6 +1336,133 @@ def transformer_prepare_decoder(targets, hparams, features=None):
1336
1336
return (decoder_input , decoder_self_attention_bias )
1337
1337
1338
1338
1339
+ def transformer_decoder_layer (decoder_input ,
1340
+ decoder_self_attention_bias ,
1341
+ layer_idx ,
1342
+ hparams ,
1343
+ encoder_output = None ,
1344
+ encoder_decoder_attention_bias = None ,
1345
+ cache = None ,
1346
+ decode_loop_step = None ,
1347
+ nonpadding = None ,
1348
+ save_weights_to = None ,
1349
+ make_image_summary = False ,
1350
+ losses = None ,
1351
+ layer_collection = None ,
1352
+ recurrent_memory_by_layer = None ,
1353
+ chunk_number = None ):
1354
+ """A single transformer decoder layer."""
1355
+ x = decoder_input
1356
+ layer = layer_idx
1357
+ layer_name = "layer_%d" % layer
1358
+ layer_cache = cache [layer_name ] if cache is not None else None
1359
+
1360
+ attention_dropout_broadcast_dims = (
1361
+ common_layers .comma_separated_string_to_integer_list (
1362
+ getattr (hparams , "attention_dropout_broadcast_dims" , "" )))
1363
+
1364
+ if recurrent_memory_by_layer is not None :
1365
+ recurrent_memory = recurrent_memory_by_layer [layer_name ]
1366
+ else :
1367
+ recurrent_memory = None
1368
+
1369
+ if layer < hparams .get ("num_area_layers" , 0 ):
1370
+ max_area_width = hparams .get ("max_area_width" , 1 )
1371
+ max_area_height = hparams .get ("max_area_height" , 1 )
1372
+ memory_height = hparams .get ("max_area_height" , 1 )
1373
+ else :
1374
+ max_area_width = 1
1375
+ max_area_height = 1
1376
+ memory_height = 1
1377
+ with tf .variable_scope (layer_name ):
1378
+ with tf .variable_scope ("self_attention" ):
1379
+ y = common_attention .multihead_attention (
1380
+ common_layers .layer_preprocess (
1381
+ x , hparams , layer_collection = layer_collection ),
1382
+ None ,
1383
+ decoder_self_attention_bias ,
1384
+ hparams .attention_key_channels or hparams .hidden_size ,
1385
+ hparams .attention_value_channels or hparams .hidden_size ,
1386
+ hparams .hidden_size ,
1387
+ hparams .num_heads ,
1388
+ hparams .attention_dropout ,
1389
+ attention_type = hparams .self_attention_type ,
1390
+ max_relative_position = hparams .max_relative_position ,
1391
+ heads_share_relative_embedding = (
1392
+ hparams .heads_share_relative_embedding ),
1393
+ add_relative_to_values = hparams .add_relative_to_values ,
1394
+ save_weights_to = save_weights_to ,
1395
+ cache = layer_cache ,
1396
+ make_image_summary = make_image_summary ,
1397
+ dropout_broadcast_dims = attention_dropout_broadcast_dims ,
1398
+ max_length = hparams .get ("max_length" ),
1399
+ decode_loop_step = decode_loop_step ,
1400
+ vars_3d = hparams .get ("attention_variables_3d" ),
1401
+ activation_dtype = hparams .get ("activation_dtype" , "float32" ),
1402
+ weight_dtype = hparams .get ("weight_dtype" , "float32" ),
1403
+ layer_collection = layer_collection ,
1404
+ recurrent_memory = recurrent_memory ,
1405
+ chunk_number = chunk_number ,
1406
+ hard_attention_k = hparams .get ("hard_attention_k" , 0 ),
1407
+ max_area_width = max_area_width ,
1408
+ max_area_height = max_area_height ,
1409
+ memory_height = memory_height ,
1410
+ area_key_mode = hparams .get ("area_key_mode" , "none" ),
1411
+ area_value_mode = hparams .get ("area_value_mode" , "none" ),
1412
+ training = (hparams .get (
1413
+ "mode" ,
1414
+ tf .estimator .ModeKeys .TRAIN ) == tf .estimator .ModeKeys .TRAIN ))
1415
+ x = common_layers .layer_postprocess (x , y , hparams )
1416
+ if encoder_output is not None :
1417
+ with tf .variable_scope ("encdec_attention" ):
1418
+ y = common_attention .multihead_attention (
1419
+ common_layers .layer_preprocess (
1420
+ x , hparams , layer_collection = layer_collection ),
1421
+ encoder_output ,
1422
+ encoder_decoder_attention_bias ,
1423
+ hparams .attention_key_channels or hparams .hidden_size ,
1424
+ hparams .attention_value_channels or hparams .hidden_size ,
1425
+ hparams .hidden_size ,
1426
+ hparams .num_heads ,
1427
+ hparams .attention_dropout ,
1428
+ max_relative_position = hparams .max_relative_position ,
1429
+ heads_share_relative_embedding = (
1430
+ hparams .heads_share_relative_embedding ),
1431
+ add_relative_to_values = hparams .add_relative_to_values ,
1432
+ save_weights_to = save_weights_to ,
1433
+ cache = layer_cache ,
1434
+ make_image_summary = make_image_summary ,
1435
+ dropout_broadcast_dims = attention_dropout_broadcast_dims ,
1436
+ max_length = hparams .get ("max_length" ),
1437
+ vars_3d = hparams .get ("attention_variables_3d" ),
1438
+ activation_dtype = hparams .get ("activation_dtype" , "float32" ),
1439
+ weight_dtype = hparams .get ("weight_dtype" , "float32" ),
1440
+ layer_collection = layer_collection ,
1441
+ hard_attention_k = hparams .get ("hard_attention_k" , 0 ),
1442
+ max_area_width = max_area_width ,
1443
+ max_area_height = max_area_height ,
1444
+ memory_height = memory_height ,
1445
+ area_key_mode = hparams .get ("area_key_mode" , "none" ),
1446
+ area_value_mode = hparams .get ("area_value_mode" , "none" ),
1447
+ training = (hparams .get (
1448
+ "mode" ,
1449
+ tf .estimator .ModeKeys .TRAIN ) == tf .estimator .ModeKeys .TRAIN ))
1450
+ x = common_layers .layer_postprocess (x , y , hparams )
1451
+ with tf .variable_scope ("ffn" ):
1452
+ y = transformer_ffn_layer (
1453
+ common_layers .layer_preprocess (
1454
+ x , hparams , layer_collection = layer_collection ),
1455
+ hparams ,
1456
+ conv_padding = "LEFT" ,
1457
+ nonpadding_mask = nonpadding ,
1458
+ losses = losses ,
1459
+ cache = layer_cache ,
1460
+ decode_loop_step = decode_loop_step ,
1461
+ layer_collection = layer_collection )
1462
+ x = common_layers .layer_postprocess (x , y , hparams )
1463
+ return x
1464
+
1465
+
1339
1466
def transformer_decoder (decoder_input ,
1340
1467
encoder_output ,
1341
1468
decoder_self_attention_bias ,
@@ -1350,8 +1477,7 @@ def transformer_decoder(decoder_input,
1350
1477
losses = None ,
1351
1478
layer_collection = None ,
1352
1479
recurrent_memory_by_layer = None ,
1353
- chunk_number = None ,
1354
- ):
1480
+ chunk_number = None ):
1355
1481
"""A stack of transformer layers.
1356
1482
1357
1483
Args:
@@ -1377,8 +1503,8 @@ def transformer_decoder(decoder_input,
1377
1503
key created from the variable scope (including name).
1378
1504
make_image_summary: Whether to make an attention image summary.
1379
1505
losses: optional list onto which to append extra training losses
1380
- layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
1381
- KFAC optimizer. Default is None.
1506
+ layer_collection: A tensorflow_kfac.LayerCollection. Only used by the KFAC
1507
+ optimizer. Default is None.
1382
1508
recurrent_memory_by_layer: Optional dict, mapping layer names to instances
1383
1509
of transformer_memory.RecurrentMemory. Default is None.
1384
1510
chunk_number: an optional integer Tensor with shape [batch] used to operate
@@ -1388,9 +1514,6 @@ def transformer_decoder(decoder_input,
1388
1514
y: a Tensors
1389
1515
"""
1390
1516
x = decoder_input
1391
- attention_dropout_broadcast_dims = (
1392
- common_layers .comma_separated_string_to_integer_list (
1393
- getattr (hparams , "attention_dropout_broadcast_dims" , "" )))
1394
1517
1395
1518
mlperf_log .transformer_print (
1396
1519
key = mlperf_log .MODEL_HP_NUM_HIDDEN_LAYERS ,
@@ -1410,106 +1533,26 @@ def transformer_decoder(decoder_input,
1410
1533
hparams = hparams )
1411
1534
1412
1535
with tf .variable_scope (name ):
1413
- for layer in range (hparams .num_decoder_layers or hparams .num_hidden_layers ):
1414
- layer_name = "layer_%d" % layer
1415
- layer_cache = cache [layer_name ] if cache is not None else None
1416
- if recurrent_memory_by_layer is not None :
1417
- recurrent_memory = recurrent_memory_by_layer [layer_name ]
1418
- else :
1419
- recurrent_memory = None
1536
+ for layer_idx in range (hparams .num_decoder_layers or
1537
+ hparams .num_hidden_layers ):
1538
+ x = transformer_decoder_layer (
1539
+ x ,
1540
+ decoder_self_attention_bias ,
1541
+ layer_idx ,
1542
+ hparams ,
1543
+ encoder_decoder_attention_bias = encoder_decoder_attention_bias ,
1544
+ encoder_output = encoder_output ,
1545
+ cache = cache ,
1546
+ decode_loop_step = decode_loop_step ,
1547
+ nonpadding = nonpadding ,
1548
+ save_weights_to = save_weights_to ,
1549
+ make_image_summary = make_image_summary ,
1550
+ losses = losses ,
1551
+ layer_collection = layer_collection ,
1552
+ recurrent_memory_by_layer = recurrent_memory_by_layer ,
1553
+ chunk_number = chunk_number ,
1554
+ )
1420
1555
1421
- if layer < hparams .get ("num_area_layers" , 0 ):
1422
- max_area_width = hparams .get ("max_area_width" , 1 )
1423
- max_area_height = hparams .get ("max_area_height" , 1 )
1424
- memory_height = hparams .get ("max_area_height" , 1 )
1425
- else :
1426
- max_area_width = 1
1427
- max_area_height = 1
1428
- memory_height = 1
1429
- with tf .variable_scope (layer_name ):
1430
- with tf .variable_scope ("self_attention" ):
1431
- y = common_attention .multihead_attention (
1432
- common_layers .layer_preprocess (
1433
- x , hparams , layer_collection = layer_collection ),
1434
- None ,
1435
- decoder_self_attention_bias ,
1436
- hparams .attention_key_channels or hparams .hidden_size ,
1437
- hparams .attention_value_channels or hparams .hidden_size ,
1438
- hparams .hidden_size ,
1439
- hparams .num_heads ,
1440
- hparams .attention_dropout ,
1441
- attention_type = hparams .self_attention_type ,
1442
- max_relative_position = hparams .max_relative_position ,
1443
- heads_share_relative_embedding = (
1444
- hparams .heads_share_relative_embedding ),
1445
- add_relative_to_values = hparams .add_relative_to_values ,
1446
- save_weights_to = save_weights_to ,
1447
- cache = layer_cache ,
1448
- make_image_summary = make_image_summary ,
1449
- dropout_broadcast_dims = attention_dropout_broadcast_dims ,
1450
- max_length = hparams .get ("max_length" ),
1451
- decode_loop_step = decode_loop_step ,
1452
- vars_3d = hparams .get ("attention_variables_3d" ),
1453
- activation_dtype = hparams .get ("activation_dtype" , "float32" ),
1454
- weight_dtype = hparams .get ("weight_dtype" , "float32" ),
1455
- layer_collection = layer_collection ,
1456
- recurrent_memory = recurrent_memory ,
1457
- chunk_number = chunk_number ,
1458
- hard_attention_k = hparams .get ("hard_attention_k" , 0 ),
1459
- max_area_width = max_area_width ,
1460
- max_area_height = max_area_height ,
1461
- memory_height = memory_height ,
1462
- area_key_mode = hparams .get ("area_key_mode" , "none" ),
1463
- area_value_mode = hparams .get ("area_value_mode" , "none" ),
1464
- training = (hparams .get ("mode" , tf .estimator .ModeKeys .TRAIN )
1465
- == tf .estimator .ModeKeys .TRAIN ))
1466
- x = common_layers .layer_postprocess (x , y , hparams )
1467
- if encoder_output is not None :
1468
- with tf .variable_scope ("encdec_attention" ):
1469
- y = common_attention .multihead_attention (
1470
- common_layers .layer_preprocess (
1471
- x , hparams , layer_collection = layer_collection ),
1472
- encoder_output ,
1473
- encoder_decoder_attention_bias ,
1474
- hparams .attention_key_channels or hparams .hidden_size ,
1475
- hparams .attention_value_channels or hparams .hidden_size ,
1476
- hparams .hidden_size ,
1477
- hparams .num_heads ,
1478
- hparams .attention_dropout ,
1479
- max_relative_position = hparams .max_relative_position ,
1480
- heads_share_relative_embedding = (
1481
- hparams .heads_share_relative_embedding ),
1482
- add_relative_to_values = hparams .add_relative_to_values ,
1483
- save_weights_to = save_weights_to ,
1484
- cache = layer_cache ,
1485
- make_image_summary = make_image_summary ,
1486
- dropout_broadcast_dims = attention_dropout_broadcast_dims ,
1487
- max_length = hparams .get ("max_length" ),
1488
- vars_3d = hparams .get ("attention_variables_3d" ),
1489
- activation_dtype = hparams .get ("activation_dtype" , "float32" ),
1490
- weight_dtype = hparams .get ("weight_dtype" , "float32" ),
1491
- layer_collection = layer_collection ,
1492
- hard_attention_k = hparams .get ("hard_attention_k" , 0 ),
1493
- max_area_width = max_area_width ,
1494
- max_area_height = max_area_height ,
1495
- memory_height = memory_height ,
1496
- area_key_mode = hparams .get ("area_key_mode" , "none" ),
1497
- area_value_mode = hparams .get ("area_value_mode" , "none" ),
1498
- training = (hparams .get ("mode" , tf .estimator .ModeKeys .TRAIN )
1499
- == tf .estimator .ModeKeys .TRAIN ))
1500
- x = common_layers .layer_postprocess (x , y , hparams )
1501
- with tf .variable_scope ("ffn" ):
1502
- y = transformer_ffn_layer (
1503
- common_layers .layer_preprocess (
1504
- x , hparams , layer_collection = layer_collection ),
1505
- hparams ,
1506
- conv_padding = "LEFT" ,
1507
- nonpadding_mask = nonpadding ,
1508
- losses = losses ,
1509
- cache = layer_cache ,
1510
- decode_loop_step = decode_loop_step ,
1511
- layer_collection = layer_collection )
1512
- x = common_layers .layer_postprocess (x , y , hparams )
1513
1556
# if normalization is done in layer_preprocess, then it should also be done
1514
1557
# on the output, since the output can grow very large, being the sum of
1515
1558
# a whole stack of unnormalized layer outputs.
0 commit comments