]> ruderich.org/simon Gitweb - nsscash/nsscash.git/blobdiff - passwd.go
.github: update upstream actions to latest version
[nsscash/nsscash.git] / passwd.go
index c07e9c73914d378d8964756994542f92eeefb3b1..6da64762ad2dae302dd9e8c6c097bec08e0a46fb 100644 (file)
--- a/passwd.go
+++ b/passwd.go
@@ -1,6 +1,6 @@
 // Parse /etc/passwd files and serialize them
 
-// Copyright (C) 2019  Simon Ruderich
+// Copyright (C) 2019-2021  Simon Ruderich
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -23,6 +23,7 @@ import (
        "encoding/binary"
        "fmt"
        "io"
+       "math"
        "sort"
        "strconv"
        "strings"
@@ -44,13 +45,24 @@ type Passwd struct {
 }
 
 // ParsePasswds parses a file in the format of /etc/passwd and returns all
-// entries as Passwd structs.
+// entries as slice of Passwd structs.
 func ParsePasswds(r io.Reader) ([]Passwd, error) {
        var res []Passwd
 
-       s := bufio.NewScanner(r)
-       for s.Scan() {
-               t := s.Text()
+       s := bufio.NewReader(r)
+       for {
+               t, err := s.ReadString('\n')
+               if err != nil {
+                       if err == io.EOF {
+                               if t != "" {
+                                       return nil, fmt.Errorf(
+                                               "no newline in last line: %q",
+                                               t)
+                               }
+                               break
+                       }
+                       return nil, err
+               }
 
                x := strings.Split(t, ":")
                if len(x) != 7 {
@@ -73,18 +85,14 @@ func ParsePasswds(r io.Reader) ([]Passwd, error) {
                        Gid:    gid,
                        Gecos:  x[4],
                        Dir:    x[5],
-                       Shell:  x[6],
+                       // ReadString() contains the delimiter
+                       Shell: strings.TrimSuffix(x[6], "\n"),
                })
        }
-       err := s.Err()
-       if err != nil {
-               return nil, err
-       }
-
        return res, nil
 }
 
-func SerializePasswd(p Passwd) []byte {
+func SerializePasswd(p Passwd) ([]byte, error) {
        // Concatenate all (NUL-terminated) strings and store the offsets
        var data bytes.Buffer
        data.Write([]byte(p.Name))
@@ -101,6 +109,11 @@ func SerializePasswd(p Passwd) []byte {
        offShell := uint16(data.Len())
        data.Write([]byte(p.Shell))
        data.WriteByte(0)
+       // Ensure the offsets can fit the length of this entry
+       if data.Len() > math.MaxUint16 {
+               return nil, fmt.Errorf("passwd too large to serialize: %v, %v",
+                       data.Len(), p)
+       }
        size := uint16(data.Len())
 
        var res bytes.Buffer // serialized result
@@ -134,14 +147,9 @@ func SerializePasswd(p Passwd) []byte {
        res.Write(data.Bytes())
        // We must pad each entry so that all uint64 at the beginning of the
        // struct are 8 byte aligned
-       l := res.Len()
-       if l%8 != 0 {
-               for i := 0; i < 8-l%8; i++ {
-                       res.Write([]byte{'0'})
-               }
-       }
+       alignBufferTo(&res, 8)
 
-       return res.Bytes()
+       return res.Bytes(), nil
 }
 
 func SerializePasswds(w io.Writer, pws []Passwd) error {
@@ -151,7 +159,11 @@ func SerializePasswds(w io.Writer, pws []Passwd) error {
        for _, p := range pws {
                // TODO: warn about duplicate entries
                offsets[p] = uint64(data.Len())
-               data.Write(SerializePasswd(p))
+               x, err := SerializePasswd(p)
+               if err != nil {
+                       return err
+               }
+               data.Write(x)
        }
 
        // Copy to prevent sorting from modifying the argument
@@ -191,7 +203,8 @@ func SerializePasswds(w io.Writer, pws []Passwd) error {
        }
 
        // Sanity check
-       if indexOrig.Len() != indexId.Len() ||
+       if len(pws)*8 != indexOrig.Len() ||
+               indexOrig.Len() != indexId.Len() ||
                indexId.Len() != indexName.Len() {
                return fmt.Errorf("indexes have inconsistent length")
        }