package netlink

import (
	"math/rand"
	"sync"
	"sync/atomic"
	"syscall"
	"time"

	"golang.org/x/net/bpf"
)

// A Conn is a connection to netlink.  A Conn can be used to send and
// receives messages to and from netlink.
//
// A Conn is safe for concurrent use, but to avoid contention in
// high-throughput applications, the caller should almost certainly create a
// pool of Conns and distribute them among workers.
//
// A Conn is capable of manipulating netlink subsystems from within a specific
// Linux network namespace, but special care must be taken when doing so. See
// the documentation of Config for details.
type Conn struct {
	// Atomics must come first.
	//
	// seq is an atomically incremented integer used to provide sequence
	// numbers when Conn.Send is called.
	seq uint32

	// mu serializes access to the netlink socket for the request/response
	// transaction within Execute.
	mu sync.RWMutex

	// sock is the operating system-specific implementation of
	// a netlink sockets connection.
	sock Socket

	// pid is the PID assigned by netlink.
	pid uint32

	// d provides debugging capabilities for a Conn if not nil.
	d *debugger
}

// A Socket is an operating-system specific implementation of netlink
// sockets used by Conn.
type Socket interface {
	Close() error
	Send(m Message) error
	SendMessages(m []Message) error
	Receive() ([]Message, error)
}

// Dial dials a connection to netlink, using the specified netlink family.
// Config specifies optional configuration for Conn.  If config is nil, a default
// configuration will be used.
func Dial(family int, config *Config) (*Conn, error) {
	// Use OS-specific dial() to create Socket
	c, pid, err := dial(family, config)
	if err != nil {
		return nil, err
	}

	return NewConn(c, pid), nil
}

// NewConn creates a Conn using the specified Socket and PID for netlink
// communications.
//
// NewConn is primarily useful for tests. Most applications should use
// Dial instead.
func NewConn(sock Socket, pid uint32) *Conn {
	// Seed the sequence number using a random number generator.
	r := rand.New(rand.NewSource(time.Now().UnixNano()))
	seq := r.Uint32()

	// Configure a debugger if arguments are set.
	var d *debugger
	if len(debugArgs) > 0 {
		d = newDebugger(debugArgs)
	}

	return &Conn{
		seq:  seq,
		sock: sock,
		pid:  pid,
		d:    d,
	}
}

// debug executes fn with the debugger if the debugger is not nil.
func (c *Conn) debug(fn func(d *debugger)) {
	if c.d == nil {
		return
	}

	fn(c.d)
}

// Close closes the connection and unblocks any pending read operations.
func (c *Conn) Close() error {
	// Close does not acquire a lock because it must be able to interrupt any
	// blocked system calls, such as when Receive is waiting on a multicast
	// group message.
	//
	// We rely on the kernel to deal with concurrent operations to the netlink
	// socket itself.
	return newOpError("close", c.sock.Close())
}

// Execute sends a single Message to netlink using Send, receives one or more
// replies using Receive, and then checks the validity of the replies against
// the request using Validate.
//
// Execute acquires a lock for the duration of the function call which blocks
// concurrent calls to Send, SendMessages, and Receive, in order to ensure
// consistency between netlink request/reply messages.
//
// See the documentation of Send, Receive, and Validate for details about
// each function.
func (c *Conn) Execute(m Message) ([]Message, error) {
	// Acquire the write lock and invoke the internal implementations of Send
	// and Receive which require the lock already be held.
	c.mu.Lock()
	defer c.mu.Unlock()

	req, err := c.lockedSend(m)
	if err != nil {
		return nil, err
	}

	res, err := c.lockedReceive()
	if err != nil {
		return nil, err
	}

	if err := Validate(req, res); err != nil {
		return nil, err
	}

	return res, nil
}

