]> ruderich.org/simon Gitweb - safcm/safcm.git/blobdiff - remote/ainsl/ainsl.go
remote: guard against symlinks in earlier path components
[safcm/safcm.git] / remote / ainsl / ainsl.go
index d716b2b4d5ce0868b4006a19ed47acb3f7da5ece..2a64d63439871954379b26630dabc763c1a73f44 100644 (file)
@@ -28,10 +28,11 @@ import (
        "io"
        "io/fs"
        "os"
-       "path/filepath"
        "strings"
        "syscall"
 
+       "golang.org/x/sys/unix"
+
        "ruderich.org/simon/safcm/remote/sync"
 )
 
@@ -74,14 +75,20 @@ func handle(path string, line string, create bool) ([]string, error) {
                        line)
        }
 
+       parentFd, baseName, err := sync.OpenParentDirectoryNoSymlinks(path)
+       if err != nil {
+               return nil, err
+       }
+       defer unix.Close(parentFd)
+
        var changes []string
 
        var uid, gid int
        var mode fs.FileMode
-       data, stat, err := readFileNoFollow(path)
+       data, stat, err := readFileAtNoFollow(parentFd, baseName)
        if err != nil {
                if !os.IsNotExist(err) {
-                       return nil, err
+                       return nil, fmt.Errorf("%q: %v", path, err)
                }
                if !create {
                        return nil, fmt.Errorf(
@@ -147,18 +154,17 @@ func handle(path string, line string, create bool) ([]string, error) {
        }
 
        // Write via temporary file and rename
-       dir := filepath.Dir(path)
-       base := filepath.Base(path)
-       tmpPath, err := sync.WriteTemp(dir, "."+base, data, uid, gid, mode)
+       tmpBase, err := sync.WriteTempAt(parentFd, "."+baseName,
+               data, uid, gid, mode)
        if err != nil {
                return nil, err
        }
-       err = os.Rename(tmpPath, path)
+       err = unix.Renameat(parentFd, tmpBase, parentFd, baseName)
        if err != nil {
-               os.Remove(tmpPath)
+               unix.Unlinkat(parentFd, tmpBase, 0)
                return nil, err
        }
-       err = sync.SyncPath(dir)
+       err = unix.Fsync(parentFd)
        if err != nil {
                return nil, err
        }
@@ -166,8 +172,8 @@ func handle(path string, line string, create bool) ([]string, error) {
        return changes, nil
 }
 
-func readFileNoFollow(path string) ([]byte, fs.FileInfo, error) {
-       fh, err := sync.OpenFileNoFollow(path)
+func readFileAtNoFollow(dirfd int, base string) ([]byte, fs.FileInfo, error) {
+       fh, err := sync.OpenAtNoFollow(dirfd, base)
        if err != nil {
                return nil, nil, err
        }
@@ -178,13 +184,13 @@ func readFileNoFollow(path string) ([]byte, fs.FileInfo, error) {
                return nil, nil, err
        }
        if stat.Mode().Type() != 0 /* regular file */ {
-               return nil, nil, fmt.Errorf("%q is not a regular file but %s",
-                       path, stat.Mode().Type())
+               return nil, nil, fmt.Errorf("not a regular file but %s",
+                       stat.Mode().Type())
        }
 
        x, err := io.ReadAll(fh)
        if err != nil {
-               return nil, nil, fmt.Errorf("%q: %v", path, err)
+               return nil, nil, err
        }
 
        return x, stat, nil