diff --git a/aws/ssm/ssm.go b/aws/ssm/ssm.go index 27cee79..167e874 100644 --- a/aws/ssm/ssm.go +++ b/aws/ssm/ssm.go @@ -2,6 +2,7 @@ package ssm import ( "bytes" + "errors" "fmt" "strings" "time" @@ -20,13 +21,19 @@ type ssm struct { log logger.Logger conf conf.Config - cl *assm.SSM + cl *assm.SSM + sess *session.Session } +var ( + ErrInstanceIDWaiterTimeOut = errors.New("could not get instance ids in the specified time period") +) + func New(log logger.Logger, conf conf.Config, session *session.Session) *ssm { return &ssm{ log: log.Named("ssm"), conf: conf, + sess: session, cl: assm.New(session), } } @@ -45,8 +52,8 @@ func (s ssm) provideTags() []*assm.Target { vals := strings.Split(strings.Split(rawTags, "=")[1], ",") for _, v := range vals { - v := v - tags = append(tags, stringRef(strings.TrimSpace(v))) + v := strings.TrimSpace(v) + tags = append(tags, &v) } resp = append(resp, &assm.Target{ @@ -88,18 +95,29 @@ func (s ssm) waitForCmdExecutionComplete(cmdID *string, instID *string) error { } func (s ssm) waitForCmdExecAndDisplayCmdOutput(command *assm.SendCommandOutput) { - var instIdsSuccess = make([]*string, 0) - + var ( + instIdsSuccess = make([]*string, 0) + err error + ) s.log.Debug("Command", "command", spew.Sdump(command)) - //TODO: get results when only tags are provided - for _, instID := range command.Command.InstanceIds { - if err := s.waitForCmdExecutionComplete(command.Command.CommandId, instID); err != nil { - s.log.Error("Error waiting for command execution", "err", err.Error(), "instance_id", *instID) - } else { - instIdsSuccess = append(instIdsSuccess, instID) - } + + instIdsSuccess, err = s.getInstanceIDsFromTags(command.Command.CommandId) + if err != nil { + s.log.Fatalln("Could not get instance id information", "err", err.Error()) } + //if ids != nil { + // instIds = ids + //} + // + //for _, instID := range instIds { + // if err := s.waitForCmdExecutionComplete(command.Command.CommandId, instID); err != nil { + // s.log.Error("Error waiting for command execution", "err", err.Error(), "instance_id", *instID) + // } else { + // instIdsSuccess = append(instIdsSuccess, instID) + // } + //} + for _, id := range instIdsSuccess { out, err := s.cl.GetCommandInvocation(&assm.GetCommandInvocationInput{ CommandId: command.Command.CommandId, @@ -137,8 +155,46 @@ func displayResults(instanceID *string, data *assm.GetCommandInvocationOutput) { fmt.Print(buff.String()) } -func stringRef(str string) *string { - s := str +func (s ssm) getInstanceIDsFromTags(commandId *string) ([]*string, error) { + var ( + resp = make([]*string, 0) + timeout = time.After(time.Duration(s.conf.CommandResultMaxWait) * time.Second) + ) + + for { + listOut, err := s.cl.ListCommandInvocations(&assm.ListCommandInvocationsInput{ + CommandId: commandId, + }) + if err != nil { + return nil, err + } + + totalInvocations := len(listOut.CommandInvocations) + + s.log.Debug("ListOut", "list", spew.Sdump(listOut)) + + for _, inv := range listOut.CommandInvocations { + if *inv.Status == assm.CommandStatusSuccess { + resp = append(resp, inv.InstanceId) + + if len(resp) == totalInvocations { + return resp, nil + } + } + } + + select { + case <-timeout: + if len(resp) != 0 { + s.log.Warn("Could not get all instance ids, returning partial results") + s.log.Warn("For more information enable debug log") + + return resp, nil + } - return &s + return nil, ErrInstanceIDWaiterTimeOut + default: + time.Sleep(1 * time.Second) + } + } }