From c747bad9778f4902b380e790b9588e301dfd8390 Mon Sep 17 00:00:00 2001 From: Heitor Danilo Date: Tue, 26 Sep 2023 12:02:36 -0300 Subject: [PATCH] feat(ssh, agent): handle commands with tty --- pkg/agent/server/modes/host/pty.go | 66 ++++++++++++-- pkg/agent/server/modes/host/sessioner.go | 106 +++++++++++++++-------- pkg/agent/server/session.go | 1 + ssh/server/handler/ssh.go | 32 +++++-- ssh/session/session.go | 7 +- 5 files changed, 153 insertions(+), 59 deletions(-) diff --git a/pkg/agent/server/modes/host/pty.go b/pkg/agent/server/modes/host/pty.go index 99d88c6af60..957fd908fb4 100644 --- a/pkg/agent/server/modes/host/pty.go +++ b/pkg/agent/server/modes/host/pty.go @@ -6,13 +6,13 @@ import ( "os/exec" "syscall" - "github.com/creack/pty" - "github.com/gliderlabs/ssh" - "github.com/sirupsen/logrus" + creackpty "github.com/creack/pty" + glidderssh "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" ) func openPty(c *exec.Cmd) (*os.File, *os.File, error) { - ptmx, tty, err := pty.Open() + ptmx, tty, err := creackpty.Open() if err != nil { return nil, nil, err } @@ -44,7 +44,7 @@ func openPty(c *exec.Cmd) (*os.File, *os.File, error) { return ptmx, tty, err } -func startPty(c *exec.Cmd, out io.ReadWriter, winCh <-chan ssh.Window) (*os.File, error) { +func startPty(c *exec.Cmd, out io.ReadWriter, winCh <-chan glidderssh.Window) (*os.File, error) { f, tty, err := openPty(c) if err != nil { return nil, err @@ -52,23 +52,73 @@ func startPty(c *exec.Cmd, out io.ReadWriter, winCh <-chan ssh.Window) (*os.File go func() { for win := range winCh { - _ = pty.Setsize(f, &pty.Winsize{uint16(win.Height), uint16(win.Width), 0, 0}) + _ = creackpty.Setsize(f, &creackpty.Winsize{uint16(win.Height), uint16(win.Width), 0, 0}) } }() go func() { _, err := io.Copy(out, f) if err != nil { - logrus.Warn(err) + log.Warn(err) } }() go func() { _, err := io.Copy(f, out) if err != nil { - logrus.Warn(err) + log.Warn(err) } }() return tty, nil } + +// initPty initializes and configures a new pseudo-terminal (PTY) for the provided command. Returns a pty and its corresponding tty. +func initPty(c *exec.Cmd, sess io.ReadWriter, winCh <-chan glidderssh.Window) (*os.File, *os.File, error) { + pty, tty, err := creackpty.Open() + if err != nil { + return nil, nil, err + } + + if c.Stdout == nil { + c.Stdout = tty + } + if c.Stderr == nil { + c.Stderr = tty + } + if c.Stdin == nil { + c.Stdin = tty + } + + if c.SysProcAttr == nil { + c.SysProcAttr = &syscall.SysProcAttr{} + } + + c.SysProcAttr.Setsid = true + c.SysProcAttr.Setctty = true + + // listen for window size changes from the SSH client and update the PTY's dimensions. + go func() { + for win := range winCh { + _ = creackpty.Setsize(pty, &creackpty.Winsize{uint16(win.Height), uint16(win.Width), 0, 0}) + } + }() + + // forward the command's output to the SSH session + go func() { + _, err := io.Copy(sess, pty) + if err != nil { + log.Warn(err) + } + }() + + // forward the input from the SSH session to the command + go func() { + _, err := io.Copy(pty, sess) + if err != nil { + log.Warn(err) + } + }() + + return pty, tty, nil +} diff --git a/pkg/agent/server/modes/host/sessioner.go b/pkg/agent/server/modes/host/sessioner.go index baaef188aa4..88c73cad4f0 100644 --- a/pkg/agent/server/modes/host/sessioner.go +++ b/pkg/agent/server/modes/host/sessioner.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "os/user" + "strings" "sync" gliderssh "github.com/gliderlabs/ssh" @@ -59,7 +60,7 @@ func NewSessioner(deviceName *string, cmds map[string]*exec.Cmd) *Sessioner { } } -// Shell handles the server's SSH shell session when server is running in host mode. +// Shell manages the SSH shell session of the server when operating in host mode. func (s *Sessioner) Shell(session gliderssh.Session) error { sspty, winCh, isPty := session.Pty() @@ -183,7 +184,6 @@ func (s *Sessioner) Heredoc(session gliderssh.Session) error { // Exec handles the SSH's server exec session when server is running in host mode. func (s *Sessioner) Exec(session gliderssh.Session) error { - u := new(osauth.OSAuth).LookupUser(session.User()) if len(session.Command()) == 0 { log.WithFields(log.Fields{ "user": session.User(), @@ -196,65 +196,95 @@ func (s *Sessioner) Exec(session gliderssh.Session) error { return nil } - cmd := command.NewCmd(u, "", "", *s.deviceName, session.Command()...) + user := new(osauth.OSAuth).LookupUser(session.User()) + sPty, sWinCh, sIsPty := session.Pty() - stdout, _ := cmd.StdoutPipe() - stdin, _ := cmd.StdinPipe() - stderr, _ := cmd.StderrPipe() + shell := os.Getenv("SHELL") + if shell == "" { + shell = user.Shell + } - serverConn, ok := session.Context().Value(gliderssh.ContextKeyConn).(*gossh.ServerConn) - if !ok { - return fmt.Errorf("failed to get server connection from session context") + term := sPty.Term + if sIsPty && term == "" { + term = "xterm" + } + + cmd := command.NewCmd(user, shell, term, *s.deviceName, shell, "-c", strings.Join(session.Command(), " ")) + defer session.Exit(cmd.ProcessState.ExitCode()) //nolint:errcheck + + wg := &sync.WaitGroup{} + if sIsPty { + pty, tty, err := initPty(cmd, session, sWinCh) + if err != nil { + log.Warn(err) + } + + defer tty.Close() + defer pty.Close() + + if err := os.Chown(tty.Name(), int(user.UID), -1); err != nil { + log.Warn(err) + } + } else { + stdout, _ := cmd.StdoutPipe() + stdin, _ := cmd.StdinPipe() + stderr, _ := cmd.StderrPipe() + + // relay input from the SSH session to the command. + go func() { + if _, err := io.Copy(stdin, session); err != nil { + fmt.Println(err) //nolint:forbidigo + } + + stdin.Close() + }() + + wg.Add(1) + + // relay the command's combined output and error streams back to the SSH session. + go func() { + defer wg.Done() + combinedOutput := io.MultiReader(stdout, stderr) + if _, err := io.Copy(session, combinedOutput); err != nil { + fmt.Println(err) //nolint:forbidigo + } + }() } log.WithFields(log.Fields{ "user": session.User(), + "ispty": sIsPty, "remoteaddr": session.RemoteAddr(), "localaddr": session.LocalAddr(), "Raw command": session.RawCommand(), }).Info("Command started") - err := cmd.Start() - if err != nil { - log.Warn(err) + if err := cmd.Start(); err != nil { + return err } - go func() { - serverConn.Wait() // nolint:errcheck - cmd.Process.Kill() // nolint:errcheck - }() - - go func() { - if _, err := io.Copy(stdin, session); err != nil { - fmt.Println(err) //nolint:forbidigo - } - - stdin.Close() - }() + if !sIsPty { + wg.Wait() + } - wg := &sync.WaitGroup{} - wg.Add(1) + serverConn, ok := session.Context().Value(gliderssh.ContextKeyConn).(*gossh.ServerConn) + if !ok { + return fmt.Errorf("failed to get server connection from session context") + } + // kill the process if the SSH connection is interrupted go func() { - combinedOutput := io.MultiReader(stdout, stderr) - if _, err := io.Copy(session, combinedOutput); err != nil { - fmt.Println(err) //nolint:forbidigo - } - - wg.Done() + serverConn.Wait() // nolint:errcheck + cmd.Process.Kill() // nolint:errcheck }() - wg.Wait() - - err = cmd.Wait() - if err != nil { + if err := cmd.Wait(); err != nil { log.Warn(err) } - session.Exit(cmd.ProcessState.ExitCode()) //nolint:errcheck - log.WithFields(log.Fields{ "user": session.User(), + "ispty": sIsPty, "remoteaddr": session.RemoteAddr(), "localaddr": session.LocalAddr(), "Raw command": session.RawCommand(), diff --git a/pkg/agent/server/session.go b/pkg/agent/server/session.go index ee657ed57c4..9e8f449775e 100644 --- a/pkg/agent/server/session.go +++ b/pkg/agent/server/session.go @@ -28,6 +28,7 @@ const ( // GetSessionType returns the session's type based on the SSH client session. func GetSessionType(session gliderssh.Session) (Type, error) { _, _, isPty := session.Pty() + requestType, ok := session.Context().Value("request_type").(string) if !ok { return SessionTypeUnknown, fmt.Errorf("failed to get request type from session context") diff --git a/ssh/server/handler/ssh.go b/ssh/server/handler/ssh.go index ffd3f88c9f3..24e408b7545 100644 --- a/ssh/server/handler/ssh.go +++ b/ssh/server/handler/ssh.go @@ -233,6 +233,16 @@ func isUnknownExitError(err error) bool { return err != nil } +func resizeWindow(uid string, agent *gossh.Session, winCh <-chan gliderssh.Window) { + for win := range winCh { + if err := agent.WindowChange(win.Height, win.Width); err != nil { + log.WithError(err). + WithFields(log.Fields{"client": uid}). + Error("failed to send WindowChange") + } + } +} + // shell handles an interactive terminal session. func shell(api internalclient.Client, sess *session.Session, agent *gossh.Session, client gliderssh.Session, opts ConfigOptions) error { uid := sess.UID @@ -247,15 +257,7 @@ func shell(api internalclient.Client, sess *session.Session, agent *gossh.Sessio return err } - go func() { - for win := range winCh { - if err := agent.WindowChange(win.Height, win.Width); err != nil { - log.WithError(err). - WithFields(log.Fields{"client": uid}). - Error("failed to send WindowChange") - } - } - }() + go resizeWindow(uid, agent, winCh) flw, err := flow.NewFlow(agent) if err != nil { @@ -394,6 +396,14 @@ func exec(api internalclient.Client, sess *session.Session, device *models.Devic return err } + // request a new pty when isPty is true + pty, winCh, isPty := client.Pty() + if isPty { + if err := agent.RequestPty(pty.Term, pty.Window.Height, pty.Window.Width, gossh.TerminalModes{}); err != nil { + return err + } + } + dev, err := api.GetDevice(device.UID) if err != nil { log.WithError(err). @@ -403,6 +413,10 @@ func exec(api internalclient.Client, sess *session.Session, device *models.Devic return err } + if isPty { + go resizeWindow(uid, agent, winCh) + } + waitPipeIn := make(chan bool) waitPipeOut := make(chan bool) diff --git a/ssh/session/session.go b/ssh/session/session.go index 6b6eca32f3f..0f705950872 100644 --- a/ssh/session/session.go +++ b/ssh/session/session.go @@ -50,8 +50,8 @@ type Session struct { const ( Web = "web" // web terminal. - Term = "term" // iterative pty. - Exec = "exec" // non-iterative pty. + Term = "term" // interactive session + Exec = "exec" // command execution HereDoc = "heredoc" // heredoc pty. SCP = "scp" // scp. SFTP = "sftp" // sftp subsystem. @@ -99,7 +99,7 @@ func (s *Session) setType() { s.Type = SCP case !s.Pty && metadata.RestoreRequest(ctx) == "shell": s.Type = HereDoc - case !s.Pty && cmd != "": + case cmd != "": s.Type = Exec case s.Pty: s.Type = Term @@ -130,7 +130,6 @@ func NewSession(client gliderssh.Session, tunnel *httptunnel.Tunnel) (*Session, lookup := metadata.RestoreLookup(clientCtx) lookup["username"] = tag.Username - // TODO: probabily this need an if lookup["ip_address"] = hos.Host if envs.IsCloud() || envs.IsEnterprise() {