@@ -271,10 +271,14 @@ def attention_image_summary(attn, image_shapes=None):
271
271
272
272
Args:
273
273
attn: a Tensor with shape [batch, num_heads, query_length, memory_length]
274
- image_shapes: optional quadruple of integer scalars.
274
+ image_shapes: optional tuple of integer scalars.
275
275
If the query positions and memory positions represent the
276
- pixels of a flattened image , then pass in their dimensions:
276
+ pixels of flattened images , then pass in their dimensions:
277
277
(query_rows, query_cols, memory_rows, memory_cols).
278
+ If the query positions and memory positions represent the
279
+ pixels x channels of flattened images, then pass in their dimensions:
280
+ (query_rows, query_cols, query_channels,
281
+ memory_rows, memory_cols, memory_channels).
278
282
"""
279
283
num_heads = attn .get_shape ().as_list ()[1 ]
280
284
# [batch, query_length, memory_length, num_heads]
@@ -286,10 +290,20 @@ def attention_image_summary(attn, image_shapes=None):
286
290
image = split_last_dimension (image , 3 )
287
291
image = tf .reduce_max (image , 4 )
288
292
if image_shapes is not None :
289
- q_rows , q_cols , m_rows , m_cols = list (image_shapes )
290
- image = tf .reshape (image , [- 1 , q_rows , q_cols , m_rows , m_cols , 3 ])
291
- image = tf .transpose (image , [0 , 1 , 3 , 2 , 4 , 5 ])
292
- image = tf .reshape (image , [- 1 , q_rows * m_rows , q_cols * m_cols , 3 ])
293
+ if len (image_shapes ) == 4 :
294
+ q_rows , q_cols , m_rows , m_cols = list (image_shapes )
295
+ image = tf .reshape (image , [- 1 , q_rows , q_cols , m_rows , m_cols , 3 ])
296
+ image = tf .transpose (image , [0 , 1 , 3 , 2 , 4 , 5 ])
297
+ image = tf .reshape (image , [- 1 , q_rows * m_rows , q_cols * m_cols , 3 ])
298
+ else :
299
+ assert len (image_shapes ) == 6
300
+ q_rows , q_cols , q_channnels , m_rows , m_cols , m_channels = list (
301
+ image_shapes )
302
+ image = tf .reshape (image , [- 1 , q_rows , q_cols , q_channnels ,
303
+ m_rows , m_cols , m_channels , 3 ])
304
+ image = tf .transpose (image , [0 , 1 , 4 , 3 , 2 , 5 , 6 , 7 ])
305
+ image = tf .reshape (image , [- 1 , q_rows * m_rows * q_channnels ,
306
+ q_cols * m_cols * m_channels , 3 ])
293
307
tf .summary .image ("attention" , image , max_outputs = 1 )
294
308
295
309
@@ -310,10 +324,8 @@ def dot_product_attention(q,
310
324
bias: bias Tensor (see attention_bias())
311
325
dropout_rate: a floating point number
312
326
summaries: a boolean
313
- image_shapes: optional quadruple of integer scalars for image summary.
314
- If the query positions and memory positions represent the
315
- pixels of a flattened image, then pass in their dimensions:
316
- (query_rows, query_cols, memory_rows, memory_cols).
327
+ image_shapes: optional tuple of integer scalars.
328
+ see comments for attention_image_summary()
317
329
name: an optional string
318
330
319
331
Returns:
@@ -356,10 +368,8 @@ def multihead_attention(query_antecedent,
356
368
num_heads: an integer dividing total_key_depth and total_value_depth
357
369
dropout_rate: a floating point number
358
370
summaries: a boolean
359
- image_shapes: optional quadruple of integer scalars for image summary.
360
- If the query positions and memory positions represent the
361
- pixels of a flattened image, then pass in their dimensions:
362
- (query_rows, query_cols, memory_rows, memory_cols).
371
+ image_shapes: optional tuple of integer scalars.
372
+ see comments for attention_image_summary()
363
373
name: an optional string
364
374
365
375
Returns:
@@ -398,3 +408,72 @@ def multihead_attention(query_antecedent,
398
408
x = combine_heads (x )
399
409
x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
400
410
return x
411
+
412
+
413
+ def parameter_attention (x ,
414
+ total_key_depth ,
415
+ total_value_depth ,
416
+ output_depth ,
417
+ memory_rows ,
418
+ num_heads ,
419
+ dropout_rate ,
420
+ name = None ):
421
+ """Attention over parameters.
422
+
423
+ We use the same multi-headed attention as in the other layers, but the memory
424
+ keys and values are model parameters. There are no linear transformation
425
+ on the keys or values.
426
+
427
+ We are also a bit more careful about memory usage, since the number of
428
+ memory positions may be very large.
429
+
430
+ Args:
431
+ x: a Tensor with shape [batch, length_q, channels]
432
+ total_key_depth: an integer
433
+ total_value_depth: an integer
434
+ output_depth: an integer
435
+ memory_rows: an integer
436
+ num_heads: an integer dividing total_key_depth and total_value_depth
437
+ dropout_rate: a floating point number
438
+ name: an optional string
439
+
440
+ Returns:
441
+ A Tensor.
442
+ """
443
+ with tf .variable_scope (name , default_name = "parameter_attention" ,
444
+ values = [x ]):
445
+ head_size_k = total_key_depth // num_heads
446
+ head_size_v = total_value_depth // num_heads
447
+ var_shape_k = [num_heads , memory_rows , head_size_k ]
448
+ var_shape_v = [num_heads , memory_rows , head_size_v ]
449
+ k = tf .get_variable (
450
+ "k" , var_shape_k ,
451
+ initializer = tf .random_normal_initializer (
452
+ 0 , output_depth ** - 0.5 )) * (num_heads ** 0.5 )
453
+ v = tf .get_variable (
454
+ "v" , var_shape_v ,
455
+ initializer = tf .random_normal_initializer (
456
+ 0 , output_depth ** - 0.5 )) * (output_depth ** 0.5 )
457
+ batch_size = tf .shape (x )[0 ]
458
+ length = tf .shape (x )[1 ]
459
+ q = common_layers .conv1d (x , total_key_depth , 1 , name = "q_transform" )
460
+ if dropout_rate :
461
+ # This is a cheaper form of attention dropout where we use to use
462
+ # the same dropout decisions across batch elemets and query positions,
463
+ # but different decisions across heads and memory positions.
464
+ v = tf .nn .dropout (v , 1.0 - dropout_rate ,
465
+ noise_shape = [num_heads , memory_rows , 1 ])
466
+ # query is [batch, length, hidden_size]
467
+ # reshape and transpose it to [heads, batch * length, head_size]
468
+ q = tf .reshape (q , [batch_size , length , num_heads , head_size_k ])
469
+ q = tf .transpose (q , [2 , 0 , 1 , 3 ])
470
+ q = tf .reshape (q , [num_heads , batch_size * length , head_size_k ])
471
+ weights = tf .matmul (q , k , transpose_b = True )
472
+ weights = tf .nn .softmax (weights )
473
+ y = tf .matmul (weights , v )
474
+ y = tf .reshape (y , [num_heads , batch_size , length , head_size_v ])
475
+ y = tf .transpose (y , [1 , 2 , 0 , 3 ])
476
+ y = tf .reshape (y , [batch_size , length , total_value_depth ])
477
+ y .set_shape ([None , None , total_value_depth ])
478
+ y = common_layers .conv1d (y , output_depth , 1 , name = "output_transform" )
479
+ return y
0 commit comments