Skip to content

Commit

Permalink
Add version support and change Embedding.Embedding to Vector (#3)
Browse files Browse the repository at this point in the history
Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos authored Nov 29, 2023
1 parent 8448691 commit b53d156
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
12 changes: 11 additions & 1 deletion openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import (

const (
// BaseURL is OpenAI HTTP API base URL.
BaseURL = "https://api.openai.com/v1"
BaseURL = "https://api.openai.com"
// EmbedAPIVersion is the latest stable embedding API version.
EmbedAPIVersion = "v1"
// Org header
OrgHeader = "OpenAI-Organization"
)
Expand All @@ -21,6 +23,7 @@ const (
type Client struct {
apiKey string
baseURL string
version string
orgID string
hc *http.Client
}
Expand All @@ -34,6 +37,7 @@ func NewClient() (*Client, error) {
return &Client{
apiKey: os.Getenv("OPENAI_API_KEY"),
baseURL: BaseURL,
version: EmbedAPIVersion,
orgID: "",
hc: &http.Client{},
}, nil
Expand All @@ -51,6 +55,12 @@ func (c *Client) WithBaseURL(baseURL string) *Client {
return c
}

// WithVersion sets the API version.
func (c *Client) WithVersion(version string) *Client {
c.version = version
return c
}

// WithOrgID sets the organization ID.
func (c *Client) WithOrgID(orgID string) *Client {
c.orgID = orgID
Expand Down
22 changes: 11 additions & 11 deletions openai/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ type Usage struct {
TotalTokens int `json:"total_tokens"`
}

// Embedding is openai API vector embedding.
// Embedding is openai API embedding.
type Embedding struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
Object string `json:"object"`
Index int `json:"index"`
Vector []float64 `json:"vector"`
}

// EmbeddingString is base64 encoded embedding.
Expand Down Expand Up @@ -87,9 +87,9 @@ func ToEmbeddings[T any](resp io.Reader) ([]*Embedding, error) {
return nil, err
}
emb := &Embedding{
Object: d.Object,
Index: d.Index,
Embedding: floats,
Object: d.Object,
Index: d.Index,
Vector: floats,
}
embs = append(embs, emb)
}
Expand All @@ -98,9 +98,9 @@ func ToEmbeddings[T any](resp io.Reader) ([]*Embedding, error) {
embs := make([]*Embedding, 0, len(e.Data))
for _, d := range e.Data {
emb := &Embedding{
Object: d.Object,
Index: d.Index,
Embedding: d.Embedding,
Object: d.Object,
Index: d.Index,
Vector: d.Embedding,
}
embs = append(embs, emb)
}
Expand All @@ -112,7 +112,7 @@ func ToEmbeddings[T any](resp io.Reader) ([]*Embedding, error) {

// Embeddings returns embeddings for every object in EmbeddingRequest.
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) {
u, err := url.Parse(c.baseURL + "/embeddings")
u, err := url.Parse(c.baseURL + "/" + c.version + "/embeddings")
if err != nil {
return nil, err
}
Expand Down

0 comments on commit b53d156

Please sign in to comment.