Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions cmd/hosts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,52 @@ func TestPromptFlagOverridesConfig(t *testing.T) {
t.Fatalf("expected CLI prompt to override config, got %q", got)
}
}

func TestParsePortValue(t *testing.T) {
tests := []struct {
name string
input interface{}
want int
wantErr bool
}{
{name: "int", input: 22, want: 22},
{name: "int64", input: int64(2222), want: 2222},
{name: "float64 integer", input: float64(2200), want: 2200},
{name: "string", input: "2022", want: 2022},
{name: "float64 non integer", input: float64(22.5), wantErr: true},
{name: "string invalid", input: "abc", wantErr: true},
{name: "out of range", input: 70000, wantErr: true},
{name: "invalid type", input: true, wantErr: true},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := parsePortValue(tc.input)
if tc.wantErr {
if err == nil {
t.Fatalf("expected error")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tc.want {
t.Fatalf("expected %d, got %d", tc.want, got)
}
})
}
}

func TestValidatePort(t *testing.T) {
if got, err := validatePort(22); err != nil || got != 22 {
t.Fatalf("expected port 22, got %d err=%v", got, err)
}

if _, err := validatePort(0); err == nil {
t.Fatalf("expected out of range error")
}
if _, err := validatePort(65536); err == nil {
t.Fatalf("expected out of range error")
}
}
47 changes: 23 additions & 24 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ var cfgFile string
var hostsFile string
var hostGroup string

var loadSSHConfigFunc = sshConn.LoadSSHConfig

var resolveHostFunc = func(resolver *sshConn.SSHConfigResolver, spec sshConn.HostSpec, fallbackUser string) (sshConn.ResolvedHost, error) {
return resolver.ResolveHost(spec, fallbackUser)
}

var spawnShellFunc = shell.Spawn

// RootCmd represents the base command when called without any subcommands
var RootCmd = &cobra.Command{
Use: "pretty",
Expand All @@ -50,19 +58,17 @@ usage:
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
RunE: func(cmd *cobra.Command, args []string) error {
argsLen := len(args)
hostSpecs, err := parseArgsHosts(args)
if err != nil {
fmt.Println(err)
os.Exit(1)
return err
}

if hostGroup != "" {
groupSpecs, err := parseGroupSpecs(viper.Get(fmt.Sprintf("groups.%s", hostGroup)), hostGroup)
if err != nil {
fmt.Println(err)
os.Exit(1)
return err
}
if argsLen > 1 {
hostSpecs = append(hostSpecs, groupSpecs...)
Expand All @@ -74,13 +80,11 @@ usage:
if hostsFile != "" {
data, err := ioutil.ReadFile(hostsFile)
if err != nil {
fmt.Printf("unable to read hostsFile: %v\n", err)
os.Exit(1)
return fmt.Errorf("unable to read hostsFile: %w", err)
}
fileSpecs, err := parseHostsFile(data)
if err != nil {
fmt.Println(err)
os.Exit(1)
return err
}
hostSpecs = append(hostSpecs, fileSpecs...)
}
Expand Down Expand Up @@ -110,13 +114,12 @@ usage:
if home, err := os.UserHomeDir(); err == nil {
userConfigPath = filepath.Join(home, ".ssh", "config")
}
resolver, err := sshConn.LoadSSHConfig(sshConn.SSHConfigPaths{
resolver, err := loadSSHConfigFunc(sshConn.SSHConfigPaths{
User: userConfigPath,
System: "/etc/ssh/ssh_config",
})
if err != nil {
fmt.Printf("unable to load ssh config: %v\n", err)
os.Exit(1)
return fmt.Errorf("unable to load ssh config: %w", err)
}

globalUser := strings.TrimSpace(viper.GetString("username"))
Expand All @@ -135,10 +138,9 @@ usage:
resolveSpec.User = globalUser
resolveSpec.UserSet = true
}
resolved, err := resolver.ResolveHost(resolveSpec, "")
resolved, err := resolveHostFunc(resolver, resolveSpec, "")
if err != nil {
fmt.Printf("unable to resolve host %q: %v\n", spec.Host, err)
os.Exit(1)
return fmt.Errorf("unable to resolve host %q: %w", spec.Host, err)
}
jumps := make([]sshConn.ResolvedHost, 0, len(resolved.ProxyJump))
for _, jumpAlias := range resolved.ProxyJump {
Expand All @@ -147,10 +149,9 @@ usage:
jumpSpec.User = globalUser
jumpSpec.UserSet = true
}
jumpResolved, err := resolver.ResolveHost(jumpSpec, "")
jumpResolved, err := resolveHostFunc(resolver, jumpSpec, "")
if err != nil {
fmt.Printf("unable to resolve jump host %q: %v\n", jumpAlias, err)
os.Exit(1)
return fmt.Errorf("unable to resolve jump host %q: %w", jumpAlias, err)
}
jumps = append(jumps, jumpResolved)
}
Expand All @@ -167,17 +168,15 @@ usage:
}
hostList.AddHost(host)
}
shell.Spawn(hostList)
spawnShellFunc(hostList)
return nil
},
}

// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := RootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}
func Execute() error {
return RootCmd.Execute()
}

func init() {
Expand Down
Loading