--- /dev/null
+// Simple RPC-like protocol: establish new connection and upload helper
+
+// Copyright (C) 2021 Simon Ruderich
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package rpc
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/sha512"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "os/exec"
+ "strconv"
+ "strings"
+
+ "ruderich.org/simon/safcm"
+ "ruderich.org/simon/safcm/remote"
+)
+
+func (c *Conn) DialSSH(remote string) error {
+ if c.events == nil {
+ return fmt.Errorf("cannot reuse Conn")
+ }
+
+ c.debugf("DialSSH: connecting to %q", remote)
+
+ opts := "-eu"
+ if c.debug {
+ // Help debugging by showing executed shell commands
+ opts += "x"
+ }
+ c.cmd = exec.Command("ssh", remote, "/bin/sh", opts)
+ c.remote = remote
+
+ stdin, err := c.cmd.StdinPipe()
+ if err != nil {
+ return err
+ }
+ stdout, err := c.cmd.StdoutPipe()
+ if err != nil {
+ return err
+ }
+ err = c.handleStderrAsEvents(c.cmd)
+ if err != nil {
+ return err
+ }
+
+ err = c.cmd.Start()
+ if err != nil {
+ return err
+ }
+
+ err = c.dialSSH(stdin, stdout)
+ if err != nil {
+ c.Kill()
+ return err
+ }
+ c.conn = safcm.NewGobConn(stdout, stdin)
+
+ return nil
+}
+
+func (c *Conn) dialSSH(stdin io.Writer, stdout_ io.Reader) error {
+ stdout := bufio.NewReader(stdout_)
+
+ goos, err := connGetGoos(stdin, stdout)
+ if err != nil {
+ return err
+ }
+ goarch, err := connGetGoarch(stdin, stdout)
+ if err != nil {
+ return err
+ }
+ uid, err := connGetUID(stdin, stdout)
+ if err != nil {
+ return err
+ }
+
+ path := fmt.Sprintf("/tmp/safcm-remote-%d", uid)
+
+ c.debugf("DialSSH: probing remote at %q", path)
+ // Use a function so the shell cannot execute the input line-wise.
+ // This is important because we're also using stdin to send data to
+ // the script. If the shell executes the input line-wise then our
+ // script is interpreted as input for `read`.
+ //
+ // The target directory must no permit other users to delete our files
+ // or symlink attacks and arbitrary code execution is possible. For
+ // /tmp this is guaranteed by the sticky bit. Make sure it has the
+ // proper permissions.
+ //
+ // We cannot use `test -f && test -O` because this is open to TOCTOU
+ // attacks. `stat` gives use the full file state. If the file is owned
+ // by us and not a symlink then it's safe to use (assuming sticky or
+ // directory not writable by others).
+ //
+ // `test -e` is only used to prevent error messages if the file
+ // doesn't exist. It does not guard against any races.
+ _, err = fmt.Fprintf(stdin, `
+f() {
+ x=%q
+
+ dir="$(dirname "$x")"
+ if ! test "$(stat -c '%%A %%u %%g' "$dir")" = 'drwxrwxrwt 0 0'; then
+ echo "unsafe permissions on $dir, aborting" >&2
+ exit 1
+ fi
+
+ if test -e "$x" && test "$(stat -c '%%A %%u' "$x")" = "-rwx------ $(id -u)"; then
+ # Report checksum
+ sha512sum "$x"
+ else
+ # Empty checksum to request upload
+ echo
+ fi
+
+ # Wait for signal to continue
+ read upload
+
+ if test -n "$upload"; then
+ tmp="$(mktemp "$x.XXXXXX")"
+ # Report filename for upload
+ echo "$tmp"
+
+ # Wait for upload to complete
+ read unused
+
+ # Safely create new file (ln does not follow symlinks)
+ rm -f "$x"
+ ln "$tmp" "$x"
+ rm "$tmp"
+ # Make file executable
+ chmod 0700 "$x"
+ fi
+
+ exec "$x"
+}
+f
+`, path)
+ if err != nil {
+ return err
+ }
+ remoteSum, err := stdout.ReadString('\n')
+ if err != nil {
+ return err
+ }
+
+ // Get embedded helper binary
+ helper, err := remote.Helpers.ReadFile(
+ fmt.Sprintf("helpers/%s-%s", goos, goarch))
+ if err != nil {
+ return fmt.Errorf("remote not built for GOOS/GOARCH %s/%s",
+ goos, goarch)
+ }
+
+ var upload bool
+ if remoteSum == "\n" {
+ upload = true
+ c.debugf("DialSSH: remote not present or invalid permissions")
+
+ } else {
+ x := strings.Fields(remoteSum)
+ if len(x) < 1 {
+ return fmt.Errorf("got unexpected checksum line %q",
+ remoteSum)
+ }
+ sha := sha512.Sum512(helper)
+ hex := hex.EncodeToString(sha[:])
+ if hex == x[0] {
+ c.debugf("DialSSH: remote checksum matches")
+ } else {
+ upload = true
+ c.debugf("DialSSH: remote checksum does not match")
+ }
+ }
+
+ if upload {
+ // Notify user that an upload is going to take place.
+ c.events <- ConnEvent{
+ Type: ConnEventUpload,
+ }
+
+ // Tell script we want to upload a new file.
+ _, err = fmt.Fprintln(stdin, "upload")
+ if err != nil {
+ return err
+ }
+ // Get path to temporary file for upload.
+ //
+ // Write to the temporary file instead of the final path so
+ // that a concurrent run of this function won't use a
+ // partially written file. The rm in the script could still
+ // cause a missing file but at least no file with unknown
+ // content is executed.
+ path, err := stdout.ReadString('\n')
+ if err != nil {
+ return err
+ }
+ path = strings.TrimSuffix(path, "\n")
+
+ c.debugf("DialSSH: uploading new remote to %q at %q",
+ c.remote, path)
+
+ cmd := exec.Command("ssh", c.remote,
+ fmt.Sprintf("cat > %q", path))
+ cmd.Stdin = bytes.NewReader(helper)
+ err = c.handleStderrAsEvents(cmd)
+ if err != nil {
+ return err
+ }
+ err = cmd.Run()
+ if err != nil {
+ return err
+ }
+ }
+
+ // Tell script to continue and execute the remote helper
+ _, err = fmt.Fprintln(stdin, "")
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func connGetGoos(stdin io.Writer, stdout *bufio.Reader) (string, error) {
+ _, err := fmt.Fprintln(stdin, "uname -o")
+ if err != nil {
+ return "", err
+ }
+ x, err := stdout.ReadString('\n')
+ if err != nil {
+ return "", err
+ }
+ x = strings.TrimSpace(x)
+
+ // NOTE: Adapt helper uploading in dialSSH() when adding new systems
+ var goos string
+ switch x {
+ case "GNU/Linux":
+ goos = "linux"
+ default:
+ return "", fmt.Errorf("unsupported OS %q (`uname -o`)", x)
+ }
+ return goos, nil
+}
+
+func connGetGoarch(stdin io.Writer, stdout *bufio.Reader) (string, error) {
+ _, err := fmt.Fprintln(stdin, "uname -m")
+ if err != nil {
+ return "", err
+ }
+ x, err := stdout.ReadString('\n')
+ if err != nil {
+ return "", err
+ }
+ x = strings.TrimSpace(x)
+
+ // NOTE: Adapt cmd/safcm-remote/build.sh when adding new architectures
+ var goarch string
+ switch x {
+ case "x86_64":
+ goarch = "amd64"
+ default:
+ return "", fmt.Errorf("unsupported arch %q (`uname -m`)", x)
+ }
+ return goarch, nil
+}
+
+func connGetUID(stdin io.Writer, stdout *bufio.Reader) (int, error) {
+ _, err := fmt.Fprintln(stdin, "id -u")
+ if err != nil {
+ return -1, err
+ }
+ x, err := stdout.ReadString('\n')
+ if err != nil {
+ return -1, err
+ }
+ x = strings.TrimSpace(x)
+
+ uid, err := strconv.Atoi(x)
+ if err != nil {
+ return -1, fmt.Errorf("invalid UID %q (`id -u`)", x)
+ }
+ return uid, nil
+}