kilo/pkg/wireguard/conf.go
Julien Viard de Galbert f8f7d0b0c7 add skip on error
2021-07-06 10:22:43 +02:00

565 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2019 the Kilo authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wireguard
import (
"bufio"
"bytes"
"errors"
"fmt"
"net"
"sort"
"strconv"
"strings"
"time"
"k8s.io/apimachinery/pkg/util/validation"
)
type section string
type key string
const (
separator = "="
dumpSeparator = "\t"
dumpNone = "(none)"
dumpOff = "off"
interfaceSection section = "Interface"
peerSection section = "Peer"
listenPortKey key = "ListenPort"
allowedIPsKey key = "AllowedIPs"
endpointKey key = "Endpoint"
persistentKeepaliveKey key = "PersistentKeepalive"
presharedKeyKey key = "PresharedKey"
privateKeyKey key = "PrivateKey"
publicKeyKey key = "PublicKey"
)
// Conf represents a WireGuard configuration file.
type Conf struct {
Interface *Interface
Peers []*Peer
}
// Interface represents the `interface` section of a WireGuard configuration.
type Interface struct {
ListenPort uint32
PrivateKey []byte
}
// Peer represents a `peer` section of a WireGuard configuration.
type Peer struct {
AllowedIPs []*net.IPNet
Endpoint *Endpoint
PersistentKeepalive int
PresharedKey []byte
PublicKey []byte
// The following fields are part of the runtime information, not the configuration.
LatestHandshake time.Time
}
// DeduplicateIPs eliminates duplicate allowed IPs.
func (p *Peer) DeduplicateIPs() {
var ips []*net.IPNet
seen := make(map[string]struct{})
for _, ip := range p.AllowedIPs {
if _, ok := seen[ip.String()]; ok {
continue
}
ips = append(ips, ip)
seen[ip.String()] = struct{}{}
}
p.AllowedIPs = ips
}
// Endpoint represents an `endpoint` key of a `peer` section.
type Endpoint struct {
DNSOrIP
Port uint32
}
// String prints the string representation of the endpoint.
func (e *Endpoint) String() string {
if e == nil {
return ""
}
dnsOrIP := e.DNSOrIP.String()
if e.IP != nil && len(e.IP) == net.IPv6len {
dnsOrIP = "[" + dnsOrIP + "]"
}
return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
}
// Equal compares two endpoints.
func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
if (e == nil) != (b == nil) {
return false
}
if e != nil {
if e.Port != b.Port {
return false
}
if DNSFirst {
// Check the DNS name first if it was resolved.
if e.DNS != b.DNS {
return false
}
if e.DNS == "" && !e.IP.Equal(b.IP) {
return false
}
} else {
// IPs take priority, so check them first.
if !e.IP.Equal(b.IP) {
return false
}
// Only check the DNS name if the IP is empty.
if e.IP == nil && e.DNS != b.DNS {
return false
}
}
}
return true
}
// DNSOrIP represents either a DNS name or an IP address.
// IPs, as they are more specific, are preferred.
type DNSOrIP struct {
DNS string
IP net.IP
}
// String prints the string representation of the struct.
func (d DNSOrIP) String() string {
if d.IP != nil {
return d.IP.String()
}
return d.DNS
}
// Parse parses a given WireGuard configuration file and produces a Conf struct.
func Parse(buf []byte) *Conf {
var (
active section
kv []string
c Conf
err error
iface *Interface
i int
k key
line, v string
peer *Peer
port uint64
)
s := bufio.NewScanner(bytes.NewBuffer(buf))
for s.Scan() {
line = strings.TrimSpace(s.Text())
// Skip comments.
if strings.HasPrefix(line, "#") {
continue
}
// Line is a section title.
if strings.HasPrefix(line, "[") {
if peer != nil {
c.Peers = append(c.Peers, peer)
peer = nil
}
if iface != nil {
c.Interface = iface
iface = nil
}
active = section(strings.TrimSpace(strings.Trim(line, "[]")))
switch active {
case interfaceSection:
iface = new(Interface)
case peerSection:
peer = new(Peer)
}
continue
}
kv = strings.SplitN(line, separator, 2)
if len(kv) != 2 {
continue
}
k = key(strings.TrimSpace(kv[0]))
v = strings.TrimSpace(kv[1])
switch active {
case interfaceSection:
switch k {
case listenPortKey:
port, err = strconv.ParseUint(v, 10, 32)
if err != nil {
continue
}
iface.ListenPort = uint32(port)
case privateKeyKey:
iface.PrivateKey = []byte(v)
}
case peerSection:
switch k {
case allowedIPsKey:
err = peer.parseAllowedIPs(v)
if err != nil {
continue
}
case endpointKey:
err = peer.parseEndpoint(v)
if err != nil {
continue
}
case persistentKeepaliveKey:
i, err = strconv.Atoi(v)
if err != nil {
continue
}
peer.PersistentKeepalive = i
case presharedKeyKey:
peer.PresharedKey = []byte(v)
case publicKeyKey:
peer.PublicKey = []byte(v)
}
}
}
if peer != nil {
c.Peers = append(c.Peers, peer)
}
if iface != nil {
c.Interface = iface
}
return &c
}
// Bytes renders a WireGuard configuration to bytes.
func (c *Conf) Bytes() ([]byte, error) {
var err error
buf := bytes.NewBuffer(make([]byte, 0, 512))
if c.Interface != nil {
if err = writeSection(buf, interfaceSection); err != nil {
return nil, fmt.Errorf("failed to write interface: %v", err)
}
if err = writePKey(buf, privateKeyKey, c.Interface.PrivateKey); err != nil {
return nil, fmt.Errorf("failed to write private key: %v", err)
}
if err = writeValue(buf, listenPortKey, strconv.FormatUint(uint64(c.Interface.ListenPort), 10)); err != nil {
return nil, fmt.Errorf("failed to write listen port: %v", err)
}
}
for i, p := range c.Peers {
// Add newlines to make the formatting nicer.
if i == 0 && c.Interface != nil || i != 0 {
if err = buf.WriteByte('\n'); err != nil {
return nil, err
}
}
if err = writeSection(buf, peerSection); err != nil {
return nil, fmt.Errorf("failed to write interface: %v", err)
}
if err = writeAllowedIPs(buf, p.AllowedIPs); err != nil {
return nil, fmt.Errorf("failed to write allowed IPs: %v", err)
}
if err = writeEndpoint(buf, p.Endpoint); err != nil {
return nil, fmt.Errorf("failed to write endpoint: %v", err)
}
if err = writeValue(buf, persistentKeepaliveKey, strconv.Itoa(p.PersistentKeepalive)); err != nil {
return nil, fmt.Errorf("failed to write persistent keepalive: %v", err)
}
if err = writePKey(buf, presharedKeyKey, p.PresharedKey); err != nil {
return nil, fmt.Errorf("failed to write preshared key: %v", err)
}
if err = writePKey(buf, publicKeyKey, p.PublicKey); err != nil {
return nil, fmt.Errorf("failed to write public key: %v", err)
}
}
return buf.Bytes(), nil
}
// Equal checks if two WireGuard configurations are equivalent.
func (c *Conf) Equal(b *Conf) bool {
if (c.Interface == nil) != (b.Interface == nil) {
return false
}
if c.Interface != nil {
if c.Interface.ListenPort != b.Interface.ListenPort || !bytes.Equal(c.Interface.PrivateKey, b.Interface.PrivateKey) {
return false
}
}
if len(c.Peers) != len(b.Peers) {
return false
}
sortPeers(c.Peers)
sortPeers(b.Peers)
for i := range c.Peers {
if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
return false
}
sortCIDRs(c.Peers[i].AllowedIPs)
sortCIDRs(b.Peers[i].AllowedIPs)
for j := range c.Peers[i].AllowedIPs {
if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() {
return false
}
}
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) {
return false
}
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
return false
}
}
return true
}
func sortPeers(peers []*Peer) {
sort.Slice(peers, func(i, j int) bool {
if bytes.Compare(peers[i].PublicKey, peers[j].PublicKey) < 0 {
return true
}
return false
})
}
func sortCIDRs(cidrs []*net.IPNet) {
sort.Slice(cidrs, func(i, j int) bool {
return cidrs[i].String() < cidrs[j].String()
})
}
func writeAllowedIPs(buf *bytes.Buffer, ais []*net.IPNet) error {
if len(ais) == 0 {
return nil
}
var err error
if err = writeKey(buf, allowedIPsKey); err != nil {
return err
}
for i := range ais {
if i != 0 {
if _, err = buf.WriteString(", "); err != nil {
return err
}
}
if _, err = buf.WriteString(ais[i].String()); err != nil {
return err
}
}
return buf.WriteByte('\n')
}
func writePKey(buf *bytes.Buffer, k key, b []byte) error {
if len(b) == 0 {
return nil
}
var err error
if err = writeKey(buf, k); err != nil {
return err
}
if _, err = buf.Write(b); err != nil {
return err
}
return buf.WriteByte('\n')
}
func writeValue(buf *bytes.Buffer, k key, v string) error {
var err error
if err = writeKey(buf, k); err != nil {
return err
}
if _, err = buf.WriteString(v); err != nil {
return err
}
return buf.WriteByte('\n')
}
func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error {
if e == nil {
return nil
}
var err error
if err = writeKey(buf, endpointKey); err != nil {
return err
}
if _, err = buf.WriteString(e.String()); err != nil {
return err
}
return buf.WriteByte('\n')
}
func writeSection(buf *bytes.Buffer, s section) error {
var err error
if err = buf.WriteByte('['); err != nil {
return err
}
if _, err = buf.WriteString(string(s)); err != nil {
return err
}
if err = buf.WriteByte(']'); err != nil {
return err
}
return buf.WriteByte('\n')
}
func writeKey(buf *bytes.Buffer, k key) error {
var err error
if _, err = buf.WriteString(string(k)); err != nil {
return err
}
_, err = buf.WriteString(" = ")
return err
}
var (
errParseEndpoint = errors.New("could not parse Endpoint")
)
func (p *Peer) parseEndpoint(v string) error {
var (
kv []string
err error
ip, ip4 net.IP
port uint64
)
kv = strings.Split(v, ":")
if len(kv) < 2 {
return errParseEndpoint
}
port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
if err != nil {
return err
}
d := DNSOrIP{}
ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
if ip == nil {
if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
return errParseEndpoint
}
d.DNS = kv[0]
} else {
if ip4 = ip.To4(); ip4 != nil {
d.IP = ip4
} else {
d.IP = ip.To16()
}
}
p.Endpoint = &Endpoint{
DNSOrIP: d,
Port: uint32(port),
}
return nil
}
func (p *Peer) parseAllowedIPs(v string) error {
var (
ai *net.IPNet
kv []string
err error
i int
ip, ip4 net.IP
)
kv = strings.Split(v, ",")
for i = range kv {
ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
if err != nil {
return err
}
if ip4 = ip.To4(); ip4 != nil {
ip = ip4
} else {
ip = ip.To16()
}
ai.IP = ip
p.AllowedIPs = append(p.AllowedIPs, ai)
}
return nil
}
// ParseDump parses a given WireGuard dump and produces a Conf struct.
func ParseDump(buf []byte) *Conf {
// from man wg, show section:
// If dump is specified, then several lines are printed;
// the first contains in order separated by tab: private-key, public-key, listen-port, fwmark.
// Subsequent lines are printed for each peer and contain in order separated by tab:
// public-key, preshared-key, endpoint, allowed-ips, latest-handshake, transfer-rx, transfer-tx, persistent-keepalive.
var (
active section
values []string
c Conf
err error
iface *Interface
i int
peer *Peer
port uint64
sec int64
)
// First line is Interface
active = interfaceSection
s := bufio.NewScanner(bytes.NewBuffer(buf))
for s.Scan() {
values = strings.Split(s.Text(), dumpSeparator)
switch active {
case interfaceSection:
if len(values) < 4 {
break
}
iface = new(Interface)
iface.PrivateKey = []byte(values[0])
port, err = strconv.ParseUint(values[2], 10, 32)
if err == nil {
iface.ListenPort = uint32(port)
}
c.Interface = iface
// Next lines are Peers
active = peerSection
case peerSection:
if len(values) < 8 {
break
}
peer = new(Peer)
peer.PublicKey = []byte(values[0])
if values[1] != dumpNone {
peer.PresharedKey = []byte(values[1])
}
if values[2] != dumpNone {
peer.parseEndpoint(values[2])
}
if values[3] != dumpNone {
peer.parseAllowedIPs(values[3])
}
if values[4] != "0" {
sec, err = strconv.ParseInt(values[4], 10, 64)
if err == nil {
peer.LatestHandshake = time.Unix(sec, 0)
}
}
if values[7] != dumpOff {
i, err = strconv.Atoi(values[7])
if err == nil {
peer.PersistentKeepalive = i
}
}
c.Peers = append(c.Peers, peer)
peer = nil
}
}
return &c
}