Skip to content

Testing Concurrency

Johnny Boursiquot edited this page Jun 7, 2024 · 1 revision

At its simplest, the testing of code designed to run concurrently should be no different then the testing of non-concurrent code.

To best illustrate the point, we'll talk through the refactoring of the solution for the worker pool exercise from the previous session.

package main

import (
	"errors"
	"flag"
	"fmt"
	"net"
	"os"
	"runtime"
	"sort"
	"strconv"
	"strings"
)

var host string
var ports string
var numWorkers int

func init() {
	flag.StringVar(&host, "host", "127.0.0.1", "Host to scan.")
	flag.StringVar(&ports, "ports", "80", "Port(s) (e.g. 80, 22-100).")
	flag.IntVar(&numWorkers, "workers", runtime.NumCPU(), "Number of workers. Defaults to 10.")
}

func main() {
	flag.Parse()

	portsToScan, err := parsePortsToScan(ports)
	if err != nil {
		fmt.Printf("Failed to parse ports to scan: %s", err)
		os.Exit(1)
	}

	portsChan := make(chan int, numWorkers)
	resultsChan := make(chan int)

	for i := 0; i < cap(portsChan); i++ { // numWorkers also acceptable here
		go worker(host, portsChan, resultsChan)
	}

	go func() {
		for _, p := range portsToScan {
			portsChan <- p
		}
	}()

	var openPorts []int
	for i := 0; i < len(portsToScan); i++ {
		if p := <-resultsChan; p != 0 { // non-zero port means it's open
			openPorts = append(openPorts, p)
		}
	}

	close(portsChan)
	close(resultsChan)

	fmt.Println("RESULTS")
	sort.Ints(openPorts)
	for _, p := range openPorts {
		fmt.Printf("%d - open\n", p)
	}
}

func parsePortsToScan(portsFlag string) ([]int, error) {
	p, err := strconv.Atoi(portsFlag)
	if err == nil {
		return []int{p}, nil
	}

	ports := strings.Split(portsFlag, "-")
	if len(ports) != 2 {
		return nil, errors.New("unable to determine port(s) to scan")
	}

	minPort, err := strconv.Atoi(ports[0])
	if err != nil {
		return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[0])
	}

	maxPort, err := strconv.Atoi(ports[1])
	if err != nil {
		return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[1])
	}

	if minPort <= 0 || maxPort <= 0 {
		return nil, fmt.Errorf("port numbers must be greater than 0")
	}

	var results []int
	for p := minPort; p <= maxPort; p++ {
		results = append(results, p)
	}
	return results, nil
}

func worker(host string, portsChan <-chan int, resultsChan chan<- int) {
	for p := range portsChan {
		address := fmt.Sprintf("%s:%d", host, p)
		conn, err := net.Dial("tcp", address)
		if err != nil {
			fmt.Printf("%d CLOSED (%s)\n", p, err)
			resultsChan <- 0
			continue
		}
		conn.Close()
		resultsChan <- p
	}
}

To make the program more testable, we need to refactor it by:

  • Decoupling logic from the main function.
  • Using interfaces to abstract dependencies.
  • Making functions pure where possible.
  • Injecting dependencies to facilitate mocking during tests.

testing/concurrent/scanner/scanner.go

package scanner

import (
	"fmt"
	"net"
	"runtime"
)

type Dialer interface {
	Dial(network, address string) (net.Conn, error)
}

var DefaultNumWorkers = runtime.NumCPU()

type TCPScanner struct {
	host    string
	workers int
	dialer  Dialer
}

func (s *TCPScanner) validate() error {
	if s.workers < 1 {
		return fmt.Errorf("invalid number of workers: %d", s.workers)
	}
	if s.dialer == nil {
		return fmt.Errorf("dialer is required")
	}
	return nil
}

func NewTCPScanner(host string, workers int, dialer Dialer) (*TCPScanner, error) {
	s := &TCPScanner{host: host, workers: workers, dialer: dialer}
	return s, s.validate()
}

type scanner interface {
	Scan(ports []int) ([]int, error)
}

