]> ruderich.org/simon Gitweb - nsscash/nsscash.git/blob - group.go
nsscash: support longer lines in passwd/group files
[nsscash/nsscash.git] / group.go
1 // Parse /etc/group files and serialize them
2
3 // Copyright (C) 2019  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU Affero General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU Affero General Public License for more details.
14 //
15 // You should have received a copy of the GNU Affero General Public License
16 // along with this program.  If not, see <https://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "bufio"
22         "bytes"
23         "encoding/binary"
24         "fmt"
25         "io"
26         "sort"
27         "strconv"
28         "strings"
29
30         "github.com/pkg/errors"
31 )
32
33 // Version written in SerializeGroups()
34 const GroupVersion = 1
35
36 type Group struct {
37         Name    string
38         Passwd  string
39         Gid     uint64
40         Members []string
41 }
42
43 // Go does not support slices in map keys; therefore, use this separate struct
44 type GroupKey struct {
45         Name    string
46         Passwd  string
47         Gid     uint64
48         Members string // "," separated
49 }
50
51 func toKey(g Group) GroupKey {
52         return GroupKey{
53                 Name:    g.Name,
54                 Passwd:  g.Passwd,
55                 Gid:     g.Gid,
56                 Members: strings.Join(g.Members, ","),
57         }
58 }
59
60 // ParseGroups parses a file in the format of /etc/group and returns all
61 // entries as slice of Group structs.
62 func ParseGroups(r io.Reader) ([]Group, error) {
63         var res []Group
64
65         s := bufio.NewReader(r)
66         for {
67                 t, err := s.ReadString('\n')
68                 if err != nil {
69                         if err == io.EOF {
70                                 break
71                         }
72                         return nil, err
73                 }
74
75                 x := strings.Split(t, ":")
76                 if len(x) != 4 {
77                         return nil, fmt.Errorf("invalid line %q", t)
78                 }
79
80                 gid, err := strconv.ParseUint(x[2], 10, 64)
81                 if err != nil {
82                         return nil, errors.Wrapf(err, "invalid gid in line %q", t)
83                 }
84
85                 // ReadString() contains the delimiter
86                 x[3] = strings.TrimSuffix(x[3], "\n")
87
88                 var members []string
89                 // No members must result in empty slice, not slice with the
90                 // empty string
91                 if x[3] != "" {
92                         members = strings.Split(x[3], ",")
93                 }
94                 res = append(res, Group{
95                         Name:    x[0],
96                         Passwd:  x[1],
97                         Gid:     gid,
98                         Members: members,
99                 })
100         }
101
102         return res, nil
103 }
104
105 func SerializeGroup(g Group) []byte {
106         le := binary.LittleEndian
107
108         // Concatenate all (NUL-terminated) strings and store the offsets
109         var mems bytes.Buffer
110         var mems_off []uint16
111         for _, m := range g.Members {
112                 mems_off = append(mems_off, uint16(mems.Len()))
113                 mems.Write([]byte(m))
114                 mems.WriteByte(0)
115         }
116         var data bytes.Buffer
117         data.Write([]byte(g.Name))
118         data.WriteByte(0)
119         offPasswd := uint16(data.Len())
120         data.Write([]byte(g.Passwd))
121         data.WriteByte(0)
122         alignBufferTo(&data, 2) // align the following uint16
123         offMemOff := uint16(data.Len())
124         // Offsets for group members
125         offMem := offMemOff + 2*uint16(len(mems_off))
126         for _, o := range mems_off {
127                 tmp := make([]byte, 2)
128                 le.PutUint16(tmp, offMem+o)
129                 data.Write(tmp)
130         }
131         // And the group members concatenated as above
132         data.Write(mems.Bytes())
133         size := uint16(data.Len())
134
135         var res bytes.Buffer // serialized result
136
137         id := make([]byte, 8)
138         // gid
139         le.PutUint64(id, g.Gid)
140         res.Write(id)
141
142         off := make([]byte, 2)
143         // off_passwd
144         le.PutUint16(off, offPasswd)
145         res.Write(off)
146         // off_mem_off
147         le.PutUint16(off, offMemOff)
148         res.Write(off)
149         // mem_count
150         le.PutUint16(off, uint16(len(g.Members)))
151         res.Write(off)
152         // data_size
153         le.PutUint16(off, size)
154         res.Write(off)
155
156         res.Write(data.Bytes())
157         // We must pad each entry so that all uint64 at the beginning of the
158         // struct are 8 byte aligned
159         alignBufferTo(&res, 8)
160
161         return res.Bytes()
162 }
163
164 func SerializeGroups(w io.Writer, grs []Group) error {
165         // Serialize groups and store offsets
166         var data bytes.Buffer
167         offsets := make(map[GroupKey]uint64)
168         for _, g := range grs {
169                 // TODO: warn about duplicate entries
170                 offsets[toKey(g)] = uint64(data.Len())
171                 data.Write(SerializeGroup(g))
172         }
173
174         // Copy to prevent sorting from modifying the argument
175         sorted := make([]Group, len(grs))
176         copy(sorted, grs)
177
178         le := binary.LittleEndian
179         tmp := make([]byte, 8)
180
181         // Create index "sorted" in input order, used when iterating over all
182         // passwd entries (getgrent_r); keeping the original order makes
183         // debugging easier
184         var indexOrig bytes.Buffer
185         for _, g := range grs {
186                 le.PutUint64(tmp, offsets[toKey(g)])
187                 indexOrig.Write(tmp)
188         }
189
190         // Create index sorted after id
191         var indexId bytes.Buffer
192         sort.Slice(sorted, func(i, j int) bool {
193                 return sorted[i].Gid < sorted[j].Gid
194         })
195         for _, g := range sorted {
196                 le.PutUint64(tmp, offsets[toKey(g)])
197                 indexId.Write(tmp)
198         }
199
200         // Create index sorted after name
201         var indexName bytes.Buffer
202         sort.Slice(sorted, func(i, j int) bool {
203                 return sorted[i].Name < sorted[j].Name
204         })
205         for _, g := range sorted {
206                 le.PutUint64(tmp, offsets[toKey(g)])
207                 indexName.Write(tmp)
208         }
209
210         // Sanity check
211         if len(grs)*8 != indexOrig.Len() ||
212                 indexOrig.Len() != indexId.Len() ||
213                 indexId.Len() != indexName.Len() {
214                 return fmt.Errorf("indexes have inconsistent length")
215         }
216
217         // Write result
218
219         // magic
220         w.Write([]byte("NSS-CASH"))
221         // version
222         le.PutUint64(tmp, GroupVersion)
223         w.Write(tmp)
224         // count
225         le.PutUint64(tmp, uint64(len(grs)))
226         w.Write(tmp)
227         // off_orig_index
228         offset := uint64(0)
229         le.PutUint64(tmp, offset)
230         w.Write(tmp)
231         // off_id_index
232         offset += uint64(indexOrig.Len())
233         le.PutUint64(tmp, offset)
234         w.Write(tmp)
235         // off_name_index
236         offset += uint64(indexId.Len())
237         le.PutUint64(tmp, offset)
238         w.Write(tmp)
239         // off_data
240         offset += uint64(indexName.Len())
241         le.PutUint64(tmp, offset)
242         w.Write(tmp)
243
244         _, err := indexOrig.WriteTo(w)
245         if err != nil {
246                 return err
247         }
248         _, err = indexId.WriteTo(w)
249         if err != nil {
250                 return err
251         }
252         _, err = indexName.WriteTo(w)
253         if err != nil {
254                 return err
255         }
256         _, err = data.WriteTo(w)
257         if err != nil {
258                 return err
259         }
260
261         return nil
262 }