]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - cmd/safcm/sync.go
Move embedded remote helpers to cmd/safcm/
[safcm/safcm.git] / cmd / safcm / sync.go
1 // "sync" sub-command: sync data to remote hosts
2
3 // Copyright (C) 2021  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "flag"
22         "fmt"
23         "io/fs"
24         "log"
25         "os"
26         "os/signal"
27         "runtime"
28         "sort"
29         "strings"
30         "sync"
31
32         "golang.org/x/term"
33
34         "ruderich.org/simon/safcm"
35         "ruderich.org/simon/safcm/cmd/safcm/config"
36         "ruderich.org/simon/safcm/rpc"
37 )
38
39 type Sync struct {
40         host *config.Host
41
42         config    *config.Config      // global configuration
43         allHosts  *config.Hosts       // known hosts
44         allGroups map[string][]string // known groups
45
46         events chan<- Event // all events generated by/for this host
47
48         isTTY bool
49 }
50
51 type Event struct {
52         Host *config.Host
53
54         // Only one of Error, Log and ConnEvent is set in a single event
55         Error     error
56         Log       Log
57         ConnEvent rpc.ConnEvent
58
59         Escaped bool // true if untrusted input is already escaped
60 }
61
62 type Log struct {
63         Level safcm.LogLevel
64         Text  string
65 }
66
67 func MainSync(args []string) error {
68         flag.Usage = func() {
69                 fmt.Fprintf(os.Stderr,
70                         "usage: %s sync [<options>] <host|group...>\n",
71                         args[0])
72                 flag.PrintDefaults()
73         }
74
75         optionDryRun := flag.Bool("n", false,
76                 "dry-run, show diff but don't perform any changes")
77         optionQuiet := flag.Bool("q", false,
78                 "hide successful, non-trigger commands with no output from host changes listing")
79         optionLog := flag.String("log", "info", "set log `level`; "+
80                 "levels: error, info, verbose, debug, debug2, debug3")
81         optionSshConfig := flag.String("sshconfig", "",
82                 "`path` to ssh configuration file; used for tests")
83
84         flag.CommandLine.Parse(args[2:])
85
86         var level safcm.LogLevel
87         switch *optionLog {
88         case "error":
89                 level = safcm.LogError
90         case "info":
91                 level = safcm.LogInfo
92         case "verbose":
93                 level = safcm.LogVerbose
94         case "debug":
95                 level = safcm.LogDebug
96         case "debug2":
97                 level = safcm.LogDebug2
98         case "debug3":
99                 level = safcm.LogDebug3
100         default:
101                 return fmt.Errorf("invalid -log value %q", *optionLog)
102         }
103
104         names := flag.Args()
105         if len(names) == 0 {
106                 flag.Usage()
107                 os.Exit(1)
108         }
109
110         if runtime.GOOS == "windows" {
111                 log.Print("WARNING: Windows support is experimental!")
112         }
113
114         cfg, allHosts, allGroups, err := LoadBaseFiles()
115         if err != nil {
116                 return err
117         }
118         cfg.DryRun = *optionDryRun
119         cfg.Quiet = *optionQuiet
120         cfg.LogLevel = level
121         cfg.SshConfig = *optionSshConfig
122
123         toSync, err := hostsToSync(names, allHosts, allGroups)
124         if err != nil {
125                 return err
126         }
127         if len(toSync) == 0 {
128                 return fmt.Errorf("no hosts found")
129         }
130
131         isTTY := term.IsTerminal(int(os.Stdout.Fd())) &&
132                 term.IsTerminal(int(os.Stderr.Fd()))
133
134         done := make(chan bool)
135         // Collect events from all hosts and print them
136         events := make(chan Event)
137         go func() {
138                 var failed bool
139                 for {
140                         x := <-events
141                         if x.Host == nil {
142                                 break
143                         }
144                         logEvent(x, cfg.LogLevel, isTTY, &failed)
145                 }
146                 done <- failed
147         }()
148
149         hostsLeft := make(map[string]bool)
150         for _, x := range toSync {
151                 hostsLeft[x.Name] = true
152         }
153         var hostsLeftMutex sync.Mutex // protects hostsLeft
154
155         // Show unfinished hosts on Ctrl-C
156         sigint := make(chan os.Signal, 1)   // buffered for Notify()
157         signal.Notify(sigint, os.Interrupt) // = SIGINT = Ctrl-C
158         go func() {
159                 // Running `ssh` processes get killed by SIGINT which is sent
160                 // to all processes
161
162                 <-sigint
163                 log.Print("Received SIGINT, aborting ...")
164
165                 // Print all queued events
166                 events <- Event{} // poison pill
167                 <-done
168                 // "races" with <-done in the main function and will hang here
169                 // if the other is faster. This is fine because then all hosts
170                 // were synced successfully.
171
172                 hostsLeftMutex.Lock()
173                 var hosts []string
174                 for x := range hostsLeft {
175                         hosts = append(hosts, x)
176                 }
177                 sort.Strings(hosts)
178                 log.Fatalf("Failed to sync %s", strings.Join(hosts, ", "))
179         }()
180
181         // Sync all hosts concurrently
182         var wg sync.WaitGroup
183         for _, x := range toSync {
184                 x := x
185
186                 // Once in sync.Host() and once in the go func below
187                 wg.Add(2)
188
189                 go func() {
190                         sync := Sync{
191                                 host:      x,
192                                 config:    cfg,
193                                 allHosts:  allHosts,
194                                 allGroups: allGroups,
195                                 events:    events,
196                                 isTTY:     isTTY,
197                         }
198                         err := sync.Host(&wg)
199                         if err != nil {
200                                 events <- Event{
201                                         Host:  x,
202                                         Error: err,
203                                 }
204                         }
205                         wg.Done()
206
207                         hostsLeftMutex.Lock()
208                         defer hostsLeftMutex.Unlock()
209                         delete(hostsLeft, x.Name)
210                 }()
211         }
212
213         wg.Wait()
214         events <- Event{} // poison pill
215         failed := <-done
216
217         if failed {
218                 // Exit instead of returning an error to prevent an extra log
219                 // message from main()
220                 os.Exit(1)
221         }
222         return nil
223 }
224
225 // hostsToSync returns the list of hosts to sync based on the command line
226 // arguments.
227 //
228 // Full host and group matches are required to prevent unexpected behavior. No
229 // arguments does not expand to all hosts to prevent accidents; "all" can be
230 // used instead. Both host and group names are permitted as these are unique.
231 //
232 // TODO: Add option to permit partial/glob matches
233 func hostsToSync(names []string, allHosts *config.Hosts,
234         allGroups map[string][]string) ([]*config.Host, error) {
235
236         detectedMap := config.TransitivelyDetectedGroups(allGroups)
237
238         const detectedErr = `
239
240 Groups depending on "detected" groups cannot be used to select hosts as these
241 are only available after the hosts were contacted.
242 `
243
244         nameMap := make(map[string]bool)
245         for _, x := range names {
246                 if detectedMap[x] {
247                         return nil, fmt.Errorf(
248                                 "group %q depends on \"detected\" groups%s",
249                                 x, detectedErr)
250                 }
251                 nameMap[x] = true
252         }
253         nameMatched := make(map[string]bool)
254         // To detect typos we must check all given names but one host can be
255         // matched by multiple names (e.g. two groups with overlapping hosts)
256         hostAdded := make(map[string]bool)
257
258         var res []*config.Host
259         for _, host := range allHosts.List {
260                 if nameMap[host.Name] {
261                         res = append(res, host)
262                         hostAdded[host.Name] = true
263                         nameMatched[host.Name] = true
264                 }
265
266                 groups, err := config.ResolveHostGroups(host.Name,
267                         allGroups, nil)
268                 if err != nil {
269                         return nil, err
270                 }
271                 for _, x := range groups {
272                         if nameMap[x] {
273                                 if !hostAdded[host.Name] {
274                                         res = append(res, host)
275                                         hostAdded[host.Name] = true
276                                 }
277                                 nameMatched[x] = true
278                         }
279                 }
280         }
281
282         // Warn about unmatched names to detect typos
283         if len(nameMap) != len(nameMatched) {
284                 var unmatched []string
285                 for x := range nameMap {
286                         if !nameMatched[x] {
287                                 unmatched = append(unmatched,
288                                         fmt.Sprintf("%q", x))
289                         }
290                 }
291                 sort.Strings(unmatched)
292                 return nil, fmt.Errorf("hosts/groups not found: %s",
293                         strings.Join(unmatched, " "))
294         }
295
296         return res, nil
297 }
298
299 func logEvent(x Event, level safcm.LogLevel, isTTY bool, failed *bool) {
300         // We have multiple event sources so this is somewhat ugly.
301         var prefix, data string
302         var color Color
303         if x.Error != nil {
304                 prefix = "[error]"
305                 data = x.Error.Error()
306                 color = ColorRed
307                 // We logged an error, tell the caller
308                 *failed = true
309         } else if x.Log.Level != 0 {
310                 // LogError and LogDebug3 should not occur here
311                 switch x.Log.Level {
312                 case safcm.LogInfo:
313                         prefix = "[info]"
314                 case safcm.LogVerbose:
315                         prefix = "[verbose]"
316                 case safcm.LogDebug:
317                         prefix = "[debug]"
318                 case safcm.LogDebug2:
319                         prefix = "[debug2]"
320                 default:
321                         prefix = fmt.Sprintf("[INVALID=%d]", x.Log.Level)
322                         color = ColorRed
323                 }
324                 data = x.Log.Text
325         } else {
326                 switch x.ConnEvent.Type {
327                 case rpc.ConnEventStderr:
328                         prefix = "[stderr]"
329                 case rpc.ConnEventDebug:
330                         prefix = "[debug3]"
331                 case rpc.ConnEventUpload:
332                         if level < safcm.LogInfo {
333                                 return
334                         }
335                         prefix = "[info]"
336                         x.ConnEvent.Data = "remote helper upload in progress"
337                 default:
338                         prefix = fmt.Sprintf("[INVALID=%d]", x.ConnEvent.Type)
339                         color = ColorRed
340                 }
341                 data = x.ConnEvent.Data
342         }
343
344         host := x.Host.Name
345         if color != 0 {
346                 host = ColorString(isTTY, color, host)
347         }
348         // Make sure to escape control characters to prevent terminal
349         // injection attacks
350         if !x.Escaped {
351                 data = EscapeControlCharacters(isTTY, data)
352         }
353         log.Printf("%-9s [%s] %s", prefix, host, data)
354 }
355
356 func (s *Sync) Host(wg *sync.WaitGroup) error {
357         conn := rpc.NewConn(s.config.LogLevel >= safcm.LogDebug3)
358         // Pass all connection events to main loop
359         go func() {
360                 for {
361                         x, ok := <-conn.Events
362                         if !ok {
363                                 break
364                         }
365                         s.events <- Event{
366                                 Host:      s.host,
367                                 ConnEvent: x,
368                         }
369                 }
370                 wg.Done()
371         }()
372
373         helpers, err := fs.Sub(RemoteHelpers, "remote")
374         if err != nil {
375                 conn.Kill()
376                 return err
377         }
378
379         // Connect to remote host
380         user := s.host.SshUser
381         if user == "" {
382                 user = s.config.SshUser
383         }
384         err = conn.DialSSH(rpc.SSHConfig{
385                 Host:          s.host.Name,
386                 User:          user,
387                 SshConfig:     s.config.SshConfig,
388                 RemoteHelpers: helpers,
389         })
390         if err != nil {
391                 conn.Kill()
392                 return err
393         }
394         defer conn.Kill()
395
396         // Collect information about remote host
397         detectedGroups, err := s.hostInfo(conn)
398         if err != nil {
399                 return err
400         }
401
402         // Sync state to remote host
403         err = s.hostSync(conn, detectedGroups)
404         if err != nil {
405                 return err
406         }
407
408         // Terminate connection to remote host
409         err = conn.Send(safcm.MsgQuitReq{})
410         if err != nil {
411                 return err
412         }
413         _, err = conn.Recv()
414         if err != nil {
415                 return err
416         }
417         err = conn.Wait()
418         if err != nil {
419                 return err
420         }
421
422         return nil
423 }
424
425 func (s *Sync) log(level safcm.LogLevel, escaped bool, msg string) {
426         if s.config.LogLevel < level {
427                 return
428         }
429         s.events <- Event{
430                 Host: s.host,
431                 Log: Log{
432                         Level: level,
433                         Text:  msg,
434                 },
435                 Escaped: escaped,
436         }
437 }
438 func (s *Sync) logDebugf(format string, a ...interface{}) {
439         s.log(safcm.LogDebug, false, fmt.Sprintf(format, a...))
440 }
441 func (s *Sync) logVerbosef(format string, a ...interface{}) {
442         s.log(safcm.LogVerbose, false, fmt.Sprintf(format, a...))
443 }
444
445 // sendRecv sends a message over conn and waits for the response. Any MsgLog
446 // messages received before the final (non MsgLog) response are passed to
447 // s.log.
448 func (s *Sync) sendRecv(conn *rpc.Conn, msg safcm.Msg) (safcm.Msg, error) {
449         err := conn.Send(msg)
450         if err != nil {
451                 return nil, err
452         }
453         for {
454                 x, err := conn.Recv()
455                 if err != nil {
456                         return nil, err
457                 }
458                 log, ok := x.(safcm.MsgLog)
459                 if ok {
460                         s.log(log.Level, false, log.Text)
461                         continue
462                 }
463                 return x, nil
464         }
465 }