// SendMessages sends multiple Messages to netlink. The handling of
// a Header's Length, Sequence and PID fields is the same as when
// calling Send.
func (c *Conn) SendMessages(msgs []Message) ([]Message, error) {
	// Wait for any concurrent calls to Execute to finish before proceeding.
	c.mu.RLock()
	defer c.mu.RUnlock()

	for i := range msgs {
		c.fixMsg(&msgs[i], nlmsgLength(len(msgs[i].Data)))
	}

	c.debug(func(d *debugger) {
		for _, m := range msgs {
			d.debugf(1, "send msgs: %+v", m)
		}
	})

	if err := c.sock.SendMessages(msgs); err != nil {
		c.debug(func(d *debugger) {
			d.debugf(1, "send msgs: err: %v", err)
		})

		return nil, newOpError("send-messages", err)
	}

	return msgs, nil
}

// Send sends a single Message to netlink.  In most cases, a Header's Length,
// Sequence, and PID fields should be set to 0, so they can be populated
// automatically before the Message is sent.  On success, Send returns a copy
// of the Message with all parameters populated, for later validation.
//
// If Header.Length is 0, it will be automatically populated using the
// correct length for the Message, including its payload.
//
// If Header.Sequence is 0, it will be automatically populated using the
// next sequence number for this connection.
//
// If Header.PID is 0, it will be automatically populated using a PID
// assigned by netlink.
func (c *Conn) Send(m Message) (Message, error) {
	// Wait for any concurrent calls to Execute to finish before proceeding.
	c.mu.RLock()
	defer c.mu.RUnlock()

	return c.lockedSend(m)
}

// lockedSend implements Send, but must be called with c.mu acquired for reading.
// We rely on the kernel to deal with concurrent reads and writes to the netlink
// socket itself.
func (c *Conn) lockedSend(m Message) (Message, error) {
	c.fixMsg(&m, nlmsgLength(len(m.Data)))

	c.debug(func(d *debugger) {
		d.debugf(1, "send: %+v", m)
	})

	if err := c.sock.Send(m); err != nil {
		c.debug(func(d *debugger) {
			d.debugf(1, "send: err: %v", err)
		})

		return Message{}, newOpError("send", err)
	}

	return m, nil
}

// Receive receives one or more messages from netlink.  Multi-part messages are
// handled transparently and returned as a single slice of Messages, with the
// final empty "multi-part done" message removed.
//
// If any of the messages indicate a netlink error, that error will be returned.
func (c *Conn) Receive() ([]Message, error) {
	// Wait for any concurrent calls to Execute to finish before proceeding.
	c.mu.RLock()
	defer c.mu.RUnlock()

	return c.lockedReceive()
}

// lockedReceive implements Receive, but must be called with c.mu acquired for reading.
// We rely on the kernel to deal with concurrent reads and writes to the netlink
// socket itself.
func (c *Conn) lockedReceive() ([]Message, error) {
	msgs, err := c.receive()
	if err != nil {
		c.debug(func(d *debugger) {
			d.debugf(1, "recv: err: %v", err)
		})

		return nil, err
	}

	c.debug(func(d *debugger) {
		for _, m := range msgs {
			d.debugf(1, "recv: %+v", m)
		}
	})

	// When using nltest, it's possible for zero messages to be returned by receive.
	if len(msgs) == 0 {
		return msgs, nil
	}

	// Trim the final message with multi-part done indicator if
	// present.
	if m := msgs[len(msgs)-1]; m.Header.Flags&Multi != 0 && m.Header.Type == Done {
		return msgs[:len(msgs)-1], nil
	}

	return msgs, nil
}

// receive is the internal implementation of Conn.Receive, which can be called
// recursively to handle multi-part messages.
func (c *Conn) receive() ([]Message, error) {
	// NB: All non-nil errors returned from this function *must* be of type
	// OpError in order to maintain the appropriate contract with callers of
	// this package.
	//
	// This contract also applies to functions called within this function,
	// such as checkMessage.

	var res []Message
	for {
		msgs, err := c.sock.Receive()
		if err != nil {
			return nil, newOpError("receive", err)
		}

		// If this message is multi-part, we will need to continue looping to
		// drain all the messages from the socket.
		var multi bool

		for _, m := range msgs {
			if err := checkMessage(m); err != nil {
				return nil, err
			}

			// Does this message indicate a multi-part message?
			if m.Header.Flags&Multi == 0 {
				// No, check the next messages.
				continue
			}

			// Does this message indicate the last message in a series of
			// multi-part messages from a single read?
			multi = m.Header.Type != Done
		}

		res = append(res, msgs...)

		if !multi {
			// No more messages coming.
			return res, nil
		}
	}
}

