]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - remote/ainsl/ainsl.go
remote: guard against symlinks in earlier path components
[safcm/safcm.git] / remote / ainsl / ainsl.go
1 // "ainsl" sub-command: "append if no such line" (inspired by FAI's ainsl),
2 // append lines to files (if not present) without replacing the file
3 // completely
4 //
5 // FAI: https://fai-project.org
6
7 // Copyright (C) 2021  Simon Ruderich
8 //
9 // This program is free software: you can redistribute it and/or modify
10 // it under the terms of the GNU General Public License as published by
11 // the Free Software Foundation, either version 3 of the License, or
12 // (at your option) any later version.
13 //
14 // This program is distributed in the hope that it will be useful,
15 // but WITHOUT ANY WARRANTY; without even the implied warranty of
16 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 // GNU General Public License for more details.
18 //
19 // You should have received a copy of the GNU General Public License
20 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
21
22 package ainsl
23
24 import (
25         "bytes"
26         "flag"
27         "fmt"
28         "io"
29         "io/fs"
30         "os"
31         "strings"
32         "syscall"
33
34         "golang.org/x/sys/unix"
35
36         "ruderich.org/simon/safcm/remote/sync"
37 )
38
39 func Main(args []string) error {
40         flag.Usage = func() {
41                 fmt.Fprintf(os.Stderr,
42                         "usage: %s ainsl [<options>] <path> <line>\n",
43                         args[0])
44                 flag.PrintDefaults()
45         }
46
47         optionCreate := flag.Bool("create", false,
48                 "create the path if it does not exist")
49
50         flag.CommandLine.Parse(args[2:])
51
52         if flag.NArg() != 2 {
53                 flag.Usage()
54                 os.Exit(1)
55         }
56
57         changes, err := handle(flag.Args()[0], flag.Args()[1], *optionCreate)
58         if err != nil {
59                 return fmt.Errorf("ainsl: %v", err)
60         }
61         for _, x := range changes {
62                 fmt.Fprintln(os.Stderr, x)
63         }
64         return nil
65 }
66
67 func handle(path string, line string, create bool) ([]string, error) {
68         // See safcm-remote/sync/files.go for details on the implementation
69
70         if line == "" {
71                 return nil, fmt.Errorf("empty line")
72         }
73         if strings.Contains(line, "\n") {
74                 return nil, fmt.Errorf("line must not contain newlines: %q",
75                         line)
76         }
77
78         parentFd, baseName, err := sync.OpenParentDirectoryNoSymlinks(path)
79         if err != nil {
80                 return nil, err
81         }
82         defer unix.Close(parentFd)
83
84         var changes []string
85
86         var uid, gid int
87         var mode fs.FileMode
88         data, stat, err := readFileAtNoFollow(parentFd, baseName)
89         if err != nil {
90                 if !os.IsNotExist(err) {
91                         return nil, fmt.Errorf("%q: %v", path, err)
92                 }
93                 if !create {
94                         return nil, fmt.Errorf(
95                                 "%q: file does not exist, use -create",
96                                 path)
97                 }
98
99                 uid, gid = os.Getuid(), os.Getgid()
100                 // Read current umask. Don't do this in programs where
101                 // multiple goroutines create files because this is inherently
102                 // racy! No goroutines here, so it's fine.
103                 umask := syscall.Umask(0)
104                 syscall.Umask(umask)
105                 // Apply umask to created file
106                 mode = 0666 & ^fs.FileMode(umask)
107
108                 changes = append(changes,
109                         fmt.Sprintf("%q: created file (%d/%d %s)",
110                                 path, uid, gid, mode))
111
112         } else {
113                 // Preserve user/group and mode of existing file
114                 x, ok := stat.Sys().(*syscall.Stat_t)
115                 if !ok {
116                         return nil, fmt.Errorf("unsupported Stat().Sys()")
117                 }
118                 uid = int(x.Uid)
119                 gid = int(x.Gid)
120                 mode = stat.Mode()
121         }
122         stat = nil // prevent accidental use
123
124         // Check if the expected line is present
125         var found bool
126         for _, x := range bytes.Split(data, []byte("\n")) {
127                 if string(x) == line {
128                         found = true
129                         break
130                 }
131         }
132         // Make sure the file has a trailing newline. This enforces symmetry
133         // with our changes. Whenever we add a line we also append a trailing
134         // newline. When we conclude that no changes are necessary the file
135         // should be in the same state as we would leave it if there were
136         // changes.
137         if len(data) != 0 && data[len(data)-1] != '\n' {
138                 data = append(data, '\n')
139                 changes = append(changes,
140                         fmt.Sprintf("%q: added missing trailing newline",
141                                 path))
142         }
143
144         // Line present, nothing to do
145         if found && len(changes) == 0 {
146                 return nil, nil
147         }
148
149         // Append line
150         if !found {
151                 data = append(data, []byte(line+"\n")...)
152                 changes = append(changes,
153                         fmt.Sprintf("%q: added line %q", path, line))
154         }
155
156         // Write via temporary file and rename
157         tmpBase, err := sync.WriteTempAt(parentFd, "."+baseName,
158                 data, uid, gid, mode)
159         if err != nil {
160                 return nil, err
161         }
162         err = unix.Renameat(parentFd, tmpBase, parentFd, baseName)
163         if err != nil {
164                 unix.Unlinkat(parentFd, tmpBase, 0)
165                 return nil, err
166         }
167         err = unix.Fsync(parentFd)
168         if err != nil {
169                 return nil, err
170         }
171
172         return changes, nil
173 }
174
175 func readFileAtNoFollow(dirfd int, base string) ([]byte, fs.FileInfo, error) {
176         fh, err := sync.OpenAtNoFollow(dirfd, base)
177         if err != nil {
178                 return nil, nil, err
179         }
180         defer fh.Close()
181
182         stat, err := fh.Stat()
183         if err != nil {
184                 return nil, nil, err
185         }
186         if stat.Mode().Type() != 0 /* regular file */ {
187                 return nil, nil, fmt.Errorf("not a regular file but %s",
188                         stat.Mode().Type())
189         }
190
191         x, err := io.ReadAll(fh)
192         if err != nil {
193                 return nil, nil, err
194         }
195
196         return x, stat, nil
197 }