diff --git a/cmd/hosts_test.go b/cmd/hosts_test.go index cbe3ffd..6badc51 100644 --- a/cmd/hosts_test.go +++ b/cmd/hosts_test.go @@ -232,3 +232,52 @@ func TestPromptFlagOverridesConfig(t *testing.T) { t.Fatalf("expected CLI prompt to override config, got %q", got) } } + +func TestParsePortValue(t *testing.T) { + tests := []struct { + name string + input interface{} + want int + wantErr bool + }{ + {name: "int", input: 22, want: 22}, + {name: "int64", input: int64(2222), want: 2222}, + {name: "float64 integer", input: float64(2200), want: 2200}, + {name: "string", input: "2022", want: 2022}, + {name: "float64 non integer", input: float64(22.5), wantErr: true}, + {name: "string invalid", input: "abc", wantErr: true}, + {name: "out of range", input: 70000, wantErr: true}, + {name: "invalid type", input: true, wantErr: true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parsePortValue(tc.input) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tc.want { + t.Fatalf("expected %d, got %d", tc.want, got) + } + }) + } +} + +func TestValidatePort(t *testing.T) { + if got, err := validatePort(22); err != nil || got != 22 { + t.Fatalf("expected port 22, got %d err=%v", got, err) + } + + if _, err := validatePort(0); err == nil { + t.Fatalf("expected out of range error") + } + if _, err := validatePort(65536); err == nil { + t.Fatalf("expected out of range error") + } +} diff --git a/cmd/root.go b/cmd/root.go index f160466..a156dfd 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -34,6 +34,14 @@ var cfgFile string var hostsFile string var hostGroup string +var loadSSHConfigFunc = sshConn.LoadSSHConfig + +var resolveHostFunc = func(resolver *sshConn.SSHConfigResolver, spec sshConn.HostSpec, fallbackUser string) (sshConn.ResolvedHost, error) { + return resolver.ResolveHost(spec, fallbackUser) +} + +var spawnShellFunc = shell.Spawn + // RootCmd represents the base command when called without any subcommands var RootCmd = &cobra.Command{ Use: "pretty", @@ -50,19 +58,17 @@ usage: } return nil }, - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { argsLen := len(args) hostSpecs, err := parseArgsHosts(args) if err != nil { - fmt.Println(err) - os.Exit(1) + return err } if hostGroup != "" { groupSpecs, err := parseGroupSpecs(viper.Get(fmt.Sprintf("groups.%s", hostGroup)), hostGroup) if err != nil { - fmt.Println(err) - os.Exit(1) + return err } if argsLen > 1 { hostSpecs = append(hostSpecs, groupSpecs...) @@ -74,13 +80,11 @@ usage: if hostsFile != "" { data, err := ioutil.ReadFile(hostsFile) if err != nil { - fmt.Printf("unable to read hostsFile: %v\n", err) - os.Exit(1) + return fmt.Errorf("unable to read hostsFile: %w", err) } fileSpecs, err := parseHostsFile(data) if err != nil { - fmt.Println(err) - os.Exit(1) + return err } hostSpecs = append(hostSpecs, fileSpecs...) } @@ -110,13 +114,12 @@ usage: if home, err := os.UserHomeDir(); err == nil { userConfigPath = filepath.Join(home, ".ssh", "config") } - resolver, err := sshConn.LoadSSHConfig(sshConn.SSHConfigPaths{ + resolver, err := loadSSHConfigFunc(sshConn.SSHConfigPaths{ User: userConfigPath, System: "/etc/ssh/ssh_config", }) if err != nil { - fmt.Printf("unable to load ssh config: %v\n", err) - os.Exit(1) + return fmt.Errorf("unable to load ssh config: %w", err) } globalUser := strings.TrimSpace(viper.GetString("username")) @@ -135,10 +138,9 @@ usage: resolveSpec.User = globalUser resolveSpec.UserSet = true } - resolved, err := resolver.ResolveHost(resolveSpec, "") + resolved, err := resolveHostFunc(resolver, resolveSpec, "") if err != nil { - fmt.Printf("unable to resolve host %q: %v\n", spec.Host, err) - os.Exit(1) + return fmt.Errorf("unable to resolve host %q: %w", spec.Host, err) } jumps := make([]sshConn.ResolvedHost, 0, len(resolved.ProxyJump)) for _, jumpAlias := range resolved.ProxyJump { @@ -147,10 +149,9 @@ usage: jumpSpec.User = globalUser jumpSpec.UserSet = true } - jumpResolved, err := resolver.ResolveHost(jumpSpec, "") + jumpResolved, err := resolveHostFunc(resolver, jumpSpec, "") if err != nil { - fmt.Printf("unable to resolve jump host %q: %v\n", jumpAlias, err) - os.Exit(1) + return fmt.Errorf("unable to resolve jump host %q: %w", jumpAlias, err) } jumps = append(jumps, jumpResolved) } @@ -167,17 +168,15 @@ usage: } hostList.AddHost(host) } - shell.Spawn(hostList) + spawnShellFunc(hostList) + return nil }, } // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. -func Execute() { - if err := RootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(1) - } +func Execute() error { + return RootCmd.Execute() } func init() { diff --git a/cmd/root_test.go b/cmd/root_test.go new file mode 100644 index 0000000..6a6d1d1 --- /dev/null +++ b/cmd/root_test.go @@ -0,0 +1,324 @@ +package cmd + +import ( + "errors" + "os" + "strings" + "testing" + + "github.com/ncode/pretty/internal/sshConn" + "github.com/spf13/viper" +) + +func TestRootCmdUsesRunE(t *testing.T) { + if RootCmd.RunE == nil { + t.Fatalf("expected RootCmd.RunE to be configured") + } +} + +func TestExecuteReturnsErrorForMissingHosts(t *testing.T) { + prevHostGroup := hostGroup + prevHostsFile := hostsFile + t.Cleanup(func() { + hostGroup = prevHostGroup + hostsFile = prevHostsFile + RootCmd.SetArgs(nil) + }) + + hostGroup = "" + hostsFile = "" + RootCmd.SetArgs([]string{}) + + err := Execute() + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "requires at least one host") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestExecuteReturnsErrorForInvalidHostSpec(t *testing.T) { + prevHostGroup := hostGroup + prevHostsFile := hostsFile + t.Cleanup(func() { + hostGroup = prevHostGroup + hostsFile = prevHostsFile + RootCmd.SetArgs(nil) + }) + + hostGroup = "" + hostsFile = "" + RootCmd.SetArgs([]string{":"}) + + err := Execute() + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "invalid host") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestExecuteReturnsErrorWhenHostsFileCannotBeRead(t *testing.T) { + prevHostGroup := hostGroup + prevHostsFile := hostsFile + t.Cleanup(func() { + hostGroup = prevHostGroup + hostsFile = prevHostsFile + RootCmd.SetArgs(nil) + }) + + hostGroup = "" + hostsFile = "/path/that/does/not/exist" + RootCmd.SetArgs([]string{"host1"}) + + err := Execute() + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "unable to read hostsFile") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestExecuteReturnsErrorForInvalidGroupSpec(t *testing.T) { + prevHostGroup := hostGroup + prevHostsFile := hostsFile + t.Cleanup(func() { + hostGroup = prevHostGroup + hostsFile = prevHostsFile + viper.Set("groups.bad", nil) + RootCmd.SetArgs(nil) + }) + + viper.Set("groups.bad", map[string]interface{}{"user": "deploy"}) + hostGroup = "bad" + hostsFile = "" + RootCmd.SetArgs([]string{"host1"}) + + err := Execute() + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "missing hosts") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestExecuteReturnsErrorForInvalidHostsFileContent(t *testing.T) { + prevHostGroup := hostGroup + prevHostsFile := hostsFile + t.Cleanup(func() { + hostGroup = prevHostGroup + hostsFile = prevHostsFile + RootCmd.SetArgs(nil) + }) + + f, err := os.CreateTemp(t.TempDir(), "hosts-*.txt") + if err != nil { + t.Fatalf("unexpected temp file error: %v", err) + } + if _, err := f.WriteString("host1:badport\n"); err != nil { + t.Fatalf("unexpected write error: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("unexpected close error: %v", err) + } + + hostGroup = "" + hostsFile = f.Name() + RootCmd.SetArgs([]string{"host2"}) + + err = Execute() + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "invalid port") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestExecuteReturnsErrorWhenLoadingSSHConfigFails(t *testing.T) { + prevHostGroup := hostGroup + prevHostsFile := hostsFile + prevLoad := loadSSHConfigFunc + t.Cleanup(func() { + hostGroup = prevHostGroup + hostsFile = prevHostsFile + loadSSHConfigFunc = prevLoad + RootCmd.SetArgs(nil) + }) + + loadSSHConfigFunc = func(paths sshConn.SSHConfigPaths) (*sshConn.SSHConfigResolver, error) { + return nil, errors.New("config boom") + } + + hostGroup = "" + hostsFile = "" + RootCmd.SetArgs([]string{"host1"}) + + err := Execute() + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "unable to load ssh config") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestExecuteReturnsErrorWhenResolveHostFails(t *testing.T) { + prevHostGroup := hostGroup + prevHostsFile := hostsFile + prevLoad := loadSSHConfigFunc + prevResolve := resolveHostFunc + t.Cleanup(func() { + hostGroup = prevHostGroup + hostsFile = prevHostsFile + loadSSHConfigFunc = prevLoad + resolveHostFunc = prevResolve + RootCmd.SetArgs(nil) + }) + + loadSSHConfigFunc = func(paths sshConn.SSHConfigPaths) (*sshConn.SSHConfigResolver, error) { + return &sshConn.SSHConfigResolver{}, nil + } + resolveHostFunc = func(resolver *sshConn.SSHConfigResolver, spec sshConn.HostSpec, fallbackUser string) (sshConn.ResolvedHost, error) { + return sshConn.ResolvedHost{}, errors.New("resolve boom") + } + + hostGroup = "" + hostsFile = "" + RootCmd.SetArgs([]string{"host1"}) + + err := Execute() + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "unable to resolve host") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestExecuteReturnsErrorWhenResolveJumpFails(t *testing.T) { + prevHostGroup := hostGroup + prevHostsFile := hostsFile + prevLoad := loadSSHConfigFunc + prevResolve := resolveHostFunc + t.Cleanup(func() { + hostGroup = prevHostGroup + hostsFile = prevHostsFile + loadSSHConfigFunc = prevLoad + resolveHostFunc = prevResolve + RootCmd.SetArgs(nil) + }) + + loadSSHConfigFunc = func(paths sshConn.SSHConfigPaths) (*sshConn.SSHConfigResolver, error) { + return &sshConn.SSHConfigResolver{}, nil + } + call := 0 + resolveHostFunc = func(resolver *sshConn.SSHConfigResolver, spec sshConn.HostSpec, fallbackUser string) (sshConn.ResolvedHost, error) { + call++ + if call == 1 { + return sshConn.ResolvedHost{Alias: spec.Alias, Host: spec.Host, Port: 22, ProxyJump: []string{"jump1"}}, nil + } + return sshConn.ResolvedHost{}, errors.New("jump boom") + } + + hostGroup = "" + hostsFile = "" + RootCmd.SetArgs([]string{"host1"}) + + err := Execute() + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "unable to resolve jump host") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestExecuteReturnsNilOnSuccessfulSetup(t *testing.T) { + prevHostGroup := hostGroup + prevHostsFile := hostsFile + prevLoad := loadSSHConfigFunc + prevSpawn := spawnShellFunc + t.Cleanup(func() { + hostGroup = prevHostGroup + hostsFile = prevHostsFile + loadSSHConfigFunc = prevLoad + spawnShellFunc = prevSpawn + RootCmd.SetArgs(nil) + }) + + loadSSHConfigFunc = func(paths sshConn.SSHConfigPaths) (*sshConn.SSHConfigResolver, error) { + return &sshConn.SSHConfigResolver{}, nil + } + spawnCalled := false + spawnShellFunc = func(hostList *sshConn.HostList) { + spawnCalled = true + } + + hostGroup = "" + hostsFile = "" + RootCmd.SetArgs([]string{"host1"}) + + err := Execute() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !spawnCalled { + t.Fatalf("expected spawn to be called") + } +} + +func TestExecuteAppliesGlobalUserToHostAndJumpSpecs(t *testing.T) { + prevHostGroup := hostGroup + prevHostsFile := hostsFile + prevLoad := loadSSHConfigFunc + prevResolve := resolveHostFunc + prevSpawn := spawnShellFunc + prevUsername := viper.Get("username") + t.Cleanup(func() { + hostGroup = prevHostGroup + hostsFile = prevHostsFile + loadSSHConfigFunc = prevLoad + resolveHostFunc = prevResolve + spawnShellFunc = prevSpawn + viper.Set("username", prevUsername) + RootCmd.SetArgs(nil) + }) + + loadSSHConfigFunc = func(paths sshConn.SSHConfigPaths) (*sshConn.SSHConfigResolver, error) { + return &sshConn.SSHConfigResolver{}, nil + } + spawnShellFunc = func(hostList *sshConn.HostList) {} + viper.Set("username", "deploy") + + call := 0 + resolveHostFunc = func(resolver *sshConn.SSHConfigResolver, spec sshConn.HostSpec, fallbackUser string) (sshConn.ResolvedHost, error) { + call++ + if call == 1 { + if !spec.UserSet || spec.User != "deploy" { + t.Fatalf("expected global user on host resolve, got user=%q userSet=%v", spec.User, spec.UserSet) + } + return sshConn.ResolvedHost{Alias: spec.Alias, Host: spec.Host, Port: 22, User: spec.User, ProxyJump: []string{"jump1"}}, nil + } + if !spec.UserSet || spec.User != "deploy" { + t.Fatalf("expected global user on jump resolve, got user=%q userSet=%v", spec.User, spec.UserSet) + } + return sshConn.ResolvedHost{Alias: spec.Alias, Host: spec.Host, Port: 22, User: spec.User}, nil + } + + hostGroup = "" + hostsFile = "" + RootCmd.SetArgs([]string{"host1"}) + + err := Execute() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if call != 2 { + t.Fatalf("expected 2 resolve calls, got %d", call) + } +} diff --git a/internal/shell/model_test.go b/internal/shell/model_test.go index 7ad1906..5bb496d 100644 --- a/internal/shell/model_test.go +++ b/internal/shell/model_test.go @@ -487,3 +487,60 @@ func BenchmarkAppendOutputs(b *testing.B) { m.appendOutputs(lines...) } } + +func TestConnectedHosts(t *testing.T) { + hostList := sshConn.NewHostList() + hostList.AddHost(&sshConn.Host{Hostname: "host1", IsConnected: 1}) + hostList.AddHost(&sshConn.Host{Hostname: "host2", IsConnected: 0}) + + hosts := connectedHosts(hostList) + if len(hosts) != 1 { + t.Fatalf("expected 1 connected host, got %d", len(hosts)) + } + if hosts[0].Hostname != "host1" { + t.Fatalf("unexpected host: %s", hosts[0].Hostname) + } +} + +func TestConnectedHostsNilList(t *testing.T) { + if got := connectedHosts(nil); got != nil { + t.Fatalf("expected nil hosts, got %#v", got) + } +} + +func TestHostnames(t *testing.T) { + hosts := []*sshConn.Host{{Hostname: "alpha"}, {Hostname: "beta"}} + got := hostnames(hosts) + if len(got) != 2 || got[0] != "alpha" || got[1] != "beta" { + t.Fatalf("unexpected hostnames: %#v", got) + } +} + +func TestInitReturnsNilCmdWithoutEvents(t *testing.T) { + m := initialModel(nil, nil, nil) + if cmd := m.Init(); cmd != nil { + t.Fatalf("expected nil cmd when events channel is nil") + } +} + +func TestInitReturnsOutputBatchFromEvents(t *testing.T) { + events := make(chan sshConn.OutputEvent, 2) + m := initialModel(nil, nil, events) + + events <- sshConn.OutputEvent{Hostname: "h1", Line: "one"} + events <- sshConn.OutputEvent{Hostname: "h1", Line: "two"} + close(events) + + cmd := m.Init() + if cmd == nil { + t.Fatalf("expected non-nil cmd") + } + msg := cmd() + got, ok := msg.(outputMsg) + if !ok { + t.Fatalf("expected outputMsg, got %T", msg) + } + if len(got.events) != 2 { + t.Fatalf("expected 2 events, got %d", len(got.events)) + } +} diff --git a/internal/shell/view_test.go b/internal/shell/view_test.go new file mode 100644 index 0000000..9c1b927 --- /dev/null +++ b/internal/shell/view_test.go @@ -0,0 +1,23 @@ +package shell + +import "testing" + +func TestViewReturnsEmptyWhenQuit(t *testing.T) { + m := initialModel(nil, nil, nil) + m.quit = true + + if got := m.View(); got != "" { + t.Fatalf("expected empty view on quit, got %q", got) + } +} + +func TestViewIncludesViewportAndInput(t *testing.T) { + m := initialModel(nil, nil, nil) + m.viewport.SetContent("output-line") + m.input.SetValue("echo hi") + + got := m.View() + if got == "" { + t.Fatalf("expected non-empty view") + } +} diff --git a/internal/sshConn/message.go b/internal/sshConn/message.go index 74d7726..796395b 100644 --- a/internal/sshConn/message.go +++ b/internal/sshConn/message.go @@ -6,6 +6,12 @@ import ( "sync/atomic" ) +var ( + connectionFunc = Connection + sessionFunc = Session + workerRunner = worker +) + type ProxyWriter struct { events chan<- OutputEvent host *Host @@ -62,7 +68,7 @@ func emitSystem(events chan<- OutputEvent, host *Host, line string) { } func worker(host *Host, input <-chan CommandRequest, events chan<- OutputEvent) { - connection, err := Connection(host) + connection, err := connectionFunc(host) if err != nil { emitSystem(events, host, fmt.Sprintf("error connection to host %s: %v", host.Hostname, err)) return @@ -72,7 +78,7 @@ func worker(host *Host, input <-chan CommandRequest, events chan<- OutputEvent) stdoutWriter := NewProxyWriter(events, host, 0) stderrWriter := NewProxyWriter(events, host, 0) stderrWriter.system = true - stdin, session, err := Session(connection, host, stdoutWriter, stderrWriter) + stdin, session, err := sessionFunc(connection, host, stdoutWriter, stderrWriter) if err != nil { emitSystem(events, host, fmt.Sprintf("unable to open session: %v", err)) atomic.StoreInt32(&host.IsConnected, 0) @@ -100,7 +106,7 @@ func worker(host *Host, input <-chan CommandRequest, events chan<- OutputEvent) func Broker(hostList *HostList, input <-chan CommandRequest, events chan<- OutputEvent) { for _, host := range hostList.Hosts() { host.Channel = make(chan CommandRequest) - go worker(host, host.Channel, events) + go workerRunner(host, host.Channel, events) } for request := range input { diff --git a/internal/sshConn/message_test.go b/internal/sshConn/message_test.go new file mode 100644 index 0000000..9c478e9 --- /dev/null +++ b/internal/sshConn/message_test.go @@ -0,0 +1,135 @@ +package sshConn + +import ( + "errors" + "io" + "strings" + "sync/atomic" + "testing" + + "golang.org/x/crypto/ssh" +) + +type captureWriteCloser struct { + buf []byte +} + +func (w *captureWriteCloser) Write(p []byte) (int, error) { + w.buf = append(w.buf, p...) + return len(p), nil +} + +func (w *captureWriteCloser) Close() error { return nil } + +func TestWorkerEmitsConnectionError(t *testing.T) { + prevConnection := connectionFunc + prevSession := sessionFunc + t.Cleanup(func() { + connectionFunc = prevConnection + sessionFunc = prevSession + }) + + connectionFunc = func(host *Host) (*ssh.Client, error) { + return nil, errors.New("dial failed") + } + sessionFunc = prevSession + + host := &Host{Hostname: "host1"} + events := make(chan OutputEvent, 1) + input := make(chan CommandRequest) + close(input) + + worker(host, input, events) + + select { + case evt := <-events: + if !evt.System { + t.Fatalf("expected system event") + } + if !strings.Contains(evt.Line, "dial failed") { + t.Fatalf("unexpected line: %q", evt.Line) + } + default: + t.Fatalf("expected connection error event") + } +} + +func TestWorkerHandlesRequestsWithStubSession(t *testing.T) { + prevConnection := connectionFunc + prevSession := sessionFunc + t.Cleanup(func() { + connectionFunc = prevConnection + sessionFunc = prevSession + }) + + connectionFunc = func(host *Host) (*ssh.Client, error) { + return &ssh.Client{}, nil + } + stdin := &captureWriteCloser{} + sessionFunc = func(connection *ssh.Client, host *Host, stdout, stderr io.Writer) (io.WriteCloser, *ssh.Session, error) { + return stdin, nil, nil + } + + host := &Host{Hostname: "host1"} + events := make(chan OutputEvent, 1) + input := make(chan CommandRequest, 2) + input <- CommandRequest{Kind: CommandKindRun, JobID: 7, Command: "uptime"} + input <- CommandRequest{Kind: CommandKindControl, JobID: 8, ControlByte: 0x03} + close(input) + + worker(host, input, events) + + written := string(stdin.buf) + if !strings.Contains(written, "uptime\n") { + t.Fatalf("expected command write, got %q", written) + } + if !strings.Contains(written, string([]byte{0x03})) { + t.Fatalf("expected control byte write, got %q", written) + } + if atomic.LoadInt32(&host.IsConnected) != 1 { + t.Fatalf("expected host connected") + } + if atomic.LoadInt32(&host.IsWaiting) != 0 { + t.Fatalf("expected host not waiting") + } +} + +func TestBrokerDispatchesOnlyToConnectedHosts(t *testing.T) { + prevWorker := workerRunner + t.Cleanup(func() { + workerRunner = prevWorker + }) + + dispatched := make(chan string, 2) + workerRunner = func(host *Host, input <-chan CommandRequest, events chan<- OutputEvent) { + request := <-input + dispatched <- host.Hostname + ":" + request.Command + } + + hostList := NewHostList() + host1 := &Host{Hostname: "host1", IsConnected: 1} + host2 := &Host{Hostname: "host2", IsConnected: 0} + hostList.AddHost(host1) + hostList.AddHost(host2) + + input := make(chan CommandRequest, 1) + done := make(chan struct{}) + go func() { + Broker(hostList, input, nil) + close(done) + }() + + input <- CommandRequest{Kind: CommandKindRun, JobID: 1, Command: "date"} + close(input) + <-done + + got := <-dispatched + if got != "host1:date" { + t.Fatalf("unexpected dispatch %q", got) + } + select { + case extra := <-dispatched: + t.Fatalf("unexpected extra dispatch: %s", extra) + default: + } +} diff --git a/internal/sshConn/output_test.go b/internal/sshConn/output_test.go index 5bcb3b8..fb04dce 100644 --- a/internal/sshConn/output_test.go +++ b/internal/sshConn/output_test.go @@ -1,6 +1,11 @@ package sshConn -import "testing" +import ( + "io" + "os" + "strings" + "testing" +) func TestProxyWriterEmitsLines(t *testing.T) { events := make(chan OutputEvent, 1) @@ -52,3 +57,46 @@ func BenchmarkProxyWriterWrite(b *testing.B) { } } } + +func TestEmitSystemSendsEventWhenChannelExists(t *testing.T) { + events := make(chan OutputEvent, 1) + host := &Host{Hostname: "host1"} + + emitSystem(events, host, "failed") + + select { + case evt := <-events: + if !evt.System { + t.Fatalf("expected system event") + } + if evt.Hostname != "host1" || evt.Line != "failed" { + t.Fatalf("unexpected event: %+v", evt) + } + default: + t.Fatalf("expected event") + } +} + +func TestEmitSystemPrintsWhenNoChannel(t *testing.T) { + host := &Host{Hostname: "host1"} + + oldStdout := os.Stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("unexpected pipe error: %v", err) + } + os.Stdout = w + + emitSystem(nil, host, "hello") + + _ = w.Close() + os.Stdout = oldStdout + + data, err := io.ReadAll(r) + if err != nil { + t.Fatalf("unexpected read error: %v", err) + } + if !strings.Contains(string(data), "hello") { + t.Fatalf("expected output to contain message, got %q", string(data)) + } +} diff --git a/internal/sshConn/ssh_test.go b/internal/sshConn/ssh_test.go index ff2591d..277ee42 100644 --- a/internal/sshConn/ssh_test.go +++ b/internal/sshConn/ssh_test.go @@ -9,3 +9,43 @@ func TestDialAddressUsesPort(t *testing.T) { t.Fatalf("unexpected address: %s", got) } } + +func TestResolvedAddressUsesPort(t *testing.T) { + host := ResolvedHost{Host: "example.com", Port: 2200} + got := resolvedAddress(host) + if got != "example.com:2200" { + t.Fatalf("unexpected address: %s", got) + } +} + +func TestHostListLifecycleAndState(t *testing.T) { + hostList := NewHostList() + if hostList == nil { + t.Fatal("expected host list") + } + if hostList.Len() != 0 { + t.Fatalf("expected empty list, got %d", hostList.Len()) + } + + h1 := &Host{Hostname: "h1", IsConnected: 1, IsWaiting: 1} + h2 := &Host{Hostname: "h2", IsConnected: 1, IsWaiting: 0} + h3 := &Host{Hostname: "h3", IsConnected: 0, IsWaiting: 1} + + hostList.AddHost(h1) + hostList.AddHost(h2) + hostList.AddHost(h3) + + if hostList.Len() != 3 { + t.Fatalf("expected len 3, got %d", hostList.Len()) + } + + hosts := hostList.Hosts() + if len(hosts) != 3 { + t.Fatalf("expected 3 hosts, got %d", len(hosts)) + } + + connected, waiting := hostList.State() + if connected != 2 || waiting != 1 { + t.Fatalf("expected connected=2 waiting=1, got connected=%d waiting=%d", connected, waiting) + } +} diff --git a/main.go b/main.go index b48e2fa..8c910d3 100644 --- a/main.go +++ b/main.go @@ -15,9 +15,15 @@ package main import ( + "fmt" + "os" + "github.com/ncode/pretty/cmd" ) func main() { - cmd.Execute() + if err := cmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..a96fad3 --- /dev/null +++ b/main_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "errors" + "os" + "os/exec" + "strings" + "testing" +) + +func TestMainExitsOnExecuteError(t *testing.T) { + if os.Getenv("PRETTY_MAIN_CHILD") == "1" { + os.Args = []string{"pretty"} + main() + return + } + + cmd := exec.Command(os.Args[0], "-test.run=TestMainExitsOnExecuteError") + cmd.Env = append(os.Environ(), "PRETTY_MAIN_CHILD=1") + out, err := cmd.CombinedOutput() + + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected process exit error, got %v", err) + } + if exitErr.ExitCode() != 1 { + t.Fatalf("expected exit code 1, got %d", exitErr.ExitCode()) + } + if !strings.Contains(string(out), "requires at least one host") { + t.Fatalf("expected error output, got %q", string(out)) + } +}