sm64

A Super Mario 64 decompilation
Log | Files | Refs | README | LICENSE

diff.py (48905B)


      1 #!/usr/bin/env python3
      2 # PYTHON_ARGCOMPLETE_OK
      3 import argparse
      4 import sys
      5 from typing import (
      6     Any,
      7     Dict,
      8     List,
      9     Match,
     10     NamedTuple,
     11     NoReturn,
     12     Optional,
     13     Set,
     14     Tuple,
     15     Union,
     16     Callable,
     17     Pattern,
     18 )
     19 
     20 
     21 def fail(msg: str) -> NoReturn:
     22     print(msg, file=sys.stderr)
     23     sys.exit(1)
     24 
     25 
     26 # Prefer to use diff_settings.py from the current working directory
     27 sys.path.insert(0, ".")
     28 try:
     29     import diff_settings
     30 except ModuleNotFoundError:
     31     fail("Unable to find diff_settings.py in the same directory.")
     32 sys.path.pop(0)
     33 
     34 # ==== COMMAND-LINE ====
     35 
     36 try:
     37     import argcomplete  # type: ignore
     38 except ModuleNotFoundError:
     39     argcomplete = None
     40 
     41 parser = argparse.ArgumentParser(description="Diff MIPS or AArch64 assembly.")
     42 
     43 start_argument = parser.add_argument(
     44     "start",
     45     help="Function name or address to start diffing from.",
     46 )
     47 
     48 if argcomplete:
     49 
     50     def complete_symbol(
     51         prefix: str, parsed_args: argparse.Namespace, **kwargs: object
     52     ) -> List[str]:
     53         if not prefix or prefix.startswith("-"):
     54             # skip reading the map file, which would
     55             # result in a lot of useless completions
     56             return []
     57         config: Dict[str, Any] = {}
     58         diff_settings.apply(config, parsed_args)  # type: ignore
     59         mapfile = config.get("mapfile")
     60         if not mapfile:
     61             return []
     62         completes = []
     63         with open(mapfile) as f:
     64             data = f.read()
     65             # assume symbols are prefixed by a space character
     66             search = f" {prefix}"
     67             pos = data.find(search)
     68             while pos != -1:
     69                 # skip the space character in the search string
     70                 pos += 1
     71                 # assume symbols are suffixed by either a space
     72                 # character or a (unix-style) line return
     73                 spacePos = data.find(" ", pos)
     74                 lineReturnPos = data.find("\n", pos)
     75                 if lineReturnPos == -1:
     76                     endPos = spacePos
     77                 elif spacePos == -1:
     78                     endPos = lineReturnPos
     79                 else:
     80                     endPos = min(spacePos, lineReturnPos)
     81                 if endPos == -1:
     82                     match = data[pos:]
     83                     pos = -1
     84                 else:
     85                     match = data[pos:endPos]
     86                     pos = data.find(search, endPos)
     87                 completes.append(match)
     88         return completes
     89 
     90     setattr(start_argument, "completer", complete_symbol)
     91 
     92 parser.add_argument(
     93     "end",
     94     nargs="?",
     95     help="Address to end diff at.",
     96 )
     97 parser.add_argument(
     98     "-o",
     99     dest="diff_obj",
    100     action="store_true",
    101     help="Diff .o files rather than a whole binary. This makes it possible to "
    102     "see symbol names. (Recommended)",
    103 )
    104 parser.add_argument(
    105     "--elf",
    106     dest="diff_elf_symbol",
    107     metavar="SYMBOL",
    108     help="Diff a given function in two ELFs, one being stripped and the other "
    109     "one non-stripped. Requires objdump from binutils 2.33+.",
    110 )
    111 parser.add_argument(
    112     "--source",
    113     action="store_true",
    114     help="Show source code (if possible). Only works with -o and -e.",
    115 )
    116 parser.add_argument(
    117     "--inlines",
    118     action="store_true",
    119     help="Show inline function calls (if possible). Only works with -o and -e.",
    120 )
    121 parser.add_argument(
    122     "--base-asm",
    123     dest="base_asm",
    124     metavar="FILE",
    125     help="Read assembly from given file instead of configured base img.",
    126 )
    127 parser.add_argument(
    128     "--write-asm",
    129     dest="write_asm",
    130     metavar="FILE",
    131     help="Write the current assembly output to file, e.g. for use with --base-asm.",
    132 )
    133 parser.add_argument(
    134     "-m",
    135     "--make",
    136     dest="make",
    137     action="store_true",
    138     help="Automatically run 'make' on the .o file or binary before diffing.",
    139 )
    140 parser.add_argument(
    141     "-l",
    142     "--skip-lines",
    143     dest="skip_lines",
    144     type=int,
    145     default=0,
    146     metavar="LINES",
    147     help="Skip the first N lines of output.",
    148 )
    149 parser.add_argument(
    150     "-f",
    151     "--stop-jr-ra",
    152     dest="stop_jrra",
    153     action="store_true",
    154     help="Stop disassembling at the first 'jr ra'. Some functions have multiple return points, so use with care!",
    155 )
    156 parser.add_argument(
    157     "-i",
    158     "--ignore-large-imms",
    159     dest="ignore_large_imms",
    160     action="store_true",
    161     help="Pretend all large enough immediates are the same.",
    162 )
    163 parser.add_argument(
    164     "-I",
    165     "--ignore-addr-diffs",
    166     action="store_true",
    167     help="Ignore address differences. Currently only affects AArch64.",
    168 )
    169 parser.add_argument(
    170     "-B",
    171     "--no-show-branches",
    172     dest="show_branches",
    173     action="store_false",
    174     help="Don't visualize branches/branch targets.",
    175 )
    176 parser.add_argument(
    177     "-S",
    178     "--base-shift",
    179     dest="base_shift",
    180     type=str,
    181     default="0",
    182     help="Diff position X in our img against position X + shift in the base img. "
    183     'Arithmetic is allowed, so e.g. |-S "0x1234 - 0x4321"| is a reasonable '
    184     "flag to pass if it is known that position 0x1234 in the base img syncs "
    185     "up with position 0x4321 in our img. Not supported together with -o.",
    186 )
    187 parser.add_argument(
    188     "-w",
    189     "--watch",
    190     dest="watch",
    191     action="store_true",
    192     help="Automatically update when source/object files change. "
    193     "Recommended in combination with -m.",
    194 )
    195 parser.add_argument(
    196     "-3",
    197     "--threeway=prev",
    198     dest="threeway",
    199     action="store_const",
    200     const="prev",
    201     help="Show a three-way diff between target asm, current asm, and asm "
    202     "prior to -w rebuild. Requires -w.",
    203 )
    204 parser.add_argument(
    205     "-b",
    206     "--threeway=base",
    207     dest="threeway",
    208     action="store_const",
    209     const="base",
    210     help="Show a three-way diff between target asm, current asm, and asm "
    211     "when diff.py was started. Requires -w.",
    212 )
    213 parser.add_argument(
    214     "--width",
    215     dest="column_width",
    216     type=int,
    217     default=50,
    218     help="Sets the width of the left and right view column.",
    219 )
    220 parser.add_argument(
    221     "--algorithm",
    222     dest="algorithm",
    223     default="levenshtein",
    224     choices=["levenshtein", "difflib"],
    225     help="Diff algorithm to use. Levenshtein gives the minimum diff, while difflib "
    226     "aims for long sections of equal opcodes. Defaults to %(default)s.",
    227 )
    228 parser.add_argument(
    229     "--max-size",
    230     "--max-lines",
    231     dest="max_lines",
    232     type=int,
    233     default=1024,
    234     help="The maximum length of the diff, in lines.",
    235 )
    236 
    237 # Project-specific flags, e.g. different versions/make arguments.
    238 add_custom_arguments_fn = getattr(diff_settings, "add_custom_arguments", None)
    239 if add_custom_arguments_fn:
    240     add_custom_arguments_fn(parser)
    241 
    242 if argcomplete:
    243     argcomplete.autocomplete(parser)
    244 
    245 # ==== IMPORTS ====
    246 
    247 # (We do imports late to optimize auto-complete performance.)
    248 
    249 import re
    250 import os
    251 import ast
    252 import subprocess
    253 import difflib
    254 import string
    255 import itertools
    256 import threading
    257 import queue
    258 import time
    259 
    260 
    261 MISSING_PREREQUISITES = (
    262     "Missing prerequisite python module {}. "
    263     "Run `python3 -m pip install --user colorama ansiwrap watchdog python-Levenshtein cxxfilt` to install prerequisites (cxxfilt only needed with --source)."
    264 )
    265 
    266 try:
    267     from colorama import Fore, Style, Back  # type: ignore
    268     import ansiwrap  # type: ignore
    269     import watchdog  # type: ignore
    270 except ModuleNotFoundError as e:
    271     fail(MISSING_PREREQUISITES.format(e.name))
    272 
    273 # ==== CONFIG ====
    274 
    275 args = parser.parse_args()
    276 
    277 # Set imgs, map file and make flags in a project-specific manner.
    278 config: Dict[str, Any] = {}
    279 diff_settings.apply(config, args)  # type: ignore
    280 
    281 arch: str = config.get("arch", "mips")
    282 baseimg: Optional[str] = config.get("baseimg")
    283 myimg: Optional[str] = config.get("myimg")
    284 mapfile: Optional[str] = config.get("mapfile")
    285 makeflags: List[str] = config.get("makeflags", [])
    286 source_directories: Optional[List[str]] = config.get("source_directories")
    287 objdump_executable: Optional[str] = config.get("objdump_executable")
    288 
    289 MAX_FUNCTION_SIZE_LINES: int = args.max_lines
    290 MAX_FUNCTION_SIZE_BYTES: int = MAX_FUNCTION_SIZE_LINES * 4
    291 
    292 COLOR_ROTATION: List[str] = [
    293     Fore.MAGENTA,
    294     Fore.CYAN,
    295     Fore.GREEN,
    296     Fore.RED,
    297     Fore.LIGHTYELLOW_EX,
    298     Fore.LIGHTMAGENTA_EX,
    299     Fore.LIGHTCYAN_EX,
    300     Fore.LIGHTGREEN_EX,
    301     Fore.LIGHTBLACK_EX,
    302 ]
    303 
    304 BUFFER_CMD: List[str] = ["tail", "-c", str(10 ** 9)]
    305 LESS_CMD: List[str] = ["less", "-SRic", "-#6"]
    306 
    307 DEBOUNCE_DELAY: float = 0.1
    308 FS_WATCH_EXTENSIONS: List[str] = [".c", ".h", ".s"]
    309 
    310 # ==== LOGIC ====
    311 
    312 ObjdumpCommand = Tuple[List[str], str, Optional[str]]
    313 
    314 if args.algorithm == "levenshtein":
    315     try:
    316         import Levenshtein  # type: ignore
    317     except ModuleNotFoundError as e:
    318         fail(MISSING_PREREQUISITES.format(e.name))
    319 
    320 if args.source:
    321     try:
    322         import cxxfilt  # type: ignore
    323     except ModuleNotFoundError as e:
    324         fail(MISSING_PREREQUISITES.format(e.name))
    325 
    326 if args.threeway and not args.watch:
    327     fail("Threeway diffing requires -w.")
    328 
    329 if objdump_executable is None:
    330     for objdump_cand in ["mips-linux-gnu-objdump", "mips64-elf-objdump"]:
    331         try:
    332             subprocess.check_call(
    333                 [objdump_cand, "--version"],
    334                 stdout=subprocess.DEVNULL,
    335                 stderr=subprocess.DEVNULL,
    336             )
    337             objdump_executable = objdump_cand
    338             break
    339         except subprocess.CalledProcessError:
    340             pass
    341         except FileNotFoundError:
    342             pass
    343 
    344 if not objdump_executable:
    345     fail(
    346         "Missing binutils; please ensure mips-linux-gnu-objdump or mips64-elf-objdump exist, or configure objdump_executable."
    347     )
    348 
    349 
    350 def maybe_eval_int(expr: str) -> Optional[int]:
    351     try:
    352         ret = ast.literal_eval(expr)
    353         if not isinstance(ret, int):
    354             raise Exception("not an integer")
    355         return ret
    356     except Exception:
    357         return None
    358 
    359 
    360 def eval_int(expr: str, emsg: str) -> int:
    361     ret = maybe_eval_int(expr)
    362     if ret is None:
    363         fail(emsg)
    364     return ret
    365 
    366 
    367 def eval_line_num(expr: str) -> int:
    368     return int(expr.strip().replace(":", ""), 16)
    369 
    370 
    371 def run_make(target: str) -> None:
    372     subprocess.check_call(["make"] + makeflags + [target])
    373 
    374 
    375 def run_make_capture_output(target: str) -> "subprocess.CompletedProcess[bytes]":
    376     return subprocess.run(
    377         ["make"] + makeflags + [target],
    378         stderr=subprocess.PIPE,
    379         stdout=subprocess.PIPE,
    380     )
    381 
    382 
    383 def restrict_to_function(dump: str, fn_name: str) -> str:
    384     out: List[str] = []
    385     search = f"<{fn_name}>:"
    386     found = False
    387     for line in dump.split("\n"):
    388         if found:
    389             if len(out) >= MAX_FUNCTION_SIZE_LINES:
    390                 break
    391             out.append(line)
    392         elif search in line:
    393             found = True
    394     return "\n".join(out)
    395 
    396 
    397 def maybe_get_objdump_source_flags() -> List[str]:
    398     if not args.source:
    399         return []
    400 
    401     flags = [
    402         "--source",
    403         "--source-comment=│ ",
    404         "-l",
    405     ]
    406 
    407     if args.inlines:
    408         flags.append("--inlines")
    409 
    410     return flags
    411 
    412 
    413 def run_objdump(cmd: ObjdumpCommand) -> str:
    414     flags, target, restrict = cmd
    415     assert objdump_executable, "checked previously"
    416     out = subprocess.check_output(
    417         [objdump_executable] + arch_flags + flags + [target], universal_newlines=True
    418     )
    419     if restrict is not None:
    420         return restrict_to_function(out, restrict)
    421     return out
    422 
    423 
    424 base_shift: int = eval_int(
    425     args.base_shift, "Failed to parse --base-shift (-S) argument as an integer."
    426 )
    427 
    428 
    429 def search_map_file(fn_name: str) -> Tuple[Optional[str], Optional[int]]:
    430     if not mapfile:
    431         fail(f"No map file configured; cannot find function {fn_name}.")
    432 
    433     try:
    434         with open(mapfile) as f:
    435             lines = f.read().split("\n")
    436     except Exception:
    437         fail(f"Failed to open map file {mapfile} for reading.")
    438 
    439     try:
    440         cur_objfile = None
    441         ram_to_rom = None
    442         cands = []
    443         last_line = ""
    444         for line in lines:
    445             if line.startswith(" .text"):
    446                 cur_objfile = line.split()[3]
    447             if "load address" in line:
    448                 tokens = last_line.split() + line.split()
    449                 ram = int(tokens[1], 0)
    450                 rom = int(tokens[5], 0)
    451                 ram_to_rom = rom - ram
    452             if line.endswith(" " + fn_name):
    453                 ram = int(line.split()[0], 0)
    454                 if cur_objfile is not None and ram_to_rom is not None:
    455                     cands.append((cur_objfile, ram + ram_to_rom))
    456             last_line = line
    457     except Exception as e:
    458         import traceback
    459 
    460         traceback.print_exc()
    461         fail(f"Internal error while parsing map file")
    462 
    463     if len(cands) > 1:
    464         fail(f"Found multiple occurrences of function {fn_name} in map file.")
    465     if len(cands) == 1:
    466         return cands[0]
    467     return None, None
    468 
    469 
    470 def dump_elf() -> Tuple[str, ObjdumpCommand, ObjdumpCommand]:
    471     if not baseimg or not myimg:
    472         fail("Missing myimg/baseimg in config.")
    473     if base_shift:
    474         fail("--base-shift not compatible with -e")
    475 
    476     start_addr = eval_int(args.start, "Start address must be an integer expression.")
    477 
    478     if args.end is not None:
    479         end_addr = eval_int(args.end, "End address must be an integer expression.")
    480     else:
    481         end_addr = start_addr + MAX_FUNCTION_SIZE_BYTES
    482 
    483     flags1 = [
    484         f"--start-address={start_addr}",
    485         f"--stop-address={end_addr}",
    486     ]
    487 
    488     flags2 = [
    489         f"--disassemble={args.diff_elf_symbol}",
    490     ]
    491 
    492     objdump_flags = ["-drz", "-j", ".text"]
    493     return (
    494         myimg,
    495         (objdump_flags + flags1, baseimg, None),
    496         (objdump_flags + flags2 + maybe_get_objdump_source_flags(), myimg, None),
    497     )
    498 
    499 
    500 def dump_objfile() -> Tuple[str, ObjdumpCommand, ObjdumpCommand]:
    501     if base_shift:
    502         fail("--base-shift not compatible with -o")
    503     if args.end is not None:
    504         fail("end address not supported together with -o")
    505     if args.start.startswith("0"):
    506         fail("numerical start address not supported with -o; pass a function name")
    507 
    508     objfile, _ = search_map_file(args.start)
    509     if not objfile:
    510         fail("Not able to find .o file for function.")
    511 
    512     if args.make:
    513         run_make(objfile)
    514 
    515     if not os.path.isfile(objfile):
    516         fail(f"Not able to find .o file for function: {objfile} is not a file.")
    517 
    518     refobjfile = "expected/" + objfile
    519     if not os.path.isfile(refobjfile):
    520         fail(f'Please ensure an OK .o file exists at "{refobjfile}".')
    521 
    522     objdump_flags = ["-drz"]
    523     return (
    524         objfile,
    525         (objdump_flags, refobjfile, args.start),
    526         (objdump_flags + maybe_get_objdump_source_flags(), objfile, args.start),
    527     )
    528 
    529 
    530 def dump_binary() -> Tuple[str, ObjdumpCommand, ObjdumpCommand]:
    531     if not baseimg or not myimg:
    532         fail("Missing myimg/baseimg in config.")
    533     if args.make:
    534         run_make(myimg)
    535     start_addr = maybe_eval_int(args.start)
    536     if start_addr is None:
    537         _, start_addr = search_map_file(args.start)
    538         if start_addr is None:
    539             fail("Not able to find function in map file.")
    540     if args.end is not None:
    541         end_addr = eval_int(args.end, "End address must be an integer expression.")
    542     else:
    543         end_addr = start_addr + MAX_FUNCTION_SIZE_BYTES
    544     objdump_flags = ["-Dz", "-bbinary", "-EB"]
    545     flags1 = [
    546         f"--start-address={start_addr + base_shift}",
    547         f"--stop-address={end_addr + base_shift}",
    548     ]
    549     flags2 = [f"--start-address={start_addr}", f"--stop-address={end_addr}"]
    550     return (
    551         myimg,
    552         (objdump_flags + flags1, baseimg, None),
    553         (objdump_flags + flags2, myimg, None),
    554     )
    555 
    556 
    557 def ansi_ljust(s: str, width: int) -> str:
    558     """Like s.ljust(width), but accounting for ANSI colors."""
    559     needed: int = width - ansiwrap.ansilen(s)
    560     if needed > 0:
    561         return s + " " * needed
    562     else:
    563         return s
    564 
    565 
    566 if arch == "mips":
    567     re_int = re.compile(r"[0-9]+")
    568     re_comment = re.compile(r"<.*?>")
    569     re_reg = re.compile(
    570         r"\$?\b(a[0-3]|t[0-9]|s[0-8]|at|v[01]|f[12]?[0-9]|f3[01]|k[01]|fp|ra|zero)\b"
    571     )
    572     re_sprel = re.compile(r"(?<=,)([0-9]+|0x[0-9a-f]+)\(sp\)")
    573     re_large_imm = re.compile(r"-?[1-9][0-9]{2,}|-?0x[0-9a-f]{3,}")
    574     re_imm = re.compile(r"(\b|-)([0-9]+|0x[0-9a-fA-F]+)\b(?!\(sp)|%(lo|hi)\([^)]*\)")
    575     forbidden = set(string.ascii_letters + "_")
    576     arch_flags = ["-m", "mips:4300"]
    577     branch_likely_instructions = {
    578         "beql",
    579         "bnel",
    580         "beqzl",
    581         "bnezl",
    582         "bgezl",
    583         "bgtzl",
    584         "blezl",
    585         "bltzl",
    586         "bc1tl",
    587         "bc1fl",
    588     }
    589     branch_instructions = branch_likely_instructions.union(
    590         {
    591             "b",
    592             "beq",
    593             "bne",
    594             "beqz",
    595             "bnez",
    596             "bgez",
    597             "bgtz",
    598             "blez",
    599             "bltz",
    600             "bc1t",
    601             "bc1f",
    602         }
    603     )
    604     instructions_with_address_immediates = branch_instructions.union({"jal", "j"})
    605 elif arch == "aarch64":
    606     re_int = re.compile(r"[0-9]+")
    607     re_comment = re.compile(r"(<.*?>|//.*$)")
    608     # GPRs and FP registers: X0-X30, W0-W30, [DSHQ]0..31
    609     # The zero registers and SP should not be in this list.
    610     re_reg = re.compile(r"\$?\b([dshq][12]?[0-9]|[dshq]3[01]|[xw][12]?[0-9]|[xw]30)\b")
    611     re_sprel = re.compile(r"sp, #-?(0x[0-9a-fA-F]+|[0-9]+)\b")
    612     re_large_imm = re.compile(r"-?[1-9][0-9]{2,}|-?0x[0-9a-f]{3,}")
    613     re_imm = re.compile(r"(?<!sp, )#-?(0x[0-9a-fA-F]+|[0-9]+)\b")
    614     arch_flags = []
    615     forbidden = set(string.ascii_letters + "_")
    616     branch_likely_instructions = set()
    617     branch_instructions = {
    618         "bl",
    619         "b",
    620         "b.eq",
    621         "b.ne",
    622         "b.cs",
    623         "b.hs",
    624         "b.cc",
    625         "b.lo",
    626         "b.mi",
    627         "b.pl",
    628         "b.vs",
    629         "b.vc",
    630         "b.hi",
    631         "b.ls",
    632         "b.ge",
    633         "b.lt",
    634         "b.gt",
    635         "b.le",
    636         "cbz",
    637         "cbnz",
    638         "tbz",
    639         "tbnz",
    640     }
    641     instructions_with_address_immediates = branch_instructions.union({"adrp"})
    642 else:
    643     fail("Unknown architecture.")
    644 
    645 
    646 def hexify_int(row: str, pat: Match[str]) -> str:
    647     full = pat.group(0)
    648     if len(full) <= 1:
    649         # leave one-digit ints alone
    650         return full
    651     start, end = pat.span()
    652     if start and row[start - 1] in forbidden:
    653         return full
    654     if end < len(row) and row[end] in forbidden:
    655         return full
    656     return hex(int(full))
    657 
    658 
    659 def parse_relocated_line(line: str) -> Tuple[str, str, str]:
    660     try:
    661         ind2 = line.rindex(",")
    662     except ValueError:
    663         ind2 = line.rindex("\t")
    664     before = line[: ind2 + 1]
    665     after = line[ind2 + 1 :]
    666     ind2 = after.find("(")
    667     if ind2 == -1:
    668         imm, after = after, ""
    669     else:
    670         imm, after = after[:ind2], after[ind2:]
    671     if imm == "0x0":
    672         imm = "0"
    673     return before, imm, after
    674 
    675 
    676 def process_mips_reloc(row: str, prev: str) -> str:
    677     before, imm, after = parse_relocated_line(prev)
    678     repl = row.split()[-1]
    679     if imm != "0":
    680         # MIPS uses relocations with addends embedded in the code as immediates.
    681         # If there is an immediate, show it as part of the relocation. Ideally
    682         # we'd show this addend in both %lo/%hi, but annoyingly objdump's output
    683         # doesn't include enough information to pair up %lo's and %hi's...
    684         # TODO: handle unambiguous cases where all addends for a symbol are the
    685         # same, or show "+???".
    686         mnemonic = prev.split()[0]
    687         if mnemonic in instructions_with_address_immediates and not imm.startswith(
    688             "0x"
    689         ):
    690             imm = "0x" + imm
    691         repl += "+" + imm if int(imm, 0) > 0 else imm
    692     if "R_MIPS_LO16" in row:
    693         repl = f"%lo({repl})"
    694     elif "R_MIPS_HI16" in row:
    695         # Ideally we'd pair up R_MIPS_LO16 and R_MIPS_HI16 to generate a
    696         # correct addend for each, but objdump doesn't give us the order of
    697         # the relocations, so we can't find the right LO16. :(
    698         repl = f"%hi({repl})"
    699     elif "R_MIPS_26" in row:
    700         # Function calls
    701         pass
    702     elif "R_MIPS_PC16" in row:
    703         # Branch to glabel. This gives confusing output, but there's not much
    704         # we can do here.
    705         pass
    706     else:
    707         assert False, f"unknown relocation type '{row}' for line '{prev}'"
    708     return before + repl + after
    709 
    710 
    711 def pad_mnemonic(line: str) -> str:
    712     if "\t" not in line:
    713         return line
    714     mn, args = line.split("\t", 1)
    715     return f"{mn:<7s} {args}"
    716 
    717 
    718 class Line(NamedTuple):
    719     mnemonic: str
    720     diff_row: str
    721     original: str
    722     normalized_original: str
    723     line_num: str
    724     branch_target: Optional[str]
    725     source_lines: List[str]
    726     comment: Optional[str]
    727 
    728 
    729 class DifferenceNormalizer:
    730     def normalize(self, mnemonic: str, row: str) -> str:
    731         """This should be called exactly once for each line."""
    732         row = self._normalize_arch_specific(mnemonic, row)
    733         if args.ignore_large_imms:
    734             row = re.sub(re_large_imm, "<imm>", row)
    735         return row
    736 
    737     def _normalize_arch_specific(self, mnemonic: str, row: str) -> str:
    738         return row
    739 
    740 
    741 class DifferenceNormalizerAArch64(DifferenceNormalizer):
    742     def __init__(self) -> None:
    743         super().__init__()
    744         self._adrp_pair_registers: Set[str] = set()
    745 
    746     def _normalize_arch_specific(self, mnemonic: str, row: str) -> str:
    747         if args.ignore_addr_diffs:
    748             row = self._normalize_adrp_differences(mnemonic, row)
    749             row = self._normalize_bl(mnemonic, row)
    750         return row
    751 
    752     def _normalize_bl(self, mnemonic: str, row: str) -> str:
    753         if mnemonic != "bl":
    754             return row
    755 
    756         row, _ = split_off_branch(row)
    757         return row
    758 
    759     def _normalize_adrp_differences(self, mnemonic: str, row: str) -> str:
    760         """Identifies ADRP + LDR/ADD pairs that are used to access the GOT and
    761         suppresses any immediate differences.
    762 
    763         Whenever an ADRP is seen, the destination register is added to the set of registers
    764         that are part of an ADRP + LDR/ADD pair. Registers are removed from the set as soon
    765         as they are used for an LDR or ADD instruction which completes the pair.
    766 
    767         This method is somewhat crude but should manage to detect most such pairs.
    768         """
    769         row_parts = row.split("\t", 1)
    770         if mnemonic == "adrp":
    771             self._adrp_pair_registers.add(row_parts[1].strip().split(",")[0])
    772             row, _ = split_off_branch(row)
    773         elif mnemonic == "ldr":
    774             for reg in self._adrp_pair_registers:
    775                 # ldr xxx, [reg]
    776                 # ldr xxx, [reg, <imm>]
    777                 if f", [{reg}" in row_parts[1]:
    778                     self._adrp_pair_registers.remove(reg)
    779                     return normalize_imms(row)
    780         elif mnemonic == "add":
    781             for reg in self._adrp_pair_registers:
    782                 # add reg, reg, <imm>
    783                 if row_parts[1].startswith(f"{reg}, {reg}, "):
    784                     self._adrp_pair_registers.remove(reg)
    785                     return normalize_imms(row)
    786 
    787         return row
    788 
    789 
    790 def make_difference_normalizer() -> DifferenceNormalizer:
    791     if arch == "aarch64":
    792         return DifferenceNormalizerAArch64()
    793     return DifferenceNormalizer()
    794 
    795 
    796 def process(lines: List[str]) -> List[Line]:
    797     normalizer = make_difference_normalizer()
    798     skip_next = False
    799     source_lines = []
    800     if not args.diff_obj:
    801         lines = lines[7:]
    802         if lines and not lines[-1]:
    803             lines.pop()
    804 
    805     output: List[Line] = []
    806     stop_after_delay_slot = False
    807     for row in lines:
    808         if args.diff_obj and (">:" in row or not row):
    809             continue
    810 
    811         if args.source and (row and row[0] != " "):
    812             source_lines.append(row)
    813             continue
    814 
    815         if "R_AARCH64_" in row:
    816             # TODO: handle relocation
    817             continue
    818 
    819         if "R_MIPS_" in row:
    820             # N.B. Don't transform the diff rows, they already ignore immediates
    821             # if output[-1].diff_row != "<delay-slot>":
    822             # output[-1] = output[-1].replace(diff_row=process_mips_reloc(row, output[-1].row_with_imm))
    823             new_original = process_mips_reloc(row, output[-1].original)
    824             output[-1] = output[-1]._replace(original=new_original)
    825             continue
    826 
    827         m_comment = re.search(re_comment, row)
    828         comment = m_comment[0] if m_comment else None
    829         row = re.sub(re_comment, "", row)
    830         row = row.rstrip()
    831         tabs = row.split("\t")
    832         row = "\t".join(tabs[2:])
    833         line_num = tabs[0].strip()
    834         row_parts = row.split("\t", 1)
    835         mnemonic = row_parts[0].strip()
    836         if mnemonic not in instructions_with_address_immediates:
    837             row = re.sub(re_int, lambda m: hexify_int(row, m), row)
    838         original = row
    839         normalized_original = normalizer.normalize(mnemonic, original)
    840         if skip_next:
    841             skip_next = False
    842             row = "<delay-slot>"
    843             mnemonic = "<delay-slot>"
    844         if mnemonic in branch_likely_instructions:
    845             skip_next = True
    846         row = re.sub(re_reg, "<reg>", row)
    847         row = re.sub(re_sprel, "addr(sp)", row)
    848         row_with_imm = row
    849         if mnemonic in instructions_with_address_immediates:
    850             row = row.strip()
    851             row, _ = split_off_branch(row)
    852             row += "<imm>"
    853         else:
    854             row = normalize_imms(row)
    855 
    856         branch_target = None
    857         if mnemonic in branch_instructions:
    858             target = row_parts[1].strip().split(",")[-1]
    859             if mnemonic in branch_likely_instructions:
    860                 target = hex(int(target, 16) - 4)[2:]
    861             branch_target = target.strip()
    862 
    863         output.append(
    864             Line(
    865                 mnemonic=mnemonic,
    866                 diff_row=row,
    867                 original=original,
    868                 normalized_original=normalized_original,
    869                 line_num=line_num,
    870                 branch_target=branch_target,
    871                 source_lines=source_lines,
    872                 comment=comment,
    873             )
    874         )
    875         source_lines = []
    876 
    877         if args.stop_jrra and mnemonic == "jr" and row_parts[1].strip() == "ra":
    878             stop_after_delay_slot = True
    879         elif stop_after_delay_slot:
    880             break
    881 
    882     return output
    883 
    884 
    885 def format_single_line_diff(line1: str, line2: str, column_width: int) -> str:
    886     return ansi_ljust(line1, column_width) + line2
    887 
    888 
    889 class SymbolColorer:
    890     symbol_colors: Dict[str, str]
    891 
    892     def __init__(self, base_index: int) -> None:
    893         self.color_index = base_index
    894         self.symbol_colors = {}
    895 
    896     def color_symbol(self, s: str, t: Optional[str] = None) -> str:
    897         try:
    898             color = self.symbol_colors[s]
    899         except:
    900             color = COLOR_ROTATION[self.color_index % len(COLOR_ROTATION)]
    901             self.color_index += 1
    902             self.symbol_colors[s] = color
    903         t = t or s
    904         return f"{color}{t}{Fore.RESET}"
    905 
    906 
    907 def normalize_imms(row: str) -> str:
    908     return re.sub(re_imm, "<imm>", row)
    909 
    910 
    911 def normalize_stack(row: str) -> str:
    912     return re.sub(re_sprel, "addr(sp)", row)
    913 
    914 
    915 def split_off_branch(line: str) -> Tuple[str, str]:
    916     parts = line.split(",")
    917     if len(parts) < 2:
    918         parts = line.split(None, 1)
    919     off = len(line) - len(parts[-1])
    920     return line[:off], line[off:]
    921 
    922 
    923 ColorFunction = Callable[[str], str]
    924 
    925 
    926 def color_fields(
    927     pat: Pattern[str],
    928     out1: str,
    929     out2: str,
    930     color1: ColorFunction,
    931     color2: Optional[ColorFunction] = None,
    932 ) -> Tuple[str, str]:
    933     diffs = [
    934         of.group() != nf.group()
    935         for (of, nf) in zip(pat.finditer(out1), pat.finditer(out2))
    936     ]
    937 
    938     it = iter(diffs)
    939 
    940     def maybe_color(color: ColorFunction, s: str) -> str:
    941         return color(s) if next(it, False) else f"{Style.RESET_ALL}{s}"
    942 
    943     out1 = pat.sub(lambda m: maybe_color(color1, m.group()), out1)
    944     it = iter(diffs)
    945     out2 = pat.sub(lambda m: maybe_color(color2 or color1, m.group()), out2)
    946 
    947     return out1, out2
    948 
    949 
    950 def color_branch_imms(br1: str, br2: str) -> Tuple[str, str]:
    951     if br1 != br2:
    952         br1 = f"{Fore.LIGHTBLUE_EX}{br1}{Style.RESET_ALL}"
    953         br2 = f"{Fore.LIGHTBLUE_EX}{br2}{Style.RESET_ALL}"
    954     return br1, br2
    955 
    956 
    957 def diff_sequences_difflib(
    958     seq1: List[str], seq2: List[str]
    959 ) -> List[Tuple[str, int, int, int, int]]:
    960     differ = difflib.SequenceMatcher(a=seq1, b=seq2, autojunk=False)
    961     return differ.get_opcodes()
    962 
    963 
    964 def diff_sequences(
    965     seq1: List[str], seq2: List[str]
    966 ) -> List[Tuple[str, int, int, int, int]]:
    967     if (
    968         args.algorithm != "levenshtein"
    969         or len(seq1) * len(seq2) > 4 * 10 ** 8
    970         or len(seq1) + len(seq2) >= 0x110000
    971     ):
    972         return diff_sequences_difflib(seq1, seq2)
    973 
    974     # The Levenshtein library assumes that we compare strings, not lists. Convert.
    975     # (Per the check above we know we have fewer than 0x110000 unique elements, so chr() works.)
    976     remapping: Dict[str, str] = {}
    977 
    978     def remap(seq: List[str]) -> str:
    979         seq = seq[:]
    980         for i in range(len(seq)):
    981             val = remapping.get(seq[i])
    982             if val is None:
    983                 val = chr(len(remapping))
    984                 remapping[seq[i]] = val
    985             seq[i] = val
    986         return "".join(seq)
    987 
    988     rem1 = remap(seq1)
    989     rem2 = remap(seq2)
    990     return Levenshtein.opcodes(rem1, rem2)  # type: ignore
    991 
    992 
    993 def diff_lines(
    994     lines1: List[Line],
    995     lines2: List[Line],
    996 ) -> List[Tuple[Optional[Line], Optional[Line]]]:
    997     ret = []
    998     for (tag, i1, i2, j1, j2) in diff_sequences(
    999         [line.mnemonic for line in lines1],
   1000         [line.mnemonic for line in lines2],
   1001     ):
   1002         for line1, line2 in itertools.zip_longest(lines1[i1:i2], lines2[j1:j2]):
   1003             if tag == "replace":
   1004                 if line1 is None:
   1005                     tag = "insert"
   1006                 elif line2 is None:
   1007                     tag = "delete"
   1008             elif tag == "insert":
   1009                 assert line1 is None
   1010             elif tag == "delete":
   1011                 assert line2 is None
   1012             ret.append((line1, line2))
   1013 
   1014     return ret
   1015 
   1016 
   1017 class OutputLine:
   1018     base: Optional[str]
   1019     fmt2: str
   1020     key2: Optional[str]
   1021 
   1022     def __init__(self, base: Optional[str], fmt2: str, key2: Optional[str]) -> None:
   1023         self.base = base
   1024         self.fmt2 = fmt2
   1025         self.key2 = key2
   1026 
   1027     def __eq__(self, other: object) -> bool:
   1028         if not isinstance(other, OutputLine):
   1029             return NotImplemented
   1030         return self.key2 == other.key2
   1031 
   1032     def __hash__(self) -> int:
   1033         return hash(self.key2)
   1034 
   1035 
   1036 def do_diff(basedump: str, mydump: str) -> List[OutputLine]:
   1037     output: List[OutputLine] = []
   1038 
   1039     lines1 = process(basedump.split("\n"))
   1040     lines2 = process(mydump.split("\n"))
   1041 
   1042     sc1 = SymbolColorer(0)
   1043     sc2 = SymbolColorer(0)
   1044     sc3 = SymbolColorer(4)
   1045     sc4 = SymbolColorer(4)
   1046     sc5 = SymbolColorer(0)
   1047     sc6 = SymbolColorer(0)
   1048     bts1: Set[str] = set()
   1049     bts2: Set[str] = set()
   1050 
   1051     if args.show_branches:
   1052         for (lines, btset, sc) in [
   1053             (lines1, bts1, sc5),
   1054             (lines2, bts2, sc6),
   1055         ]:
   1056             for line in lines:
   1057                 bt = line.branch_target
   1058                 if bt is not None:
   1059                     btset.add(bt + ":")
   1060                     sc.color_symbol(bt + ":")
   1061 
   1062     for (line1, line2) in diff_lines(lines1, lines2):
   1063         line_color1 = line_color2 = sym_color = Fore.RESET
   1064         line_prefix = " "
   1065         if line1 and line2 and line1.diff_row == line2.diff_row:
   1066             if line1.normalized_original == line2.normalized_original:
   1067                 out1 = line1.original
   1068                 out2 = line2.original
   1069             elif line1.diff_row == "<delay-slot>":
   1070                 out1 = f"{Style.BRIGHT}{Fore.LIGHTBLACK_EX}{line1.original}"
   1071                 out2 = f"{Style.BRIGHT}{Fore.LIGHTBLACK_EX}{line2.original}"
   1072             else:
   1073                 mnemonic = line1.original.split()[0]
   1074                 out1, out2 = line1.original, line2.original
   1075                 branch1 = branch2 = ""
   1076                 if mnemonic in instructions_with_address_immediates:
   1077                     out1, branch1 = split_off_branch(line1.original)
   1078                     out2, branch2 = split_off_branch(line2.original)
   1079                 branchless1 = out1
   1080                 branchless2 = out2
   1081                 out1, out2 = color_fields(
   1082                     re_imm,
   1083                     out1,
   1084                     out2,
   1085                     lambda s: f"{Fore.LIGHTBLUE_EX}{s}{Style.RESET_ALL}",
   1086                 )
   1087 
   1088                 same_relative_target = False
   1089                 if line1.branch_target is not None and line2.branch_target is not None:
   1090                     relative_target1 = eval_line_num(
   1091                         line1.branch_target
   1092                     ) - eval_line_num(line1.line_num)
   1093                     relative_target2 = eval_line_num(
   1094                         line2.branch_target
   1095                     ) - eval_line_num(line2.line_num)
   1096                     same_relative_target = relative_target1 == relative_target2
   1097 
   1098                 if not same_relative_target:
   1099                     branch1, branch2 = color_branch_imms(branch1, branch2)
   1100 
   1101                 out1 += branch1
   1102                 out2 += branch2
   1103                 if normalize_imms(branchless1) == normalize_imms(branchless2):
   1104                     if not same_relative_target:
   1105                         # only imms differences
   1106                         sym_color = Fore.LIGHTBLUE_EX
   1107                         line_prefix = "i"
   1108                 else:
   1109                     out1, out2 = color_fields(
   1110                         re_sprel, out1, out2, sc3.color_symbol, sc4.color_symbol
   1111                     )
   1112                     if normalize_stack(branchless1) == normalize_stack(branchless2):
   1113                         # only stack differences (luckily stack and imm
   1114                         # differences can't be combined in MIPS, so we
   1115                         # don't have to think about that case)
   1116                         sym_color = Fore.YELLOW
   1117                         line_prefix = "s"
   1118                     else:
   1119                         # regs differences and maybe imms as well
   1120                         out1, out2 = color_fields(
   1121                             re_reg, out1, out2, sc1.color_symbol, sc2.color_symbol
   1122                         )
   1123                         line_color1 = line_color2 = sym_color = Fore.YELLOW
   1124                         line_prefix = "r"
   1125         elif line1 and line2:
   1126             line_prefix = "|"
   1127             line_color1 = Fore.LIGHTBLUE_EX
   1128             line_color2 = Fore.LIGHTBLUE_EX
   1129             sym_color = Fore.LIGHTBLUE_EX
   1130             out1 = line1.original
   1131             out2 = line2.original
   1132         elif line1:
   1133             line_prefix = "<"
   1134             line_color1 = sym_color = Fore.RED
   1135             out1 = line1.original
   1136             out2 = ""
   1137         elif line2:
   1138             line_prefix = ">"
   1139             line_color2 = sym_color = Fore.GREEN
   1140             out1 = ""
   1141             out2 = line2.original
   1142 
   1143         if args.source and line2 and line2.comment:
   1144             out2 += f" {line2.comment}"
   1145 
   1146         def format_part(
   1147             out: str,
   1148             line: Optional[Line],
   1149             line_color: str,
   1150             btset: Set[str],
   1151             sc: SymbolColorer,
   1152         ) -> Optional[str]:
   1153             if line is None:
   1154                 return None
   1155             in_arrow = "  "
   1156             out_arrow = ""
   1157             if args.show_branches:
   1158                 if line.line_num in btset:
   1159                     in_arrow = sc.color_symbol(line.line_num, "~>") + line_color
   1160                 if line.branch_target is not None:
   1161                     out_arrow = " " + sc.color_symbol(line.branch_target + ":", "~>")
   1162             out = pad_mnemonic(out)
   1163             return f"{line_color}{line.line_num} {in_arrow} {out}{Style.RESET_ALL}{out_arrow}"
   1164 
   1165         part1 = format_part(out1, line1, line_color1, bts1, sc5)
   1166         part2 = format_part(out2, line2, line_color2, bts2, sc6)
   1167         key2 = line2.original if line2 else None
   1168 
   1169         mid = f"{sym_color}{line_prefix}"
   1170 
   1171         if line2:
   1172             for source_line in line2.source_lines:
   1173                 color = Style.DIM
   1174                 # File names and function names
   1175                 if source_line and source_line[0] != "│":
   1176                     color += Style.BRIGHT
   1177                     # Function names
   1178                     if source_line.endswith("():"):
   1179                         # Underline. Colorama does not provide this feature, unfortunately.
   1180                         color += "\u001b[4m"
   1181                         try:
   1182                             source_line = cxxfilt.demangle(
   1183                                 source_line[:-3], external_only=False
   1184                             )
   1185                         except:
   1186                             pass
   1187                 output.append(
   1188                     OutputLine(
   1189                         None,
   1190                         f"  {color}{source_line}{Style.RESET_ALL}",
   1191                         source_line,
   1192                     )
   1193                 )
   1194 
   1195         fmt2 = mid + " " + (part2 or "")
   1196         output.append(OutputLine(part1, fmt2, key2))
   1197 
   1198     return output
   1199 
   1200 
   1201 def chunk_diff(diff: List[OutputLine]) -> List[Union[List[OutputLine], OutputLine]]:
   1202     cur_right: List[OutputLine] = []
   1203     chunks: List[Union[List[OutputLine], OutputLine]] = []
   1204     for output_line in diff:
   1205         if output_line.base is not None:
   1206             chunks.append(cur_right)
   1207             chunks.append(output_line)
   1208             cur_right = []
   1209         else:
   1210             cur_right.append(output_line)
   1211     chunks.append(cur_right)
   1212     return chunks
   1213 
   1214 
   1215 def format_diff(
   1216     old_diff: List[OutputLine], new_diff: List[OutputLine]
   1217 ) -> Tuple[str, List[str]]:
   1218     old_chunks = chunk_diff(old_diff)
   1219     new_chunks = chunk_diff(new_diff)
   1220     output: List[Tuple[str, OutputLine, OutputLine]] = []
   1221     assert len(old_chunks) == len(new_chunks), "same target"
   1222     empty = OutputLine("", "", None)
   1223     for old_chunk, new_chunk in zip(old_chunks, new_chunks):
   1224         if isinstance(old_chunk, list):
   1225             assert isinstance(new_chunk, list)
   1226             if not old_chunk and not new_chunk:
   1227                 # Most of the time lines sync up without insertions/deletions,
   1228                 # and there's no interdiffing to be done.
   1229                 continue
   1230             differ = difflib.SequenceMatcher(a=old_chunk, b=new_chunk, autojunk=False)
   1231             for (tag, i1, i2, j1, j2) in differ.get_opcodes():
   1232                 if tag in ["equal", "replace"]:
   1233                     for i, j in zip(range(i1, i2), range(j1, j2)):
   1234                         output.append(("", old_chunk[i], new_chunk[j]))
   1235                 if tag in ["insert", "replace"]:
   1236                     for j in range(j1 + i2 - i1, j2):
   1237                         output.append(("", empty, new_chunk[j]))
   1238                 if tag in ["delete", "replace"]:
   1239                     for i in range(i1 + j2 - j1, i2):
   1240                         output.append(("", old_chunk[i], empty))
   1241         else:
   1242             assert isinstance(new_chunk, OutputLine)
   1243             assert new_chunk.base
   1244             # old_chunk.base and new_chunk.base have the same text since
   1245             # both diffs are based on the same target, but they might
   1246             # differ in color. Use the new version.
   1247             output.append((new_chunk.base, old_chunk, new_chunk))
   1248 
   1249     # TODO: status line, with e.g. approximate permuter score?
   1250     width = args.column_width
   1251     if args.threeway:
   1252         header_line = "TARGET".ljust(width) + "  CURRENT".ljust(width) + "  PREVIOUS"
   1253         diff_lines = [
   1254             ansi_ljust(base, width)
   1255             + ansi_ljust(new.fmt2, width)
   1256             + (old.fmt2 or "-" if old != new else "")
   1257             for (base, old, new) in output
   1258         ]
   1259     else:
   1260         header_line = ""
   1261         diff_lines = [
   1262             ansi_ljust(base, width) + new.fmt2
   1263             for (base, old, new) in output
   1264             if base or new.key2 is not None
   1265         ]
   1266     return header_line, diff_lines
   1267 
   1268 
   1269 def debounced_fs_watch(
   1270     targets: List[str],
   1271     outq: "queue.Queue[Optional[float]]",
   1272     debounce_delay: float,
   1273 ) -> None:
   1274     import watchdog.events  # type: ignore
   1275     import watchdog.observers  # type: ignore
   1276 
   1277     class WatchEventHandler(watchdog.events.FileSystemEventHandler):  # type: ignore
   1278         def __init__(
   1279             self, queue: "queue.Queue[float]", file_targets: List[str]
   1280         ) -> None:
   1281             self.queue = queue
   1282             self.file_targets = file_targets
   1283 
   1284         def on_modified(self, ev: object) -> None:
   1285             if isinstance(ev, watchdog.events.FileModifiedEvent):
   1286                 self.changed(ev.src_path)
   1287 
   1288         def on_moved(self, ev: object) -> None:
   1289             if isinstance(ev, watchdog.events.FileMovedEvent):
   1290                 self.changed(ev.dest_path)
   1291 
   1292         def should_notify(self, path: str) -> bool:
   1293             for target in self.file_targets:
   1294                 if path == target:
   1295                     return True
   1296             if args.make and any(
   1297                 path.endswith(suffix) for suffix in FS_WATCH_EXTENSIONS
   1298             ):
   1299                 return True
   1300             return False
   1301 
   1302         def changed(self, path: str) -> None:
   1303             if self.should_notify(path):
   1304                 self.queue.put(time.time())
   1305 
   1306     def debounce_thread() -> NoReturn:
   1307         listenq: "queue.Queue[float]" = queue.Queue()
   1308         file_targets: List[str] = []
   1309         event_handler = WatchEventHandler(listenq, file_targets)
   1310         observer = watchdog.observers.Observer()
   1311         observed = set()
   1312         for target in targets:
   1313             if os.path.isdir(target):
   1314                 observer.schedule(event_handler, target, recursive=True)
   1315             else:
   1316                 file_targets.append(target)
   1317                 target = os.path.dirname(target) or "."
   1318                 if target not in observed:
   1319                     observed.add(target)
   1320                     observer.schedule(event_handler, target)
   1321         observer.start()
   1322         while True:
   1323             t = listenq.get()
   1324             more = True
   1325             while more:
   1326                 delay = t + debounce_delay - time.time()
   1327                 if delay > 0:
   1328                     time.sleep(delay)
   1329                 # consume entire queue
   1330                 more = False
   1331                 try:
   1332                     while True:
   1333                         t = listenq.get(block=False)
   1334                         more = True
   1335                 except queue.Empty:
   1336                     pass
   1337             outq.put(t)
   1338 
   1339     th = threading.Thread(target=debounce_thread, daemon=True)
   1340     th.start()
   1341 
   1342 
   1343 class Display:
   1344     basedump: str
   1345     mydump: str
   1346     emsg: Optional[str]
   1347     last_diff_output: Optional[List[OutputLine]]
   1348     pending_update: Optional[Tuple[str, bool]]
   1349     ready_queue: "queue.Queue[None]"
   1350     watch_queue: "queue.Queue[Optional[float]]"
   1351     less_proc: "Optional[subprocess.Popen[bytes]]"
   1352 
   1353     def __init__(self, basedump: str, mydump: str) -> None:
   1354         self.basedump = basedump
   1355         self.mydump = mydump
   1356         self.emsg = None
   1357         self.last_diff_output = None
   1358 
   1359     def run_less(self) -> "Tuple[subprocess.Popen[bytes], subprocess.Popen[bytes]]":
   1360         if self.emsg is not None:
   1361             output = self.emsg
   1362         else:
   1363             diff_output = do_diff(self.basedump, self.mydump)
   1364             last_diff_output = self.last_diff_output or diff_output
   1365             if args.threeway != "base" or not self.last_diff_output:
   1366                 self.last_diff_output = diff_output
   1367             header, diff_lines = format_diff(last_diff_output, diff_output)
   1368             header_lines = [header] if header else []
   1369             output = "\n".join(header_lines + diff_lines[args.skip_lines :])
   1370 
   1371         # Pipe the output through 'tail' and only then to less, to ensure the
   1372         # write call doesn't block. ('tail' has to buffer all its input before
   1373         # it starts writing.) This also means we don't have to deal with pipe
   1374         # closure errors.
   1375         buffer_proc = subprocess.Popen(
   1376             BUFFER_CMD, stdin=subprocess.PIPE, stdout=subprocess.PIPE
   1377         )
   1378         less_proc = subprocess.Popen(LESS_CMD, stdin=buffer_proc.stdout)
   1379         assert buffer_proc.stdin
   1380         assert buffer_proc.stdout
   1381         buffer_proc.stdin.write(output.encode())
   1382         buffer_proc.stdin.close()
   1383         buffer_proc.stdout.close()
   1384         return (buffer_proc, less_proc)
   1385 
   1386     def run_sync(self) -> None:
   1387         proca, procb = self.run_less()
   1388         procb.wait()
   1389         proca.wait()
   1390 
   1391     def run_async(self, watch_queue: "queue.Queue[Optional[float]]") -> None:
   1392         self.watch_queue = watch_queue
   1393         self.ready_queue = queue.Queue()
   1394         self.pending_update = None
   1395         dthread = threading.Thread(target=self.display_thread)
   1396         dthread.start()
   1397         self.ready_queue.get()
   1398 
   1399     def display_thread(self) -> None:
   1400         proca, procb = self.run_less()
   1401         self.less_proc = procb
   1402         self.ready_queue.put(None)
   1403         while True:
   1404             ret = procb.wait()
   1405             proca.wait()
   1406             self.less_proc = None
   1407             if ret != 0:
   1408                 # fix the terminal
   1409                 os.system("tput reset")
   1410             if ret != 0 and self.pending_update is not None:
   1411                 # killed by program with the intent to refresh
   1412                 msg, error = self.pending_update
   1413                 self.pending_update = None
   1414                 if not error:
   1415                     self.mydump = msg
   1416                     self.emsg = None
   1417                 else:
   1418                     self.emsg = msg
   1419                 proca, procb = self.run_less()
   1420                 self.less_proc = procb
   1421                 self.ready_queue.put(None)
   1422             else:
   1423                 # terminated by user, or killed
   1424                 self.watch_queue.put(None)
   1425                 self.ready_queue.put(None)
   1426                 break
   1427 
   1428     def progress(self, msg: str) -> None:
   1429         # Write message to top-left corner
   1430         sys.stdout.write("\x1b7\x1b[1;1f{}\x1b8".format(msg + " "))
   1431         sys.stdout.flush()
   1432 
   1433     def update(self, text: str, error: bool) -> None:
   1434         if not error and not self.emsg and text == self.mydump:
   1435             self.progress("Unchanged. ")
   1436             return
   1437         self.pending_update = (text, error)
   1438         if not self.less_proc:
   1439             return
   1440         self.less_proc.kill()
   1441         self.ready_queue.get()
   1442 
   1443     def terminate(self) -> None:
   1444         if not self.less_proc:
   1445             return
   1446         self.less_proc.kill()
   1447         self.ready_queue.get()
   1448 
   1449 
   1450 def main() -> None:
   1451     if args.diff_elf_symbol:
   1452         make_target, basecmd, mycmd = dump_elf()
   1453     elif args.diff_obj:
   1454         make_target, basecmd, mycmd = dump_objfile()
   1455     else:
   1456         make_target, basecmd, mycmd = dump_binary()
   1457 
   1458     if args.write_asm is not None:
   1459         mydump = run_objdump(mycmd)
   1460         with open(args.write_asm, "w") as f:
   1461             f.write(mydump)
   1462         print(f"Wrote assembly to {args.write_asm}.")
   1463         sys.exit(0)
   1464 
   1465     if args.base_asm is not None:
   1466         with open(args.base_asm) as f:
   1467             basedump = f.read()
   1468     else:
   1469         basedump = run_objdump(basecmd)
   1470 
   1471     mydump = run_objdump(mycmd)
   1472 
   1473     display = Display(basedump, mydump)
   1474 
   1475     if not args.watch:
   1476         display.run_sync()
   1477     else:
   1478         if not args.make:
   1479             yn = input(
   1480                 "Warning: watch-mode (-w) enabled without auto-make (-m). "
   1481                 "You will have to run make manually. Ok? (Y/n) "
   1482             )
   1483             if yn.lower() == "n":
   1484                 return
   1485         if args.make:
   1486             watch_sources = None
   1487             watch_sources_for_target_fn = getattr(
   1488                 diff_settings, "watch_sources_for_target", None
   1489             )
   1490             if watch_sources_for_target_fn:
   1491                 watch_sources = watch_sources_for_target_fn(make_target)
   1492             watch_sources = watch_sources or source_directories
   1493             if not watch_sources:
   1494                 fail("Missing source_directories config, don't know what to watch.")
   1495         else:
   1496             watch_sources = [make_target]
   1497         q: "queue.Queue[Optional[float]]" = queue.Queue()
   1498         debounced_fs_watch(watch_sources, q, DEBOUNCE_DELAY)
   1499         display.run_async(q)
   1500         last_build = 0.0
   1501         try:
   1502             while True:
   1503                 t = q.get()
   1504                 if t is None:
   1505                     break
   1506                 if t < last_build:
   1507                     continue
   1508                 last_build = time.time()
   1509                 if args.make:
   1510                     display.progress("Building...")
   1511                     ret = run_make_capture_output(make_target)
   1512                     if ret.returncode != 0:
   1513                         display.update(
   1514                             ret.stderr.decode("utf-8-sig", "replace")
   1515                             or ret.stdout.decode("utf-8-sig", "replace"),
   1516                             error=True,
   1517                         )
   1518                         continue
   1519                 mydump = run_objdump(mycmd)
   1520                 display.update(mydump, error=False)
   1521         except KeyboardInterrupt:
   1522             display.terminate()
   1523 
   1524 
   1525 main()