Skip to content

Commit f696fc1

Browse files
author
Drew MacInnis
authored
Lookup alias if URL matches (#7)
1 parent 6e660e7 commit f696fc1

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

backend/vault.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,17 @@ func (v *Vault) resolveArgs(args []string) error {
334334
return errors.New("cannot parse url")
335335
}
336336
v.path = v.origURL.Path
337-
v.name = v.origURL.Hostname()
338-
u := core.ResolveAlias(v.viper, v.name)
337+
alias := core.ReverseLookupAlias(v.viper, v.origURL)
338+
if alias == "" {
339+
alias = v.origURL.Hostname()
340+
}
341+
u := core.ResolveAlias(v.viper, alias)
339342
if u != nil {
340343
v.url = u
341344
} else {
342345
v.url = v.origURL
343346
}
347+
v.name = alias
344348
log.Printf("%s using url: %v\n", v.name, v.url)
345349
return nil
346350
}

core/common.go

+13
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,16 @@ func ResolveAlias(v *viper.Viper, alias string) *url.URL {
3030
}
3131
return nil
3232
}
33+
34+
// ReverseLookupAlias from a URL
35+
func ReverseLookupAlias(v *viper.Viper, u *url.URL) string {
36+
urlMap := make(map[string]string, 1)
37+
aliases := v.GetStringMap("vault")
38+
for alias := range aliases {
39+
vkey := fmt.Sprintf("vault.%s.url", alias)
40+
aliasURL := v.GetString(vkey)
41+
urlMap[aliasURL] = alias
42+
}
43+
uStr := fmt.Sprintf("%s://%s:%s", u.Scheme, u.Hostname(), u.Port())
44+
return urlMap[uStr]
45+
}

core/common_test.go

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package core
2+
3+
import (
4+
"net/url"
5+
"testing"
6+
7+
"github.com/spf13/viper"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func getViper(file string) *viper.Viper {
12+
v := viper.New()
13+
v.SetConfigFile(file)
14+
v.ReadInConfig()
15+
return v
16+
}
17+
18+
var reverseLookupAliasTests = []struct {
19+
url string
20+
expect string
21+
}{
22+
{"http://localhost:8201", "vault-a"},
23+
{"https://localhost:8201", ""}, // wrong scheme
24+
{"https://unknown-tls", ""}, // unknown URL
25+
{"https://unknown-tls:8200", ""}, // unknown URL
26+
}
27+
28+
func TestReverseLookupAlias_Valid(t *testing.T) {
29+
v := getViper("../testdata/syncrets-test1.yml")
30+
for _, tc := range reverseLookupAliasTests {
31+
u, _ := url.Parse(tc.url)
32+
alias := ReverseLookupAlias(v, u)
33+
assert.Equal(t, tc.expect, alias)
34+
}
35+
}

0 commit comments

Comments
 (0)