Skip to content

Commit

Permalink
Maint: do not mock os.OpenFile
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Jul 23, 2021
1 parent f7e8b9d commit 841eddc
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 176 deletions.
9 changes: 3 additions & 6 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/qdm12/dns/pkg/nameserver"
"github.com/qdm12/dns/pkg/unbound"
"github.com/qdm12/golibs/logging"
customOS "github.com/qdm12/golibs/os"
"github.com/qdm12/updated/pkg/dnscrypto"
)

Expand All @@ -43,11 +42,10 @@ func main() {
args := os.Args
logger := logging.NewParent(logging.Settings{})
configReader := config.NewReader(logger)
osIntf := customOS.New()

errorCh := make(chan error)
go func() {
errorCh <- _main(ctx, buildInfo, args, logger, configReader, osIntf)
errorCh <- _main(ctx, buildInfo, args, logger, configReader)
}()

select {
Expand Down Expand Up @@ -78,8 +76,7 @@ func main() {
}

func _main(ctx context.Context, buildInfo models.BuildInformation,
args []string, logger logging.ParentLogger, configReader config.Reader,
os customOS.OS) error {
args []string, logger logging.ParentLogger, configReader config.Reader) error {
if health.IsClientMode(args) {
// Running the program in a separate instance through the Docker
// built-in healthcheck, in an ephemeral fashion to query the
Expand All @@ -99,7 +96,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
const unboundEtcDir = "/unbound"
const unboundPath = "/unbound/unbound"
const cacertsPath = "/unbound/ca-certificates.crt"
dnsConf := unbound.NewConfigurator(logger, os.OpenFile, dnsCrypto, unboundEtcDir, unboundPath, cacertsPath)
dnsConf := unbound.NewConfigurator(logger, dnsCrypto, unboundEtcDir, unboundPath, cacertsPath)

if len(args) > 1 && args[1] == "build" {
return dnsConf.SetupFiles(ctx)
Expand Down
16 changes: 9 additions & 7 deletions pkg/nameserver/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import (
"context"
"io/ioutil"
"net"
"os"
"strings"

"github.com/qdm12/golibs/os"
)

// UseDNSInternally is to change the Go program DNS only.
Expand All @@ -18,11 +17,14 @@ func UseDNSInternally(ip net.IP) { //nolint:interfacer
}
}

const resolvConfFilepath = "/etc/resolv.conf"

// UseDNSSystemWide changes the nameserver to use for DNS system wide.
func UseDNSSystemWide(openFile os.OpenFileFunc, ip net.IP, keepNameserver bool) error { //nolint:interfacer
file, err := openFile(resolvConfFilepath, os.O_RDONLY, 0)
// If resolvConfPath is empty, it defaults to /etc/resolv.conf.
func UseDNSSystemWide(resolvConfPath string, ip net.IP, keepNameserver bool) error { //nolint:interfacer
const defaultResolvConfPath = "/etc/resolv.conf"
if resolvConfPath == "" {
resolvConfPath = defaultResolvConfPath
}
file, err := os.Open(resolvConfPath)
if err != nil {
return err
}
Expand Down Expand Up @@ -50,7 +52,7 @@ func UseDNSSystemWide(openFile os.OpenFileFunc, ip net.IP, keepNameserver bool)

s = strings.Join(lines, "\n") + "\n"

file, err = openFile(resolvConfFilepath, os.O_WRONLY|os.O_TRUNC, 0644)
file, err = os.OpenFile(resolvConfPath, os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
Expand Down
233 changes: 81 additions & 152 deletions pkg/nameserver/nameserver_test.go
Original file line number Diff line number Diff line change
@@ -1,167 +1,96 @@
package nameserver

import (
"fmt"
"io"
"net"
"os"
"path/filepath"
"testing"

"github.com/golang/mock/gomock"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/os/mock_os"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel()

tests := map[string]struct {
ip net.IP
keepNameserver bool
data []byte
firstOpenErr error
readErr error
firstCloseErr error
secondOpenErr error
writtenData string
writeErr error
secondCloseErr error
err error
}{
"no data": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n",
},
"first open error": {
ip: net.IP{127, 0, 0, 1},
firstOpenErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"first close error": {
firstCloseErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"second open error": {
ip: net.IP{127, 0, 0, 1},
secondOpenErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n",
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"second close error": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n",
secondCloseErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"lines without nameserver": {
ip: net.IP{127, 0, 0, 1},
data: []byte("abc\ndef\n"),
writtenData: "nameserver 127.0.0.1\nabc\ndef\n",
},
"lines with nameserver": {
ip: net.IP{127, 0, 0, 1},
data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: "nameserver 127.0.0.1\nabc\ndef\n",
},
"keep nameserver": {
ip: net.IP{127, 0, 0, 1},
keepNameserver: true,
data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: "nameserver 127.0.0.1\nabc\nnameserver abc def\ndef\n",
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)

type fileCall struct {
path string
flag int
perm os.FileMode
file os.File
err error
}

var fileCalls []fileCall

readOnlyFile := mock_os.NewMockFile(mockCtrl)

if tc.firstOpenErr == nil {
firstReadCall := readOnlyFile.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
DoAndReturn(func(b []byte) (int, error) {
copy(b, tc.data)
return len(tc.data), nil
})
readErr := tc.readErr
if readErr == nil {
readErr = io.EOF
}
finalReadCall := readOnlyFile.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
Return(0, readErr).After(firstReadCall)
readOnlyFile.EXPECT().Close().
Return(tc.firstCloseErr).
After(finalReadCall)
}

fileCalls = append(fileCalls, fileCall{
path: resolvConfFilepath,
flag: os.O_RDONLY,
perm: 0,
file: readOnlyFile,
err: tc.firstOpenErr,
}) // always return readOnlyFile

if tc.firstOpenErr == nil && tc.readErr == nil && tc.firstCloseErr == nil {
writeOnlyFile := mock_os.NewMockFile(mockCtrl)
if tc.secondOpenErr == nil {
writeCall := writeOnlyFile.EXPECT().
WriteString(tc.writtenData).
Return(0, tc.writeErr)
writeOnlyFile.EXPECT().
Close().
Return(tc.secondCloseErr).
After(writeCall)
}
fileCalls = append(fileCalls, fileCall{
path: resolvConfFilepath,
flag: os.O_WRONLY | os.O_TRUNC,
perm: os.FileMode(0644),
file: writeOnlyFile,
err: tc.secondOpenErr,
})
}

fileCallIndex := 0
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
fileCall := fileCalls[fileCallIndex]
fileCallIndex++
assert.Equal(t, fileCall.path, name)
assert.Equal(t, fileCall.flag, flag)
assert.Equal(t, fileCall.perm, perm)
return fileCall.file, fileCall.err
}

err := UseDNSSystemWide(openFile, tc.ip, tc.keepNameserver)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
t.Run("file does not exist", func(t *testing.T) {
t.Parallel()

dirPath, err := os.MkdirTemp("", "")
require.NoError(t, err)
defer func() {
err := os.RemoveAll(dirPath)
require.NoError(t, err)
}()

resolvConfPath := filepath.Join(dirPath, "resolv.conf")
ip := net.IP{1, 1, 1, 1}
const keepNameserver = false

err = UseDNSSystemWide(resolvConfPath, ip, keepNameserver)

require.Error(t, err)
assert.Equal(t, "open "+resolvConfPath+": no such file or directory", err.Error())
})

t.Run("empty file", func(t *testing.T) {
t.Parallel()

file, err := os.CreateTemp("", "")
require.NoError(t, err)
err = file.Close()
require.NoError(t, err)

resolvConfPath := file.Name()

defer func() {
err := os.Remove(resolvConfPath)
require.NoError(t, err)
}()

ip := net.IP{1, 1, 1, 1}
const keepNameserver = false

err = UseDNSSystemWide(resolvConfPath, ip, keepNameserver)

require.NoError(t, err)

file, err = os.Open(resolvConfPath)
require.NoError(t, err)
b, err := io.ReadAll(file)
require.NoError(t, err)
assert.Equal(t, "nameserver 1.1.1.1\n", string(b))
})

t.Run("preserve nameserver", func(t *testing.T) {
t.Parallel()

file, err := os.CreateTemp("", "")
require.NoError(t, err)
_, err = io.WriteString(file, "nameserver 1.2.3.4\n\n")
require.NoError(t, err)
err = file.Close()
require.NoError(t, err)

resolvConfPath := file.Name()

defer func() {
err := os.Remove(resolvConfPath)
require.NoError(t, err)
}()

ip := net.IP{1, 1, 1, 1}
const keepNameserver = true

err = UseDNSSystemWide(resolvConfPath, ip, keepNameserver)

require.NoError(t, err)

file, err = os.Open(resolvConfPath)
require.NoError(t, err)
b, err := io.ReadAll(file)
require.NoError(t, err)
assert.Equal(t, "nameserver 1.1.1.1\nnameserver 1.2.3.4\n", string(b))
})
}
5 changes: 2 additions & 3 deletions pkg/unbound/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ package unbound

import (
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"

"github.com/qdm12/golibs/os"
)

func (c *configurator) MakeUnboundConf(settings Settings) (err error) {
configFilepath := filepath.Join(c.unboundEtcDir, unboundConfigFilename)
file, err := c.openFile(configFilepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
file, err := os.OpenFile(configFilepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
Expand Down
7 changes: 2 additions & 5 deletions pkg/unbound/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/updated/pkg/dnscrypto"
)

Expand All @@ -18,18 +17,16 @@ type Configurator interface {
}

type configurator struct {
openFile os.OpenFileFunc
commander command.Commander
dnscrypto dnscrypto.DNSCrypto
unboundEtcDir string
unboundPath string
cacertsPath string
}

func NewConfigurator(logger logging.Logger, openFile os.OpenFileFunc,
dnscrypto dnscrypto.DNSCrypto, unboundEtcDir, unboundPath, cacertsPath string) Configurator {
func NewConfigurator(logger logging.Logger, dnscrypto dnscrypto.DNSCrypto,
unboundEtcDir, unboundPath, cacertsPath string) Configurator {
return &configurator{
openFile: openFile,
commander: command.NewCommander(),
dnscrypto: dnscrypto,
unboundEtcDir: unboundEtcDir,
Expand Down
2 changes: 1 addition & 1 deletion pkg/unbound/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const includeConfFilename = "include.conf"

func (c *configurator) createEmptyIncludeConf() error {
filepath := filepath.Join(c.unboundEtcDir, includeConfFilename)
file, err := c.openFile(filepath, os.O_CREATE, 0644)
file, err := os.OpenFile(filepath, os.O_CREATE, 0644)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 841eddc

Please sign in to comment.