import math import os import pytest import socket import subprocess import sys import trio import unshare import netfilterqueue from functools import partial from typing import AsyncIterator, Callable, Optional, Tuple from async_generator import asynccontextmanager from pytest_trio.enable_trio_mode import * # We'll create three network namespaces, representing a router (which # has interfaces on ROUTER_IP[1, 2]) and two hosts connected to it # (PEER_IP[1, 2] respectively). The router (in the parent pytest # process) will configure netfilterqueue iptables rules and use them # to intercept and modify traffic between the two hosts (each of which # is implemented in a subprocess). # # The 'peer' subprocesses communicate with each other over UDP, and # with the router parent over a UNIX domain SOCK_SEQPACKET socketpair. # Each packet sent from the parent to one peer over the UNIX domain # socket will be forwarded to the other peer over UDP. Each packet # received over UDP by either of the peers will be forwarded to its # parent. ROUTER_IP = {1: "172.16.101.1", 2: "172.16.102.1"} PEER_IP = {1: "172.16.101.2", 2: "172.16.102.2"} def enter_netns() -> None: # Create new namespaces of the other types we need unshare.unshare(unshare.CLONE_NEWNS | unshare.CLONE_NEWNET) # Mount /sys so network tools work subprocess.run("/bin/mount -t sysfs sys /sys".split(), check=True) # Bind-mount /run so iptables can get its lock subprocess.run("/bin/mount -t tmpfs tmpfs /run".split(), check=True) # Set up loopback interface subprocess.run("/sbin/ip link set lo up".split(), check=True) @pytest.hookimpl(tryfirst=True) def pytest_runtestloop(): if os.getuid() != 0: # Create a new user namespace for the whole test session outer = {"uid": os.getuid(), "gid": os.getgid()} unshare.unshare(unshare.CLONE_NEWUSER) with open("/proc/self/setgroups", "wb") as fp: # This is required since we're unprivileged outside the namespace fp.write(b"deny") for idtype in ("uid", "gid"): with open(f"/proc/self/{idtype}_map", "wb") as fp: fp.write(b"0 %d 1" % (outer[idtype],)) assert os.getuid() == os.getgid() == 0 # Create a new network namespace for this pytest process enter_netns() with open("/proc/sys/net/ipv4/ip_forward", "wb") as fp: fp.write(b"1\n") async def peer_main(idx: int, parent_fd: int) -> None: parent = trio.socket.fromfd(parent_fd, socket.AF_UNIX, socket.SOCK_SEQPACKET) # Tell parent we've set up our netns, wait for it to confirm it's # created our veth interface await parent.send(b"ok") assert b"ok" == await parent.recv(4096) my_ip = PEER_IP[idx] router_ip = ROUTER_IP[idx] peer_ip = PEER_IP[3 - idx] for cmd in ( f"ip link set veth0 up", f"ip addr add {my_ip}/24 dev veth0", f"ip route add default via {router_ip} dev veth0", ): await trio.run_process(cmd.split(), capture_stdout=True, capture_stderr=True) peer = trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM) await peer.bind((my_ip, 0)) # Tell the parent our port and get our peer's port await parent.send(b"%d" % peer.getsockname()[1]) peer_port = int(await parent.recv(4096)) await peer.connect((peer_ip, peer_port)) # Enter the message-forwarding loop async def proxy_one_way(src, dest): while src.fileno() >= 0: try: msg = await src.recv(4096) except trio.ClosedResourceError: return if not msg: dest.close() return try: await dest.send(msg) except BrokenPipeError: return async with trio.open_nursery() as nursery: nursery.start_soon(proxy_one_way, parent, peer) nursery.start_soon(proxy_one_way, peer, parent) def _default_capture_cb( target: "trio.MemorySendChannel[netfilterqueue.Packet]", packet: netfilterqueue.Packet, ) -> None: packet.retain() target.send_nowait(packet) class Harness: def __init__(self): self._received = {} self._conn = {} self.dest_addr = {} self.failed = False async def _run_peer(self, idx: int, *, task_status): their_ip = PEER_IP[idx] my_ip = ROUTER_IP[idx] conn, child_conn = trio.socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET) with conn: try: process = await trio.open_process( [sys.executable, __file__, str(idx), str(child_conn.fileno())], stdin=subprocess.DEVNULL, pass_fds=[child_conn.fileno()], preexec_fn=enter_netns, ) finally: child_conn.close() assert b"ok" == await conn.recv(4096) for cmd in ( f"ip link add veth{idx} type veth peer netns {process.pid} name veth0", f"ip link set veth{idx} up", f"ip addr add {my_ip}/24 dev veth{idx}", ): await trio.run_process(cmd.split()) try: await conn.send(b"ok") self._conn[idx] = conn task_status.started() retval = await process.wait() except BaseException: process.kill() with trio.CancelScope(shield=True): await process.wait() raise else: if retval != 0: raise RuntimeError( "peer subprocess exited with code {}".format(retval) ) finally: # On some kernels the veth device is removed when the subprocess exits # and its netns goes away. check=False to suppress that error. await trio.run_process(f"ip link delete veth{idx}".split(), check=False) async def _manage_peer(self, idx: int, *, task_status): async with trio.open_nursery() as nursery: await nursery.start(self._run_peer, idx) packets_w, packets_r = trio.open_memory_channel(math.inf) self._received[idx] = packets_r task_status.started() async with packets_w: while True: msg = await self._conn[idx].recv(4096) if not msg: break await packets_w.send(msg) @asynccontextmanager async def run(self): async with trio.open_nursery() as nursery: async with trio.open_nursery() as start_nursery: start_nursery.start_soon(nursery.start, self._manage_peer, 1) start_nursery.start_soon(nursery.start, self._manage_peer, 2) # Tell each peer about the other one's port for idx in (1, 2): self.dest_addr[idx] = ( PEER_IP[idx], int(await self._received[idx].receive()) ) await self._conn[3 - idx].send(b"%d" % self.dest_addr[idx][1]) yield self._conn[1].shutdown(socket.SHUT_WR) self._conn[2].shutdown(socket.SHUT_WR) if not self.failed: for idx in (1, 2): async for remainder in self._received[idx]: raise AssertionError( f"Peer {idx} received unexepcted packet {remainder!r}" ) def bind_queue( self, cb: Callable[[netfilterqueue.Packet], None], *, queue_num: int = -1, **options: int, ) -> Tuple[int, netfilterqueue.NetfilterQueue]: nfq = netfilterqueue.NetfilterQueue() # Use a smaller socket buffer to avoid a warning in CI options.setdefault("sock_len", 131072) if queue_num >= 0: nfq.bind(queue_num, cb, **options) else: for queue_num in range(16): try: nfq.bind(queue_num, cb, **options) break except Exception as ex: last_error = ex else: raise RuntimeError( "Couldn't bind any netfilter queue number between 0-15" ) from last_error return queue_num, nfq @asynccontextmanager async def enqueue_packets_to( self, idx: int, queue_num: int, *, forwarded: bool = True ) -> AsyncIterator[None]: if forwarded: chain = "FORWARD" else: chain = "OUTPUT" rule = f"{chain} -d {PEER_IP[idx]} -j NFQUEUE --queue-num {queue_num}" await trio.run_process(f"/sbin/iptables -A {rule}".split()) try: yield finally: await trio.run_process(f"/sbin/iptables -D {rule}".split()) @asynccontextmanager async def capture_packets_to( self, idx: int, cb: Callable[ ["trio.MemorySendChannel[netfilterqueue.Packet]", netfilterqueue.Packet], None, ] = _default_capture_cb, **options: int, ) -> AsyncIterator["trio.MemoryReceiveChannel[netfilterqueue.Packet]"]: packets_w, packets_r = trio.open_memory_channel(math.inf) queue_num, nfq = self.bind_queue(partial(cb, packets_w), **options) try: async with self.enqueue_packets_to(idx, queue_num): async with packets_w, trio.open_nursery() as nursery: @nursery.start_soon async def listen_for_packets(): while True: await trio.lowlevel.wait_readable(nfq.get_fd()) nfq.run(block=False) yield packets_r nursery.cancel_scope.cancel() finally: nfq.unbind() async def expect(self, idx: int, *packets: bytes): for expected in packets: with trio.move_on_after(5) as scope: received = await self._received[idx].receive() if scope.cancelled_caught: self.failed = True raise AssertionError( f"Timeout waiting for peer {idx} to receive {expected!r}" ) if received != expected: self.failed = True raise AssertionError( f"Expected peer {idx} to receive {expected!r} but it " f"received {received!r}" ) async def send(self, idx: int, *packets: bytes): for packet in packets: await self._conn[3 - idx].send(packet) @pytest.fixture async def harness() -> Harness: h = Harness() async with h.run(): yield h if __name__ == "__main__": trio.run(peer_main, int(sys.argv[1]), int(sys.argv[2]))