]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - rpc/dial.go
36092793e005e06991a64fe440b97ca74fac6a88
[safcm/safcm.git] / rpc / dial.go
1 // Simple RPC-like protocol: establish new connection and upload helper
2
3 // Copyright (C) 2021  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 package rpc
19
20 import (
21         "bufio"
22         "bytes"
23         "crypto/sha512"
24         "encoding/hex"
25         "fmt"
26         "io"
27         "os/exec"
28         "strconv"
29         "strings"
30
31         "ruderich.org/simon/safcm"
32         "ruderich.org/simon/safcm/remote"
33 )
34
35 func (c *Conn) DialSSH(remote string) error {
36         if c.events == nil {
37                 return fmt.Errorf("cannot reuse Conn")
38         }
39
40         c.debugf("DialSSH: connecting to %q", remote)
41
42         opts := "-eu"
43         if c.debug {
44                 // Help debugging by showing executed shell commands
45                 opts += "x"
46         }
47         c.cmd = exec.Command("ssh", remote, "/bin/sh", opts)
48         c.remote = remote
49
50         stdin, err := c.cmd.StdinPipe()
51         if err != nil {
52                 return err
53         }
54         stdout, err := c.cmd.StdoutPipe()
55         if err != nil {
56                 return err
57         }
58         err = c.handleStderrAsEvents(c.cmd)
59         if err != nil {
60                 return err
61         }
62
63         err = c.cmd.Start()
64         if err != nil {
65                 return err
66         }
67
68         err = c.dialSSH(stdin, stdout)
69         if err != nil {
70                 c.Kill()
71                 return err
72         }
73         c.conn = safcm.NewGobConn(stdout, stdin)
74
75         return nil
76 }
77
78 func (c *Conn) dialSSH(stdin io.Writer, stdout_ io.Reader) error {
79         stdout := bufio.NewReader(stdout_)
80
81         goos, err := connGetGoos(stdin, stdout)
82         if err != nil {
83                 return err
84         }
85         goarch, err := connGetGoarch(stdin, stdout)
86         if err != nil {
87                 return err
88         }
89         uid, err := connGetUID(stdin, stdout)
90         if err != nil {
91                 return err
92         }
93
94         path := fmt.Sprintf("/tmp/safcm-remote-%d", uid)
95
96         c.debugf("DialSSH: probing remote at %q", path)
97         // Use a function so the shell cannot execute the input line-wise.
98         // This is important because we're also using stdin to send data to
99         // the script. If the shell executes the input line-wise then our
100         // script is interpreted as input for `read`.
101         //
102         // The target directory must no permit other users to delete our files
103         // or symlink attacks and arbitrary code execution is possible. For
104         // /tmp this is guaranteed by the sticky bit. Make sure it has the
105         // proper permissions.
106         //
107         // We cannot use `test -f && test -O` because this is open to TOCTOU
108         // attacks. `stat` gives use the full file state. If the file is owned
109         // by us and not a symlink then it's safe to use (assuming sticky or
110         // directory not writable by others).
111         //
112         // `test -e` is only used to prevent error messages if the file
113         // doesn't exist. It does not guard against any races.
114         _, err = fmt.Fprintf(stdin, `
115 f() {
116         x=%q
117
118         dir="$(dirname "$x")"
119         if ! test "$(stat -c '%%A %%u %%g' "$dir")" = 'drwxrwxrwt 0 0'; then
120                 echo "unsafe permissions on $dir, aborting" >&2
121                 exit 1
122         fi
123
124         if test -e "$x" && test "$(stat -c '%%A %%u' "$x")" = "-rwx------ $(id -u)"; then
125                 # Report checksum
126                 sha512sum "$x"
127         else
128                 # Empty checksum to request upload
129                 echo
130         fi
131
132         # Wait for signal to continue
133         read upload
134
135         if test -n "$upload"; then
136                 tmp="$(mktemp "$x.XXXXXX")"
137                 # Report filename for upload
138                 echo "$tmp"
139
140                 # Wait for upload to complete
141                 read unused
142
143                 # Safely create new file (ln does not follow symlinks)
144                 rm -f "$x"
145                 ln "$tmp" "$x"
146                 rm "$tmp"
147                 # Make file executable
148                 chmod 0700 "$x"
149         fi
150
151         exec "$x"
152 }
153 f
154 `, path)
155         if err != nil {
156                 return err
157         }
158         remoteSum, err := stdout.ReadString('\n')
159         if err != nil {
160                 return err
161         }
162
163         // Get embedded helper binary
164         helper, err := remote.Helpers.ReadFile(
165                 fmt.Sprintf("helpers/%s-%s", goos, goarch))
166         if err != nil {
167                 return fmt.Errorf("remote not built for GOOS/GOARCH %s/%s",
168                         goos, goarch)
169         }
170
171         var upload bool
172         if remoteSum == "\n" {
173                 upload = true
174                 c.debugf("DialSSH: remote not present or invalid permissions")
175
176         } else {
177                 x := strings.Fields(remoteSum)
178                 if len(x) < 1 {
179                         return fmt.Errorf("got unexpected checksum line %q",
180                                 remoteSum)
181                 }
182                 sha := sha512.Sum512(helper)
183                 hex := hex.EncodeToString(sha[:])
184                 if hex == x[0] {
185                         c.debugf("DialSSH: remote checksum matches")
186                 } else {
187                         upload = true
188                         c.debugf("DialSSH: remote checksum does not match")
189                 }
190         }
191
192         if upload {
193                 // Notify user that an upload is going to take place.
194                 c.events <- ConnEvent{
195                         Type: ConnEventUpload,
196                 }
197
198                 // Tell script we want to upload a new file.
199                 _, err = fmt.Fprintln(stdin, "upload")
200                 if err != nil {
201                         return err
202                 }
203                 // Get path to temporary file for upload.
204                 //
205                 // Write to the temporary file instead of the final path so
206                 // that a concurrent run of this function won't use a
207                 // partially written file. The rm in the script could still
208                 // cause a missing file but at least no file with unknown
209                 // content is executed.
210                 path, err := stdout.ReadString('\n')
211                 if err != nil {
212                         return err
213                 }
214                 path = strings.TrimSuffix(path, "\n")
215
216                 c.debugf("DialSSH: uploading new remote to %q at %q",
217                         c.remote, path)
218
219                 cmd := exec.Command("ssh", c.remote,
220                         fmt.Sprintf("cat > %q", path))
221                 cmd.Stdin = bytes.NewReader(helper)
222                 err = c.handleStderrAsEvents(cmd)
223                 if err != nil {
224                         return err
225                 }
226                 err = cmd.Run()
227                 if err != nil {
228                         return err
229                 }
230         }
231
232         // Tell script to continue and execute the remote helper
233         _, err = fmt.Fprintln(stdin, "")
234         if err != nil {
235                 return err
236         }
237
238         return nil
239 }
240
241 func connGetGoos(stdin io.Writer, stdout *bufio.Reader) (string, error) {
242         _, err := fmt.Fprintln(stdin, "uname -o")
243         if err != nil {
244                 return "", err
245         }
246         x, err := stdout.ReadString('\n')
247         if err != nil {
248                 return "", err
249         }
250         x = strings.TrimSpace(x)
251
252         // NOTE: Adapt helper uploading in dialSSH() when adding new systems
253         var goos string
254         switch x {
255         case "GNU/Linux":
256                 goos = "linux"
257         default:
258                 return "", fmt.Errorf("unsupported OS %q (`uname -o`)", x)
259         }
260         return goos, nil
261 }
262
263 func connGetGoarch(stdin io.Writer, stdout *bufio.Reader) (string, error) {
264         _, err := fmt.Fprintln(stdin, "uname -m")
265         if err != nil {
266                 return "", err
267         }
268         x, err := stdout.ReadString('\n')
269         if err != nil {
270                 return "", err
271         }
272         x = strings.TrimSpace(x)
273
274         // NOTE: Adapt cmd/safcm-remote/build.sh when adding new architectures
275         var goarch string
276         switch x {
277         case "x86_64":
278                 goarch = "amd64"
279         default:
280                 return "", fmt.Errorf("unsupported arch %q (`uname -m`)", x)
281         }
282         return goarch, nil
283 }
284
285 func connGetUID(stdin io.Writer, stdout *bufio.Reader) (int, error) {
286         _, err := fmt.Fprintln(stdin, "id -u")
287         if err != nil {
288                 return -1, err
289         }
290         x, err := stdout.ReadString('\n')
291         if err != nil {
292                 return -1, err
293         }
294         x = strings.TrimSpace(x)
295
296         uid, err := strconv.Atoi(x)
297         if err != nil {
298                 return -1, fmt.Errorf("invalid UID %q (`id -u`)", x)
299         }
300         return uid, nil
301 }