Skip to content

Exercises

Johnny Boursiquot edited this page Jun 7, 2024 · 2 revisions

Exercise 1: Refactoring with Interfaces

Consider the following code:

package mypackage

import (
	"context"
	"fmt"
	"os"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
)

type Person struct {
	Name string
}

var client *dynamodb.DynamoDB

func main() {
    client = dynamodb.New(aws.NewConfig().WithRegion("us-east-1")),

	ctx := context.Background()
	person := &Person{Name: "John Doe"}

    item, err := dynamodbattribute.MarshalMap(p)
	if err != nil {
        log.Fatalln("failed to marshal person for storage: %s", err)
	}

	input := &dynamodb.PutItemInput{
        Item:      item,
		TableName: aws.String(os.Getenv("TABLE_NAME")),
	}

	_, err := client.PutItemWithContext(ctx, input)
	if err != nil {
		fmt.Printf("Failed to save person: %v\n", err)
	}
}

Objectives

  1. Create an interface type to provide dependency injection
  2. Implicitly implement the functionality of an interface
  3. Leverage interfaces during testing by writing stubs/mocks

Starter Code

The following is a starter sample for how you might refactor. You don't have to use it.

package mypackage

import (
	"context"
	"fmt"
	"os"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
)
// Person is a simple struct that we will use to store data in DynamoDB.
type Person struct {
  Name string
}

type ddbClient interface {
  // TODO: Add method to the interface for puting an item into DynamoDB. Look at how the client is used in the `Save` method below.
  // Note the key takeaway here:
  // We don't need to mock out the entire DynamoDB client interface, just the methods we need.
  // We don't even need to use the interface AWS provides for DynamoDB SDK, again, just the behavior we need.
}

// DynamoDBSaver is a wrapper that encapsulates interactions with DynamoDB.
type DynamoDBSaver struct {
	Client ddbClient
}

// Save performs the put operation on DynamoDB using the client we've assigned to DynamoDBSaver.
func (s *DynamoDBSaver) Save(ctx context.Context, p *Person) error {
	item, err := dynamodbattribute.MarshalMap(p)
	if err != nil {
		return fmt.Errorf("failed to marshal shoutout for storage: %s", err)
	}

	input := &dynamodb.PutItemInput{
		Item:      item,
		TableName: aws.String(os.Getenv("TABLE_NAME")),
	}

	_, err = s.Client.PutItemWithContext(ctx, input)

	return err
}

Exercise 2: Refactor for testability

Your objective is to refactor the following code in order to improves its testability. Leverage interfaces for dependency injection where possible.

package main

import (
	"errors"
	"flag"
	"fmt"
	"net"
	"os"
	"runtime"
	"strconv"
	"strings"
	"sync"
	"time"
)

var ports string
var workers int

func init() {
	flag.StringVar(&ports, "ports", "5400-5500", "Port(s) (e.g. 80, 22-100).")
	flag.IntVar(&workers, "workers", runtime.NumCPU(), "Number of workers (defaults to # of logical CPUs).")
}

func main() {
	flag.Parse()

	portsToScan, err := parsePortsToScan(ports)
	if err != nil {
		fmt.Printf("Failed to parse ports to scan: %s\n", err)
		os.Exit(1)
	}

	// The done channel will be shared by the entire pipeline
	// so that when it's closed it serves as a signal
	// for all the goroutines we started to exit.
	done := make(chan struct{})
	defer close(done)

	in := gen(done, portsToScan...)

	// fan-out
	var chans []<-chan scanOp
	for i := 0; i < workers; i++ {
		chans = append(chans, scan(done, in))
	}

	// for s := range filterOpen(done, merge(done, chans...)) {
	// 	fmt.Printf("%#v\n", s)
	// }

	for s := range filterErr(done, merge(done, chans...)) {
		fmt.Printf("%#v\n", s)
		done <- struct{}{}
		return
	}

	// done chan is closed by the deferred call here
}

func parsePortsToScan(portsFlag string) ([]int, error) {
	p, err := strconv.Atoi(portsFlag)
	if err == nil {
		return []int{p}, nil
	}

	ports := strings.Split(portsFlag, "-")
	if len(ports) != 2 {
		return nil, errors.New("unable to determine port(s) to scan")
	}

	minPort, err := strconv.Atoi(ports[0])
	if err != nil {
		return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[0])
	}

	maxPort, err := strconv.Atoi(ports[1])
	if err != nil {
		return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[1])
	}

	if minPort <= 0 || maxPort <= 0 {
		return nil, fmt.Errorf("port numbers must be greater than 0")
	}

	var results []int
	for p := minPort; p <= maxPort; p++ {
		results = append(results, p)
	}
	return results, nil
}

