]> ruderich.org/simon Gitweb - nsscash/nsscash.git/blobdiff - main_test.go
nsscash: add "ca" option for files
[nsscash/nsscash.git] / main_test.go
index 2b48bcc7c9d37315910f8efdac80633ed269a6af..38499da197bb6fd0be3276e48d11527215a30a8c 100644 (file)
@@ -17,6 +17,7 @@ package main
 
 import (
        "crypto/sha1"
+       "crypto/tls"
        "encoding/hex"
        "fmt"
        "io/ioutil"
@@ -37,6 +38,10 @@ const (
        passwdPath = "testdata/passwd.nsscash"
        plainPath  = "testdata/plain"
        groupPath  = "testdata/group.nsscash"
+       tlsCAPath   = "testdata/ca.crt"
+       tlsCertPath = "testdata/server.crt"
+       tlsKeyPath  = "testdata/server.key"
+       tlsCA2Path  = "testdata/ca2.crt"
 )
 
 type args struct {
@@ -104,7 +109,8 @@ statepath = "%[1]s"
 type = "passwd"
 url = "%[2]s/passwd"
 path = "%[3]s"
-`, statePath, url, passwdPath))
+ca = "%[4]s"
+`, statePath, url, passwdPath, tlsCAPath))
 }
 
 func mustWriteGroupConfig(t *testing.T, url string) {
@@ -115,7 +121,8 @@ statepath = "%[1]s"
 type = "group"
 url = "%[2]s/group"
 path = "%[3]s"
-`, statePath, url, groupPath))
+ca = "%[4]s"
+`, statePath, url, groupPath, tlsCAPath))
 }
 
 // mustCreate creates a file, truncating it if it exists. It then changes the
@@ -207,12 +214,30 @@ func TestMainFetch(t *testing.T) {
                fetchSecondFetchFails,
        }
 
+       // HTTP tests
+
+       for _, f := range tests {
+               runMainTest(t, f, nil)
+       }
+
+       // HTTPS tests
+
+       tests = append(tests, fetchInvalidCA)
+
+       cert, err := tls.LoadX509KeyPair(tlsCertPath, tlsKeyPath)
+       if err != nil {
+               t.Fatal(err)
+       }
+       tls := &tls.Config{
+               Certificates: []tls.Certificate{cert},
+       }
+
        for _, f := range tests {
-               runMainTest(t, f)
+               runMainTest(t, f, tls)
        }
 }
 
-func runMainTest(t *testing.T, f func(args)) {
+func runMainTest(t *testing.T, f func(args), tls *tls.Config) {
        cleanup := []string{
                configPath,
                statePath,
@@ -226,6 +251,9 @@ func runMainTest(t *testing.T, f func(args)) {
        fn := runtime.FuncForPC(reflect.ValueOf(f).Pointer())
        name := fn.Name()
        name = name[strings.LastIndex(name, ".")+1:]
+       if tls != nil {
+               name = "tls" + name
+       }
 
        t.Run(name, func(t *testing.T) {
                // Preparation & cleanup
@@ -240,10 +268,16 @@ func runMainTest(t *testing.T, f func(args)) {
                }
 
                var handler func(http.ResponseWriter, *http.Request)
-               ts := httptest.NewServer(http.HandlerFunc(
+               ts := httptest.NewUnstartedServer(http.HandlerFunc(
                        func(w http.ResponseWriter, r *http.Request) {
                                handler(w, r)
                        }))
+               if tls == nil {
+                       ts.Start()
+               } else {
+                       ts.TLS = tls
+                       ts.StartTLS()
+               }
                defer ts.Close()
 
                f(args{
@@ -464,7 +498,8 @@ statepath = "%[1]s"
 type = "plain"
 url = "%[2]s/plain"
 path = "%[3]s"
-`, statePath, a.url, plainPath))
+ca = "%[4]s"
+`, statePath, a.url, plainPath, tlsCAPath))
        mustCreate(t, plainPath)
 
        *a.handler = func(w http.ResponseWriter, r *http.Request) {
@@ -488,7 +523,8 @@ statepath = "%[1]s"
 type = "plain"
 url = "%[2]s/plain"
 path = "%[3]s"
-`, statePath, a.url, plainPath))
+ca = "%[4]s"
+`, statePath, a.url, plainPath, tlsCAPath))
        mustCreate(t, plainPath)
        mustHaveHash(t, plainPath, "da39a3ee5e6b4b0d3255bfef95601890afd80709")
 
@@ -707,12 +743,14 @@ statepath = "%[1]s"
 type = "passwd"
 url = "%[2]s/passwd"
 path = "%[3]s"
+ca = "%[5]s"
 
 [[file]]
 type = "group"
 url = "%[2]s/group"
 path = "%[4]s"
-`, statePath, a.url, passwdPath, groupPath))
+ca = "%[5]s"
+`, statePath, a.url, passwdPath, groupPath, tlsCAPath))
        mustCreate(t, passwdPath)
        mustCreate(t, groupPath)
        mustHaveHash(t, passwdPath, "da39a3ee5e6b4b0d3255bfef95601890afd80709")
@@ -736,3 +774,60 @@ path = "%[4]s"
        // because the second fetch failed
        mustBeOld(t, passwdPath, groupPath)
 }
+
+func fetchInvalidCA(a args) {
+       t := a.t
+
+       // System CA
+
+       mustWriteConfig(t, fmt.Sprintf(`
+statepath = "%[1]s"
+
+[[file]]
+type = "passwd"
+url = "%[2]s/passwd"
+path = "%[3]s"
+`, statePath, a.url, passwdPath))
+       mustCreate(t, passwdPath)
+       mustHaveHash(t, passwdPath, "da39a3ee5e6b4b0d3255bfef95601890afd80709")
+
+       *a.handler = func(w http.ResponseWriter, r *http.Request) {
+               if r.URL.Path == "/passwd" {
+                       fmt.Fprintln(w, "root:x:0:0:root:/root:/bin/bash")
+               }
+       }
+
+       err := mainFetch(configPath)
+       mustBeErrorWithSubstring(t, err,
+               "x509: certificate signed by unknown authority")
+
+       mustNotExist(t, statePath, plainPath, groupPath)
+       mustBeOld(t, passwdPath)
+
+       // Invalid CA
+
+       mustWriteConfig(t, fmt.Sprintf(`
+statepath = "%[1]s"
+
+[[file]]
+type = "passwd"
+url = "%[2]s/passwd"
+path = "%[3]s"
+ca = "%[4]s"
+`, statePath, a.url, passwdPath, tlsCA2Path))
+       mustCreate(t, passwdPath)
+       mustHaveHash(t, passwdPath, "da39a3ee5e6b4b0d3255bfef95601890afd80709")
+
+       *a.handler = func(w http.ResponseWriter, r *http.Request) {
+               if r.URL.Path == "/passwd" {
+                       fmt.Fprintln(w, "root:x:0:0:root:/root:/bin/bash")
+               }
+       }
+
+       err = mainFetch(configPath)
+       mustBeErrorWithSubstring(t, err,
+               "x509: certificate signed by unknown authority")
+
+       mustNotExist(t, statePath, plainPath, groupPath)
+       mustBeOld(t, passwdPath)
+}