X-Git-Url: https://ruderich.org/simon/gitweb/?p=nsscash%2Fnsscash.git;a=blobdiff_plain;f=main.go;h=7a7d8b0da8e6855b8545c609f6169184819db7ae;hp=6322157f2ca9564d6523606095ee0967fd2d8daa;hb=HEAD;hpb=33e06a7ef39ce42d9933ca9106a453daa7eb58ac diff --git a/main.go b/main.go index 6322157..7a7d8b0 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,6 @@ // Main file for nsscash -// Copyright (C) 2019 Simon Ruderich +// Copyright (C) 2019-2021 Simon Ruderich // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -24,6 +24,9 @@ import ( "io/ioutil" "log" "os" + "path/filepath" + + "github.com/google/renameio" ) func main() { @@ -49,7 +52,10 @@ func main() { break } - mainFetch(args[1]) + err := mainFetch(args[1]) + if err != nil { + log.Fatal(err) + } return case "convert": @@ -57,7 +63,10 @@ func main() { break } - mainConvert(args[1], args[2], args[3]) + err := mainConvert(args[1], args[2], args[3]) + if err != nil { + log.Fatal(err) + } return } @@ -65,77 +74,86 @@ func main() { os.Exit(1) } -func mainFetch(cfgPath string) { - cfg, err := LoadConfig(cfgPath) +func mainFetch(cfgPath string) error { + cfg, err := LoadConfig(cfgPath) + if err != nil { + return err + } + state, err := LoadState(cfg.StatePath) + if err != nil { + return err + } + err = handleFiles(cfg, state) + if err != nil { + return err + } + // NOTE: Make sure to call WriteState() only if there were no + // errors (see WriteState() and README) + err = WriteState(cfg.StatePath, state) + if err != nil { + return err + } + return nil +} + +func mainConvert(typ, srcPath, dstPath string) error { + var t FileType + err := t.UnmarshalText([]byte(typ)) + if err != nil { + return err + } + + src, err := ioutil.ReadFile(srcPath) + if err != nil { + return err + } + var x bytes.Buffer + if t == FileTypePlain { + x.Write(src) + } else if t == FileTypePasswd { + pws, err := ParsePasswds(bytes.NewReader(src)) if err != nil { - log.Fatal(err) + return err } - state, err := LoadState(cfg.StatePath) + err = SerializePasswds(&x, pws) if err != nil { - log.Fatal(err) + return err } - err = handleFiles(cfg, state) + } else if t == FileTypeGroup { + grs, err := ParseGroups(bytes.NewReader(src)) if err != nil { - log.Fatal(err) + return err } - // NOTE: Make sure to call WriteState() only if there were no - // errors (see WriteState() and README) - err = WriteState(cfg.StatePath, state) + err = SerializeGroups(&x, grs) if err != nil { - log.Fatal(err) + return err } -} + } else { + return fmt.Errorf("unsupported file type %v", t) + } -func mainConvert(typ, srcPath, dstPath string) { - var t FileType - err := t.UnmarshalText([]byte(typ)) - if err != nil { - log.Fatal(err) - } + // We must create the file first or deployFile() will abort; this is + // ugly because deployFile() already performs an atomic replacement + // but the simplest solution with the least duplicate code + f, err := renameio.TempFile(filepath.Dir(dstPath), dstPath) + if err != nil { + return err + } + defer f.Cleanup() - src, err := ioutil.ReadFile(srcPath) - if err != nil { - log.Fatal(err) - } - var x bytes.Buffer - if t == FileTypePlain { - x.Write(src) - } else if t == FileTypePasswd { - pws, err := ParsePasswds(bytes.NewReader(src)) - if err != nil { - log.Fatal(err) - } - err = SerializePasswds(&x, pws) - if err != nil { - log.Fatal(err) - } - } else if t == FileTypeGroup { - grs, err := ParseGroups(bytes.NewReader(src)) - if err != nil { - log.Fatal(err) - } - err = SerializeGroups(&x, grs) - if err != nil { - log.Fatal(err) - } - } else { - log.Fatalf("unsupported file type %v", t) - } + err = deployFile(&File{ + Type: t, + Url: srcPath, + Path: f.Name(), + body: x.Bytes(), + }) + if err != nil { + return err + } - // We must create the file first or deployFile() will abort - f, err := os.Create(dstPath) - if err != nil { - log.Fatal(err) - } - f.Close() - - err = deployFile(&File{ - Type: t, - Url: srcPath, - Path: dstPath, - body: x.Bytes(), - }) - if err != nil { - log.Fatal(err) - } + err = f.CloseAtomicallyReplace() + if err != nil { + return err + } + return syncPath(filepath.Dir(dstPath)) }