type scanOp struct {
	port         int
	open         bool
	scanErr      string
	scanDuration time.Duration
}

func gen(done <-chan struct{}, ports ...int) <-chan scanOp {
	out := make(chan scanOp, len(ports))
	go func() {
		defer close(out)
		for _, p := range ports {
			select {
			case out <- scanOp{port: p}:
			case <-done:
				return
			}
		}
	}()
	return out
}

func scan(done <-chan struct{}, in <-chan scanOp) <-chan scanOp {
	out := make(chan scanOp)
	go func() {
		defer close(out)
		for scan := range in {
			select {
			default:
				address := fmt.Sprintf("127.0.0.1:%d", scan.port)
				start := time.Now()
				conn, err := net.Dial("tcp", address)
				scan.scanDuration = time.Since(start)
				if err != nil {
					scan.scanErr = err.Error()
				} else {
					conn.Close()
					scan.open = true
				}
				out <- scan
			case <-done:
				return
			}
		}
	}()
	return out
}

func filterOpen(done <-chan struct{}, in <-chan scanOp) <-chan scanOp {
	out := make(chan scanOp)
	go func() {
		defer close(out)
		for scan := range in {
			select {
			default:
				if scan.open {
					out <- scan
				}
			case <-done:
				return
			}
		}
	}()
	return out
}

func filterErr(done <-chan struct{}, in <-chan scanOp) <-chan scanOp {
	out := make(chan scanOp)
	go func() {
		defer close(out)
		for scan := range in {
			select {
			default:
				if !scan.open && strings.Contains(scan.scanErr, "too many open files") {
					out <- scan
				}
			case <-done:
				return
			}
		}
	}()
	return out
}

func merge(done <-chan struct{}, chans ...<-chan scanOp) <-chan scanOp {
	out := make(chan scanOp)
	wg := sync.WaitGroup{}
	wg.Add(len(chans))

	for _, sc := range chans {
		go func(sc <-chan scanOp) {
			defer wg.Done()
			for scan := range sc {
				select {
				case out <- scan:
				case <-done:
					return
				}
			}
		}(sc)
	}

	go func() {
		wg.Wait()
		close(out)
	}()

	return out
}

Exercise 3: Integration Testing in an LLM world

You have been tasked with writing a program that generates embeddings for text data and storing them in a PostgreSQL database table. You will make use of the following technologies:

  1. pgvector and vector extension support in a local PostgreSQL database running via Docker container.
  2. Ollama for running LLM models locally (also running via Docker)
  3. Gorm as the ORM for migrations and data management
  4. Testcontainers for integration testing via locally running servers using Docker

Starter Code

Below is a working proof of concept. You'll need to refactor this code for testability using the same techniques you used in previous solutions. You'll also need to integrate Testcontainers in your tests.

package main

import (
	"bufio"
	"bytes"
	"context"
	"encoding/json"
	"flag"
	"log"
	"net/http"
	"os"
	"strings"
	"time"

	"github.com/pgvector/pgvector-go"
	"gorm.io/driver/postgres"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"
)

var bookPath string
var db *gorm.DB

func init() {
	flag.StringVar(&bookPath, "book", "", "path to the book ")
	flag.Parse()

	var err error
	var dsn = "postgres://postgres:password@localhost:5432/test?sslmode=disable"
	db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
	if err != nil {
		log.Fatalln("failed to connect database", err)
	}

	if err := db.Exec("CREATE EXTENSION IF NOT EXISTS vector").Error; err != nil {
		log.Fatalln("failed to create extension", err)
	}

	if err := db.AutoMigrate(&book{}, &bookEmbedding{}); err != nil {
		log.Fatalln("failed to migrate", err)
	}

	if err := db.Exec("CREATE INDEX ON book_embeddings USING hnsw (embedding vector_l2_ops)").Error; err != nil {
		log.Fatalln("failed to create index", err)
	}
}

type book struct {
	gorm.Model
	Title      string
	Author     string
	Embeddings []bookEmbedding
}

type bookEmbedding struct {
	gorm.Model
	BookID    uint
	Text      string
	Embedding pgvector.Vector `gorm:"type:vector(384)"`
}

