import math
import os
import pytest
import socket
import subprocess
import sys
import trio
import unshare
import netfilterqueue
from functools import partial
from typing import AsyncIterator, Callable, Optional
from async_generator import asynccontextmanager
from pytest_trio.enable_trio_mode import *


# We'll create three network namespaces, representing a router (which
# has interfaces on ROUTER_IP[1, 2]) and two hosts connected to it
# (PEER_IP[1, 2] respectively). The router (in the parent pytest
# process) will configure netfilterqueue iptables rules and use them
# to intercept and modify traffic between the two hosts (each of which
# is implemented in a subprocess).
#
# The 'peer' subprocesses communicate with each other over UDP, and
# with the router parent over a UNIX domain SOCK_SEQPACKET socketpair.
# Each packet sent from the parent to one peer over the UNIX domain
# socket will be forwarded to the other peer over UDP. Each packet
# received over UDP by either of the peers will be forwarded to its
# parent.

ROUTER_IP = {1: "172.16.101.1", 2: "172.16.102.1"}
PEER_IP = {1: "172.16.101.2", 2: "172.16.102.2"}


def enter_netns() -> None:
    # Create new namespaces of the other types we need
    unshare.unshare(unshare.CLONE_NEWNS | unshare.CLONE_NEWNET)

    # Mount /sys so network tools work
    subprocess.run("/bin/mount -t sysfs sys /sys".split(), check=True)

    # Bind-mount /run so iptables can get its lock
    subprocess.run("/bin/mount -t tmpfs tmpfs /run".split(), check=True)

    # Set up loopback interface
    subprocess.run("/sbin/ip link set lo up".split(), check=True)


@pytest.hookimpl(tryfirst=True)
def pytest_runtestloop():
    if os.getuid() != 0:
        # Create a new user namespace for the whole test session
        outer = {"uid": os.getuid(), "gid": os.getgid()}
        unshare.unshare(unshare.CLONE_NEWUSER)
        with open("/proc/self/setgroups", "wb") as fp:
            # This is required since we're unprivileged outside the namespace
            fp.write(b"deny")
        for idtype in ("uid", "gid"):
            with open(f"/proc/self/{idtype}_map", "wb") as fp:
                fp.write(b"0 %d 1" % (outer[idtype],))
        assert os.getuid() == os.getgid() == 0

    # Create a new network namespace for this pytest process
    enter_netns()
    with open("/proc/sys/net/ipv4/ip_forward", "wb") as fp:
        fp.write(b"1\n")


async def peer_main(idx: int, parent_fd: int) -> None:
    parent = trio.socket.fromfd(parent_fd, socket.AF_UNIX, socket.SOCK_SEQPACKET)

    # Tell parent we've set up our netns, wait for it to confirm it's
    # created our veth interface
    await parent.send(b"ok")
    assert b"ok" == await parent.recv(4096)

    my_ip = PEER_IP[idx]
    router_ip = ROUTER_IP[idx]
    peer_ip = PEER_IP[3 - idx]

    for cmd in (
        f"ip link set veth0 up",
        f"ip addr add {my_ip}/24 dev veth0",
        f"ip route add default via {router_ip} dev veth0",
    ):
        await trio.run_process(cmd.split(), capture_stdout=True, capture_stderr=True)

    peer = trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    await peer.bind((my_ip, 0))

    # Tell the parent our port and get our peer's port
    await parent.send(b"%d" % peer.getsockname()[1])
    peer_port = int(await parent.recv(4096))
    await peer.connect((peer_ip, peer_port))

    # Enter the message-forwarding loop
    async def proxy_one_way(src, dest):
        while src.fileno() >= 0:
            try:
                msg = await src.recv(4096)
            except trio.ClosedResourceError:
                return
            if not msg:
                dest.close()
                return
            try:
                await dest.send(msg)
            except BrokenPipeError:
                return

    async with trio.open_nursery() as nursery:
        nursery.start_soon(proxy_one_way, parent, peer)
        nursery.start_soon(proxy_one_way, peer, parent)


def _default_capture_cb(
    target: "trio.MemorySendChannel[netfilterqueue.Packet]",
    packet: netfilterqueue.Packet,
) -> None:
    packet.retain()
    target.send_nowait(packet)


