Skip to content

Commit

Permalink
Use decision transformer prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
evhub committed Apr 29, 2022
1 parent 907f20f commit 3031d18
Show file tree
Hide file tree
Showing 10 changed files with 568 additions and 527 deletions.
105 changes: 61 additions & 44 deletions bbopt-source/backends/openai.coco
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@ The OpenAI backend. Uses large language models for black box optimization.
"""

import os
import random
from ast import literal_eval

import openai

from bbopt import constants
from bbopt.util import printerr
from bbopt.util import printerr, stdev
from bbopt.params import param_processor
from bbopt.backends.util import StandardBackend


# Utilities:

def get_prompt(params, data_points, losses) =
def get_prompt(params, data_points, losses, hoped_for_loss) =
"""Get the OpenAI API prompt to use."""
'''# black box function to be minimized
def f({func_params}) -> float:
Expand All @@ -28,9 +29,10 @@ def f({func_params}) -> float:
"""
return black_box_loss({names})

# optimization history (MUST stay within the bounds, SHOULD fully explore the bounds, SHOULD converge to minimum)
# experimentally observed data
# new experiments MUST stay within the bounds, SHOULD fully explore the bounds, and SHOULD converge to minimum
# bounds: f({bounds})
{values}f('''.format(
{values}{hoped_for_loss} == f('''.format(
func_params=", ".join(
"{name}: {type}".format(
name=name,
Expand Down Expand Up @@ -61,23 +63,16 @@ def f({func_params}) -> float:
for name, (func, args, _) in params.items()
),
values="".join(
"f({args}) == {loss}\n".format(
"{loss} == f({args})\n".format(
args=", ".join(params |> map$(point[]) |> map$(repr)),
loss=loss,
)
for point, loss in zip(data_points, losses)
),
hoped_for_loss=hoped_for_loss,
)


def get_completion_len(data_points) =
"""Get the maximum number of characters in a completion."""
max(
len(", ".join(point.values() |> map$(repr)))
for point in data_points
) + 1


def to_python(completion, params):
"""Convert a completion to Python code as best as possible."""
completion = completion.strip("(,")
Expand All @@ -87,6 +82,7 @@ def to_python(completion, params):
("\u2019", "'"),
("\u201c", '"'),
("\u201d", '"'),
("\u221e", 'float("inf")'),
) :: (
(f"{name}=", "") for name in params
) :: (
Expand All @@ -96,6 +92,13 @@ def to_python(completion, params):
return completion


def get_loss_eps(min_loss):
"""Get a reasonably-sized expected loss improvement."""
a, b = float(abs(min_loss)).as_integer_ratio()
little_a = int("1" * len(str(a)))
return little_a / b


# Backend:

class OpenAIBackend(StandardBackend):
Expand All @@ -110,7 +113,7 @@ class OpenAIBackend(StandardBackend):

max_prompt_len = float("inf")

def setup_backend(self, params, engine=None, temperature=None, max_retries=None, api_key=None, debug=False):
def setup_backend(self, params, engine=None, temperature=None, max_retries=None, api_key=None, debug=True):
self.params = params

self.engine = engine ?? constants.openai_default_engine
Expand All @@ -127,37 +130,25 @@ class OpenAIBackend(StandardBackend):
self.data_points += new_data
self.losses += new_losses

def retry_get_values(self, temp=None, cached_values=None):
if not self.max_retries:
printerr(f"BBopt Warning: Maximum number of OpenAI API retries exceeded on:\n== PROMPT ==\n{self.get_prompt()}\n== END ==")
return {} # return empty values so that the fallback random backend will be used instead
if self.debug:
if temp is None:
print(f"RETRYING with: {self.max_prompt_len=}")
else:
print(f"RETRYING with new temperature: {self.temp} -> {temp}")
old_retries, self.max_retries = self.max_retries, self.max_retries - 1
if temp is not None:
old_temp, self.temp = self.temp, temp
if cached_values is not None:
if self.debug:
print(f"CACHING values: {cached_values[:len(self.cached_values)]} + {cached_values[len(self.cached_values):]}")
self.cached_values = cached_values
try:
return self.get_next_values()
finally:
self.max_retries = old_retries
if temp is not None:
self.temp = old_temp
if cached_values is not None:
self.cached_values = ()

def get_prompt(self) = (
get_prompt(self.params, self.data_points, self.losses)
get_prompt(
self.params,
self.data_points,
self.losses,
hoped_for_loss=min_loss - random.uniform(0, stdev(self.losses) + get_loss_eps(min_loss)),
)
+ ", ".join(self.cached_values |> map$(repr))
# only "," not ", " since the prompt shouldn't end in a space
+ ("," if self.cached_values else "")
)
) where:
min_loss = min(self.losses)

def get_completion_len(self) =
"""Get the maximum number of characters in a completion."""
max(
len(", ".join(self.params |> map$(point[]) |> map$(repr)))
for point in self.data_points
) + 1

def get_next_values(self):
# generate prompt
Expand All @@ -175,7 +166,7 @@ class OpenAIBackend(StandardBackend):
engine=self.engine,
prompt=prompt,
temperature=self.temp,
max_tokens=get_completion_len(self.data_points),
max_tokens=self.get_completion_len(),
)
except openai.error.InvalidRequestError as api_err:
if self.debug:
Expand All @@ -187,7 +178,7 @@ class OpenAIBackend(StandardBackend):
if self.max_prompt_len == float("inf"):
self.max_prompt_len = len(prompt.rsplit("\n", 1)[0])
else:
self.max_prompt_len -= get_completion_len(self.data_points)
self.max_prompt_len -= self.get_completion_len()
return self.retry_get_values()

# parse response
Expand Down Expand Up @@ -216,7 +207,7 @@ class OpenAIBackend(StandardBackend):
)
if len(legal_values) < len(self.params):
if self.debug:
if len(valvec) < len(self.params):
if len(valvec) < len(self.params) - len(self.cached_values):
print(f"ERROR: insufficient values (got {len(valvec)}; expected {len(self.params) - len(self.cached_values)})")
else:
print(f"ERROR: got illegal values: {valvec!r}")
Expand All @@ -230,6 +221,32 @@ class OpenAIBackend(StandardBackend):
return self.retry_get_values(temp=self.temp + (constants.openai_max_temp - self.temp) / 2)
return values