// A groupJoinLeaver is a Socket that supports joining and leaving
// netlink multicast groups.
type groupJoinLeaver interface {
	Socket
	JoinGroup(group uint32) error
	LeaveGroup(group uint32) error
}

// JoinGroup joins a netlink multicast group by its ID.
func (c *Conn) JoinGroup(group uint32) error {
	conn, ok := c.sock.(groupJoinLeaver)
	if !ok {
		return notSupported("join-group")
	}

	return newOpError("join-group", conn.JoinGroup(group))
}

// LeaveGroup leaves a netlink multicast group by its ID.
func (c *Conn) LeaveGroup(group uint32) error {
	conn, ok := c.sock.(groupJoinLeaver)
	if !ok {
		return notSupported("leave-group")
	}

	return newOpError("leave-group", conn.LeaveGroup(group))
}

// A bpfSetter is a Socket that supports setting and removing BPF filters.
type bpfSetter interface {
	Socket
	bpf.Setter
	RemoveBPF() error
}

// SetBPF attaches an assembled BPF program to a Conn.
func (c *Conn) SetBPF(filter []bpf.RawInstruction) error {
	conn, ok := c.sock.(bpfSetter)
	if !ok {
		return notSupported("set-bpf")
	}

	return newOpError("set-bpf", conn.SetBPF(filter))
}

// RemoveBPF removes a BPF filter from a Conn.
func (c *Conn) RemoveBPF() error {
	conn, ok := c.sock.(bpfSetter)
	if !ok {
		return notSupported("remove-bpf")
	}

	return newOpError("remove-bpf", conn.RemoveBPF())
}

// A deadlineSetter is a Socket that supports setting deadlines.
type deadlineSetter interface {
	Socket
	SetDeadline(time.Time) error
	SetReadDeadline(time.Time) error
	SetWriteDeadline(time.Time) error
}

// SetDeadline sets the read and write deadlines associated with the connection.
func (c *Conn) SetDeadline(t time.Time) error {
	conn, ok := c.sock.(deadlineSetter)
	if !ok {
		return notSupported("set-deadline")
	}

	return newOpError("set-deadline", conn.SetDeadline(t))
}

// SetReadDeadline sets the read deadline associated with the connection.
func (c *Conn) SetReadDeadline(t time.Time) error {
	conn, ok := c.sock.(deadlineSetter)
	if !ok {
		return notSupported("set-read-deadline")
	}

	return newOpError("set-read-deadline", conn.SetReadDeadline(t))
}

// SetWriteDeadline sets the write deadline associated with the connection.
func (c *Conn) SetWriteDeadline(t time.Time) error {
	conn, ok := c.sock.(deadlineSetter)
	if !ok {
		return notSupported("set-write-deadline")
	}

	return newOpError("set-write-deadline", conn.SetWriteDeadline(t))
}

// A ConnOption is a boolean option that may be set for a Conn.
type ConnOption int

// Possible ConnOption values.  These constants are equivalent to the Linux
// setsockopt boolean options for netlink sockets.
const (
	PacketInfo ConnOption = iota
	BroadcastError
	NoENOBUFS
	ListenAllNSID
	CapAcknowledge
	ExtendedAcknowledge
	GetStrictCheck
)

// An optionSetter is a Socket that supports setting netlink options.
type optionSetter interface {
	Socket
	SetOption(option ConnOption, enable bool) error
}

// SetOption enables or disables a netlink socket option for the Conn.
func (c *Conn) SetOption(option ConnOption, enable bool) error {
	conn, ok := c.sock.(optionSetter)
	if !ok {
		return notSupported("set-option")
	}

	return newOpError("set-option", conn.SetOption(option, enable))
}

// A bufferSetter is a Socket that supports setting connection buffer sizes.
type bufferSetter interface {
	Socket
	SetReadBuffer(bytes int) error
	SetWriteBuffer(bytes int) error
}

