]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - cmd/safcm/sync.go
go fmt
[safcm/safcm.git] / cmd / safcm / sync.go
1 // "sync" sub-command: sync data to remote hosts
2
3 // Copyright (C) 2021  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "flag"
22         "fmt"
23         "io/fs"
24         "log"
25         "os"
26         "runtime"
27         "sort"
28         "strings"
29
30         "golang.org/x/term"
31
32         "ruderich.org/simon/safcm"
33         "ruderich.org/simon/safcm/cmd/safcm/config"
34         "ruderich.org/simon/safcm/frontend"
35         "ruderich.org/simon/safcm/rpc"
36 )
37
38 type Sync struct {
39         host *config.Host
40
41         config    *config.Config      // global configuration
42         allHosts  *config.Hosts       // known hosts
43         allGroups map[string][]string // known groups
44
45         isTTY bool
46
47         logFunc func(level safcm.LogLevel, escaped bool, msg string)
48 }
49
50 func MainSync(args []string) error {
51         flag.Usage = func() {
52                 fmt.Fprintf(os.Stderr,
53                         "usage: %s sync [<options>] <host|group...>\n",
54                         args[0])
55                 flag.PrintDefaults()
56         }
57
58         optionDryRun := flag.Bool("n", false,
59                 "dry-run, show diff but don't perform any changes")
60         optionQuiet := flag.Bool("q", false,
61                 "hide successful, non-trigger commands with no output from host changes listing")
62         optionLog := flag.String("log", "info", "set log `level`; "+
63                 "levels: error, info, verbose, debug, debug2, debug3")
64         optionSshConfig := flag.String("sshconfig", "",
65                 "`path` to ssh configuration file; used for tests")
66
67         flag.CommandLine.Parse(args[2:])
68
69         var level safcm.LogLevel
70         switch *optionLog {
71         case "error":
72                 level = safcm.LogError
73         case "info":
74                 level = safcm.LogInfo
75         case "verbose":
76                 level = safcm.LogVerbose
77         case "debug":
78                 level = safcm.LogDebug
79         case "debug2":
80                 level = safcm.LogDebug2
81         case "debug3":
82                 level = safcm.LogDebug3
83         default:
84                 return fmt.Errorf("invalid -log value %q", *optionLog)
85         }
86
87         names := flag.Args()
88         if len(names) == 0 {
89                 flag.Usage()
90                 os.Exit(1)
91         }
92
93         if runtime.GOOS == "windows" {
94                 log.Print("WARNING: Windows support is experimental!")
95         }
96
97         cfg, allHosts, allGroups, err := LoadBaseFiles()
98         if err != nil {
99                 return err
100         }
101         cfg.DryRun = *optionDryRun
102         cfg.Quiet = *optionQuiet
103         cfg.LogLevel = level
104         cfg.SshConfig = *optionSshConfig
105
106         toSync, err := hostsToSync(names, allHosts, allGroups)
107         if err != nil {
108                 return err
109         }
110         if len(toSync) == 0 {
111                 return fmt.Errorf("no hosts found")
112         }
113
114         isTTY := term.IsTerminal(int(os.Stdout.Fd())) &&
115                 term.IsTerminal(int(os.Stderr.Fd()))
116
117         loop := &frontend.Loop{
118                 DebugConn: cfg.LogLevel >= safcm.LogDebug3,
119                 LogEventFunc: func(x frontend.Event, failed *bool) {
120                         logEvent(x, cfg.LogLevel, isTTY, failed)
121                 },
122                 SyncHostFunc: func(conn *rpc.Conn, host frontend.Host) error {
123                         return host.(*Sync).Host(conn)
124                 },
125         }
126
127         var hosts []frontend.Host
128         for _, x := range toSync {
129                 s := &Sync{
130                         host:      x,
131                         config:    cfg,
132                         allHosts:  allHosts,
133                         allGroups: allGroups,
134                         isTTY:     isTTY,
135                 }
136                 s.logFunc = func(level safcm.LogLevel, escaped bool,
137                         msg string) {
138                         loop.Log(s, level, escaped, msg)
139                 }
140                 hosts = append(hosts, s)
141         }
142
143         succ := loop.Run(hosts)
144
145         if !succ {
146                 // Exit instead of returning an error to prevent an extra log
147                 // message from main()
148                 os.Exit(1)
149         }
150         return nil
151 }
152
153 // hostsToSync returns the list of hosts to sync based on the command line
154 // arguments.
155 //
156 // Full host and group matches are required to prevent unexpected behavior. No
157 // arguments does not expand to all hosts to prevent accidents; "all" can be
158 // used instead. Both host and group names are permitted as these are unique.
159 //
160 // TODO: Add option to permit partial/glob matches
161 func hostsToSync(names []string, allHosts *config.Hosts,
162         allGroups map[string][]string) ([]*config.Host, error) {
163
164         detectedMap := config.TransitivelyDetectedGroups(allGroups)
165
166         const detectedErr = `
167
168 Groups depending on "detected" groups cannot be used to select hosts as these
169 are only available after the hosts were contacted.
170 `
171
172         nameMap := make(map[string]bool)
173         for _, x := range names {
174                 if detectedMap[x] {
175                         return nil, fmt.Errorf(
176                                 "group %q depends on \"detected\" groups%s",
177                                 x, detectedErr)
178                 }
179                 nameMap[x] = true
180         }
181         nameMatched := make(map[string]bool)
182         // To detect typos we must check all given names but one host can be
183         // matched by multiple names (e.g. two groups with overlapping hosts)
184         hostAdded := make(map[string]bool)
185
186         var res []*config.Host
187         for _, host := range allHosts.List {
188                 if nameMap[host.Name] {
189                         res = append(res, host)
190                         hostAdded[host.Name] = true
191                         nameMatched[host.Name] = true
192                 }
193
194                 groups, err := config.ResolveHostGroups(host.Name,
195                         allGroups, nil)
196                 if err != nil {
197                         return nil, err
198                 }
199                 for _, x := range groups {
200                         if nameMap[x] {
201                                 if !hostAdded[host.Name] {
202                                         res = append(res, host)
203                                         hostAdded[host.Name] = true
204                                 }
205                                 nameMatched[x] = true
206                         }
207                 }
208         }
209
210         // Warn about unmatched names to detect typos
211         if len(nameMap) != len(nameMatched) {
212                 var unmatched []string
213                 for x := range nameMap {
214                         if !nameMatched[x] {
215                                 unmatched = append(unmatched,
216                                         fmt.Sprintf("%q", x))
217                         }
218                 }
219                 sort.Strings(unmatched)
220                 return nil, fmt.Errorf("hosts/groups not found: %s",
221                         strings.Join(unmatched, " "))
222         }
223
224         return res, nil
225 }
226
227 func logEvent(x frontend.Event, level safcm.LogLevel, isTTY bool, failed *bool) {
228         // We have multiple event sources so this is somewhat ugly.
229         var prefix, data string
230         var color Color
231         if x.Error != nil {
232                 prefix = "[error]"
233                 data = x.Error.Error()
234                 color = ColorRed
235                 // We logged an error, tell the caller
236                 *failed = true
237         } else if x.Log.Level != 0 {
238                 if level < x.Log.Level {
239                         return
240                 }
241                 // LogError and LogDebug3 should not occur here
242                 switch x.Log.Level {
243                 case safcm.LogInfo:
244                         prefix = "[info]"
245                 case safcm.LogVerbose:
246                         prefix = "[verbose]"
247                 case safcm.LogDebug:
248                         prefix = "[debug]"
249                 case safcm.LogDebug2:
250                         prefix = "[debug2]"
251                 default:
252                         prefix = fmt.Sprintf("[INVALID=%d]", x.Log.Level)
253                         color = ColorRed
254                 }
255                 data = x.Log.Text
256         } else {
257                 switch x.ConnEvent.Type {
258                 case rpc.ConnEventStderr:
259                         prefix = "[stderr]"
260                 case rpc.ConnEventDebug:
261                         prefix = "[debug3]"
262                 case rpc.ConnEventUpload:
263                         if level < safcm.LogInfo {
264                                 return
265                         }
266                         prefix = "[info]"
267                         x.ConnEvent.Data = "remote helper upload in progress"
268                 default:
269                         prefix = fmt.Sprintf("[INVALID=%d]", x.ConnEvent.Type)
270                         color = ColorRed
271                 }
272                 data = x.ConnEvent.Data
273         }
274
275         host := x.Host.Name()
276         if color != 0 {
277                 host = ColorString(isTTY, color, host)
278         }
279         // Make sure to escape control characters to prevent terminal
280         // injection attacks
281         if !x.Escaped {
282                 data = EscapeControlCharacters(isTTY, data)
283         }
284         log.Printf("%-9s [%s] %s", prefix, host, data)
285 }
286
287 func (s *Sync) Name() string {
288         return s.host.Name
289 }
290
291 func (s *Sync) Dial(conn *rpc.Conn) error {
292         helpers, err := fs.Sub(RemoteHelpers, "remote")
293         if err != nil {
294                 return err
295         }
296
297         // Connect to remote host
298         user := s.host.SshUser
299         if user == "" {
300                 user = s.config.SshUser
301         }
302         return conn.DialSSH(rpc.SSHConfig{
303                 Host:          s.host.Name,
304                 User:          user,
305                 SshConfig:     s.config.SshConfig,
306                 RemoteHelpers: helpers,
307         })
308 }
309
310 func (s *Sync) Host(conn *rpc.Conn) error {
311         // Collect information about remote host
312         detectedGroups, err := s.hostInfo(conn)
313         if err != nil {
314                 return err
315         }
316
317         // Sync state to remote host
318         err = s.hostSync(conn, detectedGroups)
319         if err != nil {
320                 return err
321         }
322
323         return nil
324 }
325
326 func (s *Sync) log(level safcm.LogLevel, escaped bool, msg string) {
327         s.logFunc(level, escaped, msg)
328 }
329 func (s *Sync) logDebugf(format string, a ...interface{}) {
330         s.log(safcm.LogDebug, false, fmt.Sprintf(format, a...))
331 }
332 func (s *Sync) logVerbosef(format string, a ...interface{}) {
333         s.log(safcm.LogVerbose, false, fmt.Sprintf(format, a...))
334 }
335
336 // sendRecv sends a message over conn and waits for the response. Any MsgLog
337 // messages received before the final (non MsgLog) response are passed to
338 // s.log.
339 func (s *Sync) sendRecv(conn *rpc.Conn, msg safcm.Msg) (safcm.Msg, error) {
340         err := conn.Send(msg)
341         if err != nil {
342                 return nil, err
343         }
344         for {
345                 x, err := conn.Recv()
346                 if err != nil {
347                         return nil, err
348                 }
349                 log, ok := x.(safcm.MsgLog)
350                 if ok {
351                         s.log(log.Level, false, log.Text)
352                         continue
353                 }
354                 return x, nil
355         }
356 }