// Compile-time check to verify TCPScanner implements the Scanner interface.
var _ scanner = &TCPScanner{}

// Scan scans the specified ports.
func (s *TCPScanner) Scan(ports []int) ([]int, error) {
	portsChan := make(chan int, s.workers)
	resultsChan := make(chan int)

	for i := 0; i < s.workers; i++ {
		go s.worker(portsChan, resultsChan)
	}

	go func() {
		for _, p := range ports {
			portsChan <- p
		}
		close(portsChan)
	}()

	var openPorts []int
	for i := 0; i < len(ports); i++ {
		if p := <-resultsChan; p != 0 {
			openPorts = append(openPorts, p)
		}
	}
	close(resultsChan)

	return openPorts, nil
}

// worker scans ports and sends results to the results channel.
func (s *TCPScanner) worker(portsChan <-chan int, resultsChan chan<- int) {
	for p := range portsChan {
		if s.scan(p) {
			resultsChan <- p
		} else {
			resultsChan <- 0
		}
	}
}

func (s *TCPScanner) scan(port int) bool {
	address := fmt.Sprintf("%s:%d", s.host, port)
	conn, err := s.dialer.Dial("tcp", address)
	if err != nil {
		return false
	}
	conn.Close()
	return true
}

testing/concurrent/scanner/scanner_test.go

package scanner_test

import (
	"errors"
	"fmt"
	"net"
	"testing"
	"time"

	"github.com/idiomat/dodtnyt/testing/concurrent/scanner"
)

func TestNewTCPScanner(t *testing.T) {
	tests := map[string]struct {
		workers int
		dialer  scanner.Dialer
		wantErr bool
	}{
		"valid configuration": {
			workers: 2,
			dialer:  &MockDialer{},
			wantErr: false,
		},
		"invalid number of workers": {
			workers: 0,
			dialer:  &MockDialer{},
			wantErr: true,
		},
		"nil dialer": {
			workers: 2,
			dialer:  nil,
			wantErr: true,
		},
	}

	for name, tt := range tests {
		t.Run(name, func(t *testing.T) {
			_, err := scanner.NewTCPScanner("localhost", tt.workers, tt.dialer)
			if (err != nil) != tt.wantErr {
				t.Errorf("NewTCPScanner(%d, %v) error = %v, wantErr %v", tt.workers, tt.dialer, err, tt.wantErr)
			}
		})
	}
}

func TestTCPScanner_Scan(t *testing.T) {
	tests := map[string]struct {
		openPorts         map[int]bool
		portsToScan       []int
		expectedOpenPorts []int
	}{
		"mixed open and closed ports": {
			openPorts:         map[int]bool{80: true, 81: false, 82: true},
			portsToScan:       []int{80, 81, 82},
			expectedOpenPorts: []int{80, 82},
		},
		"all ports closed": {
			openPorts:         map[int]bool{80: false, 81: false, 82: false},
			portsToScan:       []int{80, 81, 82},
			expectedOpenPorts: []int{},
		},
		"all ports open": {
			openPorts:         map[int]bool{80: true, 81: true, 82: true},
			portsToScan:       []int{80, 81, 82},
			expectedOpenPorts: []int{80, 81, 82},
		},
		"no ports to scan": {
			openPorts:         map[int]bool{80: true, 81: false, 82: true},
			portsToScan:       []int{},
			expectedOpenPorts: []int{},
		},
	}

	for name, tt := range tests {
		t.Run(name, func(t *testing.T) {
			mockDialer := &MockDialer{
				openPorts: tt.openPorts,
			}
			scanner, err := scanner.NewTCPScanner("localhost", scanner.DefaultNumWorkers, mockDialer)
			if err != nil {
				t.Fatalf("failed to create scanner: %v", err)
			}

			openPorts, err := scanner.Scan(tt.portsToScan)
			if err != nil {
				t.Errorf("TCPScanner.Scan() error = %v", err)
			}

			if !equal(openPorts, tt.expectedOpenPorts) {
				t.Errorf("TCPScanner.Scan() = %v, want %v", openPorts, tt.expectedOpenPorts)
			}
		})
	}
}

