]> ruderich.org/simon Gitweb - nsscash/nsscash.git/blobdiff - main.go
.github: update upstream actions to latest version
[nsscash/nsscash.git] / main.go
diff --git a/main.go b/main.go
index 6322157f2ca9564d6523606095ee0967fd2d8daa..7a7d8b0da8e6855b8545c609f6169184819db7ae 100644 (file)
--- a/main.go
+++ b/main.go
@@ -1,6 +1,6 @@
 // Main file for nsscash
 
-// Copyright (C) 2019  Simon Ruderich
+// Copyright (C) 2019-2021  Simon Ruderich
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -24,6 +24,9 @@ import (
        "io/ioutil"
        "log"
        "os"
+       "path/filepath"
+
+       "github.com/google/renameio"
 )
 
 func main() {
@@ -49,7 +52,10 @@ func main() {
                        break
                }
 
-               mainFetch(args[1])
+               err := mainFetch(args[1])
+               if err != nil {
+                       log.Fatal(err)
+               }
                return
 
        case "convert":
@@ -57,7 +63,10 @@ func main() {
                        break
                }
 
-               mainConvert(args[1], args[2], args[3])
+               err := mainConvert(args[1], args[2], args[3])
+               if err != nil {
+                       log.Fatal(err)
+               }
                return
        }
 
@@ -65,77 +74,86 @@ func main() {
        os.Exit(1)
 }
 
-func mainFetch(cfgPath string) {
-               cfg, err := LoadConfig(cfgPath)
+func mainFetch(cfgPath string) error {
+       cfg, err := LoadConfig(cfgPath)
+       if err != nil {
+               return err
+       }
+       state, err := LoadState(cfg.StatePath)
+       if err != nil {
+               return err
+       }
+       err = handleFiles(cfg, state)
+       if err != nil {
+               return err
+       }
+       // NOTE: Make sure to call WriteState() only if there were no
+       // errors (see WriteState() and README)
+       err = WriteState(cfg.StatePath, state)
+       if err != nil {
+               return err
+       }
+       return nil
+}
+
+func mainConvert(typ, srcPath, dstPath string) error {
+       var t FileType
+       err := t.UnmarshalText([]byte(typ))
+       if err != nil {
+               return err
+       }
+
+       src, err := ioutil.ReadFile(srcPath)
+       if err != nil {
+               return err
+       }
+       var x bytes.Buffer
+       if t == FileTypePlain {
+               x.Write(src)
+       } else if t == FileTypePasswd {
+               pws, err := ParsePasswds(bytes.NewReader(src))
                if err != nil {
-                       log.Fatal(err)
+                       return err
                }
-               state, err := LoadState(cfg.StatePath)
+               err = SerializePasswds(&x, pws)
                if err != nil {
-                       log.Fatal(err)
+                       return err
                }
-               err = handleFiles(cfg, state)
+       } else if t == FileTypeGroup {
+               grs, err := ParseGroups(bytes.NewReader(src))
                if err != nil {
-                       log.Fatal(err)
+                       return err
                }
-               // NOTE: Make sure to call WriteState() only if there were no
-               // errors (see WriteState() and README)
-               err = WriteState(cfg.StatePath, state)
+               err = SerializeGroups(&x, grs)
                if err != nil {
-                       log.Fatal(err)
+                       return err
                }
-}
+       } else {
+               return fmt.Errorf("unsupported file type %v", t)
+       }
 
-func mainConvert(typ, srcPath, dstPath string) {
-               var t FileType
-               err := t.UnmarshalText([]byte(typ))
-               if err != nil {
-                       log.Fatal(err)
-               }
+       // We must create the file first or deployFile() will abort; this is
+       // ugly because deployFile() already performs an atomic replacement
+       // but the simplest solution with the least duplicate code
+       f, err := renameio.TempFile(filepath.Dir(dstPath), dstPath)
+       if err != nil {
+               return err
+       }
+       defer f.Cleanup()
 
-               src, err := ioutil.ReadFile(srcPath)
-               if err != nil {
-                       log.Fatal(err)
-               }
-               var x bytes.Buffer
-               if t == FileTypePlain {
-                       x.Write(src)
-               } else if t == FileTypePasswd {
-                       pws, err := ParsePasswds(bytes.NewReader(src))
-                       if err != nil {
-                               log.Fatal(err)
-                       }
-                       err = SerializePasswds(&x, pws)
-                       if err != nil {
-                               log.Fatal(err)
-                       }
-               } else if t == FileTypeGroup {
-                       grs, err := ParseGroups(bytes.NewReader(src))
-                       if err != nil {
-                               log.Fatal(err)
-                       }
-                       err = SerializeGroups(&x, grs)
-                       if err != nil {
-                               log.Fatal(err)
-                       }
-               } else {
-                       log.Fatalf("unsupported file type %v", t)
-               }
+       err = deployFile(&File{
+               Type: t,
+               Url:  srcPath,
+               Path: f.Name(),
+               body: x.Bytes(),
+       })
+       if err != nil {
+               return err
+       }
 
-               // We must create the file first or deployFile() will abort
-               f, err := os.Create(dstPath)
-               if err != nil {
-                       log.Fatal(err)
-               }
-               f.Close()
-
-               err = deployFile(&File{
-                       Type: t,
-                       Url:  srcPath,
-                       Path: dstPath,
-                       body: x.Bytes(),
-               })
-               if err != nil {
-                       log.Fatal(err)
-               }
+       err = f.CloseAtomicallyReplace()
+       if err != nil {
+               return err
+       }
+       return syncPath(filepath.Dir(dstPath))
 }