]> ruderich.org/simon Gitweb - safcm/safcm.git/blobdiff - remote/ainsl/ainsl.go
Use SPDX license identifiers
[safcm/safcm.git] / remote / ainsl / ainsl.go
index d716b2b4d5ce0868b4006a19ed47acb3f7da5ece..6437a5a08a5745237731325e084af4dcfa3be5f9 100644 (file)
@@ -4,20 +4,8 @@
 //
 // FAI: https://fai-project.org
 
-// Copyright (C) 2021  Simon Ruderich
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with this program.  If not, see <http://www.gnu.org/licenses/>.
+// SPDX-License-Identifier: GPL-3.0-or-later
+// Copyright (C) 2021-2024  Simon Ruderich
 
 package ainsl
 
@@ -28,10 +16,11 @@ import (
        "io"
        "io/fs"
        "os"
-       "path/filepath"
        "strings"
        "syscall"
 
+       "golang.org/x/sys/unix"
+
        "ruderich.org/simon/safcm/remote/sync"
 )
 
@@ -46,7 +35,7 @@ func Main(args []string) error {
        optionCreate := flag.Bool("create", false,
                "create the path if it does not exist")
 
-       flag.CommandLine.Parse(args[2:])
+       flag.CommandLine.Parse(args[2:]) //nolint:errcheck
 
        if flag.NArg() != 2 {
                flag.Usage()
@@ -74,14 +63,20 @@ func handle(path string, line string, create bool) ([]string, error) {
                        line)
        }
 
+       parentFd, baseName, err := sync.OpenParentDirectoryNoSymlinks(path)
+       if err != nil {
+               return nil, err
+       }
+       defer unix.Close(parentFd)
+
        var changes []string
 
        var uid, gid int
        var mode fs.FileMode
-       data, stat, err := readFileNoFollow(path)
+       data, stat, err := readFileAtNoFollow(parentFd, baseName)
        if err != nil {
                if !os.IsNotExist(err) {
-                       return nil, err
+                       return nil, fmt.Errorf("%q: %v", path, err)
                }
                if !create {
                        return nil, fmt.Errorf(
@@ -112,7 +107,7 @@ func handle(path string, line string, create bool) ([]string, error) {
                gid = int(x.Gid)
                mode = stat.Mode()
        }
-       stat = nil // prevent accidental use
+       stat = nil //nolint:wastedassign // prevent accidental use
 
        // Check if the expected line is present
        var found bool
@@ -147,18 +142,17 @@ func handle(path string, line string, create bool) ([]string, error) {
        }
 
        // Write via temporary file and rename
-       dir := filepath.Dir(path)
-       base := filepath.Base(path)
-       tmpPath, err := sync.WriteTemp(dir, "."+base, data, uid, gid, mode)
+       tmpBase, err := sync.WriteTempAt(parentFd, "."+baseName,
+               data, uid, gid, mode)
        if err != nil {
                return nil, err
        }
-       err = os.Rename(tmpPath, path)
+       err = unix.Renameat(parentFd, tmpBase, parentFd, baseName)
        if err != nil {
-               os.Remove(tmpPath)
+               unix.Unlinkat(parentFd, tmpBase, 0 /* flags */) //nolint:errcheck
                return nil, err
        }
-       err = sync.SyncPath(dir)
+       err = unix.Fsync(parentFd)
        if err != nil {
                return nil, err
        }
@@ -166,8 +160,8 @@ func handle(path string, line string, create bool) ([]string, error) {
        return changes, nil
 }
 
-func readFileNoFollow(path string) ([]byte, fs.FileInfo, error) {
-       fh, err := sync.OpenFileNoFollow(path)
+func readFileAtNoFollow(dirfd int, base string) ([]byte, fs.FileInfo, error) {
+       fh, err := sync.OpenAtNoFollow(dirfd, base)
        if err != nil {
                return nil, nil, err
        }
@@ -178,13 +172,13 @@ func readFileNoFollow(path string) ([]byte, fs.FileInfo, error) {
                return nil, nil, err
        }
        if stat.Mode().Type() != 0 /* regular file */ {
-               return nil, nil, fmt.Errorf("%q is not a regular file but %s",
-                       path, stat.Mode().Type())
+               return nil, nil, fmt.Errorf("not a regular file but %s",
+                       stat.Mode().Type())
        }
 
        x, err := io.ReadAll(fh)
        if err != nil {
-               return nil, nil, fmt.Errorf("%q: %v", path, err)
+               return nil, nil, err
        }
 
        return x, stat, nil