]> ruderich.org/simon Gitweb - nsscash/nsscash.git/blob - group.go
Add support for group files
[nsscash/nsscash.git] / group.go
1 // Parse /etc/group files and serialize them
2
3 // Copyright (C) 2019  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU Affero General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU Affero General Public License for more details.
14 //
15 // You should have received a copy of the GNU Affero General Public License
16 // along with this program.  If not, see <https://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "bufio"
22         "bytes"
23         "encoding/binary"
24         "fmt"
25         "io"
26         "sort"
27         "strconv"
28         "strings"
29
30         "github.com/pkg/errors"
31 )
32
33 // Version written in SerializeGroups()
34 const GroupVersion = 1
35
36 type Group struct {
37         Name    string
38         Passwd  string
39         Gid     uint64
40         Members []string
41 }
42 type GroupKey struct {
43         Name    string
44         Passwd  string
45         Gid     uint64
46         Members string
47 }
48
49 func toKey(g Group) GroupKey {
50         return GroupKey{
51                 Name:    g.Name,
52                 Passwd:  g.Passwd,
53                 Gid:     g.Gid,
54                 Members: strings.Join(g.Members, ","),
55         }
56 }
57
58 // ParseGroups parses a file in the format of /etc/group and returns all
59 // entries as Group structs.
60 func ParseGroups(r io.Reader) ([]Group, error) {
61         var res []Group
62
63         s := bufio.NewScanner(r)
64         for s.Scan() {
65                 t := s.Text()
66
67                 x := strings.Split(t, ":")
68                 if len(x) != 4 {
69                         return nil, fmt.Errorf("invalid line %q", t)
70                 }
71
72                 gid, err := strconv.ParseUint(x[2], 10, 64)
73                 if err != nil {
74                         return nil, errors.Wrapf(err, "invalid gid in line %q", t)
75                 }
76
77                 var members []string
78                 // No members must result in empty slice, not slice with the
79                 // empty string
80                 if x[3] != "" {
81                         members = strings.Split(x[3], ",")
82                 }
83                 res = append(res, Group{
84                         Name:    x[0],
85                         Passwd:  x[1],
86                         Gid:     gid,
87                         Members: members,
88                 })
89         }
90         err := s.Err()
91         if err != nil {
92                 return nil, err
93         }
94
95         return res, nil
96 }
97
98 func SerializeGroup(g Group) []byte {
99         le := binary.LittleEndian
100
101         // Concatenate all (NUL-terminated) strings and store the offsets
102         var mems bytes.Buffer
103         var mems_off []uint16
104         for _, m := range g.Members {
105                 mems_off = append(mems_off, uint16(mems.Len()))
106                 mems.Write([]byte(m))
107                 mems.WriteByte(0)
108         }
109         var data bytes.Buffer
110         data.Write([]byte(g.Name))
111         data.WriteByte(0)
112         offPasswd := uint16(data.Len())
113         data.Write([]byte(g.Passwd))
114         data.WriteByte(0)
115         // Padding to align the following uint16
116         if data.Len()%2 != 0 {
117                 data.WriteByte(0)
118         }
119         offMemOff := uint16(data.Len())
120         // Offsets for group members
121         offMem := offMemOff + 2*uint16(len(mems_off))
122         for _, o := range mems_off {
123                 tmp := make([]byte, 2)
124                 le.PutUint16(tmp, offMem+o)
125                 data.Write(tmp)
126         }
127         // And the group members concatenated as above
128         data.Write(mems.Bytes())
129         size := uint16(data.Len())
130
131         var res bytes.Buffer // serialized result
132
133         id := make([]byte, 8)
134         // gid
135         le.PutUint64(id, g.Gid)
136         res.Write(id)
137
138         off := make([]byte, 2)
139         // off_passwd
140         le.PutUint16(off, offPasswd)
141         res.Write(off)
142         // off_mem_off
143         le.PutUint16(off, offMemOff)
144         res.Write(off)
145         // mem_count
146         le.PutUint16(off, uint16(len(g.Members)))
147         res.Write(off)
148         // data_size
149         le.PutUint16(off, size)
150         res.Write(off)
151
152         res.Write(data.Bytes())
153         // We must pad each entry so that all uint64 at the beginning of the
154         // struct are 8 byte aligned
155         l := res.Len()
156         if l%8 != 0 {
157                 for i := 0; i < 8-l%8; i++ {
158                         res.WriteByte(0)
159                 }
160         }
161
162         return res.Bytes()
163 }
164
165 func SerializeGroups(w io.Writer, grs []Group) error {
166         // Serialize groups and store offsets
167         var data bytes.Buffer
168         offsets := make(map[GroupKey]uint64)
169         for _, g := range grs {
170                 // TODO: warn about duplicate entries
171                 offsets[toKey(g)] = uint64(data.Len())
172                 data.Write(SerializeGroup(g))
173         }
174
175         // Copy to prevent sorting from modifying the argument
176         sorted := make([]Group, len(grs))
177         copy(sorted, grs)
178
179         le := binary.LittleEndian
180         tmp := make([]byte, 8)
181
182         // Create index "sorted" in input order, used when iterating over all
183         // passwd entries (getgrent_r); keeping the original order makes
184         // debugging easier
185         var indexOrig bytes.Buffer
186         for _, g := range grs {
187                 le.PutUint64(tmp, offsets[toKey(g)])
188                 indexOrig.Write(tmp)
189         }
190
191         // Create index sorted after id
192         var indexId bytes.Buffer
193         sort.Slice(sorted, func(i, j int) bool {
194                 return sorted[i].Gid < sorted[j].Gid
195         })
196         for _, g := range sorted {
197                 le.PutUint64(tmp, offsets[toKey(g)])
198                 indexId.Write(tmp)
199         }
200
201         // Create index sorted after name
202         var indexName bytes.Buffer
203         sort.Slice(sorted, func(i, j int) bool {
204                 return sorted[i].Name < sorted[j].Name
205         })
206         for _, g := range sorted {
207                 le.PutUint64(tmp, offsets[toKey(g)])
208                 indexName.Write(tmp)
209         }
210
211         // Sanity check
212         if indexOrig.Len() != indexId.Len() ||
213                 indexId.Len() != indexName.Len() {
214                 return fmt.Errorf("indexes have inconsistent length")
215         }
216
217         // Write result
218
219         // magic
220         w.Write([]byte("NSS-CASH"))
221         // version
222         le.PutUint64(tmp, GroupVersion)
223         w.Write(tmp)
224         // count
225         le.PutUint64(tmp, uint64(len(grs)))
226         w.Write(tmp)
227         // off_orig_index
228         offset := uint64(0)
229         le.PutUint64(tmp, offset)
230         w.Write(tmp)
231         // off_id_index
232         offset += uint64(indexOrig.Len())
233         le.PutUint64(tmp, offset)
234         w.Write(tmp)
235         // off_name_index
236         offset += uint64(indexId.Len())
237         le.PutUint64(tmp, offset)
238         w.Write(tmp)
239         // off_data
240         offset += uint64(indexName.Len())
241         le.PutUint64(tmp, offset)
242         w.Write(tmp)
243
244         _, err := indexOrig.WriteTo(w)
245         if err != nil {
246                 return err
247         }
248         _, err = indexId.WriteTo(w)
249         if err != nil {
250                 return err
251         }
252         _, err = indexName.WriteTo(w)
253         if err != nil {
254                 return err
255         }
256         _, err = data.WriteTo(w)
257         if err != nil {
258                 return err
259         }
260
261         return nil
262 }