Skip to content

Commit

Permalink
Add Cohere embedding API support. (#2)
Browse files Browse the repository at this point in the history
* Add Cohere embedding API support.

Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos authored Nov 29, 2023
1 parent b53d156 commit 6689e87
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 0 deletions.
47 changes: 47 additions & 0 deletions cmd/cohere/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package main

import (
"context"
"flag"
"fmt"
"log"

"github.com/milosgajdos/go-embeddings/cohere"
)

var (
input string
model string
truncate string
inputType string
)

func init() {
flag.StringVar(&input, "input", "", "input data")
flag.StringVar(&model, "model", string(cohere.EnglishV3), "model name")
flag.StringVar(&truncate, "truncate", string(cohere.NoneTrunc), "truncate type")
flag.StringVar(&inputType, "input-type", string(cohere.ClusteringInput), "input type")
}

func main() {
flag.Parse()

c, err := cohere.NewClient()
if err != nil {
log.Fatal(err)
}

embReq := &cohere.EmbeddingRequest{
Texts: []string{input},
Model: cohere.Model(model),
InputType: cohere.InputType(inputType),
Truncate: cohere.Truncate(truncate),
}

embs, err := c.Embeddings(context.Background(), embReq)
if err != nil {
log.Fatal(err)
}

fmt.Printf("got %d embeddings", len(embs))
}
134 changes: 134 additions & 0 deletions cohere/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package cohere

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
)

const (
// BaseURL is Cohere HTTP API base URL.
BaseURL = "https://api.cohere.ai"
// EmbedAPIVersion is the latest stable embedding API version.
EmbedAPIVersion = "v1"
)

// Client is Cohere HTTP API client.
type Client struct {
apiKey string
baseURL string
version string
hc *http.Client
}

// NewClient creates a new HTTP client and returns it.
// It reads the Cohere API key from COHERE_API_KEY env var
// and uses the default Go http.Client.
// You can override the default options by using the
// client methods.
func NewClient() (*Client, error) {
return &Client{
apiKey: os.Getenv("COHERE_API_KEY"),
baseURL: BaseURL,
version: EmbedAPIVersion,
hc: &http.Client{},
}, nil
}

// WithAPIKey sets the API key.
func (c *Client) WithAPIKey(apiKey string) *Client {
c.apiKey = apiKey
return c
}

// WithBaseURL sets the API base URL.
func (c *Client) WithBaseURL(baseURL string) *Client {
c.baseURL = baseURL
return c
}

// WithVersion sets the API version.
func (c *Client) WithVersion(version string) *Client {
c.version = version
return c
}

// WithHTTPClient sets the HTTP client.
func (c *Client) WithHTTPClient(httpClient *http.Client) *Client {
c.hc = httpClient
return c
}

// ReqOption is http requestion functional option.
type ReqOption func(*http.Request)

// WithSetHeader sets the header key to value val.
func WithSetHeader(key, val string) ReqOption {
return func(req *http.Request) {
if req.Header == nil {
req.Header = make(http.Header)
}
req.Header.Set(key, val)
}
}

// WithAddHeader adds the val to key header.
func WithAddHeader(key, val string) ReqOption {
return func(req *http.Request) {
if req.Header == nil {
req.Header = make(http.Header)
}
req.Header.Add(key, val)
}
}

func (c *Client) newRequest(ctx context.Context, method, url string, body io.Reader, opts ...ReqOption) (*http.Request, error) {
if ctx == nil {
ctx = context.Background()
}
if body == nil {
body = &bytes.Reader{}
}

req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}

for _, setOption := range opts {
setOption(req)
}

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
req.Header.Set("Accept", "application/json; charset=utf-8")
if body != nil {
// if no content-type is specified we default to json
if ct := req.Header.Get("Content-Type"); len(ct) == 0 {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}
}

return req, nil
}

func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
resp, err := c.hc.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest {
return resp, nil
}
defer resp.Body.Close()

var apiErr APIError
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
return nil, err
}

return nil, apiErr
}
33 changes: 33 additions & 0 deletions cohere/cohere.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package cohere

// Model is embedding model.
type Model string

const (
EnglishV3 Model = "embed-english-v3.0"
MultiLingV3 Model = "embed-multilingual-v3.0"
EnglishLightV3 Model = "embed-english-light-v3.0"
MultiLingLightV3 Model = "embed-multilingual-light-v3.0"
EnglishV2 Model = "embed-english-v2.0"
EnglishLightV2 Model = "embed-english-light-v2.0"
MultiLingV2 Model = "embed-multilingual-v2.0"
)

// InputType is embedding input type.
type InputType string

const (
SearchDocInput InputType = "search_document"
SearchQueryInput InputType = "search_query"
ClassificationInput InputType = "classification"
ClusteringInput InputType = "clustering"
)

// Truncate controls input truncating.
type Truncate string

const (
StartTrunc Truncate = "START"
EndTrunc Truncate = "END"
NoneTrunc Truncate = "NONE"
)
84 changes: 84 additions & 0 deletions cohere/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package cohere

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/url"
)

// Embedding is cohere API vector embedding.
type Embedding struct {
Vector []float64 `json:"vector"`
}

// EmbeddingRequest sent to API endpoint.
type EmbeddingRequest struct {
Texts []string `json:"texts"`
Model Model `json:"model,omitempty"`
InputType InputType `json:"input_type"`
Truncate Truncate `json:"truncate,omitempty"`
}

// EmbedddingResponse received from API endpoint.
type EmbedddingResponse struct {
Embeddings [][]float64 `json:"embeddings"`
Meta *Meta `json:"meta,omitempty"`
}

// Meta stores API response metadata
type Meta struct {
APIVersion *APIVersion `json:"api_version,omitempty"`
}

// APIVersion stores metadata API version.
type APIVersion struct {
Version string `json:"version"`
}

func ToEmbeddings(r io.Reader) ([]*Embedding, error) {
var resp EmbedddingResponse
if err := json.NewDecoder(r).Decode(&resp); err != nil {
return nil, err
}
embs := make([]*Embedding, 0, len(resp.Embeddings))
for _, e := range resp.Embeddings {
floats := make([]float64, len(e))
copy(floats, e)
emb := &Embedding{
Vector: floats,
}
embs = append(embs, emb)
}
return embs, nil
}

// Embeddings returns embeddings for every object in EmbeddingRequest.
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) {
u, err := url.Parse(c.baseURL + "/" + c.version + "/embed")
if err != nil {
return nil, err
}

var body = &bytes.Buffer{}
enc := json.NewEncoder(body)
enc.SetEscapeHTML(false)
if err := enc.Encode(embReq); err != nil {
return nil, err
}

req, err := c.newRequest(ctx, http.MethodPost, u.String(), body)
if err != nil {
return nil, err
}

resp, err := c.doRequest(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

return ToEmbeddings(resp.Body)
}
15 changes: 15 additions & 0 deletions cohere/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package cohere

import "encoding/json"

type APIError struct {
Message string `json:"message"`
}

func (e APIError) Error() string {
b, err := json.Marshal(e)
if err != nil {
return "unknown error"
}
return string(b)
}

0 comments on commit 6689e87

Please sign in to comment.