]> ruderich.org/simon Gitweb - nsscash/nsscash.git/blob - group.go
nsscash: main_test: add special tests
[nsscash/nsscash.git] / group.go
1 // Parse /etc/group files and serialize them
2
3 // Copyright (C) 2019  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU Affero General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU Affero General Public License for more details.
14 //
15 // You should have received a copy of the GNU Affero General Public License
16 // along with this program.  If not, see <https://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "bufio"
22         "bytes"
23         "encoding/binary"
24         "fmt"
25         "io"
26         "math"
27         "sort"
28         "strconv"
29         "strings"
30
31         "github.com/pkg/errors"
32 )
33
34 // Version written in SerializeGroups()
35 const GroupVersion = 1
36
37 type Group struct {
38         Name    string
39         Passwd  string
40         Gid     uint64
41         Members []string
42 }
43
44 // Go does not support slices in map keys; therefore, use this separate struct
45 type GroupKey struct {
46         Name    string
47         Passwd  string
48         Gid     uint64
49         Members string // "," separated
50 }
51
52 func toKey(g Group) GroupKey {
53         return GroupKey{
54                 Name:    g.Name,
55                 Passwd:  g.Passwd,
56                 Gid:     g.Gid,
57                 Members: strings.Join(g.Members, ","),
58         }
59 }
60
61 // ParseGroups parses a file in the format of /etc/group and returns all
62 // entries as slice of Group structs.
63 func ParseGroups(r io.Reader) ([]Group, error) {
64         var res []Group
65
66         s := bufio.NewReader(r)
67         for {
68                 t, err := s.ReadString('\n')
69                 if err != nil {
70                         if err == io.EOF {
71                                 if t != "" {
72                                         return nil, fmt.Errorf(
73                                                 "no newline in last line: %q",
74                                                 t)
75                                 }
76                                 break
77                         }
78                         return nil, err
79                 }
80
81                 x := strings.Split(t, ":")
82                 if len(x) != 4 {
83                         return nil, fmt.Errorf("invalid line %q", t)
84                 }
85
86                 gid, err := strconv.ParseUint(x[2], 10, 64)
87                 if err != nil {
88                         return nil, errors.Wrapf(err, "invalid gid in line %q", t)
89                 }
90
91                 // ReadString() contains the delimiter
92                 x[3] = strings.TrimSuffix(x[3], "\n")
93
94                 var members []string
95                 // No members must result in empty slice, not slice with the
96                 // empty string
97                 if x[3] != "" {
98                         members = strings.Split(x[3], ",")
99                 }
100                 res = append(res, Group{
101                         Name:    x[0],
102                         Passwd:  x[1],
103                         Gid:     gid,
104                         Members: members,
105                 })
106         }
107
108         return res, nil
109 }
110
111 func SerializeGroup(g Group) ([]byte, error) {
112         le := binary.LittleEndian
113
114         // Concatenate all (NUL-terminated) strings and store the offsets
115         var mems bytes.Buffer
116         var mems_off []uint16
117         for _, m := range g.Members {
118                 mems_off = append(mems_off, uint16(mems.Len()))
119                 mems.Write([]byte(m))
120                 mems.WriteByte(0)
121         }
122         var data bytes.Buffer
123         data.Write([]byte(g.Name))
124         data.WriteByte(0)
125         offPasswd := uint16(data.Len())
126         data.Write([]byte(g.Passwd))
127         data.WriteByte(0)
128         alignBufferTo(&data, 2) // align the following uint16
129         offMemOff := uint16(data.Len())
130         // Offsets for group members
131         offMem := offMemOff + 2*uint16(len(mems_off))
132         for _, o := range mems_off {
133                 tmp := make([]byte, 2)
134                 le.PutUint16(tmp, offMem+o)
135                 data.Write(tmp)
136         }
137         // And the group members concatenated as above
138         data.Write(mems.Bytes())
139         // Ensure the offsets can fit the length of this entry
140         if data.Len() > math.MaxUint16 {
141                 return nil, fmt.Errorf("group too large to serialize: %v, %v",
142                         data.Len(), g)
143         }
144         size := uint16(data.Len())
145
146         var res bytes.Buffer // serialized result
147
148         id := make([]byte, 8)
149         // gid
150         le.PutUint64(id, g.Gid)
151         res.Write(id)
152
153         off := make([]byte, 2)
154         // off_passwd
155         le.PutUint16(off, offPasswd)
156         res.Write(off)
157         // off_mem_off
158         le.PutUint16(off, offMemOff)
159         res.Write(off)
160         // mem_count
161         le.PutUint16(off, uint16(len(g.Members)))
162         res.Write(off)
163         // data_size
164         le.PutUint16(off, size)
165         res.Write(off)
166
167         res.Write(data.Bytes())
168         // We must pad each entry so that all uint64 at the beginning of the
169         // struct are 8 byte aligned
170         alignBufferTo(&res, 8)
171
172         return res.Bytes(), nil
173 }
174
175 func SerializeGroups(w io.Writer, grs []Group) error {
176         // Serialize groups and store offsets
177         var data bytes.Buffer
178         offsets := make(map[GroupKey]uint64)
179         for _, g := range grs {
180                 // TODO: warn about duplicate entries
181                 offsets[toKey(g)] = uint64(data.Len())
182                 x, err := SerializeGroup(g)
183                 if err != nil {
184                         return err
185                 }
186                 data.Write(x)
187         }
188
189         // Copy to prevent sorting from modifying the argument
190         sorted := make([]Group, len(grs))
191         copy(sorted, grs)
192
193         le := binary.LittleEndian
194         tmp := make([]byte, 8)
195
196         // Create index "sorted" in input order, used when iterating over all
197         // passwd entries (getgrent_r); keeping the original order makes
198         // debugging easier
199         var indexOrig bytes.Buffer
200         for _, g := range grs {
201                 le.PutUint64(tmp, offsets[toKey(g)])
202                 indexOrig.Write(tmp)
203         }
204
205         // Create index sorted after id
206         var indexId bytes.Buffer
207         sort.Slice(sorted, func(i, j int) bool {
208                 return sorted[i].Gid < sorted[j].Gid
209         })
210         for _, g := range sorted {
211                 le.PutUint64(tmp, offsets[toKey(g)])
212                 indexId.Write(tmp)
213         }
214
215         // Create index sorted after name
216         var indexName bytes.Buffer
217         sort.Slice(sorted, func(i, j int) bool {
218                 return sorted[i].Name < sorted[j].Name
219         })
220         for _, g := range sorted {
221                 le.PutUint64(tmp, offsets[toKey(g)])
222                 indexName.Write(tmp)
223         }
224
225         // Sanity check
226         if len(grs)*8 != indexOrig.Len() ||
227                 indexOrig.Len() != indexId.Len() ||
228                 indexId.Len() != indexName.Len() {
229                 return fmt.Errorf("indexes have inconsistent length")
230         }
231
232         // Write result
233
234         // magic
235         w.Write([]byte("NSS-CASH"))
236         // version
237         le.PutUint64(tmp, GroupVersion)
238         w.Write(tmp)
239         // count
240         le.PutUint64(tmp, uint64(len(grs)))
241         w.Write(tmp)
242         // off_orig_index
243         offset := uint64(0)
244         le.PutUint64(tmp, offset)
245         w.Write(tmp)
246         // off_id_index
247         offset += uint64(indexOrig.Len())
248         le.PutUint64(tmp, offset)
249         w.Write(tmp)
250         // off_name_index
251         offset += uint64(indexId.Len())
252         le.PutUint64(tmp, offset)
253         w.Write(tmp)
254         // off_data
255         offset += uint64(indexName.Len())
256         le.PutUint64(tmp, offset)
257         w.Write(tmp)
258
259         _, err := indexOrig.WriteTo(w)
260         if err != nil {
261                 return err
262         }
263         _, err = indexId.WriteTo(w)
264         if err != nil {
265                 return err
266         }
267         _, err = indexName.WriteTo(w)
268         if err != nil {
269                 return err
270         }
271         _, err = data.WriteTo(w)
272         if err != nil {
273                 return err
274         }
275
276         return nil
277 }