]> ruderich.org/simon Gitweb - safcm/safcm.git/blob - remote/ainsl/ainsl.go
d716b2b4d5ce0868b4006a19ed47acb3f7da5ece
[safcm/safcm.git] / remote / ainsl / ainsl.go
1 // "ainsl" sub-command: "append if no such line" (inspired by FAI's ainsl),
2 // append lines to files (if not present) without replacing the file
3 // completely
4 //
5 // FAI: https://fai-project.org
6
7 // Copyright (C) 2021  Simon Ruderich
8 //
9 // This program is free software: you can redistribute it and/or modify
10 // it under the terms of the GNU General Public License as published by
11 // the Free Software Foundation, either version 3 of the License, or
12 // (at your option) any later version.
13 //
14 // This program is distributed in the hope that it will be useful,
15 // but WITHOUT ANY WARRANTY; without even the implied warranty of
16 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 // GNU General Public License for more details.
18 //
19 // You should have received a copy of the GNU General Public License
20 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
21
22 package ainsl
23
24 import (
25         "bytes"
26         "flag"
27         "fmt"
28         "io"
29         "io/fs"
30         "os"
31         "path/filepath"
32         "strings"
33         "syscall"
34
35         "ruderich.org/simon/safcm/remote/sync"
36 )
37
38 func Main(args []string) error {
39         flag.Usage = func() {
40                 fmt.Fprintf(os.Stderr,
41                         "usage: %s ainsl [<options>] <path> <line>\n",
42                         args[0])
43                 flag.PrintDefaults()
44         }
45
46         optionCreate := flag.Bool("create", false,
47                 "create the path if it does not exist")
48
49         flag.CommandLine.Parse(args[2:])
50
51         if flag.NArg() != 2 {
52                 flag.Usage()
53                 os.Exit(1)
54         }
55
56         changes, err := handle(flag.Args()[0], flag.Args()[1], *optionCreate)
57         if err != nil {
58                 return fmt.Errorf("ainsl: %v", err)
59         }
60         for _, x := range changes {
61                 fmt.Fprintln(os.Stderr, x)
62         }
63         return nil
64 }
65
66 func handle(path string, line string, create bool) ([]string, error) {
67         // See safcm-remote/sync/files.go for details on the implementation
68
69         if line == "" {
70                 return nil, fmt.Errorf("empty line")
71         }
72         if strings.Contains(line, "\n") {
73                 return nil, fmt.Errorf("line must not contain newlines: %q",
74                         line)
75         }
76
77         var changes []string
78
79         var uid, gid int
80         var mode fs.FileMode
81         data, stat, err := readFileNoFollow(path)
82         if err != nil {
83                 if !os.IsNotExist(err) {
84                         return nil, err
85                 }
86                 if !create {
87                         return nil, fmt.Errorf(
88                                 "%q: file does not exist, use -create",
89                                 path)
90                 }
91
92                 uid, gid = os.Getuid(), os.Getgid()
93                 // Read current umask. Don't do this in programs where
94                 // multiple goroutines create files because this is inherently
95                 // racy! No goroutines here, so it's fine.
96                 umask := syscall.Umask(0)
97                 syscall.Umask(umask)
98                 // Apply umask to created file
99                 mode = 0666 & ^fs.FileMode(umask)
100
101                 changes = append(changes,
102                         fmt.Sprintf("%q: created file (%d/%d %s)",
103                                 path, uid, gid, mode))
104
105         } else {
106                 // Preserve user/group and mode of existing file
107                 x, ok := stat.Sys().(*syscall.Stat_t)
108                 if !ok {
109                         return nil, fmt.Errorf("unsupported Stat().Sys()")
110                 }
111                 uid = int(x.Uid)
112                 gid = int(x.Gid)
113                 mode = stat.Mode()
114         }
115         stat = nil // prevent accidental use
116
117         // Check if the expected line is present
118         var found bool
119         for _, x := range bytes.Split(data, []byte("\n")) {
120                 if string(x) == line {
121                         found = true
122                         break
123                 }
124         }
125         // Make sure the file has a trailing newline. This enforces symmetry
126         // with our changes. Whenever we add a line we also append a trailing
127         // newline. When we conclude that no changes are necessary the file
128         // should be in the same state as we would leave it if there were
129         // changes.
130         if len(data) != 0 && data[len(data)-1] != '\n' {
131                 data = append(data, '\n')
132                 changes = append(changes,
133                         fmt.Sprintf("%q: added missing trailing newline",
134                                 path))
135         }
136
137         // Line present, nothing to do
138         if found && len(changes) == 0 {
139                 return nil, nil
140         }
141
142         // Append line
143         if !found {
144                 data = append(data, []byte(line+"\n")...)
145                 changes = append(changes,
146                         fmt.Sprintf("%q: added line %q", path, line))
147         }
148
149         // Write via temporary file and rename
150         dir := filepath.Dir(path)
151         base := filepath.Base(path)
152         tmpPath, err := sync.WriteTemp(dir, "."+base, data, uid, gid, mode)
153         if err != nil {
154                 return nil, err
155         }
156         err = os.Rename(tmpPath, path)
157         if err != nil {
158                 os.Remove(tmpPath)
159                 return nil, err
160         }
161         err = sync.SyncPath(dir)
162         if err != nil {
163                 return nil, err
164         }
165
166         return changes, nil
167 }
168
169 func readFileNoFollow(path string) ([]byte, fs.FileInfo, error) {
170         fh, err := sync.OpenFileNoFollow(path)
171         if err != nil {
172                 return nil, nil, err
173         }
174         defer fh.Close()
175
176         stat, err := fh.Stat()
177         if err != nil {
178                 return nil, nil, err
179         }
180         if stat.Mode().Type() != 0 /* regular file */ {
181                 return nil, nil, fmt.Errorf("%q is not a regular file but %s",
182                         path, stat.Mode().Type())
183         }
184
185         x, err := io.ReadAll(fh)
186         if err != nil {
187                 return nil, nil, fmt.Errorf("%q: %v", path, err)
188         }
189
190         return x, stat, nil
191 }