244 lines
8.4 KiB
Python
244 lines
8.4 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# SPDX-License-Identifier: GPL-2.0
|
||
|
|
||
|
import argparse
|
||
|
import errno
|
||
|
import logging
|
||
|
import socket
|
||
|
import struct
|
||
|
import time
|
||
|
|
||
|
import usb.core
|
||
|
import usb.util
|
||
|
|
||
|
|
||
|
def path_from_usb_dev(dev):
|
||
|
"""Takes a pyUSB device as argument and returns a string.
|
||
|
The string is a Path representation of the position of the USB device on the USB bus tree.
|
||
|
|
||
|
This path is used to find a USB device on the bus or all devices connected to a HUB.
|
||
|
The path is made up of the number of the USB controller followed be the ports of the HUB tree."""
|
||
|
if dev.port_numbers:
|
||
|
dev_path = ".".join(str(i) for i in dev.port_numbers)
|
||
|
return f"{dev.bus}-{dev_path}"
|
||
|
return ""
|
||
|
|
||
|
|
||
|
HEXDUMP_FILTER = "".join(chr(x).isprintable() and chr(x) or "." for x in range(128)) + "." * 128
|
||
|
|
||
|
|
||
|
class Forwarder:
|
||
|
@staticmethod
|
||
|
def _log_hexdump(data):
|
||
|
if not logging.root.isEnabledFor(logging.TRACE):
|
||
|
return
|
||
|
L = 16
|
||
|
for c in range(0, len(data), L):
|
||
|
chars = data[c : c + L]
|
||
|
dump = " ".join(f"{x:02x}" for x in chars)
|
||
|
printable = "".join(HEXDUMP_FILTER[x] for x in chars)
|
||
|
line = f"{c:08x} {dump:{L*3}s} |{printable:{L}s}|"
|
||
|
logging.root.log(logging.TRACE, "%s", line)
|
||
|
|
||
|
def __init__(self, server, vid, pid, path):
|
||
|
self.stats = {
|
||
|
"c2s packets": 0,
|
||
|
"c2s bytes": 0,
|
||
|
"s2c packets": 0,
|
||
|
"s2c bytes": 0,
|
||
|
}
|
||
|
self.stats_logged = time.monotonic()
|
||
|
|
||
|
def find_filter(dev):
|
||
|
dev_path = path_from_usb_dev(dev)
|
||
|
if path is not None:
|
||
|
return dev_path == path
|
||
|
return True
|
||
|
|
||
|
dev = usb.core.find(idVendor=vid, idProduct=pid, custom_match=find_filter)
|
||
|
if dev is None:
|
||
|
raise ValueError("Device not found")
|
||
|
|
||
|
logging.info(f"found device: {dev.bus}/{dev.address} located at {path_from_usb_dev(dev)}")
|
||
|
|
||
|
# dev.set_configuration() is not necessary since g_multi has only one
|
||
|
usb9pfs = None
|
||
|
# g_multi adds 9pfs as last interface
|
||
|
cfg = dev.get_active_configuration()
|
||
|
for intf in cfg:
|
||
|
# we have to detach the usb-storage driver from multi gadget since
|
||
|
# stall option could be set, which will lead to spontaneous port
|
||
|
# resets and our transfers will run dead
|
||
|
if intf.bInterfaceClass == 0x08:
|
||
|
if dev.is_kernel_driver_active(intf.bInterfaceNumber):
|
||
|
dev.detach_kernel_driver(intf.bInterfaceNumber)
|
||
|
|
||
|
if intf.bInterfaceClass == 0xFF and intf.bInterfaceSubClass == 0xFF and intf.bInterfaceProtocol == 0x09:
|
||
|
usb9pfs = intf
|
||
|
if usb9pfs is None:
|
||
|
raise ValueError("Interface not found")
|
||
|
|
||
|
logging.info(f"claiming interface:\n{usb9pfs}")
|
||
|
usb.util.claim_interface(dev, usb9pfs.bInterfaceNumber)
|
||
|
ep_out = usb.util.find_descriptor(
|
||
|
usb9pfs,
|
||
|
custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT,
|
||
|
)
|
||
|
assert ep_out is not None
|
||
|
ep_in = usb.util.find_descriptor(
|
||
|
usb9pfs,
|
||
|
custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN,
|
||
|
)
|
||
|
assert ep_in is not None
|
||
|
logging.info("interface claimed")
|
||
|
|
||
|
self.ep_out = ep_out
|
||
|
self.ep_in = ep_in
|
||
|
self.dev = dev
|
||
|
|
||
|
# create and connect socket
|
||
|
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
|
self.s.connect(server)
|
||
|
|
||
|
logging.info("connected to server")
|
||
|
|
||
|
def c2s(self):
|
||
|
"""forward a request from the USB client to the TCP server"""
|
||
|
data = None
|
||
|
while data is None:
|
||
|
try:
|
||
|
logging.log(logging.TRACE, "c2s: reading")
|
||
|
data = self.ep_in.read(self.ep_in.wMaxPacketSize)
|
||
|
except usb.core.USBTimeoutError:
|
||
|
logging.log(logging.TRACE, "c2s: reading timed out")
|
||
|
continue
|
||
|
except usb.core.USBError as e:
|
||
|
if e.errno == errno.EIO:
|
||
|
logging.debug("c2s: reading failed with %s, retrying", repr(e))
|
||
|
time.sleep(0.5)
|
||
|
continue
|
||
|
logging.error("c2s: reading failed with %s, aborting", repr(e))
|
||
|
raise
|
||
|
size = struct.unpack("<I", data[:4])[0]
|
||
|
while len(data) < size:
|
||
|
data += self.ep_in.read(size - len(data))
|
||
|
logging.log(logging.TRACE, "c2s: writing")
|
||
|
self._log_hexdump(data)
|
||
|
self.s.send(data)
|
||
|
logging.debug("c2s: forwarded %i bytes", size)
|
||
|
self.stats["c2s packets"] += 1
|
||
|
self.stats["c2s bytes"] += size
|
||
|
|
||
|
def s2c(self):
|
||
|
"""forward a response from the TCP server to the USB client"""
|
||
|
logging.log(logging.TRACE, "s2c: reading")
|
||
|
data = self.s.recv(4)
|
||
|
size = struct.unpack("<I", data[:4])[0]
|
||
|
while len(data) < size:
|
||
|
data += self.s.recv(size - len(data))
|
||
|
logging.log(logging.TRACE, "s2c: writing")
|
||
|
self._log_hexdump(data)
|
||
|
while data:
|
||
|
written = self.ep_out.write(data)
|
||
|
assert written > 0
|
||
|
data = data[written:]
|
||
|
if size % self.ep_out.wMaxPacketSize == 0:
|
||
|
logging.log(logging.TRACE, "sending zero length packet")
|
||
|
self.ep_out.write(b"")
|
||
|
logging.debug("s2c: forwarded %i bytes", size)
|
||
|
self.stats["s2c packets"] += 1
|
||
|
self.stats["s2c bytes"] += size
|
||
|
|
||
|
def log_stats(self):
|
||
|
logging.info("statistics:")
|
||
|
for k, v in self.stats.items():
|
||
|
logging.info(f" {k+':':14s} {v}")
|
||
|
|
||
|
def log_stats_interval(self, interval=5):
|
||
|
if (time.monotonic() - self.stats_logged) < interval:
|
||
|
return
|
||
|
|
||
|
self.log_stats()
|
||
|
self.stats_logged = time.monotonic()
|
||
|
|
||
|
|
||
|
def try_get_usb_str(dev, name):
|
||
|
try:
|
||
|
with open(f"/sys/bus/usb/devices/{dev.bus}-{dev.address}/{name}") as f:
|
||
|
return f.read().strip()
|
||
|
except FileNotFoundError:
|
||
|
return None
|
||
|
|
||
|
|
||
|
def list_usb(args):
|
||
|
vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
|
||
|
|
||
|
print("Bus | Addr | Manufacturer | Product | ID | Path")
|
||
|
print("--- | ---- | ---------------- | ---------------- | --------- | ----")
|
||
|
for dev in usb.core.find(find_all=True, idVendor=vid, idProduct=pid):
|
||
|
path = path_from_usb_dev(dev) or ""
|
||
|
manufacturer = try_get_usb_str(dev, "manufacturer") or "unknown"
|
||
|
product = try_get_usb_str(dev, "product") or "unknown"
|
||
|
print(
|
||
|
f"{dev.bus:3} | {dev.address:4} | {manufacturer:16} | {product:16} | {dev.idVendor:04x}:{dev.idProduct:04x} | {path:18}"
|
||
|
)
|
||
|
|
||
|
|
||
|
def connect(args):
|
||
|
vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
|
||
|
|
||
|
f = Forwarder(server=(args.server, args.port), vid=vid, pid=pid, path=args.path)
|
||
|
|
||
|
try:
|
||
|
while True:
|
||
|
f.c2s()
|
||
|
f.s2c()
|
||
|
f.log_stats_interval()
|
||
|
finally:
|
||
|
f.log_stats()
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description="Forward 9PFS requests from USB to TCP",
|
||
|
)
|
||
|
|
||
|
parser.add_argument("--id", type=str, default="1d6b:0109", help="vid:pid of target device")
|
||
|
parser.add_argument("--path", type=str, required=False, help="path of target device")
|
||
|
parser.add_argument("-v", "--verbose", action="count", default=0)
|
||
|
|
||
|
subparsers = parser.add_subparsers()
|
||
|
subparsers.required = True
|
||
|
subparsers.dest = "command"
|
||
|
|
||
|
parser_list = subparsers.add_parser("list", help="List all connected 9p gadgets")
|
||
|
parser_list.set_defaults(func=list_usb)
|
||
|
|
||
|
parser_connect = subparsers.add_parser(
|
||
|
"connect", help="Forward messages between the usb9pfs gadget and the 9p server"
|
||
|
)
|
||
|
parser_connect.set_defaults(func=connect)
|
||
|
connect_group = parser_connect.add_argument_group()
|
||
|
connect_group.required = True
|
||
|
parser_connect.add_argument("-s", "--server", type=str, default="127.0.0.1", help="server hostname")
|
||
|
parser_connect.add_argument("-p", "--port", type=int, default=564, help="server port")
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
logging.TRACE = logging.DEBUG - 5
|
||
|
logging.addLevelName(logging.TRACE, "TRACE")
|
||
|
|
||
|
if args.verbose >= 2:
|
||
|
level = logging.TRACE
|
||
|
elif args.verbose:
|
||
|
level = logging.DEBUG
|
||
|
else:
|
||
|
level = logging.INFO
|
||
|
logging.basicConfig(level=level, format="%(asctime)-15s %(levelname)-8s %(message)s")
|
||
|
|
||
|
args.func(args)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|