]> ruderich.org/simon Gitweb - safcm/safcm.git/blobdiff - cmd/safcm-remote/sync/files.go
sync: refactor temporary file creation into WriteTemp()
[safcm/safcm.git] / cmd / safcm-remote / sync / files.go
index 06bc4066643b1394b286363e2bd7665a136ed722..6e001a8186d14675864b4d90794d77694269cee0 100644 (file)
@@ -110,10 +110,7 @@ func (s *Sync) syncFile(file *safcm.File, changed *bool) error {
 
        var oldStat fs.FileInfo
 reopen:
-       oldFh, err := os.OpenFile(file.Path,
-               // O_NOFOLLOW prevents symlink attacks
-               // O_NONBLOCK is necessary to prevent blocking on FIFOs
-               os.O_RDONLY|syscall.O_NOFOLLOW|syscall.O_NONBLOCK, 0)
+       oldFh, err := OpenFileNoFollow(file.Path)
        if err != nil {
                err := err.(*fs.PathError)
                if err.Err == syscall.ELOOP {
@@ -273,8 +270,7 @@ reopen:
                // a symlink at this point. There's no lchmod so open the
                // directory.
                debugf("chmodding %s", file.Mode)
-               dh, err := os.OpenFile(file.Path,
-                       os.O_RDONLY|syscall.O_NOFOLLOW|syscall.O_NONBLOCK, 0)
+               dh, err := OpenFileNoFollow(file.Path)
                if err != nil {
                        return err
                }
@@ -332,52 +328,20 @@ reopen:
        }
 
        dir := filepath.Dir(file.Path)
-       base := filepath.Base(file.Path)
+       // Create hidden file which should be ignored by most other tools and
+       // thus not affect anything during creation
+       base := "." + filepath.Base(file.Path)
 
        var tmpPath string
        switch file.Mode.Type() {
        case 0: // regular file
                debugf("creating temporary file %q",
-                       filepath.Join(dir, "."+base+"*"))
-               // Create hidden file which should be ignored by most other
-               // tools and thus not affect anything during creation
-               newFh, err := os.CreateTemp(dir, "."+base)
+                       filepath.Join(dir, base+"*"))
+               tmpPath, err = WriteTemp(dir, base, file.Data,
+                       file.Uid, file.Gid, file.Mode)
                if err != nil {
                        return err
                }
-               tmpPath = newFh.Name()
-
-               _, err = newFh.Write(file.Data)
-               if err != nil {
-                       newFh.Close()
-                       os.Remove(tmpPath)
-                       return err
-               }
-               // CreateTemp() creates the file with 0600
-               err = newFh.Chown(file.Uid, file.Gid)
-               if err != nil {
-                       newFh.Close()
-                       os.Remove(tmpPath)
-                       return err
-               }
-               err = newFh.Chmod(file.Mode)
-               if err != nil {
-                       newFh.Close()
-                       os.Remove(tmpPath)
-                       return err
-               }
-               err = newFh.Sync()
-               if err != nil {
-                       newFh.Close()
-                       os.Remove(tmpPath)
-                       return err
-               }
-               err = newFh.Close()
-               if err != nil {
-                       newFh.Close()
-                       os.Remove(tmpPath)
-                       return err
-               }
 
        case fs.ModeSymlink:
                i := 0
@@ -385,7 +349,7 @@ reopen:
                // Similar to os.CreateTemp() but for symlinks which we cannot
                // open as file
                tmpPath = filepath.Join(dir,
-                       "."+base+strconv.Itoa(rand.Int()))
+                       base+strconv.Itoa(rand.Int()))
                debugf("creating temporary symlink %q", tmpPath)
                err := os.Symlink(string(file.Data), tmpPath)
                if err != nil {
@@ -502,6 +466,57 @@ func diffData(oldData []byte, newData []byte) (string, error) {
        return result, nil
 }
 
+func OpenFileNoFollow(path string) (*os.File, error) {
+       return os.OpenFile(path,
+               // O_NOFOLLOW prevents symlink attacks
+               // O_NONBLOCK is necessary to prevent blocking on FIFOs
+               os.O_RDONLY|syscall.O_NOFOLLOW|syscall.O_NONBLOCK, 0)
+}
+
+func WriteTemp(dir, base string, data []byte, uid, gid int, mode fs.FileMode) (
+       string, error) {
+
+       fh, err := os.CreateTemp(dir, base)
+       if err != nil {
+               return "", err
+       }
+       tmpPath := fh.Name()
+
+       _, err = fh.Write(data)
+       if err != nil {
+               fh.Close()
+               os.Remove(tmpPath)
+               return "", err
+       }
+       // CreateTemp() creates the file with 0600
+       err = fh.Chown(uid, gid)
+       if err != nil {
+               fh.Close()
+               os.Remove(tmpPath)
+               return "", err
+       }
+       err = fh.Chmod(mode)
+       if err != nil {
+               fh.Close()
+               os.Remove(tmpPath)
+               return "", err
+       }
+       err = fh.Sync()
+       if err != nil {
+               fh.Close()
+               os.Remove(tmpPath)
+               return "", err
+       }
+       err = fh.Close()
+       if err != nil {
+               fh.Close()
+               os.Remove(tmpPath)
+               return "", err
+       }
+
+       return tmpPath, nil
+}
+
 // syncPath syncs path, which should be a directory. To guarantee durability
 // it must be called on a parent directory after adding, renaming or removing
 // files therein.