]> ruderich.org/simon Gitweb - safcm/safcm.git/blobdiff - rpc/dial.go
rpc: use SSHConfig struct as argument to DialSSH()
[safcm/safcm.git] / rpc / dial.go
index b723e55c5657e78609d79b8019448636c3de5eef..c0c6cd4088b2859d758aa0d06cf64405a2b70a16 100644 (file)
@@ -32,11 +32,21 @@ import (
        "ruderich.org/simon/safcm/remote"
 )
 
-func (c *Conn) DialSSH(remote string) error {
+type SSHConfig struct {
+       Host      string
+       User      string // optional
+       SshConfig string // optional
+}
+
+func (c *Conn) DialSSH(cfg SSHConfig) error {
        if c.events == nil {
                return fmt.Errorf("cannot reuse Conn")
        }
 
+       remote := cfg.Host
+       if cfg.User != "" {
+               remote = cfg.User + "@" + cfg.Host
+       }
        c.debugf("DialSSH: connecting to %q", remote)
 
        opts := "-eu"
@@ -44,8 +54,14 @@ func (c *Conn) DialSSH(remote string) error {
                // Help debugging by showing executed shell commands
                opts += "x"
        }
-       c.cmd = exec.Command("ssh", remote, "/bin/sh", opts)
-       c.remote = remote
+
+       c.sshRemote = remote
+       if cfg.SshConfig != "" {
+               c.sshOpts = []string{"-F", cfg.SshConfig}
+       }
+       c.cmd = exec.Command("ssh",
+               append(append([]string{}, c.sshOpts...),
+                       c.sshRemote, "/bin/sh", opts)...)
 
        stdin, err := c.cmd.StdinPipe()
        if err != nil {
@@ -94,6 +110,36 @@ func (c *Conn) dialSSH(stdin io.Writer, stdout_ io.Reader) error {
        path := fmt.Sprintf("/tmp/safcm-remote-%d", uid)
 
        c.debugf("DialSSH: probing remote at %q", path)
+
+       // Compatibility for different operating systems
+       var compat string
+       switch goos {
+       case "linux":
+               compat = `
+dir_stat='drwxrwxrwt 0 0'
+file_stat="-rwx------ $(id -u) $(id -g)"
+compat_stat() {
+       stat -c '%A %u %g' "$1"
+}
+compat_sha512sum() {
+       sha512sum "$1"
+}
+`
+       case "freebsd", "openbsd":
+               compat = `
+dir_stat='41777 0 0'
+file_stat="100700 $(id -u) $(id -g)"
+compat_stat() {
+       stat -f '%p %u %g' "$1"
+}
+compat_sha512sum() {
+       sha512 -q "$1"
+}
+`
+       default:
+               return fmt.Errorf("internal error: no support for %q", goos)
+       }
+
        // Use a function so the shell cannot execute the input line-wise.
        // This is important because we're also using stdin to send data to
        // the script. If the shell executes the input line-wise then our
@@ -101,29 +147,30 @@ func (c *Conn) dialSSH(stdin io.Writer, stdout_ io.Reader) error {
        //
        // The target directory must no permit other users to delete our files
        // or symlink attacks and arbitrary code execution is possible. For
-       // /tmp this is guaranteed by the sticky bit. Make sure it has the
-       // proper permissions.
+       // /tmp this is guaranteed by the sticky bit. The code verifies the
+       // directory has the proper permissions.
        //
        // We cannot use `test -f && test -O` because this is open to TOCTOU
        // attacks. `stat` gives use the full file state. If the file is owned
-       // by us and not a symlink then it's safe to use (assuming sticky or
-       // directory not writable by others).
+       // by us and not a symlink then it's safe to use (assuming sticky
+       // directory or directory not writable by others).
        //
        // `test -e` is only used to prevent error messages if the file
        // doesn't exist. It does not guard against any races.
        _, err = fmt.Fprintf(stdin, `
+%s
 f() {
        x=%q
 
        dir="$(dirname "$x")"
-       if ! test "$(stat -c '%%A %%u %%g' "$dir")" = 'drwxrwxrwt 0 0'; then
+       if ! test "$(compat_stat "$dir")" = "$dir_stat"; then
                echo "unsafe permissions on $dir, aborting" >&2
                exit 1
        fi
 
-       if test -e "$x" && test "$(stat -c '%%A %%u' "$x")" = "-rwx------ $(id -u)"; then
+       if test -e "$x" && test "$(compat_stat "$x")" = "$file_stat"; then
                # Report checksum
-               sha512sum "$x"
+               compat_sha512sum "$x"
        else
                # Empty checksum to request upload
                echo
@@ -136,7 +183,6 @@ f() {
                tmp="$(mktemp "$x.XXXXXX")"
                # Report filename for upload
                echo "$tmp"
-
                # Wait for upload to complete
                read unused
 
@@ -146,12 +192,14 @@ f() {
                rm "$tmp"
                # Make file executable
                chmod 0700 "$x"
+               # Some BSD create files with group wheel in /tmp
+               chgrp "$(id -g)" "$x"
        fi
 
-       exec "$x"
+       exec "$x" sync
 }
 f
-`, path)
+`, compat, path)
        if err != nil {
                return err
        }
@@ -214,12 +262,14 @@ f
                path = strings.TrimSuffix(path, "\n")
 
                c.debugf("DialSSH: uploading new remote to %q at %q",
-                       c.remote, path)
+                       c.sshRemote, path)
 
-               cmd := exec.Command("ssh", c.remote,
-                       fmt.Sprintf("cat > %q", path))
+               cmd := exec.Command("ssh",
+                       append(append([]string{}, c.sshOpts...),
+                               c.sshRemote,
+                               fmt.Sprintf("cat > %q", path))...)
                cmd.Stdin = bytes.NewReader(helper)
-               err = c.handleStderrAsEvents(cmd)
+               err = c.handleStderrAsEvents(cmd) // cmd.Stderr
                if err != nil {
                        return err
                }
@@ -239,7 +289,7 @@ f
 }
 
 func connGetGoos(stdin io.Writer, stdout *bufio.Reader) (string, error) {
-       _, err := fmt.Fprintln(stdin, "uname -o")
+       _, err := fmt.Fprintln(stdin, "uname")
        if err != nil {
                return "", err
        }
@@ -252,10 +302,14 @@ func connGetGoos(stdin io.Writer, stdout *bufio.Reader) (string, error) {
        // NOTE: Adapt helper uploading in dialSSH() when adding new systems
        var goos string
        switch x {
-       case "GNU/Linux":
+       case "Linux":
                goos = "linux"
+       case "FreeBSD":
+               goos = "freebsd"
+       case "OpenBSD":
+               goos = "openbsd"
        default:
-               return "", fmt.Errorf("unsupported OS %q (`uname -o`)", x)
+               return "", fmt.Errorf("unsupported OS %q (`uname`)", x)
        }
        return goos, nil
 }
@@ -274,7 +328,7 @@ func connGetGoarch(stdin io.Writer, stdout *bufio.Reader) (string, error) {
        // NOTE: Adapt cmd/safcm-remote/build.sh when adding new architectures
        var goarch string
        switch x {
-       case "x86_64":
+       case "x86_64", "amd64":
                goarch = "amd64"
        case "armv7l":
                goarch = "armv7l"