Skip to content

Add -print flag #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 84 additions & 15 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func main() {
template = flag.String("template", DefaultTemplate, "The template used to determine what the SSM parameter name is for an environment variable. When this template returns an empty string, the env variable is not an SSM parameter")
decrypt = flag.Bool("with-decryption", false, "Will attempt to decrypt the parameter, and set the env var as plaintext")
nofail = flag.Bool("no-fail", false, "Don't fail if error retrieving parameter")
print = flag.Bool("print", false, "Print the decrypted env vars without exporting them and exit")
print_version = flag.Bool("V", false, "Print the version and exit")
)
flag.Parse()
Expand All @@ -61,26 +62,51 @@ func main() {
return
}

if len(args) <= 0 {
if !*print && len(args) <= 0 {
flag.Usage()
os.Exit(1)
fmt.Fprintf(os.Stderr, "\nmissing program to execute\n")
os.Exit(2)
}

path, err := exec.LookPath(args[0])
must(err)
if *print && len(args) > 0 {
flag.Usage()
fmt.Fprintf(os.Stderr, "\n-print is incompatible with arguments\n")
os.Exit(3)
}

var os osEnviron
var osEnv osEnviron

// Construct the template we'll use for extracting the ssm params we need to
// fetch.
t, err := parseTemplate(*template)
must(err)

// Construct an expander with the configs for fetching/replacing env vars.
e := &expander{
batchSize: defaultBatchSize,
t: t,
ssm: &lazySSMClient{},
os: os,
os: osEnv,
}
must(e.expandEnviron(*decrypt, *nofail))
must(syscall.Exec(path, args[0:], os.Environ()))
// Attempt to "expand" ssm vars.
vars, err := e.expandEnviron(*decrypt, *nofail)
must(err)

// Actually set the env vars for the process.
e.setEnviron(*print, vars)
// If -print was passed, we're done.
if *print {
os.Exit(0)
}

// Make sure that we're invoking ssm-env with an executable that actually
// exists.
path, err := exec.LookPath(args[0])
must(err)

// Exec whatever command was passed, using the current process' env vars
// (which are now expanded).
must(syscall.Exec(path, args[0:], osEnv.Environ()))
}

// lazySSMClient wraps the AWS SDK SSM client such that the AWS session and
Expand Down Expand Up @@ -124,6 +150,9 @@ func (c *lazySSMClient) awsSession() (*session.Session, error) {
return sess, nil
}

// Construct the template we use for parsing out ssm env var strings (by
// default, `DefaultTemplate`, which works with values like
// "ssm://<path>:<version>").
func parseTemplate(templateText string) (*template.Template, error) {
return template.New("template").Funcs(TemplateFuncs).Parse(templateText)
}
Expand All @@ -134,7 +163,9 @@ type ssmClient interface {

type environ interface {
Environ() []string
Setenv(key, vale string)
Setenv(key, val string)
Getenv(key string) string
Write(s string) error
}

type osEnviron int
Expand All @@ -147,6 +178,16 @@ func (e osEnviron) Setenv(key, val string) {
os.Setenv(key, val)
}

func (e osEnviron) Getenv(key string) string {
return os.Getenv(key)
}

func (e osEnviron) Write(s string) error {
_, err := fmt.Println(s)

return err
}

type ssmVar struct {
envvar string
parameter string
Expand All @@ -172,7 +213,22 @@ func (e *expander) parameter(k, v string) (*string, error) {
return nil, nil
}

func (e *expander) expandEnviron(decrypt bool, nofail bool) error {
func (e *expander) setEnviron(print bool, vars map[string]string) {
// If -print was passed, just dump the decrypted env vars to stdout and return.
if print {
for k, v := range vars {
e.os.Write(fmt.Sprintf("%s=%s", k, v))
}

return
}

for k, v := range vars {
e.os.Setenv(k, v)
}
}

func (e *expander) expandEnviron(decrypt bool, nofail bool) (map[string]string, error) {
// Environment variables that point to some SSM parameters.
var ssmVars []ssmVar

Expand All @@ -183,7 +239,7 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error {
parameter, err := e.parameter(k, v)
if err != nil {
// TODO: Should this _also_ not error if nofail is passed?
return fmt.Errorf("determining name of parameter: %v", err)
return make(map[string]string), fmt.Errorf("determining name of parameter: %v", err)
}

if parameter != nil {
Expand All @@ -194,16 +250,20 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error {

if len(uniqNames) == 0 {
// Nothing to do, no SSM parameters.
return nil
return make(map[string]string), nil
}

// Construct a string slice to hold each ssm value.
names := make([]string, len(uniqNames))
// Go through and extract the values from uniqNames into the string slice.
i := 0
for k := range uniqNames {
names[i] = k
i++
}

// For each chunk of batched ssm params, get the decrypted values.
decryptedVars := make(map[string]string)
for i := 0; i < len(names); i += e.batchSize {
j := i + e.batchSize
if j > len(names) {
Expand All @@ -212,18 +272,26 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error {

values, err := e.getParameters(names[i:j], decrypt, nofail)
if err != nil {
return err
return make(map[string]string), err
}

if nofail && len(values) == 0 {
for _, v := range ssmVars {
decryptedVars[v.envvar] = e.os.Getenv(v.envvar)
}

return decryptedVars, nil
}

for _, v := range ssmVars {
val, ok := values[v.parameter]
if ok {
e.os.Setenv(v.envvar, val)
decryptedVars[v.envvar] = val
}
}
}

return nil
return decryptedVars, nil
}

func (e *expander) getParameters(names []string, decrypt bool, nofail bool) (map[string]string, error) {
Expand Down Expand Up @@ -287,6 +355,7 @@ func splitVar(v string) (key, val string) {
return parts[0], parts[1]
}

// Abort with an error message if err is not nill.
func must(err error) {
if err != nil {
fmt.Fprintf(os.Stderr, "ssm-env: %v\n", err)
Expand Down
Loading