class Harness:
    def __init__(self):
        self._received = {}
        self._conn = {}
        self.failed = False

    async def _run_peer(self, idx: int, *, task_status):
        their_ip = PEER_IP[idx]
        my_ip = ROUTER_IP[idx]
        conn, child_conn = trio.socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
        with conn:
            try:
                process = await trio.open_process(
                    [sys.executable, __file__, str(idx), str(child_conn.fileno())],
                    stdin=subprocess.DEVNULL,
                    pass_fds=[child_conn.fileno()],
                    preexec_fn=enter_netns,
                )
            finally:
                child_conn.close()
            assert b"ok" == await conn.recv(4096)
            for cmd in (
                f"ip link add veth{idx} type veth peer netns {process.pid} name veth0",
                f"ip link set veth{idx} up",
                f"ip addr add {my_ip}/24 dev veth{idx}",
            ):
                await trio.run_process(cmd.split())

            try:
                await conn.send(b"ok")
                self._conn[idx] = conn
                task_status.started()
                retval = await process.wait()
            except BaseException:
                process.kill()
                with trio.CancelScope(shield=True):
                    await process.wait()
                raise
            else:
                if retval != 0:
                    raise RuntimeError(
                        "peer subprocess exited with code {}".format(retval)
                    )
            finally:
                # On some kernels the veth device is removed when the subprocess exits
                # and its netns goes away. check=False to suppress that error.
                await trio.run_process(f"ip link delete veth{idx}".split(), check=False)

    async def _manage_peer(self, idx: int, *, task_status):
        async with trio.open_nursery() as nursery:
            await nursery.start(self._run_peer, idx)
            packets_w, packets_r = trio.open_memory_channel(math.inf)
            self._received[idx] = packets_r
            task_status.started()
            async with packets_w:
                while True:
                    msg = await self._conn[idx].recv(4096)
                    if not msg:
                        break
                    await packets_w.send(msg)

    @asynccontextmanager
    async def run(self):
        async with trio.open_nursery() as nursery:
            async with trio.open_nursery() as start_nursery:
                start_nursery.start_soon(nursery.start, self._manage_peer, 1)
                start_nursery.start_soon(nursery.start, self._manage_peer, 2)
            # Tell each peer about the other one's port
            await self._conn[2].send(await self._received[1].receive())
            await self._conn[1].send(await self._received[2].receive())
            yield
            self._conn[1].shutdown(socket.SHUT_WR)
            self._conn[2].shutdown(socket.SHUT_WR)

        if not self.failed:
            for idx in (1, 2):
                async for remainder in self._received[idx]:
                    raise AssertionError(
                        f"Peer {idx} received unexepcted packet {remainder!r}"
                    )

    @asynccontextmanager
    async def capture_packets_to(
        self,
        idx: int,
        cb: Callable[
            ["trio.MemorySendChannel[netfilterqueue.Packet]", netfilterqueue.Packet],
            None,
        ] = _default_capture_cb,
        *,
        queue_num: int = -1,
        **options: int,
    ) -> AsyncIterator["trio.MemoryReceiveChannel[netfilterqueue.Packet]"]:

        packets_w, packets_r = trio.open_memory_channel(math.inf)

        nfq = netfilterqueue.NetfilterQueue()
        # Use a smaller socket buffer to avoid a warning in CI
        options.setdefault("sock_len", 131072)
        if queue_num >= 0:
            nfq.bind(queue_num, partial(cb, packets_w), **options)
        else:
            for queue_num in range(16):
                try:
                    nfq.bind(queue_num, partial(cb, packets_w), **options)
                    break
                except Exception as ex:
                    last_error = ex
            else:
                raise RuntimeError(
                    "Couldn't bind any netfilter queue number between 0-15"
                ) from last_error
        try:
            rule = f"-d {PEER_IP[idx]} -j NFQUEUE --queue-num {queue_num}"
            await trio.run_process(f"/sbin/iptables -A FORWARD {rule}".split())
            try:
                async with packets_w, trio.open_nursery() as nursery:

                    @nursery.start_soon
                    async def listen_for_packets():
                        while True:
                            await trio.lowlevel.wait_readable(nfq.get_fd())
                            nfq.run(block=False)

                    yield packets_r
                    nursery.cancel_scope.cancel()
            finally:
                await trio.run_process(f"/sbin/iptables -D FORWARD {rule}".split())
        finally:
            nfq.unbind()

    async def expect(self, idx: int, *packets: bytes):
        for expected in packets:
            with trio.move_on_after(5) as scope:
                received = await self._received[idx].receive()
            if scope.cancelled_caught:
                self.failed = True
                raise AssertionError(
                    f"Timeout waiting for peer {idx} to receive {expected!r}"
                )
            if received != expected:
                self.failed = True
                raise AssertionError(
                    f"Expected peer {idx} to receive {expected!r} but it "
                    f"received {received!r}"
                )

    async def send(self, idx: int, *packets: bytes):
        for packet in packets:
            await self._conn[3 - idx].send(packet)


@pytest.fixture
async def harness() -> Harness:
    h = Harness()
    async with h.run():
        yield h


if __name__ == "__main__":
    trio.run(peer_main, int(sys.argv[1]), int(sys.argv[2]))