#!/usr/bin/env python3

import sys, re, subprocess
from collections import namedtuple

Function = namedtuple('Function', 'length length_instructions offset refs')

start_func = re.compile(r'^([0-9a-f]+) <(.+)>:$')

instruction = re.compile(r'^\s*[0-9a-f]+:\s+((?:[0-9a-f]{2} )+)\s*([0-9a-z()]+)?(?:\w+(.+))?\w*$')
instruction_ref = re.compile(r'^[^<]*(?:<([^>+]+)(?:[+]([^>]+))?>)?$')

entry_point = re.compile(r'^\s*Entry point address:\s+0x([0-9a-f]+)$')

data_line = re.compile(r'^\s*0x[0-9a-f]+ ([0-9a-f]+) ([0-9a-f]+) ([0-9a-f]+) ([0-9a-f]+) .+$')

global_text = re.compile(r'^[0-9a-f]+ T (.+)$')

def parse(disasm):
    funcs = {}

    cur_func = None
    cur_offset = 0
    cur_len = 0
    cur_len_instructions = 0
    cur_refs = None
    for line in disasm:
        line = line.rstrip('\n')

        if len(line) == 0:
            continue

        if line.startswith('Disassembly of section '):
            continue

        m = start_func.match(line)
        if m:
            if cur_func:
                funcs[cur_func] = Function(length = cur_len,
                                               length_instructions = cur_len_instructions,
                                               offset = cur_offset,
                                               refs = cur_refs)
            cur_func = m.group(2)
            cur_offset = m.group(1)
            cur_len = 0
            cur_len_instructions = 0
            cur_refs = set()

            continue

        if not cur_func:
            continue

        m = instruction.match(line)
        if m:
            instruction_bytes = m.group(1)
            instruction_opcode = m.group(2)
            instruction_operands = m.group(3)
            if not instruction_opcode:
                #Padding.  As a hack decrement isns count, since it isn't really one.
                cur_len_instructions -= 1
            if instruction_operands:
                ref_match = instruction_ref.match(instruction_operands)
                if ref_match:
                    ref = ref_match.group(1)
                    if ref:
                        cur_refs.add(ref.split('+')[0])
            cur_len_instructions += 1
            cur_len += instruction_bytes.count(' ')
        else:
            print("Couldn't parse '%s' " % line)
            assert(False)

    if cur_func:
        funcs[cur_func] = Function(length = cur_len,
                                       length_instructions = cur_len_instructions,
                                       offset = cur_offset,
                                       refs = cur_refs)

    for key in tuple(funcs.keys()):
        if key.startswith('.'):
            del(funcs[key])

    return funcs

def get_roots(binary, funcs):
    roots = set()
    candidates = []

    inv_funcs = {v.offset: k for k, v in funcs.items()}

    #Get the entry point, and scan various data sections for possible function pointers.
    sections = [ '.data.rel.ro', '.rodata', '.data' ]
    section_args = [ v for pair in zip(['-x'] * 3, sections) for v in pair ]
    readelf = subprocess.run(['readelf', '-W', '-h', *section_args, binary],
                                stdout=subprocess.PIPE, universal_newlines=True)
    for l in readelf.stdout.split('\n'):
        m = entry_point.match(l)
        if m:
            roots.add(inv_funcs[m.group(1).rjust(16, '0')])

        m = data_line.match(l)
        if m:
            #XXX This misses the trailing line if it has less than 16 bytes.
            for i in 1, 3:
                hi = m.group(i+1)
                lo = m.group(i)
                candidates.append("%s%s%s%s%s%s%s%s" % (hi[6:8], hi[4:6], hi[2:4], hi[0:2],
                                                            lo[6:8], lo[4:6], lo[2:4], lo[0:2]))

    for c in candidates:
        f = inv_funcs.get(c)
        if f:
            roots.add(f)

    #Get exported .text segment symbols.
    nm = subprocess.run(['nm', '-g', binary], stdout=subprocess.PIPE, universal_newlines=True)
    for l in nm.stdout.split('\n'):
        m = global_text.match(l)
        if m:
            roots.add(m.group(1))

    return roots

def mark(root, funcs, markset):
#    if root == 'GFp_bn_sqr8x_internal':
#        import inspect
#        raise Exception([f.frame.f_locals['root'] for f in inspect.stack() if f.function == 'mark'])
    markset.add(root)
    for ref in funcs[root].refs:
        if ref == root:
            continue
        if ref in markset:
            continue
        if ref in funcs:
            mark(ref, funcs, markset)

def clobber(binary, clobbers, funcs):
    with open(binary, 'r+b') as binfile:
        for clobber in clobbers:
            func = funcs[clobber]
            offset = int(func.offset, 16)
            clobber_bytes = b'\xcc' * (func.length - 1)
            clobber_bytes += b'\xc3'
            binfile.seek(offset)
            binfile.write(clobber_bytes)

def main():
    assert(sys.argv[1].endswith('.so'))

    disasm = subprocess.Popen(['objdump', '-d', sys.argv[1]],
                                  stdout=subprocess.PIPE,
                                  universal_newlines=True)
    funcs = parse(disasm.stdout)

    roots = get_roots(sys.argv[1], funcs)
    print("Found %s roots" % len(roots))

    marks = set()
    for root in roots:
        mark(root, funcs, marks)

    deads = set(funcs) - marks

    to_clobber = set()
    byte_cnt = 0
    for dead in deads:
        #if dead.startswith("_Z"):
        #    continue #Skip Rust code for now.
        byte_cnt += funcs[dead].length
        print("Dead: %s - %s isns %s:%s" % (dead, funcs[dead].length_instructions,
                                          funcs[dead].offset, funcs[dead].length))
        to_clobber.add(dead)

    print('Dead/Total functions: %s/%s (%.2f%%)' % (len(to_clobber), len(funcs),
                                                        len(to_clobber)/len(funcs)*100.0))

    byte_total = 0
    for f in funcs.values():
        byte_total += f.length
    print('Dead bytes/Total bytes: %s/%s (%.2f%%)' % (byte_cnt, byte_total,
                                                          byte_cnt/byte_total*100.0))

    if len(sys.argv) >= 3 and sys.argv[2] == '--clobber':
        clobber(sys.argv[1], to_clobber, funcs)
        print('Clobbered %s functions' % len(to_clobber))

if __name__ == '__main__':
    main()
