]> ruderich.org/simon Gitweb - safcm/safcm.git/blobdiff - cmd/safcm/sync.go
safcm: move sync_changes.go and term.go to frontend package
[safcm/safcm.git] / cmd / safcm / sync.go
index 6358f0cb60feb4570cc5dae1a9e0e24bce6a69ed..f1877cffde6b9b42b18955587192a19fda43127d 100644 (file)
@@ -20,17 +20,18 @@ package main
 import (
        "flag"
        "fmt"
+       "io/fs"
        "log"
        "os"
-       "os/signal"
+       "runtime"
        "sort"
        "strings"
-       "sync"
 
        "golang.org/x/term"
 
        "ruderich.org/simon/safcm"
        "ruderich.org/simon/safcm/cmd/safcm/config"
+       "ruderich.org/simon/safcm/frontend"
        "ruderich.org/simon/safcm/rpc"
 )
 
@@ -41,25 +42,9 @@ type Sync struct {
        allHosts  *config.Hosts       // known hosts
        allGroups map[string][]string // known groups
 
-       events chan<- Event // all events generated by/for this host
-
        isTTY bool
-}
-
-type Event struct {
-       Host *config.Host
-
-       // Only one of Error, Log and ConnEvent is set in a single event
-       Error     error
-       Log       Log
-       ConnEvent rpc.ConnEvent
-
-       Escaped bool // true if untrusted input is already escaped
-}
 
