generated from milosgajdos/go-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Cohere embedding API support. (#2)
* Add Cohere embedding API support. Signed-off-by: Milos Gajdos <[email protected]>
- Loading branch information
1 parent
b53d156
commit 6689e87
Showing
5 changed files
with
313 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"flag" | ||
"fmt" | ||
"log" | ||
|
||
"github.com/milosgajdos/go-embeddings/cohere" | ||
) | ||
|
||
var ( | ||
input string | ||
model string | ||
truncate string | ||
inputType string | ||
) | ||
|
||
func init() { | ||
flag.StringVar(&input, "input", "", "input data") | ||
flag.StringVar(&model, "model", string(cohere.EnglishV3), "model name") | ||
flag.StringVar(&truncate, "truncate", string(cohere.NoneTrunc), "truncate type") | ||
flag.StringVar(&inputType, "input-type", string(cohere.ClusteringInput), "input type") | ||
} | ||
|
||
func main() { | ||
flag.Parse() | ||
|
||
c, err := cohere.NewClient() | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
embReq := &cohere.EmbeddingRequest{ | ||
Texts: []string{input}, | ||
Model: cohere.Model(model), | ||
InputType: cohere.InputType(inputType), | ||
Truncate: cohere.Truncate(truncate), | ||
} | ||
|
||
embs, err := c.Embeddings(context.Background(), embReq) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
fmt.Printf("got %d embeddings", len(embs)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
package cohere | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"os" | ||
) | ||
|
||
const ( | ||
// BaseURL is Cohere HTTP API base URL. | ||
BaseURL = "https://api.cohere.ai" | ||
// EmbedAPIVersion is the latest stable embedding API version. | ||
EmbedAPIVersion = "v1" | ||
) | ||
|
||
// Client is Cohere HTTP API client. | ||
type Client struct { | ||
apiKey string | ||
baseURL string | ||
version string | ||
hc *http.Client | ||
} | ||
|
||
// NewClient creates a new HTTP client and returns it. | ||
// It reads the Cohere API key from COHERE_API_KEY env var | ||
// and uses the default Go http.Client. | ||
// You can override the default options by using the | ||
// client methods. | ||
func NewClient() (*Client, error) { | ||
return &Client{ | ||
apiKey: os.Getenv("COHERE_API_KEY"), | ||
baseURL: BaseURL, | ||
version: EmbedAPIVersion, | ||
hc: &http.Client{}, | ||
}, nil | ||
} | ||
|
||
// WithAPIKey sets the API key. | ||
func (c *Client) WithAPIKey(apiKey string) *Client { | ||
c.apiKey = apiKey | ||
return c | ||
} | ||
|
||
// WithBaseURL sets the API base URL. | ||
func (c *Client) WithBaseURL(baseURL string) *Client { | ||
c.baseURL = baseURL | ||
return c | ||
} | ||
|
||
// WithVersion sets the API version. | ||
func (c *Client) WithVersion(version string) *Client { | ||
c.version = version | ||
return c | ||
} | ||
|
||
// WithHTTPClient sets the HTTP client. | ||
func (c *Client) WithHTTPClient(httpClient *http.Client) *Client { | ||
c.hc = httpClient | ||
return c | ||
} | ||
|
||
// ReqOption is http requestion functional option. | ||
type ReqOption func(*http.Request) | ||
|
||
// WithSetHeader sets the header key to value val. | ||
func WithSetHeader(key, val string) ReqOption { | ||
return func(req *http.Request) { | ||
if req.Header == nil { | ||
req.Header = make(http.Header) | ||
} | ||
req.Header.Set(key, val) | ||
} | ||
} | ||
|
||
// WithAddHeader adds the val to key header. | ||
func WithAddHeader(key, val string) ReqOption { | ||
return func(req *http.Request) { | ||
if req.Header == nil { | ||
req.Header = make(http.Header) | ||
} | ||
req.Header.Add(key, val) | ||
} | ||
} | ||
|
||
func (c *Client) newRequest(ctx context.Context, method, url string, body io.Reader, opts ...ReqOption) (*http.Request, error) { | ||
if ctx == nil { | ||
ctx = context.Background() | ||
} | ||
if body == nil { | ||
body = &bytes.Reader{} | ||
} | ||
|
||
req, err := http.NewRequestWithContext(ctx, method, url, body) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
for _, setOption := range opts { | ||
setOption(req) | ||
} | ||
|
||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) | ||
req.Header.Set("Accept", "application/json; charset=utf-8") | ||
if body != nil { | ||
// if no content-type is specified we default to json | ||
if ct := req.Header.Get("Content-Type"); len(ct) == 0 { | ||
req.Header.Set("Content-Type", "application/json; charset=utf-8") | ||
} | ||
} | ||
|
||
return req, nil | ||
} | ||
|
||
func (c *Client) doRequest(req *http.Request) (*http.Response, error) { | ||
resp, err := c.hc.Do(req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest { | ||
return resp, nil | ||
} | ||
defer resp.Body.Close() | ||
|
||
var apiErr APIError | ||
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil { | ||
return nil, err | ||
} | ||
|
||
return nil, apiErr | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package cohere | ||
|
||
// Model is embedding model. | ||
type Model string | ||
|
||
const ( | ||
EnglishV3 Model = "embed-english-v3.0" | ||
MultiLingV3 Model = "embed-multilingual-v3.0" | ||
EnglishLightV3 Model = "embed-english-light-v3.0" | ||
MultiLingLightV3 Model = "embed-multilingual-light-v3.0" | ||
EnglishV2 Model = "embed-english-v2.0" | ||
EnglishLightV2 Model = "embed-english-light-v2.0" | ||
MultiLingV2 Model = "embed-multilingual-v2.0" | ||
) | ||
|
||
// InputType is embedding input type. | ||
type InputType string | ||
|
||
const ( | ||
SearchDocInput InputType = "search_document" | ||
SearchQueryInput InputType = "search_query" | ||
ClassificationInput InputType = "classification" | ||
ClusteringInput InputType = "clustering" | ||
) | ||
|
||
// Truncate controls input truncating. | ||
type Truncate string | ||
|
||
const ( | ||
StartTrunc Truncate = "START" | ||
EndTrunc Truncate = "END" | ||
NoneTrunc Truncate = "NONE" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
package cohere | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"io" | ||
"net/http" | ||
"net/url" | ||
) | ||
|
||
// Embedding is cohere API vector embedding. | ||
type Embedding struct { | ||
Vector []float64 `json:"vector"` | ||
} | ||
|
||
// EmbeddingRequest sent to API endpoint. | ||
type EmbeddingRequest struct { | ||
Texts []string `json:"texts"` | ||
Model Model `json:"model,omitempty"` | ||
InputType InputType `json:"input_type"` | ||
Truncate Truncate `json:"truncate,omitempty"` | ||
} | ||
|
||
// EmbedddingResponse received from API endpoint. | ||
type EmbedddingResponse struct { | ||
Embeddings [][]float64 `json:"embeddings"` | ||
Meta *Meta `json:"meta,omitempty"` | ||
} | ||
|
||
// Meta stores API response metadata | ||
type Meta struct { | ||
APIVersion *APIVersion `json:"api_version,omitempty"` | ||
} | ||
|
||
// APIVersion stores metadata API version. | ||
type APIVersion struct { | ||
Version string `json:"version"` | ||
} | ||
|
||
func ToEmbeddings(r io.Reader) ([]*Embedding, error) { | ||
var resp EmbedddingResponse | ||
if err := json.NewDecoder(r).Decode(&resp); err != nil { | ||
return nil, err | ||
} | ||
embs := make([]*Embedding, 0, len(resp.Embeddings)) | ||
for _, e := range resp.Embeddings { | ||
floats := make([]float64, len(e)) | ||
copy(floats, e) | ||
emb := &Embedding{ | ||
Vector: floats, | ||
} | ||
embs = append(embs, emb) | ||
} | ||
return embs, nil | ||
} | ||
|
||
// Embeddings returns embeddings for every object in EmbeddingRequest. | ||
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) { | ||
u, err := url.Parse(c.baseURL + "/" + c.version + "/embed") | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var body = &bytes.Buffer{} | ||
enc := json.NewEncoder(body) | ||
enc.SetEscapeHTML(false) | ||
if err := enc.Encode(embReq); err != nil { | ||
return nil, err | ||
} | ||
|
||
req, err := c.newRequest(ctx, http.MethodPost, u.String(), body) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
resp, err := c.doRequest(req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer resp.Body.Close() | ||
|
||
return ToEmbeddings(resp.Body) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package cohere | ||
|
||
import "encoding/json" | ||
|
||
type APIError struct { | ||
Message string `json:"message"` | ||
} | ||
|
||
func (e APIError) Error() string { | ||
b, err := json.Marshal(e) | ||
if err != nil { | ||
return "unknown error" | ||
} | ||
return string(b) | ||
} |