//+build linux

package netlink

import (
	"os"
	"runtime"
	"syscall"
	"time"
	"unsafe"

	"github.com/mdlayher/socket"
	"golang.org/x/net/bpf"
	"golang.org/x/sys/unix"
)

var _ Socket = &conn{}

// A conn is the Linux implementation of a netlink sockets connection.
type conn struct {
	s *socket.Conn
}

// dial is the entry point for Dial. dial opens a netlink socket using
// system calls, and returns its PID.
func dial(family int, config *Config) (*conn, uint32, error) {
	if config == nil {
		config = &Config{}
	}

	// The caller has indicated it wants the netlink socket to be created
	// inside another network namespace.
	if config.NetNS != 0 {
		runtime.LockOSThread()
		defer runtime.UnlockOSThread()

		// Retrieve and store the calling OS thread's network namespace so
		// the thread can be reassigned to it after creating a socket in another
		// network namespace.
		threadNS, err := threadNetNS()
		if err != nil {
			return nil, 0, err
		}
		// Always close the netns handle created above.
		defer threadNS.Close()

		// Assign the current OS thread the goroutine is locked to to the given
		// network namespace.
		if err := threadNS.Set(config.NetNS); err != nil {
			return nil, 0, err
		}

		// Thread's namespace has been successfully set. Return the thread
		// back to its original namespace after attempting to create the
		// netlink socket.
		defer threadNS.Restore()
	}

	// Prepare the netlink socket.
	s, err := socket.Socket(unix.AF_NETLINK, unix.SOCK_RAW, family, "netlink")
	if err != nil {
		return nil, 0, err
	}

	return newConn(s, config)
}

// newConn binds a connection to netlink using the input *socket.Conn.
func newConn(s *socket.Conn, config *Config) (*conn, uint32, error) {
	if config == nil {
		config = &Config{}
	}

	addr := &unix.SockaddrNetlink{
		Family: unix.AF_NETLINK,
		Groups: config.Groups,
	}

	// Socket must be closed in the event of any system call errors, to avoid
	// leaking file descriptors.

	if err := s.Bind(addr); err != nil {
		_ = s.Close()
		return nil, 0, err
	}

	sa, err := s.Getsockname()
	if err != nil {
		_ = s.Close()
		return nil, 0, err
	}

	return &conn{
		s: s,
	}, sa.(*unix.SockaddrNetlink).Pid, nil
}

// SendMessages serializes multiple Messages and sends them to netlink.
func (c *conn) SendMessages(messages []Message) error {
	var buf []byte
	for _, m := range messages {
		b, err := m.MarshalBinary()
		if err != nil {
			return err
		}

		buf = append(buf, b...)
	}

	sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK}
	return c.s.Sendmsg(buf, nil, sa, 0)
}

// Send sends a single Message to netlink.
func (c *conn) Send(m Message) error {
	b, err := m.MarshalBinary()
	if err != nil {
		return err
	}

	sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK}
	return c.s.Sendmsg(b, nil, sa, 0)
}

// Receive receives one or more Messages from netlink.
func (c *conn) Receive() ([]Message, error) {
	b := make([]byte, os.Getpagesize())
	for {
		// Peek at the buffer to see how many bytes are available.
		//
		// TODO(mdlayher): deal with OOB message data if available, such as
		// when PacketInfo ConnOption is true.
		n, _, _, _, err := c.s.Recvmsg(b, nil, unix.MSG_PEEK)
		if err != nil {
			return nil, err
		}

		// Break when we can read all messages
		if n < len(b) {
			break
		}

		// Double in size if not enough bytes
		b = make([]byte, len(b)*2)
	}

	// Read out all available messages
	n, _, _, _, err := c.s.Recvmsg(b, nil, 0)
	if err != nil {
		return nil, err
	}

	raw, err := syscall.ParseNetlinkMessage(b[:nlmsgAlign(n)])
	if err != nil {
		return nil, err
	}

	msgs := make([]Message, 0, len(raw))
	for _, r := range raw {
		m := Message{
			Header: sysToHeader(r.Header),
			Data:   r.Data,
		}

		msgs = append(msgs, m)
	}

	return msgs, nil
}

// Close closes the connection.
func (c *conn) Close() error { return c.s.Close() }

// JoinGroup joins a multicast group by ID.
func (c *conn) JoinGroup(group uint32) error {
	return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_ADD_MEMBERSHIP, int(group))
}

// LeaveGroup leaves a multicast group by ID.
func (c *conn) LeaveGroup(group uint32) error {
	return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_DROP_MEMBERSHIP, int(group))
}

// SetBPF attaches an assembled BPF program to a conn.
func (c *conn) SetBPF(filter []bpf.RawInstruction) error { return c.s.SetBPF(filter) }

// RemoveBPF removes a BPF filter from a conn.
func (c *conn) RemoveBPF() error { return c.s.RemoveBPF() }

// SetOption enables or disables a netlink socket option for the Conn.
func (c *conn) SetOption(option ConnOption, enable bool) error {
	o, ok := linuxOption(option)
	if !ok {
		// Return the typical Linux error for an unknown ConnOption.
		return os.NewSyscallError("setsockopt", unix.ENOPROTOOPT)
	}

	var v int
	if enable {
		v = 1
	}

	return c.s.SetsockoptInt(unix.SOL_NETLINK, o, v)
}

func (c *conn) SetDeadline(t time.Time) error      { return c.s.SetDeadline(t) }
func (c *conn) SetReadDeadline(t time.Time) error  { return c.s.SetReadDeadline(t) }
func (c *conn) SetWriteDeadline(t time.Time) error { return c.s.SetWriteDeadline(t) }

// SetReadBuffer sets the size of the operating system's receive buffer
// associated with the Conn.
func (c *conn) SetReadBuffer(bytes int) error { return c.s.SetReadBuffer(bytes) }

// SetReadBuffer sets the size of the operating system's transmit buffer
// associated with the Conn.
func (c *conn) SetWriteBuffer(bytes int) error { return c.s.SetWriteBuffer(bytes) }

// SyscallConn returns a raw network connection.
func (c *conn) SyscallConn() (syscall.RawConn, error) { return c.s.SyscallConn() }

// linuxOption converts a ConnOption to its Linux value.
func linuxOption(o ConnOption) (int, bool) {
	switch o {
	case PacketInfo:
		return unix.NETLINK_PKTINFO, true
	case BroadcastError:
		return unix.NETLINK_BROADCAST_ERROR, true
	case NoENOBUFS:
		return unix.NETLINK_NO_ENOBUFS, true
	case ListenAllNSID:
		return unix.NETLINK_LISTEN_ALL_NSID, true
	case CapAcknowledge:
		return unix.NETLINK_CAP_ACK, true
	case ExtendedAcknowledge:
		return unix.NETLINK_EXT_ACK, true
	case GetStrictCheck:
		return unix.NETLINK_GET_STRICT_CHK, true
	default:
		return 0, false
	}
}

// sysToHeader converts a syscall.NlMsghdr to a Header.
func sysToHeader(r syscall.NlMsghdr) Header {
	// NB: the memory layout of Header and syscall.NlMsgHdr must be
	// exactly the same for this unsafe cast to work
	return *(*Header)(unsafe.Pointer(&r))
}

// newError converts an error number from netlink into the appropriate
// system call error for Linux.
func newError(errno int) error {
	return syscall.Errno(errno)
}