]> ruderich.org/simon Gitweb - safcm/safcm.git/blobdiff - cmd/safcm/config/groups_test.go
config: remove unnecessary os.Chdir in test
[safcm/safcm.git] / cmd / safcm / config / groups_test.go
index 68f18bad3bc5ec3d6ff93fa9ec6010ec9951c073..cf6cd6e088cd48d3789fb2f618f903eb1e420dcc 100644 (file)
@@ -18,6 +18,7 @@ package config
 import (
        "fmt"
        "os"
+       "path/filepath"
        "reflect"
        "testing"
 
@@ -186,27 +187,21 @@ func TestLoadGroups(t *testing.T) {
 
        for _, tc := range tests {
                t.Run(tc.path, func(t *testing.T) {
-               err := os.Chdir(tc.path)
-               if err != nil {
-                       t.Fatal(err)
-               }
+                       err := os.Chdir(filepath.Join(cwd, tc.path))
+                       if err != nil {
+                               t.Fatal(err)
+                       }
 
-               res, err := LoadGroups(tc.cfg, tc.hosts)
+                       res, err := LoadGroups(tc.cfg, tc.hosts)
 
-               if !reflect.DeepEqual(tc.exp, res) {
-                       t.Errorf("res: %s",
-                               cmp.Diff(tc.exp, res))
-               }
-               // Ugly but the simplest way to compare errors (including nil)
-               if fmt.Sprintf("%s", err) != fmt.Sprintf("%s", tc.expErr) {
-                       t.Errorf("err = %#v, want %#v",
-                               err, tc.expErr)
-               }
-
-               err = os.Chdir(cwd)
-               if err != nil {
-                       t.Fatal(err)
-               }
+                       if !reflect.DeepEqual(tc.exp, res) {
+                               t.Errorf("res: %s", cmp.Diff(tc.exp, res))
+                       }
+                       // Ugly but the simplest way to compare errors (including nil)
+                       if fmt.Sprintf("%s", err) != fmt.Sprintf("%s", tc.expErr) {
+                               t.Errorf("err = %#v, want %#v",
+                                       err, tc.expErr)
+                       }
                })
        }
 }
@@ -318,16 +313,16 @@ func TestResolveHostGroups(t *testing.T) {
 
        for _, tc := range tests {
                t.Run(tc.name, func(t *testing.T) {
-               res, err := ResolveHostGroups(tc.host, allGroups, tc.detected)
-               if !reflect.DeepEqual(tc.exp, res) {
-                       t.Errorf("res: %s",
-                               cmp.Diff(tc.exp, res))
-               }
-               // Ugly but the simplest way to compare errors (including nil)
-               if fmt.Sprintf("%s", err) != fmt.Sprintf("%s", tc.expErr) {
-                       t.Errorf("err = %#v, want %#v",
-                               err, tc.expErr)
-               }
+                       res, err := ResolveHostGroups(tc.host, allGroups,
+                               tc.detected)
+                       if !reflect.DeepEqual(tc.exp, res) {
+                               t.Errorf("res: %s", cmp.Diff(tc.exp, res))
+                       }
+                       // Ugly but the simplest way to compare errors (including nil)
+                       if fmt.Sprintf("%s", err) != fmt.Sprintf("%s", tc.expErr) {
+                               t.Errorf("err = %#v, want %#v",
+                                       err, tc.expErr)
+                       }
                })
        }
 }