]> ruderich.org/simon Gitweb - nsscash/nsscash.git/blob - group.go
nss: Makefile: fix typo in LD_PRELOAD variable name
[nsscash/nsscash.git] / group.go
1 // Parse /etc/group files and serialize them
2
3 // Copyright (C) 2019  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU Affero General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU Affero General Public License for more details.
14 //
15 // You should have received a copy of the GNU Affero General Public License
16 // along with this program.  If not, see <https://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "bufio"
22         "bytes"
23         "encoding/binary"
24         "fmt"
25         "io"
26         "sort"
27         "strconv"
28         "strings"
29
30         "github.com/pkg/errors"
31 )
32
33 // Version written in SerializeGroups()
34 const GroupVersion = 1
35
36 type Group struct {
37         Name    string
38         Passwd  string
39         Gid     uint64
40         Members []string
41 }
42
43 // Go does not support slices in map keys; therefore, use this separate struct
44 type GroupKey struct {
45         Name    string
46         Passwd  string
47         Gid     uint64
48         Members string
49 }
50
51 func toKey(g Group) GroupKey {
52         return GroupKey{
53                 Name:    g.Name,
54                 Passwd:  g.Passwd,
55                 Gid:     g.Gid,
56                 Members: strings.Join(g.Members, ","),
57         }
58 }
59
60 // ParseGroups parses a file in the format of /etc/group and returns all
61 // entries as Group structs.
62 func ParseGroups(r io.Reader) ([]Group, error) {
63         var res []Group
64
65         s := bufio.NewScanner(r)
66         for s.Scan() {
67                 t := s.Text()
68
69                 x := strings.Split(t, ":")
70                 if len(x) != 4 {
71                         return nil, fmt.Errorf("invalid line %q", t)
72                 }
73
74                 gid, err := strconv.ParseUint(x[2], 10, 64)
75                 if err != nil {
76                         return nil, errors.Wrapf(err, "invalid gid in line %q", t)
77                 }
78
79                 var members []string
80                 // No members must result in empty slice, not slice with the
81                 // empty string
82                 if x[3] != "" {
83                         members = strings.Split(x[3], ",")
84                 }
85                 res = append(res, Group{
86                         Name:    x[0],
87                         Passwd:  x[1],
88                         Gid:     gid,
89                         Members: members,
90                 })
91         }
92         err := s.Err()
93         if err != nil {
94                 return nil, err
95         }
96
97         return res, nil
98 }
99
100 func SerializeGroup(g Group) []byte {
101         le := binary.LittleEndian
102
103         // Concatenate all (NUL-terminated) strings and store the offsets
104         var mems bytes.Buffer
105         var mems_off []uint16
106         for _, m := range g.Members {
107                 mems_off = append(mems_off, uint16(mems.Len()))
108                 mems.Write([]byte(m))
109                 mems.WriteByte(0)
110         }
111         var data bytes.Buffer
112         data.Write([]byte(g.Name))
113         data.WriteByte(0)
114         offPasswd := uint16(data.Len())
115         data.Write([]byte(g.Passwd))
116         data.WriteByte(0)
117         // Padding to align the following uint16
118         if data.Len()%2 != 0 {
119                 data.WriteByte(0)
120         }
121         offMemOff := uint16(data.Len())
122         // Offsets for group members
123         offMem := offMemOff + 2*uint16(len(mems_off))
124         for _, o := range mems_off {
125                 tmp := make([]byte, 2)
126                 le.PutUint16(tmp, offMem+o)
127                 data.Write(tmp)
128         }
129         // And the group members concatenated as above
130         data.Write(mems.Bytes())
131         size := uint16(data.Len())
132
133         var res bytes.Buffer // serialized result
134
135         id := make([]byte, 8)
136         // gid
137         le.PutUint64(id, g.Gid)
138         res.Write(id)
139
140         off := make([]byte, 2)
141         // off_passwd
142         le.PutUint16(off, offPasswd)
143         res.Write(off)
144         // off_mem_off
145         le.PutUint16(off, offMemOff)
146         res.Write(off)
147         // mem_count
148         le.PutUint16(off, uint16(len(g.Members)))
149         res.Write(off)
150         // data_size
151         le.PutUint16(off, size)
152         res.Write(off)
153
154         res.Write(data.Bytes())
155         // We must pad each entry so that all uint64 at the beginning of the
156         // struct are 8 byte aligned
157         l := res.Len()
158         if l%8 != 0 {
159                 for i := 0; i < 8-l%8; i++ {
160                         res.WriteByte(0)
161                 }
162         }
163
164         return res.Bytes()
165 }
166
167 func SerializeGroups(w io.Writer, grs []Group) error {
168         // Serialize groups and store offsets
169         var data bytes.Buffer
170         offsets := make(map[GroupKey]uint64)
171         for _, g := range grs {
172                 // TODO: warn about duplicate entries
173                 offsets[toKey(g)] = uint64(data.Len())
174                 data.Write(SerializeGroup(g))
175         }
176
177         // Copy to prevent sorting from modifying the argument
178         sorted := make([]Group, len(grs))
179         copy(sorted, grs)
180
181         le := binary.LittleEndian
182         tmp := make([]byte, 8)
183
184         // Create index "sorted" in input order, used when iterating over all
185         // passwd entries (getgrent_r); keeping the original order makes
186         // debugging easier
187         var indexOrig bytes.Buffer
188         for _, g := range grs {
189                 le.PutUint64(tmp, offsets[toKey(g)])
190                 indexOrig.Write(tmp)
191         }
192
193         // Create index sorted after id
194         var indexId bytes.Buffer
195         sort.Slice(sorted, func(i, j int) bool {
196                 return sorted[i].Gid < sorted[j].Gid
197         })
198         for _, g := range sorted {
199                 le.PutUint64(tmp, offsets[toKey(g)])
200                 indexId.Write(tmp)
201         }
202
203         // Create index sorted after name
204         var indexName bytes.Buffer
205         sort.Slice(sorted, func(i, j int) bool {
206                 return sorted[i].Name < sorted[j].Name
207         })
208         for _, g := range sorted {
209                 le.PutUint64(tmp, offsets[toKey(g)])
210                 indexName.Write(tmp)
211         }
212
213         // Sanity check
214         if len(grs)*8 != indexOrig.Len() ||
215                 indexOrig.Len() != indexId.Len() ||
216                 indexId.Len() != indexName.Len() {
217                 return fmt.Errorf("indexes have inconsistent length")
218         }
219
220         // Write result
221
222         // magic
223         w.Write([]byte("NSS-CASH"))
224         // version
225         le.PutUint64(tmp, GroupVersion)
226         w.Write(tmp)
227         // count
228         le.PutUint64(tmp, uint64(len(grs)))
229         w.Write(tmp)
230         // off_orig_index
231         offset := uint64(0)
232         le.PutUint64(tmp, offset)
233         w.Write(tmp)
234         // off_id_index
235         offset += uint64(indexOrig.Len())
236         le.PutUint64(tmp, offset)
237         w.Write(tmp)
238         // off_name_index
239         offset += uint64(indexId.Len())
240         le.PutUint64(tmp, offset)
241         w.Write(tmp)
242         // off_data
243         offset += uint64(indexName.Len())
244         le.PutUint64(tmp, offset)
245         w.Write(tmp)
246
247         _, err := indexOrig.WriteTo(w)
248         if err != nil {
249                 return err
250         }
251         _, err = indexId.WriteTo(w)
252         if err != nil {
253                 return err
254         }
255         _, err = indexName.WriteTo(w)
256         if err != nil {
257                 return err
258         }
259         _, err = data.WriteTo(w)
260         if err != nil {
261                 return err
262         }
263
264         return nil
265 }