Skip to content

Commit

Permalink
Fix UseDNSSystemWide
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Jan 3, 2021
1 parent 2b70eec commit 9c3e47c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
26 changes: 11 additions & 15 deletions pkg/unbound/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ func (c *configurator) UseDNSInternally(ip net.IP) {

// UseDNSSystemWide changes the nameserver to use for DNS system wide.
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
const filepath = resolvConfFilepath
file, err := c.openFile(filepath, os.O_RDWR|os.O_TRUNC, 0644)
file, err := c.openFile(resolvConfFilepath, os.O_RDWR|os.O_TRUNC, 0644)
if err != nil {
return err
}
Expand All @@ -29,23 +28,20 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
_ = file.Close()
return err
}

s := strings.TrimSuffix(string(data), "\n")
lines := strings.Split(s, "\n")
if len(lines) == 1 && lines[0] == "" {
lines = nil

lines := []string{
"nameserver " + ip.String(),
}
found := false
if !keepNameserver { // default
for i := range lines {
if strings.HasPrefix(lines[i], "nameserver ") {
lines[i] = "nameserver " + ip.String()
found = true
}
for _, line := range strings.Split(s, "\n") {
if line == "" ||
(!keepNameserver && strings.HasPrefix(line, "nameserver ")) {
continue
}
lines = append(lines, line)
}
if !found {
lines = append(lines, "nameserver "+ip.String())
}

s = strings.Join(lines, "\n") + "\n"
_, err = file.WriteString(s)
if err != nil {
Expand Down
33 changes: 23 additions & 10 deletions pkg/unbound/nameserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@ import (
func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel()
tests := map[string]struct {
data []byte
writtenData string
openErr error
readErr error
writeErr error
closeErr error
err error
ip net.IP
keepNameserver bool
data []byte
writtenData string
openErr error
readErr error
writeErr error
closeErr error
err error
}{
"no data": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n",
},
"open error": {
ip: net.IP{127, 0, 0, 1},
openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
Expand All @@ -36,17 +40,26 @@ func Test_UseDNSSystemWide(t *testing.T) {
err: fmt.Errorf("error"),
},
"write error": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n",
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"lines without nameserver": {
ip: net.IP{127, 0, 0, 1},
data: []byte("abc\ndef\n"),
writtenData: "abc\ndef\nnameserver 127.0.0.1\n",
writtenData: "nameserver 127.0.0.1\nabc\ndef\n",
},
"lines with nameserver": {
ip: net.IP{127, 0, 0, 1},
data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: "abc\nnameserver 127.0.0.1\ndef\n",
writtenData: "nameserver 127.0.0.1\nabc\ndef\n",
},
"keep nameserver": {
ip: net.IP{127, 0, 0, 1},
keepNameserver: true,
data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: "nameserver 127.0.0.1\nabc\nnameserver abc def\ndef\n",
},
}
for name, tc := range tests {
Expand Down Expand Up @@ -89,7 +102,7 @@ func Test_UseDNSSystemWide(t *testing.T) {
c := &configurator{
openFile: openFile,
}
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
err := c.UseDNSSystemWide(tc.ip, tc.keepNameserver)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
Expand Down

0 comments on commit 9c3e47c

Please sign in to comment.