Skip to content

Commit

Permalink
feat(ssh, agent): handle commands with tty
Browse files Browse the repository at this point in the history
  • Loading branch information
heiytor authored and gustavosbarreto committed Sep 28, 2023
1 parent 7e66ead commit c747bad
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 59 deletions.
66 changes: 58 additions & 8 deletions pkg/agent/server/modes/host/pty.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
"os/exec"
"syscall"

"github.com/creack/pty"
"github.com/gliderlabs/ssh"
"github.com/sirupsen/logrus"
creackpty "github.com/creack/pty"
glidderssh "github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)

func openPty(c *exec.Cmd) (*os.File, *os.File, error) {
ptmx, tty, err := pty.Open()
ptmx, tty, err := creackpty.Open()
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -44,31 +44,81 @@ func openPty(c *exec.Cmd) (*os.File, *os.File, error) {
return ptmx, tty, err
}

func startPty(c *exec.Cmd, out io.ReadWriter, winCh <-chan ssh.Window) (*os.File, error) {
func startPty(c *exec.Cmd, out io.ReadWriter, winCh <-chan glidderssh.Window) (*os.File, error) {
f, tty, err := openPty(c)
if err != nil {
return nil, err
}

go func() {
for win := range winCh {
_ = pty.Setsize(f, &pty.Winsize{uint16(win.Height), uint16(win.Width), 0, 0})
_ = creackpty.Setsize(f, &creackpty.Winsize{uint16(win.Height), uint16(win.Width), 0, 0})
}
}()

go func() {
_, err := io.Copy(out, f)
if err != nil {
logrus.Warn(err)
log.Warn(err)
}
}()

go func() {
_, err := io.Copy(f, out)
if err != nil {
logrus.Warn(err)
log.Warn(err)
}
}()

return tty, nil
}

// initPty initializes and configures a new pseudo-terminal (PTY) for the provided command. Returns a pty and its corresponding tty.
func initPty(c *exec.Cmd, sess io.ReadWriter, winCh <-chan glidderssh.Window) (*os.File, *os.File, error) {
pty, tty, err := creackpty.Open()
if err != nil {
return nil, nil, err
}

if c.Stdout == nil {
c.Stdout = tty
}
if c.Stderr == nil {
c.Stderr = tty
}
if c.Stdin == nil {
c.Stdin = tty
}

if c.SysProcAttr == nil {
c.SysProcAttr = &syscall.SysProcAttr{}
}

c.SysProcAttr.Setsid = true
c.SysProcAttr.Setctty = true

// listen for window size changes from the SSH client and update the PTY's dimensions.
go func() {
for win := range winCh {
_ = creackpty.Setsize(pty, &creackpty.Winsize{uint16(win.Height), uint16(win.Width), 0, 0})
}
}()

// forward the command's output to the SSH session
go func() {
_, err := io.Copy(sess, pty)
if err != nil {
log.Warn(err)
}
}()

// forward the input from the SSH session to the command
go func() {
_, err := io.Copy(pty, sess)
if err != nil {
log.Warn(err)
}
}()

return pty, tty, nil
}
106 changes: 68 additions & 38 deletions pkg/agent/server/modes/host/sessioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"os/exec"
"os/user"
"strings"
"sync"

gliderssh "github.com/gliderlabs/ssh"
Expand Down Expand Up @@ -59,7 +60,7 @@ func NewSessioner(deviceName *string, cmds map[string]*exec.Cmd) *Sessioner {
}
}

// Shell handles the server's SSH shell session when server is running in host mode.
// Shell manages the SSH shell session of the server when operating in host mode.
func (s *Sessioner) Shell(session gliderssh.Session) error {
sspty, winCh, isPty := session.Pty()

Expand Down Expand Up @@ -183,7 +184,6 @@ func (s *Sessioner) Heredoc(session gliderssh.Session) error {

// Exec handles the SSH's server exec session when server is running in host mode.
func (s *Sessioner) Exec(session gliderssh.Session) error {
u := new(osauth.OSAuth).LookupUser(session.User())
if len(session.Command()) == 0 {
log.WithFields(log.Fields{
"user": session.User(),
Expand All @@ -196,65 +196,95 @@ func (s *Sessioner) Exec(session gliderssh.Session) error {
return nil
}

cmd := command.NewCmd(u, "", "", *s.deviceName, session.Command()...)
user := new(osauth.OSAuth).LookupUser(session.User())
sPty, sWinCh, sIsPty := session.Pty()

stdout, _ := cmd.StdoutPipe()
stdin, _ := cmd.StdinPipe()
stderr, _ := cmd.StderrPipe()
shell := os.Getenv("SHELL")
if shell == "" {
shell = user.Shell
}

serverConn, ok := session.Context().Value(gliderssh.ContextKeyConn).(*gossh.ServerConn)
if !ok {
return fmt.Errorf("failed to get server connection from session context")
term := sPty.Term
if sIsPty && term == "" {
term = "xterm"
}

cmd := command.NewCmd(user, shell, term, *s.deviceName, shell, "-c", strings.Join(session.Command(), " "))
defer session.Exit(cmd.ProcessState.ExitCode()) //nolint:errcheck

wg := &sync.WaitGroup{}
if sIsPty {
pty, tty, err := initPty(cmd, session, sWinCh)
if err != nil {
log.Warn(err)
}

defer tty.Close()
defer pty.Close()

if err := os.Chown(tty.Name(), int(user.UID), -1); err != nil {
log.Warn(err)
}
} else {
stdout, _ := cmd.StdoutPipe()
stdin, _ := cmd.StdinPipe()
stderr, _ := cmd.StderrPipe()

// relay input from the SSH session to the command.
go func() {
if _, err := io.Copy(stdin, session); err != nil {
fmt.Println(err) //nolint:forbidigo
}

stdin.Close()
}()

wg.Add(1)

// relay the command's combined output and error streams back to the SSH session.
go func() {
defer wg.Done()
combinedOutput := io.MultiReader(stdout, stderr)
if _, err := io.Copy(session, combinedOutput); err != nil {
fmt.Println(err) //nolint:forbidigo
}
}()
}

log.WithFields(log.Fields{
"user": session.User(),
"ispty": sIsPty,
"remoteaddr": session.RemoteAddr(),
"localaddr": session.LocalAddr(),
"Raw command": session.RawCommand(),
}).Info("Command started")

err := cmd.Start()
if err != nil {
log.Warn(err)
if err := cmd.Start(); err != nil {
return err
}

go func() {
serverConn.Wait() // nolint:errcheck
cmd.Process.Kill() // nolint:errcheck
}()

go func() {
if _, err := io.Copy(stdin, session); err != nil {
fmt.Println(err) //nolint:forbidigo
}

stdin.Close()
}()
if !sIsPty {
wg.Wait()
}

wg := &sync.WaitGroup{}
wg.Add(1)
serverConn, ok := session.Context().Value(gliderssh.ContextKeyConn).(*gossh.ServerConn)
if !ok {
return fmt.Errorf("failed to get server connection from session context")
}

// kill the process if the SSH connection is interrupted
go func() {
combinedOutput := io.MultiReader(stdout, stderr)
if _, err := io.Copy(session, combinedOutput); err != nil {
fmt.Println(err) //nolint:forbidigo
}

wg.Done()
serverConn.Wait() // nolint:errcheck
cmd.Process.Kill() // nolint:errcheck
}()

wg.Wait()

err = cmd.Wait()
if err != nil {
if err := cmd.Wait(); err != nil {
log.Warn(err)
}

session.Exit(cmd.ProcessState.ExitCode()) //nolint:errcheck

log.WithFields(log.Fields{
"user": session.User(),
"ispty": sIsPty,
"remoteaddr": session.RemoteAddr(),
"localaddr": session.LocalAddr(),
"Raw command": session.RawCommand(),
Expand Down
1 change: 1 addition & 0 deletions pkg/agent/server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (
// GetSessionType returns the session's type based on the SSH client session.
func GetSessionType(session gliderssh.Session) (Type, error) {
_, _, isPty := session.Pty()

requestType, ok := session.Context().Value("request_type").(string)
if !ok {
return SessionTypeUnknown, fmt.Errorf("failed to get request type from session context")
Expand Down
32 changes: 23 additions & 9 deletions ssh/server/handler/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,16 @@ func isUnknownExitError(err error) bool {
return err != nil
}

func resizeWindow(uid string, agent *gossh.Session, winCh <-chan gliderssh.Window) {
for win := range winCh {
if err := agent.WindowChange(win.Height, win.Width); err != nil {
log.WithError(err).
WithFields(log.Fields{"client": uid}).
Error("failed to send WindowChange")
}
}
}

// shell handles an interactive terminal session.
func shell(api internalclient.Client, sess *session.Session, agent *gossh.Session, client gliderssh.Session, opts ConfigOptions) error {
uid := sess.UID
Expand All @@ -247,15 +257,7 @@ func shell(api internalclient.Client, sess *session.Session, agent *gossh.Sessio
return err
}

go func() {
for win := range winCh {
if err := agent.WindowChange(win.Height, win.Width); err != nil {
log.WithError(err).
WithFields(log.Fields{"client": uid}).
Error("failed to send WindowChange")
}
}
}()
go resizeWindow(uid, agent, winCh)

flw, err := flow.NewFlow(agent)
if err != nil {
Expand Down Expand Up @@ -394,6 +396,14 @@ func exec(api internalclient.Client, sess *session.Session, device *models.Devic
return err
}

// request a new pty when isPty is true
pty, winCh, isPty := client.Pty()
if isPty {
if err := agent.RequestPty(pty.Term, pty.Window.Height, pty.Window.Width, gossh.TerminalModes{}); err != nil {
return err
}
}

dev, err := api.GetDevice(device.UID)
if err != nil {
log.WithError(err).
Expand All @@ -403,6 +413,10 @@ func exec(api internalclient.Client, sess *session.Session, device *models.Devic
return err
}

if isPty {
go resizeWindow(uid, agent, winCh)
}

waitPipeIn := make(chan bool)
waitPipeOut := make(chan bool)

Expand Down
7 changes: 3 additions & 4 deletions ssh/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ type Session struct {

const (
Web = "web" // web terminal.
Term = "term" // iterative pty.
Exec = "exec" // non-iterative pty.
Term = "term" // interactive session
Exec = "exec" // command execution
HereDoc = "heredoc" // heredoc pty.
SCP = "scp" // scp.
SFTP = "sftp" // sftp subsystem.
Expand Down Expand Up @@ -99,7 +99,7 @@ func (s *Session) setType() {
s.Type = SCP
case !s.Pty && metadata.RestoreRequest(ctx) == "shell":
s.Type = HereDoc
case !s.Pty && cmd != "":
case cmd != "":
s.Type = Exec
case s.Pty:
s.Type = Term
Expand Down Expand Up @@ -130,7 +130,6 @@ func NewSession(client gliderssh.Session, tunnel *httptunnel.Tunnel) (*Session,
lookup := metadata.RestoreLookup(clientCtx)

lookup["username"] = tag.Username
// TODO: probabily this need an if
lookup["ip_address"] = hos.Host

if envs.IsCloud() || envs.IsEnterprise() {
Expand Down

0 comments on commit c747bad

Please sign in to comment.