def retry_get_values(self, temp=None, cached_values=None):
"""Used in get_next_values to keep track of recursive calls."""
if not self.max_retries:
printerr(f"BBopt Warning: Maximum number of OpenAI API retries exceeded on:\n== PROMPT ==\n{self.get_prompt()}\n== END ==")
return {} # return empty values so that the fallback random backend will be used instead
if self.debug:
if temp is None:
print(f"RETRYING with: {self.max_prompt_len=}")
else:
print(f"RETRYING with new temperature: {self.temp} -> {temp}")
old_retries, self.max_retries = self.max_retries, self.max_retries - 1
if temp is not None:
old_temp, self.temp = self.temp, temp
if cached_values is not None:
if self.debug:
print(f"CACHING values: {cached_values[:len(self.cached_values)]} + {cached_values[len(self.cached_values):]}")
self.cached_values = cached_values
try:
return self.get_next_values()
finally:
self.max_retries = old_retries
if temp is not None:
self.temp = old_temp
if cached_values is not None:
self.cached_values = ()


# Registered names:

Expand Down
10 changes: 7 additions & 3 deletions bbopt-source/benchmarking.coco
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ OPT_FUNCS.append(cond_gain_func)

# Main

def benchmark(algs, plot_func="plot_convergence", n=25):
def benchmark(algs, plot_func="plot_convergence", n=10):
figsize = len(OPT_FUNCS) |> math.sqrt |> math.ceil |> int
fig, axs = plt.subplots(figsize, figsize)
for i, func in enumerate(OPT_FUNCS):
Expand All @@ -98,9 +98,12 @@ def benchmark(algs, plot_func="plot_convergence", n=25):
bb = BlackBoxOptimizer(__file__, tag=f"{func.__name__}_{alg}")
if bb.num_examples < n:
for _ in range(n - bb.num_examples):
bb.run(alg)
if isinstance(alg, tuple):
bb.run_meta(alg)
else:
bb.run(alg)
func(bb)
getattr(bb, plot_func)(ax, label=alg)
getattr(bb, plot_func)(ax, label=str(alg))
ax.set_title(func.__name__)
ax.set_xlabel("")
ax.legend()
Expand All @@ -112,4 +115,5 @@ if __name__ == "__main__":
"tpe_or_gp",
"tree_structured_parzen_estimator",
"safe_gaussian_process",
("openai", "safe_gaussian_process"),
))
2 changes: 1 addition & 1 deletion bbopt-source/constants.coco
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ safe_fallback_alg = "tree_structured_parzen_estimator"
# OpenAI constants:
openai_default_engine = "text-curie-001"

openai_default_temp = 1.1
openai_default_temp = 1
openai_max_temp = 2

openai_default_max_retries = 10
Expand Down
7 changes: 1 addition & 6 deletions bbopt-source/tests/examples_test.coco
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from coconut.command.util import call_output
from bbopt.util import (
mean,
median,
stdev,
)


Expand Down Expand Up @@ -128,12 +129,6 @@ def middle_mean(xs) =
mean(xs[a:b])


def stdev(xs) =
"""Standard deviation of xs."""
mean((x - mu)**2 for x in xs)**0.5 where:
mu = mean(xs)


def assert_improving(data, ave_func=mean, within_stdevs=0.5):
"""Assert that the second half of data is greater/smaller than the first."""
examples = \data["examples"]
Expand Down
7 changes: 7 additions & 0 deletions bbopt-source/util.coco
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,10 @@ def median(xs):
sorted_xs[len(sorted_xs)//2],
sorted_xs[(len(sorted_xs) + 1)//2],
))


def stdev(xs) =
"""Standard deviation of xs."""
mean((x - mu)**2 for x in xs)**0.5 where:
mu = mean(xs)
xs = tuple(xs)
Loading

0 comments on commit 3031d18

Please sign in to comment.