]> ruderich.org/simon Gitweb - safcm/safcm.git/blobdiff - cmd/safcm/sync.go
safcm: use better variable name in hostsToSync()
[safcm/safcm.git] / cmd / safcm / sync.go
index 09ffe854e1e7491727882f963a894ff9c801994e..70aaf47e4c8f5c364247fa4fe0c05e239bb5d79b 100644 (file)
@@ -22,6 +22,7 @@ import (
        "fmt"
        "log"
        "os"
+       "os/signal"
        "sort"
        "strings"
        "sync"
@@ -71,8 +72,12 @@ func MainSync(args []string) error {
 
        optionDryRun := flag.Bool("n", false,
                "dry-run, show diff but don't perform any changes")
+       optionQuiet := flag.Bool("q", false,
+               "hide successful, non-trigger commands with no output from host changes listing")
        optionLog := flag.String("log", "info", "set log `level`; "+
                "levels: error, info, verbose, debug, debug2, debug3")
+       optionSshConfig := flag.String("sshconfig", "",
+               "`path` to ssh configuration file; used for tests")
 
        flag.CommandLine.Parse(args[2:])
 
@@ -105,7 +110,9 @@ func MainSync(args []string) error {
                return err
        }
        cfg.DryRun = *optionDryRun
+       cfg.Quiet = *optionQuiet
        cfg.LogLevel = level
+       cfg.SshConfig = *optionSshConfig
 
        toSync, err := hostsToSync(names, allHosts, allGroups)
        if err != nil {
@@ -115,7 +122,8 @@ func MainSync(args []string) error {
                return fmt.Errorf("no hosts found")
        }
 
-       isTTY := term.IsTerminal(int(os.Stdout.Fd()))
+       isTTY := term.IsTerminal(int(os.Stdout.Fd())) &&
+               term.IsTerminal(int(os.Stderr.Fd()))
 
        done := make(chan bool)
        // Collect events from all hosts and print them
@@ -132,6 +140,38 @@ func MainSync(args []string) error {
                done <- failed
        }()
 
+       hostsLeft := make(map[string]bool)
+       for _, x := range toSync {
+               hostsLeft[x.Name] = true
+       }
+       var hostsLeftMutex sync.Mutex // protects hostsLeft
+
+       // Show unfinished hosts on Ctrl-C
+       sigint := make(chan os.Signal, 1)   // buffered for Notify()
+       signal.Notify(sigint, os.Interrupt) // = SIGINT = Ctrl-C
+       go func() {
+               // Running `ssh` processes get killed by SIGINT which is sent
+               // to all processes
+
+               <-sigint
+               log.Print("Received SIGINT, aborting ...")
+
+               // Print all queued events
+               events <- Event{} // poison pill
+               <-done
+               // "races" with <-done in the main function and will hang here
+               // if the other is faster. This is fine because then all hosts
+               // were synced successfully.
+
+               hostsLeftMutex.Lock()
+               var hosts []string
+               for x := range hostsLeft {
+                       hosts = append(hosts, x)
+               }
+               sort.Strings(hosts)
+               log.Fatalf("Failed to sync %s", strings.Join(hosts, ", "))
+       }()
+
        // Sync all hosts concurrently
        var wg sync.WaitGroup
        for _, x := range toSync {
@@ -157,6 +197,10 @@ func MainSync(args []string) error {
                                }
                        }
                        wg.Done()
+
+                       hostsLeftMutex.Lock()
+                       defer hostsLeftMutex.Unlock()
+                       delete(hostsLeft, x.Name)
                }()
        }
 
@@ -183,25 +227,36 @@ func MainSync(args []string) error {
 func hostsToSync(names []string, allHosts *config.Hosts,
        allGroups map[string][]string) ([]*config.Host, error) {
 
+       detectedMap := config.TransitivelyDetectedGroups(allGroups)
+
+       const detectedErr = `
+
+Groups depending on "detected" groups cannot be used to select hosts as these
+are only available after the hosts were contacted.
+`
+
        nameMap := make(map[string]bool)
        for _, x := range names {
+               if detectedMap[x] {
+                       return nil, fmt.Errorf(
+                               "group %q depends on \"detected\" groups%s",
+                               x, detectedErr)
+               }
                nameMap[x] = true
        }
        nameMatched := make(map[string]bool)
-       // To detect typos we must check all given names but only want to add
-       // each match once
-       hostMatched := make(map[string]bool)
+       // To detect typos we must check all given names but one host can be
+       // matched by multiple names (e.g. two groups with overlapping hosts)
+       hostAdded := make(map[string]bool)
 
        var res []*config.Host
        for _, host := range allHosts.List {
                if nameMap[host.Name] {
                        res = append(res, host)
-                       hostMatched[host.Name] = true
+                       hostAdded[host.Name] = true
                        nameMatched[host.Name] = true
                }
 
-               // TODO: don't permit groups which contain "detected" groups
-               // because these are not available yet
                groups, err := config.ResolveHostGroups(host.Name,
                        allGroups, nil)
                if err != nil {
@@ -209,9 +264,9 @@ func hostsToSync(names []string, allHosts *config.Hosts,
                }
                for _, x := range groups {
                        if nameMap[x] {
-                               if !hostMatched[host.Name] {
+                               if !hostAdded[host.Name] {
                                        res = append(res, host)
-                                       hostMatched[host.Name] = true
+                                       hostAdded[host.Name] = true
                                }
                                nameMatched[x] = true
                        }
@@ -310,7 +365,7 @@ func (s *Sync) Host(wg *sync.WaitGroup) error {
        }()
 
        // Connect to remote host
-       err := conn.DialSSH(s.host.Name)
+       err := conn.DialSSH(s.host.SshUser, s.host.Name, s.config.SshConfig)
        if err != nil {
                return err
        }