Skip to content

Commit

Permalink
introduce stop condition to beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Dec 9, 2024
1 parent 4686635 commit 856706e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 45 deletions.
98 changes: 65 additions & 33 deletions python/text_utils/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def beam_search(
logit_fns: list[LogitFn] | None = None,
kwargs_select_fn: MaskSelectFn | None = None,
kwargs_update_fn: MaskUpdateFn | None = None,
max_outputs: int | list[int] | None = None,
stop_condition: str | None = None,
max_new_tokens: int | None = None,
return_incomplete: bool = False,
yield_intermediate: bool = False,
Expand All @@ -43,20 +43,20 @@ def beam_search(
assert (
max_new_tokens is None or max_new_tokens > 0
), "max_new_tokens must be None or positive"
if stop_condition is None:
stop_condition = "max score"
assert stop_condition in {
"max score",
"estimated score",
"max outputs",
}, "stop condition must be 'max score', 'estimated score' or 'max outputs'"
batch_size = len(initial)

decoder_info: Any | None = None
update_info: list[int] = []
current_beams: list[list[Beam]] = []
beam_queues: list[list[Beam]] = []
if max_outputs is None:
max_outputs = [beam_width] * batch_size
elif isinstance(max_outputs, int):
max_outputs = [max_outputs] * batch_size
else:
assert (
len(max_outputs) == batch_size
), "max_outputs must be None, int or list of length batch_size"
finished_beams: list[list[Beam]] = []
too_long_beams: list[list[Beam]] = []

for init in initial:
if isinstance(init, Beam):
Expand All @@ -72,50 +72,82 @@ def beam_search(

current_beams.append(beams) # type: ignore
update_info.append(len(beams))
beam_queues.append([])

def should_stop(beam: Beam) -> bool:
return (
stop_fn(beam)
or len(beam) >= max_length
or beam.decoded_length >= (max_new_tokens or (beam.decoded_length + 1))
)
finished_beams.append([])
too_long_beams.append([])

def filter_beams() -> bool:
finished = True
for idx in range(batch_size):
new_beams = []
for beam in current_beams[idx]:
if should_stop(beam):
beam_queues[idx].append(beam)
if stop_fn(beam):
finished_beams[idx].append(beam)
elif len(beam) >= max_length or beam.decoded_length >= (
max_new_tokens or (beam.decoded_length + 1)
):
too_long_beams[idx].append(beam)
else:
new_beams.append(beam)

current_beams[idx] = new_beams
finished = finished and (
len(current_beams[idx]) == 0
or len(beam_queues[idx]) >= max_outputs[idx]
if not current_beams[idx]:
# we are done with this batch element
continue

elif len(finished_beams[idx]) < beam_width:
finished = False
continue

elif stop_condition == "max outputs":
# we are done with this batch element
# because we have enough finished beams
current_beams[idx] = []
continue

worst_finished = min(
(score_fn(b) for b in finished_beams[idx]), default=float("-inf")
)
if stop_condition == "estimated score":
# best current calculated from current length
# idea: is a current active beam better than the worst finished beam?
best_current = max(score_fn(b) for b in current_beams[idx])
else:
# best current calculated from maximum length
# idea: assume all remaining tokens are perfectly predicted
# with probability 1.0, can a current active beam be better
# than the worst finished beam?
current = current_beams[idx][0]
max_decoded_length = max_length - current.initial_length
length = min(max_decoded_length, max_new_tokens or max_decoded_length)
best_current = max(score_fn(b, length) for b in current_beams[idx])

if worst_finished >= best_current:
# set current beams to empty list to stop processing
current_beams[idx] = []
else:
finished = False

return finished

def get_outputs(intermediate: bool) -> list[list[Beam]]:
outputs = []
for idx in range(batch_size):
beam_queue = beam_queues[idx]
current = current_beams[idx]

if return_incomplete:
finished = finished_beams[idx] + too_long_beams[idx]
else:
finished = finished_beams[idx]

if intermediate:
# for intermediate outputs we
# return the active beams, so swap here
beam_queue, current = current, beam_queue
n = beam_width
# return the active beams first if available
beams = current if current else finished
else:
n = max_outputs[idx]

beam_queue = sorted(beam_queue, key=lambda b: score_fn(b), reverse=True)
if len(beam_queue) == 0 and (return_incomplete or intermediate):
beam_queue = sorted(current, key=lambda b: score_fn(b), reverse=True)
beams = finished

outputs.append(beam_queue[:n])
beams = sorted(beams, key=score_fn, reverse=True)
outputs.append(beams[:beam_width])

return outputs

Expand Down
25 changes: 13 additions & 12 deletions python/text_utils/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def decoded_token_ids(self) -> list[int]:
def decoded_log_probs(self) -> list[float]:
return self.log_probs[self.initial_length :]

@property
def length(self) -> int:
return len(self)

@property
def log_prob(self) -> float:
return sum(self.log_probs)
Expand Down Expand Up @@ -135,22 +139,19 @@ def __repr__(self) -> str:
]


# takes in a beam and returns a scalar score
ScoreFn = Callable[[Beam], float]
# takes in a beam (and optional length) and returns a scalar score
ScoreFn = Callable[[Beam, int | None], float]


def log_likelihood_score(
normalize_by_length: bool = True, alpha: float = 1.0, full: bool = False
) -> ScoreFn:
def _score(beam: Beam) -> float:
if full:
log_prob = beam.log_prob
length = len(beam)
else:
log_prob = beam.decoded_log_prob
def log_likelihood_score(normalize: bool = True, alpha: float = 1.0) -> ScoreFn:
assert alpha >= 0.0, "alpha must be positive"

def _score(beam: Beam, length: int | None = None) -> float:
log_prob = beam.decoded_log_prob
if length is None:
length = beam.decoded_length

if normalize_by_length and length > 0:
if normalize and length > 0:
return log_prob / (length**alpha)
else:
return log_prob
Expand Down

0 comments on commit 856706e

Please sign in to comment.