// "sync" sub-command: sync data to remote hosts
// Copyright (C) 2021 Simon Ruderich
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
package main
import (
"flag"
"fmt"
"log"
"os"
"os/signal"
"sort"
"strings"
"sync"
"golang.org/x/term"
"ruderich.org/simon/safcm"
"ruderich.org/simon/safcm/cmd/safcm/config"
"ruderich.org/simon/safcm/rpc"
)
type Sync struct {
host *config.Host
config *config.Config // global configuration
allHosts *config.Hosts // known hosts
allGroups map[string][]string // known groups
events chan<- Event // all events generated by/for this host
isTTY bool
}
type Event struct {
Host *config.Host
// Only one of Error, Log and ConnEvent is set in a single event
Error error
Log Log
ConnEvent rpc.ConnEvent
Escaped bool // true if untrusted input is already escaped
}
type Log struct {
Level safcm.LogLevel
Text string
}
func MainSync(args []string) error {
flag.Usage = func() {
fmt.Fprintf(os.Stderr,
"usage: %s sync [] \n",
args[0])
flag.PrintDefaults()
}
optionDryRun := flag.Bool("n", false,
"dry-run, show diff but don't perform any changes")
optionQuiet := flag.Bool("q", false,
"hide successful, non-trigger commands with no output from host changes listing")
optionLog := flag.String("log", "info", "set log `level`; "+
"levels: error, info, verbose, debug, debug2, debug3")
optionSshConfig := flag.String("sshconfig", "",
"`path` to ssh configuration file; used for tests")
flag.CommandLine.Parse(args[2:])
var level safcm.LogLevel
switch *optionLog {
case "error":
level = safcm.LogError
case "info":
level = safcm.LogInfo
case "verbose":
level = safcm.LogVerbose
case "debug":
level = safcm.LogDebug
case "debug2":
level = safcm.LogDebug2
case "debug3":
level = safcm.LogDebug3
default:
return fmt.Errorf("invalid -log value %q", *optionLog)
}
names := flag.Args()
if len(names) == 0 {
flag.Usage()
os.Exit(1)
}
cfg, allHosts, allGroups, err := LoadBaseFiles()
if err != nil {
return err
}
cfg.DryRun = *optionDryRun
cfg.Quiet = *optionQuiet
cfg.LogLevel = level
cfg.SshConfig = *optionSshConfig
toSync, err := hostsToSync(names, allHosts, allGroups)
if err != nil {
return err
}
if len(toSync) == 0 {
return fmt.Errorf("no hosts found")
}
isTTY := term.IsTerminal(int(os.Stdout.Fd()))
done := make(chan bool)
// Collect events from all hosts and print them
events := make(chan Event)
go func() {
var failed bool
for {
x := <-events
if x.Host == nil {
break
}
logEvent(x, cfg.LogLevel, isTTY, &failed)
}
done <- failed
}()
hostsLeft := make(map[string]bool)
for _, x := range toSync {
hostsLeft[x.Name] = true
}
var hostsLeftMutex sync.Mutex // protects hostsLeft
// Show unfinished hosts on Ctrl-C
sigint := make(chan os.Signal, 1) // buffered for Notify()
signal.Notify(sigint, os.Interrupt) // = SIGINT = Ctrl-C
go func() {
// Running `ssh` processes get killed by SIGINT which is sent
// to all processes
<-sigint
log.Print("Received SIGINT, aborting ...")
// Print all queued events
events <- Event{} // poison pill
<-done
// "races" with <-done in the main function and will hang here
// if the other is faster. This is fine because then all hosts
// were synced successfully.
hostsLeftMutex.Lock()
var hosts []string
for x := range hostsLeft {
hosts = append(hosts, x)
}
sort.Strings(hosts)
log.Fatalf("Failed to sync %s", strings.Join(hosts, ", "))
}()
// Sync all hosts concurrently
var wg sync.WaitGroup
for _, x := range toSync {
x := x
// Once in sync.Host() and once in the go func below
wg.Add(2)
go func() {
sync := Sync{
host: x,
config: cfg,
allHosts: allHosts,
allGroups: allGroups,
events: events,
isTTY: isTTY,
}
err := sync.Host(&wg)
if err != nil {
events <- Event{
Host: x,
Error: err,
}
}
wg.Done()
hostsLeftMutex.Lock()
defer hostsLeftMutex.Unlock()
delete(hostsLeft, x.Name)
}()
}
wg.Wait()
events <- Event{} // poison pill
failed := <-done
if failed {
// Exit instead of returning an error to prevent an extra log
// message from main()
os.Exit(1)
}
return nil
}
// hostsToSync returns the list of hosts to sync based on the command line
// arguments.
//
// Full host and group matches are required to prevent unexpected behavior. No
// arguments does not expand to all hosts to prevent accidents; "all" can be
// used instead. Both host and group names are permitted as these are unique.
//
// TODO: Add option to permit partial/glob matches
func hostsToSync(names []string, allHosts *config.Hosts,
allGroups map[string][]string) ([]*config.Host, error) {
detectedMap := make(map[string]bool)
for _, x := range config.TransitivelyDetectedGroups(allGroups) {
detectedMap[x] = true
}
const detectedErr = `
Groups depending on "detected" groups cannot be used to select hosts as these
are only available after the hosts were contacted.
`
nameMap := make(map[string]bool)
for _, x := range names {
if detectedMap[x] {
return nil, fmt.Errorf(
"group %q depends on \"detected\" groups%s",
x, detectedErr)
}
nameMap[x] = true
}
nameMatched := make(map[string]bool)
// To detect typos we must check all given names but only want to add
// each match once
hostMatched := make(map[string]bool)
var res []*config.Host
for _, host := range allHosts.List {
if nameMap[host.Name] {
res = append(res, host)
hostMatched[host.Name] = true
nameMatched[host.Name] = true
}
groups, err := config.ResolveHostGroups(host.Name,
allGroups, nil)
if err != nil {
return nil, err
}
for _, x := range groups {
if nameMap[x] {
if !hostMatched[host.Name] {
res = append(res, host)
hostMatched[host.Name] = true
}
nameMatched[x] = true
}
}
}
// Warn about unmatched names to detect typos
if len(nameMap) != len(nameMatched) {
var unmatched []string
for x := range nameMap {
if !nameMatched[x] {
unmatched = append(unmatched,
fmt.Sprintf("%q", x))
}
}
sort.Strings(unmatched)
return nil, fmt.Errorf("hosts/groups not found: %s",
strings.Join(unmatched, " "))
}
return res, nil
}
func logEvent(x Event, level safcm.LogLevel, isTTY bool, failed *bool) {
// We have multiple event sources so this is somewhat ugly.
var prefix, data string
var color Color
if x.Error != nil {
prefix = "[error]"
data = x.Error.Error()
color = ColorRed
// We logged an error, tell the caller
*failed = true
} else if x.Log.Level != 0 {
// LogError and LogDebug3 should not occur here
switch x.Log.Level {
case safcm.LogInfo:
prefix = "[info]"
case safcm.LogVerbose:
prefix = "[verbose]"
case safcm.LogDebug:
prefix = "[debug]"
case safcm.LogDebug2:
prefix = "[debug2]"
default:
prefix = fmt.Sprintf("[INVALID=%d]", x.Log.Level)
color = ColorRed
}
data = x.Log.Text
} else {
switch x.ConnEvent.Type {
case rpc.ConnEventStderr:
prefix = "[stderr]"
case rpc.ConnEventDebug:
prefix = "[debug3]"
case rpc.ConnEventUpload:
if level < safcm.LogInfo {
return
}
prefix = "[info]"
x.ConnEvent.Data = "remote helper upload in progress"
default:
prefix = fmt.Sprintf("[INVALID=%d]", x.ConnEvent.Type)
color = ColorRed
}
data = x.ConnEvent.Data
}
host := x.Host.Name
if color != 0 {
host = ColorString(isTTY, color, host)
}
// Make sure to escape control characters to prevent terminal
// injection attacks
if !x.Escaped {
data = EscapeControlCharacters(isTTY, data)
}
log.Printf("%-9s [%s] %s", prefix, host, data)
}
func (s *Sync) Host(wg *sync.WaitGroup) error {
conn := rpc.NewConn(s.config.LogLevel >= safcm.LogDebug3)
// Pass all connection events to main loop
go func() {
for {
x, ok := <-conn.Events
if !ok {
break
}
s.events <- Event{
Host: s.host,
ConnEvent: x,
}
}
wg.Done()
}()
// Connect to remote host
err := conn.DialSSH(s.host.SshUser, s.host.Name, s.config.SshConfig)
if err != nil {
return err
}
defer conn.Kill()
// Collect information about remote host
detectedGroups, err := s.hostInfo(conn)
if err != nil {
return err
}
// Sync state to remote host
err = s.hostSync(conn, detectedGroups)
if err != nil {
return err
}
// Terminate connection to remote host
err = conn.Send(safcm.MsgQuitReq{})
if err != nil {
return err
}
_, err = conn.Recv()
if err != nil {
return err
}
err = conn.Wait()
if err != nil {
return err
}
return nil
}
func (s *Sync) logf(level safcm.LogLevel, escaped bool,
format string, a ...interface{}) {
if s.config.LogLevel < level {
return
}
s.events <- Event{
Host: s.host,
Log: Log{
Level: level,
Text: fmt.Sprintf(format, a...),
},
Escaped: escaped,
}
}
func (s *Sync) logDebugf(format string, a ...interface{}) {
s.logf(safcm.LogDebug, false, format, a...)
}
func (s *Sync) logVerbosef(format string, a ...interface{}) {
s.logf(safcm.LogVerbose, false, format, a...)
}
// sendRecv sends a message over conn and waits for the response. Any MsgLog
// messages received before the final (non MsgLog) response are passed to
// s.log.
func (s *Sync) sendRecv(conn *rpc.Conn, msg safcm.Msg) (safcm.Msg, error) {
err := conn.Send(msg)
if err != nil {
return nil, err
}
for {
x, err := conn.Recv()
if err != nil {
return nil, err
}
log, ok := x.(safcm.MsgLog)
if ok {
s.logf(log.Level, false, "%s", log.Text)
continue
}
return x, nil
}
}