Skip to content

Commit 7febb75

Browse files
BorisPowermadeleinethhallacy
authored
Boris/examples and cli (#32)
* Add a codex backtranslation example to improve SQL queries (#58) * Add a codex backtranslation example to improve SQL queries * Boris update ft example (#57) * update fine-tune example to show the new CLI outputs * model specifiction for search (#60) * Catch chunked encoding errors and retry (#63) * Add batch suggestion logic to prepare_data for fine_tunes and custom Q&A answers logic (#62) * Add batch suggestion logic to prepare_data for fine_tunes; add an example of how to create a rudimentary answers endpoint with a custom Q&A model Co-authored-by: Madeleine Thompson <[email protected]> Co-authored-by: hallacy <[email protected]>
1 parent c79fefc commit 7febb75

7 files changed

+465
-94
lines changed

examples/codex/backtranslation.py

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import openai
2+
from smokey import Smokey
3+
from typing import List, Union
4+
5+
6+
def get_candidates(
7+
prompt: str,
8+
stop: List[str],
9+
temperature: float,
10+
priming_prefix: str,
11+
engine: str,
12+
n: int = 5,
13+
) -> List[str]:
14+
"""
15+
Generate N candidate completions based on the prompt, generated with a specific temperature.
16+
17+
:param prompt: The prompt to start the conversation with.
18+
:param stop: A list of tokens that indicate the end of the generation.
19+
:param temperature: The temperature of the generation.
20+
:param priming_prefix: The prefix to use for the priming.
21+
:param engine: The engine to use for the generation.
22+
:param n: The number of completions to generate.
23+
:return: A list of completions.
24+
"""
25+
response = openai.Completion.create(
26+
engine=engine,
27+
prompt=prompt,
28+
temperature=temperature,
29+
max_tokens=150,
30+
top_p=1,
31+
frequency_penalty=0,
32+
presence_penalty=0,
33+
stop=stop,
34+
n=n,
35+
)
36+
responses = [priming_prefix + choice.text for choice in response.choices]
37+
return responses
38+
39+
40+
def rindex(lst: List, value: str) -> int:
41+
"""
42+
Return the index of the last occurence of a value in a list.
43+
44+
:param lst: The list to search in.
45+
:param value: The value to search for.
46+
:return: The index of the last occurence of the value.
47+
"""
48+
try:
49+
return len(lst) - lst[::-1].index(value) - 1
50+
except ValueError:
51+
raise ValueError(f"Answer start token `{value}` not found in the eval template")
52+
53+
54+
def eval_candidate(
55+
candidate_answer: str,
56+
original_instruction: str,
57+
eval_template: str,
58+
answer_start_token: str,
59+
engine: str,
60+
) -> float:
61+
"""
62+
Evaluate a candidate answer by calculating the average log probability
63+
of the original instruction, given the candidate answer with a specific
64+
evaluation template, aimed at reconstructing the original instruction.
65+
66+
:param candidate_answer: The candidate answer to evaluate.
67+
:param original_instruction: The original instruction.
68+
:param eval_template: The template to use for the evaluation.
69+
:param answer_start_token: The token to use to indicate the start of the answer.
70+
:param engine: The engine to use for the evaluation.
71+
:return: The evaluation of the candidate answer.
72+
"""
73+
response = openai.Completion.create(
74+
engine=engine,
75+
prompt=eval_template.format(candidate_answer, original_instruction),
76+
temperature=0,
77+
max_tokens=0,
78+
top_p=1,
79+
frequency_penalty=0,
80+
presence_penalty=0,
81+
logprobs=1,
82+
echo=True,
83+
)
84+
85+
answer_start = rindex(
86+
response["choices"][0]["logprobs"]["tokens"], answer_start_token
87+
)
88+
logprobs = response["choices"][0]["logprobs"]["token_logprobs"][answer_start + 1 :]
89+
return sum(logprobs) / len(logprobs)
90+
91+
92+
def backtranslation(
93+
prompt_template: str,
94+
additional_info: str,
95+
instruction: str,
96+
eval_template: str,
97+
priming_prefix: str = "SELECT",
98+
stop1: List[str] = ["#", ";"],
99+
answer_start_token: str = "--",
100+
n: int = 5,
101+
temperature: float = 0.5,
102+
return_all_results: bool = False,
103+
engine: str = "davinci-codex",
104+
) -> Union[str, List[str, float]]:
105+
"""
106+
Generate a number of SQL queries given a natural language instruction,
107+
and pick the best one based on the average log probability of explaining the
108+
candidate SQL query with the exact original instruction, when prompted for
109+
a natural language explanation of the candidate SQL query.
110+
111+
:param prompt_template: The template to use for the prompt to generate SQL.
112+
:param additional_info: Additional information to include in the prompt
113+
(SQL Tables, and their properties).
114+
:param instruction: The instruction in natural language.
115+
:param eval_template: The template to use for the evaluation.
116+
:param priming_prefix: The prefix to use for the priming of the SQL query.
117+
:param stop1: A list of tokens that indicate the end of the generation.
118+
:param answer_start_token: The token to use to indicate the start of the
119+
natural answer.
120+
:param n: The number of candidates to generate.
121+
:param temperature: The temperature of the generation.
122+
:param return_all_results: Whether to return all results or just the best one.
123+
:param engine: The engine to use for the generation and evaluation.
124+
:return: The best SQL query, or a list of all scored generated SQL queries.
125+
"""
126+
prompt_template = prompt_template.format(
127+
additional_info, instruction, priming_prefix
128+
)
129+
130+
candidates = []
131+
responses = get_candidates(
132+
prompt_template, stop1, temperature, priming_prefix, engine=engine, n=n
133+
)
134+
for i in range(n):
135+
quality = eval_candidate(
136+
responses[i],
137+
instruction,
138+
eval_template,
139+
answer_start_token,
140+
engine=engine,
141+
)
142+
candidates.append((responses[i], quality))
143+
144+
candidates.sort(key=lambda x: x[1], reverse=True)
145+
if return_all_results:
146+
return candidates
147+
return candidates[0][0]
148+
149+
150+
def main(
151+
nl_query: str = "Return the name of each department that had more than 10 employees in June 2021",
152+
eval_template: str = "{};\n-- Explanation of the above query in human readable format\n-- {}",
153+
table_definitions: str = "# Employee(id, name, department_id)\n# Department(id, name, address)\n# Salary_Payments(id, employee_id, amount, date)\n",
154+
prompt_template: str = "### Postgres SQL tables, with their properties:\n#\n{}#\n### {}\n{}",
155+
n: int = 3,
156+
temperature: float = 0.3,
157+
engine: str = "davinci-codex",
158+
):
159+
"""
160+
Generate a number of SQL queries given a natural language instruction,
161+
and pick the best one based on the highest backtranslation score.
162+
163+
:param nl_query: The natural language query.
164+
:param eval_template: The template to use for the evaluation.
165+
:param table_definitions: The definitions of the tables used in the query.
166+
:param prompt_template: The template to use for the prompt to generate SQL.
167+
:param n: The number of candidates to generate.
168+
:param temperature: The temperature of the generation.
169+
:param engine: The engine to use for the generation and evaluation.
170+
:return: The best SQL query, or a list of all scored generated SQL queries.
171+
"""
172+
173+
result = backtranslation(
174+
prompt_template,
175+
table_definitions,
176+
nl_query,
177+
eval_template,
178+
priming_prefix="SELECT",
179+
temperature=temperature,
180+
n=n,
181+
engine=engine,
182+
)
183+
print(result)
184+
185+
186+
if __name__ == "__main__":
187+
Smokey(main)
+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import openai
2+
import argparse
3+
4+
5+
def create_context(
6+
question, search_file_id, max_len=1800, search_model="ada", max_rerank=10
7+
):
8+
"""
9+
Create a context for a question by finding the most similar context from the search file.
10+
:param question: The question
11+
:param search_file_id: The file id of the search file
12+
:param max_len: The maximum length of the returned context (in tokens)
13+
:param search_model: The search model to use
14+
:param max_rerank: The maximum number of reranking
15+
:return: The context
16+
"""
17+
results = openai.Engine(search_model).search(
18+
search_model=search_model,
19+
query=question,
20+
max_rerank=max_rerank,
21+
file=search_file_id,
22+
return_metadata=True,
23+
)
24+
returns = []
25+
cur_len = 0
26+
for result in results["data"]:
27+
cur_len += int(result["metadata"]) + 4
28+
if cur_len > max_len:
29+
break
30+
returns.append(result["text"])
31+
return "\n\n###\n\n".join(returns)
32+
33+
34+
def answer_question(
35+
search_file_id="<SEARCH_FILE_ID>",
36+
fine_tuned_qa_model="<FT_QA_MODEL_ID>",
37+
question="Which country won the European Football championship in 2021?",
38+
max_len=1800,
39+
search_model="ada",
40+
max_rerank=10,
41+
debug=False,
42+
stop_sequence=["\n", "."],
43+
max_tokens=100,
44+
):
45+
"""
46+
Answer a question based on the most similar context from the search file, using your fine-tuned model.
47+
:param question: The question
48+
:param fine_tuned_qa_model: The fine tuned QA model
49+
:param search_file_id: The file id of the search file
50+
:param max_len: The maximum length of the returned context (in tokens)
51+
:param search_model: The search model to use
52+
:param max_rerank: The maximum number of reranking
53+
:param debug: Whether to output debug information
54+
:param stop_sequence: The stop sequence for Q&A model
55+
:param max_tokens: The maximum number of tokens to return
56+
:return: The answer
57+
"""
58+
context = create_context(
59+
question,
60+
search_file_id,
61+
max_len=max_len,
62+
search_model=search_model,
63+
max_rerank=max_rerank,
64+
)
65+
if debug:
66+
print("Context:\n" + context)
67+
print("\n\n")
68+
try:
69+
response = openai.Completion.create(
70+
model=fine_tuned_qa_model,
71+
prompt=f"Answer the question based on the context below\n\nText: {context}\n\n---\n\nQuestion: {question}\nAnswer:",
72+
temperature=0,
73+
max_tokens=max_tokens,
74+
top_p=1,
75+
frequency_penalty=0,
76+
presence_penalty=0,
77+
stop=stop_sequence,
78+
)
79+
return response["choices"][0]["text"]
80+
except Exception as e:
81+
print(e)
82+
return ""
83+
84+
85+
if __name__ == "__main__":
86+
parser = argparse.ArgumentParser(
87+
description="Rudimentary functionality of the answers endpoint with a fine-tuned Q&A model.",
88+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
89+
)
90+
parser.add_argument(
91+
"--search_file_id", help="Search file id", required=True, type=str
92+
)
93+
parser.add_argument(
94+
"--fine_tuned_qa_model", help="Fine-tuned QA model id", required=True, type=str
95+
)
96+
parser.add_argument(
97+
"--question", help="Question to answer", required=True, type=str
98+
)
99+
parser.add_argument(
100+
"--max_len",
101+
help="Maximum length of the returned context (in tokens)",
102+
default=1800,
103+
type=int,
104+
)
105+
parser.add_argument(
106+
"--search_model", help="Search model to use", default="ada", type=str
107+
)
108+
parser.add_argument(
109+
"--max_rerank",
110+
help="Maximum number of reranking for the search",
111+
default=10,
112+
type=int,
113+
)
114+
parser.add_argument(
115+
"--debug", help="Print debug information (context used)", action="store_true"
116+
)
117+
parser.add_argument(
118+
"--stop_sequence",
119+
help="Stop sequences for the Q&A model",
120+
default=["\n", "."],
121+
nargs="+",
122+
type=str,
123+
)
124+
parser.add_argument(
125+
"--max_tokens",
126+
help="Maximum number of tokens to return",
127+
default=100,
128+
type=int,
129+
)
130+
args = parser.parse_args()
131+
response = answer_question(
132+
search_file_id=args.search_file_id,
133+
fine_tuned_qa_model=args.fine_tuned_qa_model,
134+
question=args.question,
135+
max_len=args.max_len,
136+
search_model=args.search_model,
137+
max_rerank=args.max_rerank,
138+
debug=args.debug,
139+
stop_sequence=args.stop_sequence,
140+
max_tokens=args.max_tokens,
141+
)
142+
print(f"Answer:{response}")

0 commit comments

Comments
 (0)