]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - cmd/safcm/sync.go
safcm: add experimental support to sync from Windows hosts
[safcm/safcm.git] / cmd / safcm / sync.go
1 // "sync" sub-command: sync data to remote hosts
2
3 // Copyright (C) 2021  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "flag"
22         "fmt"
23         "log"
24         "os"
25         "os/signal"
26         "runtime"
27         "sort"
28         "strings"
29         "sync"
30
31         "golang.org/x/term"
32
33         "ruderich.org/simon/safcm"
34         "ruderich.org/simon/safcm/cmd/safcm/config"
35         "ruderich.org/simon/safcm/rpc"
36 )
37
38 type Sync struct {
39         host *config.Host
40
41         config    *config.Config      // global configuration
42         allHosts  *config.Hosts       // known hosts
43         allGroups map[string][]string // known groups
44
45         events chan<- Event // all events generated by/for this host
46
47         isTTY bool
48 }
49
50 type Event struct {
51         Host *config.Host
52
53         // Only one of Error, Log and ConnEvent is set in a single event
54         Error     error
55         Log       Log
56         ConnEvent rpc.ConnEvent
57
58         Escaped bool // true if untrusted input is already escaped
59 }
60
61 type Log struct {
62         Level safcm.LogLevel
63         Text  string
64 }
65
66 func MainSync(args []string) error {
67         flag.Usage = func() {
68                 fmt.Fprintf(os.Stderr,
69                         "usage: %s sync [<options>] <host|group...>\n",
70                         args[0])
71                 flag.PrintDefaults()
72         }
73
74         optionDryRun := flag.Bool("n", false,
75                 "dry-run, show diff but don't perform any changes")
76         optionQuiet := flag.Bool("q", false,
77                 "hide successful, non-trigger commands with no output from host changes listing")
78         optionLog := flag.String("log", "info", "set log `level`; "+
79                 "levels: error, info, verbose, debug, debug2, debug3")
80         optionSshConfig := flag.String("sshconfig", "",
81                 "`path` to ssh configuration file; used for tests")
82
83         flag.CommandLine.Parse(args[2:])
84
85         var level safcm.LogLevel
86         switch *optionLog {
87         case "error":
88                 level = safcm.LogError
89         case "info":
90                 level = safcm.LogInfo
91         case "verbose":
92                 level = safcm.LogVerbose
93         case "debug":
94                 level = safcm.LogDebug
95         case "debug2":
96                 level = safcm.LogDebug2
97         case "debug3":
98                 level = safcm.LogDebug3
99         default:
100                 return fmt.Errorf("invalid -log value %q", *optionLog)
101         }
102
103         names := flag.Args()
104         if len(names) == 0 {
105                 flag.Usage()
106                 os.Exit(1)
107         }
108
109         if runtime.GOOS == "windows" {
110                 log.Print("WARNING: Windows support is experimental!")
111         }
112
113         cfg, allHosts, allGroups, err := LoadBaseFiles()
114         if err != nil {
115                 return err
116         }
117         cfg.DryRun = *optionDryRun
118         cfg.Quiet = *optionQuiet
119         cfg.LogLevel = level
120         cfg.SshConfig = *optionSshConfig
121
122         toSync, err := hostsToSync(names, allHosts, allGroups)
123         if err != nil {
124                 return err
125         }
126         if len(toSync) == 0 {
127                 return fmt.Errorf("no hosts found")
128         }
129
130         isTTY := term.IsTerminal(int(os.Stdout.Fd())) &&
131                 term.IsTerminal(int(os.Stderr.Fd()))
132
133         done := make(chan bool)
134         // Collect events from all hosts and print them
135         events := make(chan Event)
136         go func() {
137                 var failed bool
138                 for {
139                         x := <-events
140                         if x.Host == nil {
141                                 break
142                         }
143                         logEvent(x, cfg.LogLevel, isTTY, &failed)
144                 }
145                 done <- failed
146         }()
147
148         hostsLeft := make(map[string]bool)
149         for _, x := range toSync {
150                 hostsLeft[x.Name] = true
151         }
152         var hostsLeftMutex sync.Mutex // protects hostsLeft
153
154         // Show unfinished hosts on Ctrl-C
155         sigint := make(chan os.Signal, 1)   // buffered for Notify()
156         signal.Notify(sigint, os.Interrupt) // = SIGINT = Ctrl-C
157         go func() {
158                 // Running `ssh` processes get killed by SIGINT which is sent
159                 // to all processes
160
161                 <-sigint
162                 log.Print("Received SIGINT, aborting ...")
163
164                 // Print all queued events
165                 events <- Event{} // poison pill
166                 <-done
167                 // "races" with <-done in the main function and will hang here
168                 // if the other is faster. This is fine because then all hosts
169                 // were synced successfully.
170
171                 hostsLeftMutex.Lock()
172                 var hosts []string
173                 for x := range hostsLeft {
174                         hosts = append(hosts, x)
175                 }
176                 sort.Strings(hosts)
177                 log.Fatalf("Failed to sync %s", strings.Join(hosts, ", "))
178         }()
179
180         // Sync all hosts concurrently
181         var wg sync.WaitGroup
182         for _, x := range toSync {
183                 x := x
184
185                 // Once in sync.Host() and once in the go func below
186                 wg.Add(2)
187
188                 go func() {
189                         sync := Sync{
190                                 host:      x,
191                                 config:    cfg,
192                                 allHosts:  allHosts,
193                                 allGroups: allGroups,
194                                 events:    events,
195                                 isTTY:     isTTY,
196                         }
197                         err := sync.Host(&wg)
198                         if err != nil {
199                                 events <- Event{
200                                         Host:  x,
201                                         Error: err,
202                                 }
203                         }
204                         wg.Done()
205
206                         hostsLeftMutex.Lock()
207                         defer hostsLeftMutex.Unlock()
208                         delete(hostsLeft, x.Name)
209                 }()
210         }
211
212         wg.Wait()
213         events <- Event{} // poison pill
214         failed := <-done
215
216         if failed {
217                 // Exit instead of returning an error to prevent an extra log
218                 // message from main()
219                 os.Exit(1)
220         }
221         return nil
222 }
223
224 // hostsToSync returns the list of hosts to sync based on the command line
225 // arguments.
226 //
227 // Full host and group matches are required to prevent unexpected behavior. No
228 // arguments does not expand to all hosts to prevent accidents; "all" can be
229 // used instead. Both host and group names are permitted as these are unique.
230 //
231 // TODO: Add option to permit partial/glob matches
232 func hostsToSync(names []string, allHosts *config.Hosts,
233         allGroups map[string][]string) ([]*config.Host, error) {
234
235         detectedMap := config.TransitivelyDetectedGroups(allGroups)
236
237         const detectedErr = `
238
239 Groups depending on "detected" groups cannot be used to select hosts as these
240 are only available after the hosts were contacted.
241 `
242
243         nameMap := make(map[string]bool)
244         for _, x := range names {
245                 if detectedMap[x] {
246                         return nil, fmt.Errorf(
247                                 "group %q depends on \"detected\" groups%s",
248                                 x, detectedErr)
249                 }
250                 nameMap[x] = true
251         }
252         nameMatched := make(map[string]bool)
253         // To detect typos we must check all given names but one host can be
254         // matched by multiple names (e.g. two groups with overlapping hosts)
255         hostAdded := make(map[string]bool)
256
257         var res []*config.Host
258         for _, host := range allHosts.List {
259                 if nameMap[host.Name] {
260                         res = append(res, host)
261                         hostAdded[host.Name] = true
262                         nameMatched[host.Name] = true
263                 }
264
265                 groups, err := config.ResolveHostGroups(host.Name,
266                         allGroups, nil)
267                 if err != nil {
268                         return nil, err
269                 }
270                 for _, x := range groups {
271                         if nameMap[x] {
272                                 if !hostAdded[host.Name] {
273                                         res = append(res, host)
274                                         hostAdded[host.Name] = true
275                                 }
276                                 nameMatched[x] = true
277                         }
278                 }
279         }
280
281         // Warn about unmatched names to detect typos
282         if len(nameMap) != len(nameMatched) {
283                 var unmatched []string
284                 for x := range nameMap {
285                         if !nameMatched[x] {
286                                 unmatched = append(unmatched,
287                                         fmt.Sprintf("%q", x))
288                         }
289                 }
290                 sort.Strings(unmatched)
291                 return nil, fmt.Errorf("hosts/groups not found: %s",
292                         strings.Join(unmatched, " "))
293         }
294
295         return res, nil
296 }
297
298 func logEvent(x Event, level safcm.LogLevel, isTTY bool, failed *bool) {
299         // We have multiple event sources so this is somewhat ugly.
300         var prefix, data string
301         var color Color
302         if x.Error != nil {
303                 prefix = "[error]"
304                 data = x.Error.Error()
305                 color = ColorRed
306                 // We logged an error, tell the caller
307                 *failed = true
308         } else if x.Log.Level != 0 {
309                 // LogError and LogDebug3 should not occur here
310                 switch x.Log.Level {
311                 case safcm.LogInfo:
312                         prefix = "[info]"
313                 case safcm.LogVerbose:
314                         prefix = "[verbose]"
315                 case safcm.LogDebug:
316                         prefix = "[debug]"
317                 case safcm.LogDebug2:
318                         prefix = "[debug2]"
319                 default:
320                         prefix = fmt.Sprintf("[INVALID=%d]", x.Log.Level)
321                         color = ColorRed
322                 }
323                 data = x.Log.Text
324         } else {
325                 switch x.ConnEvent.Type {
326                 case rpc.ConnEventStderr:
327                         prefix = "[stderr]"
328                 case rpc.ConnEventDebug:
329                         prefix = "[debug3]"
330                 case rpc.ConnEventUpload:
331                         if level < safcm.LogInfo {
332                                 return
333                         }
334                         prefix = "[info]"
335                         x.ConnEvent.Data = "remote helper upload in progress"
336                 default:
337                         prefix = fmt.Sprintf("[INVALID=%d]", x.ConnEvent.Type)
338                         color = ColorRed
339                 }
340                 data = x.ConnEvent.Data
341         }
342
343         host := x.Host.Name
344         if color != 0 {
345                 host = ColorString(isTTY, color, host)
346         }
347         // Make sure to escape control characters to prevent terminal
348         // injection attacks
349         if !x.Escaped {
350                 data = EscapeControlCharacters(isTTY, data)
351         }
352         log.Printf("%-9s [%s] %s", prefix, host, data)
353 }
354
355 func (s *Sync) Host(wg *sync.WaitGroup) error {
356         conn := rpc.NewConn(s.config.LogLevel >= safcm.LogDebug3)
357         // Pass all connection events to main loop
358         go func() {
359                 for {
360                         x, ok := <-conn.Events
361                         if !ok {
362                                 break
363                         }
364                         s.events <- Event{
365                                 Host:      s.host,
366                                 ConnEvent: x,
367                         }
368                 }
369                 wg.Done()
370         }()
371
372         // Connect to remote host
373         err := conn.DialSSH(s.host.SshUser, s.host.Name, s.config.SshConfig)
374         if err != nil {
375                 return err
376         }
377         defer conn.Kill()
378
379         // Collect information about remote host
380         detectedGroups, err := s.hostInfo(conn)
381         if err != nil {
382                 return err
383         }
384
385         // Sync state to remote host
386         err = s.hostSync(conn, detectedGroups)
387         if err != nil {
388                 return err
389         }
390
391         // Terminate connection to remote host
392         err = conn.Send(safcm.MsgQuitReq{})
393         if err != nil {
394                 return err
395         }
396         _, err = conn.Recv()
397         if err != nil {
398                 return err
399         }
400         err = conn.Wait()
401         if err != nil {
402                 return err
403         }
404
405         return nil
406 }
407
408 func (s *Sync) log(level safcm.LogLevel, escaped bool, msg string) {
409         if s.config.LogLevel < level {
410                 return
411         }
412         s.events <- Event{
413                 Host: s.host,
414                 Log: Log{
415                         Level: level,
416                         Text:  msg,
417                 },
418                 Escaped: escaped,
419         }
420 }
421 func (s *Sync) logDebugf(format string, a ...interface{}) {
422         s.log(safcm.LogDebug, false, fmt.Sprintf(format, a...))
423 }
424 func (s *Sync) logVerbosef(format string, a ...interface{}) {
425         s.log(safcm.LogVerbose, false, fmt.Sprintf(format, a...))
426 }
427
428 // sendRecv sends a message over conn and waits for the response. Any MsgLog
429 // messages received before the final (non MsgLog) response are passed to
430 // s.log.
431 func (s *Sync) sendRecv(conn *rpc.Conn, msg safcm.Msg) (safcm.Msg, error) {
432         err := conn.Send(msg)
433         if err != nil {
434                 return nil, err
435         }
436         for {
437                 x, err := conn.Recv()
438                 if err != nil {
439                         return nil, err
440                 }
441                 log, ok := x.(safcm.MsgLog)
442                 if ok {
443                         s.log(log.Level, false, log.Text)
444                         continue
445                 }
446                 return x, nil
447         }
448 }