]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - remote/ainsl/ainsl.go
Use SPDX license identifiers
[safcm/safcm.git] / remote / ainsl / ainsl.go
1 // "ainsl" sub-command: "append if no such line" (inspired by FAI's ainsl),
2 // append lines to files (if not present) without replacing the file
3 // completely
4 //
5 // FAI: https://fai-project.org
6
7 // SPDX-License-Identifier: GPL-3.0-or-later
8 // Copyright (C) 2021-2024  Simon Ruderich
9
10 package ainsl
11
12 import (
13         "bytes"
14         "flag"
15         "fmt"
16         "io"
17         "io/fs"
18         "os"
19         "strings"
20         "syscall"
21
22         "golang.org/x/sys/unix"
23
24         "ruderich.org/simon/safcm/remote/sync"
25 )
26
27 func Main(args []string) error {
28         flag.Usage = func() {
29                 fmt.Fprintf(os.Stderr,
30                         "usage: %s ainsl [<options>] <path> <line>\n",
31                         args[0])
32                 flag.PrintDefaults()
33         }
34
35         optionCreate := flag.Bool("create", false,
36                 "create the path if it does not exist")
37
38         flag.CommandLine.Parse(args[2:]) //nolint:errcheck
39
40         if flag.NArg() != 2 {
41                 flag.Usage()
42                 os.Exit(1)
43         }
44
45         changes, err := handle(flag.Args()[0], flag.Args()[1], *optionCreate)
46         if err != nil {
47                 return fmt.Errorf("ainsl: %v", err)
48         }
49         for _, x := range changes {
50                 fmt.Fprintln(os.Stderr, x)
51         }
52         return nil
53 }
54
55 func handle(path string, line string, create bool) ([]string, error) {
56         // See safcm-remote/sync/files.go for details on the implementation
57
58         if line == "" {
59                 return nil, fmt.Errorf("empty line")
60         }
61         if strings.Contains(line, "\n") {
62                 return nil, fmt.Errorf("line must not contain newlines: %q",
63                         line)
64         }
65
66         parentFd, baseName, err := sync.OpenParentDirectoryNoSymlinks(path)
67         if err != nil {
68                 return nil, err
69         }
70         defer unix.Close(parentFd)
71
72         var changes []string
73
74         var uid, gid int
75         var mode fs.FileMode
76         data, stat, err := readFileAtNoFollow(parentFd, baseName)
77         if err != nil {
78                 if !os.IsNotExist(err) {
79                         return nil, fmt.Errorf("%q: %v", path, err)
80                 }
81                 if !create {
82                         return nil, fmt.Errorf(
83                                 "%q: file does not exist, use -create",
84                                 path)
85                 }
86
87                 uid, gid = os.Getuid(), os.Getgid()
88                 // Read current umask. Don't do this in programs where
89                 // multiple goroutines create files because this is inherently
90                 // racy! No goroutines here, so it's fine.
91                 umask := syscall.Umask(0)
92                 syscall.Umask(umask)
93                 // Apply umask to created file
94                 mode = 0666 & ^fs.FileMode(umask)
95
96                 changes = append(changes,
97                         fmt.Sprintf("%q: created file (%d/%d %s)",
98                                 path, uid, gid, mode))
99
100         } else {
101                 // Preserve user/group and mode of existing file
102                 x, ok := stat.Sys().(*syscall.Stat_t)
103                 if !ok {
104                         return nil, fmt.Errorf("unsupported Stat().Sys()")
105                 }
106                 uid = int(x.Uid)
107                 gid = int(x.Gid)
108                 mode = stat.Mode()
109         }
110         stat = nil //nolint:wastedassign // prevent accidental use
111
112         // Check if the expected line is present
113         var found bool
114         for _, x := range bytes.Split(data, []byte("\n")) {
115                 if string(x) == line {
116                         found = true
117                         break
118                 }
119         }
120         // Make sure the file has a trailing newline. This enforces symmetry
121         // with our changes. Whenever we add a line we also append a trailing
122         // newline. When we conclude that no changes are necessary the file
123         // should be in the same state as we would leave it if there were
124         // changes.
125         if len(data) != 0 && data[len(data)-1] != '\n' {
126                 data = append(data, '\n')
127                 changes = append(changes,
128                         fmt.Sprintf("%q: added missing trailing newline",
129                                 path))
130         }
131
132         // Line present, nothing to do
133         if found && len(changes) == 0 {
134                 return nil, nil
135         }
136
137         // Append line
138         if !found {
139                 data = append(data, []byte(line+"\n")...)
140                 changes = append(changes,
141                         fmt.Sprintf("%q: added line %q", path, line))
142         }
143
144         // Write via temporary file and rename
145         tmpBase, err := sync.WriteTempAt(parentFd, "."+baseName,
146                 data, uid, gid, mode)
147         if err != nil {
148                 return nil, err
149         }
150         err = unix.Renameat(parentFd, tmpBase, parentFd, baseName)
151         if err != nil {
152                 unix.Unlinkat(parentFd, tmpBase, 0 /* flags */) //nolint:errcheck
153                 return nil, err
154         }
155         err = unix.Fsync(parentFd)
156         if err != nil {
157                 return nil, err
158         }
159
160         return changes, nil
161 }
162
163 func readFileAtNoFollow(dirfd int, base string) ([]byte, fs.FileInfo, error) {
164         fh, err := sync.OpenAtNoFollow(dirfd, base)
165         if err != nil {
166                 return nil, nil, err
167         }
168         defer fh.Close()
169
170         stat, err := fh.Stat()
171         if err != nil {
172                 return nil, nil, err
173         }
174         if stat.Mode().Type() != 0 /* regular file */ {
175                 return nil, nil, fmt.Errorf("not a regular file but %s",
176                         stat.Mode().Type())
177         }
178
179         x, err := io.ReadAll(fh)
180         if err != nil {
181                 return nil, nil, err
182         }
183
184         return x, stat, nil
185 }