]> ruderich.org/simon Gitweb - linux-network-namespace-labs/linux-network-namespace-labs.git/commitdiff
Move setup code to separate file
authorSimon Ruderich <simon@ruderich.org>
Mon, 4 Nov 2024 07:35:16 +0000 (08:35 +0100)
committerSimon Ruderich <simon@ruderich.org>
Mon, 4 Nov 2024 07:35:16 +0000 (08:35 +0100)
main.go
setup.go [new file with mode: 0644]

diff --git a/main.go b/main.go
index 0a1428beb07b7c3200d871f772e738fec4f5b404..eb3f15075a3c3d04a029081522f61c9e5c3420ea 100644 (file)
--- a/main.go
+++ b/main.go
@@ -8,16 +8,8 @@
 package main
 
 import (
-       "bufio"
-       "encoding/json"
-       "fmt"
        "log"
        "os"
-       "os/exec"
-       "path/filepath"
-       "strconv"
-       "strings"
-       "syscall"
 )
 
 func main() {
@@ -39,228 +31,7 @@ func main() {
                }
        }
 
-       for _, node := range cfg.Nodes {
-               log.Printf("Setting up node %q ...", node.Name)
-
-               // One namespace per node, named as the node's name
-               ns := node.Name
-               if netnsExists(ns) {
-                       // Terminate processes in old namespaces
-                       for _, x := range netnsPids(ns) {
-                               log.Printf("  Killing old PID %d", x)
-                               err := syscall.Kill(x, syscall.SIGTERM)
-                               if err != nil {
-                                       log.Fatalf("failed to kill %d: %v", x, err)
-                               }
-                               // Also try SIGHUP to terminate shells which ignore SIGTERM
-                               _ = syscall.Kill(x, syscall.SIGHUP)
-                       }
-                       // Prevent any conflicts with existing data
-                       ip("netns", "del", ns)
-                       // Don't remove anything in /etc/netns/ as the user might store
-                       // configuration there!
-               }
-               ip("netns", "add", ns)
-
-               // Write /etc/netns/$netns/hosts with all known hosts
-               nsCfgPath := filepath.Join("/etc/netns", ns)
-               nsHostsPath := filepath.Join(nsCfgPath, "hosts")
-               log.Printf("  Writing %q", nsHostsPath)
-               err := os.MkdirAll(nsCfgPath, 0755)
-               if err != nil {
-                       log.Fatal(err)
-               }
-               err = writeHosts(cfg, nsHostsPath)
-               if err != nil {
-                       log.Fatal(err)
-               }
-
-               // Enable IPv4 and IPv6 forwarding
-               ip("netns", "exec", ns, "sysctl", "-q", "net.ipv4.ip_forward=1")
-               ip("netns", "exec", ns, "sysctl", "-q", "net.ipv6.conf.all.forwarding=1")
-
-               ip("-n", ns, "link", "set", "lo", "up")
-               // Extra interface for our loopback addresses; keeping them separate
-               // can make things easier (e.g. using the interface for protocols).
-               lo := "lo2"
-               ip("-n", ns, "link", "add", lo, "type", "dummy")
-               ip("-n", ns, "link", "set", lo, "up")
-               for _, x := range node.Loopbacks {
-                       ip("-n", ns, "addr", "add", x.String(), "dev", lo)
-               }
-       }
-
-       log.Printf("Setting up links ...")
-       for _, link := range cfg.Links {
-               nsa := link.A.Node.Name
-               nsb := link.B.Node.Name
-
-               log.Printf(" Link between %q and %q ...", nsa, nsb)
-
-               // Use name of other node for the interface
-               la := nsb
-               lb := nsa
-               // Support multiple links between nodes
-               if ifaceExists(nsa, la) {
-                       la = nextFreeIface(nsa, la)
-               }
-               if ifaceExists(nsb, lb) {
-                       lb = nextFreeIface(nsb, lb)
-               }
-
-               t := "veth"
-               if link.Layer3 {
-                       t = "netkit" // since kernel 6.7, iproute2 6.8
-               }
-               ip("link", "add", "tmpa", "type", t, "peer", "name", "tmpb")
-               ip("link", "set", "tmpa", "netns", nsa)
-               ip("link", "set", "tmpb", "netns", nsb)
-               ip("-n", nsa, "link", "set", "tmpa", "name", la)
-               ip("-n", nsb, "link", "set", "tmpb", "name", lb)
-               for _, x := range link.A.Addrs {
-                       ip("-n", nsa, "addr", "add", x.String(), "dev", la)
-               }
-               for _, x := range link.B.Addrs {
-                       ip("-n", nsb, "addr", "add", x.String(), "dev", lb)
-               }
-               ip("-n", nsa, "link", "set", la, "up")
-               ip("-n", nsb, "link", "set", lb, "up")
-       }
-
-       log.Printf("Starting cmds ...")
-       for _, node := range cfg.Nodes {
-               log.Printf(" For node %q", node.Name)
-               ns := node.Name
-               for _, cmd := range cfg.Cmds {
-                       ip("netns", "exec", ns, "sh", "-c", cmd, "argv0", ns)
-               }
-       }
-}
-
-func ip(args ...string) {
-       xargs := append([]string{"ip"}, args...)
-       log.Printf("  Running %q", xargs)
-
-       cmd := exec.Command("ip", args...)
-       cmd.Stdout = os.Stdout
-       cmd.Stderr = os.Stderr
-       err := cmd.Run()
-       if err != nil {
-               log.Fatalf("failed to run %q: %v", xargs, err)
-       }
-}
-
-func netnsExists(name string) bool {
-       _, err := os.Stat(filepath.Join("/run/netns", name))
-       if err != nil {
-               if os.IsNotExist(err) {
-                       return false
-               }
-               log.Fatal(err)
-       }
-       return true
-}
-
-func netnsPids(netns string) []int {
-       args := []string{"netns", "pids", netns}
-       xargs := append([]string{"ip"}, args...)
-
-       cmd := exec.Command("ip", args...)
-       out, err := cmd.Output()
-       if err != nil {
-               log.Fatalf("failed to run %q: %v", xargs, err)
-       }
-
-       isNewline := func(c rune) bool {
-               return c == '\n'
-       }
-
-       var res []int
-       xs := strings.FieldsFunc(string(out), isNewline)
-       for _, x := range xs {
-               y, err := strconv.Atoi(x)
-               if err != nil {
-                       log.Fatalf("invalid output from %q: %q: %v", xargs, out, err)
-               }
-               res = append(res, y)
-       }
-       return res
-}
-
-func ifaceExists(netns, name string) bool {
-       args := []string{"-n", netns, "-json", "link"}
-       xargs := append([]string{"ip"}, args...)
-
-       cmd := exec.Command("ip", args...)
-       out, err := cmd.Output()
-       if err != nil {
-               log.Fatalf("failed to run %q: %v", xargs, err)
-       }
-
-       var ifaces []struct {
-               Ifname string `json:"ifname"`
-       }
-       err = json.Unmarshal(out, &ifaces)
-       if err != nil {
-               log.Fatalf("failed to parse output from ip (%q): %v", out, err)
-       }
-
-       for _, x := range ifaces {
-               if x.Ifname == name {
-                       return true
-               }
-       }
-       return false
-}
-
-func nextFreeIface(netns, name string) string {
-       i := 2
-       x := name
-       for ifaceExists(netns, x) {
-               x = fmt.Sprintf("%s_%d", name, i)
-               i++
-       }
-       return x
-}
-
-// writeHosts writes a hosts-file (i.e. /etc/hosts) with all known addresses
-// and their corresponding node names.
-func writeHosts(cfg *Config, path string) error {
-       f, err := os.Create(path)
-       if err != nil {
-               return err
-       }
-       defer f.Close()
-       w := bufio.NewWriter(f)
-
-       // Standard entries
-       fmt.Fprintf(w, "127.0.0.1 localhost\n")
-       fmt.Fprintf(w, "::1 localhost ip6-localhost ip6-loopback\n")
-
-       for _, node := range cfg.Nodes {
-               for _, x := range node.Loopbacks {
-                       fmt.Fprintf(w, "%s %s-loop\n", x.String(), node.Name)
-               }
-       }
-
-       for _, link := range cfg.Links {
-               for _, x := range link.A.Addrs {
-                       fmt.Fprintf(w, "%s %s\n", x.Addr().String(), link.A.Node.Name)
-               }
-               for _, x := range link.B.Addrs {
-                       fmt.Fprintf(w, "%s %s\n", x.Addr().String(), link.B.Node.Name)
-               }
-       }
-
-       err = w.Flush()
-       if err != nil {
-               return err
-       }
-       err = f.Sync()
-       if err != nil {
-               return err
-       }
-       return nil
+       mustUp(cfg)
 }
 
 // vi: set noet ts=4 sw=4 sts=4:
