m1n1/proxyclient/proxy.py
Hector Martin 7dfe24ee2c Rework kboot/chainload flow to shut down before calling the next stage
Next stage boots now exit back to main() after replying to the proxy
command, allowing shutdown functions to be called. Introduces a new
P_VECTOR proxy op, distinct from P_CALL. The Python side is reworked
to remain compatible with older versions that do not support this.

Signed-off-by: Hector Martin <marcan@marcan.st>
2021-04-17 18:12:59 +09:00

716 lines
22 KiB
Python
Executable file

#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
import os, sys, struct
from serial.tools.miniterm import Miniterm
def hexdump(s, sep=" "):
return sep.join(["%02x"%x for x in s])
def hexdump32(s, sep=" "):
vals = struct.unpack("<%dI" % (len(s)//4), s)
return sep.join(["%08x"%x for x in vals])
def ascii(s):
s2 = ""
for c in s:
if c < 0x20 or c > 0x7e:
s2 += "."
else:
s2 += chr(c)
return s2
def pad(s,c,l):
if len(s) < l:
s += c * (l - len(s))
return s
def chexdump(s,st=0):
for i in range(0,len(s),16):
print("%08x %s %s |%s|" % (
i + st,
hexdump(s[i:i+8], ' ').rjust(23),
hexdump(s[i+8:i+16], ' ').rjust(23),
ascii(s[i:i+16]).rjust(16)))
def chexdump32(s, st=0, abbreviate=True):
last = None
skip = False
for i in range(0,len(s),32):
val = s[i:i+32]
if val == last and abbreviate:
if not skip:
print("%08x *" % (i + st))
skip = True
else:
print("%08x %s" % (
i + st,
hexdump32(val, ' ')))
last = val
skip = False
class UartError(RuntimeError):
pass
class UartTimeout(UartError):
pass
class UartCMDError(UartError):
pass
class UartChecksumError(UartError):
pass
class UartRemoteError(UartError):
pass
class UartInterface:
REQ_NOP = 0x00AA55FF
REQ_PROXY = 0x01AA55FF
REQ_MEMREAD = 0x02AA55FF
REQ_MEMWRITE = 0x03AA55FF
REQ_BOOT = 0x04AA55FF
ST_OK = 0
ST_BADCMD = -1
ST_INVAL = -2
ST_XFERERR = -3
ST_CSUMERR = -4
CMD_LEN = 56
REPLY_LEN = 36
def __init__(self, device, debug=False):
self.debug = debug
self.dev = device
self.dev.timeout = 0
self.dev.flushOutput()
self.dev.flushInput()
self.pted = False
#d = self.dev.read(1)
#while d != "":
#d = self.dev.read(1)
self.dev.timeout = 3
self.tty_enable = True
def checksum(self, data):
sum = 0xDEADBEEF;
for c in data:
sum *= 31337
sum += c ^ 0x5a
sum &= 0xFFFFFFFF
return (sum ^ 0xADDEDBAD) & 0xFFFFFFFF
def readfull(self, size):
d = b''
while len(d) < size:
block = self.dev.read(size - len(d))
if not block:
raise UartTimeout("Expected %d bytes, got %d bytes"%(size,len(d)))
d += block
return d
def cmd(self, cmd, payload=b""):
if len(payload) > self.CMD_LEN:
raise ValueError("Incorrect payload size %d"%len(payload))
payload = payload.ljust(self.CMD_LEN, b"\x00")
command = struct.pack("<I", cmd) + payload
command += struct.pack("<I", self.checksum(command))
if self.debug:
print("<<", hexdump(command))
self.dev.write(command)
def unkhandler(self, s):
if not self.tty_enable:
return
for c in s:
if not self.pted:
sys.stdout.write("TTY> ")
self.pted = True
if c == 10:
self.pted = False
sys.stdout.write(chr(c))
sys.stdout.flush()
def ttymode(self):
tout = self.dev.timeout
self.tty_enable = True
self.dev.timeout = None
term = Miniterm(self.dev, eol='cr')
term.exit_character = chr(0x1d) # GS/CTRL+]
term.menu_character = chr(0x14) # Menu: CTRL+T
term.raw = True
term.set_rx_encoding('UTF-8')
term.set_tx_encoding('UTF-8')
print('--- TTY mode | Quit: CTRL+] | Menu: CTRL+T ---')
term.start()
try:
term.join(True)
except KeyboardInterrupt:
pass
print('--- Exit TTY mode ---')
term.join()
term.close()
self.dev.timeout = tout
self.tty_enable = False
def reply(self, cmd):
reply = b''
while True:
if not reply or reply[-1] != 255:
reply = b''
reply += self.readfull(1)
if reply != b"\xff":
self.unkhandler(reply)
continue
else:
reply = b'\xff'
reply += self.readfull(1)
if reply != b"\xff\x55":
self.unkhandler(reply)
continue
reply += self.readfull(1)
if reply != b"\xff\x55\xaa":
self.unkhandler(reply)
continue
reply += self.readfull(self.REPLY_LEN - 3)
if self.debug:
print(">>", hexdump(reply))
cmdin, status, data, checksum = struct.unpack("<Ii24sI", reply)
ccsum = self.checksum(reply[:-4])
if checksum != ccsum:
print("Reply checksum error: Expected 0x%08x, got 0x%08x"%(checksum, ccsum))
raise UartChecksumError()
if cmdin != cmd:
if cmdin == self.REQ_BOOT:
# Proxy rebooted in the meantime, try again
return self.reply(cmd)
raise UartCMDError("Reply command mismatch: Expected 0x%08x, got 0x%08x"%(cmd, cmdin))
if status != self.ST_OK:
if status == self.ST_BADCMD:
raise UartRemoteError("Reply error: Bad Command")
elif status == self.ST_INVAL:
raise UartRemoteError("Reply error: Invalid argument")
elif status == self.ST_XFERERR:
raise UartRemoteError("Reply error: Data transfer failed")
elif status == self.ST_CSUMERR:
raise UartRemoteError("Reply error: Data checksum failed")
else:
raise UartRemoteError("Reply error: Unknown error (%d)"%status)
return data
def wait_boot(self):
self.reply(self.REQ_BOOT)
def nop(self):
self.cmd(self.REQ_NOP)
self.reply(self.REQ_NOP)
def proxyreq(self, req, reboot=False, no_reply=False, pre_reply=None):
self.cmd(self.REQ_PROXY, req)
if pre_reply:
pre_reply()
if no_reply:
return
elif reboot:
return self.reply(self.REQ_BOOT)
else:
return self.reply(self.REQ_PROXY)
def writemem(self, addr, data, progress=False):
checksum = self.checksum(data)
size = len(data)
req = struct.pack("<QQI", addr, size, checksum)
self.cmd(self.REQ_MEMWRITE, req)
if self.debug:
print("<< DATA:")
chexdump(data)
for i in range(0, len(data), 8192):
self.dev.write(data[i:i + 8192])
if progress:
sys.stdout.write(".")
sys.stdout.flush()
if progress:
print()
# should automatically report a CRC failure
self.reply(self.REQ_MEMWRITE)
def readmem(self, addr, size):
req = struct.pack("<QQ", addr, size)
self.cmd(self.REQ_MEMREAD, req)
reply = self.reply(self.REQ_MEMREAD)
checksum = struct.unpack("<I",reply[:4])[0]
data = self.readfull(size)
if self.debug:
print(">> DATA:")
chexdump(data)
ccsum = self.checksum(data)
if checksum != ccsum:
raise UartChecksumError("Reply data checksum error: Expected 0x%08x, got 0x%08x"%(checksum, ccsum))
return data
def readstruct(self, addr, stype):
return stype.parse(self.readmem(addr, stype.sizeof()))
class ProxyError(RuntimeError):
pass
class ProxyReplyError(ProxyError):
pass
class ProxyRemoteError(ProxyError):
pass
class ProxyCommandError(ProxyRemoteError):
pass
class AlignmentError(Exception):
pass
class M1N1Proxy:
S_OK = 0
S_BADCMD = -1
P_NOP = 0x000
P_EXIT = 0x001
P_CALL = 0x002
P_GET_BOOTARGS = 0x003
P_GET_BASE = 0x004
P_SET_BAUD = 0x005
P_UDELAY = 0x006
P_SET_EXC_GUARD = 0x007
P_GET_EXC_COUNT = 0x008
P_EL0_CALL = 0x009
P_EL1_CALL = 0x00a
P_VECTOR = 0x00b
GUARD_OFF = 0
GUARD_SKIP = 1
GUARD_MARK = 2
GUARD_RETURN = 3
GUARD_SILENT = 0x100
IODEV_UART = 0
IODEV_FB = 1
USAGE_CONSOLE = (1 << 0)
USAGE_UARTPROXY = (1 << 1)
P_WRITE64 = 0x100
P_WRITE32 = 0x101
P_WRITE16 = 0x102
P_WRITE8 = 0x103
P_READ64 = 0x104
P_READ32 = 0x105
P_READ16 = 0x106
P_READ8 = 0x107
P_SET64 = 0x108
P_SET32 = 0x109
P_SET16 = 0x10a
P_SET8 = 0x10b
P_CLEAR64 = 0x10c
P_CLEAR32 = 0x10d
P_CLEAR16 = 0x10e
P_CLEAR8 = 0x10f
P_MASK64 = 0x110
P_MASK32 = 0x111
P_MASK16 = 0x112
P_MASK8 = 0x113
P_WRITEREAD64 = 0x114
P_WRITEREAD32 = 0x115
P_WRITEREAD16 = 0x116
P_WRITEREAD8 = 0x117
P_MEMCPY64 = 0x200
P_MEMCPY32 = 0x201
P_MEMCPY16 = 0x202
P_MEMCPY8 = 0x203
P_MEMSET64 = 0x204
P_MEMSET32 = 0x205
P_MEMSET16 = 0x206
P_MEMSET8 = 0x207
P_IC_IALLUIS = 0x300
P_IC_IALLU = 0x301
P_IC_IVAU = 0x302
P_DC_IVAC = 0x303
P_DC_ISW = 0x304
P_DC_CSW = 0x305
P_DC_CISW = 0x306
P_DC_ZVA = 0x307
P_DC_CVAC = 0x308
P_DC_CVAU = 0x309
P_DC_CIVAC = 0x30a
P_MMU_SHUTDOWN = 0x30b
P_XZDEC = 0x400
P_GZDEC = 0x401
P_SMP_START_SECONDARIES = 0x500
P_SMP_CALL = 0x501
P_SMP_CALL_SYNC = 0x502
P_HEAPBLOCK_ALLOC = 0x600
P_MALLOC = 0x601
P_MEMALIGN = 0x602
P_FREE = 0x602
P_KBOOT_BOOT = 0x700
P_KBOOT_SET_BOOTARGS = 0x701
P_KBOOT_SET_INITRD = 0x702
P_KBOOT_PREPARE_DT = 0x703
P_PMGR_CLOCK_ENABLE = 0x800
P_PMGR_CLOCK_DISABLE = 0x801
P_PMGR_ADT_CLOCKS_ENABLE = 0x802
P_PMGR_ADT_CLOCKS_DISABLE = 0x803
P_IODEV_SET_USAGE = 0x900
P_IODEV_CAN_READ = 0x901
P_IODEV_CAN_WRITE = 0x902
P_IODEV_READ = 0x903
P_IODEV_WRITE = 0x904
P_TUNABLES_APPLY_GLOBAL = 0xa00
P_TUNABLES_APPLY_LOCAL = 0xa01
P_DART_INIT = 0xb00
P_DART_SHUTDOWN = 0xb01
P_DART_MAP = 0xb02
P_DART_UNMAP = 0xb03
def __init__(self, iface, debug=False):
self.debug = debug
self.iface = iface
self.heap = None
def _request(self, opcode, *args, reboot=False, signed=False, no_reply=False, pre_reply=None):
if len(args) > 6:
raise ValueError("Too many arguments")
args = list(args) + [0] * (6 - len(args))
req = struct.pack("<7Q", opcode, *args)
if self.debug:
print("<<<< %08x: %08x %08x %08x %08x %08x %08x"%tuple([opcode] + args))
reply = self.iface.proxyreq(req, reboot=reboot, no_reply=no_reply, pre_reply=None)
if no_reply:
return
ret_fmt = "q" if signed else "Q"
rop, status, retval = struct.unpack("<Qq" + ret_fmt, reply)
if self.debug:
print(">>>> %08x: %d %08x"%(rop, status, retval))
if reboot:
return
if rop != opcode:
raise ProxyReplyError("Reply opcode mismatch: Expected 0x%08x, got 0x%08x"%(opcode,rop))
if status != self.S_OK:
if status == self.S_BADCMD:
raise ProxyCommandError("Reply error: Bad Command")
else:
raise ProxyRemoteError("Reply error: Unknown error (%d)"%status)
return retval
def request(self, opcode, *args, **kwargs):
free = []
args = list(args)
args2 = []
for i, arg in enumerate(args):
if isinstance(arg, str):
arg = arg.encode("utf-8") + b"\0"
if isinstance(arg, bytes) and self.heap:
p = self.heap.malloc(len(arg))
free.append(p)
self.iface.writemem(p, arg)
if (i < (len(args) - 1)) and args[i + 1] is None:
args[i + 1] = len(arg)
arg = p
args2.append(arg)
try:
return self._request(opcode, *args2, **kwargs)
finally:
for i in free:
self.heap.free(i)
def nop(self):
self.request(self.P_NOP)
def exit(self):
self.request(self.P_EXIT)
def call(self, addr, *args):
if len(args) > 4:
raise ValueError("Too many arguments")
return self.request(self.P_CALL, addr, *args)
def reboot(self, addr, *args, el1=False):
if len(args) > 4:
raise ValueError("Too many arguments")
if el1:
self.request(self.P_EL1_CALL, addr, *args, no_reply=True)
else:
try:
self.request(self.P_VECTOR, addr, *args)
self.iface.wait_boot()
except ProxyCommandError: # old m1n1 does not support P_VECTOR
try:
self.mmu_shutdown()
except ProxyCommandError: # older m1n1 does not support MMU
pass
self.request(self.P_CALL, addr, *args, reboot=True)
def get_bootargs(self):
return self.request(self.P_GET_BOOTARGS)
def get_base(self):
return self.request(self.P_GET_BASE)
def set_baud(self, baudrate):
self.iface.tty_enable = False
def change():
self.iface.dev.baudrate = baudrate
try:
self.request(self.P_SET_BAUD, baudrate, 16, 0x005aa5f0, pre_reply=change)
finally:
self.iface.tty_enable = True
def udelay(self, usec):
self.request(self.P_UDELAY, usec)
def set_exc_guard(self, mode):
self.request(self.P_SET_EXC_GUARD, mode)
def get_exc_count(self):
return self.request(self.P_GET_EXC_COUNT)
def el0_call(self, addr, *args):
if len(args) > 4:
raise ValueError("Too many arguments")
return self.request(self.P_EL0_CALL, addr, *args)
def el1_call(self, addr, *args):
if len(args) > 4:
raise ValueError("Too many arguments")
return self.request(self.P_EL1_CALL, addr, *args)
def write64(self, addr, data):
if addr & 7:
raise AlignmentError()
self.request(self.P_WRITE64, addr, data)
def write32(self, addr, data):
if addr & 3:
raise AlignmentError()
self.request(self.P_WRITE32, addr, data)
def write16(self, addr, data):
if addr & 1:
raise AlignmentError()
self.request(self.P_WRITE16, addr, data)
def write8(self, addr, data):
self.request(self.P_WRITE8, addr, data)
def read64(self, addr):
if addr & 7:
raise AlignmentError()
return self.request(self.P_READ64, addr)
def read32(self, addr):
if addr & 3:
raise AlignmentError()
return self.request(self.P_READ32, addr)
def read16(self, addr):
if addr & 1:
raise AlignmentError()
return self.request(self.P_READ16, addr)
def read8(self, addr):
return self.request(self.P_READ8, addr)
def set64(self, addr, data):
if addr & 7:
raise AlignmentError()
self.request(self.P_SET64, addr, data)
def set32(self, addr, data):
if addr & 3:
raise AlignmentError()
self.request(self.P_SET32, addr, data)
def set16(self, addr, data):
if addr & 1:
raise AlignmentError()
self.request(self.P_SET16, addr, data)
def set8(self, addr, data):
self.request(self.P_SET8, addr, data)
def clear64(self, addr, data):
if addr & 7:
raise AlignmentError()
self.request(self.P_CLEAR64, addr, data)
def clear32(self, addr, data):
if addr & 3:
raise AlignmentError()
self.request(self.P_CLEAR32, addr, data)
def clear16(self, addr, data):
if addr & 1:
raise AlignmentError()
self.request(self.P_CLEAR16, addr, data)
def clear8(self, addr, data):
self.request(self.P_CLEAR8, addr, data)
def mask64(self, addr, clear, set):
if addr & 7:
raise AlignmentError()
self.request(self.P_MASK64, addr, clear, set)
def mask32(self, addr, clear, set):
if addr & 3:
raise AlignmentError()
self.request(self.P_MASK32, addr, clear, set)
def mask16(self, addr, clear, set):
if addr & 1:
raise AlignmentError()
self.request(self.P_MASK16, addr, clear, set)
def mask8(self, addr, clear, set):
self.request(self.P_MASK8, addr, clear, set)
def writeread64(self, addr, data):
return self.request(self.P_WRITEREAD64, addr, data)
def writeread32(self, addr, data):
return self.request(self.P_WRITEREAD32, addr, data)
def writeread16(self, addr, data):
return self.request(self.P_WRITEREAD16, addr, data)
def writeread8(self, addr, data):
return self.request(self.P_WRITEREAD8, addr, data)
def memcpy64(self, dst, src, size):
if src & 7 or dst & 7:
raise AlignmentError()
self.request(self.P_MEMCPY64, dst, src, size)
def memcpy32(self, dst, src, size):
if src & 3 or dst & 3:
raise AlignmentError()
self.request(self.P_MEMCPY32, dst, src, size)
def memcpy16(self, dst, src, size):
if src & 1 or dst & 1:
raise AlignmentError()
self.request(self.P_MEMCPY16, dst, src, size)
def memcpy8(self, dst, src, size):
self.request(self.P_MEMCPY8, dst, src, size)
def memset64(self, dst, src, size):
if dst & 7:
raise AlignmentError()
self.request(self.P_MEMSET64, dst, src, size)
def memset32(self, dst, src, size):
if dst & 3:
raise AlignmentError()
self.request(self.P_MEMSET32, dst, src, size)
def memset16(self, dst, src, size):
if dst & 1:
raise AlignmentError()
self.request(self.P_MEMSET16, dst, src, size)
def memset8(self, dst, src, size):
self.request(self.P_MEMSET8, dst, src, size)
def ic_ialluis(self):
self.request(self.P_IC_IALLUIS)
def ic_iallu(self):
self.request(self.P_IC_IALLU)
def ic_ivau(self, addr, size):
self.request(self.P_IC_IVAU, addr, size)
def ic_ivac(self, addr, size):
self.request(self.P_IC_IVAC, addr, size)
def dc_isw(self, sw):
self.request(self.P_DC_ISW, sw)
def dc_csw(self, sw):
self.request(self.P_DC_CSW, sw)
def dc_cisw(self, sw):
self.request(self.P_DC_CISW, sw)
def dc_zva(self, addr, size):
self.request(self.P_DC_ZVA, addr, size)
def dc_cvac(self, addr, size):
self.request(self.P_DC_CVAC, addr, size)
def dc_cvau(self, addr, size):
self.request(self.P_DC_CVAU, addr, size)
def dc_civac(self, addr, size):
self.request(self.P_DC_CIVAC, addr, size)
def mmu_shutdown(self):
self.request(self.P_MMU_SHUTDOWN)
def xzdec(self, inbuf, insize, outbuf=0, outsize=0):
return self.request(self.P_XZDEC, inbuf, insize, outbuf,
outsize, signed=True)
def gzdec(self, inbuf, insize, outbuf, outsize):
return self.request(self.P_GZDEC, inbuf, insize, outbuf,
outsize, signed=True)
def smp_start_secondaries(self):
self.request(self.P_SMP_START_SECONDARIES)
def smp_call(self, cpu, addr, *args):
if len(args) > 4:
raise ValueError("Too many arguments")
self.request(self.P_SMP_CALL, cpu, addr, *args)
def smp_call_sync(self, cpu, addr, *args):
if len(args) > 4:
raise ValueError("Too many arguments")
return self.request(self.P_SMP_CALL_SYNC, cpu, addr, *args)
def heapblock_alloc(self, size):
return self.request(self.P_HEAPBLOCK_ALLOC, size)
def malloc(self, size):
return self.request(self.P_MALLOC, size)
def memalign(self, align, size):
return self.request(self.P_MEMALIGN, align, size)
def free(self, ptr):
self.request(self.P_FREE, ptr)
def kboot_boot(self, kernel):
self.request(self.P_KBOOT_BOOT, kernel)
def kboot_set_bootargs(self, bootargs):
self.request(self.P_KBOOT_SET_BOOTARGS, bootargs)
def kboot_set_initrd(self, base, size):
self.request(self.P_KBOOT_SET_INITRD, base, size)
def kboot_prepare_dt(self, dt_addr):
return self.request(self.P_KBOOT_PREPARE_DT, dt_addr)
def pmgr_clock_enable(self, clkid):
return self.request(self.P_PMGR_CLOCK_ENABLE, clkid)
def pmgr_clock_disable(self, clkid):
return self.request(self.P_PMGR_CLOCK_DISABLE, clkid)
def pmgr_adt_clocks_enable(self, path):
return self.request(self.P_PMGR_ADT_CLOCKS_ENABLE, path)
def pmgr_adt_clocks_disable(self, path):
return self.request(self.P_PMGR_ADT_CLOCKS_DISABLE, path)
def iodev_set_usage(self, iodev, usage):
return self.request(self.P_IODEV_SET_USAGE, iodev, usage)
def iodev_can_read(self, iodev):
return self.request(self.P_IODEV_CAN_READ, iodev)
def iodev_can_write(self, iodev):
return self.request(self.P_IODEV_CAN_WRITE, iodev)
def iodev_read(self, iodev, buf, size=None):
return self.request(self.P_IODEV_READ, iodev, buf, size)
def iodev_write(self, iodev, buf, size=None):
return self.request(self.P_IODEV_WRITE, iodev, buf, size)
def tunables_apply_global(self, path, prop):
return self.request(self.P_TUNABLES_APPLY_GLOBAL, path, prop)
def tunables_apply_local(self, path, prop, reg_offset):
return self.request(self.P_TUNABLES_APPLY_LOCAL, path, prop, reg_offset)
def tunables_apply_local_addr(self, path, prop, base):
return self.request(self.P_TUNABLES_APPLY_LOCAL, path, prop, base)
def dart_init(self, base, sid):
return self.request(self.P_DART_INIT, base, sid)
def dart_shutdown(self, dart):
return self.request(self.P_DART_SHUTDOWN, dart)
def dart_map(self, dart, iova, bfr, len):
return self.request(self.P_DART_MAP, dart, iova, bfr, len)
def dart_unmap(self, dart, iova, len):
return self.request(self.P_DART_UNMAP, dart, iova, len)
if __name__ == "__main__":
import serial
uartdev = os.environ.get("M1N1DEVICE", "/dev/ttyUSB0")
usbuart = serial.Serial(uartdev, 115200)
uartif = UartInterface(usbuart, debug=True)
print("Sending NOP...", end=' ')
uartif.nop()
print("OK")
proxy = M1N1Proxy(uartif, debug=True)
print("Sending Proxy NOP...", end=' ')
proxy.nop()
print("OK")
print("Boot args: 0x%x" % proxy.get_bootargs())