]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - rpc/dial.go
fc702d3518d0c875ba2b9ce1104971e740290fec
[safcm/safcm.git] / rpc / dial.go
1 // Simple RPC-like protocol: establish new connection and upload helper
2
3 // Copyright (C) 2021  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 package rpc
19
20 import (
21         "bufio"
22         "bytes"
23         "crypto/sha512"
24         "encoding/hex"
25         "fmt"
26         "io"
27         "io/fs"
28         "os/exec"
29         "strconv"
30         "strings"
31
32         "ruderich.org/simon/safcm"
33 )
34
35 type SSHConfig struct {
36         Host      string
37         User      string // optional
38         SshConfig string // optional
39
40         RemoteHelpers fs.FS
41 }
42
43 func (c *Conn) DialSSH(cfg SSHConfig) error {
44         if c.events == nil {
45                 return fmt.Errorf("cannot reuse Conn")
46         }
47
48         if cfg.RemoteHelpers == nil {
49                 return fmt.Errorf("SSHConfig.RemoteHelpers not set")
50         }
51         c.remoteHelpers = cfg.RemoteHelpers
52
53         remote := cfg.Host
54         if cfg.User != "" {
55                 remote = cfg.User + "@" + cfg.Host
56         }
57         c.debugf("DialSSH: connecting to %q", remote)
58
59         opts := "-eu"
60         if c.debug {
61                 // Help debugging by showing executed shell commands
62                 opts += "x"
63         }
64
65         c.sshRemote = remote
66         if cfg.SshConfig != "" {
67                 c.sshOpts = []string{"-F", cfg.SshConfig}
68         }
69         c.cmd = exec.Command("ssh",
70                 append(append([]string{}, c.sshOpts...),
71                         c.sshRemote, "/bin/sh", opts)...)
72
73         stdin, err := c.cmd.StdinPipe()
74         if err != nil {
75                 return err
76         }
77         stdout, err := c.cmd.StdoutPipe()
78         if err != nil {
79                 return err
80         }
81         err = c.handleStderrAsEvents(c.cmd)
82         if err != nil {
83                 return err
84         }
85
86         err = c.cmd.Start()
87         if err != nil {
88                 return err
89         }
90
91         err = c.dialSSH(stdin, stdout)
92         if err != nil {
93                 c.Kill() //nolint:errcheck
94                 return err
95         }
96         c.conn = safcm.NewGobConn(stdout, stdin)
97
98         return nil
99 }
100
101 func (c *Conn) dialSSH(stdin io.Writer, stdout_ io.Reader) error {
102         stdout := bufio.NewReader(stdout_)
103
104         goos, err := connGetGoos(stdin, stdout)
105         if err != nil {
106                 return err
107         }
108         goarch, err := connGetGoarch(stdin, stdout)
109         if err != nil {
110                 return err
111         }
112         uid, err := connGetUID(stdin, stdout)
113         if err != nil {
114                 return err
115         }
116
117         path := fmt.Sprintf("/tmp/safcm-remote-%d", uid)
118
119         c.debugf("DialSSH: probing remote at %q", path)
120
121         // Compatibility for different operating systems
122         var compat string
123         switch goos {
124         case "linux":
125                 compat = `
126 dir_stat='drwxrwxrwt 0 0'
127 file_stat="-rwx------ $(id -u) $(id -g)"
128 compat_stat() {
129         stat -c '%A %u %g' "$1"
130 }
131 compat_sha512sum() {
132         sha512sum "$1"
133 }
134 `
135         case "freebsd", "openbsd":
136                 compat = `
137 dir_stat='41777 0 0'
138 file_stat="100700 $(id -u) $(id -g)"
139 compat_stat() {
140         stat -f '%p %u %g' "$1"
141 }
142 compat_sha512sum() {
143         sha512 -q "$1"
144 }
145 `
146         default:
147                 return fmt.Errorf("internal error: no support for %q", goos)
148         }
149
150         // Use a function so the shell cannot execute the input line-wise.
151         // This is important because we're also using stdin to send data to
152         // the script. If the shell executes the input line-wise then our
153         // script is interpreted as input for `read`.
154         //
155         // The target directory must no permit other users to delete our files
156         // or symlink attacks and arbitrary code execution is possible. For
157         // /tmp this is guaranteed by the sticky bit. The code verifies the
158         // directory has the proper permissions.
159         //
160         // We cannot use `test -f && test -O` because this is open to TOCTOU
161         // attacks. `stat` gives use the full file state. If the file is owned
162         // by us and not a symlink then it's safe to use (assuming sticky
163         // directory or directory not writable by others).
164         //
165         // `test -e` is only used to prevent error messages if the file
166         // doesn't exist. It does not guard against any races.
167         _, err = fmt.Fprintf(stdin, `
168 %s
169 f() {
170         x=%q
171
172         dir="$(dirname "$x")"
173         if ! test "$(compat_stat "$dir")" = "$dir_stat"; then
174                 echo "unsafe permissions on $dir, aborting" >&2
175                 exit 1
176         fi
177
178         if test -e "$x" && test "$(compat_stat "$x")" = "$file_stat"; then
179                 # Report checksum
180                 compat_sha512sum "$x"
181         else
182                 # Empty checksum to request upload
183                 echo
184         fi
185
186         # Wait for signal to continue
187         read upload
188
189         if test -n "$upload"; then
190                 tmp="$(mktemp "$x.XXXXXX")"
191                 # Report filename for upload
192                 echo "$tmp"
193                 # Wait for upload to complete
194                 read unused
195
196                 # Safely create new file (ln does not follow symlinks)
197                 rm -f "$x"
198                 ln "$tmp" "$x"
199                 rm "$tmp"
200                 # Make file executable
201                 chmod 0700 "$x"
202                 # Some BSD create files with group wheel in /tmp
203                 chgrp "$(id -g)" "$x"
204         fi
205
206         exec "$x" sync
207 }
208 f
209 `, compat, path)
210         if err != nil {
211                 return err
212         }
213         remoteSum, err := stdout.ReadString('\n')
214         if err != nil {
215                 return err
216         }
217
218         // Get remote helper binary
219         helper, err := fs.ReadFile(c.remoteHelpers,
220                 fmt.Sprintf("%s-%s", goos, goarch))
221         if err != nil {
222                 return fmt.Errorf("remote not built for GOOS/GOARCH %s/%s",
223                         goos, goarch)
224         }
225
226         var upload bool
227         if remoteSum == "\n" {
228                 upload = true
229                 c.debugf("DialSSH: remote not present or invalid permissions")
230
231         } else {
232                 x := strings.Fields(remoteSum)
233                 if len(x) < 1 {
234                         return fmt.Errorf("got unexpected checksum line %q",
235                                 remoteSum)
236                 }
237                 sha := sha512.Sum512(helper)
238                 hex := hex.EncodeToString(sha[:])
239                 if hex == x[0] {
240                         c.debugf("DialSSH: remote checksum matches")
241                 } else {
242                         upload = true
243                         c.debugf("DialSSH: remote checksum does not match")
244                 }
245         }
246
247         if upload {
248                 // Notify user that an upload is going to take place.
249                 c.events <- ConnEvent{
250                         Type: ConnEventUpload,
251                 }
252
253                 // Tell script we want to upload a new file.
254                 _, err = fmt.Fprintln(stdin, "upload")
255                 if err != nil {
256                         return err
257                 }
258                 // Get path to temporary file for upload.
259                 //
260                 // Write to the temporary file instead of the final path so
261                 // that a concurrent run of this function won't use a
262                 // partially written file. The rm in the script could still
263                 // cause a missing file but at least no file with unknown
264                 // content is executed.
265                 path, err := stdout.ReadString('\n')
266                 if err != nil {
267                         return err
268                 }
269                 path = strings.TrimSuffix(path, "\n")
270
271                 c.debugf("DialSSH: uploading new remote to %q at %q",
272                         c.sshRemote, path)
273
274                 cmd := exec.Command("ssh",
275                         append(append([]string{}, c.sshOpts...),
276                                 c.sshRemote,
277                                 fmt.Sprintf("cat > %q", path))...)
278                 cmd.Stdin = bytes.NewReader(helper)
279                 err = c.handleStderrAsEvents(cmd) // cmd.Stderr
280                 if err != nil {
281                         return err
282                 }
283                 err = cmd.Run()
284                 if err != nil {
285                         return err
286                 }
287         }
288
289         // Tell script to continue and execute the remote helper
290         _, err = fmt.Fprintln(stdin, "")
291         if err != nil {
292                 return err
293         }
294
295         return nil
296 }
297
298 func connGetGoos(stdin io.Writer, stdout *bufio.Reader) (string, error) {
299         _, err := fmt.Fprintln(stdin, "uname")
300         if err != nil {
301                 return "", err
302         }
303         x, err := stdout.ReadString('\n')
304         if err != nil {
305                 return "", err
306         }
307         x = strings.TrimSpace(x)
308
309         // NOTE: Adapt helper uploading in dialSSH() when adding new systems
310         var goos string
311         switch x {
312         case "Linux":
313                 goos = "linux"
314         case "FreeBSD":
315                 goos = "freebsd"
316         case "OpenBSD":
317                 goos = "openbsd"
318         default:
319                 return "", fmt.Errorf("unsupported OS %q (`uname`)", x)
320         }
321         return goos, nil
322 }
323
324 func connGetGoarch(stdin io.Writer, stdout *bufio.Reader) (string, error) {
325         _, err := fmt.Fprintln(stdin, "uname -m")
326         if err != nil {
327                 return "", err
328         }
329         x, err := stdout.ReadString('\n')
330         if err != nil {
331                 return "", err
332         }
333         x = strings.TrimSpace(x)
334
335         // NOTE: Adapt cmd/safcm-remote/build.sh when adding new architectures
336         var goarch string
337         switch x {
338         case "x86_64", "amd64":
339                 goarch = "amd64"
340         case "armv7l":
341                 goarch = "armv7l"
342         default:
343                 return "", fmt.Errorf("unsupported arch %q (`uname -m`)", x)
344         }
345         return goarch, nil
346 }
347
348 func connGetUID(stdin io.Writer, stdout *bufio.Reader) (int, error) {
349         _, err := fmt.Fprintln(stdin, "id -u")
350         if err != nil {
351                 return -1, err
352         }
353         x, err := stdout.ReadString('\n')
354         if err != nil {
355                 return -1, err
356         }
357         x = strings.TrimSpace(x)
358
359         uid, err := strconv.Atoi(x)
360         if err != nil {
361                 return -1, fmt.Errorf("invalid UID %q (`id -u`)", x)
362         }
363         return uid, nil
364 }