-
Notifications
You must be signed in to change notification settings - Fork 2
Exercises
Johnny Boursiquot edited this page Jun 7, 2024
·
2 revisions
Consider the following code:
package mypackage
import (
"context"
"fmt"
"os"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
)
type Person struct {
Name string
}
var client *dynamodb.DynamoDB
func main() {
client = dynamodb.New(aws.NewConfig().WithRegion("us-east-1")),
ctx := context.Background()
person := &Person{Name: "John Doe"}
item, err := dynamodbattribute.MarshalMap(p)
if err != nil {
log.Fatalln("failed to marshal person for storage: %s", err)
}
input := &dynamodb.PutItemInput{
Item: item,
TableName: aws.String(os.Getenv("TABLE_NAME")),
}
_, err := client.PutItemWithContext(ctx, input)
if err != nil {
fmt.Printf("Failed to save person: %v\n", err)
}
}
- Create an interface type to provide dependency injection
- Implicitly implement the functionality of an interface
- Leverage interfaces during testing by writing stubs/mocks
The following is a starter sample for how you might refactor. You don't have to use it.
package mypackage
import (
"context"
"fmt"
"os"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
)
// Person is a simple struct that we will use to store data in DynamoDB.
type Person struct {
Name string
}
type ddbClient interface {
// TODO: Add method to the interface for puting an item into DynamoDB. Look at how the client is used in the `Save` method below.
// Note the key takeaway here:
// We don't need to mock out the entire DynamoDB client interface, just the methods we need.
// We don't even need to use the interface AWS provides for DynamoDB SDK, again, just the behavior we need.
}
// DynamoDBSaver is a wrapper that encapsulates interactions with DynamoDB.
type DynamoDBSaver struct {
Client ddbClient
}
// Save performs the put operation on DynamoDB using the client we've assigned to DynamoDBSaver.
func (s *DynamoDBSaver) Save(ctx context.Context, p *Person) error {
item, err := dynamodbattribute.MarshalMap(p)
if err != nil {
return fmt.Errorf("failed to marshal shoutout for storage: %s", err)
}
input := &dynamodb.PutItemInput{
Item: item,
TableName: aws.String(os.Getenv("TABLE_NAME")),
}
_, err = s.Client.PutItemWithContext(ctx, input)
return err
}
Your objective is to refactor the following code in order to improves its testability. Leverage interfaces for dependency injection where possible.
package main
import (
"errors"
"flag"
"fmt"
"net"
"os"
"runtime"
"strconv"
"strings"
"sync"
"time"
)
var ports string
var workers int
func init() {
flag.StringVar(&ports, "ports", "5400-5500", "Port(s) (e.g. 80, 22-100).")
flag.IntVar(&workers, "workers", runtime.NumCPU(), "Number of workers (defaults to # of logical CPUs).")
}
func main() {
flag.Parse()
portsToScan, err := parsePortsToScan(ports)
if err != nil {
fmt.Printf("Failed to parse ports to scan: %s\n", err)
os.Exit(1)
}
// The done channel will be shared by the entire pipeline
// so that when it's closed it serves as a signal
// for all the goroutines we started to exit.
done := make(chan struct{})
defer close(done)
in := gen(done, portsToScan...)
// fan-out
var chans []<-chan scanOp
for i := 0; i < workers; i++ {
chans = append(chans, scan(done, in))
}
// for s := range filterOpen(done, merge(done, chans...)) {
// fmt.Printf("%#v\n", s)
// }
for s := range filterErr(done, merge(done, chans...)) {
fmt.Printf("%#v\n", s)
done <- struct{}{}
return
}
// done chan is closed by the deferred call here
}
func parsePortsToScan(portsFlag string) ([]int, error) {
p, err := strconv.Atoi(portsFlag)
if err == nil {
return []int{p}, nil
}
ports := strings.Split(portsFlag, "-")
if len(ports) != 2 {
return nil, errors.New("unable to determine port(s) to scan")
}
minPort, err := strconv.Atoi(ports[0])
if err != nil {
return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[0])
}
maxPort, err := strconv.Atoi(ports[1])
if err != nil {
return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[1])
}
if minPort <= 0 || maxPort <= 0 {
return nil, fmt.Errorf("port numbers must be greater than 0")
}
var results []int
for p := minPort; p <= maxPort; p++ {
results = append(results, p)
}
return results, nil
}
type scanOp struct {
port int
open bool
scanErr string
scanDuration time.Duration
}
func gen(done <-chan struct{}, ports ...int) <-chan scanOp {
out := make(chan scanOp, len(ports))
go func() {
defer close(out)
for _, p := range ports {
select {
case out <- scanOp{port: p}:
case <-done:
return
}
}
}()
return out
}
func scan(done <-chan struct{}, in <-chan scanOp) <-chan scanOp {
out := make(chan scanOp)
go func() {
defer close(out)
for scan := range in {
select {
default:
address := fmt.Sprintf("127.0.0.1:%d", scan.port)
start := time.Now()
conn, err := net.Dial("tcp", address)
scan.scanDuration = time.Since(start)
if err != nil {
scan.scanErr = err.Error()
} else {
conn.Close()
scan.open = true
}
out <- scan
case <-done:
return
}
}
}()
return out
}
func filterOpen(done <-chan struct{}, in <-chan scanOp) <-chan scanOp {
out := make(chan scanOp)
go func() {
defer close(out)
for scan := range in {
select {
default:
if scan.open {
out <- scan
}
case <-done:
return
}
}
}()
return out
}
func filterErr(done <-chan struct{}, in <-chan scanOp) <-chan scanOp {
out := make(chan scanOp)
go func() {
defer close(out)
for scan := range in {
select {
default:
if !scan.open && strings.Contains(scan.scanErr, "too many open files") {
out <- scan
}
case <-done:
return
}
}
}()
return out
}
func merge(done <-chan struct{}, chans ...<-chan scanOp) <-chan scanOp {
out := make(chan scanOp)
wg := sync.WaitGroup{}
wg.Add(len(chans))
for _, sc := range chans {
go func(sc <-chan scanOp) {
defer wg.Done()
for scan := range sc {
select {
case out <- scan:
case <-done:
return
}
}
}(sc)
}
go func() {
wg.Wait()
close(out)
}()
return out
}
You have been tasked with writing a program that generates embeddings for text data and storing them in a PostgreSQL database table. You will make use of the following technologies:
-
pgvector and
vector
extension support in a local PostgreSQL database running via Docker container. - Ollama for running LLM models locally (also running via Docker)
- Gorm as the ORM for migrations and data management
- Testcontainers for integration testing via locally running servers using Docker
Below is a working proof of concept. You'll need to refactor this code for testability using the same techniques you used in previous solutions. You'll also need to integrate Testcontainers in your tests.
package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"flag"
"log"
"net/http"
"os"
"strings"
"time"
"github.com/pgvector/pgvector-go"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
var bookPath string
var db *gorm.DB
func init() {
flag.StringVar(&bookPath, "book", "", "path to the book ")
flag.Parse()
var err error
var dsn = "postgres://postgres:password@localhost:5432/test?sslmode=disable"
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
log.Fatalln("failed to connect database", err)
}
if err := db.Exec("CREATE EXTENSION IF NOT EXISTS vector").Error; err != nil {
log.Fatalln("failed to create extension", err)
}
if err := db.AutoMigrate(&book{}, &bookEmbedding{}); err != nil {
log.Fatalln("failed to migrate", err)
}
if err := db.Exec("CREATE INDEX ON book_embeddings USING hnsw (embedding vector_l2_ops)").Error; err != nil {
log.Fatalln("failed to create index", err)
}
}
type book struct {
gorm.Model
Title string
Author string
Embeddings []bookEmbedding
}
type bookEmbedding struct {
gorm.Model
BookID uint
Text string
Embedding pgvector.Vector `gorm:"type:vector(384)"`
}
func main() {
ctx := context.Background()
log.Println("Start")
if bookPath == "" {
log.Fatalln("book path is required")
}
f, err := os.Open(bookPath)
if err != nil {
log.Fatalln(err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
var text []string
for scanner.Scan() {
text = append(text, scanner.Text())
}
const chunkSize = 512
const chunkOverlap = 128
var chunks []string // store the final chunks of text
var currentChunkBuilder strings.Builder // helps efficiently build the current chunk of text
var currentChunkWords int // keeps track of the number of words in the current chunk
for _, line := range text {
words := strings.Fields(line) // split the line into words
for _, word := range words {
if currentChunkWords > 0 {
currentChunkBuilder.WriteString(" ") // add a space before adding the next word
}
currentChunkBuilder.WriteString(word) // add the word to the current chunk
currentChunkWords++ // increment the number of words in the current chunk
// build the full chunk
if currentChunkWords >= chunkSize {
chunks = append(chunks, currentChunkBuilder.String())
overlapWords := strings.Fields(currentChunkBuilder.String())
currentChunkBuilder.Reset()
currentChunkWords = 0
for i := len(overlapWords) - chunkOverlap; i < len(overlapWords); i++ {
if currentChunkWords > 0 {
currentChunkBuilder.WriteString(" ")
}
currentChunkBuilder.WriteString(overlapWords[i])
currentChunkWords++
}
}
}
}
// add the last chunk
if currentChunkWords > 0 {
chunks = append(chunks, currentChunkBuilder.String())
}
var b book
if err := db.FirstOrCreate(&b, book{Title: "Meditations", Author: "Marcus Aurelius"}).Error; err != nil {
log.Fatalln("failed to create book", err)
}
httpClient := http.Client{Timeout: 30 * time.Second}
type embeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
}
type embeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
endpoint := "http://localhost:11434/api/embeddings"
for _, chunk := range chunks {
bs, err := json.Marshal(embeddingRequest{
Model: "all-minilm",
Prompt: chunk,
})
if err != nil {
log.Fatalln(err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(bs))
if err != nil {
log.Fatalln(err)
}
req.Header.Set("Content-Type", "application/json")
res, err := httpClient.Do(req)
if err != nil {
log.Fatalln(err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
log.Fatalf("unexpected status code: %s", res.Status)
}
var response embeddingResponse
if err := json.NewDecoder(res.Body).Decode(&response); err != nil {
log.Fatalln(err)
}
be := bookEmbedding{
BookID: b.ID,
Text: chunk,
Embedding: pgvector.NewVector(response.Embedding),
}
if err := db.Save(&be).Error; err != nil {
log.Fatalln("failed to save book embedding", err)
}
}
if err := db.Save(&b).Error; err != nil {
log.Fatalln("failed to save book", err)
}
var strEmbedding []float32
{
str := "How short lived the praiser and the praised, the one who remembers and the remembered."
bs, err := json.Marshal(embeddingRequest{
Model: "all-minilm",
Prompt: str,
})
if err != nil {
log.Fatalln(err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(bs))
if err != nil {
log.Fatalln(err)
}
req.Header.Set("Content-Type", "application/json")
res, err := httpClient.Do(req)
if err != nil {
log.Fatalln(err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
log.Fatalf("unexpected status code: %s", res.Status)
}
var response embeddingResponse
if err := json.NewDecoder(res.Body).Decode(&response); err != nil {
log.Fatalln(err)
}
strEmbedding = response.Embedding
}
var bookEmbeddings []bookEmbedding
db.Clauses(
clause.OrderBy{
Expression: clause.Expr{
SQL: "embedding <-> ?",
Vars: []interface{}{
pgvector.NewVector(strEmbedding),
},
},
},
).Limit(5).Find(&bookEmbeddings)
for _, be := range bookEmbeddings {
log.Println(be.Text)
}
log.Println("Done")
}