]> ruderich.org/simon Gitweb - nsscash/nsscash.git/blob - passwd.go
nsscash: main_test: use existing t variable instead of a.t
[nsscash/nsscash.git] / passwd.go
1 // Parse /etc/passwd files and serialize them
2
3 // Copyright (C) 2019  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU Affero General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU Affero General Public License for more details.
14 //
15 // You should have received a copy of the GNU Affero General Public License
16 // along with this program.  If not, see <https://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "bufio"
22         "bytes"
23         "encoding/binary"
24         "fmt"
25         "io"
26         "math"
27         "sort"
28         "strconv"
29         "strings"
30
31         "github.com/pkg/errors"
32 )
33
34 // Version written in SerializePasswds()
35 const PasswdVersion = 1
36
37 type Passwd struct {
38         Name   string
39         Passwd string
40         Uid    uint64
41         Gid    uint64
42         Gecos  string
43         Dir    string
44         Shell  string
45 }
46
47 // ParsePasswds parses a file in the format of /etc/passwd and returns all
48 // entries as slice of Passwd structs.
49 func ParsePasswds(r io.Reader) ([]Passwd, error) {
50         var res []Passwd
51
52         s := bufio.NewReader(r)
53         for {
54                 t, err := s.ReadString('\n')
55                 if err != nil {
56                         if err == io.EOF {
57                                 if t != "" {
58                                         return nil, fmt.Errorf(
59                                                 "no newline in last line: %q",
60                                                 t)
61                                 }
62                                 break
63                         }
64                         return nil, err
65                 }
66
67                 x := strings.Split(t, ":")
68                 if len(x) != 7 {
69                         return nil, fmt.Errorf("invalid line %q", t)
70                 }
71
72                 uid, err := strconv.ParseUint(x[2], 10, 64)
73                 if err != nil {
74                         return nil, errors.Wrapf(err, "invalid uid in line %q", t)
75                 }
76                 gid, err := strconv.ParseUint(x[3], 10, 64)
77                 if err != nil {
78                         return nil, errors.Wrapf(err, "invalid gid in line %q", t)
79                 }
80
81                 res = append(res, Passwd{
82                         Name:   x[0],
83                         Passwd: x[1],
84                         Uid:    uid,
85                         Gid:    gid,
86                         Gecos:  x[4],
87                         Dir:    x[5],
88                         // ReadString() contains the delimiter
89                         Shell: strings.TrimSuffix(x[6], "\n"),
90                 })
91         }
92         return res, nil
93 }
94
95 func SerializePasswd(p Passwd) ([]byte, error) {
96         // Concatenate all (NUL-terminated) strings and store the offsets
97         var data bytes.Buffer
98         data.Write([]byte(p.Name))
99         data.WriteByte(0)
100         offPasswd := uint16(data.Len())
101         data.Write([]byte(p.Passwd))
102         data.WriteByte(0)
103         offGecos := uint16(data.Len())
104         data.Write([]byte(p.Gecos))
105         data.WriteByte(0)
106         offDir := uint16(data.Len())
107         data.Write([]byte(p.Dir))
108         data.WriteByte(0)
109         offShell := uint16(data.Len())
110         data.Write([]byte(p.Shell))
111         data.WriteByte(0)
112         // Ensure the offsets can fit the length of this entry
113         if data.Len() > math.MaxUint16 {
114                 return nil, fmt.Errorf("passwd too large to serialize: %v, %v",
115                         data.Len(), p)
116         }
117         size := uint16(data.Len())
118
119         var res bytes.Buffer // serialized result
120         le := binary.LittleEndian
121
122         id := make([]byte, 8)
123         // uid
124         le.PutUint64(id, p.Uid)
125         res.Write(id)
126         // gid
127         le.PutUint64(id, p.Gid)
128         res.Write(id)
129
130         off := make([]byte, 2)
131         // off_passwd
132         le.PutUint16(off, offPasswd)
133         res.Write(off)
134         // off_gecos
135         le.PutUint16(off, offGecos)
136         res.Write(off)
137         // off_dir
138         le.PutUint16(off, offDir)
139         res.Write(off)
140         // off_shell
141         le.PutUint16(off, offShell)
142         res.Write(off)
143         // data_size
144         le.PutUint16(off, size)
145         res.Write(off)
146
147         res.Write(data.Bytes())
148         // We must pad each entry so that all uint64 at the beginning of the
149         // struct are 8 byte aligned
150         alignBufferTo(&res, 8)
151
152         return res.Bytes(), nil
153 }
154
155 func SerializePasswds(w io.Writer, pws []Passwd) error {
156         // Serialize passwords and store offsets
157         var data bytes.Buffer
158         offsets := make(map[Passwd]uint64)
159         for _, p := range pws {
160                 // TODO: warn about duplicate entries
161                 offsets[p] = uint64(data.Len())
162                 x, err := SerializePasswd(p)
163                 if err != nil {
164                         return err
165                 }
166                 data.Write(x)
167         }
168
169         // Copy to prevent sorting from modifying the argument
170         sorted := make([]Passwd, len(pws))
171         copy(sorted, pws)
172
173         le := binary.LittleEndian
174         tmp := make([]byte, 8)
175
176         // Create index "sorted" in input order, used when iterating over all
177         // passwd entries (getpwent_r); keeping the original order makes
178         // debugging easier
179         var indexOrig bytes.Buffer
180         for _, p := range pws {
181                 le.PutUint64(tmp, offsets[p])
182                 indexOrig.Write(tmp)
183         }
184
185         // Create index sorted after id
186         var indexId bytes.Buffer
187         sort.Slice(sorted, func(i, j int) bool {
188                 return sorted[i].Uid < sorted[j].Uid
189         })
190         for _, p := range sorted {
191                 le.PutUint64(tmp, offsets[p])
192                 indexId.Write(tmp)
193         }
194
195         // Create index sorted after name
196         var indexName bytes.Buffer
197         sort.Slice(sorted, func(i, j int) bool {
198                 return sorted[i].Name < sorted[j].Name
199         })
200         for _, p := range sorted {
201                 le.PutUint64(tmp, offsets[p])
202                 indexName.Write(tmp)
203         }
204
205         // Sanity check
206         if len(pws)*8 != indexOrig.Len() ||
207                 indexOrig.Len() != indexId.Len() ||
208                 indexId.Len() != indexName.Len() {
209                 return fmt.Errorf("indexes have inconsistent length")
210         }
211
212         // Write result
213
214         // magic
215         w.Write([]byte("NSS-CASH"))
216         // version
217         le.PutUint64(tmp, PasswdVersion)
218         w.Write(tmp)
219         // count
220         le.PutUint64(tmp, uint64(len(pws)))
221         w.Write(tmp)
222         // off_orig_index
223         offset := uint64(0)
224         le.PutUint64(tmp, offset)
225         w.Write(tmp)
226         // off_id_index
227         offset += uint64(indexOrig.Len())
228         le.PutUint64(tmp, offset)
229         w.Write(tmp)
230         // off_name_index
231         offset += uint64(indexId.Len())
232         le.PutUint64(tmp, offset)
233         w.Write(tmp)
234         // off_data
235         offset += uint64(indexName.Len())
236         le.PutUint64(tmp, offset)
237         w.Write(tmp)
238
239         _, err := indexOrig.WriteTo(w)
240         if err != nil {
241                 return err
242         }
243         _, err = indexId.WriteTo(w)
244         if err != nil {
245                 return err
246         }
247         _, err = indexName.WriteTo(w)
248         if err != nil {
249                 return err
250         }
251         _, err = data.WriteTo(w)
252         if err != nil {
253                 return err
254         }
255
256         return nil
257 }