]> ruderich.org/simon Gitweb - nsscash/nsscash.git/blob - passwd.go
nsscash: handle errors in SerializePasswd(), SerializeGroup()
[nsscash/nsscash.git] / passwd.go
1 // Parse /etc/passwd files and serialize them
2
3 // Copyright (C) 2019  Simon Ruderich
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU Affero General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU Affero General Public License for more details.
14 //
15 // You should have received a copy of the GNU Affero General Public License
16 // along with this program.  If not, see <https://www.gnu.org/licenses/>.
17
18 package main
19
20 import (
21         "bufio"
22         "bytes"
23         "encoding/binary"
24         "fmt"
25         "io"
26         "sort"
27         "strconv"
28         "strings"
29
30         "github.com/pkg/errors"
31 )
32
33 // Version written in SerializePasswds()
34 const PasswdVersion = 1
35
36 type Passwd struct {
37         Name   string
38         Passwd string
39         Uid    uint64
40         Gid    uint64
41         Gecos  string
42         Dir    string
43         Shell  string
44 }
45
46 // ParsePasswds parses a file in the format of /etc/passwd and returns all
47 // entries as slice of Passwd structs.
48 func ParsePasswds(r io.Reader) ([]Passwd, error) {
49         var res []Passwd
50
51         s := bufio.NewReader(r)
52         for {
53                 t, err := s.ReadString('\n')
54                 if err != nil {
55                         if err == io.EOF {
56                                 break
57                         }
58                         return nil, err
59                 }
60
61                 x := strings.Split(t, ":")
62                 if len(x) != 7 {
63                         return nil, fmt.Errorf("invalid line %q", t)
64                 }
65
66                 uid, err := strconv.ParseUint(x[2], 10, 64)
67                 if err != nil {
68                         return nil, errors.Wrapf(err, "invalid uid in line %q", t)
69                 }
70                 gid, err := strconv.ParseUint(x[3], 10, 64)
71                 if err != nil {
72                         return nil, errors.Wrapf(err, "invalid gid in line %q", t)
73                 }
74
75                 res = append(res, Passwd{
76                         Name:   x[0],
77                         Passwd: x[1],
78                         Uid:    uid,
79                         Gid:    gid,
80                         Gecos:  x[4],
81                         Dir:    x[5],
82                         // ReadString() contains the delimiter
83                         Shell: strings.TrimSuffix(x[6], "\n"),
84                 })
85         }
86         return res, nil
87 }
88
89 func SerializePasswd(p Passwd) ([]byte, error) {
90         // Concatenate all (NUL-terminated) strings and store the offsets
91         var data bytes.Buffer
92         data.Write([]byte(p.Name))
93         data.WriteByte(0)
94         offPasswd := uint16(data.Len())
95         data.Write([]byte(p.Passwd))
96         data.WriteByte(0)
97         offGecos := uint16(data.Len())
98         data.Write([]byte(p.Gecos))
99         data.WriteByte(0)
100         offDir := uint16(data.Len())
101         data.Write([]byte(p.Dir))
102         data.WriteByte(0)
103         offShell := uint16(data.Len())
104         data.Write([]byte(p.Shell))
105         data.WriteByte(0)
106         size := uint16(data.Len())
107
108         var res bytes.Buffer // serialized result
109         le := binary.LittleEndian
110
111         id := make([]byte, 8)
112         // uid
113         le.PutUint64(id, p.Uid)
114         res.Write(id)
115         // gid
116         le.PutUint64(id, p.Gid)
117         res.Write(id)
118
119         off := make([]byte, 2)
120         // off_passwd
121         le.PutUint16(off, offPasswd)
122         res.Write(off)
123         // off_gecos
124         le.PutUint16(off, offGecos)
125         res.Write(off)
126         // off_dir
127         le.PutUint16(off, offDir)
128         res.Write(off)
129         // off_shell
130         le.PutUint16(off, offShell)
131         res.Write(off)
132         // data_size
133         le.PutUint16(off, size)
134         res.Write(off)
135
136         res.Write(data.Bytes())
137         // We must pad each entry so that all uint64 at the beginning of the
138         // struct are 8 byte aligned
139         alignBufferTo(&res, 8)
140
141         return res.Bytes(), nil
142 }
143
144 func SerializePasswds(w io.Writer, pws []Passwd) error {
145         // Serialize passwords and store offsets
146         var data bytes.Buffer
147         offsets := make(map[Passwd]uint64)
148         for _, p := range pws {
149                 // TODO: warn about duplicate entries
150                 offsets[p] = uint64(data.Len())
151                 x, err := SerializePasswd(p)
152                 if err != nil {
153                         return err
154                 }
155                 data.Write(x)
156         }
157
158         // Copy to prevent sorting from modifying the argument
159         sorted := make([]Passwd, len(pws))
160         copy(sorted, pws)
161
162         le := binary.LittleEndian
163         tmp := make([]byte, 8)
164
165         // Create index "sorted" in input order, used when iterating over all
166         // passwd entries (getpwent_r); keeping the original order makes
167         // debugging easier
168         var indexOrig bytes.Buffer
169         for _, p := range pws {
170                 le.PutUint64(tmp, offsets[p])
171                 indexOrig.Write(tmp)
172         }
173
174         // Create index sorted after id
175         var indexId bytes.Buffer
176         sort.Slice(sorted, func(i, j int) bool {
177                 return sorted[i].Uid < sorted[j].Uid
178         })
179         for _, p := range sorted {
180                 le.PutUint64(tmp, offsets[p])
181                 indexId.Write(tmp)
182         }
183
184         // Create index sorted after name
185         var indexName bytes.Buffer
186         sort.Slice(sorted, func(i, j int) bool {
187                 return sorted[i].Name < sorted[j].Name
188         })
189         for _, p := range sorted {
190                 le.PutUint64(tmp, offsets[p])
191                 indexName.Write(tmp)
192         }
193
194         // Sanity check
195         if len(pws)*8 != indexOrig.Len() ||
196                 indexOrig.Len() != indexId.Len() ||
197                 indexId.Len() != indexName.Len() {
198                 return fmt.Errorf("indexes have inconsistent length")
199         }
200
201         // Write result
202
203         // magic
204         w.Write([]byte("NSS-CASH"))
205         // version
206         le.PutUint64(tmp, PasswdVersion)
207         w.Write(tmp)
208         // count
209         le.PutUint64(tmp, uint64(len(pws)))
210         w.Write(tmp)
211         // off_orig_index
212         offset := uint64(0)
213         le.PutUint64(tmp, offset)
214         w.Write(tmp)
215         // off_id_index
216         offset += uint64(indexOrig.Len())
217         le.PutUint64(tmp, offset)
218         w.Write(tmp)
219         // off_name_index
220         offset += uint64(indexId.Len())
221         le.PutUint64(tmp, offset)
222         w.Write(tmp)
223         // off_data
224         offset += uint64(indexName.Len())
225         le.PutUint64(tmp, offset)
226         w.Write(tmp)
227
228         _, err := indexOrig.WriteTo(w)
229         if err != nil {
230                 return err
231         }
232         _, err = indexId.WriteTo(w)
233         if err != nil {
234                 return err
235         }
236         _, err = indexName.WriteTo(w)
237         if err != nil {
238                 return err
239         }
240         _, err = data.WriteTo(w)
241         if err != nil {
242                 return err
243         }
244
245         return nil
246 }