func main() {
	ctx := context.Background()

	log.Println("Start")

	if bookPath == "" {
		log.Fatalln("book path is required")
	}

	f, err := os.Open(bookPath)
	if err != nil {
		log.Fatalln(err)
	}
	defer f.Close()

	scanner := bufio.NewScanner(f)
	var text []string
	for scanner.Scan() {
		text = append(text, scanner.Text())
	}

	const chunkSize = 512
	const chunkOverlap = 128
	var chunks []string                     // store the final chunks of text
	var currentChunkBuilder strings.Builder // helps efficiently build the current chunk of text
	var currentChunkWords int               // keeps track of the number of words in the current chunk

	for _, line := range text {
		words := strings.Fields(line) // split the line into words
		for _, word := range words {
			if currentChunkWords > 0 {
				currentChunkBuilder.WriteString(" ") // add a space before adding the next word
			}
			currentChunkBuilder.WriteString(word) // add the word to the current chunk
			currentChunkWords++                   // increment the number of words in the current chunk

			// build the full chunk
			if currentChunkWords >= chunkSize {
				chunks = append(chunks, currentChunkBuilder.String())
				overlapWords := strings.Fields(currentChunkBuilder.String())
				currentChunkBuilder.Reset()
				currentChunkWords = 0
				for i := len(overlapWords) - chunkOverlap; i < len(overlapWords); i++ {
					if currentChunkWords > 0 {
						currentChunkBuilder.WriteString(" ")
					}
					currentChunkBuilder.WriteString(overlapWords[i])
					currentChunkWords++
				}
			}
		}
	}

	// add the last chunk
	if currentChunkWords > 0 {
		chunks = append(chunks, currentChunkBuilder.String())
	}

	var b book
	if err := db.FirstOrCreate(&b, book{Title: "Meditations", Author: "Marcus Aurelius"}).Error; err != nil {
		log.Fatalln("failed to create book", err)
	}

	httpClient := http.Client{Timeout: 30 * time.Second}

	type embeddingRequest struct {
		Model  string `json:"model"`
		Prompt string `json:"prompt"`
	}

	type embeddingResponse struct {
		Embedding []float32 `json:"embedding"`
	}

	endpoint := "http://localhost:11434/api/embeddings"

	for _, chunk := range chunks {
		bs, err := json.Marshal(embeddingRequest{
			Model:  "all-minilm",
			Prompt: chunk,
		})
		if err != nil {
			log.Fatalln(err)
		}

		req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(bs))
		if err != nil {
			log.Fatalln(err)
		}

		req.Header.Set("Content-Type", "application/json")
		res, err := httpClient.Do(req)
		if err != nil {
			log.Fatalln(err)
		}
		defer res.Body.Close()

		if res.StatusCode != http.StatusOK {
			log.Fatalf("unexpected status code: %s", res.Status)
		}

		var response embeddingResponse
		if err := json.NewDecoder(res.Body).Decode(&response); err != nil {
			log.Fatalln(err)
		}

		be := bookEmbedding{
			BookID:    b.ID,
			Text:      chunk,
			Embedding: pgvector.NewVector(response.Embedding),
		}
		if err := db.Save(&be).Error; err != nil {
			log.Fatalln("failed to save book embedding", err)
		}
	}

	if err := db.Save(&b).Error; err != nil {
		log.Fatalln("failed to save book", err)
	}

	var strEmbedding []float32

	{
		str := "How short lived the praiser and the praised, the one who remembers and the remembered."
		bs, err := json.Marshal(embeddingRequest{
			Model:  "all-minilm",
			Prompt: str,
		})
		if err != nil {
			log.Fatalln(err)
		}

		req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(bs))
		if err != nil {
			log.Fatalln(err)
		}

		req.Header.Set("Content-Type", "application/json")
		res, err := httpClient.Do(req)
		if err != nil {
			log.Fatalln(err)
		}
		defer res.Body.Close()

		if res.StatusCode != http.StatusOK {
			log.Fatalf("unexpected status code: %s", res.Status)
		}

		var response embeddingResponse
		if err := json.NewDecoder(res.Body).Decode(&response); err != nil {
			log.Fatalln(err)
		}

		strEmbedding = response.Embedding
	}

	var bookEmbeddings []bookEmbedding
	db.Clauses(
		clause.OrderBy{
			Expression: clause.Expr{
				SQL: "embedding <-> ?",
				Vars: []interface{}{
					pgvector.NewVector(strEmbedding),
				},
			},
		},
	).Limit(5).Find(&bookEmbeddings)

	for _, be := range bookEmbeddings {
		log.Println(be.Text)
	}

	log.Println("Done")
}