]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - cmd/safcm/sync.go
7fa2c936b637c7059e5429b273f310233ee92ed4
[safcm/safcm.git] / cmd / safcm / sync.go
1 // "sync" sub-command: sync data to remote hosts
2
3 // Copyright (C) 2021  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "flag"
22         "fmt"
23         "log"
24         "os"
25         "sort"
26         "strings"
27         "sync"
28
29         "golang.org/x/term"
30
31         "ruderich.org/simon/safcm"
32         "ruderich.org/simon/safcm/cmd/safcm/config"
33         "ruderich.org/simon/safcm/rpc"
34 )
35
36 type Sync struct {
37         host *config.Host
38
39         config    *config.Config      // global configuration
40         allHosts  *config.Hosts       // known hosts
41         allGroups map[string][]string // known groups
42
43         events chan<- Event // all events generated by/for this host
44
45         isTTY bool
46 }
47
48 type Event struct {
49         Host *config.Host
50
51         // Only one of Error, Log and ConnEvent is set in a single event
52         Error     error
53         Log       Log
54         ConnEvent rpc.ConnEvent
55
56         Escaped bool // true if untrusted input is already escaped
57 }
58
59 type Log struct {
60         Level safcm.LogLevel
61         Text  string
62 }
63
64 func MainSync(args []string) error {
65         flag.Usage = func() {
66                 fmt.Fprintf(os.Stderr,
67                         "usage: %s sync [<options>] <host|group...>\n",
68                         args[0])
69                 flag.PrintDefaults()
70         }
71
72         optionDryRun := flag.Bool("n", false,
73                 "dry-run, show diff but don't perform any changes")
74         optionQuiet := flag.Bool("q", false,
75                 "hide successful, non-trigger commands with no output from host changes listing")
76         optionLog := flag.String("log", "info", "set log `level`; "+
77                 "levels: error, info, verbose, debug, debug2, debug3")
78
79         flag.CommandLine.Parse(args[2:])
80
81         var level safcm.LogLevel
82         switch *optionLog {
83         case "error":
84                 level = safcm.LogError
85         case "info":
86                 level = safcm.LogInfo
87         case "verbose":
88                 level = safcm.LogVerbose
89         case "debug":
90                 level = safcm.LogDebug
91         case "debug2":
92                 level = safcm.LogDebug2
93         case "debug3":
94                 level = safcm.LogDebug3
95         default:
96                 return fmt.Errorf("invalid -log value %q", *optionLog)
97         }
98
99         names := flag.Args()
100         if len(names) == 0 {
101                 flag.Usage()
102                 os.Exit(1)
103         }
104
105         cfg, allHosts, allGroups, err := LoadBaseFiles()
106         if err != nil {
107                 return err
108         }
109         cfg.DryRun = *optionDryRun
110         cfg.Quiet = *optionQuiet
111         cfg.LogLevel = level
112
113         toSync, err := hostsToSync(names, allHosts, allGroups)
114         if err != nil {
115                 return err
116         }
117         if len(toSync) == 0 {
118                 return fmt.Errorf("no hosts found")
119         }
120
121         isTTY := term.IsTerminal(int(os.Stdout.Fd()))
122
123         done := make(chan bool)
124         // Collect events from all hosts and print them
125         events := make(chan Event)
126         go func() {
127                 var failed bool
128                 for {
129                         x := <-events
130                         if x.Host == nil {
131                                 break
132                         }
133                         logEvent(x, cfg.LogLevel, isTTY, &failed)
134                 }
135                 done <- failed
136         }()
137
138         // Sync all hosts concurrently
139         var wg sync.WaitGroup
140         for _, x := range toSync {
141                 x := x
142
143                 // Once in sync.Host() and once in the go func below
144                 wg.Add(2)
145
146                 go func() {
147                         sync := Sync{
148                                 host:      x,
149                                 config:    cfg,
150                                 allHosts:  allHosts,
151                                 allGroups: allGroups,
152                                 events:    events,
153                                 isTTY:     isTTY,
154                         }
155                         err := sync.Host(&wg)
156                         if err != nil {
157                                 events <- Event{
158                                         Host:  x,
159                                         Error: err,
160                                 }
161                         }
162                         wg.Done()
163                 }()
164         }
165
166         wg.Wait()
167         events <- Event{} // poison pill
168         failed := <-done
169
170         if failed {
171                 // Exit instead of returning an error to prevent an extra log
172                 // message from main()
173                 os.Exit(1)
174         }
175         return nil
176 }
177
178 // hostsToSync returns the list of hosts to sync based on the command line
179 // arguments.
180 //
181 // Full host and group matches are required to prevent unexpected behavior. No
182 // arguments does not expand to all hosts to prevent accidents; "all" can be
183 // used instead. Both host and group names are permitted as these are unique.
184 //
185 // TODO: Add option to permit partial/glob matches
186 func hostsToSync(names []string, allHosts *config.Hosts,
187         allGroups map[string][]string) ([]*config.Host, error) {
188
189         detectedMap := make(map[string]bool)
190         for _, x := range config.TransitivelyDetectedGroups(allGroups) {
191                 detectedMap[x] = true
192         }
193
194         const detectedErr = `
195
196 Groups depending on "detected" groups cannot be used to select hosts as these
197 are only available after the hosts were contacted.
198 `
199
200         nameMap := make(map[string]bool)
201         for _, x := range names {
202                 if detectedMap[x] {
203                         return nil, fmt.Errorf(
204                                 "group %q depends on \"detected\" groups%s",
205                                 x, detectedErr)
206                 }
207                 nameMap[x] = true
208         }
209         nameMatched := make(map[string]bool)
210         // To detect typos we must check all given names but only want to add
211         // each match once
212         hostMatched := make(map[string]bool)
213
214         var res []*config.Host
215         for _, host := range allHosts.List {
216                 if nameMap[host.Name] {
217                         res = append(res, host)
218                         hostMatched[host.Name] = true
219                         nameMatched[host.Name] = true
220                 }
221
222                 groups, err := config.ResolveHostGroups(host.Name,
223                         allGroups, nil)
224                 if err != nil {
225                         return nil, err
226                 }
227                 for _, x := range groups {
228                         if nameMap[x] {
229                                 if !hostMatched[host.Name] {
230                                         res = append(res, host)
231                                         hostMatched[host.Name] = true
232                                 }
233                                 nameMatched[x] = true
234                         }
235                 }
236         }
237
238         // Warn about unmatched names to detect typos
239         if len(nameMap) != len(nameMatched) {
240                 var unmatched []string
241                 for x := range nameMap {
242                         if !nameMatched[x] {
243                                 unmatched = append(unmatched,
244                                         fmt.Sprintf("%q", x))
245                         }
246                 }
247                 sort.Strings(unmatched)
248                 return nil, fmt.Errorf("hosts/groups not found: %s",
249                         strings.Join(unmatched, " "))
250         }
251
252         return res, nil
253 }
254
255 func logEvent(x Event, level safcm.LogLevel, isTTY bool, failed *bool) {
256         // We have multiple event sources so this is somewhat ugly.
257         var prefix, data string
258         var color Color
259         if x.Error != nil {
260                 prefix = "[error]"
261                 data = x.Error.Error()
262                 color = ColorRed
263                 // We logged an error, tell the caller
264                 *failed = true
265         } else if x.Log.Level != 0 {
266                 // LogError and LogDebug3 should not occur here
267                 switch x.Log.Level {
268                 case safcm.LogInfo:
269                         prefix = "[info]"
270                 case safcm.LogVerbose:
271                         prefix = "[verbose]"
272                 case safcm.LogDebug:
273                         prefix = "[debug]"
274                 case safcm.LogDebug2:
275                         prefix = "[debug2]"
276                 default:
277                         prefix = fmt.Sprintf("[INVALID=%d]", x.Log.Level)
278                         color = ColorRed
279                 }
280                 data = x.Log.Text
281         } else {
282                 switch x.ConnEvent.Type {
283                 case rpc.ConnEventStderr:
284                         prefix = "[stderr]"
285                 case rpc.ConnEventDebug:
286                         prefix = "[debug3]"
287                 case rpc.ConnEventUpload:
288                         if level < safcm.LogInfo {
289                                 return
290                         }
291                         prefix = "[info]"
292                         x.ConnEvent.Data = "remote helper upload in progress"
293                 default:
294                         prefix = fmt.Sprintf("[INVALID=%d]", x.ConnEvent.Type)
295                         color = ColorRed
296                 }
297                 data = x.ConnEvent.Data
298         }
299
300         host := x.Host.Name
301         if color != 0 {
302                 host = ColorString(isTTY, color, host)
303         }
304         // Make sure to escape control characters to prevent terminal
305         // injection attacks
306         if !x.Escaped {
307                 data = EscapeControlCharacters(isTTY, data)
308         }
309         log.Printf("%-9s [%s] %s", prefix, host, data)
310 }
311
312 func (s *Sync) Host(wg *sync.WaitGroup) error {
313         conn := rpc.NewConn(s.config.LogLevel >= safcm.LogDebug3)
314         // Pass all connection events to main loop
315         go func() {
316                 for {
317                         x, ok := <-conn.Events
318                         if !ok {
319                                 break
320                         }
321                         s.events <- Event{
322                                 Host:      s.host,
323                                 ConnEvent: x,
324                         }
325                 }
326                 wg.Done()
327         }()
328
329         // Connect to remote host
330         err := conn.DialSSH(s.host.SshUser, s.host.Name)
331         if err != nil {
332                 return err
333         }
334         defer conn.Kill()
335
336         // Collect information about remote host
337         detectedGroups, err := s.hostInfo(conn)
338         if err != nil {
339                 return err
340         }
341
342         // Sync state to remote host
343         err = s.hostSync(conn, detectedGroups)
344         if err != nil {
345                 return err
346         }
347
348         // Terminate connection to remote host
349         err = conn.Send(safcm.MsgQuitReq{})
350         if err != nil {
351                 return err
352         }
353         _, err = conn.Recv()
354         if err != nil {
355                 return err
356         }
357         err = conn.Wait()
358         if err != nil {
359                 return err
360         }
361
362         return nil
363 }
364
365 func (s *Sync) logf(level safcm.LogLevel, escaped bool,
366         format string, a ...interface{}) {
367
368         if s.config.LogLevel < level {
369                 return
370         }
371         s.events <- Event{
372                 Host: s.host,
373                 Log: Log{
374                         Level: level,
375                         Text:  fmt.Sprintf(format, a...),
376                 },
377                 Escaped: escaped,
378         }
379 }
380 func (s *Sync) logDebugf(format string, a ...interface{}) {
381         s.logf(safcm.LogDebug, false, format, a...)
382 }
383 func (s *Sync) logVerbosef(format string, a ...interface{}) {
384         s.logf(safcm.LogVerbose, false, format, a...)
385 }
386
387 // sendRecv sends a message over conn and waits for the response. Any MsgLog
388 // messages received before the final (non MsgLog) response are passed to
389 // s.log.
390 func (s *Sync) sendRecv(conn *rpc.Conn, msg safcm.Msg) (safcm.Msg, error) {
391         err := conn.Send(msg)
392         if err != nil {
393                 return nil, err
394         }
395         for {
396                 x, err := conn.Recv()
397                 if err != nil {
398                         return nil, err
399                 }
400                 log, ok := x.(safcm.MsgLog)
401                 if ok {
402                         s.logf(log.Level, false, "%s", log.Text)
403                         continue
404                 }
405                 return x, nil
406         }
407 }