-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtypechat.go
86 lines (69 loc) · 2.34 KB
/
typechat.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package typechat
import (
"fmt"
"strings"
)
type JsonTranslator[T any] interface {
CreateRequestPrompt(request string) string
CreateRepairPrompt(validationError string) string
Translate(request string) (*T, error)
Validator() JsonValidator[T]
Model() LanguageModel
}
type baseJsonTranslator[T any] struct {
model LanguageModel
validator JsonValidator[T]
attemptRepair bool
stripNulls bool
}
func NewJsonTranslator[T any](model LanguageModel, schema string, typeName string) JsonTranslator[T] {
return &baseJsonTranslator[T]{
model: model,
validator: NewJsonValidator[T](schema, typeName),
attemptRepair: true,
}
}
func (t *baseJsonTranslator[T]) CreateRequestPrompt(request string) string {
return fmt.Sprintf("You are a service that translates user requests into JSON objects of struct \"%s\" according to the following Go definitions:\n"+
"```go\n%s```\n"+
"The following is a user request:\n"+
"\"\"\"\n%s\n\"\"\"\n"+
"The following is the user request translated into a JSON object with 1 spaces of indentation and no properties with the value undefined:\n",
t.validator.GetTypeName(), t.validator.GetSchema(), request)
}
func (t *baseJsonTranslator[T]) CreateRepairPrompt(validationError string) string {
return fmt.Sprintf("The JSON object is invalid for the following reason:\n"+
"\"\"\"\n%s\n\"\"\"\n"+
"The following is a revised JSON object:\n", validationError)
}
func (t *baseJsonTranslator[T]) Translate(request string) (*T, error) {
prompt := t.CreateRequestPrompt(request)
attemptRepair := t.attemptRepair
for {
resp, err := t.model.complete(prompt)
if err != nil {
return nil, err
}
startIndex := strings.Index(resp, "{")
endIndex := strings.LastIndex(resp, "}")
if !(startIndex >= 0 && endIndex > startIndex) {
return nil, fmt.Errorf("Response is not JSON:\n%s", resp)
}
jsonText := resp[startIndex : endIndex+1]
result, err := t.validator.Validate(jsonText)
if err == nil {
return result, nil
}
if !attemptRepair {
return nil, fmt.Errorf("JSON validation failed: %v\n%s", err, jsonText)
}
prompt += fmt.Sprintf("%s\n%s", jsonText, t.CreateRepairPrompt(err.Error()))
attemptRepair = false
}
}
func (t *baseJsonTranslator[T]) Validator() JsonValidator[T] {
return t.validator
}
func (t *baseJsonTranslator[T]) Model() LanguageModel {
return t.model
}