// "sync" sub-command: sync data to remote hosts // Copyright (C) 2021 Simon Ruderich // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see . package main import ( "flag" "fmt" "log" "os" "sort" "strings" "sync" "golang.org/x/term" "ruderich.org/simon/safcm" "ruderich.org/simon/safcm/cmd/safcm/config" "ruderich.org/simon/safcm/rpc" ) type Sync struct { host *config.Host config *config.Config // global configuration allHosts *config.Hosts // known hosts allGroups map[string][]string // known groups events chan<- Event // all events generated by/for this host isTTY bool } type Event struct { Host *config.Host // Only one of Error, Log and ConnEvent is set in a single event Error error Log Log ConnEvent rpc.ConnEvent Escaped bool // true if untrusted input is already escaped } type Log struct { Level safcm.LogLevel Text string } func MainSync(args []string) error { flag.Usage = func() { fmt.Fprintf(os.Stderr, "usage: %s sync [] \n", args[0]) flag.PrintDefaults() } optionDryRun := flag.Bool("n", false, "dry-run, show diff but don't perform any changes") optionLog := flag.String("log", "info", "set log `level`; "+ "levels: error, info, verbose, debug, debug2, debug3") flag.CommandLine.Parse(args[2:]) var level safcm.LogLevel switch *optionLog { case "error": level = safcm.LogError case "info": level = safcm.LogInfo case "verbose": level = safcm.LogVerbose case "debug": level = safcm.LogDebug case "debug2": level = safcm.LogDebug2 case "debug3": level = safcm.LogDebug3 default: return fmt.Errorf("invalid -log value %q", *optionLog) } names := flag.Args() if len(names) == 0 { flag.Usage() os.Exit(1) } cfg, allHosts, allGroups, err := LoadBaseFiles() if err != nil { return err } cfg.DryRun = *optionDryRun cfg.LogLevel = level toSync, err := hostsToSync(names, allHosts, allGroups) if err != nil { return err } if len(toSync) == 0 { return fmt.Errorf("no hosts found") } isTTY := term.IsTerminal(int(os.Stdout.Fd())) done := make(chan bool) // Collect events from all hosts and print them events := make(chan Event) go func() { var failed bool for { x := <-events if x.Host == nil { break } logEvent(x, cfg.LogLevel, isTTY, &failed) } done <- failed }() // Sync all hosts concurrently var wg sync.WaitGroup for _, x := range toSync { x := x // Once in sync.Host() and once in the go func below wg.Add(2) go func() { sync := Sync{ host: x, config: cfg, allHosts: allHosts, allGroups: allGroups, events: events, isTTY: isTTY, } err := sync.Host(&wg) if err != nil { events <- Event{ Host: x, Error: err, } } wg.Done() }() } wg.Wait() events <- Event{} // poison pill failed := <-done if failed { // Exit instead of returning an error to prevent an extra log // message from main() os.Exit(1) } return nil } // hostsToSync returns the list of hosts to sync based on the command line // arguments. // // Full host and group matches are required to prevent unexpected behavior. No // arguments does not expand to all hosts to prevent accidents; "all" can be // used instead. Both host and group names are permitted as these are unique. // // TODO: Add option to permit partial/glob matches func hostsToSync(names []string, allHosts *config.Hosts, allGroups map[string][]string) ([]*config.Host, error) { nameMap := make(map[string]bool) for _, x := range names { nameMap[x] = true } nameMatched := make(map[string]bool) // To detect typos we must check all given names but only want to add // each match once hostMatched := make(map[string]bool) var res []*config.Host for _, host := range allHosts.List { if nameMap[host.Name] { res = append(res, host) hostMatched[host.Name] = true nameMatched[host.Name] = true } // TODO: don't permit groups which contain "detected" groups // because these are not available yet groups, err := config.ResolveHostGroups(host.Name, allGroups, nil) if err != nil { return nil, err } for _, x := range groups { if nameMap[x] { if !hostMatched[host.Name] { res = append(res, host) hostMatched[host.Name] = true } nameMatched[x] = true } } } // Warn about unmatched names to detect typos if len(nameMap) != len(nameMatched) { var unmatched []string for x := range nameMap { if !nameMatched[x] { unmatched = append(unmatched, fmt.Sprintf("%q", x)) } } sort.Strings(unmatched) return nil, fmt.Errorf("hosts/groups not found: %s", strings.Join(unmatched, " ")) } return res, nil } func logEvent(x Event, level safcm.LogLevel, isTTY bool, failed *bool) { // We have multiple event sources so this is somewhat ugly. var prefix, data string var color Color if x.Error != nil { prefix = "[error]" data = x.Error.Error() color = ColorRed // We logged an error, tell the caller *failed = true } else if x.Log.Level != 0 { // LogError and LogDebug3 should not occur here switch x.Log.Level { case safcm.LogInfo: prefix = "[info]" case safcm.LogVerbose: prefix = "[verbose]" case safcm.LogDebug: prefix = "[debug]" case safcm.LogDebug2: prefix = "[debug2]" default: prefix = fmt.Sprintf("[INVALID=%d]", x.Log.Level) color = ColorRed } data = x.Log.Text } else { switch x.ConnEvent.Type { case rpc.ConnEventStderr: prefix = "[stderr]" case rpc.ConnEventDebug: prefix = "[debug3]" case rpc.ConnEventUpload: if level < safcm.LogInfo { return } prefix = "[info]" x.ConnEvent.Data = "remote helper upload in progress" default: prefix = fmt.Sprintf("[INVALID=%d]", x.ConnEvent.Type) color = ColorRed } data = x.ConnEvent.Data } host := x.Host.Name if color != 0 { host = ColorString(isTTY, color, host) } // Make sure to escape control characters to prevent terminal // injection attacks if !x.Escaped { data = EscapeControlCharacters(isTTY, data) } log.Printf("%-9s [%s] %s", prefix, host, data) } func (s *Sync) Host(wg *sync.WaitGroup) error { conn := rpc.NewConn(s.config.LogLevel >= safcm.LogDebug3) // Pass all connection events to main loop go func() { for { x, ok := <-conn.Events if !ok { break } s.events <- Event{ Host: s.host, ConnEvent: x, } } wg.Done() }() // Connect to remote host err := conn.DialSSH(s.host.SshUser, s.host.Name) if err != nil { return err } defer conn.Kill() // Collect information about remote host detectedGroups, err := s.hostInfo(conn) if err != nil { return err } // Sync state to remote host err = s.hostSync(conn, detectedGroups) if err != nil { return err } // Terminate connection to remote host err = conn.Send(safcm.MsgQuitReq{}) if err != nil { return err } _, err = conn.Recv() if err != nil { return err } err = conn.Wait() if err != nil { return err } return nil } func (s *Sync) logf(level safcm.LogLevel, escaped bool, format string, a ...interface{}) { if s.config.LogLevel < level { return } s.events <- Event{ Host: s.host, Log: Log{ Level: level, Text: fmt.Sprintf(format, a...), }, Escaped: escaped, } } func (s *Sync) logDebugf(format string, a ...interface{}) { s.logf(safcm.LogDebug, false, format, a...) } func (s *Sync) logVerbosef(format string, a ...interface{}) { s.logf(safcm.LogVerbose, false, format, a...) } // sendRecv sends a message over conn and waits for the response. Any MsgLog // messages received before the final (non MsgLog) response are passed to // s.log. func (s *Sync) sendRecv(conn *rpc.Conn, msg safcm.Msg) (safcm.Msg, error) { err := conn.Send(msg) if err != nil { return nil, err } for { x, err := conn.Recv() if err != nil { return nil, err } log, ok := x.(safcm.MsgLog) if ok { s.logf(log.Level, false, "%s", log.Text) continue } return x, nil } }