-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.go
127 lines (109 loc) · 3.37 KB
/
app.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package main
import (
"bufio"
"context"
"flag"
"fmt"
"log"
"math/rand"
"os"
"strconv"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/ssm"
"github.com/aws/aws-sdk-go-v2/service/sts"
stsTypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
)
var (
session_prefix string = "unraid-secrets"
keyfile_location string = "/root/keyfile"
default_aws_region string = "us-east-1"
)
type SSMGetParameterAPI interface {
GetParameter(ctx context.Context,
params *ssm.GetParameterInput,
optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error)
}
func findParameter(c context.Context, api SSMGetParameterAPI, input *ssm.GetParameterInput) (*ssm.GetParameterOutput, error) {
return api.GetParameter(c, input)
}
func check(e error) {
if e != nil {
log.Fatalf("%v", e)
panic(e)
}
}
func main() {
log.SetOutput(os.Stdout)
// Allow the region to be set via command line argument, or default to 'us-east-1'
pRegion := flag.String("region", default_aws_region, "The AWS region to operate in")
pRoleARN := flag.String("role-arn", "", "The role ARN to use for accessing the secret (mandatory)")
pParamPath := flag.String("param-path", "", "The SSM parameter path to retrieve (mandatory)")
pKeyPath := flag.String("key-path", keyfile_location, "The path to write the keyfile to")
required := []string{"role-arn", "param-path"}
flag.Parse()
seen := make(map[string]bool)
flag.Visit(func(f *flag.Flag) { seen[f.Name] = true })
for _, req := range required {
if !seen[req] {
log.Fatalf("Missing required '%v' argument", req)
os.Exit(1)
}
}
log.Printf("AWS Region: %s", *pRegion)
log.Printf("AWS Role ARN: %s", *pRoleARN)
log.Printf("AWS SSM Parameter Path: %s", *pParamPath)
log.Printf("Output Key Path: %s", *pKeyPath)
cfg, err := config.LoadDefaultConfig(context.TODO(),
config.WithRegion(*pRegion),
)
check(err)
sourceAccount := sts.NewFromConfig(cfg)
// Assume the role and extract the credentials
rand.Seed(time.Now().UnixNano())
response, err := sourceAccount.AssumeRole(context.TODO(), &sts.AssumeRoleInput{
RoleArn: aws.String(*pRoleARN),
RoleSessionName: aws.String(session_prefix + strconv.Itoa(10000+rand.Intn(25000))),
})
if err != nil {
log.Fatalf("Unable to assume target role, %v", err)
os.Exit(1)
}
var assumedRoleCreds *stsTypes.Credentials = response.Credentials
cfg, err = config.LoadDefaultConfig(
context.TODO(),
config.WithRegion(*pRegion),
config.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(
*assumedRoleCreds.AccessKeyId,
*assumedRoleCreds.SecretAccessKey,
*assumedRoleCreds.SessionToken),
),
)
if err != nil {
log.Fatalf("Unable to load static credentials for service client config, %v", err)
os.Exit(1)
}
// Get the value of the relevant parameter
ssmClient := ssm.NewFromConfig(cfg)
input := &ssm.GetParameterInput{
Name: aws.String(*pParamPath),
WithDecryption: func() *bool { b := true; return &b }(),
}
results, err := findParameter(context.TODO(), ssmClient, input)
if err != nil {
fmt.Println(err.Error())
return
} else {
f, err := os.Create(*pKeyPath)
check(err)
defer f.Close()
w := bufio.NewWriter(f)
k, err := w.WriteString(*results.Parameter.Value)
check(err)
log.Printf("Wrote keyfile to disk (%v bytes)", k)
w.Flush()
}
}