45
45
class Transformer (t2t_model .T2TModel ):
46
46
"""Attention net. See file docstring."""
47
47
48
+ def __init__ (self , * args , ** kwargs ):
49
+ super (Transformer , self ).__init__ (* args , ** kwargs )
50
+ self .attention_weights = dict () # For vizualizing attention heads.
51
+
48
52
def encode (self , inputs , target_space , hparams , features = None ):
49
53
"""Encode transformer inputs.
50
54
@@ -73,7 +77,8 @@ def encode(self, inputs, target_space, hparams, features=None):
73
77
74
78
encoder_output = transformer_encoder (
75
79
encoder_input , self_attention_bias ,
76
- hparams , nonpadding = _features_to_nonpadding (features , "inputs" ))
80
+ hparams , nonpadding = _features_to_nonpadding (features , "inputs" ),
81
+ save_weights_to = self .attention_weights )
77
82
78
83
return encoder_output , encoder_decoder_attention_bias
79
84
@@ -114,7 +119,8 @@ def decode(self,
114
119
encoder_decoder_attention_bias ,
115
120
hparams ,
116
121
cache = cache ,
117
- nonpadding = nonpadding )
122
+ nonpadding = nonpadding ,
123
+ save_weights_to = self .attention_weights )
118
124
119
125
if hparams .use_tpu and hparams .mode == tf .estimator .ModeKeys .TRAIN :
120
126
# TPU does not react kindly to extra dimensions.
@@ -507,7 +513,8 @@ def transformer_encoder(encoder_input,
507
513
encoder_self_attention_bias ,
508
514
hparams ,
509
515
name = "encoder" ,
510
- nonpadding = None ):
516
+ nonpadding = None ,
517
+ save_weights_to = None ):
511
518
"""A stack of transformer layers.
512
519
513
520
Args:
@@ -522,6 +529,9 @@ def transformer_encoder(encoder_input,
522
529
encoder_self_attention_bias. The knowledge about padding is used
523
530
for pad_remover(efficiency) and to mask out padding in convoltutional
524
531
layers.
532
+ save_weights_to: an optional dictionary to capture attention weights
533
+ for vizualization; the weights tensor will be appended there under
534
+ a string key created from the variable scope (including name).
525
535
526
536
Returns:
527
537
y: a Tensors
@@ -551,6 +561,7 @@ def transformer_encoder(encoder_input,
551
561
hparams .num_heads ,
552
562
hparams .attention_dropout ,
553
563
attention_type = hparams .self_attention_type ,
564
+ save_weights_to = save_weights_to ,
554
565
max_relative_position = hparams .max_relative_position )
555
566
x = common_layers .layer_postprocess (x , y , hparams )
556
567
with tf .variable_scope ("ffn" ):
@@ -571,7 +582,8 @@ def transformer_decoder(decoder_input,
571
582
hparams ,
572
583
cache = None ,
573
584
name = "decoder" ,
574
- nonpadding = None ):
585
+ nonpadding = None ,
586
+ save_weights_to = None ):
575
587
"""A stack of transformer layers.
576
588
577
589
Args:
@@ -590,6 +602,9 @@ def transformer_decoder(decoder_input,
590
602
to mask out padding in convoltutional layers. We generally only
591
603
need this mask for "packed" datasets, because for ordinary datasets,
592
604
no padding is ever followed by nonpadding.
605
+ save_weights_to: an optional dictionary to capture attention weights
606
+ for vizualization; the weights tensor will be appended there under
607
+ a string key created from the variable scope (including name).
593
608
594
609
Returns:
595
610
y: a Tensors
@@ -612,6 +627,7 @@ def transformer_decoder(decoder_input,
612
627
hparams .num_heads ,
613
628
hparams .attention_dropout ,
614
629
attention_type = hparams .self_attention_type ,
630
+ save_weights_to = save_weights_to ,
615
631
max_relative_position = hparams .max_relative_position ,
616
632
cache = layer_cache )
617
633
x = common_layers .layer_postprocess (x , y , hparams )
@@ -624,7 +640,8 @@ def transformer_decoder(decoder_input,
624
640
hparams .attention_key_channels or hparams .hidden_size ,
625
641
hparams .attention_value_channels or hparams .hidden_size ,
626
642
hparams .hidden_size , hparams .num_heads ,
627
- hparams .attention_dropout )
643
+ hparams .attention_dropout ,
644
+ save_weights_to = save_weights_to )
628
645
x = common_layers .layer_postprocess (x , y , hparams )
629
646
with tf .variable_scope ("ffn" ):
630
647
y = transformer_ffn_layer (
0 commit comments