import (
"crypto/sha1"
+ "crypto/tls"
"encoding/hex"
"fmt"
"io/ioutil"
)
const (
- configPath = "testdata/config.toml"
- statePath = "testdata/state.json"
- passwdPath = "testdata/passwd.nsscash"
- plainPath = "testdata/plain"
- groupPath = "testdata/group.nsscash"
+ configPath = "testdata/config.toml"
+ statePath = "testdata/state.json"
+ passwdPath = "testdata/passwd.nsscash"
+ plainPath = "testdata/plain"
+ groupPath = "testdata/group.nsscash"
+ tlsCAPath = "testdata/ca.crt"
+ tlsCertPath = "testdata/server.crt"
+ tlsKeyPath = "testdata/server.key"
+ tlsCA2Path = "testdata/ca2.crt"
)
type args struct {
}
}
+func hashAsHex(x []byte) string {
+ h := sha1.New()
+ h.Write(x)
+ return hex.EncodeToString(h.Sum(nil))
+}
+
// mustHaveHash checks if the given path content has the given SHA-1 string
// (in hex).
func mustHaveHash(t *testing.T, path string, hash string) {
t.Fatal(err)
}
- h := sha1.New()
- h.Write(x)
- y := hex.EncodeToString(h.Sum(nil))
-
+ y := hashAsHex(x)
if y != hash {
t.Errorf("%q has unexpected hash %q", path, y)
}
type = "passwd"
url = "%[2]s/passwd"
path = "%[3]s"
-`, statePath, url, passwdPath))
+ca = "%[4]s"
+`, statePath, url, passwdPath, tlsCAPath))
}
func mustWriteGroupConfig(t *testing.T, url string) {
type = "group"
url = "%[2]s/group"
path = "%[3]s"
-`, statePath, url, groupPath))
+ca = "%[4]s"
+`, statePath, url, groupPath, tlsCAPath))
}
// mustCreate creates a file, truncating it if it exists. It then changes the
fetchStateCannotWrite,
fetchCannotDeploy,
fetchSecondFetchFails,
+ fetchBasicAuth,
+ }
+
+ // HTTP tests
+
+ for _, f := range tests {
+ runMainTest(t, f, nil)
+ }
+
+ // HTTPS tests
+
+ tests = append(tests, fetchInvalidCA)
+
+ cert, err := tls.LoadX509KeyPair(tlsCertPath, tlsKeyPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tls := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+
+ for _, f := range tests {
+ runMainTest(t, f, tls)
}
+}
+func runMainTest(t *testing.T, f func(args), tls *tls.Config) {
cleanup := []string{
configPath,
statePath,
groupPath,
}
- for _, f := range tests {
- // NOTE: This is not guaranteed to work according to reflect's
- // documentation but seems to work reliable for normal
- // functions.
- fn := runtime.FuncForPC(reflect.ValueOf(f).Pointer())
- name := fn.Name()
- name = name[strings.LastIndex(name, ".")+1:]
-
- t.Run(name, func(t *testing.T) {
- // Preparation & cleanup
- for _, p := range cleanup {
- err := os.Remove(p)
- if err != nil && !os.IsNotExist(err) {
- t.Fatal(err)
- }
- // Remove the file at the end of this test
- // run, if it was created
- defer os.Remove(p)
+ // NOTE: This is not guaranteed to work according to reflect's
+ // documentation but seems to work reliable for normal functions.
+ fn := runtime.FuncForPC(reflect.ValueOf(f).Pointer())
+ name := fn.Name()
+ name = name[strings.LastIndex(name, ".")+1:]
+ if tls != nil {
+ name = "tls" + name
+ }
+
+ t.Run(name, func(t *testing.T) {
+ // Preparation & cleanup
+ for _, p := range cleanup {
+ err := os.Remove(p)
+ if err != nil && !os.IsNotExist(err) {
+ t.Fatal(err)
}
+ // Remove the file at the end of this test run, if it
+ // was created
+ defer os.Remove(p)
+ }
+
+ var handler func(http.ResponseWriter, *http.Request)
+ ts := httptest.NewUnstartedServer(http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ handler(w, r)
+ }))
+ if tls == nil {
+ ts.Start()
+ } else {
+ ts.TLS = tls
+ ts.StartTLS()
+ }
+ defer ts.Close()
- var handler func(http.ResponseWriter, *http.Request)
- ts := httptest.NewServer(http.HandlerFunc(
- func(w http.ResponseWriter, r *http.Request) {
- handler(w, r)
- }))
- defer ts.Close()
-
- f(args{
- t: t,
- url: ts.URL,
- handler: &handler,
- })
+ f(args{
+ t: t,
+ url: ts.URL,
+ handler: &handler,
})
- }
+ })
}
func fetchPasswdCacheFileDoesNotExist(a args) {
mustMakeOld(t, passwdPath, statePath)
lastChange := time.Now()
+ change := false
*a.handler = func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/passwd" {
return
t.Fatalf("invalid If-Modified-Since %v",
modified)
}
- if !x.Before(lastChange) {
+ if !x.Before(lastChange.Truncate(time.Second)) {
w.WriteHeader(http.StatusNotModified)
return
}
}
w.Header().Add("Last-Modified",
- lastChange.Format(http.TimeFormat))
+ lastChange.UTC().Format(http.TimeFormat))
fmt.Fprintln(w, "root:x:0:0:root:/root:/bin/bash")
fmt.Fprintln(w, "daemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin")
+ if change {
+ fmt.Fprintln(w, "bin:x:2:2:bin:/bin:/usr/sbin/nologin")
+ }
}
err = mainFetch(configPath)
mustNotExist(t, plainPath, groupPath)
mustBeNew(t, passwdPath, statePath)
mustHaveHash(t, passwdPath, "bbb7db67469b111200400e2470346d5515d64c23")
+
+ t.Log("Fetch again with newer server response")
+
+ change = true
+ lastChange = time.Now().Add(time.Second)
+
+ mustMakeOld(t, passwdPath, statePath)
+
+ err = mainFetch(configPath)
+ if err != nil {
+ t.Error(err)
+ }
+
+ mustNotExist(t, plainPath, groupPath)
+ mustBeNew(t, passwdPath, statePath)
+ mustHaveHash(t, passwdPath, "ca9c7477cb425667fc9ecbd79e8e1c2ad0e84423")
}
func fetchPlainEmpty(a args) {
type = "plain"
url = "%[2]s/plain"
path = "%[3]s"
-`, statePath, a.url, plainPath))
+ca = "%[4]s"
+`, statePath, a.url, plainPath, tlsCAPath))
mustCreate(t, plainPath)
*a.handler = func(w http.ResponseWriter, r *http.Request) {
type = "plain"
url = "%[2]s/plain"
path = "%[3]s"
-`, statePath, a.url, plainPath))
+ca = "%[4]s"
+`, statePath, a.url, plainPath, tlsCAPath))
mustCreate(t, plainPath)
mustHaveHash(t, plainPath, "da39a3ee5e6b4b0d3255bfef95601890afd80709")
type = "passwd"
url = "%[2]s/passwd"
path = "%[3]s"
+ca = "%[5]s"
[[file]]
type = "group"
url = "%[2]s/group"
path = "%[4]s"
-`, statePath, a.url, passwdPath, groupPath))
+ca = "%[5]s"
+`, statePath, a.url, passwdPath, groupPath, tlsCAPath))
mustCreate(t, passwdPath)
mustCreate(t, groupPath)
mustHaveHash(t, passwdPath, "da39a3ee5e6b4b0d3255bfef95601890afd80709")
// because the second fetch failed
mustBeOld(t, passwdPath, groupPath)
}
+
+func fetchBasicAuth(a args) {
+ t := a.t
+ mustWritePasswdConfig(t, a.url)
+ mustCreate(t, passwdPath)
+ mustHaveHash(t, passwdPath, "da39a3ee5e6b4b0d3255bfef95601890afd80709")
+
+ validUser := "username"
+ validPass := "password"
+
+ *a.handler = func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/passwd" {
+ return
+ }
+
+ user, pass, ok := r.BasicAuth()
+ // NOTE: Do not use this in production because it permits
+ // attackers to determine the length of user/pass. Instead use
+ // hashes and subtle.ConstantTimeCompare().
+ if !ok || user != validUser || pass != validPass {
+ w.Header().Set("WWW-Authenticate", `Basic realm="Test"`)
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+
+ fmt.Fprintln(w, "root:x:0:0:root:/root:/bin/bash")
+ fmt.Fprintln(w, "daemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin")
+ }
+
+ t.Log("Missing authentication")
+
+ err := mainFetch(configPath)
+ mustBeErrorWithSubstring(t, err,
+ "status code 401")
+
+ mustNotExist(t, statePath, groupPath, plainPath)
+ mustBeOld(t, passwdPath)
+
+ t.Log("Unsafe config permissions")
+
+ mustWriteConfig(t, fmt.Sprintf(`
+statepath = "%[1]s"
+
+[[file]]
+type = "passwd"
+url = "%[2]s/passwd"
+path = "%[3]s"
+ca = "%[4]s"
+username = "%[5]s"
+password = "%[6]s"
+`, statePath, a.url, passwdPath, tlsCAPath, validUser, validPass))
+
+ err = os.Chmod(configPath, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = mainFetch(configPath)
+ mustBeErrorWithSubstring(t, err,
+ "file[0].username/passsword in use and unsafe permissions "+
+ "-rw-r--r-- on \"testdata/config.toml\"")
+
+ mustNotExist(t, statePath, groupPath, plainPath)
+ mustBeOld(t, passwdPath)
+
+ t.Log("Working authentication")
+
+ err = os.Chmod(configPath, 0600)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = mainFetch(configPath)
+ if err != nil {
+ t.Error(err)
+ }
+
+ mustNotExist(t, plainPath, groupPath)
+ mustBeNew(t, passwdPath, statePath)
+ mustHaveHash(t, passwdPath, "bbb7db67469b111200400e2470346d5515d64c23")
+}
+
+func fetchInvalidCA(a args) {
+ t := a.t
+
+ // System CA
+
+ mustWriteConfig(t, fmt.Sprintf(`
+statepath = "%[1]s"
+
+[[file]]
+type = "passwd"
+url = "%[2]s/passwd"
+path = "%[3]s"
+`, statePath, a.url, passwdPath))
+ mustCreate(t, passwdPath)
+ mustHaveHash(t, passwdPath, "da39a3ee5e6b4b0d3255bfef95601890afd80709")
+
+ *a.handler = func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/passwd" {
+ fmt.Fprintln(w, "root:x:0:0:root:/root:/bin/bash")
+ }
+ }
+
+ err := mainFetch(configPath)
+ mustBeErrorWithSubstring(t, err,
+ "x509: certificate signed by unknown authority")
+
+ mustNotExist(t, statePath, plainPath, groupPath)
+ mustBeOld(t, passwdPath)
+
+ // Invalid CA
+
+ mustWriteConfig(t, fmt.Sprintf(`
+statepath = "%[1]s"
+
+[[file]]
+type = "passwd"
+url = "%[2]s/passwd"
+path = "%[3]s"
+ca = "%[4]s"
+`, statePath, a.url, passwdPath, tlsCA2Path))
+ mustCreate(t, passwdPath)
+ mustHaveHash(t, passwdPath, "da39a3ee5e6b4b0d3255bfef95601890afd80709")
+
+ *a.handler = func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/passwd" {
+ fmt.Fprintln(w, "root:x:0:0:root:/root:/bin/bash")
+ }
+ }
+
+ err = mainFetch(configPath)
+ mustBeErrorWithSubstring(t, err,
+ "x509: certificate signed by unknown authority")
+
+ mustNotExist(t, statePath, plainPath, groupPath)
+ mustBeOld(t, passwdPath)
+}