-type Log struct {
-       Level safcm.LogLevel
-       Text  string
+       logFunc func(level safcm.LogLevel, escaped bool, msg string)
 }
 
 func MainSync(args []string) error {
@@ -81,22 +66,9 @@ func MainSync(args []string) error {
 
        flag.CommandLine.Parse(args[2:])
 
-       var level safcm.LogLevel
-       switch *optionLog {
-       case "error":
-               level = safcm.LogError
-       case "info":
-               level = safcm.LogInfo
-       case "verbose":
-               level = safcm.LogVerbose
-       case "debug":
-               level = safcm.LogDebug
-       case "debug2":
-               level = safcm.LogDebug2
-       case "debug3":
-               level = safcm.LogDebug3
-       default:
-               return fmt.Errorf("invalid -log value %q", *optionLog)
+       level, err := safcm.ParseLogLevel(*optionLog)
+       if err != nil {
+               return fmt.Errorf("-log: %v", err)
        }
 
        names := flag.Args()
@@ -105,6 +77,10 @@ func MainSync(args []string) error {
                os.Exit(1)
        }
 
+       if runtime.GOOS == "windows" {
+               log.Print("WARNING: Windows support is experimental!")
+       }
+
        cfg, allHosts, allGroups, err := LoadBaseFiles()
        if err != nil {
                return err
@@ -122,92 +98,38 @@ func MainSync(args []string) error {
                return fmt.Errorf("no hosts found")
        }
 
-       isTTY := term.IsTerminal(int(os.Stdout.Fd()))
-
-       done := make(chan bool)
-       // Collect events from all hosts and print them
-       events := make(chan Event)
-       go func() {
-               var failed bool
-               for {
-                       x := <-events
-                       if x.Host == nil {
-                               break
-                       }
-                       logEvent(x, cfg.LogLevel, isTTY, &failed)
-               }
-               done <- failed
-       }()
+       isTTY := term.IsTerminal(int(os.Stdout.Fd())) &&
+               term.IsTerminal(int(os.Stderr.Fd()))
 
-       hostsLeft := make(map[string]bool)
-       for _, x := range toSync {
-               hostsLeft[x.Name] = true
+       loop := &frontend.Loop{
+               DebugConn: cfg.LogLevel >= safcm.LogDebug3,
+               LogEventFunc: func(x frontend.Event, failed *bool) {
+                       logEvent(x, cfg.LogLevel, isTTY, failed)
+               },
+               SyncHostFunc: func(conn *rpc.Conn, host frontend.Host) error {
+                       return host.(*Sync).Host(conn)
+               },
        }
-       var hostsLeftMutex sync.Mutex // protects hostsLeft
-
-       // Show unfinished hosts on Ctrl-C
-       sigint := make(chan os.Signal, 1)   // buffered for Notify()
-       signal.Notify(sigint, os.Interrupt) // = SIGINT = Ctrl-C
-       go func() {
-               // Running `ssh` processes get killed by SIGINT which is sent
-               // to all processes
-
-               <-sigint
-               log.Print("Received SIGINT, aborting ...")
-
-               // Print all queued events
-               events <- Event{} // poison pill
-               <-done
-               // "races" with <-done in the main function and will hang here
-               // if the other is faster. This is fine because then all hosts
-               // were synced successfully.
-
-               hostsLeftMutex.Lock()
-               var hosts []string
-               for x := range hostsLeft {
-                       hosts = append(hosts, x)
-               }
-               sort.Strings(hosts)
-               log.Fatalf("Failed to sync %s", strings.Join(hosts, ", "))
-       }()
 
-       // Sync all hosts concurrently
-       var wg sync.WaitGroup
+       var hosts []frontend.Host
        for _, x := range toSync {
-               x := x
-
-               // Once in sync.Host() and once in the go func below
-               wg.Add(2)
-
-               go func() {
-                       sync := Sync{
-                               host:      x,
-                               config:    cfg,
-                               allHosts:  allHosts,
-                               allGroups: allGroups,
-                               events:    events,
-                               isTTY:     isTTY,
-                       }
-                       err := sync.Host(&wg)
-                       if err != nil {
-                               events <- Event{
-                                       Host:  x,
-                                       Error: err,
-                               }
-                       }
-                       wg.Done()
-
-                       hostsLeftMutex.Lock()
-                       defer hostsLeftMutex.Unlock()
-                       delete(hostsLeft, x.Name)
-               }()
+               s := &Sync{
+                       host:      x,
+                       config:    cfg,
+                       allHosts:  allHosts,
+                       allGroups: allGroups,
+                       isTTY:     isTTY,
+               }
+               s.logFunc = func(level safcm.LogLevel, escaped bool,
+                       msg string) {
+                       loop.Log(s, level, escaped, msg)
+               }
+               hosts = append(hosts, s)
        }
 
-       wg.Wait()
-       events <- Event{} // poison pill
-       failed := <-done
+       succ := loop.Run(hosts)
 
-       if failed {
+       if !succ {
                // Exit instead of returning an error to prevent an extra log
                // message from main()
                os.Exit(1)
@@ -246,13 +168,13 @@ are only available after the hosts were contacted.
        nameMatched := make(map[string]bool)
        // To detect typos we must check all given names but one host can be
        // matched by multiple names (e.g. two groups with overlapping hosts)
-       hostMatched := make(map[string]bool)
+       hostAdded := make(map[string]bool)
 
        var res []*config.Host
        for _, host := range allHosts.List {
                if nameMap[host.Name] {
                        res = append(res, host)
-                       hostMatched[host.Name] = true
+                       hostAdded[host.Name] = true
                        nameMatched[host.Name] = true
                }
 
@@ -263,9 +185,9 @@ are only available after the hosts were contacted.
                }
                for _, x := range groups {
                        if nameMap[x] {
-                               if !hostMatched[host.Name] {
+                               if !hostAdded[host.Name] {
                                        res = append(res, host)
-                                       hostMatched[host.Name] = true
+                                       hostAdded[host.Name] = true
                                }
                                nameMatched[x] = true
                        }
@@ -289,17 +211,20 @@ are only available after the hosts were contacted.
        return res, nil
 }
 
-func logEvent(x Event, level safcm.LogLevel, isTTY bool, failed *bool) {
+func logEvent(x frontend.Event, level safcm.LogLevel, isTTY bool, failed *bool) {
        // We have multiple event sources so this is somewhat ugly.
        var prefix, data string
-       var color Color
+       var color frontend.Color
        if x.Error != nil {
                prefix = "[error]"
                data = x.Error.Error()
-               color = ColorRed
+               color = frontend.ColorRed
                // We logged an error, tell the caller
                *failed = true
        } else if x.Log.Level != 0 {
+               if level < x.Log.Level {
+                       return
+               }
                // LogError and LogDebug3 should not occur here
                switch x.Log.Level {
                case safcm.LogInfo:
@@ -312,7 +237,7 @@ func logEvent(x Event, level safcm.LogLevel, isTTY bool, failed *bool) {
                        prefix = "[debug2]"
                default:
                        prefix = fmt.Sprintf("[INVALID=%d]", x.Log.Level)
-                       color = ColorRed
+                       color = frontend.ColorRed
                }
                data = x.Log.Text
        } else {
@@ -329,47 +254,47 @@ func logEvent(x Event, level safcm.LogLevel, isTTY bool, failed *bool) {
                        x.ConnEvent.Data = "remote helper upload in progress"
                default:
                        prefix = fmt.Sprintf("[INVALID=%d]", x.ConnEvent.Type)
-                       color = ColorRed
+                       color = frontend.ColorRed
                }
                data = x.ConnEvent.Data
        }
 
-       host := x.Host.Name
+       host := x.Host.Name()
        if color != 0 {
-               host = ColorString(isTTY, color, host)
+               host = frontend.ColorString(isTTY, color, host)
        }
        // Make sure to escape control characters to prevent terminal
        // injection attacks
        if !x.Escaped {
-               data = EscapeControlCharacters(isTTY, data)
+               data = frontend.EscapeControlCharacters(isTTY, data)
        }
        log.Printf("%-9s [%s] %s", prefix, host, data)
 }
 
-func (s *Sync) Host(wg *sync.WaitGroup) error {
-       conn := rpc.NewConn(s.config.LogLevel >= safcm.LogDebug3)
-       // Pass all connection events to main loop
-       go func() {
-               for {
-                       x, ok := <-conn.Events
-                       if !ok {
-                               break
-                       }
-                       s.events <- Event{
-                               Host:      s.host,
-                               ConnEvent: x,
-                       }
-               }
-               wg.Done()
-       }()
+func (s *Sync) Name() string {
+       return s.host.Name
+}
 
-       // Connect to remote host
-       err := conn.DialSSH(s.host.SshUser, s.host.Name, s.config.SshConfig)
+func (s *Sync) Dial(conn *rpc.Conn) error {
+       helpers, err := fs.Sub(RemoteHelpers, "remote")
        if err != nil {
                return err
        }
-       defer conn.Kill()
 
+       // Connect to remote host
+       user := s.host.SshUser
+       if user == "" {
+               user = s.config.SshUser
+       }
+       return conn.DialSSH(rpc.SSHConfig{
+               Host:          s.host.Name,
+               User:          user,
+               SshConfig:     s.config.SshConfig,
+               RemoteHelpers: helpers,
+       })
+}
+
+func (s *Sync) Host(conn *rpc.Conn) error {
        // Collect information about remote host
        detectedGroups, err := s.hostInfo(conn)
        if err != nil {
@@ -382,43 +307,17 @@ func (s *Sync) Host(wg *sync.WaitGroup) error {
                return err
        }
 
-       // Terminate connection to remote host
-       err = conn.Send(safcm.MsgQuitReq{})
-       if err != nil {
-               return err
-       }
-       _, err = conn.Recv()
-       if err != nil {
-               return err
-       }
-       err = conn.Wait()
-       if err != nil {
-               return err
-       }
-
        return nil
 }
 
-func (s *Sync) logf(level safcm.LogLevel, escaped bool,
-       format string, a ...interface{}) {
-
-       if s.config.LogLevel < level {
-               return
-       }
-       s.events <- Event{
-               Host: s.host,
-               Log: Log{
-                       Level: level,
-                       Text:  fmt.Sprintf(format, a...),
-               },
-               Escaped: escaped,
-       }
+func (s *Sync) log(level safcm.LogLevel, escaped bool, msg string) {
+       s.logFunc(level, escaped, msg)
 }
 func (s *Sync) logDebugf(format string, a ...interface{}) {
-       s.logf(safcm.LogDebug, false, format, a...)
+       s.log(safcm.LogDebug, false, fmt.Sprintf(format, a...))
 }
 func (s *Sync) logVerbosef(format string, a ...interface{}) {
-       s.logf(safcm.LogVerbose, false, format, a...)
+       s.log(safcm.LogVerbose, false, fmt.Sprintf(format, a...))
 }
 
 // sendRecv sends a message over conn and waits for the response. Any MsgLog
@@ -436,7 +335,7 @@ func (s *Sync) sendRecv(conn *rpc.Conn, msg safcm.Msg) (safcm.Msg, error) {
                }
                log, ok := x.(safcm.MsgLog)
                if ok {
-                       s.logf(log.Level, false, "%s", log.Text)
+                       s.log(log.Level, false, log.Text)
                        continue
                }
                return x, nil