]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - cmd/safcm/sync.go
Improve and add comments
[safcm/safcm.git] / cmd / safcm / sync.go
1 // "sync" sub-command: sync data to remote hosts
2
3 // Copyright (C) 2021  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "flag"
22         "fmt"
23         "log"
24         "os"
25         "os/signal"
26         "sort"
27         "strings"
28         "sync"
29
30         "golang.org/x/term"
31
32         "ruderich.org/simon/safcm"
33         "ruderich.org/simon/safcm/cmd/safcm/config"
34         "ruderich.org/simon/safcm/rpc"
35 )
36
37 type Sync struct {
38         host *config.Host
39
40         config    *config.Config      // global configuration
41         allHosts  *config.Hosts       // known hosts
42         allGroups map[string][]string // known groups
43
44         events chan<- Event // all events generated by/for this host
45
46         isTTY bool
47 }
48
49 type Event struct {
50         Host *config.Host
51
52         // Only one of Error, Log and ConnEvent is set in a single event
53         Error     error
54         Log       Log
55         ConnEvent rpc.ConnEvent
56
57         Escaped bool // true if untrusted input is already escaped
58 }
59
60 type Log struct {
61         Level safcm.LogLevel
62         Text  string
63 }
64
65 func MainSync(args []string) error {
66         flag.Usage = func() {
67                 fmt.Fprintf(os.Stderr,
68                         "usage: %s sync [<options>] <host|group...>\n",
69                         args[0])
70                 flag.PrintDefaults()
71         }
72
73         optionDryRun := flag.Bool("n", false,
74                 "dry-run, show diff but don't perform any changes")
75         optionQuiet := flag.Bool("q", false,
76                 "hide successful, non-trigger commands with no output from host changes listing")
77         optionLog := flag.String("log", "info", "set log `level`; "+
78                 "levels: error, info, verbose, debug, debug2, debug3")
79         optionSshConfig := flag.String("sshconfig", "",
80                 "`path` to ssh configuration file; used for tests")
81
82         flag.CommandLine.Parse(args[2:])
83
84         var level safcm.LogLevel
85         switch *optionLog {
86         case "error":
87                 level = safcm.LogError
88         case "info":
89                 level = safcm.LogInfo
90         case "verbose":
91                 level = safcm.LogVerbose
92         case "debug":
93                 level = safcm.LogDebug
94         case "debug2":
95                 level = safcm.LogDebug2
96         case "debug3":
97                 level = safcm.LogDebug3
98         default:
99                 return fmt.Errorf("invalid -log value %q", *optionLog)
100         }
101
102         names := flag.Args()
103         if len(names) == 0 {
104                 flag.Usage()
105                 os.Exit(1)
106         }
107
108         cfg, allHosts, allGroups, err := LoadBaseFiles()
109         if err != nil {
110                 return err
111         }
112         cfg.DryRun = *optionDryRun
113         cfg.Quiet = *optionQuiet
114         cfg.LogLevel = level
115         cfg.SshConfig = *optionSshConfig
116
117         toSync, err := hostsToSync(names, allHosts, allGroups)
118         if err != nil {
119                 return err
120         }
121         if len(toSync) == 0 {
122                 return fmt.Errorf("no hosts found")
123         }
124
125         isTTY := term.IsTerminal(int(os.Stdout.Fd()))
126
127         done := make(chan bool)
128         // Collect events from all hosts and print them
129         events := make(chan Event)
130         go func() {
131                 var failed bool
132                 for {
133                         x := <-events
134                         if x.Host == nil {
135                                 break
136                         }
137                         logEvent(x, cfg.LogLevel, isTTY, &failed)
138                 }
139                 done <- failed
140         }()
141
142         hostsLeft := make(map[string]bool)
143         for _, x := range toSync {
144                 hostsLeft[x.Name] = true
145         }
146         var hostsLeftMutex sync.Mutex // protects hostsLeft
147
148         // Show unfinished hosts on Ctrl-C
149         sigint := make(chan os.Signal, 1)   // buffered for Notify()
150         signal.Notify(sigint, os.Interrupt) // = SIGINT = Ctrl-C
151         go func() {
152                 // Running `ssh` processes get killed by SIGINT which is sent
153                 // to all processes
154
155                 <-sigint
156                 log.Print("Received SIGINT, aborting ...")
157
158                 // Print all queued events
159                 events <- Event{} // poison pill
160                 <-done
161                 // "races" with <-done in the main function and will hang here
162                 // if the other is faster. This is fine because then all hosts
163                 // were synced successfully.
164
165                 hostsLeftMutex.Lock()
166                 var hosts []string
167                 for x := range hostsLeft {
168                         hosts = append(hosts, x)
169                 }
170                 sort.Strings(hosts)
171                 log.Fatalf("Failed to sync %s", strings.Join(hosts, ", "))
172         }()
173
174         // Sync all hosts concurrently
175         var wg sync.WaitGroup
176         for _, x := range toSync {
177                 x := x
178
179                 // Once in sync.Host() and once in the go func below
180                 wg.Add(2)
181
182                 go func() {
183                         sync := Sync{
184                                 host:      x,
185                                 config:    cfg,
186                                 allHosts:  allHosts,
187                                 allGroups: allGroups,
188                                 events:    events,
189                                 isTTY:     isTTY,
190                         }
191                         err := sync.Host(&wg)
192                         if err != nil {
193                                 events <- Event{
194                                         Host:  x,
195                                         Error: err,
196                                 }
197                         }
198                         wg.Done()
199
200                         hostsLeftMutex.Lock()
201                         defer hostsLeftMutex.Unlock()
202                         delete(hostsLeft, x.Name)
203                 }()
204         }
205
206         wg.Wait()
207         events <- Event{} // poison pill
208         failed := <-done
209
210         if failed {
211                 // Exit instead of returning an error to prevent an extra log
212                 // message from main()
213                 os.Exit(1)
214         }
215         return nil
216 }
217
218 // hostsToSync returns the list of hosts to sync based on the command line
219 // arguments.
220 //
221 // Full host and group matches are required to prevent unexpected behavior. No
222 // arguments does not expand to all hosts to prevent accidents; "all" can be
223 // used instead. Both host and group names are permitted as these are unique.
224 //
225 // TODO: Add option to permit partial/glob matches
226 func hostsToSync(names []string, allHosts *config.Hosts,
227         allGroups map[string][]string) ([]*config.Host, error) {
228
229         detectedMap := make(map[string]bool)
230         for _, x := range config.TransitivelyDetectedGroups(allGroups) {
231                 detectedMap[x] = true
232         }
233
234         const detectedErr = `
235
236 Groups depending on "detected" groups cannot be used to select hosts as these
237 are only available after the hosts were contacted.
238 `
239
240         nameMap := make(map[string]bool)
241         for _, x := range names {
242                 if detectedMap[x] {
243                         return nil, fmt.Errorf(
244                                 "group %q depends on \"detected\" groups%s",
245                                 x, detectedErr)
246                 }
247                 nameMap[x] = true
248         }
249         nameMatched := make(map[string]bool)
250         // To detect typos we must check all given names but one host can be
251         // matched by multiple names (e.g. two groups with overlapping hosts)
252         hostMatched := make(map[string]bool)
253
254         var res []*config.Host
255         for _, host := range allHosts.List {
256                 if nameMap[host.Name] {
257                         res = append(res, host)
258                         hostMatched[host.Name] = true
259                         nameMatched[host.Name] = true
260                 }
261
262                 groups, err := config.ResolveHostGroups(host.Name,
263                         allGroups, nil)
264                 if err != nil {
265                         return nil, err
266                 }
267                 for _, x := range groups {
268                         if nameMap[x] {
269                                 if !hostMatched[host.Name] {
270                                         res = append(res, host)
271                                         hostMatched[host.Name] = true
272                                 }
273                                 nameMatched[x] = true
274                         }
275                 }
276         }
277
278         // Warn about unmatched names to detect typos
279         if len(nameMap) != len(nameMatched) {
280                 var unmatched []string
281                 for x := range nameMap {
282                         if !nameMatched[x] {
283                                 unmatched = append(unmatched,
284                                         fmt.Sprintf("%q", x))
285                         }
286                 }
287                 sort.Strings(unmatched)
288                 return nil, fmt.Errorf("hosts/groups not found: %s",
289                         strings.Join(unmatched, " "))
290         }
291
292         return res, nil
293 }
294
295 func logEvent(x Event, level safcm.LogLevel, isTTY bool, failed *bool) {
296         // We have multiple event sources so this is somewhat ugly.
297         var prefix, data string
298         var color Color
299         if x.Error != nil {
300                 prefix = "[error]"
301                 data = x.Error.Error()
302                 color = ColorRed
303                 // We logged an error, tell the caller
304                 *failed = true
305         } else if x.Log.Level != 0 {
306                 // LogError and LogDebug3 should not occur here
307                 switch x.Log.Level {
308                 case safcm.LogInfo:
309                         prefix = "[info]"
310                 case safcm.LogVerbose:
311                         prefix = "[verbose]"
312                 case safcm.LogDebug:
313                         prefix = "[debug]"
314                 case safcm.LogDebug2:
315                         prefix = "[debug2]"
316                 default:
317                         prefix = fmt.Sprintf("[INVALID=%d]", x.Log.Level)
318                         color = ColorRed
319                 }
320                 data = x.Log.Text
321         } else {
322                 switch x.ConnEvent.Type {
323                 case rpc.ConnEventStderr:
324                         prefix = "[stderr]"
325                 case rpc.ConnEventDebug:
326                         prefix = "[debug3]"
327                 case rpc.ConnEventUpload:
328                         if level < safcm.LogInfo {
329                                 return
330                         }
331                         prefix = "[info]"
332                         x.ConnEvent.Data = "remote helper upload in progress"
333                 default:
334                         prefix = fmt.Sprintf("[INVALID=%d]", x.ConnEvent.Type)
335                         color = ColorRed
336                 }
337                 data = x.ConnEvent.Data
338         }
339
340         host := x.Host.Name
341         if color != 0 {
342                 host = ColorString(isTTY, color, host)
343         }
344         // Make sure to escape control characters to prevent terminal
345         // injection attacks
346         if !x.Escaped {
347                 data = EscapeControlCharacters(isTTY, data)
348         }
349         log.Printf("%-9s [%s] %s", prefix, host, data)
350 }
351
352 func (s *Sync) Host(wg *sync.WaitGroup) error {
353         conn := rpc.NewConn(s.config.LogLevel >= safcm.LogDebug3)
354         // Pass all connection events to main loop
355         go func() {
356                 for {
357                         x, ok := <-conn.Events
358                         if !ok {
359                                 break
360                         }
361                         s.events <- Event{
362                                 Host:      s.host,
363                                 ConnEvent: x,
364                         }
365                 }
366                 wg.Done()
367         }()
368
369         // Connect to remote host
370         err := conn.DialSSH(s.host.SshUser, s.host.Name, s.config.SshConfig)
371         if err != nil {
372                 return err
373         }
374         defer conn.Kill()
375
376         // Collect information about remote host
377         detectedGroups, err := s.hostInfo(conn)
378         if err != nil {
379                 return err
380         }
381
382         // Sync state to remote host
383         err = s.hostSync(conn, detectedGroups)
384         if err != nil {
385                 return err
386         }
387
388         // Terminate connection to remote host
389         err = conn.Send(safcm.MsgQuitReq{})
390         if err != nil {
391                 return err
392         }
393         _, err = conn.Recv()
394         if err != nil {
395                 return err
396         }
397         err = conn.Wait()
398         if err != nil {
399                 return err
400         }
401
402         return nil
403 }
404
405 func (s *Sync) logf(level safcm.LogLevel, escaped bool,
406         format string, a ...interface{}) {
407
408         if s.config.LogLevel < level {
409                 return
410         }
411         s.events <- Event{
412                 Host: s.host,
413                 Log: Log{
414                         Level: level,
415                         Text:  fmt.Sprintf(format, a...),
416                 },
417                 Escaped: escaped,
418         }
419 }
420 func (s *Sync) logDebugf(format string, a ...interface{}) {
421         s.logf(safcm.LogDebug, false, format, a...)
422 }
423 func (s *Sync) logVerbosef(format string, a ...interface{}) {
424         s.logf(safcm.LogVerbose, false, format, a...)
425 }
426
427 // sendRecv sends a message over conn and waits for the response. Any MsgLog
428 // messages received before the final (non MsgLog) response are passed to
429 // s.log.
430 func (s *Sync) sendRecv(conn *rpc.Conn, msg safcm.Msg) (safcm.Msg, error) {
431         err := conn.Send(msg)
432         if err != nil {
433                 return nil, err
434         }
435         for {
436                 x, err := conn.Recv()
437                 if err != nil {
438                         return nil, err
439                 }
440                 log, ok := x.(safcm.MsgLog)
441                 if ok {
442                         s.logf(log.Level, false, "%s", log.Text)
443                         continue
444                 }
445                 return x, nil
446         }
447 }