]> ruderich.org/simon Gitweb - nsscash/nsscash.git/blob - passwd.go
nss: Makefile: don't link against asan
[nsscash/nsscash.git] / passwd.go
1 // Parse /etc/passwd files and serialize them
2
3 // Copyright (C) 2019  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU Affero General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU Affero General Public License for more details.
14 //
15 // You should have received a copy of the GNU Affero General Public License
16 // along with this program.  If not, see <https://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "bufio"
22         "bytes"
23         "encoding/binary"
24         "fmt"
25         "io"
26         "math"
27         "sort"
28         "strconv"
29         "strings"
30
31         "github.com/pkg/errors"
32 )
33
34 // Version written in SerializePasswds()
35 const PasswdVersion = 1
36
37 type Passwd struct {
38         Name   string
39         Passwd string
40         Uid    uint64
41         Gid    uint64
42         Gecos  string
43         Dir    string
44         Shell  string
45 }
46
47 // ParsePasswds parses a file in the format of /etc/passwd and returns all
48 // entries as slice of Passwd structs.
49 func ParsePasswds(r io.Reader) ([]Passwd, error) {
50         var res []Passwd
51
52         s := bufio.NewReader(r)
53         for {
54                 t, err := s.ReadString('\n')
55                 if err != nil {
56                         if err == io.EOF {
57                                 break
58                         }
59                         return nil, err
60                 }
61
62                 x := strings.Split(t, ":")
63                 if len(x) != 7 {
64                         return nil, fmt.Errorf("invalid line %q", t)
65                 }
66
67                 uid, err := strconv.ParseUint(x[2], 10, 64)
68                 if err != nil {
69                         return nil, errors.Wrapf(err, "invalid uid in line %q", t)
70                 }
71                 gid, err := strconv.ParseUint(x[3], 10, 64)
72                 if err != nil {
73                         return nil, errors.Wrapf(err, "invalid gid in line %q", t)
74                 }
75
76                 res = append(res, Passwd{
77                         Name:   x[0],
78                         Passwd: x[1],
79                         Uid:    uid,
80                         Gid:    gid,
81                         Gecos:  x[4],
82                         Dir:    x[5],
83                         // ReadString() contains the delimiter
84                         Shell: strings.TrimSuffix(x[6], "\n"),
85                 })
86         }
87         return res, nil
88 }
89
90 func SerializePasswd(p Passwd) ([]byte, error) {
91         // Concatenate all (NUL-terminated) strings and store the offsets
92         var data bytes.Buffer
93         data.Write([]byte(p.Name))
94         data.WriteByte(0)
95         offPasswd := uint16(data.Len())
96         data.Write([]byte(p.Passwd))
97         data.WriteByte(0)
98         offGecos := uint16(data.Len())
99         data.Write([]byte(p.Gecos))
100         data.WriteByte(0)
101         offDir := uint16(data.Len())
102         data.Write([]byte(p.Dir))
103         data.WriteByte(0)
104         offShell := uint16(data.Len())
105         data.Write([]byte(p.Shell))
106         data.WriteByte(0)
107         // Ensure the offsets can fit the length of this entry
108         if data.Len() > math.MaxUint16 {
109                 return nil, fmt.Errorf("passwd too large to serialize: %v, %v",
110                         data.Len(), p)
111         }
112         size := uint16(data.Len())
113
114         var res bytes.Buffer // serialized result
115         le := binary.LittleEndian
116
117         id := make([]byte, 8)
118         // uid
119         le.PutUint64(id, p.Uid)
120         res.Write(id)
121         // gid
122         le.PutUint64(id, p.Gid)
123         res.Write(id)
124
125         off := make([]byte, 2)
126         // off_passwd
127         le.PutUint16(off, offPasswd)
128         res.Write(off)
129         // off_gecos
130         le.PutUint16(off, offGecos)
131         res.Write(off)
132         // off_dir
133         le.PutUint16(off, offDir)
134         res.Write(off)
135         // off_shell
136         le.PutUint16(off, offShell)
137         res.Write(off)
138         // data_size
139         le.PutUint16(off, size)
140         res.Write(off)
141
142         res.Write(data.Bytes())
143         // We must pad each entry so that all uint64 at the beginning of the
144         // struct are 8 byte aligned
145         alignBufferTo(&res, 8)
146
147         return res.Bytes(), nil
148 }
149
150 func SerializePasswds(w io.Writer, pws []Passwd) error {
151         // Serialize passwords and store offsets
152         var data bytes.Buffer
153         offsets := make(map[Passwd]uint64)
154         for _, p := range pws {
155                 // TODO: warn about duplicate entries
156                 offsets[p] = uint64(data.Len())
157                 x, err := SerializePasswd(p)
158                 if err != nil {
159                         return err
160                 }
161                 data.Write(x)
162         }
163
164         // Copy to prevent sorting from modifying the argument
165         sorted := make([]Passwd, len(pws))
166         copy(sorted, pws)
167
168         le := binary.LittleEndian
169         tmp := make([]byte, 8)
170
171         // Create index "sorted" in input order, used when iterating over all
172         // passwd entries (getpwent_r); keeping the original order makes
173         // debugging easier
174         var indexOrig bytes.Buffer
175         for _, p := range pws {
176                 le.PutUint64(tmp, offsets[p])
177                 indexOrig.Write(tmp)
178         }
179
180         // Create index sorted after id
181         var indexId bytes.Buffer
182         sort.Slice(sorted, func(i, j int) bool {
183                 return sorted[i].Uid < sorted[j].Uid
184         })
185         for _, p := range sorted {
186                 le.PutUint64(tmp, offsets[p])
187                 indexId.Write(tmp)
188         }
189
190         // Create index sorted after name
191         var indexName bytes.Buffer
192         sort.Slice(sorted, func(i, j int) bool {
193                 return sorted[i].Name < sorted[j].Name
194         })
195         for _, p := range sorted {
196                 le.PutUint64(tmp, offsets[p])
197                 indexName.Write(tmp)
198         }
199
200         // Sanity check
201         if len(pws)*8 != indexOrig.Len() ||
202                 indexOrig.Len() != indexId.Len() ||
203                 indexId.Len() != indexName.Len() {
204                 return fmt.Errorf("indexes have inconsistent length")
205         }
206
207         // Write result
208
209         // magic
210         w.Write([]byte("NSS-CASH"))
211         // version
212         le.PutUint64(tmp, PasswdVersion)
213         w.Write(tmp)
214         // count
215         le.PutUint64(tmp, uint64(len(pws)))
216         w.Write(tmp)
217         // off_orig_index
218         offset := uint64(0)
219         le.PutUint64(tmp, offset)
220         w.Write(tmp)
221         // off_id_index
222         offset += uint64(indexOrig.Len())
223         le.PutUint64(tmp, offset)
224         w.Write(tmp)
225         // off_name_index
226         offset += uint64(indexId.Len())
227         le.PutUint64(tmp, offset)
228         w.Write(tmp)
229         // off_data
230         offset += uint64(indexName.Len())
231         le.PutUint64(tmp, offset)
232         w.Write(tmp)
233
234         _, err := indexOrig.WriteTo(w)
235         if err != nil {
236                 return err
237         }
238         _, err = indexId.WriteTo(w)
239         if err != nil {
240                 return err
241         }
242         _, err = indexName.WriteTo(w)
243         if err != nil {
244                 return err
245         }
246         _, err = data.WriteTo(w)
247         if err != nil {
248                 return err
249         }
250
251         return nil
252 }