Skip to content

Commit

Permalink
Verbose output for sync command and adding --region flag. Defaults to…
Browse files Browse the repository at this point in the history
… sso region
  • Loading branch information
null93 committed Nov 5, 2024
1 parent f0e20eb commit 7ec7352
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 22 deletions.
16 changes: 10 additions & 6 deletions internal/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"os/exec"
"syscall"

"github.com/null93/aws-knox/sdk/credentials"
"github.com/null93/aws-knox/sdk/tui"
"github.com/null93/aws-knox/pkg/color"
"github.com/null93/aws-knox/sdk/credentials"
. "github.com/null93/aws-knox/sdk/style"
"github.com/null93/aws-knox/sdk/tui"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -77,8 +77,11 @@ var connectCmd = &cobra.Command{
continue
}
}
if region == "" {
region = role.Region
}
if instanceId == "" {
if instanceId, action, err = tui.SelectInstance(role, searchTerm); err != nil {
if instanceId, action, err = tui.SelectInstance(role, region, searchTerm); err != nil {
ExitWithError(19, "failed to pick an instance", err)
} else if action == "back" {
goBack()
Expand All @@ -91,7 +94,7 @@ var connectCmd = &cobra.Command{
title := TitleStyle.Decorator()
DefaultStyle.Printfln("")
DefaultStyle.Printfln("%s %s", title("SSO Session: "), gray(role.SessionName))
DefaultStyle.Printfln("%s %s", title("Region: "), gray(role.Region))
DefaultStyle.Printfln("%s %s", title("Region: "), gray(region))
DefaultStyle.Printfln("%s %s", title("Account ID: "), gray(role.AccountId))
DefaultStyle.Printfln("%s %s", title("Role Name: "), gray(role.Name))
DefaultStyle.Printfln("%s %s", title("Instance ID: "), yellow(instanceId))
Expand All @@ -106,11 +109,11 @@ var connectCmd = &cobra.Command{
command := exec.Command(
binaryPath,
fmt.Sprintf(`{"SessionId": "%s", "TokenValue": "%s", "StreamUrl": "%s"}`, *details.SessionId, *details.TokenValue, *details.StreamUrl),
role.Region,
region,
"StartSession",
"", // No Profile
fmt.Sprintf(`{"Target": "%s"}`, instanceId),
fmt.Sprintf("https://ssm.%s.amazonaws.com", role.Region),
fmt.Sprintf("https://ssm.%s.amazonaws.com", region),
)
command.Stdin = os.Stdin
command.Stdout = os.Stdout
Expand All @@ -131,6 +134,7 @@ func init() {
connectCmd.Flags().StringVarP(&accountId, "account-id", "a", accountId, "AWS account ID")
connectCmd.Flags().StringVarP(&roleName, "role-name", "r", roleName, "AWS role name")
connectCmd.Flags().StringVarP(&instanceId, "instance-id", "i", instanceId, "EC2 instance ID")
connectCmd.Flags().StringVar(&region, "region", region, "Region for quering instances")
connectCmd.Flags().BoolVarP(&selectCachedFirst, "cached", "c", selectCachedFirst, "select from cached credentials")
connectCmd.Flags().BoolVarP(&lastUsed, "last-used", "l", lastUsed, "select last used credentials")
connectCmd.Flags().Uint32VarP(&connectUid, "uid", "u", connectUid, "UID on instance to 'su' to")
Expand Down
1 change: 1 addition & 0 deletions internal/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var (
accountId string
roleName string
instanceId string
region string
accountAliases map[string]string
)

Expand Down
40 changes: 28 additions & 12 deletions internal/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"strings"
"syscall"

"github.com/null93/aws-knox/pkg/color"
"github.com/null93/aws-knox/sdk/credentials"
. "github.com/null93/aws-knox/sdk/style"
"github.com/null93/aws-knox/sdk/tui"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -48,11 +50,11 @@ func rsyncInit(role *credentials.Role, instanceId string) {
command := exec.Command(
binaryPath,
fmt.Sprintf(`{"SessionId": "%s", "TokenValue": "%s", "StreamUrl": "%s"}`, *details.SessionId, *details.TokenValue, *details.StreamUrl),
role.Region,
region,
"StartSession",
"", // No Profile
fmt.Sprintf(`{"Target": "%s"}`, instanceId),
fmt.Sprintf("https://ssm.%s.amazonaws.com", role.Region),
fmt.Sprintf("https://ssm.%s.amazonaws.com", region),
)
command.Stdin = os.Stdin
command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true, Foreground: true}
Expand Down Expand Up @@ -85,11 +87,11 @@ func rsyncStart(role *credentials.Role, instanceId string) {
command := exec.Command(
binaryPath,
fmt.Sprintf(`{"SessionId": "%s", "TokenValue": "%s", "StreamUrl": "%s"}`, *details.SessionId, *details.TokenValue, *details.StreamUrl),
role.Region,
region,
"StartSession",
"", // No Profile
fmt.Sprintf(`{"Target": "%s"}`, instanceId),
fmt.Sprintf("https://ssm.%s.amazonaws.com", role.Region),
fmt.Sprintf("https://ssm.%s.amazonaws.com", region),
)
command.Stdin = os.Stdin
command.Stdout = nil
Expand Down Expand Up @@ -117,11 +119,11 @@ func rsyncClean(role *credentials.Role, instanceId string) {
command := exec.Command(
binaryPath,
fmt.Sprintf(`{"SessionId": "%s", "TokenValue": "%s", "StreamUrl": "%s"}`, *details.SessionId, *details.TokenValue, *details.StreamUrl),
role.Region,
region,
"StartSession",
"", // No Profile
fmt.Sprintf(`{"Target": "%s"}`, instanceId),
fmt.Sprintf("https://ssm.%s.amazonaws.com", role.Region),
fmt.Sprintf("https://ssm.%s.amazonaws.com", region),
)
command.Stdin = os.Stdin
command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true, Foreground: true}
Expand All @@ -146,11 +148,11 @@ func rsyncPortForward(role *credentials.Role, instanceId string) {
command := exec.Command(
binaryPath,
fmt.Sprintf(`{"SessionId": "%s", "TokenValue": "%s", "StreamUrl": "%s"}`, *details.SessionId, *details.TokenValue, *details.StreamUrl),
role.Region,
region,
"StartSession",
"", // No Profile
fmt.Sprintf(`{"Target": "%s"}`, instanceId),
fmt.Sprintf("https://ssm.%s.amazonaws.com", role.Region),
fmt.Sprintf("https://ssm.%s.amazonaws.com", region),
)
command.Stdin = os.Stdin
command.Stdout = nil
Expand Down Expand Up @@ -224,17 +226,30 @@ var syncCmd = &cobra.Command{
continue
}
}
if region == "" {
region = role.Region
}
if instanceId == "" {
if instanceId, action, err = tui.SelectInstance(role, searchTerm); err != nil {
if instanceId, action, err = tui.SelectInstance(role, region, searchTerm); err != nil {
ExitWithError(19, "failed to pick an instance", err)
} else if action == "back" {
goBack()
continue
}
}
fmt.Println("Remote Destination: /root/knox-sync")
fmt.Printf("Example Command: rsync -P ./dump.sql ./release.tar.gz rsync://127.0.0.1:%d/sync\n", localPort)
fmt.Println()

yellow := color.ToForeground(YellowColor).Decorator()
gray := color.ToForeground(LightGrayColor).Decorator()
title := TitleStyle.Decorator()
DefaultStyle.Printfln("")
DefaultStyle.Printfln("%s %s", title("SSO Session: "), gray(role.SessionName))
DefaultStyle.Printfln("%s %s", title("Region: "), gray(region))
DefaultStyle.Printfln("%s %s", title("Account ID: "), gray(role.AccountId))
DefaultStyle.Printfln("%s %s", title("Role Name: "), gray(role.Name))
DefaultStyle.Printfln("%s %s", title("Instance ID: "), gray(instanceId))
DefaultStyle.Printfln("%s %s", title("Remote Destination: "), gray("/root/knox-sync"))
DefaultStyle.Printfln("%s %s", title("Example Command: "), yellow("rsync -P ./dump.sql ./release.tar.gz rsync://127.0.0.1:%d/sync\n", localPort))

defer rsyncClean(role, instanceId)
defer func() {
fmt.Println("\nCleaning up...")
Expand All @@ -257,6 +272,7 @@ func init() {
syncCmd.Flags().StringVarP(&accountId, "account-id", "a", accountId, "AWS account ID")
syncCmd.Flags().StringVarP(&roleName, "role-name", "r", roleName, "AWS role name")
syncCmd.Flags().StringVarP(&instanceId, "instance-id", "i", instanceId, "EC2 instance ID")
syncCmd.Flags().StringVar(&region, "region", region, "Region for quering instances")
syncCmd.Flags().Uint16VarP(&rsyncPort, "rsync-port", "P", rsyncPort, "rsync port")
syncCmd.Flags().Uint16VarP(&localPort, "local-port", "p", localPort, "local port")
syncCmd.Flags().BoolVarP(&lastUsed, "last-used", "l", lastUsed, "select last used credentials")
Expand Down
4 changes: 2 additions & 2 deletions sdk/credentials/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (r Roles) FindByName(name string) *Role {
return nil
}

func (r *Role) GetManagedInstances() (Instances, error) {
func (r *Role) GetManagedInstances(region string) (Instances, error) {
instances := Instances{}
if r.Credentials == nil {
return instances, ErrorRoleCredentialsNil
Expand All @@ -123,7 +123,7 @@ func (r *Role) GetManagedInstances() (Instances, error) {
r.Credentials.SecretAccessKey,
r.Credentials.SessionToken,
)
options := ec2.Options{Region: r.Region, Credentials: staticProvider}
options := ec2.Options{Region: region, Credentials: staticProvider}
client := ec2.New(options)
params := ec2.DescribeInstancesInput{
Filters: []ec2types.Filter{
Expand Down
4 changes: 2 additions & 2 deletions sdk/tui/tui.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ func SelectRole(roles credentials.Roles) (string, string, error) {
return selection.Value.(string), "", nil
}

func SelectInstance(role *credentials.Role, initialFilter string) (string, string, error) {
instances, err := role.GetManagedInstances()
func SelectInstance(role *credentials.Role, region, initialFilter string) (string, string, error) {
instances, err := role.GetManagedInstances(region)
if err != nil {
return "", "", err
}
Expand Down

0 comments on commit 7ec7352

Please sign in to comment.