diff --git a/codetransformer/assembler.py b/codetransformer/assembler.py new file mode 100644 index 0000000..75fa7b2 --- /dev/null +++ b/codetransformer/assembler.py @@ -0,0 +1,99 @@ +import sys +import types +from toolz import mapcat + +from . import instructions as instrs +from .code import Code + + +Label = instrs.Label + + +def assemble_function(signature, objs, code_kwargs=None, function_kwargs=None): + """TODO + """ + if code_kwargs is None: + code_kwargs = {} + if function_kwargs is None: + function_kwargs = {} + + code_kwargs.setdefault('argnames', list(gen_argnames_for_code(signature))) + + # Default to using the globals of the calling stack frame. + function_kwargs.setdefault('globals', sys._getframe(1).f_globals) + + function_kwargs.setdefault('argdefs', tuple(extract_defaults(signature))) + + code = assemble_code(objs, **code_kwargs).to_pycode() + + return types.FunctionType(code, **function_kwargs) + + +def assemble_code(objs, **code_kwargs): + """TODO + """ + instrs = resolve_labels(assemble_instructions(objs)) + return Code(instrs, **code_kwargs) + + +def assemble_instructions(objs): + """Assemble a sequence of Instructions or iterables of instructions. + """ + return list(mapcat(_validate_instructions, objs)) + + +def resolve_labels(objs): + """TODO + """ + out = [] + last_instr = None + for i in reversed(objs): + if isinstance(i, Label): + if last_instr is None: + # TODO: Better error here. + raise ValueError("Can't end with a Label!") + # Make any jumps to `i` resolve to `last_instr`. + last_instr.steal(i) + elif isinstance(i, instrs.Instruction): + last_instr = i + out.append(i) + else: + raise TypeError("Unknown type: {}", i) + + for i in out: + if isinstance(i.arg, Label): + raise ValueError("Unresolved label for {}".format(i)) + + return reversed(out) + + +def _validate_instructions(obj): + """TODO + """ + Instruction = instrs.Instruction + if isinstance(obj, (Label, Instruction)): + yield obj + else: + for instr in obj: + if not isinstance(instr, (Instruction, Label)): + raise TypeError( + "Expected an Instruction or Label. Got %s" % obj, + ) + yield instr + + +def gen_argnames_for_code(sig): + """Get argnames from an inspect.signature to pass to a Code object. """ + for name, param in sig.parameters.items(): + if param.kind == param.VAR_POSITIONAL: + yield '*' + name + elif param.kind == param.VAR_KEYWORD: + yield '**' + name + else: + yield name + + +def extract_defaults(sig): + """Get default parameters from an inspect.signature. + """ + return (p.default for p in sig.parameters.values() if p.default != p.empty) diff --git a/codetransformer/code.py b/codetransformer/code.py index c576206..541abbe 100644 --- a/codetransformer/code.py +++ b/codetransformer/code.py @@ -344,21 +344,18 @@ def __init__(self, if kwarg is not None: raise ValueError('cannot specify **kwargs more than once') kwarg = argname[2:] + append_argname(argname) continue elif argname.startswith('*'): if varg is not None: raise ValueError('cannot specify *args more than once') varg = argname[1:] argcounter = kwonlyargcount # all following args are kwonly. + append_argname(argname) continue argcounter[0] += 1 append_argname(argname) - if varg is not None: - append_argname(varg) - if kwarg is not None: - append_argname(kwarg) - cellvar_names = set(cellvars) freevar_names = set(freevars) for instr in filter(op.attrgetter('uses_free'), instrs): diff --git a/codetransformer/instructions.py b/codetransformer/instructions.py index 624d70b..a8f1309 100644 --- a/codetransformer/instructions.py +++ b/codetransformer/instructions.py @@ -64,30 +64,29 @@ def _vartype(self): class InstructionMeta(ABCMeta, matchable): - _marker = object() # sentinel _type_cache = {} - def __init__(self, *args, opcode=None): + def __init__(self, *args, opcode=None, synthetic=False): return super().__init__(*args) - def __new__(mcls, name, bases, dict_, *, opcode=None): + def __new__(mcls, name, bases, dict_, *, opcode=None, synthetic=False): try: return mcls._type_cache[opcode] except KeyError: pass - if len(bases) != 1: + if len(bases) > 1: raise TypeError( '{} does not support multiple inheritance'.format( mcls.__name__, ), ) - if bases[0] is mcls._marker: + if synthetic: dict_['_reprname'] = immutableattr(name) for attr in ('absjmp', 'have_arg', 'opcode', 'opname', 'reljmp'): dict_[attr] = _notimplemented(attr) - return super().__new__(mcls, name, (object,), dict_) + return super().__new__(mcls, name, bases, dict_) if opcode not in opmap.values(): raise TypeError('Invalid opcode: {}'.format(opcode)) @@ -123,7 +122,45 @@ def __repr__(self): __str__ = __repr__ -class Instruction(InstructionMeta._marker, metaclass=InstructionMeta): +class JumpTarget: + """Base class for objects that can be targets of jump instructions. + + This is the base for both Instruction and Label. + """ + + def __init__(self): + self._target_of = set() + self._stolen_by = None # used for lnotab recalculation + + def steal(self, instr): + """Steal the jump index off of `instr`. + + This makes anything that would have jumped to `instr` jump to + this Instruction instead. + + Parameters + ---------- + instr : JumpTarget + The target to steal the jump sources from. + + Returns + ------- + self : JumpTarget + The object that owns this method. + + Notes + ----- + This mutates self and ``instr`` inplace. + """ + instr._stolen_by = self + for jmp in instr._target_of: + jmp.arg = self + self._target_of = instr._target_of + instr._target_of = set() + return self + + +class Instruction(JumpTarget, metaclass=InstructionMeta, synthetic=True): """ Base class for all instruction types. @@ -139,13 +176,12 @@ class Instruction(InstructionMeta._marker, metaclass=InstructionMeta): _no_arg = no_default def __init__(self, arg=_no_arg): + super().__init__() if self.have_arg and arg is self._no_arg: raise TypeError( "{} missing 1 required argument: 'arg'".format(self.opname), ) self.arg = self._normalize_arg(arg) - self._target_of = set() - self._stolen_by = None # used for lnotab recalculation def __repr__(self): arg = self.arg @@ -158,33 +194,6 @@ def __repr__(self): def _normalize_arg(arg): return arg - def steal(self, instr): - """Steal the jump index off of `instr`. - - This makes anything that would have jumped to `instr` jump to - this Instruction instead. - - Parameters - ---------- - instr : Instruction - The instruction to steal the jump sources from. - - Returns - ------- - self : Instruction - The instruction that owns this method. - - Notes - ----- - This mutates self and ``instr`` inplace. - """ - instr._stolen_by = self - for jmp in instr._target_of: - jmp.arg = self - self._target_of = instr._target_of - instr._target_of = set() - return self - @classmethod def from_opcode(cls, opcode, arg=_no_arg): """ @@ -302,13 +311,13 @@ def _call_repr(self): def _check_jmp_arg(self, arg): - if not isinstance(arg, (Instruction, _RawArg)): + if not isinstance(arg, (JumpTarget, _RawArg)): raise TypeError( - 'argument to %s must be an instruction, got: %r' % ( + 'argument to %s must be a valid jump target, got: %r' % ( type(self).__name__, arg, ), ) - if isinstance(arg, Instruction): + if isinstance(arg, JumpTarget): arg._target_of.add(self) return arg @@ -424,6 +433,18 @@ def __get__(self, instance, owner): del class_ +class Label(JumpTarget): + """A "pseudo-instruction" that can be the target of a jump instruction. + """ + + def __init__(self, debug_name='anonymous'): + super().__init__() + self.debug_name = debug_name + + def __repr__(self): + return "Label({!r})".format(self.debug_name) + + # Clean up the namespace del name del globals_ diff --git a/codetransformer/macros.py b/codetransformer/macros.py new file mode 100644 index 0000000..d09bef5 --- /dev/null +++ b/codetransformer/macros.py @@ -0,0 +1,141 @@ +"""Macros for building high-level operations in bytecode. +""" +from . import instructions as instrs +from .assembler import assemble_instructions + +Label = instrs.Label + + +class Macro: + """TODO + """ + def assemble(self): + raise NotImplementedError('assemble') + + def __iter__(self): + return self.assemble() + + +class ForLoop(Macro): + """Macro for assembling for-loops. + """ + + def __init__(self, init, unpack, body, else_body=()): + self.init = init + self.unpack = unpack + self.body = body + self.else_body = else_body + + def assemble(self): + top_of_loop = Label('top') + cleanup = Label('cleanup') + end = Label('end') + + # Loop setup. + yield instrs.SETUP_LOOP(end) + yield from assemble_instructions(self.init) + yield instrs.GET_ITER() + + # Loop iteration setup. + yield top_of_loop + yield instrs.FOR_ITER(cleanup) + yield from assemble_instructions(self.unpack) + + # Loop body. + yield from assemble_instructions(self.body) + yield instrs.JUMP_ABSOLUTE(top_of_loop) + + # Cleanup. + yield cleanup + yield instrs.POP_BLOCK() + yield from self.else_body + + # End of Loop. + yield end + + +class IfStatement(Macro): + """Macro for assembling an if block. + """ + def __init__(self, test, body, else_body=()): + self.test = test + self.body = body + self.else_body = else_body + + def assemble(self): + done = Label('done') + + # Setup Test. + yield from self.test + + if self.else_body: + # Test. + start_of_else = Label('start_of_else') + yield instrs.POP_JUMP_IF_FALSE(start_of_else) + + # Main Branch. + yield from self.body + yield instrs.JUMP_FORWARD(done) + + # Else Branch. + yield start_of_else + yield from self.else_body + else: + # Test. + yield instrs.POP_JUMP_IF_FALSE(done) + + # Body. + yield from self.body + + yield done + + +class PrintVariable(Macro): + """Macro for printing a local variable by name. + + This is mostly useful for debugging. + """ + def __init__(self, name): + self.name = name + + def assemble(self): + yield instrs.LOAD_FAST(self.name) + yield instrs.PRINT_EXPR() + + def __repr__(self): + return "{}({!r})".format(type(self).__name__, self.name) + + +class PrintStack(Macro): + """Macro for printing the toptop N values on the stack. + """ + def __init__(self, n=1): + self.n = n + + def assemble(self): + # Pop the top N values off the stack into a tuple. + yield instrs.BUILD_TUPLE(self.n) + # Make a copy of the tuple. + yield instrs.DUP_TOP() + # Print it. This pops the copy. + yield instrs.PRINT_EXPR() + # Unpack the tuple back onto the stack. We call reversed here because + # UNPACK_SEQUENCE unpacks in reverse. + yield instrs.LOAD_CONST(reversed) + yield instrs.ROT_TWO() + yield instrs.CALL_FUNCTION(1) + yield instrs.UNPACK_SEQUENCE(self.n) + + def __repr__(self): + return "{}({!r})".format(type(self).__name__, self.n) + + +class AssertFail(Macro): + def __init__(self, message): + self.message = message + + def assemble(self): + yield instrs.LOAD_CONST(AssertionError) + yield instrs.LOAD_CONST(self.message) + yield instrs.CALL_FUNCTION(1) + yield instrs.RAISE_VARARGS(1) diff --git a/codetransformer/tests/test_macros.py b/codetransformer/tests/test_macros.py new file mode 100644 index 0000000..3f861a0 --- /dev/null +++ b/codetransformer/tests/test_macros.py @@ -0,0 +1,264 @@ +import pytest + +from functools import total_ordering +from inspect import signature + +from .. import instructions as instrs +from ..assembler import assemble_function +from ..macros import AssertFail, IfStatement, ForLoop +from ..utils.instance import instance + + +def assert_same_result(f1, f2, *args, **kwargs): + try: + result1 = f1(*args, **kwargs) + f1_raised = False + except Exception as e: + result1 = e + f1_raised = True + + try: + result2 = f2(*args, **kwargs) + f2_raised = False + except Exception as e: + result2 = e + f2_raised = True + + if f1_raised and not f2_raised: + raise AssertionError("\n{} raised {}\n{} returned {}".format( + f1.__name__, result1, f2.__name__, result2 + )) + elif not f1_raised and f2_raised: + raise AssertionError("\n{} returned {}\n {} raised {}".format( + f1.__name__, result1, f2.__name__, result2 + )) + elif f1_raised and f2_raised: + assert type(result1) == type(result2) and result1.args == result2.args + else: + assert result1 == result2 + + +def test_simple_if_statement(): + + def goal(x, y): + z = 0 + if x > y: + z = 1 + z = z + 1 + return z + + assembly = [ + instrs.LOAD_CONST(0), + instrs.STORE_FAST('z'), + IfStatement( + test=[ + instrs.LOAD_FAST('x'), + instrs.LOAD_FAST('y'), + instrs.COMPARE_OP.GT, + ], + body=[ + instrs.LOAD_CONST(1), + instrs.STORE_FAST('z'), + ], + ), + instrs.LOAD_FAST('z'), + instrs.LOAD_CONST(1), + instrs.BINARY_ADD(), + instrs.STORE_FAST('z'), + instrs.LOAD_FAST('z'), + instrs.RETURN_VALUE(), + ] + func = assemble_function(signature(goal), assembly) + + assert_same_result(func, goal, 1, 1) + assert_same_result(func, goal, 1, 2) + assert_same_result(func, goal, 2, 1) + assert_same_result(func, goal, incomparable, 1) + assert_same_result(func, goal, 1, incomparable) + + +def test_if_else(): + + def goal(x): + if x > 0: + return x + else: + return -x + + assembly = [ + IfStatement( + test=[ + instrs.LOAD_FAST('x'), + instrs.LOAD_CONST(0), + instrs.COMPARE_OP.GT, + ], + body=[ + instrs.LOAD_FAST('x'), + instrs.RETURN_VALUE(), + ], + else_body=[ + instrs.LOAD_FAST('x'), + instrs.UNARY_NEGATIVE(), + instrs.RETURN_VALUE(), + ], + ), + AssertFail("Shouldn't ever get here!"), + ] + + func = assemble_function(signature(goal), assembly) + + assert_same_result(func, goal, 1) + assert_same_result(func, goal, 0) + assert_same_result(func, goal, -1) + assert_same_result(func, goal, object()) + + +def test_simple_for_loop(): + + def goal(x): + result = [] + for i in range(x): + result.append(i * 2) + return result + + assembly = [ + instrs.BUILD_LIST(0), + instrs.STORE_FAST('result'), + ForLoop( + init=[ + instrs.LOAD_GLOBAL('range'), + instrs.LOAD_FAST('x'), + instrs.CALL_FUNCTION(1), + ], + unpack=[ + instrs.STORE_FAST('i'), + ], + body=[ + instrs.LOAD_FAST('result'), + instrs.LOAD_ATTR('append'), + instrs.LOAD_FAST('i'), + instrs.LOAD_CONST(2), + instrs.BINARY_MULTIPLY(), + instrs.CALL_FUNCTION(1), + instrs.POP_TOP(), + ], + ), + instrs.LOAD_FAST('result'), + instrs.RETURN_VALUE(), + ] + + func = assemble_function(signature(goal), assembly) + + assert_same_result(func, goal, 1) + assert_same_result(func, goal, 2) + assert_same_result(func, goal, 3) + assert_same_result(func, goal, -1) + assert_same_result(func, goal, "this should crash") + + +def test_nested_for_loop(): + + def goal(x, y): + result = [] + for i in range(x): + for j in range(y): + result.append(i + j) + return result + + assembly = [ + instrs.BUILD_LIST(0), + instrs.STORE_FAST('result'), + ForLoop( + init=[ + instrs.LOAD_GLOBAL('range'), + instrs.LOAD_FAST('x'), + instrs.CALL_FUNCTION(1), + ], + unpack=[ + instrs.STORE_FAST('i'), + ], + body=[ + ForLoop( + init=[ + instrs.LOAD_GLOBAL('range'), + instrs.LOAD_FAST('y'), + instrs.CALL_FUNCTION(1), + ], + unpack=[ + instrs.STORE_FAST('j'), + ], + body=[ + instrs.LOAD_FAST('result'), + instrs.LOAD_ATTR('append'), + instrs.LOAD_FAST('i'), + instrs.LOAD_FAST('j'), + instrs.BINARY_ADD(), + instrs.CALL_FUNCTION(1), + instrs.POP_TOP(), + ], + ), + ], + ), + instrs.LOAD_FAST('result'), + instrs.RETURN_VALUE(), + ] + + func = assemble_function(signature(goal), assembly) + + assert_same_result(func, goal, 0, 0) + assert_same_result(func, goal, 5, 3) + assert_same_result(func, goal, 3, 5) + assert_same_result(func, goal, -1, -1) + assert_same_result(func, goal, "this should", "crash") + + +def test_for_else(): + + def goal(x): + for obj in ('a', 'b', 'c'): + if x == obj: + return 'found' + else: + return 'not found' + + assembly = [ + ForLoop( + init=[instrs.LOAD_CONST(('a', 'b', 'c'))], + unpack=[instrs.STORE_FAST('obj')], + body=[ + IfStatement( + test=[ + instrs.LOAD_FAST('x'), + instrs.LOAD_FAST('obj'), + instrs.COMPARE_OP.EQ, + ], + body=[instrs.LOAD_CONST('found'), instrs.RETURN_VALUE()], + ) + ], + else_body=[instrs.LOAD_CONST('not found'), instrs.RETURN_VALUE()], + ), + AssertFail("Shouldn't ever get here!"), + ] + + func = assemble_function(signature(goal), assembly) + + assert_same_result(func, goal, 'a') + assert_same_result(func, goal, 'b') + assert_same_result(func, goal, 'not in the tuple') + + +def test_assert_fail(): + assembly = [AssertFail('message')] + func = assemble_function(signature(lambda: None), assembly) + + with pytest.raises(AssertionError) as e: + func() + + assert e.value.args == ('message',) + + +@instance +@total_ordering +class incomparable: + def __lt__(self, other): + raise TypeError("Nothing compares to me!") diff --git a/tox.ini b/tox.ini index 33642d7..8042ac3 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,6 @@ commands= py.test [pytest] -addopts = --doctest-modules --cov codetransformer --cov-report term-missing --ignore setup.py +addopts = --doctest-modules --ignore setup.py --tb=short testpaths = codetransformer norecursedirs = decompiler