266 lines
7.9 KiB
Python
266 lines
7.9 KiB
Python
|
#! /usr/bin/env python3
|
||
|
# SPDX-License-Identifier: GPL-2.0
|
||
|
|
||
|
import argparse
|
||
|
import ctypes
|
||
|
import errno
|
||
|
import hashlib
|
||
|
import os
|
||
|
import select
|
||
|
import signal
|
||
|
import socket
|
||
|
import subprocess
|
||
|
import sys
|
||
|
import atexit
|
||
|
from pwd import getpwuid
|
||
|
from os import stat
|
||
|
|
||
|
# Allow utils module to be imported from different directory
|
||
|
this_dir = os.path.dirname(os.path.realpath(__file__))
|
||
|
sys.path.append(os.path.join(this_dir, "../"))
|
||
|
from lib.py.utils import ip
|
||
|
|
||
|
libc = ctypes.cdll.LoadLibrary('libc.so.6')
|
||
|
setns = libc.setns
|
||
|
|
||
|
net0 = 'net0'
|
||
|
net1 = 'net1'
|
||
|
|
||
|
veth0 = 'veth0'
|
||
|
veth1 = 'veth1'
|
||
|
|
||
|
# Helper function for creating a socket inside a network namespace.
|
||
|
# We need this because otherwise RDS will detect that the two TCP
|
||
|
# sockets are on the same interface and use the loop transport instead
|
||
|
# of the TCP transport.
|
||
|
def netns_socket(netns, *args):
|
||
|
u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
|
||
|
|
||
|
child = os.fork()
|
||
|
if child == 0:
|
||
|
# change network namespace
|
||
|
with open(f'/var/run/netns/{netns}') as f:
|
||
|
try:
|
||
|
ret = setns(f.fileno(), 0)
|
||
|
except IOError as e:
|
||
|
print(e.errno)
|
||
|
print(e)
|
||
|
|
||
|
# create socket in target namespace
|
||
|
s = socket.socket(*args)
|
||
|
|
||
|
# send resulting socket to parent
|
||
|
socket.send_fds(u0, [], [s.fileno()])
|
||
|
|
||
|
sys.exit(0)
|
||
|
|
||
|
# receive socket from child
|
||
|
_, s, _, _ = socket.recv_fds(u1, 0, 1)
|
||
|
os.waitpid(child, 0)
|
||
|
u0.close()
|
||
|
u1.close()
|
||
|
return socket.fromfd(s[0], *args)
|
||
|
|
||
|
def signal_handler(sig, frame):
|
||
|
print('Test timed out')
|
||
|
sys.exit(1)
|
||
|
|
||
|
#Parse out command line arguments. We take an optional
|
||
|
# timeout parameter and an optional log output folder
|
||
|
parser = argparse.ArgumentParser(description="init script args",
|
||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||
|
parser.add_argument("-d", "--logdir", action="store",
|
||
|
help="directory to store logs", default="/tmp")
|
||
|
parser.add_argument('--timeout', help="timeout to terminate hung test",
|
||
|
type=int, default=0)
|
||
|
parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
|
||
|
type=int, default=0)
|
||
|
parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
|
||
|
type=int, default=0)
|
||
|
parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
|
||
|
type=int, default=0)
|
||
|
args = parser.parse_args()
|
||
|
logdir=args.logdir
|
||
|
packet_loss=str(args.loss)+'%'
|
||
|
packet_corruption=str(args.corruption)+'%'
|
||
|
packet_duplicate=str(args.duplicate)+'%'
|
||
|
|
||
|
ip(f"netns add {net0}")
|
||
|
ip(f"netns add {net1}")
|
||
|
ip(f"link add type veth")
|
||
|
|
||
|
addrs = [
|
||
|
# we technically don't need different port numbers, but this will
|
||
|
# help identify traffic in the network analyzer
|
||
|
('10.0.0.1', 10000),
|
||
|
('10.0.0.2', 20000),
|
||
|
]
|
||
|
|
||
|
# move interfaces to separate namespaces so they can no longer be
|
||
|
# bound directly; this prevents rds from switching over from the tcp
|
||
|
# transport to the loop transport.
|
||
|
ip(f"link set {veth0} netns {net0} up")
|
||
|
ip(f"link set {veth1} netns {net1} up")
|
||
|
|
||
|
|
||
|
|
||
|
# add addresses
|
||
|
ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}")
|
||
|
ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}")
|
||
|
|
||
|
# add routes
|
||
|
ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}")
|
||
|
ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}")
|
||
|
|
||
|
# sanity check that our two interfaces/addresses are correctly set up
|
||
|
# and communicating by doing a single ping
|
||
|
ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}")
|
||
|
|
||
|
# Start a packet capture on each network
|
||
|
for net in [net0, net1]:
|
||
|
tcpdump_pid = os.fork()
|
||
|
if tcpdump_pid == 0:
|
||
|
pcap = logdir+'/'+net+'.pcap'
|
||
|
subprocess.check_call(['touch', pcap])
|
||
|
user = getpwuid(stat(pcap).st_uid).pw_name
|
||
|
ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}")
|
||
|
sys.exit(0)
|
||
|
|
||
|
# simulate packet loss, duplication and corruption
|
||
|
for net, iface in [(net0, veth0), (net1, veth1)]:
|
||
|
ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem \
|
||
|
corrupt {packet_corruption} loss {packet_loss} duplicate \
|
||
|
{packet_duplicate}")
|
||
|
|
||
|
# add a timeout
|
||
|
if args.timeout > 0:
|
||
|
signal.alarm(args.timeout)
|
||
|
signal.signal(signal.SIGALRM, signal_handler)
|
||
|
|
||
|
sockets = [
|
||
|
netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET),
|
||
|
netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET),
|
||
|
]
|
||
|
|
||
|
for s, addr in zip(sockets, addrs):
|
||
|
s.bind(addr)
|
||
|
s.setblocking(0)
|
||
|
|
||
|
fileno_to_socket = {
|
||
|
s.fileno(): s for s in sockets
|
||
|
}
|
||
|
|
||
|
addr_to_socket = {
|
||
|
addr: s for addr, s in zip(addrs, sockets)
|
||
|
}
|
||
|
|
||
|
socket_to_addr = {
|
||
|
s: addr for addr, s in zip(addrs, sockets)
|
||
|
}
|
||
|
|
||
|
send_hashes = {}
|
||
|
recv_hashes = {}
|
||
|
|
||
|
ep = select.epoll()
|
||
|
|
||
|
for s in sockets:
|
||
|
ep.register(s, select.EPOLLRDNORM)
|
||
|
|
||
|
n = 50000
|
||
|
nr_send = 0
|
||
|
nr_recv = 0
|
||
|
|
||
|
while nr_send < n:
|
||
|
# Send as much as we can without blocking
|
||
|
print("sending...", nr_send, nr_recv)
|
||
|
while nr_send < n:
|
||
|
send_data = hashlib.sha256(
|
||
|
f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')
|
||
|
|
||
|
# pseudo-random send/receive pattern
|
||
|
sender = sockets[nr_send % 2]
|
||
|
receiver = sockets[1 - (nr_send % 3) % 2]
|
||
|
|
||
|
try:
|
||
|
sender.sendto(send_data, socket_to_addr[receiver])
|
||
|
send_hashes.setdefault((sender.fileno(), receiver.fileno()),
|
||
|
hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8'))
|
||
|
nr_send = nr_send + 1
|
||
|
except BlockingIOError as e:
|
||
|
break
|
||
|
except OSError as e:
|
||
|
if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]:
|
||
|
break
|
||
|
raise
|
||
|
|
||
|
# Receive as much as we can without blocking
|
||
|
print("receiving...", nr_send, nr_recv)
|
||
|
while nr_recv < nr_send:
|
||
|
for fileno, eventmask in ep.poll():
|
||
|
receiver = fileno_to_socket[fileno]
|
||
|
|
||
|
if eventmask & select.EPOLLRDNORM:
|
||
|
while True:
|
||
|
try:
|
||
|
recv_data, address = receiver.recvfrom(1024)
|
||
|
sender = addr_to_socket[address]
|
||
|
recv_hashes.setdefault((sender.fileno(),
|
||
|
receiver.fileno()), hashlib.sha256()).update(
|
||
|
f'<{recv_data}>'.encode('utf-8'))
|
||
|
nr_recv = nr_recv + 1
|
||
|
except BlockingIOError as e:
|
||
|
break
|
||
|
|
||
|
# exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
|
||
|
for net in [net0, net1]:
|
||
|
ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
|
||
|
ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
|
||
|
|
||
|
print("done", nr_send, nr_recv)
|
||
|
|
||
|
# the Python socket module doesn't know these
|
||
|
RDS_INFO_FIRST = 10000
|
||
|
RDS_INFO_LAST = 10017
|
||
|
|
||
|
nr_success = 0
|
||
|
nr_error = 0
|
||
|
|
||
|
for s in sockets:
|
||
|
for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1):
|
||
|
# Sigh, the Python socket module doesn't allow us to pass
|
||
|
# buffer lengths greater than 1024 for some reason. RDS
|
||
|
# wants multiple pages.
|
||
|
try:
|
||
|
s.getsockopt(socket.SOL_RDS, optname, 1024)
|
||
|
nr_success = nr_success + 1
|
||
|
except OSError as e:
|
||
|
nr_error = nr_error + 1
|
||
|
if e.errno == errno.ENOSPC:
|
||
|
# ignore
|
||
|
pass
|
||
|
|
||
|
print(f"getsockopt(): {nr_success}/{nr_error}")
|
||
|
|
||
|
print("Stopping network packet captures")
|
||
|
subprocess.check_call(['killall', '-q', 'tcpdump'])
|
||
|
|
||
|
# We're done sending and receiving stuff, now let's check if what
|
||
|
# we received is what we sent.
|
||
|
for (sender, receiver), send_hash in send_hashes.items():
|
||
|
recv_hash = recv_hashes.get((sender, receiver))
|
||
|
|
||
|
if recv_hash is None:
|
||
|
print("FAIL: No data received")
|
||
|
sys.exit(1)
|
||
|
|
||
|
if send_hash.hexdigest() != recv_hash.hexdigest():
|
||
|
print("FAIL: Send/recv mismatch")
|
||
|
print("hash expected:", send_hash.hexdigest())
|
||
|
print("hash received:", recv_hash.hexdigest())
|
||
|
sys.exit(1)
|
||
|
|
||
|
print(f"{sender}/{receiver}: ok")
|
||
|
|
||
|
print("Success")
|
||
|
sys.exit(0)
|