From 55d860d4711fd4403d2fc2a14fb3333478802155 Mon Sep 17 00:00:00 2001 From: Ariya Hidayat Date: Tue, 24 Sep 2024 20:31:09 -0700 Subject: [PATCH] Go version: support response streaming --- ask-llm.go | 64 ++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/ask-llm.go b/ask-llm.go index 081dff4..4454b14 100644 --- a/ask-llm.go +++ b/ask-llm.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "os" + "strings" "time" ) @@ -14,6 +15,7 @@ var ( LLMAPIBaseURL = os.Getenv("LLM_API_BASE_URL") LLMAPIKey = os.Getenv("LLM_API_KEY") LLMChatModel = os.Getenv("LLM_CHAT_MODEL") + LLMStreaming = os.Getenv("LLM_STREAMING") != "no" LLMDebug = os.Getenv("LLM_DEBUG") ) @@ -28,6 +30,7 @@ type ChatRequest struct { Stop []string `json:"stop"` MaxTokens int `json:"max_tokens"` Temperature float64 `json:"temperature"` + Stream bool `json:"stream"` } type Choice struct { @@ -36,18 +39,20 @@ type Choice struct { } `json:"message"` } -func chat(messages []Message) (string, error) { +func chat(messages []Message, handler func(string)) (string, error) { url := fmt.Sprintf("%s/chat/completions", LLMAPIBaseURL) authHeader := "" if LLMAPIKey != "" { authHeader = fmt.Sprintf("Bearer %s", LLMAPIKey) } + stream := LLMStreaming && handler != nil requestBody := ChatRequest{ Messages: messages, Model: LLMChatModel, Stop: []string{"<|im_end|>", "<|end|>", "<|eot_id|>"}, MaxTokens: 200, Temperature: 0, + Stream: stream, } jsonBody, err := json.Marshal(requestBody) if err != nil { @@ -73,15 +78,49 @@ func chat(messages []Message) (string, error) { return "", fmt.Errorf("HTTP error: %d %s", resp.StatusCode, resp.Status) } - var data struct { - Choices []Choice `json:"choices"` - } - if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { - return "", err + if !stream { + var data struct { + Choices []Choice `json:"choices"` + } + err := json.NewDecoder(resp.Body).Decode(&data) + if err != nil { + return "", err + } + answer := data.Choices[0].Message.Content + if handler != nil { + handler(answer) + } + return answer, nil + } else { + answer := "" + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + payload := line[6:] + var data struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + } `json:"choices"` + } + err := json.Unmarshal([]byte(payload), &data) + if err != nil { + return "", err + } + partial := data.Choices[0].Delta.Content + answer += partial + if handler != nil { + handler(partial) + } + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return answer, nil } - - answer := data.Choices[0].Message.Content - return answer, nil } const SystemPrompt = "Answer the question politely and concisely." @@ -105,14 +144,17 @@ func main() { } messages = append(messages, Message{Role: "user", Content: question}) + handler := func(partial string) { + fmt.Print(partial) + } start := time.Now() - answer, err := chat(messages) + answer, err := chat(messages, handler) if err != nil { fmt.Println("Error:", err) break } messages = append(messages, Message{Role: "assistant", Content: answer}) - fmt.Println(answer) + fmt.Println() elapsed := time.Since(start) if LLMDebug != "" { fmt.Printf("[%d ms]\n", elapsed.Milliseconds())