// SetReadBuffer sets the size of the operating system's receive buffer
// associated with the Conn.
func (c *Conn) SetReadBuffer(bytes int) error {
	conn, ok := c.sock.(bufferSetter)
	if !ok {
		return notSupported("set-read-buffer")
	}

	return newOpError("set-read-buffer", conn.SetReadBuffer(bytes))
}

// SetWriteBuffer sets the size of the operating system's transmit buffer
// associated with the Conn.
func (c *Conn) SetWriteBuffer(bytes int) error {
	conn, ok := c.sock.(bufferSetter)
	if !ok {
		return notSupported("set-write-buffer")
	}

	return newOpError("set-write-buffer", conn.SetWriteBuffer(bytes))
}

// A syscallConner is a Socket that supports syscall.Conn.
type syscallConner interface {
	Socket
	syscall.Conn
}

var _ syscall.Conn = &Conn{}

// SyscallConn returns a raw network connection. This implements the
// syscall.Conn interface.
//
// SyscallConn is intended for advanced use cases, such as getting and setting
// arbitrary socket options using the netlink socket's file descriptor.
//
// Once invoked, it is the caller's responsibility to ensure that operations
// performed using Conn and the syscall.RawConn do not conflict with
// each other.
func (c *Conn) SyscallConn() (syscall.RawConn, error) {
	sc, ok := c.sock.(syscallConner)
	if !ok {
		return nil, notSupported("syscall-conn")
	}

	// TODO(mdlayher): mutex or similar to enforce syscall.RawConn contract of
	// FD remaining valid for duration of calls?

	return sc.SyscallConn()
}

// fixMsg updates the fields of m using the logic specified in Send.
func (c *Conn) fixMsg(m *Message, ml int) {
	if m.Header.Length == 0 {
		m.Header.Length = uint32(nlmsgAlign(ml))
	}

	if m.Header.Sequence == 0 {
		m.Header.Sequence = c.nextSequence()
	}

	if m.Header.PID == 0 {
		m.Header.PID = c.pid
	}
}

// nextSequence atomically increments Conn's sequence number and returns
// the incremented value.
func (c *Conn) nextSequence() uint32 {
	return atomic.AddUint32(&c.seq, 1)
}

// Validate validates one or more reply Messages against a request Message,
// ensuring that they contain matching sequence numbers and PIDs.
func Validate(request Message, replies []Message) error {
	for _, m := range replies {
		// Check for mismatched sequence, unless:
		//   - request had no sequence, meaning we are probably validating
		//     a multicast reply
		if m.Header.Sequence != request.Header.Sequence && request.Header.Sequence != 0 {
			return newOpError("validate", errMismatchedSequence)
		}

		// Check for mismatched PID, unless:
		//   - request had no PID, meaning we are either:
		//     - validating a multicast reply
		//     - netlink has not yet assigned us a PID
		//   - response had no PID, meaning it's from the kernel as a multicast reply
		if m.Header.PID != request.Header.PID && request.Header.PID != 0 && m.Header.PID != 0 {
			return newOpError("validate", errMismatchedPID)
		}
	}

	return nil
}

// Config contains options for a Conn.
type Config struct {
	// Groups is a bitmask which specifies multicast groups. If set to 0,
	// no multicast group subscriptions will be made.
	Groups uint32

	// NetNS specifies the network namespace the Conn will operate in.
	//
	// If set (non-zero), Conn will enter the specified network namespace and
	// an error will occur in Dial if the operation fails.
	//
	// If not set (zero), a best-effort attempt will be made to enter the
	// network namespace of the calling thread: this means that any changes made
	// to the calling thread's network namespace will also be reflected in Conn.
	// If this operation fails (due to lack of permissions or because network
	// namespaces are disabled by kernel configuration), Dial will not return
	// an error, and the Conn will operate in the default network namespace of
	// the process. This enables non-privileged use of Conn in applications
	// which do not require elevated privileges.
	//
	// Entering a network namespace is a privileged operation (root or
	// CAP_SYS_ADMIN are required), and most applications should leave this set
	// to 0.
	NetNS int

	// DisableNSLockThread is deprecated and has no effect.
	DisableNSLockThread bool
}