Skip to content

Commit

Permalink
Go version: support response streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
ariya committed Sep 25, 2024
1 parent 939f72d commit 55d860d
Showing 1 changed file with 53 additions and 11 deletions.
64 changes: 53 additions & 11 deletions ask-llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
"fmt"
"net/http"
"os"
"strings"
"time"
)

var (
LLMAPIBaseURL = os.Getenv("LLM_API_BASE_URL")
LLMAPIKey = os.Getenv("LLM_API_KEY")
LLMChatModel = os.Getenv("LLM_CHAT_MODEL")
LLMStreaming = os.Getenv("LLM_STREAMING") != "no"
LLMDebug = os.Getenv("LLM_DEBUG")
)

Expand All @@ -28,6 +30,7 @@ type ChatRequest struct {
Stop []string `json:"stop"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
Stream bool `json:"stream"`
}

type Choice struct {
Expand All @@ -36,18 +39,20 @@ type Choice struct {
} `json:"message"`
}

func chat(messages []Message) (string, error) {
func chat(messages []Message, handler func(string)) (string, error) {
url := fmt.Sprintf("%s/chat/completions", LLMAPIBaseURL)
authHeader := ""
if LLMAPIKey != "" {
authHeader = fmt.Sprintf("Bearer %s", LLMAPIKey)
}
stream := LLMStreaming && handler != nil
requestBody := ChatRequest{
Messages: messages,
Model: LLMChatModel,
Stop: []string{"<|im_end|>", "<|end|>", "<|eot_id|>"},
MaxTokens: 200,
Temperature: 0,
Stream: stream,
}
jsonBody, err := json.Marshal(requestBody)
if err != nil {
Expand All @@ -73,15 +78,49 @@ func chat(messages []Message) (string, error) {
return "", fmt.Errorf("HTTP error: %d %s", resp.StatusCode, resp.Status)
}

var data struct {
Choices []Choice `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return "", err
if !stream {
var data struct {
Choices []Choice `json:"choices"`
}
err := json.NewDecoder(resp.Body).Decode(&data)
if err != nil {
return "", err
}
answer := data.Choices[0].Message.Content
if handler != nil {
handler(answer)
}
return answer, nil
} else {
answer := ""
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
payload := line[6:]
var data struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
}
err := json.Unmarshal([]byte(payload), &data)
if err != nil {
return "", err
}
partial := data.Choices[0].Delta.Content
answer += partial
if handler != nil {
handler(partial)
}
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return answer, nil
}

answer := data.Choices[0].Message.Content
return answer, nil
}

const SystemPrompt = "Answer the question politely and concisely."
Expand All @@ -105,14 +144,17 @@ func main() {
}

messages = append(messages, Message{Role: "user", Content: question})
handler := func(partial string) {
fmt.Print(partial)
}
start := time.Now()
answer, err := chat(messages)
answer, err := chat(messages, handler)
if err != nil {
fmt.Println("Error:", err)
break
}
messages = append(messages, Message{Role: "assistant", Content: answer})
fmt.Println(answer)
fmt.Println()
elapsed := time.Since(start)
if LLMDebug != "" {
fmt.Printf("[%d ms]\n", elapsed.Milliseconds())
Expand Down

0 comments on commit 55d860d

Please sign in to comment.