func equal(a, b []int) bool {
	if len(a) != len(b) {
		return false
	}

	// Create a map to count occurrences of each element in 'a'
	counts := make(map[int]int)
	for _, v := range a {
		counts[v]++
	}

	// Check elements in 'b' against the map
	for _, v := range b {
		if counts[v] == 0 {
			return false
		}
		counts[v]--
	}

	return true
}

// MockConn is a mock implementation of the net.Conn interface.
type MockConn struct{}

func (mc *MockConn) Read(b []byte) (n int, err error) {
	return 0, nil
}

func (mc *MockConn) Write(b []byte) (n int, err error) {
	return len(b), nil
}

func (mc *MockConn) Close() error {
	return nil
}

func (mc *MockConn) LocalAddr() net.Addr {
	return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}
}

func (mc *MockConn) RemoteAddr() net.Addr {
	return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}
}

func (mc *MockConn) SetDeadline(t time.Time) error {
	return nil
}

func (mc *MockConn) SetReadDeadline(t time.Time) error {
	return nil
}

func (mc *MockConn) SetWriteDeadline(t time.Time) error {
	return nil
}

// MockDialer is a mock implementation of the dialer interface.
type MockDialer struct {
	openPorts map[int]bool
}

func (m *MockDialer) Dial(network, address string) (net.Conn, error) {
	var port int
	fmt.Sscanf(address, "127.0.0.1:%d", &port)
	if m.openPorts[port] {
		return &MockConn{}, nil
	}
	return nil, errors.New("connection refused")
}

testing/concurrent/main.go

package main

import (
	"errors"
	"flag"
	"fmt"
	"net"
	"os"
	"runtime"
	"sort"
	"strconv"
	"strings"

	"github.com/idiomat/dodtnyt/testing/concurrent/scanner"
)

var host string
var ports string
var numWorkers int

func init() {
	flag.StringVar(&host, "host", "127.0.0.1", "Host to scan.")
	flag.StringVar(&ports, "ports", "80", "Port(s) (e.g. 80, 22-100).")
	flag.IntVar(&numWorkers, "workers", runtime.NumCPU(), "Number of workers. Defaults to 10.")
}

func main() {
	flag.Parse()

	portsToScan, err := parsePortsToScan(ports)
	if err != nil {
		fmt.Printf("failed to parse ports to scan: %s\n", err)
		os.Exit(1)
	}

	tcpScanner, err := scanner.NewTCPScanner(numWorkers, &net.Dialer{})
	if err != nil {
		fmt.Printf("failed to create TCP scanner: %s\n", err)
		os.Exit(1)
	}

	openPorts, err := tcpScanner.Scan(host, portsToScan)
	if err != nil {
		fmt.Printf("failed to scan ports: %s\n", err)
		os.Exit(1)
	}

	fmt.Println("RESULTS")
	sort.Ints(openPorts)
	for _, p := range openPorts {
		fmt.Printf("%d - open\n", p)
	}
}

// parsePortsToScan parses the ports string and returns a slice of ports to scan.
func parsePortsToScan(portsFlag string) ([]int, error) {
	p, err := strconv.Atoi(portsFlag)
	if err == nil {
		return []int{p}, nil
	}

	ports := strings.Split(portsFlag, "-")
	if len(ports) != 2 {
		return nil, errors.New("unable to determine port(s) to scan")
	}

	minPort, err := strconv.Atoi(ports[0])
	if err != nil {
		return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[0])
	}

	maxPort, err := strconv.Atoi(ports[1])
	if err != nil {
		return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[1])
	}

	if minPort <= 0 || maxPort <= 0 {
		return nil, fmt.Errorf("port numbers must be greater than 0")
	}

	var results []int
	for p := minPort; p <= maxPort; p++ {
		results = append(results, p)
	}
	return results, nil
}

Exercise 2

Now that you've seen how you can refactor concurrent code to be more testable, it's time to give it a shot by applying the same technique to the fan-out/fan-in exercise from the previous session. Head on over.

Clone this wiki locally