]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - cmd/safcm/sync.go
config: add Host option "ssh_user"
[safcm/safcm.git] / cmd / safcm / sync.go
1 // "sync" sub-command: sync data to remote hosts
2
3 // Copyright (C) 2021  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "flag"
22         "fmt"
23         "log"
24         "os"
25         "sort"
26         "strings"
27         "sync"
28
29         "golang.org/x/term"
30
31         "ruderich.org/simon/safcm"
32         "ruderich.org/simon/safcm/cmd/safcm/config"
33         "ruderich.org/simon/safcm/rpc"
34 )
35
36 type Sync struct {
37         host *config.Host
38
39         config    *config.Config      // global configuration
40         allHosts  *config.Hosts       // known hosts
41         allGroups map[string][]string // known groups
42
43         events chan<- Event // all events generated by/for this host
44
45         isTTY bool
46 }
47
48 type Event struct {
49         Host *config.Host
50
51         // Only one of Error, Log and ConnEvent is set in a single event
52         Error     error
53         Log       Log
54         ConnEvent rpc.ConnEvent
55
56         Escaped bool // true if untrusted input is already escaped
57 }
58
59 type Log struct {
60         Level safcm.LogLevel
61         Text  string
62 }
63
64 func MainSync(args []string) error {
65         flag.Usage = func() {
66                 fmt.Fprintf(os.Stderr,
67                         "usage: %s sync [<options>] <host|group...>\n",
68                         args[0])
69                 flag.PrintDefaults()
70         }
71
72         optionDryRun := flag.Bool("n", false,
73                 "dry-run, show diff but don't perform any changes")
74         optionLog := flag.String("log", "info", "set log `level`; "+
75                 "levels: error, info, verbose, debug, debug2, debug3")
76
77         flag.CommandLine.Parse(args[2:])
78
79         var level safcm.LogLevel
80         switch *optionLog {
81         case "error":
82                 level = safcm.LogError
83         case "info":
84                 level = safcm.LogInfo
85         case "verbose":
86                 level = safcm.LogVerbose
87         case "debug":
88                 level = safcm.LogDebug
89         case "debug2":
90                 level = safcm.LogDebug2
91         case "debug3":
92                 level = safcm.LogDebug3
93         default:
94                 return fmt.Errorf("invalid -log value %q", *optionLog)
95         }
96
97         names := flag.Args()
98         if len(names) == 0 {
99                 flag.Usage()
100                 os.Exit(1)
101         }
102
103         cfg, allHosts, allGroups, err := LoadBaseFiles()
104         if err != nil {
105                 return err
106         }
107         cfg.DryRun = *optionDryRun
108         cfg.LogLevel = level
109
110         toSync, err := hostsToSync(names, allHosts, allGroups)
111         if err != nil {
112                 return err
113         }
114         if len(toSync) == 0 {
115                 return fmt.Errorf("no hosts found")
116         }
117
118         isTTY := term.IsTerminal(int(os.Stdout.Fd()))
119
120         done := make(chan bool)
121         // Collect events from all hosts and print them
122         events := make(chan Event)
123         go func() {
124                 var failed bool
125                 for {
126                         x := <-events
127                         if x.Host == nil {
128                                 break
129                         }
130                         logEvent(x, cfg.LogLevel, isTTY, &failed)
131                 }
132                 done <- failed
133         }()
134
135         // Sync all hosts concurrently
136         var wg sync.WaitGroup
137         for _, x := range toSync {
138                 x := x
139
140                 // Once in sync.Host() and once in the go func below
141                 wg.Add(2)
142
143                 go func() {
144                         sync := Sync{
145                                 host:      x,
146                                 config:    cfg,
147                                 allHosts:  allHosts,
148                                 allGroups: allGroups,
149                                 events:    events,
150                                 isTTY:     isTTY,
151                         }
152                         err := sync.Host(&wg)
153                         if err != nil {
154                                 events <- Event{
155                                         Host:  x,
156                                         Error: err,
157                                 }
158                         }
159                         wg.Done()
160                 }()
161         }
162
163         wg.Wait()
164         events <- Event{} // poison pill
165         failed := <-done
166
167         if failed {
168                 // Exit instead of returning an error to prevent an extra log
169                 // message from main()
170                 os.Exit(1)
171         }
172         return nil
173 }
174
175 // hostsToSync returns the list of hosts to sync based on the command line
176 // arguments.
177 //
178 // Full host and group matches are required to prevent unexpected behavior. No
179 // arguments does not expand to all hosts to prevent accidents; "all" can be
180 // used instead. Both host and group names are permitted as these are unique.
181 //
182 // TODO: Add option to permit partial/glob matches
183 func hostsToSync(names []string, allHosts *config.Hosts,
184         allGroups map[string][]string) ([]*config.Host, error) {
185
186         nameMap := make(map[string]bool)
187         for _, x := range names {
188                 nameMap[x] = true
189         }
190         nameMatched := make(map[string]bool)
191         // To detect typos we must check all given names but only want to add
192         // each match once
193         hostMatched := make(map[string]bool)
194
195         var res []*config.Host
196         for _, host := range allHosts.List {
197                 if nameMap[host.Name] {
198                         res = append(res, host)
199                         hostMatched[host.Name] = true
200                         nameMatched[host.Name] = true
201                 }
202
203                 // TODO: don't permit groups which contain "detected" groups
204                 // because these are not available yet
205                 groups, err := config.ResolveHostGroups(host.Name,
206                         allGroups, nil)
207                 if err != nil {
208                         return nil, err
209                 }
210                 for _, x := range groups {
211                         if nameMap[x] {
212                                 if !hostMatched[host.Name] {
213                                         res = append(res, host)
214                                         hostMatched[host.Name] = true
215                                 }
216                                 nameMatched[x] = true
217                         }
218                 }
219         }
220
221         // Warn about unmatched names to detect typos
222         if len(nameMap) != len(nameMatched) {
223                 var unmatched []string
224                 for x := range nameMap {
225                         if !nameMatched[x] {
226                                 unmatched = append(unmatched,
227                                         fmt.Sprintf("%q", x))
228                         }
229                 }
230                 sort.Strings(unmatched)
231                 return nil, fmt.Errorf("hosts/groups not found: %s",
232                         strings.Join(unmatched, " "))
233         }
234
235         return res, nil
236 }
237
238 func logEvent(x Event, level safcm.LogLevel, isTTY bool, failed *bool) {
239         // We have multiple event sources so this is somewhat ugly.
240         var prefix, data string
241         var color Color
242         if x.Error != nil {
243                 prefix = "[error]"
244                 data = x.Error.Error()
245                 color = ColorRed
246                 // We logged an error, tell the caller
247                 *failed = true
248         } else if x.Log.Level != 0 {
249                 // LogError and LogDebug3 should not occur here
250                 switch x.Log.Level {
251                 case safcm.LogInfo:
252                         prefix = "[info]"
253                 case safcm.LogVerbose:
254                         prefix = "[verbose]"
255                 case safcm.LogDebug:
256                         prefix = "[debug]"
257                 case safcm.LogDebug2:
258                         prefix = "[debug2]"
259                 default:
260                         prefix = fmt.Sprintf("[INVALID=%d]", x.Log.Level)
261                         color = ColorRed
262                 }
263                 data = x.Log.Text
264         } else {
265                 switch x.ConnEvent.Type {
266                 case rpc.ConnEventStderr:
267                         prefix = "[stderr]"
268                 case rpc.ConnEventDebug:
269                         prefix = "[debug3]"
270                 case rpc.ConnEventUpload:
271                         if level < safcm.LogInfo {
272                                 return
273                         }
274                         prefix = "[info]"
275                         x.ConnEvent.Data = "remote helper upload in progress"
276                 default:
277                         prefix = fmt.Sprintf("[INVALID=%d]", x.ConnEvent.Type)
278                         color = ColorRed
279                 }
280                 data = x.ConnEvent.Data
281         }
282
283         host := x.Host.Name
284         if color != 0 {
285                 host = ColorString(isTTY, color, host)
286         }
287         // Make sure to escape control characters to prevent terminal
288         // injection attacks
289         if !x.Escaped {
290                 data = EscapeControlCharacters(isTTY, data)
291         }
292         log.Printf("%-9s [%s] %s", prefix, host, data)
293 }
294
295 func (s *Sync) Host(wg *sync.WaitGroup) error {
296         conn := rpc.NewConn(s.config.LogLevel >= safcm.LogDebug3)
297         // Pass all connection events to main loop
298         go func() {
299                 for {
300                         x, ok := <-conn.Events
301                         if !ok {
302                                 break
303                         }
304                         s.events <- Event{
305                                 Host:      s.host,
306                                 ConnEvent: x,
307                         }
308                 }
309                 wg.Done()
310         }()
311
312         // Connect to remote host
313         err := conn.DialSSH(s.host.SshUser, s.host.Name)
314         if err != nil {
315                 return err
316         }
317         defer conn.Kill()
318
319         // Collect information about remote host
320         detectedGroups, err := s.hostInfo(conn)
321         if err != nil {
322                 return err
323         }
324
325         // Sync state to remote host
326         err = s.hostSync(conn, detectedGroups)
327         if err != nil {
328                 return err
329         }
330
331         // Terminate connection to remote host
332         err = conn.Send(safcm.MsgQuitReq{})
333         if err != nil {
334                 return err
335         }
336         _, err = conn.Recv()
337         if err != nil {
338                 return err
339         }
340         err = conn.Wait()
341         if err != nil {
342                 return err
343         }
344
345         return nil
346 }
347
348 func (s *Sync) logf(level safcm.LogLevel, escaped bool,
349         format string, a ...interface{}) {
350
351         if s.config.LogLevel < level {
352                 return
353         }
354         s.events <- Event{
355                 Host: s.host,
356                 Log: Log{
357                         Level: level,
358                         Text:  fmt.Sprintf(format, a...),
359                 },
360                 Escaped: escaped,
361         }
362 }
363 func (s *Sync) logDebugf(format string, a ...interface{}) {
364         s.logf(safcm.LogDebug, false, format, a...)
365 }
366 func (s *Sync) logVerbosef(format string, a ...interface{}) {
367         s.logf(safcm.LogVerbose, false, format, a...)
368 }
369
370 // sendRecv sends a message over conn and waits for the response. Any MsgLog
371 // messages received before the final (non MsgLog) response are passed to
372 // s.log.
373 func (s *Sync) sendRecv(conn *rpc.Conn, msg safcm.Msg) (safcm.Msg, error) {
374         err := conn.Send(msg)
375         if err != nil {
376                 return nil, err
377         }
378         for {
379                 x, err := conn.Recv()
380                 if err != nil {
381                         return nil, err
382                 }
383                 log, ok := x.(safcm.MsgLog)
384                 if ok {
385                         s.logf(log.Level, false, "%s", log.Text)
386                         continue
387                 }
388                 return x, nil
389         }
390 }