]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - remote/sync/sync_test.go
Use SPDX license identifiers
[safcm/safcm.git] / remote / sync / sync_test.go
1 // SPDX-License-Identifier: GPL-3.0-or-later
2 // Copyright (C) 2021-2024  Simon Ruderich
3
4 package sync
5
6 import (
7         "bytes"
8         "fmt"
9         "os/exec"
10         "reflect"
11         "sync"
12         "testing"
13
14         "github.com/google/go-cmp/cmp"
15         "github.com/google/go-cmp/cmp/cmpopts"
16
17         "ruderich.org/simon/safcm"
18         "ruderich.org/simon/safcm/remote/log"
19         "ruderich.org/simon/safcm/remote/run"
20 )
21
22 // testRunner implements run.Runner to test commands without actually running
23 // them.
24 type testRunner struct {
25         t         *testing.T
26         expCmds   []*exec.Cmd
27         resStdout [][]byte
28         resStderr [][]byte
29         resError  []error
30 }
31
32 func (r *testRunner) Run(cmd *exec.Cmd) error {
33         stdout, stderr, resErr := r.check("run", cmd)
34         _, err := cmd.Stdout.Write(stdout)
35         if err != nil {
36                 panic(err)
37         }
38         _, err = cmd.Stderr.Write(stderr)
39         if err != nil {
40                 panic(err)
41         }
42         return resErr
43 }
44 func (r *testRunner) CombinedOutput(cmd *exec.Cmd) ([]byte, error) {
45         r.t.Helper()
46
47         stdout, stderr, err := r.check("combinedOutput", cmd)
48         if stderr != nil {
49                 // stdout also contains stderr
50                 r.t.Fatalf("CombinedOutput: stderr != nil, but %v", stderr)
51         }
52         return stdout, err
53 }
54 func (r *testRunner) check(method string, cmd *exec.Cmd) (
55         []byte, []byte, error) {
56         r.t.Helper()
57
58         if len(r.expCmds) == 0 {
59                 r.t.Fatalf("%s: empty expCmds", method)
60         }
61         if len(r.resStdout) == 0 {
62                 r.t.Fatalf("%s: empty resStdout", method)
63         }
64         if len(r.resStderr) == 0 {
65                 r.t.Fatalf("%s: empty resStderr", method)
66         }
67         if len(r.resError) == 0 {
68                 r.t.Fatalf("%s: empty resError", method)
69         }
70
71         exp := r.expCmds[0]
72         r.expCmds = r.expCmds[1:]
73         if !reflect.DeepEqual(exp, cmd) {
74                 opts := cmpopts.IgnoreUnexported(exec.Cmd{}, bytes.Buffer{})
75                 r.t.Errorf("%s: %s", method, cmp.Diff(exp, cmd, opts))
76         }
77
78         var stdout, stderr []byte
79         var err error
80
81         stdout, r.resStdout = r.resStdout[0], r.resStdout[1:]
82         stderr, r.resStderr = r.resStderr[0], r.resStderr[1:]
83         err, r.resError = r.resError[0], r.resError[1:]
84
85         return stdout, stderr, err
86 }
87
88 type syncTestResult struct {
89         ch     chan string
90         wg     sync.WaitGroup
91         dbg    []string
92         runner *testRunner
93 }
94
95 func prepareSync(req safcm.MsgSyncReq, runner *testRunner) (
96         *Sync, *syncTestResult) {
97
98         res := &syncTestResult{
99                 ch:     make(chan string),
100                 runner: runner,
101         }
102         res.wg.Add(1)
103         go func() {
104                 for {
105                         x, ok := <-res.ch
106                         if !ok {
107                                 break
108                         }
109                         res.dbg = append(res.dbg, x)
110                 }
111                 res.wg.Done()
112         }()
113
114         logger := log.NewLogger(func(level safcm.LogLevel, msg string) {
115                 res.ch <- fmt.Sprintf("%d: %s", level, msg)
116         })
117         return &Sync{
118                 req: req,
119                 cmd: run.NewCmd(runner, logger),
120                 log: logger,
121         }, res
122 }
123
124 func (s *syncTestResult) Wait() []string {
125         s.runner.t.Helper()
126
127         close(s.ch)
128         s.wg.Wait()
129
130         // All expected commands must have been executed
131         if len(s.runner.expCmds) != 0 {
132                 s.runner.t.Errorf("expCmds left: %v", s.runner.expCmds)
133         }
134         if len(s.runner.resStdout) != 0 {
135                 s.runner.t.Errorf("resStdout left: %v", s.runner.resStdout)
136         }
137         if len(s.runner.resStderr) != 0 {
138                 s.runner.t.Errorf("resStderr left: %v", s.runner.resStderr)
139         }
140         if len(s.runner.resError) != 0 {
141                 s.runner.t.Errorf("resError left: %v", s.runner.resError)
142         }
143
144         return s.dbg
145 }