-
Notifications
You must be signed in to change notification settings - Fork 2
Testing Concurrency
Johnny Boursiquot edited this page Jun 7, 2024
·
1 revision
At its simplest, the testing of code designed to run concurrently should be no different then the testing of non-concurrent code.
To best illustrate the point, we'll talk through the refactoring of the solution for the worker pool exercise from the previous session.
package main
import (
"errors"
"flag"
"fmt"
"net"
"os"
"runtime"
"sort"
"strconv"
"strings"
)
var host string
var ports string
var numWorkers int
func init() {
flag.StringVar(&host, "host", "127.0.0.1", "Host to scan.")
flag.StringVar(&ports, "ports", "80", "Port(s) (e.g. 80, 22-100).")
flag.IntVar(&numWorkers, "workers", runtime.NumCPU(), "Number of workers. Defaults to 10.")
}
func main() {
flag.Parse()
portsToScan, err := parsePortsToScan(ports)
if err != nil {
fmt.Printf("Failed to parse ports to scan: %s", err)
os.Exit(1)
}
portsChan := make(chan int, numWorkers)
resultsChan := make(chan int)
for i := 0; i < cap(portsChan); i++ { // numWorkers also acceptable here
go worker(host, portsChan, resultsChan)
}
go func() {
for _, p := range portsToScan {
portsChan <- p
}
}()
var openPorts []int
for i := 0; i < len(portsToScan); i++ {
if p := <-resultsChan; p != 0 { // non-zero port means it's open
openPorts = append(openPorts, p)
}
}
close(portsChan)
close(resultsChan)
fmt.Println("RESULTS")
sort.Ints(openPorts)
for _, p := range openPorts {
fmt.Printf("%d - open\n", p)
}
}
func parsePortsToScan(portsFlag string) ([]int, error) {
p, err := strconv.Atoi(portsFlag)
if err == nil {
return []int{p}, nil
}
ports := strings.Split(portsFlag, "-")
if len(ports) != 2 {
return nil, errors.New("unable to determine port(s) to scan")
}
minPort, err := strconv.Atoi(ports[0])
if err != nil {
return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[0])
}
maxPort, err := strconv.Atoi(ports[1])
if err != nil {
return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[1])
}
if minPort <= 0 || maxPort <= 0 {
return nil, fmt.Errorf("port numbers must be greater than 0")
}
var results []int
for p := minPort; p <= maxPort; p++ {
results = append(results, p)
}
return results, nil
}
func worker(host string, portsChan <-chan int, resultsChan chan<- int) {
for p := range portsChan {
address := fmt.Sprintf("%s:%d", host, p)
conn, err := net.Dial("tcp", address)
if err != nil {
fmt.Printf("%d CLOSED (%s)\n", p, err)
resultsChan <- 0
continue
}
conn.Close()
resultsChan <- p
}
}
To make the program more testable, we need to refactor it by:
- Decoupling logic from the main function.
- Using interfaces to abstract dependencies.
- Making functions pure where possible.
- Injecting dependencies to facilitate mocking during tests.
testing/concurrent/scanner/scanner.go
package scanner
import (
"fmt"
"net"
"runtime"
)
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}
var DefaultNumWorkers = runtime.NumCPU()
type TCPScanner struct {
host string
workers int
dialer Dialer
}
func (s *TCPScanner) validate() error {
if s.workers < 1 {
return fmt.Errorf("invalid number of workers: %d", s.workers)
}
if s.dialer == nil {
return fmt.Errorf("dialer is required")
}
return nil
}
func NewTCPScanner(host string, workers int, dialer Dialer) (*TCPScanner, error) {
s := &TCPScanner{host: host, workers: workers, dialer: dialer}
return s, s.validate()
}
type scanner interface {
Scan(ports []int) ([]int, error)
}
// Compile-time check to verify TCPScanner implements the Scanner interface.
var _ scanner = &TCPScanner{}
// Scan scans the specified ports.
func (s *TCPScanner) Scan(ports []int) ([]int, error) {
portsChan := make(chan int, s.workers)
resultsChan := make(chan int)
for i := 0; i < s.workers; i++ {
go s.worker(portsChan, resultsChan)
}
go func() {
for _, p := range ports {
portsChan <- p
}
close(portsChan)
}()
var openPorts []int
for i := 0; i < len(ports); i++ {
if p := <-resultsChan; p != 0 {
openPorts = append(openPorts, p)
}
}
close(resultsChan)
return openPorts, nil
}
// worker scans ports and sends results to the results channel.
func (s *TCPScanner) worker(portsChan <-chan int, resultsChan chan<- int) {
for p := range portsChan {
if s.scan(p) {
resultsChan <- p
} else {
resultsChan <- 0
}
}
}
func (s *TCPScanner) scan(port int) bool {
address := fmt.Sprintf("%s:%d", s.host, port)
conn, err := s.dialer.Dial("tcp", address)
if err != nil {
return false
}
conn.Close()
return true
}
testing/concurrent/scanner/scanner_test.go
package scanner_test
import (
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/idiomat/dodtnyt/testing/concurrent/scanner"
)
func TestNewTCPScanner(t *testing.T) {
tests := map[string]struct {
workers int
dialer scanner.Dialer
wantErr bool
}{
"valid configuration": {
workers: 2,
dialer: &MockDialer{},
wantErr: false,
},
"invalid number of workers": {
workers: 0,
dialer: &MockDialer{},
wantErr: true,
},
"nil dialer": {
workers: 2,
dialer: nil,
wantErr: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
_, err := scanner.NewTCPScanner("localhost", tt.workers, tt.dialer)
if (err != nil) != tt.wantErr {
t.Errorf("NewTCPScanner(%d, %v) error = %v, wantErr %v", tt.workers, tt.dialer, err, tt.wantErr)
}
})
}
}
func TestTCPScanner_Scan(t *testing.T) {
tests := map[string]struct {
openPorts map[int]bool
portsToScan []int
expectedOpenPorts []int
}{
"mixed open and closed ports": {
openPorts: map[int]bool{80: true, 81: false, 82: true},
portsToScan: []int{80, 81, 82},
expectedOpenPorts: []int{80, 82},
},
"all ports closed": {
openPorts: map[int]bool{80: false, 81: false, 82: false},
portsToScan: []int{80, 81, 82},
expectedOpenPorts: []int{},
},
"all ports open": {
openPorts: map[int]bool{80: true, 81: true, 82: true},
portsToScan: []int{80, 81, 82},
expectedOpenPorts: []int{80, 81, 82},
},
"no ports to scan": {
openPorts: map[int]bool{80: true, 81: false, 82: true},
portsToScan: []int{},
expectedOpenPorts: []int{},
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
mockDialer := &MockDialer{
openPorts: tt.openPorts,
}
scanner, err := scanner.NewTCPScanner("localhost", scanner.DefaultNumWorkers, mockDialer)
if err != nil {
t.Fatalf("failed to create scanner: %v", err)
}
openPorts, err := scanner.Scan(tt.portsToScan)
if err != nil {
t.Errorf("TCPScanner.Scan() error = %v", err)
}
if !equal(openPorts, tt.expectedOpenPorts) {
t.Errorf("TCPScanner.Scan() = %v, want %v", openPorts, tt.expectedOpenPorts)
}
})
}
}
func equal(a, b []int) bool {
if len(a) != len(b) {
return false
}
// Create a map to count occurrences of each element in 'a'
counts := make(map[int]int)
for _, v := range a {
counts[v]++
}
// Check elements in 'b' against the map
for _, v := range b {
if counts[v] == 0 {
return false
}
counts[v]--
}
return true
}
// MockConn is a mock implementation of the net.Conn interface.
type MockConn struct{}
func (mc *MockConn) Read(b []byte) (n int, err error) {
return 0, nil
}
func (mc *MockConn) Write(b []byte) (n int, err error) {
return len(b), nil
}
func (mc *MockConn) Close() error {
return nil
}
func (mc *MockConn) LocalAddr() net.Addr {
return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}
}
func (mc *MockConn) RemoteAddr() net.Addr {
return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}
}
func (mc *MockConn) SetDeadline(t time.Time) error {
return nil
}
func (mc *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (mc *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
// MockDialer is a mock implementation of the dialer interface.
type MockDialer struct {
openPorts map[int]bool
}
func (m *MockDialer) Dial(network, address string) (net.Conn, error) {
var port int
fmt.Sscanf(address, "127.0.0.1:%d", &port)
if m.openPorts[port] {
return &MockConn{}, nil
}
return nil, errors.New("connection refused")
}
testing/concurrent/main.go
package main
import (
"errors"
"flag"
"fmt"
"net"
"os"
"runtime"
"sort"
"strconv"
"strings"
"github.com/idiomat/dodtnyt/testing/concurrent/scanner"
)
var host string
var ports string
var numWorkers int
func init() {
flag.StringVar(&host, "host", "127.0.0.1", "Host to scan.")
flag.StringVar(&ports, "ports", "80", "Port(s) (e.g. 80, 22-100).")
flag.IntVar(&numWorkers, "workers", runtime.NumCPU(), "Number of workers. Defaults to 10.")
}
func main() {
flag.Parse()
portsToScan, err := parsePortsToScan(ports)
if err != nil {
fmt.Printf("failed to parse ports to scan: %s\n", err)
os.Exit(1)
}
tcpScanner, err := scanner.NewTCPScanner(numWorkers, &net.Dialer{})
if err != nil {
fmt.Printf("failed to create TCP scanner: %s\n", err)
os.Exit(1)
}
openPorts, err := tcpScanner.Scan(host, portsToScan)
if err != nil {
fmt.Printf("failed to scan ports: %s\n", err)
os.Exit(1)
}
fmt.Println("RESULTS")
sort.Ints(openPorts)
for _, p := range openPorts {
fmt.Printf("%d - open\n", p)
}
}
// parsePortsToScan parses the ports string and returns a slice of ports to scan.
func parsePortsToScan(portsFlag string) ([]int, error) {
p, err := strconv.Atoi(portsFlag)
if err == nil {
return []int{p}, nil
}
ports := strings.Split(portsFlag, "-")
if len(ports) != 2 {
return nil, errors.New("unable to determine port(s) to scan")
}
minPort, err := strconv.Atoi(ports[0])
if err != nil {
return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[0])
}
maxPort, err := strconv.Atoi(ports[1])
if err != nil {
return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[1])
}
if minPort <= 0 || maxPort <= 0 {
return nil, fmt.Errorf("port numbers must be greater than 0")
}
var results []int
for p := minPort; p <= maxPort; p++ {
results = append(results, p)
}
return results, nil
}
Now that you've seen how you can refactor concurrent code to be more testable, it's time to give it a shot by applying the same technique to the fan-out/fan-in
exercise from the previous session. Head on over.