diff --git a/setup.go b/setup.go
new file mode 100644 (file)
index 0000000..7d17520
--- /dev/null
+++ b/setup.go
@@ -0,0 +1,246 @@
+// Apply configuration and setup everything.
+
+// SPDX-License-Identifier: GPL-3.0-or-later
+// Copyright (C) 2024  Simon Ruderich
+
+package main
+
+import (
+       "bufio"
+       "encoding/json"
+       "fmt"
+       "log"
+       "os"
+       "os/exec"
+       "path/filepath"
+       "strconv"
+       "strings"
+       "syscall"
+)
+
+func mustUp(cfg *Config) {
+       for _, node := range cfg.Nodes {
+               log.Printf("Setting up node %q ...", node.Name)
+
+               // One namespace per node, named as the node's name
+               ns := node.Name
+               if netnsExists(ns) {
+                       // Terminate processes in old namespaces
+                       for _, x := range netnsPids(ns) {
+                               log.Printf("  Killing old PID %d", x)
+                               err := syscall.Kill(x, syscall.SIGTERM)
+                               if err != nil {
+                                       log.Fatalf("failed to kill %d: %v", x, err)
+                               }
+                               // Also try SIGHUP to terminate shells which ignore SIGTERM
+                               _ = syscall.Kill(x, syscall.SIGHUP)
+                       }
+                       // Prevent any conflicts with existing data
+                       ip("netns", "del", ns)
+                       // Don't remove anything in /etc/netns/ as the user might store
+                       // configuration there!
+               }
+               ip("netns", "add", ns)
+
+               // Write /etc/netns/$netns/hosts with all known hosts
+               nsCfgPath := filepath.Join("/etc/netns", ns)
+               nsHostsPath := filepath.Join(nsCfgPath, "hosts")
+               log.Printf("  Writing %q", nsHostsPath)
+               err := os.MkdirAll(nsCfgPath, 0755)
+               if err != nil {
+                       log.Fatal(err)
+               }
+               err = writeHosts(cfg, nsHostsPath)
+               if err != nil {
+                       log.Fatal(err)
+               }
+
+               // Enable IPv4 and IPv6 forwarding
+               ip("netns", "exec", ns, "sysctl", "-q", "net.ipv4.ip_forward=1")
+               ip("netns", "exec", ns, "sysctl", "-q", "net.ipv6.conf.all.forwarding=1")
+
+               ip("-n", ns, "link", "set", "lo", "up")
+               // Extra interface for our loopback addresses; keeping them separate
+               // can make things easier (e.g. using the interface for protocols).
+               lo := "lo2"
+               ip("-n", ns, "link", "add", lo, "type", "dummy")
+               ip("-n", ns, "link", "set", lo, "up")
+               for _, x := range node.Loopbacks {
+                       ip("-n", ns, "addr", "add", x.String(), "dev", lo)
+               }
+       }
+
+       log.Printf("Setting up links ...")
+       for _, link := range cfg.Links {
+               nsa := link.A.Node.Name
+               nsb := link.B.Node.Name
+
+               log.Printf(" Link between %q and %q ...", nsa, nsb)
+
+               // Use name of other node for the interface
+               la := nsb
+               lb := nsa
+               // Support multiple links between nodes
+               if ifaceExists(nsa, la) {
+                       la = nextFreeIface(nsa, la)
+               }
+               if ifaceExists(nsb, lb) {
+                       lb = nextFreeIface(nsb, lb)
+               }
+
+               t := "veth"
+               if link.Layer3 {
+                       t = "netkit" // since kernel 6.7, iproute2 6.8
+               }
+               ip("link", "add", "tmpa", "type", t, "peer", "name", "tmpb")
+               ip("link", "set", "tmpa", "netns", nsa)
+               ip("link", "set", "tmpb", "netns", nsb)
+               ip("-n", nsa, "link", "set", "tmpa", "name", la)
+               ip("-n", nsb, "link", "set", "tmpb", "name", lb)
+               for _, x := range link.A.Addrs {
+                       ip("-n", nsa, "addr", "add", x.String(), "dev", la)
+               }
+               for _, x := range link.B.Addrs {
+                       ip("-n", nsb, "addr", "add", x.String(), "dev", lb)
+               }
+               ip("-n", nsa, "link", "set", la, "up")
+               ip("-n", nsb, "link", "set", lb, "up")
+       }
+
+       log.Printf("Starting cmds ...")
+       for _, node := range cfg.Nodes {
+               log.Printf(" For node %q", node.Name)
+               ns := node.Name
+               for _, cmd := range cfg.Cmds {
+                       ip("netns", "exec", ns, "sh", "-c", cmd, "argv0", ns)
+               }
+       }
+}
+
+func ip(args ...string) {
+       xargs := append([]string{"ip"}, args...)
+       log.Printf("  Running %q", xargs)
+
+       cmd := exec.Command("ip", args...)
+       cmd.Stdout = os.Stdout
+       cmd.Stderr = os.Stderr
+       err := cmd.Run()
+       if err != nil {
+               log.Fatalf("failed to run %q: %v", xargs, err)
+       }
+}
+
+func netnsExists(name string) bool {
+       _, err := os.Stat(filepath.Join("/run/netns", name))
+       if err != nil {
+               if os.IsNotExist(err) {
+                       return false
+               }
+               log.Fatal(err)
+       }
+       return true
+}
+
+func netnsPids(netns string) []int {
+       args := []string{"netns", "pids", netns}
+       xargs := append([]string{"ip"}, args...)
+
+       cmd := exec.Command("ip", args...)
+       out, err := cmd.Output()
+       if err != nil {
+               log.Fatalf("failed to run %q: %v", xargs, err)
+       }
+
+       isNewline := func(c rune) bool {
+               return c == '\n'
+       }
+
+       var res []int
+       xs := strings.FieldsFunc(string(out), isNewline)
+       for _, x := range xs {
+               y, err := strconv.Atoi(x)
+               if err != nil {
+                       log.Fatalf("invalid output from %q: %q: %v", xargs, out, err)
+               }
+               res = append(res, y)
+       }
+       return res
+}
+
+func ifaceExists(netns, name string) bool {
+       args := []string{"-n", netns, "-json", "link"}
+       xargs := append([]string{"ip"}, args...)
+
+       cmd := exec.Command("ip", args...)
+       out, err := cmd.Output()
+       if err != nil {
+               log.Fatalf("failed to run %q: %v", xargs, err)
+       }
+
+       var ifaces []struct {
+               Ifname string `json:"ifname"`
+       }
+       err = json.Unmarshal(out, &ifaces)
+       if err != nil {
+               log.Fatalf("failed to parse output from ip (%q): %v", out, err)
+       }
+
+       for _, x := range ifaces {
+               if x.Ifname == name {
+                       return true
+               }
+       }
+       return false
+}
+
+func nextFreeIface(netns, name string) string {
+       i := 2
+       x := name
+       for ifaceExists(netns, x) {
+               x = fmt.Sprintf("%s_%d", name, i)
+               i++
+       }
+       return x
+}
+
+// writeHosts writes a hosts-file (i.e. /etc/hosts) with all known addresses
+// and their corresponding node names.
+func writeHosts(cfg *Config, path string) error {
+       f, err := os.Create(path)
+       if err != nil {
+               return err
+       }
+       defer f.Close()
+       w := bufio.NewWriter(f)
+
+       // Standard entries
+       fmt.Fprintf(w, "127.0.0.1 localhost\n")
+       fmt.Fprintf(w, "::1 localhost ip6-localhost ip6-loopback\n")
+
+       for _, node := range cfg.Nodes {
+               for _, x := range node.Loopbacks {
+                       fmt.Fprintf(w, "%s %s-loop\n", x.String(), node.Name)
+               }
+       }
+
+       for _, link := range cfg.Links {
+               for _, x := range link.A.Addrs {
+                       fmt.Fprintf(w, "%s %s\n", x.Addr().String(), link.A.Node.Name)
+               }
+               for _, x := range link.B.Addrs {
+                       fmt.Fprintf(w, "%s %s\n", x.Addr().String(), link.B.Node.Name)
+               }
+       }
+
+       err = w.Flush()
+       if err != nil {
+               return err
+       }
+       err = f.Sync()
+       if err != nil {
+               return err
+       }
+       return nil
+}
+
+// vi: set noet ts=4 sw=4 sts=4: