#!/usr/bin/env python3 # SPDX-License-Identifier: MIT import os, sys, struct, serial, time from construct import * from utils import * from sysreg import * from enum import IntEnum, IntFlag from serial.tools.miniterm import Miniterm def hexdump(s, sep=" "): return sep.join(["%02x"%x for x in s]) def hexdump32(s, sep=" "): vals = struct.unpack("<%dI" % (len(s)//4), s) return sep.join(["%08x"%x for x in vals]) def ascii(s): s2 = "" for c in s: if c < 0x20 or c > 0x7e: s2 += "." else: s2 += chr(c) return s2 def pad(s,c,l): if len(s) < l: s += c * (l - len(s)) return s def chexdump(s,st=0): for i in range(0,len(s),16): print("%08x %s %s |%s|" % ( i + st, hexdump(s[i:i+8], ' ').rjust(23), hexdump(s[i+8:i+16], ' ').rjust(23), ascii(s[i:i+16]).rjust(16))) def chexdump32(s, st=0, abbreviate=True): last = None skip = False for i in range(0,len(s),32): val = s[i:i+32] if val == last and abbreviate: if not skip: print("%08x *" % (i + st)) skip = True else: print("%08x %s" % ( i + st, hexdump32(val, ' '))) last = val skip = False # Hack to disable input buffer flushing class Serial(serial.Serial): def _reset_input_buffer(self): return def reset_input_buffer(self): super()._reset_input_buffer() class UartError(RuntimeError): pass class UartTimeout(UartError): pass class UartCMDError(UartError): pass class UartChecksumError(UartError): pass class UartRemoteError(UartError): pass class START(IntEnum): BOOT = 0 EXCEPTION = 1 EXCEPTION_LOWER = 2 class EXC(IntEnum): SYNC = 0 IRQ = 1 FIQ = 2 SERROR = 3 class EXC_RET(IntEnum): UNHANDLED = 1 HANDLED = 2 EXIT_GUEST = 3 ExcInfo = Struct( "spsr" / RegAdapter(SPSR), "elr" / Int64ul, "esr" / RegAdapter(ESR), "far" / Int64ul, "regs" / Array(31, Int64ul), "sp" / Array(3, Int64ul), "mpidr" / Int64ul, "elr_phys" / Int64ul, "far_phys" / Int64ul, "sp_phys" / Int64ul, ) class UartInterface: REQ_NOP = 0x00AA55FF REQ_PROXY = 0x01AA55FF REQ_MEMREAD = 0x02AA55FF REQ_MEMWRITE = 0x03AA55FF REQ_BOOT = 0x04AA55FF ST_OK = 0 ST_BADCMD = -1 ST_INVAL = -2 ST_XFERERR = -3 ST_CSUMERR = -4 CMD_LEN = 56 REPLY_LEN = 36 def __init__(self, device=None, debug=False): self.debug = debug self.devpath = None if device is None: device = os.environ.get("M1N1DEVICE", "/dev/ttyUSB0:115200") if isinstance(device, str): baud = 115200 if ":" in device: device, baud = device.rsplit(":", 1) baud = int(baud) self.devpath = device self.baudrate = baud device = Serial(self.devpath, baud) self.dev = device self.dev.timeout = 0 self.dev.flushOutput() self.dev.flushInput() self.pted = False #d = self.dev.read(1) #while d != "": #d = self.dev.read(1) self.dev.timeout = 3 self.tty_enable = True self.handlers = {} def checksum(self, data): sum = 0xDEADBEEF; for c in data: sum *= 31337 sum += c ^ 0x5a sum &= 0xFFFFFFFF return (sum ^ 0xADDEDBAD) & 0xFFFFFFFF def readfull(self, size): d = b'' while len(d) < size: block = self.dev.read(size - len(d)) if not block: raise UartTimeout("Expected %d bytes, got %d bytes"%(size,len(d))) d += block return d def cmd(self, cmd, payload=b""): if len(payload) > self.CMD_LEN: raise ValueError("Incorrect payload size %d"%len(payload)) payload = payload.ljust(self.CMD_LEN, b"\x00") command = struct.pack(" ") self.pted = True if c == 10: self.pted = False sys.stdout.write(chr(c)) sys.stdout.flush() def ttymode(self, dev=None): if dev is None: dev = self.dev tout = dev.timeout self.tty_enable = True dev.timeout = None term = Miniterm(dev, eol='cr') term.exit_character = chr(0x1d) # GS/CTRL+] term.menu_character = chr(0x14) # Menu: CTRL+T term.raw = True term.set_rx_encoding('UTF-8') term.set_tx_encoding('UTF-8') print('--- TTY mode | Quit: CTRL+] | Menu: CTRL+T ---') term.start() try: term.join(True) except KeyboardInterrupt: pass print('--- Exit TTY mode ---') term.join() term.close() dev.timeout = tout self.tty_enable = False def reply(self, cmd): reply = b'' while True: if not reply or reply[-1] != 255: reply = b'' reply += self.readfull(1) if reply != b"\xff": self.unkhandler(reply) continue else: reply = b'\xff' reply += self.readfull(1) if reply != b"\xff\x55": self.unkhandler(reply) continue reply += self.readfull(1) if reply != b"\xff\x55\xaa": self.unkhandler(reply) continue reply += self.readfull(self.REPLY_LEN - 3) if self.debug: print(">>", hexdump(reply)) cmdin, status, data, checksum = struct.unpack("> DATA:") chexdump(data) ccsum = self.checksum(data) if checksum != ccsum: raise UartChecksumError("Reply data checksum error: Expected 0x%08x, got 0x%08x"%(checksum, ccsum)) return data def readstruct(self, addr, stype): return stype.parse(self.readmem(addr, stype.sizeof())) class ProxyError(RuntimeError): pass class ProxyReplyError(ProxyError): pass class ProxyRemoteError(ProxyError): pass class ProxyCommandError(ProxyRemoteError): pass class AlignmentError(Exception): pass class IODEV(IntEnum): UART = 0 FB = 1 USB0 = 2 USB1 = 3 class USAGE(IntFlag): CONSOLE = (1 << 0) UARTPROXY = (1 << 1) class GUARD(IntFlag): OFF = 0 SKIP = 1 MARK = 2 RETURN = 3 SILENT = 0x100 class M1N1Proxy: S_OK = 0 S_BADCMD = -1 P_NOP = 0x000 P_EXIT = 0x001 P_CALL = 0x002 P_GET_BOOTARGS = 0x003 P_GET_BASE = 0x004 P_SET_BAUD = 0x005 P_UDELAY = 0x006 P_SET_EXC_GUARD = 0x007 P_GET_EXC_COUNT = 0x008 P_EL0_CALL = 0x009 P_EL1_CALL = 0x00a P_VECTOR = 0x00b P_WRITE64 = 0x100 P_WRITE32 = 0x101 P_WRITE16 = 0x102 P_WRITE8 = 0x103 P_READ64 = 0x104 P_READ32 = 0x105 P_READ16 = 0x106 P_READ8 = 0x107 P_SET64 = 0x108 P_SET32 = 0x109 P_SET16 = 0x10a P_SET8 = 0x10b P_CLEAR64 = 0x10c P_CLEAR32 = 0x10d P_CLEAR16 = 0x10e P_CLEAR8 = 0x10f P_MASK64 = 0x110 P_MASK32 = 0x111 P_MASK16 = 0x112 P_MASK8 = 0x113 P_WRITEREAD64 = 0x114 P_WRITEREAD32 = 0x115 P_WRITEREAD16 = 0x116 P_WRITEREAD8 = 0x117 P_MEMCPY64 = 0x200 P_MEMCPY32 = 0x201 P_MEMCPY16 = 0x202 P_MEMCPY8 = 0x203 P_MEMSET64 = 0x204 P_MEMSET32 = 0x205 P_MEMSET16 = 0x206 P_MEMSET8 = 0x207 P_IC_IALLUIS = 0x300 P_IC_IALLU = 0x301 P_IC_IVAU = 0x302 P_DC_IVAC = 0x303 P_DC_ISW = 0x304 P_DC_CSW = 0x305 P_DC_CISW = 0x306 P_DC_ZVA = 0x307 P_DC_CVAC = 0x308 P_DC_CVAU = 0x309 P_DC_CIVAC = 0x30a P_MMU_SHUTDOWN = 0x30b P_XZDEC = 0x400 P_GZDEC = 0x401 P_SMP_START_SECONDARIES = 0x500 P_SMP_CALL = 0x501 P_SMP_CALL_SYNC = 0x502 P_HEAPBLOCK_ALLOC = 0x600 P_MALLOC = 0x601 P_MEMALIGN = 0x602 P_FREE = 0x602 P_KBOOT_BOOT = 0x700 P_KBOOT_SET_BOOTARGS = 0x701 P_KBOOT_SET_INITRD = 0x702 P_KBOOT_PREPARE_DT = 0x703 P_PMGR_CLOCK_ENABLE = 0x800 P_PMGR_CLOCK_DISABLE = 0x801 P_PMGR_ADT_CLOCKS_ENABLE = 0x802 P_PMGR_ADT_CLOCKS_DISABLE = 0x803 P_IODEV_SET_USAGE = 0x900 P_IODEV_CAN_READ = 0x901 P_IODEV_CAN_WRITE = 0x902 P_IODEV_READ = 0x903 P_IODEV_WRITE = 0x904 P_IODEV_WHOAMI = 0x905 P_TUNABLES_APPLY_GLOBAL = 0xa00 P_TUNABLES_APPLY_LOCAL = 0xa01 P_DART_INIT = 0xb00 P_DART_SHUTDOWN = 0xb01 P_DART_MAP = 0xb02 P_DART_UNMAP = 0xb03 P_HV_INIT = 0xc00 P_HV_MAP = 0xc01 P_HV_START = 0xc02 P_HV_TRANSLATE = 0xc03 def __init__(self, iface, debug=False): self.debug = debug self.iface = iface self.heap = None def _request(self, opcode, *args, reboot=False, signed=False, no_reply=False, pre_reply=None): if len(args) > 6: raise ValueError("Too many arguments") args = list(args) + [0] * (6 - len(args)) req = struct.pack("<7Q", opcode, *args) if self.debug: print("<<<< %08x: %08x %08x %08x %08x %08x %08x"%tuple([opcode] + args)) reply = self.iface.proxyreq(req, reboot=reboot, no_reply=no_reply, pre_reply=None) if no_reply or reboot and reply is None: return ret_fmt = "q" if signed else "Q" rop, status, retval = struct.unpack(">>> %08x: %d %08x"%(rop, status, retval)) if reboot: return if rop != opcode: raise ProxyReplyError("Reply opcode mismatch: Expected 0x%08x, got 0x%08x"%(opcode,rop)) if status != self.S_OK: if status == self.S_BADCMD: raise ProxyCommandError("Reply error: Bad Command") else: raise ProxyRemoteError("Reply error: Unknown error (%d)"%status) return retval def request(self, opcode, *args, **kwargs): free = [] args = list(args) args2 = [] for i, arg in enumerate(args): if isinstance(arg, str): arg = arg.encode("utf-8") + b"\0" if isinstance(arg, bytes) and self.heap: p = self.heap.malloc(len(arg)) free.append(p) self.iface.writemem(p, arg) if (i < (len(args) - 1)) and args[i + 1] is None: args[i + 1] = len(arg) arg = p args2.append(arg) try: return self._request(opcode, *args2, **kwargs) finally: for i in free: self.heap.free(i) def nop(self): self.request(self.P_NOP) def exit(self, retval=0): self.request(self.P_EXIT, retval) def call(self, addr, *args, reboot=False): if len(args) > 4: raise ValueError("Too many arguments") return self.request(self.P_CALL, addr, *args, reboot=reboot) def reboot(self, addr, *args, el1=False): if len(args) > 4: raise ValueError("Too many arguments") if el1: self.request(self.P_EL1_CALL, addr, *args, no_reply=True) else: try: self.request(self.P_VECTOR, addr, *args) self.iface.wait_boot() except ProxyCommandError: # old m1n1 does not support P_VECTOR try: self.mmu_shutdown() except ProxyCommandError: # older m1n1 does not support MMU pass self.request(self.P_CALL, addr, *args, reboot=True) def get_bootargs(self): return self.request(self.P_GET_BOOTARGS) def get_base(self): return self.request(self.P_GET_BASE) def set_baud(self, baudrate): self.iface.tty_enable = False def change(): self.iface.dev.baudrate = baudrate try: self.request(self.P_SET_BAUD, baudrate, 16, 0x005aa5f0, pre_reply=change) finally: self.iface.tty_enable = True def udelay(self, usec): self.request(self.P_UDELAY, usec) def set_exc_guard(self, mode): self.request(self.P_SET_EXC_GUARD, mode) def get_exc_count(self): return self.request(self.P_GET_EXC_COUNT) def el0_call(self, addr, *args): if len(args) > 4: raise ValueError("Too many arguments") return self.request(self.P_EL0_CALL, addr, *args) def el1_call(self, addr, *args): if len(args) > 4: raise ValueError("Too many arguments") return self.request(self.P_EL1_CALL, addr, *args) def write64(self, addr, data): if addr & 7: raise AlignmentError() self.request(self.P_WRITE64, addr, data) def write32(self, addr, data): if addr & 3: raise AlignmentError() self.request(self.P_WRITE32, addr, data) def write16(self, addr, data): if addr & 1: raise AlignmentError() self.request(self.P_WRITE16, addr, data) def write8(self, addr, data): self.request(self.P_WRITE8, addr, data) def read64(self, addr): if addr & 7: raise AlignmentError() return self.request(self.P_READ64, addr) def read32(self, addr): if addr & 3: raise AlignmentError() return self.request(self.P_READ32, addr) def read16(self, addr): if addr & 1: raise AlignmentError() return self.request(self.P_READ16, addr) def read8(self, addr): return self.request(self.P_READ8, addr) def set64(self, addr, data): if addr & 7: raise AlignmentError() self.request(self.P_SET64, addr, data) def set32(self, addr, data): if addr & 3: raise AlignmentError() self.request(self.P_SET32, addr, data) def set16(self, addr, data): if addr & 1: raise AlignmentError() self.request(self.P_SET16, addr, data) def set8(self, addr, data): self.request(self.P_SET8, addr, data) def clear64(self, addr, data): if addr & 7: raise AlignmentError() self.request(self.P_CLEAR64, addr, data) def clear32(self, addr, data): if addr & 3: raise AlignmentError() self.request(self.P_CLEAR32, addr, data) def clear16(self, addr, data): if addr & 1: raise AlignmentError() self.request(self.P_CLEAR16, addr, data) def clear8(self, addr, data): self.request(self.P_CLEAR8, addr, data) def mask64(self, addr, clear, set): if addr & 7: raise AlignmentError() self.request(self.P_MASK64, addr, clear, set) def mask32(self, addr, clear, set): if addr & 3: raise AlignmentError() self.request(self.P_MASK32, addr, clear, set) def mask16(self, addr, clear, set): if addr & 1: raise AlignmentError() self.request(self.P_MASK16, addr, clear, set) def mask8(self, addr, clear, set): self.request(self.P_MASK8, addr, clear, set) def writeread64(self, addr, data): return self.request(self.P_WRITEREAD64, addr, data) def writeread32(self, addr, data): return self.request(self.P_WRITEREAD32, addr, data) def writeread16(self, addr, data): return self.request(self.P_WRITEREAD16, addr, data) def writeread8(self, addr, data): return self.request(self.P_WRITEREAD8, addr, data) def memcpy64(self, dst, src, size): if src & 7 or dst & 7: raise AlignmentError() self.request(self.P_MEMCPY64, dst, src, size) def memcpy32(self, dst, src, size): if src & 3 or dst & 3: raise AlignmentError() self.request(self.P_MEMCPY32, dst, src, size) def memcpy16(self, dst, src, size): if src & 1 or dst & 1: raise AlignmentError() self.request(self.P_MEMCPY16, dst, src, size) def memcpy8(self, dst, src, size): self.request(self.P_MEMCPY8, dst, src, size) def memset64(self, dst, src, size): if dst & 7: raise AlignmentError() self.request(self.P_MEMSET64, dst, src, size) def memset32(self, dst, src, size): if dst & 3: raise AlignmentError() self.request(self.P_MEMSET32, dst, src, size) def memset16(self, dst, src, size): if dst & 1: raise AlignmentError() self.request(self.P_MEMSET16, dst, src, size) def memset8(self, dst, src, size): self.request(self.P_MEMSET8, dst, src, size) def ic_ialluis(self): self.request(self.P_IC_IALLUIS) def ic_iallu(self): self.request(self.P_IC_IALLU) def ic_ivau(self, addr, size): self.request(self.P_IC_IVAU, addr, size) def ic_ivac(self, addr, size): self.request(self.P_IC_IVAC, addr, size) def dc_isw(self, sw): self.request(self.P_DC_ISW, sw) def dc_csw(self, sw): self.request(self.P_DC_CSW, sw) def dc_cisw(self, sw): self.request(self.P_DC_CISW, sw) def dc_zva(self, addr, size): self.request(self.P_DC_ZVA, addr, size) def dc_cvac(self, addr, size): self.request(self.P_DC_CVAC, addr, size) def dc_cvau(self, addr, size): self.request(self.P_DC_CVAU, addr, size) def dc_civac(self, addr, size): self.request(self.P_DC_CIVAC, addr, size) def mmu_shutdown(self): self.request(self.P_MMU_SHUTDOWN) def xzdec(self, inbuf, insize, outbuf=0, outsize=0): return self.request(self.P_XZDEC, inbuf, insize, outbuf, outsize, signed=True) def gzdec(self, inbuf, insize, outbuf, outsize): return self.request(self.P_GZDEC, inbuf, insize, outbuf, outsize, signed=True) def smp_start_secondaries(self): self.request(self.P_SMP_START_SECONDARIES) def smp_call(self, cpu, addr, *args): if len(args) > 4: raise ValueError("Too many arguments") self.request(self.P_SMP_CALL, cpu, addr, *args) def smp_call_sync(self, cpu, addr, *args): if len(args) > 4: raise ValueError("Too many arguments") return self.request(self.P_SMP_CALL_SYNC, cpu, addr, *args) def heapblock_alloc(self, size): return self.request(self.P_HEAPBLOCK_ALLOC, size) def malloc(self, size): return self.request(self.P_MALLOC, size) def memalign(self, align, size): return self.request(self.P_MEMALIGN, align, size) def free(self, ptr): self.request(self.P_FREE, ptr) def kboot_boot(self, kernel): self.request(self.P_KBOOT_BOOT, kernel) def kboot_set_bootargs(self, bootargs): self.request(self.P_KBOOT_SET_BOOTARGS, bootargs) def kboot_set_initrd(self, base, size): self.request(self.P_KBOOT_SET_INITRD, base, size) def kboot_prepare_dt(self, dt_addr): return self.request(self.P_KBOOT_PREPARE_DT, dt_addr) def pmgr_clock_enable(self, clkid): return self.request(self.P_PMGR_CLOCK_ENABLE, clkid) def pmgr_clock_disable(self, clkid): return self.request(self.P_PMGR_CLOCK_DISABLE, clkid) def pmgr_adt_clocks_enable(self, path): return self.request(self.P_PMGR_ADT_CLOCKS_ENABLE, path) def pmgr_adt_clocks_disable(self, path): return self.request(self.P_PMGR_ADT_CLOCKS_DISABLE, path) def iodev_set_usage(self, iodev, usage): return self.request(self.P_IODEV_SET_USAGE, iodev, usage) def iodev_can_read(self, iodev): return self.request(self.P_IODEV_CAN_READ, iodev) def iodev_can_write(self, iodev): return self.request(self.P_IODEV_CAN_WRITE, iodev) def iodev_read(self, iodev, buf, size=None): return self.request(self.P_IODEV_READ, iodev, buf, size) def iodev_write(self, iodev, buf, size=None): return self.request(self.P_IODEV_WRITE, iodev, buf, size) def iodev_whoami(self): return IODEV(self.request(self.P_IODEV_WHOAMI)) def tunables_apply_global(self, path, prop): return self.request(self.P_TUNABLES_APPLY_GLOBAL, path, prop) def tunables_apply_local(self, path, prop, reg_offset): return self.request(self.P_TUNABLES_APPLY_LOCAL, path, prop, reg_offset) def tunables_apply_local_addr(self, path, prop, base): return self.request(self.P_TUNABLES_APPLY_LOCAL, path, prop, base) def dart_init(self, base, sid): return self.request(self.P_DART_INIT, base, sid) def dart_shutdown(self, dart): return self.request(self.P_DART_SHUTDOWN, dart) def dart_map(self, dart, iova, bfr, len): return self.request(self.P_DART_MAP, dart, iova, bfr, len) def dart_unmap(self, dart, iova, len): return self.request(self.P_DART_UNMAP, dart, iova, len) def hv_init(self): return self.request(self.P_HV_INIT) def hv_map(self, from_, to, size, incr): return self.request(self.P_HV_MAP, from_, to, size, incr) def hv_start(self, entry, *args): return self.request(self.P_HV_START, entry, *args) def hv_translate(self, addr, s1, w): return self.request(self.P_HV_TRANSLATE, addr, s1, w) if __name__ == "__main__": import serial uartdev = os.environ.get("M1N1DEVICE", "/dev/ttyUSB0") usbuart = serial.Serial(uartdev, 115200) uartif = UartInterface(usbuart, debug=True) print("Sending NOP...", end=' ') uartif.nop() print("OK") proxy = M1N1Proxy(uartif, debug=True) print("Sending Proxy NOP...", end=' ') proxy.nop() print("OK") print("Boot args: 0x%x" % proxy.get_bootargs())