Skip to content

Commit ac5d522

Browse files
committed
nucleus sampling
1 parent f35fa1d commit ac5d522

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

Diff for: src/generate_unconditional_samples.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def sample_model(
1616
length=None,
1717
temperature=1,
1818
top_k=0,
19+
top_p=1,
1920
models_dir='models',
2021
):
2122
"""
@@ -58,7 +59,7 @@ def sample_model(
5859
hparams=hparams, length=length,
5960
start_token=enc.encoder['<|endoftext|>'],
6061
batch_size=batch_size,
61-
temperature=temperature, top_k=top_k
62+
temperature=temperature, top_k=top_k, top_p=top_p
6263
)[:, 1:]
6364

6465
saver = tf.train.Saver()

Diff for: src/interactive_conditional_samples.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def interact_model(
1616
length=None,
1717
temperature=1,
1818
top_k=0,
19+
top_p=1,
1920
models_dir='models',
2021
):
2122
"""
@@ -61,7 +62,7 @@ def interact_model(
6162
hparams=hparams, length=length,
6263
context=context,
6364
batch_size=batch_size,
64-
temperature=temperature, top_k=top_k
65+
temperature=temperature, top_k=top_k, top_p=top_p
6566
)
6667

6768
saver = tf.train.Saver()

Diff for: src/sample.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,25 @@ def _top_k():
2222
)
2323

2424

25-
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
25+
def top_p_logits(logits, p):
26+
"""Nucleus sampling"""
27+
batch, _ = logits.shape.as_list()
28+
sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
29+
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
30+
indices = tf.stack([
31+
tf.range(0, batch),
32+
# number of indices to include
33+
tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
34+
], axis=-1)
35+
min_values = tf.gather_nd(sorted_logits, indices)
36+
return tf.where(
37+
logits < min_values,
38+
tf.ones_like(logits) * -1e10,
39+
logits,
40+
)
41+
42+
43+
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1):
2644
if start_token is None:
2745
assert context is not None, 'Specify exactly one of start_token and context!'
2846
else:
@@ -45,6 +63,7 @@ def body(past, prev, output):
4563
next_outputs = step(hparams, prev, past=past)
4664
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
4765
logits = top_k_logits(logits, k=top_k)
66+
logits = top_p_logits(logits, p=top_p)
4867
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
4968
return [
5069
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),

0 commit comments

Comments
 (0)