-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.go
199 lines (168 loc) · 5.35 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
package main
import (
"bufio"
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"regexp"
"strings"
"time"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/spf13/viper"
"github.com/xrash/smetrics"
)
type TriviaResponse struct {
Questions []string `json:"questions"`
Answers []string `json:"answers"`
}
func main() {
log.SetFlags(0)
cliApiKey := flag.String("apiKey", "", "API key (overrides config file)")
flag.Parse()
configFile := "config.json"
apiKey := ""
// Check if the file exists
if _, err := os.Stat(configFile); os.IsNotExist(err) {
fmt.Printf("Config file %s does not exist. I will use the api key flag\n", configFile)
if *cliApiKey != "" {
apiKey = *cliApiKey
} else {
fmt.Println("No API key provided. Please set the API key in the config file or use the -apiKey flag.")
panic("No API Key")
}
} else {
apiKey = setApiKey()
}
mainPrompt := generatePrompt()
client := openai.NewClient(option.WithAPIKey(apiKey))
// create a channel to signal when to stop the loading indicator
done := make(chan bool)
// start loading indicator
go loadingIndicator(done)
completion := generateCompletion(context.TODO(), client, mainPrompt)
// stop the loading indicator
done <- true
var trivia TriviaResponse
err := json.Unmarshal([]byte(completion.Choices[0].Message.Content), &trivia)
if err != nil {
log.Fatalf("Error unmarshaling JSON: %v", err)
}
runTriviaGame(trivia.Questions, trivia.Answers)
}
func setApiKey() string {
viper.SetConfigName("config") // Name of config file without extension
viper.SetConfigType("json") // Type of config file
viper.AddConfigPath(".") // Path to look for the config file
err := viper.ReadInConfig()
if err != nil {
fmt.Println("Error reading config file:", err)
panic(err)
}
apiKey := viper.GetString("apikey")
return apiKey
}
func generatePrompt() string {
// get the trivia topic from the user
topicPrompt := "Enter a topic for your trivia questions, or leave blank for random topics: "
userTopic := getUserInput(topicPrompt, true)
if userTopic == "" {
userTopic = "random topics"
}
mainPrompt := fmt.Sprintf(`Generate a series of 20 trivia questions about %s.
Please respond only in the following valid JSON format, with no extra formatting or text in your response:
{
"questions": ["Question 1", "Question 2", ...],
"answers": ["Answer 1", "Answer 2", ...]
}
Do not include any questions where the answer would include special symbols or characters.`, userTopic)
return mainPrompt
}
func generateCompletion(ctx context.Context, client *openai.Client, prompt string) *openai.ChatCompletion {
chatCompletion, err := client.Chat.Completions.New(
ctx,
openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage(prompt),
}),
Model: openai.F(openai.ChatModelGPT4oMini),
})
if err != nil {
panic(err.Error())
}
return chatCompletion
}
// prints a spinning indicator until receiving a signal to stop
func loadingIndicator(done chan bool) {
// spinner animation characters
chars := []rune{'|', '/', '-', '\\'}
for {
select {
// clear the loading indicator when done
case <-done:
fmt.Print("\r \r")
return
default:
for _, char := range chars {
fmt.Printf("\rGenerating trivia questions... %c", char)
time.Sleep(100 * time.Millisecond)
}
}
}
}
func getUserInput(prompt string, allowEmpty bool) string {
fmt.Print(prompt)
scanner := bufio.NewScanner(os.Stdin)
for {
scanner.Scan()
input := scanner.Text()
if input != "" || allowEmpty {
return input
}
fmt.Print("Input cannot be empty. Try again: ")
}
}
func isCorrectAnswer(userGuess, correctAnswer string) bool {
normalizedUserGuess := strings.ToLower(strings.TrimSpace(userGuess))
normalizedCorrectAnswer := strings.ToLower(strings.TrimSpace(correctAnswer))
// check if the user's guess is a substring of the correct answer
// this covers edge cases where the correct answer includes "the" as in "The Nile"
// or answering "Shakespeare" when the answer is "William Shakespeare"
if strings.Contains(normalizedCorrectAnswer, normalizedUserGuess) {
return true
}
// use Jaro-Winkler distance to compare the user's guess to the correct answer
// the comparison is forgiving of simple typos and misspellings, etc
return smetrics.JaroWinkler(normalizedUserGuess, normalizedCorrectAnswer, 0.7, 4) > 0.85
}
func runTriviaGame(questions, answers []string) {
score := 0
isNumeric := regexp.MustCompile(`^\d+$`)
// ask the user to answer the trivia questions one by one
fmt.Println("\nTrivia Questions:")
for i, question := range questions {
userGuess := getUserInput(fmt.Sprintf("\nQuestion %d: %s\nEnter your guess: ", i+1, question), false)
correctAnswer := answers[i]
// check if the correct answer is numeric
if isNumeric.MatchString(correctAnswer) {
// if it's a numeric answer there must be an exact match with the user's guess
if userGuess == correctAnswer {
fmt.Println("\nCorrect!")
score++
} else {
fmt.Printf("\nIncorrect. The answer was %s\n", correctAnswer)
}
} else {
if isCorrectAnswer(userGuess, correctAnswer) {
fmt.Println("\nCorrect!")
score++
} else {
fmt.Printf("\nIncorrect. The answer was %s\n", correctAnswer)
}
}
}
fmt.Println("\nYour final score: ", score, " out of ", len(questions))
}