diff --git a/.gitignore b/.gitignore index d140936..212e1f0 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,9 @@ target/ .*.swp example tests/functional/output*py + +# mypy +.mypy_cache/ + +# Misc +/test*.py diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..a3215f9 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,23 @@ + +# The MIT License (MIT) + +Copyright © 2017 by [nvbn](https://github.com/nvbn/). +Copyright © 2019 by luk3yx. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 86bab79..a536c34 100644 --- a/README.md +++ b/README.md @@ -1,40 +1,73 @@ # Py-backwards [![Build Status](https://travis-ci.org/nvbn/py-backwards.svg?branch=master)](https://travis-ci.org/nvbn/py-backwards) -Python to python compiler that allows you to use some Python 3.6 features in older versions, you can try it in [the online demo](https://py-backwards.herokuapp.com/). +Python to python compiler that allows you to use some Python 3.6+ features in older versions, you can try it in [the online demo](https://py-backwards.herokuapp.com/). -Requires Python 3.3+ to run, can compile down to 2.7. +Requires Python 3.3+ to run, can compile down to 2.7 (and down to 2.5 if you +only use a subset of Python 3). + +Note that py_backwards creates variables beginning with `_py_backwards` for +internal use, to prevent variable conflicts try to avoid function/variable +names beginning with `_py_backwards` in your code. ## Supported features +Target 3.7: +* [Certain walrus operators](https://docs.python.org/3.8/whatsnew/3.8.html#assignment-expressions) - This is rather hit and miss, and some walrus operators currently + only work on CPython and only if the variable has already been defined in + the same scope. +* [Positional only parameters](https://docs.python.org/3.8/whatsnew/3.8.html#positional-only-parameters) +* [Self-documenting f-string expressions](https://docs.python.org/3.8/whatsnew/3.8.html#f-strings-support-for-self-documenting-expressions-and-debugging) (works automatically) + Target 3.5: -* [formatted string literals](https://docs.python.org/3/whatsnew/3.6.html#pep-498-formatted-string-literals) like `f'hi {x}'` -* [variables annotations](https://docs.python.org/3/whatsnew/3.6.html#whatsnew36-pep526) like `x: int = 10` and `x: int` -* [underscores in numeric literals](https://docs.python.org/3/whatsnew/3.6.html#pep-515-underscores-in-numeric-literals) like `1_000_000` (works automatically) +* [Formatted string literals](https://docs.python.org/3/whatsnew/3.6.html#pep-498-formatted-string-literals) like `f'hi {x}'` +* [Variable annotations](https://docs.python.org/3/whatsnew/3.6.html#whatsnew36-pep526) like `x: int = 10` and `x: int` +* [Asynchronous generators](https://www.python.org/dev/peps/pep-0525) +* [Underscores in numeric literals](https://docs.python.org/3/whatsnew/3.6.html#pep-515-underscores-in-numeric-literals) like `1_000_000` (works automatically) Target 3.4: -* [starred unpacking](https://docs.python.org/3/whatsnew/3.5.html#pep-448-additional-unpacking-generalizations) like `[*range(1, 5), *range(10, 15)]` and `print(*[1, 2], 3, *[4, 5])` -* [dict unpacking](https://docs.python.org/3/whatsnew/3.5.html#pep-448-additional-unpacking-generalizations) like `{1: 2, **{3: 4}}` +* [Starred unpacking](https://docs.python.org/3/whatsnew/3.5.html#pep-448-additional-unpacking-generalizations) like `[*range(1, 5), *range(10, 15)]` and `print(*[1, 2], 3, *[4, 5])` +* [Dict unpacking](https://docs.python.org/3/whatsnew/3.5.html#pep-448-additional-unpacking-generalizations) like `{1: 2, **{3: 4}}` Target 3.3: -* import [pathlib2](https://pypi.python.org/pypi/pathlib2/) instead of pathlib +* Import [pathlib2](https://pypi.python.org/pypi/pathlib2/) instead of pathlib Target 3.2: -* [yield from](https://docs.python.org/3/whatsnew/3.3.html#pep-380) -* [return from generator](https://docs.python.org/3/whatsnew/3.3.html#pep-380) +* [`yield from`](https://docs.python.org/3/whatsnew/3.3.html#pep-380) +* [Return from generator](https://docs.python.org/3/whatsnew/3.3.html#pep-380) Target 2.7: -* [functions annotations](https://www.python.org/dev/peps/pep-3107/) like `def fn(a: int) -> str` -* [imports from `__future__`](https://docs.python.org/3/howto/pyporting.html#prevent-compatibility-regressions) -* [super without arguments](https://www.python.org/dev/peps/pep-3135/) -* classes without base like `class A: pass` -* imports from [six moves](https://pythonhosted.org/six/#module-six.moves) -* metaclass -* string/unicode literals (works automatically) +* [Keyword only arguments](https://www.python.org/dev/peps/pep-3102/) +* [Function annotations](https://www.python.org/dev/peps/pep-3107/) like `def fn(a: int) -> str` +* [Imports from `__future__`](https://docs.python.org/3/howto/pyporting.html#prevent-compatibility-regressions) +* [`super()` without arguments](https://www.python.org/dev/peps/pep-3135/) +* [The `nonlocal` statement](https://www.python.org/dev/peps/pep-3104/), + provided you don't try and check for variables used with `nonlocal` in + `locals()`. +* Implicit `object` class base. +* Imports from [six.moves](https://pythonhosted.org/six/#module-six.moves) +* Metaclasses +* A `__nonzero__` alias for any `__bool__` methods. +* String/unicode literals (works automatically) * `str` to `unicode` -* define encoding (not transformer) +* Add `# -*- coding: utf-8 -*-` (not transformer) * `dbm => anydbm` and `dbm.ndbm => dbm` +* [Non-ASCII identifiers](https://www.python.org/dev/peps/pep-3131/). Non-ASCII + identifiers are mangled currently mangled in a similar way to + [Hy](https://docs.hylang.org/en/stable/language/syntax.html#mangling). + +Target 2.6: +* Class decorators +* Dict comprehension +* Set literals + +Target 2.5: +* `six.print_()` instead of `print()`. +* `six.advance_iterator()` instead of `next()`. +* `except as` (note that this breaks compatibility with Python 3.0+). +* Keyword arguments after `*args`. +* An `itertools.zip_longest` backport. -For example, if you have some python 3.6 code, like: +For example, if you have some Python 3.6 code, like: ```python def returning_range(x: int): @@ -80,7 +113,7 @@ print(ImportantNumberManager().ten()) print(ImportantNumberManager.eleven()) ``` -You can compile it for python 2.7 with: +You can compile it for Python 2.7 with: ```bash ➜ py-backwards -i input.py -o output.py -t 2.7 @@ -154,7 +187,7 @@ pip install py-backwards-packager ``` And change `setup` import in `setup.py` to: - + ```python try: from py_backwards_packager import setup @@ -163,7 +196,7 @@ except ImportError: ``` By default all targets enabled, but you can limit them with: - + ```python setup(..., py_backwards_targets=['2.7', '3.3']) @@ -263,7 +296,7 @@ from ..utils.snippet import snippet, let, extend def my_snippet(class_name, class_body): class class_name: # will be replaced with `class_name` extend(class_body) # body of the class will be extended with `class_body` - + def fn(self): let(x) # x will be replaced everywhere with unique name, like `_py_backwards_x_1` x = 10 @@ -286,5 +319,6 @@ it contains such useful functions like `find`, `get_parent` and etc. * [tox-py-backwards](https://github.com/nvbn/tox-py-backwards) * [py-backwards-packager](https://github.com/nvbn/py-backwards-packager) * [pytest-docker-pexpect](https://github.com/nvbn/pytest-docker-pexpect) +* [lib3to6](https://gitlab.com/mbarkhau/lib3to6) ## License MIT diff --git a/py_backwards/__init__.py b/py_backwards/__init__.py index e69de29..a97c6b5 100644 --- a/py_backwards/__init__.py +++ b/py_backwards/__init__.py @@ -0,0 +1,42 @@ + +import sys +if sys.version_info >= (3, 9): + import ast + def unparse(tree): + return ast.unparse(ast.fix_missing_locations(tree)) +elif sys.version_info >= (3, 8): + import ast + + # A hack to allow astunparse to parse ast.Constant-s. + import astunparse + from astunparse import unparse + from types import SimpleNamespace as _SimpleNamespace + + def _Constant(self, tree): + value = tree.value + if isinstance(value, str): + self._Str(_SimpleNamespace(s=value)) + elif isinstance(value, bytes): + self._Bytes(_SimpleNamespace(s=value)) + elif value is Ellipsis: + self._Ellipsis(tree) + else: + self._NameConstant(tree) + + def _NamedExpr(self, tree): + self.write('(') + self.dispatch(tree.target) + self.write(' := ') + self.dispatch(tree.value) + self.write(')') + + + if not hasattr(astunparse.Unparser, '_Constant'): + astunparse.Unparser._Constant = _Constant + if not hasattr(astunparse.Unparser, '_NamedExpr'): + astunparse.Unparser._NamedExpr = _NamedExpr + + del _Constant, _NamedExpr +else: + from typed_ast import ast3 as ast + from astunparse import unparse diff --git a/py_backwards/__main__.py b/py_backwards/__main__.py new file mode 100644 index 0000000..fa910d6 --- /dev/null +++ b/py_backwards/__main__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from .main import main + +if __name__ == '__main__': + main() diff --git a/py_backwards/compiler.py b/py_backwards/compiler.py index 75435d2..2689a14 100644 --- a/py_backwards/compiler.py +++ b/py_backwards/compiler.py @@ -2,15 +2,17 @@ from time import time from traceback import format_exc from typing import List, Tuple, Optional -from typed_ast import ast3 as ast -from astunparse import unparse, dump +try: + from ast import dump +except ImportError: + from astunparse import dump from autopep8 import fix_code from .files import get_input_output_paths, InputOutput from .transformers import transformers from .types import CompilationTarget, CompilationResult from .exceptions import CompilationError, TransformationError from .utils.helpers import debug -from . import const +from . import ast, const, unparse def _transform(path: str, code: str, target: CompilationTarget) -> Tuple[str, List[str]]: @@ -30,6 +32,17 @@ def _transform(path: str, code: str, target: CompilationTarget) -> Tuple[str, Li working_tree = deepcopy(tree) try: result = transformer.transform(working_tree) + except SyntaxError as exc: + if isinstance(getattr(exc, 'ast_node', None), ast.AST): + if not getattr(exc.ast_node, 'lineno', None): # type: ignore + ast.fix_missing_locations(working_tree) + exc.lineno = getattr(exc.ast_node, 'lineno', 0) # type: ignore + exc.offset = getattr(exc.ast_node, 'col_offset', -1) + 1 # type: ignore + else: + exc.lineno = exc.lineno or 0 + exc.offset = exc.offset or 0 + + raise exc except: raise TransformationError(path, transformer, dump(tree), format_exc()) @@ -49,7 +62,10 @@ def _transform(path: str, code: str, target: CompilationTarget) -> Tuple[str, Li raise TransformationError(path, transformer, dump(tree), format_exc()) - return fix_code(code), dependencies + # Disable E402 (moving imports to the top of the file) as it breaks. + code = fix_code(code, options={'ignore': ['E226', 'E24', 'W50', 'W690', + 'E402']}) + return code, dependencies def _compile_file(paths: InputOutput, target: CompilationTarget) -> List[str]: @@ -62,14 +78,14 @@ def _compile_file(paths: InputOutput, target: CompilationTarget) -> List[str]: code, target) except SyntaxError as e: raise CompilationError(paths.input.as_posix(), - code, e.lineno, e.offset) + code, e.lineno, e.offset or 0) try: paths.output.parent.mkdir(parents=True) except FileExistsError: pass - if target == const.TARGETS['2.7']: + if target <= const.TARGETS['2.7']: transformed = '# -*- coding: utf-8 -*-\n{}'.format(transformed) with paths.output.open('w') as f: diff --git a/py_backwards/const.py b/py_backwards/const.py index eb538be..b553d14 100644 --- a/py_backwards/const.py +++ b/py_backwards/const.py @@ -1,14 +1,18 @@ from collections import OrderedDict -TARGETS = OrderedDict([('2.7', (2, 7)), +TARGETS = OrderedDict([('2.5', (2, 5)), + ('2.6', (2, 6)), + ('2.7', (2, 7)), ('3.0', (3, 0)), ('3.1', (3, 1)), ('3.2', (3, 2)), ('3.3', (3, 3)), ('3.4', (3, 4)), ('3.5', (3, 5)), - ('3.6', (3, 6))]) + ('3.6', (3, 6)), + ('3.7', (3, 7)), + ('3.8', (3, 8))]) SYNTAX_ERROR_OFFSET = 5 -TARGET_ALL = (9999, 9999) +TARGET_ALL = next(reversed(TARGETS.values())) diff --git a/py_backwards/main.py b/py_backwards/main.py index 9fc53e5..48e116d 100644 --- a/py_backwards/main.py +++ b/py_backwards/main.py @@ -3,20 +3,34 @@ init() from argparse import ArgumentParser +import atexit +import pathlib +import shutil import sys +import tempfile from .compiler import compile_files from .conf import init_settings from . import const, messages, exceptions +def _cleanup(tmpdir): + try: + print('\n# -*- coding: utf-8 -*-') + for path in pathlib.Path(tmpdir).glob('**/*.py'): + print() + print('# ----------', path.name, '---------- #') + with path.open('r') as f: + shutil.copyfileobj(f, sys.stdout) + path.unlink() + finally: + shutil.rmtree(tmpdir) def main() -> int: - parser = ArgumentParser( - 'py-backwards', + parser = ArgumentParser('py-backwards', description='Python to python compiler that allows you to use some ' - 'Python 3.6 features in older versions.') + 'Python 3.6+ features in older versions.') parser.add_argument('-i', '--input', type=str, nargs='+', required=True, help='input file or folder') - parser.add_argument('-o', '--output', type=str, required=True, + parser.add_argument('-o', '--output', type=str, default='-', help='output file or folder') parser.add_argument('-t', '--target', type=str, required=True, choices=const.TARGETS.keys(), @@ -28,9 +42,17 @@ def main() -> int: args = parser.parse_args() init_settings(args) + output = args.output + if output == '-': + output = tempfile.mkdtemp() + atexit.register(_cleanup, output) + result_file = sys.stderr + else: + result_file = sys.stdout + try: for input_ in args.input: - result = compile_files(input_, args.output, + result = compile_files(input_, output, const.TARGETS[args.target], args.root) except exceptions.CompilationError as e: @@ -50,5 +72,5 @@ def main() -> int: print(messages.permission_error(args.output), file=sys.stderr) return 1 - print(messages.compilation_result(result)) + print(messages.compilation_result(result), file=result_file) return 0 diff --git a/py_backwards/messages.py b/py_backwards/messages.py index 87dcb77..fc65114 100644 --- a/py_backwards/messages.py +++ b/py_backwards/messages.py @@ -53,7 +53,7 @@ def syntax_error(e: CompilationError) -> str: red=Fore.RED, e=e, reset=Style.RESET_ALL, - bright=Style.BRIGHT, + # bright=Style.BRIGHT, lines='\n'.join(lines)) diff --git a/py_backwards/transformers/__init__.py b/py_backwards/transformers/__init__.py index 82efa61..30d1b54 100644 --- a/py_backwards/transformers/__init__.py +++ b/py_backwards/transformers/__init__.py @@ -1,28 +1,57 @@ from typing import List, Type +from .walrus_operator import WalrusTransformer +from .walrus_operator import FallbackWalrusTransformer +from .posonlyargs import PosOnlyArgTransformer from .dict_unpacking import DictUnpackingTransformer from .formatted_values import FormattedValuesTransformer from .functions_annotations import FunctionsAnnotationsTransformer from .starred_unpacking import StarredUnpackingTransformer from .variables_annotations import VariablesAnnotationsTransformer +from .matrix_multiplication import MatMultTransformer +from .async_generators import AsyncGeneratorTransformer +from .async_for import AsyncForTransformer +from .async_with import AsyncWithTransformer +from .async_functions import AsyncFunctionTransformer from .yield_from import YieldFromTransformer from .return_from_generator import ReturnFromGeneratorTransformer from .python2_future import Python2FutureTransformer +from .python2_future import Python25FutureTransformer +from .nonlocal_statement import NonlocalStatementTransformer +from .class_bool_method import ClassBoolMethodTransformer from .super_without_arguments import SuperWithoutArgumentsTransformer from .class_without_bases import ClassWithoutBasesTransformer from .import_pathlib import ImportPathlibTransformer from .six_moves import SixMovesTransformer from .metaclass import MetaclassTransformer -from .string_types import StringTypesTransformer +from .kwargs import KwArgTransformer +from .kwonlyargs import KwOnlyArgTransformer +from .byte_literals import ByteLiteralTransformer from .import_dbm import ImportDbmTransformer +from .unicode_identifiers import UnicodeIdentifierTransformer +from .set_literals import SetLiteralTransformer +from .dict_comprehension import DictComprehensionTransformer +from .class_decorators import ClassDecoratorTransformer +from .except_as import ExceptAsTransformer +from .raise_from import RaiseFromTransformer +from .print_function import PrintFunctionTransformer from .base import BaseTransformer transformers = [ + # 3.7 + WalrusTransformer, + FallbackWalrusTransformer, + PosOnlyArgTransformer, # 3.5 VariablesAnnotationsTransformer, FormattedValuesTransformer, + AsyncGeneratorTransformer, # 3.4 DictUnpackingTransformer, StarredUnpackingTransformer, + MatMultTransformer, + AsyncForTransformer, + AsyncWithTransformer, + AsyncFunctionTransformer, # 3.2 YieldFromTransformer, ReturnFromGeneratorTransformer, @@ -32,8 +61,23 @@ ClassWithoutBasesTransformer, ImportPathlibTransformer, SixMovesTransformer, + ClassBoolMethodTransformer, MetaclassTransformer, - StringTypesTransformer, ImportDbmTransformer, - Python2FutureTransformer, # always should be the last transformer + NonlocalStatementTransformer, + RaiseFromTransformer, + KwOnlyArgTransformer, + UnicodeIdentifierTransformer, + # 2.5 + SetLiteralTransformer, + PrintFunctionTransformer, + ExceptAsTransformer, + DictComprehensionTransformer, + ClassDecoratorTransformer, + KwArgTransformer, + + # These transformers should be last and in this order to prevent conflicts. + ByteLiteralTransformer, # 2.5 + Python2FutureTransformer, # 2.7 + Python25FutureTransformer, # 2.5 ] # type: List[Type[BaseTransformer]] diff --git a/py_backwards/transformers/async_for.py b/py_backwards/transformers/async_for.py new file mode 100644 index 0000000..5c1dd27 --- /dev/null +++ b/py_backwards/transformers/async_for.py @@ -0,0 +1,85 @@ +from ..exceptions import NodeNotFound +from ..utils.helpers import warn +from ..utils.snippet import snippet +from ..utils.tree import get_node_position, find, insert_at, replace_at +from ..types import TransformationResult +from .. import ast +from .base import BaseTransformer + +# Create a StopAsyncIteration if one doesn't already exist. +@snippet +def _init(): + try: + assert issubclass(StopAsyncIteration, Exception) + except (AssertionError, NameError): + let(builtin) + import builtins as builtin + class StopAsyncIteration(Exception): + pass + builtin.StopAsyncIteration = StopAsyncIteration + del builtin + + from asyncio import iscoroutine as _py_backwards_iscoroutine + +@snippet +def _async_for(target, iter_): + let(it) + let(itertype) + let(running) + it = iter_ + it = type(it).__aiter__(it) + + # Support legacy iterators + if _py_backwards_iscoroutine(it): + it = yield from it + + itertype = type(it) + running = True + while running: + try: + target = yield from itertype.__anext__(it) + except StopAsyncIteration: + running = False + + del it, itertype, running + +class AsyncForTransformer(BaseTransformer): + """Compiles: + async def test1(): + async for i in async_iterable: + print(i) + else: + print('Else') + """ + target = (3, 4) + + @classmethod + def transform(cls, tree: ast.AST, *, add_init=True) \ + -> TransformationResult: + tree_changed = False + + for node in find(tree, ast.AsyncFor): + if not tree_changed: + tree_changed = True + if add_init: + insert_at(0, tree, _init.get_body()) + + try: + position = get_node_position(tree, node) + except NodeNotFound: + warn('Async for outside of body') + continue + + body = _async_for.get_body(target=node.target, iter_=node.iter) + + # This can't use extend() as that replaces variables. + for n in body: + if isinstance(n, ast.While): + n.orelse = node.orelse + assert isinstance(n.body[0], ast.Try) + n.body[0].orelse = node.body + break + + replace_at(position.index, position.parent, body) + + return TransformationResult(tree, tree_changed, []) #['asyncio']) diff --git a/py_backwards/transformers/async_functions.py b/py_backwards/transformers/async_functions.py new file mode 100644 index 0000000..e8cecbe --- /dev/null +++ b/py_backwards/transformers/async_functions.py @@ -0,0 +1,39 @@ +import sys +from ..utils.snippet import snippet +from .. import ast +from .base import BaseNodeTransformer +from typing import Optional + +@snippet +def _import(): + from asyncio import coroutine as _py_backwards_coroutine + +class AsyncFunctionTransformer(BaseNodeTransformer): + """Compiles: + async def test(): + await asyncio.sleep(2) + To + @_py_backwards_coroutine + def test(): + yield from asyncio.sleep(2) + + """ + target = (3, 4) + # dependencies = ['asyncio'] + + def visit_Module(self, node: ast.Module) -> Optional[ast.AST]: + node.body = _import.get_body() + node.body # type: ignore + return self.generic_visit(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) \ + -> Optional[ast.AST]: + self._tree_changed = True + node.decorator_list.append(ast.Name(id='_py_backwards_coroutine')) + func = ast.FunctionDef(node.name, node.args, node.body, + node.decorator_list, node.returns) + + return self.generic_visit(func) + + def visit_Await(self, node: ast.Await) -> Optional[ast.AST]: + self._tree_changed = True + return self.generic_visit(ast.YieldFrom(value=node.value)) diff --git a/py_backwards/transformers/async_generators.py b/py_backwards/transformers/async_generators.py new file mode 100644 index 0000000..08a55ca --- /dev/null +++ b/py_backwards/transformers/async_generators.py @@ -0,0 +1,189 @@ +from ..types import TransformationResult +from ..utils.snippet import snippet +from ..utils.tree import insert_at +from .. import ast +from .async_for import AsyncForTransformer +from .async_with import AsyncWithTransformer +from .base import BaseNodeTransformer +from typing import List, Tuple, Optional, Union + +# Handle generators. +@snippet +def _async_generator(): + let(coro) + let(AsyncGenerator) + from asyncio import (coroutine as coro, + iscoroutine as _py_backwards_iscoroutine) + class AsyncGenerator: + __slots__ = ('_iter', 'ag_running') + + @coro + def asend(self, value): + while True: + try: + i = self._iter.send(value) + except StopIteration: + self.ag_running = False + raise StopAsyncIteration from None + + # Normally, isinstance would be used, however a tuple subclass + # should never be yield-ed when being used as a generator + # yield. + if type(i) is tuple and len(i) == 2 and \ + i[0] is _py_backwards_async_generator: + return i[1] + + value = yield i + + # I think athrow() and aclose() are implemented correctly here, however + # they are probably not. + @coro + def athrow(self, *args): + if not self.ag_running: + return + + i = self._iter.throw(*args) + if type(i) is tuple and len(i) == 2 and \ + i[0] is _py_backwards_async_generator: + return i[1] + + return (yield from self.asend((yield i))) + + @coro + def aclose(self): + try: + yield from self.athrow(GeneratorExit) + except StopAsyncIteration: + pass + + def __aiter__(self): + return self + + @coro + def __anext__(self): + return (yield from self.asend(None)) + + def __init__(self, iterator): + self._iter = iterator + self.ag_running = True + + let(functools_wraps) + from functools import wraps as functools_wraps + + def _py_backwards_async_generator(func): + @functools_wraps(func) + def wrapper(*args, **kwargs): + return AsyncGenerator(coro_func(*args, **kwargs)) + coro_func = coro(func) + return wrapper + +class _YieldFinder(ast.NodeVisitor): + @classmethod + def find_yields(cls, tree: ast.AsyncFunctionDef) \ + -> Tuple[List[ast.Yield], List[ast.Await]]: + self = cls() + self.generic_visit(tree) + if self.returns_value and self.yields: + exc = SyntaxError("'return' with value in async generator") + exc.ast_node = self.returns_value # type: ignore + raise exc + return self.yields, self.awaits + + def __init__(self): + self.yields = [] # type: List[ast.Yield] + self.awaits = [] # type: List[ast.Await] + self.returns_value = None # type: Optional[ast.Return] + + def visit_FunctionDef(self, node: ast.AST) -> ast.AST: + return node + + visit_ClassDef = visit_Lambda = visit_AsyncFunctionDef = visit_FunctionDef + + def visit_Yield(self, node: ast.Yield) -> ast.AST: + self.yields.append(node) + return self.generic_visit(node) + + def visit_Await(self, node: ast.Await) -> ast.AST: + self.awaits.append(node) + return self.generic_visit(node) + + def visit_YieldFrom(self, node: ast.YieldFrom) -> ast.AST: + # Ignore if the name contains _py_backwards. + n = node.value # type: ast.AST + if isinstance(n, ast.Call): + n = n.func + if isinstance(n, ast.Attribute): + n = n.value + if isinstance(n, ast.Name) and n.id.startswith('_py_backwards_it'): + return self.generic_visit(node) + + # Otherwise raise a SyntaxError. + exc = SyntaxError('yield from in async function') + exc.ast_node = node # type: ignore + raise exc + + def visit_Return(self, node: ast.Return) -> ast.AST: + if node.value and not self.returns_value: + self.returns_value = node + + return self.generic_visit(node) + +class AsyncGeneratorTransformer(BaseNodeTransformer): + """Compiles: + async def test2(): + yield 1 + await asyncio.sleep(1) + yield 2 + To + @_py_backwards_async_generator + def test2(): + yield (_py_backwards_async_generator, 1) + yield from asyncio.sleep(1) + yield (_py_backwards_async_generator, 2) + """ + target = (3, 5) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> \ + Optional[ast.AST]: + """ Searches for async iterators and rewrites them. """ + yields, awaits = _YieldFinder.find_yields(node) + + if yields: + # Because this is no longer an async function, "async for" and + # "async with" need to be removed. + AsyncForTransformer.transform(node, add_init=False) + AsyncWithTransformer.transform(node) + + self._tree_changed = True + for n in yields: + name = ast.Name(id='_py_backwards_async_generator', + ctx=ast.Load()) + if not n.value: + n.value = ast.NameConstant(value=None) + n.value = ast.Tuple(elts=[name, n.value], ctx=ast.Load()) + + name = ast.Name(id='_py_backwards_async_generator', ctx=ast.Load()) + node.decorator_list.append(name) + + for await_ in awaits: + await_.py_backwards_await = True # type: ignore + + func = ast.FunctionDef(name=node.name, args=node.args, + body=node.body, returns=node.returns, + decorator_list=node.decorator_list) + return self.generic_visit(func) + + return self.generic_visit(node) + + def visit_Await(self, node: ast.Await) -> Optional[ast.AST]: + if getattr(node, 'py_backwards_await', False): + return self.generic_visit(ast.YieldFrom(value=node.value)) + return self.generic_visit(node) + + @classmethod + def transform(cls, tree: ast.AST) -> TransformationResult: + res = super().transform(tree) + if res.tree_changed and \ + isinstance(getattr(res.tree, 'body', None), list): + insert_at(0, res.tree, _async_generator.get_body()) + return res diff --git a/py_backwards/transformers/async_with.py b/py_backwards/transformers/async_with.py new file mode 100644 index 0000000..c47f8f4 --- /dev/null +++ b/py_backwards/transformers/async_with.py @@ -0,0 +1,81 @@ +from ..exceptions import NodeNotFound +from ..utils.helpers import warn +from ..utils.snippet import snippet +from ..utils.tree import get_node_position, find, insert_at, replace_at +from ..types import TransformationResult +from .. import ast +from .base import BaseTransformer +from typing import List, Union + +@snippet +def _async_with(expr, aenter): + let(mgr) + let(aexit) + let(exc) + mgr = expr + aexit = type(mgr).__aexit__ + extend(aenter) + try: + ... + except BaseException as exc: + if not (yield from aexit(mgr, type(exc), exc, exc.__traceback__)): + raise + else: + aexit(mgr, None, None, None) + + del mgr, aexit + +@snippet +def _aenter1(var): + var = yield from type(mgr).__aenter__(mgr) + +@snippet +def _aenter2(): + yield from type(mgr).__aenter__(mgr) + +class AsyncWithTransformer(BaseTransformer): + """Compiles: + async with test1(): + ... + """ + target = (3, 4) + + @classmethod + def _replace_with(cls, tree: ast.AST, node: ast.AsyncWith) -> None: + try: + position = get_node_position(tree, node) + except NodeNotFound: + warn('Async with outside of body') + return + + item = node.items[0] + with_body = node.body # type: List[ast.stmt] + if len(node.items) > 1: + with_body = [ast.AsyncWith(items=node.items[1:], + body=with_body)] + + if item.optional_vars: + aenter = _aenter1.get_body(var=item.optional_vars) + else: + aenter = _aenter2.get_body() + + body = _async_with.get_body(expr=item.context_expr, aenter=aenter) + for n in body: + if isinstance(n, ast.Try): + n.body = with_body + break + + replace_at(position.index, position.parent, body) + + if len(node.items) > 1: + cls._replace_with(tree, with_body[0]) # type: ignore + + @classmethod + def transform(cls, tree: ast.AST) -> TransformationResult: + tree_changed = False + + for node in find(tree, ast.AsyncWith): + tree_changed = True + cls._replace_with(tree, node) + + return TransformationResult(tree, tree_changed, []) #['asyncio']) diff --git a/py_backwards/transformers/base.py b/py_backwards/transformers/base.py index 42627f8..a4b640d 100644 --- a/py_backwards/transformers/base.py +++ b/py_backwards/transformers/base.py @@ -1,6 +1,6 @@ from abc import ABCMeta, abstractmethod from typing import List, Tuple, Union, Optional, Iterable, Dict -from typed_ast import ast3 as ast +from .. import ast from ..types import CompilationTarget, TransformationResult from ..utils.snippet import snippet, extend @@ -71,16 +71,19 @@ def visit_Import(self, node: ast.Import) -> Union[ast.Import, ast.Try]: if rewrite: return self._replace_import(node, *rewrite) - return self.generic_visit(node) + return self.generic_visit(node) # type: ignore def _replace_import_from_module(self, node: ast.ImportFrom, from_: str, to: str) -> ast.Try: """Replaces import from with try/except with old and new import module.""" self._tree_changed = True - rewrote_module = node.module.replace(from_, to, 1) - rewrote = ast.ImportFrom(module=rewrote_module, - names=node.names, - level=node.level) + if node.module: + rewrote_module = node.module.replace(from_, to, 1) + rewrote = ast.ImportFrom(module=rewrote_module, + names=node.names, + level=node.level) + else: + rewrote = node return self.wrapper.get_body(previous=node, # type: ignore current=rewrote)[0] @@ -112,9 +115,9 @@ def _get_replaced_import_from_part(self, node: ast.ImportFrom, alias: ast.alias, def _replace_import_from_names(self, node: ast.ImportFrom, names_to_replace: Dict[str, Tuple[str, str]]) -> ast.Try: - """Replaces import from with try/except with old and new + """Replaces import from with try/except with old and new import module and names. - + """ self._tree_changed = True @@ -123,9 +126,9 @@ def _replace_import_from_names(self, node: ast.ImportFrom, for alias in node.names] return self.wrapper.get_body(previous=node, # type: ignore - current=rewrotes)[0] + current=rewrotes)[0] # type: ignore - def visit_ImportFrom(self, node: ast.ImportFrom) -> Union[ast.ImportFrom, ast.Try, ast.AST]: + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST: rewrite = self._get_matched_rewrite(node.module) if rewrite: return self._replace_import_from_module(node, *rewrite) @@ -134,4 +137,4 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> Union[ast.ImportFrom, ast.Tr if names_to_replace: return self._replace_import_from_names(node, names_to_replace) - return self.generic_visit(node) + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/byte_literals.py b/py_backwards/transformers/byte_literals.py new file mode 100644 index 0000000..57ff063 --- /dev/null +++ b/py_backwards/transformers/byte_literals.py @@ -0,0 +1,35 @@ +from .. import ast +from .base import BaseNodeTransformer + +# Python2-style unicode and str repr()s, astunparse should call these +# overridden functions. +class _py2_unicode(str): + __slots__ = () + def __repr__(self): + return 'u' + super().__repr__() + +class _py2_str(bytes): + __slots__ = () + def __repr__(self): + return super().__repr__().lstrip('b') + +class ByteLiteralTransformer(BaseNodeTransformer): + """Compiles: + test = 'Hello, world!' + test2 = b'test' + To + test = u'Hello, world!' + test2 = 'test' + + """ + target = (2, 5) + + def visit_Str(self, node: ast.Str) -> ast.Str: + self._tree_changed = True + node.s = _py2_unicode(node.s) + return self.generic_visit(node) # type: ignore + + def visit_Bytes(self, node: ast.Bytes) -> ast.Bytes: + self._tree_changed = True + node.s = _py2_str(node.s) + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/class_bool_method.py b/py_backwards/transformers/class_bool_method.py new file mode 100644 index 0000000..a9410b1 --- /dev/null +++ b/py_backwards/transformers/class_bool_method.py @@ -0,0 +1,47 @@ +from .. import ast +from .base import BaseNodeTransformer +from typing import List, Union + +def _find_bool(nodes: Union[List[ast.AST], List[ast.expr]]) -> bool: + for node in nodes: + if isinstance(node, ast.Name): + if node.id == '__bool__': + return True + elif isinstance(node, ast.Tuple): + if _find_bool(node.elts): + return True + + return False + +class ClassBoolMethodTransformer(BaseNodeTransformer): + """Compiles: + class A: + def __bool__(self): + return False + To: + class A: + def __bool__(self): + return False + __nonzero__ = __bool__ + + """ + target = (2, 7) + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + has_bool = False + for n in node.body: + if has_bool: + break + + if isinstance(n, ast.Assign): + has_bool = _find_bool(n.targets) + elif isinstance(n, ast.FunctionDef): + has_bool = (n.name == '__bool__') + + if has_bool: + self._tree_changed = True + nonzero = ast.Name(id='__nonzero__', ctx=ast.Store()) + bool_ = ast.Name(id='__bool__', ctx=ast.Load()) + node.body.append(ast.Assign(targets=[nonzero], value=bool_)) + + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/class_decorators.py b/py_backwards/transformers/class_decorators.py new file mode 100644 index 0000000..6d22225 --- /dev/null +++ b/py_backwards/transformers/class_decorators.py @@ -0,0 +1,39 @@ +from .. import ast +from ..types import TransformationResult +from ..utils.tree import find, get_node_position, insert_at +from .base import BaseTransformer + +class ClassDecoratorTransformer(BaseTransformer): + """Compiles: + @decorator + class Test: + pass + To + class Test: + pass + + Test = decorator(Test) + + """ + target = (2, 6) + + @classmethod + def transform(cls, tree: ast.AST) -> TransformationResult: + tree_changed = False + for node in find(tree, ast.ClassDef): + if not node.decorator_list: + continue + + tree_changed = True + pos = get_node_position(tree, node) + index = pos.index + 1 + value = ast.Name(id=node.name, ctx=ast.Load()) # type: ast.AST + for decorator in reversed(node.decorator_list): + value = ast.Call(func=decorator, args=[value], keywords=[]) + + insert_at(index, pos.parent, ast.Assign( + targets=[ast.Name(id=node.name, ctx=ast.Store())], + value=value)) + node.decorator_list = [] + + return TransformationResult(tree, tree_changed, []) diff --git a/py_backwards/transformers/class_without_bases.py b/py_backwards/transformers/class_without_bases.py index be36ae3..50da283 100644 --- a/py_backwards/transformers/class_without_bases.py +++ b/py_backwards/transformers/class_without_bases.py @@ -1,4 +1,4 @@ -from typed_ast import ast3 as ast +from .. import ast from .base import BaseNodeTransformer @@ -7,8 +7,9 @@ class ClassWithoutBasesTransformer(BaseNodeTransformer): class A: pass To: - class A(object) - + class A(object): + pass + """ target = (2, 7) diff --git a/py_backwards/transformers/dict_comprehension.py b/py_backwards/transformers/dict_comprehension.py new file mode 100644 index 0000000..4813218 --- /dev/null +++ b/py_backwards/transformers/dict_comprehension.py @@ -0,0 +1,23 @@ +from .. import ast +from .base import BaseNodeTransformer + + +class DictComprehensionTransformer(BaseNodeTransformer): + """Compiles: + d = {v: k for k, v in zip(range(10), range(10, 20))} + To + d = dict((v, k) for k, v in zip(range(10), range(10, 20))) + + """ + target = (2, 6) + + def visit_DictComp(self, node: ast.DictComp) -> ast.Call: + self._tree_changed = True + + generator = ast.GeneratorExp(elt=ast.Tuple(elts=[node.key, node.value]), + generators=node.generators) + + res = ast.Call(func=ast.Name(id='dict'), args=[generator], + keywords=[]) + + return self.generic_visit(res) # type: ignore diff --git a/py_backwards/transformers/dict_unpacking.py b/py_backwards/transformers/dict_unpacking.py index 312d33f..da19525 100644 --- a/py_backwards/transformers/dict_unpacking.py +++ b/py_backwards/transformers/dict_unpacking.py @@ -1,7 +1,7 @@ from typing import Union, Iterable, Optional, List, Tuple -from typed_ast import ast3 as ast from ..utils.tree import insert_at from ..utils.snippet import snippet +from .. import ast from .base import BaseNodeTransformer @@ -20,13 +20,13 @@ def _py_backwards_merge_dicts(dicts): class DictUnpackingTransformer(BaseNodeTransformer): """Compiles: - + {1: 1, **dict_a} - + To: - - _py_backwards_merge_dicts([{1: 1}], dict_a}) - + + _py_backwards_merge_dicts(({1: 1}, dict_a)) + """ target = (3, 4) @@ -61,7 +61,7 @@ def _merge_dicts(self, xs: Iterable[Union[ast.Call, ast.Dict]]) \ """Creates call of function for merging dicts.""" return ast.Call( func=ast.Name(id='_py_backwards_merge_dicts'), - args=[ast.List(elts=list(xs))], + args=[ast.Tuple(elts=list(xs))], keywords=[]) def visit_Module(self, node: ast.Module) -> ast.Module: diff --git a/py_backwards/transformers/except_as.py b/py_backwards/transformers/except_as.py new file mode 100644 index 0000000..cd5d8b0 --- /dev/null +++ b/py_backwards/transformers/except_as.py @@ -0,0 +1,28 @@ +from .. import ast, unparse +from .base import BaseNodeTransformer + +class ExceptAsTransformer(BaseNodeTransformer): + """Compiles: + try: + 1 / 0 + except ZeroDivisionError as e: + print(repr(e)) + To + try: + 1 / 0 + except ZeroDivisionError, e: + print(repr(e)) + + """ + target = (2, 5) + + def visit_Try(self, node: ast.Try) -> ast.Try: + # This is a hack. + for handler in node.handlers: + if handler.type and handler.name: + self._tree_changed = True + name = unparse(handler.type).strip() + ', ' + handler.name + handler.type = ast.Name(id=name) + handler.name = None + + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/formatted_values.py b/py_backwards/transformers/formatted_values.py index 8620b47..85d8366 100644 --- a/py_backwards/transformers/formatted_values.py +++ b/py_backwards/transformers/formatted_values.py @@ -1,36 +1,81 @@ -from typed_ast import ast3 as ast +from .. import ast from ..const import TARGET_ALL from .base import BaseNodeTransformer +from typing import List, Union +# Because astunparse does not handle format strings nicely, this transformer +# has to target all Python versions (even 3.6 and 3.7). class FormattedValuesTransformer(BaseNodeTransformer): """Compiles: f"hello {x}" To - ''.join(['hello ', '{}'.format(x)]) - + 'hello {}'.format(x) + """ target = TARGET_ALL + def _parse_formatted_value(self, i: ast.FormattedValue, keywords: list) \ + -> str: + """ + Parse a FormattedValue and return a format string to add to a format() + call. The "keywords" argument is used to add keyword arguments for + nested format strings. + """ + res = '{' + + if i.conversion and i.conversion > 0: + res += '!' + chr(i.conversion) + + if i.format_spec: + spec = i.format_spec + + # A single ast.Str can just be returned as-is. + if not isinstance(spec, ast.JoinedStr): + assert isinstance(spec, ast.Str) + res += ':' + spec.s + elif len(spec.values) == 1 and isinstance(spec.values[0], ast.Str): + res += ':' + spec.values[0].s + else: + # For more complicated format strings, add the JoinedStr to the + # keyword arguments list. + kwarg = 'x{:x}'.format(len(keywords)) + keywords.append(ast.keyword(arg=kwarg, value=spec)) + res += ':{' + kwarg + '}' + + return res + '}' + def visit_FormattedValue(self, node: ast.FormattedValue) -> ast.Call: self._tree_changed = True - if node.format_spec: - template = ''.join(['{:', node.format_spec.s, '}']) # type: ignore - else: - template = '{}' + keywords = [] # type: List[ast.keyword] + template = self._parse_formatted_value(node, keywords) format_call = ast.Call(func=ast.Attribute(value=ast.Str(s=template), attr='format'), - args=[node.value], - keywords=[]) + args=[node.value], keywords=keywords) return self.generic_visit(format_call) # type: ignore - def visit_JoinedStr(self, node: ast.JoinedStr) -> ast.Call: + def visit_JoinedStr(self, node: ast.JoinedStr) -> Union[ast.Call, ast.Str]: self._tree_changed = True - join_call = ast.Call(func=ast.Attribute(value=ast.Str(s=''), - attr='join'), - args=[ast.List(elts=node.values)], - keywords=[]) - return self.generic_visit(join_call) # type: ignore + fs = [] + args = [] + keywords = [] # type: List[ast.keyword] + for i in node.values: + if isinstance(i, ast.Str): + fs.append(i.s.replace('{', '{{').replace('}', '}}')) + elif isinstance(i, ast.FormattedValue): + fs.append(self._parse_formatted_value(i, keywords)) + args.append(i.value) + else: + raise TypeError(i) + + value = ast.Str(s=''.join(fs)) + if not args: + return value + + format_call = ast.Call(func=ast.Attribute(value=value, + attr='format'), + args=args, keywords=keywords) + return self.generic_visit(format_call) # type: ignore diff --git a/py_backwards/transformers/functions_annotations.py b/py_backwards/transformers/functions_annotations.py index d13d679..9053e39 100644 --- a/py_backwards/transformers/functions_annotations.py +++ b/py_backwards/transformers/functions_annotations.py @@ -1,4 +1,4 @@ -from typed_ast import ast3 as ast +from .. import ast from .base import BaseNodeTransformer @@ -9,7 +9,7 @@ def fn(x: int) -> int: To: def fn(x): pass - + """ target = (2, 7) diff --git a/py_backwards/transformers/import_dbm.py b/py_backwards/transformers/import_dbm.py index b9b962a..921da80 100644 --- a/py_backwards/transformers/import_dbm.py +++ b/py_backwards/transformers/import_dbm.py @@ -1,5 +1,5 @@ from typing import Union -from typed_ast import ast3 as ast +from .base import ast from ..utils.snippet import snippet, extend from .base import BaseImportRewrite @@ -14,7 +14,7 @@ def import_rewrite(previous, current): class ImportDbmTransformer(BaseImportRewrite): """Replaces: - + dbm => anydbm dbm.ndbm => dbm diff --git a/py_backwards/transformers/kwargs.py b/py_backwards/transformers/kwargs.py new file mode 100644 index 0000000..dc0332e --- /dev/null +++ b/py_backwards/transformers/kwargs.py @@ -0,0 +1,44 @@ +from .. import ast +from ..utils.snippet import snippet +from .base import BaseNodeTransformer +from typing import Optional + +@snippet +def _kwarg_lambda(): + lambda **k : k + +class KwArgTransformer(BaseNodeTransformer): + """Compiles: + args = (1, 2, 3, 4) + test(*args, a=1) + test2(**kwargs) + To + args = (1, 2, 3, 4) + test(*args, **{b'a': 1}) + test2(**kwargs) + + """ + target = (2, 5) + + def visit_Call(self, node: ast.Call) -> Optional[ast.AST]: + # This uses bytes literals because in Python 2.5 those are converted to + # strings. + if node.args and isinstance(node.args[-1], ast.Starred) \ + and node.keywords: + self._tree_changed = True + if any(True for k in node.keywords if k.arg is None): + if len(node.keywords) == 1: + return self.generic_visit(node) + + # TODO: Something less hacky. + func = _kwarg_lambda.get_body()[0] + d = ast.Call(func=func, args=[], keywords=list(node.keywords)) # type: ast.AST + else: + d = ast.Dict(keys=[ast.Bytes(s=k.arg.encode('utf-8')) for k in + node.keywords if k.arg], + values=[k.value for k in node.keywords]) + + node.keywords.clear() + node.keywords.append(ast.keyword(arg=None, value=d)) + + return self.generic_visit(node) diff --git a/py_backwards/transformers/kwonlyargs.py b/py_backwards/transformers/kwonlyargs.py new file mode 100644 index 0000000..bd5bba6 --- /dev/null +++ b/py_backwards/transformers/kwonlyargs.py @@ -0,0 +1,84 @@ +import sys +from ..utils.snippet import snippet +from ..utils.tree import insert_at +from .. import ast +from .base import BaseNodeTransformer + +# "for arg in kwargs:" is a hack. +@snippet +def _sanity_check(func): + for arg in _py_backwards_kwargs: + raise TypeError(func + '() got an unexpected keyword argument ' + + repr(arg)) + +class KwOnlyArgTransformer(BaseNodeTransformer): + """Compiles: + def test(a, b=1, *, c, d=2, **kwargs): + pass + + test2 = lambda a, b=1, *, c=2, d=3 : d + + def test3(a, b=1, *, c, d=2): + pass + To + def test(a, b=1, **kwargs): + c = kwargs.pop('c') + d = kwargs.pop('d', 2) + pass + + test2 = lambda a, b=1, c=2, d=3 : d + + def test3(a, b=1, **_py_backwards_kwargs): + c = _py_backwards_kwargs.pop('c') + d = _py_backwards_kwargs.pop('d', 2) + for arg in _py_backwards_kwargs: + raise TypeError('test3() got an unexpected keyword argument ' + + repr(arg)) + + """ + target = (2, 7) + + def visit_FunctionDef(self, node: ast.FunctionDef) \ + -> ast.FunctionDef: + if node.args.kwonlyargs: + self._tree_changed = True + docstring = None + if isinstance(node.body[0], ast.Expr) and \ + isinstance(node.body[0].value, ast.Str): + docstring = node.body.pop(0) + + if node.args.kwarg: + kwarg = ast.Name(id=node.args.kwarg.arg, ctx=ast.Load()) + else: + kwarg = ast.Name(id='_py_backwards_kwargs', ctx=ast.Load()) + node.args.kwarg = ast.arg(arg='_py_backwards_kwargs', + annotation=None) + insert_at(0, node, + _sanity_check.get_body(func=ast.Str(s=node.name))) + + for i, arg in enumerate(node.args.kwonlyargs): + args = [ast.Str(arg.arg)] # type: list + if node.args.kw_defaults[i]: + args.append(node.args.kw_defaults[i]) + insert_at(i, node, ast.Assign(targets=[ast.Name(id=arg.arg)], + value=ast.Call(func=ast.Attribute(value=kwarg, + attr='pop'), args=args, keywords=[]))) + + node.args.kwonlyargs.clear() + node.args.kw_defaults.clear() + + if docstring: + node.body.insert(0, docstring) + + return self.generic_visit(node) # type: ignore + + # Just make all paramters positional in lambdas + def visit_Lambda(self, node: ast.Lambda) -> ast.Lambda: + if not node.args.vararg and node.args.kwonlyargs: + self._tree_changed = True + node.args.args.extend(node.args.kwonlyargs) + node.args.kwonlyargs.clear() + node.args.defaults.extend(node.args.kw_defaults) + node.args.kw_defaults.clear() + + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/matrix_multiplication.py b/py_backwards/transformers/matrix_multiplication.py new file mode 100644 index 0000000..9d1e2db --- /dev/null +++ b/py_backwards/transformers/matrix_multiplication.py @@ -0,0 +1,93 @@ +import sys +from typing import Union +from ..utils.tree import insert_at +from ..utils.snippet import snippet +from .. import ast +from .base import BaseNodeTransformer + +if sys.version_info >= (3, 4): + MatMult, NameConstant = ast.MatMult, ast.NameConstant +else: + MatMult = None + def NameConstant(obj): + return ast.Name(repr(obj)) + +@snippet +def _matmul(): + def _py_backwards_matmul(left, right): + """ Same as a @ b. """ + + res = NotImplemented + lt = type(left) + rt = type(right) + + if hasattr(lt, '__matmul__'): + res = lt.__matmul__(left, right) + if res is NotImplemented and hasattr(rt, '__rmatmul__'): + res = rt.__rmatmul__(right, left) + + if res is NotImplemented: + raise TypeError('unsupported operand type(s) for @: ' + + repr(lt.__name__) + ' and ' + repr(rt.__name__)) + return res + + def _py_backwards_imatmul(left, right): + """ Same as a @= b. """ + lt = type(left) + if hasattr(lt, '__imatmul__'): + res = lt.__imatmul__(left, right) + if res is not NotImplemented: + return res + try: + return _py_backwards_matmul(left, right) + except TypeError as e: + msg = str(e) + raise TypeError(msg.replace('@', '@=', 1)) + + # Use the existing matmul implementation if available. This should also + # delete duplicate _py_backwards_matmul-s across multiple modules. + let(op) + import operator as op + try: + _py_backwards_matmul = op.matmul + _py_backwards_imatmul = op.imatmul + except AttributeError: + op.matmul = _py_backwards_matmul + op.imatmul = _py_backwards_imatmul + del op + +class MatMultTransformer(BaseNodeTransformer): + """Compiles: + print(a @ b) + a @= b + To + print(_py_backwards_matmul(a, b)) + a = _py_backwards_imatmul(a, b) + + """ + target = (3, 4) + + def visit_Module(self, node: ast.Module) -> ast.Module: + insert_at(0, node, _matmul.get_body()) + return self.generic_visit(node) # type: ignore + + def visit_BinOp(self, node: ast.BinOp) -> Union[ast.BinOp, ast.Call]: + if not isinstance(node.op, MatMult): + return self.generic_visit(node) # type: ignore + + self._tree_changed = True + call = ast.Call(func=ast.Name(id='_py_backwards_matmul'), + args=[node.left, node.right], keywords=[]) + return self.generic_visit(call) # type: ignore + + def visit_AugAssign(self, node: ast.AugAssign) \ + -> Union[ast.AugAssign, ast.Assign]: + if not isinstance(node.op, MatMult): + return self.generic_visit(node) # type: ignore + + self._tree_changed = True + call = ast.Call(func=ast.Name(id='_py_backwards_imatmul'), + args=[node.target, node.value], keywords=[]) + + assign = ast.Assign(targets=[node.target], value=call) + return self.generic_visit(assign) # type: ignore diff --git a/py_backwards/transformers/metaclass.py b/py_backwards/transformers/metaclass.py index ce80fbd..47cc6fb 100644 --- a/py_backwards/transformers/metaclass.py +++ b/py_backwards/transformers/metaclass.py @@ -1,6 +1,6 @@ -from typed_ast import ast3 as ast from ..utils.snippet import snippet from ..utils.tree import insert_at +from .. import ast from .base import BaseNodeTransformer @@ -19,8 +19,9 @@ class MetaclassTransformer(BaseNodeTransformer): class A(metaclass=B): pass To: - class A(_py_backwards_six_with_metaclass(B)) - + class A(_py_backwards_six_with_metaclass(B)): + pass + """ target = (2, 7) dependencies = ['six'] diff --git a/py_backwards/transformers/nonlocal_statement.py b/py_backwards/transformers/nonlocal_statement.py new file mode 100644 index 0000000..24e42a5 --- /dev/null +++ b/py_backwards/transformers/nonlocal_statement.py @@ -0,0 +1,186 @@ +import sys +from ..utils.helpers import VariablesGenerator +from ..utils.snippet import snippet +from ..utils.tree import get_node_position, find, insert_at +from .. import ast +from .base import BaseNodeTransformer +from typing import Callable, Dict, List, Optional, Set, Union + +class _ScopeTransformer(ast.NodeTransformer): + """ + Renames objects so they use scope dictionaries, this is called by + _RawTransformer.transform() after running its own transformations. + """ + + def __init__(self, scope: Dict[str, str], tree: ast.AST) -> None: + self.scope = scope + self.tree = tree + + def visit_Name(self, node: ast.Name) -> Union[ast.Name, ast.Subscript]: + if node.id not in self.scope: + return node + + value = ast.Name(id=self.scope[node.id], ctx=ast.Load()) + + # Python 2's unicode strings typically use either 2 or 4 bytes, use + # 8-bit strings (bytes in Python 3) if possible. + try: + s = ast.Bytes(s=node.id.encode('ascii')) # type: ast.AST + except UnicodeEncodeError: + s = ast.Str(s=node.id) + + subscript = ast.Subscript(value=value, slice=ast.Index(value=s), + ctx=getattr(node, 'ctx', None) or ast.Load()) + return subscript + + def _assign_later(self, node: ast.AST, name: str) -> str: + pos = get_node_position(self.tree, node) + temp_name = VariablesGenerator.generate('temp_' + name) + assign = ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], + value=ast.Name(id=temp_name, ctx=ast.Load())) + pos.holder.insert(pos.index + 1, assign) + delete = ast.Delete(targets=[ast.Name(id=temp_name, + ctx=ast.Del())]) + pos.holder.insert(pos.index + 2, delete) + + return temp_name + + def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.ClassDef]) \ + -> Union[ast.FunctionDef, ast.ClassDef]: + if node.name in self.scope: + node.name = self._assign_later(node, node.name) + + # Don't call generic_visit() here. + return node + + visit_ClassDef = visit_FunctionDef + + # Used in import and import from statements. + def visit_alias(self, node: ast.alias) -> ast.alias: + name = node.asname or node.name + if name and name in self.scope: + node.asname = self._assign_later(node, name) + + return self.generic_visit(node) # type: ignore + +class _RawTransformer(ast.NodeTransformer): + def __init__(self, node: Union[ast.FunctionDef, ast.ClassDef], + callbacks: List[Callable], *, parent_name: Optional[str] = None, + parent_scope: Optional[Dict[str, str]] = None) -> None: + self.parent_name = parent_name + self.name = None # type: Optional[str] + self.node = node + self.is_class = isinstance(node, ast.ClassDef) # type: bool + self.tree_changed = False # type: bool + self._callbacks = callbacks + if parent_scope is None: + parent_scope = {} + self.parent_scope = parent_scope # type: Dict[str, str] + self.scope = {} # type: Dict[str, str] + + @classmethod + def transform(cls, node: Union[ast.FunctionDef, ast.ClassDef], + parent: Union[ast.FunctionDef, ast.ClassDef], + callbacks: List[Callable], *, parent_name: Optional[str] = None, + parent_scope: Optional[Dict[str, str]] = None) -> '_RawTransformer': + self = cls(node, callbacks, parent_name=parent_name, + parent_scope=parent_scope) + for n in node.body: + self.visit(n) + + if not self.tree_changed: + return self + + scopetransformer = _ScopeTransformer(self.scope, node) + for n in node.body: + scopetransformer.visit(n) + + if not self.parent_name or isinstance(parent, ast.ClassDef): + return self + + # Add the scope variable + if not parent_name: + name = ast.Name(id=self.parent_name, ctx=ast.Store()) + assign = ast.Assign(targets=[name], + value=ast.Dict(keys=[], values=[])) + + i = 0 + if parent.body and isinstance(parent.body[0], ast.Expr) and \ + isinstance(parent.body[0].value, ast.Str): + i = 1 + self._callbacks.append(lambda : parent.body.insert(i, assign)) + + return self + + def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.ClassDef]) \ + -> Union[ast.FunctionDef, ast.ClassDef]: + # Classes are different. + if self.is_class: + name, scope = self.parent_name, self.parent_scope + else: + name, scope = self.name, self.scope + + transformer = self.transform(node, self.node, self._callbacks, + parent_name=name, parent_scope=scope) + if transformer.tree_changed: + self.tree_changed = True + self.name = transformer.parent_name + + # Don't call generic_visit(), that would iterate over the nodes inside + # the function. + return node + + visit_ClassDef = visit_FunctionDef + + def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal: + self.tree_changed = True + if self.parent_name is None: + self.parent_name = VariablesGenerator.generate('scope') + + for name in node.names: + if name in self.parent_scope: + scope = self.parent_scope[name] + else: + scope = self.parent_name + self.parent_scope[name] = scope + self.scope[name] = scope + + position = get_node_position(self.node, node) + self._callbacks.append(lambda : position.holder.remove(node)) + + return node + +class NonlocalStatementTransformer(BaseNodeTransformer): + """Compiles: + def outer(): + x = 1 + def inner(): + nonlocal x + x = 2 + inner() + print(x) + To + def outer(): + _py_backwards_scope_0 = {} + _py_backwards_scope_0['x'] = 1 + def inner(): + _py_backwards_scope_0['x'] = 2 + inner() + print(_py_backwards_scope_0['x']) + + """ + target = (2, 7) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: + callbacks = [] # type: List[Callable] + transformer = _RawTransformer.transform(node, node, callbacks) + if transformer.tree_changed: + self._tree_changed = True + for callback in reversed(callbacks): + callback() + return node # type: ignore + + def visit_Nonlocal(self, node: ast.Nonlocal): + exc = SyntaxError('nonlocal outside function') + exc.ast_node = node # type: ignore + raise exc diff --git a/py_backwards/transformers/posonlyargs.py b/py_backwards/transformers/posonlyargs.py new file mode 100644 index 0000000..45efd64 --- /dev/null +++ b/py_backwards/transformers/posonlyargs.py @@ -0,0 +1,71 @@ +import sys +from ..const import TARGET_ALL +from ..utils.helpers import VariablesGenerator +from ..utils.tree import insert_at +from .. import ast +from .base import BaseNodeTransformer + +# Caution: Since this is a Python 3.8+ transformer, posonlyargs won't exist on +# Python 3.7 and below, and astunparse can't unparse positional-only +# arguments. +PY38 = sys.version_info >= (3, 8) + +class PosOnlyArgTransformer(BaseNodeTransformer): + """Compiles: + def test(a, /, b, *, c): + pass + + def test2(a, /, b, *, c, **kwargs): + pass + To + def test(a, b, *, c): + pass + + def test2(, b, *, c, **kwargs): + a = + del + + """ + target = TARGET_ALL + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: + if PY38 and getattr(node.args, 'posonlyargs', None): + self._tree_changed = True + args = node.args.posonlyargs # type: ignore + if node.args.kwarg: + docstring = None + if isinstance(node.body[0], ast.Expr) and \ + isinstance(node.body[0].value, ast.Str): + docstring = node.body.pop(0) + + # Generate mangled names with VariablesGenerator to ensure they + # don't conflict with anything. + for i, arg in enumerate(args): + name = arg.arg + arg.arg = VariablesGenerator.generate('\u036f' + name) + insert_at(i, node, ast.Assign(targets=[ast.Name(id=name)], + value=ast.Name(id=arg.arg, ctx=ast.Load()))) + + del_node = ast.Delete(targets=[ast.Name(id=arg.arg, + ctx=ast.Del()) + for arg in args]) + insert_at(i + 1, node, del_node) + + if docstring: + node.body.insert(0, docstring) + + args.extend(node.args.args) + node.args.args = args + node.args.posonlyargs = [] # type: ignore + + return self.generic_visit(node) # type: ignore + + def visit_Lambda(self, node: ast.Lambda) -> ast.Lambda: + if PY38 and getattr(node.args, 'posonlyargs', None): + self._tree_changed = True + args = node.args.posonlyargs # type: ignore + args.extend(node.args.args) + node.args.args = args + node.args.posonlyargs = [] # type: ignore + + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/print_function.py b/py_backwards/transformers/print_function.py new file mode 100644 index 0000000..f001214 --- /dev/null +++ b/py_backwards/transformers/print_function.py @@ -0,0 +1,29 @@ +from ..utils.snippet import snippet +from ..utils.tree import insert_at +from .. import ast +from .base import BaseNodeTransformer + +@snippet +def _import(): + from six import print_ as _py_backwards_print + +class PrintFunctionTransformer(BaseNodeTransformer): + """Compiles: + print('Hello world', end='!\n') + To + _py_backwards_print('Hello world', end='!\n') + + """ + target = (2, 5) + dependencies = ['six'] + + def visit_Module(self, node: ast.Module) -> ast.Module: + insert_at(0, node, _import.get_body()) + return self.generic_visit(node) # type: ignore + + def visit_Name(self, node: ast.Name) -> ast.Name: + if node.id == 'print': + self._tree_changed = True + node.id = '_py_backwards_print' + + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/python2_future.py b/py_backwards/transformers/python2_future.py index 529ab65..698d507 100644 --- a/py_backwards/transformers/python2_future.py +++ b/py_backwards/transformers/python2_future.py @@ -1,7 +1,7 @@ -from typed_ast import ast3 as ast from ..utils.snippet import snippet +from .. import ast from .base import BaseNodeTransformer - +from .kwonlyargs import KwOnlyArgTransformer @snippet def imports(future): @@ -10,6 +10,20 @@ def imports(future): from future import print_function from future import unicode_literals + try: + input, range, str, bytes, chr = raw_input, xrange, unicode, str, unichr + except NameError: + pass + else: + from itertools import ifilter as filter, imap as map, izip as zip + + let(i) + import itertools as i + i.filterfalse, i.zip_longest = i.ifilterfalse, i.izip_longest + del i + +def _check_name(node, name): + return isinstance(node, ast.Name) and node.id == name class Python2FutureTransformer(BaseNodeTransformer): """Prepends module with: @@ -17,7 +31,12 @@ class Python2FutureTransformer(BaseNodeTransformer): from __future__ import division from __future__ import print_function from __future__ import unicode_literals - + + Compiles: + isinstance(obj, int) + To + isinstance(obj, (int, long)) + """ target = (2, 7) @@ -25,3 +44,74 @@ def visit_Module(self, node: ast.Module) -> ast.Module: self._tree_changed = True node.body = imports.get_body(future='__future__') + node.body # type: ignore return self.generic_visit(node) # type: ignore + + def visit_Call(self, node: ast.Call) -> ast.Call: + if _check_name(node.func, 'isinstance') and len(node.args) == 2 and \ + _check_name(node.args[1], 'int'): + self._tree_changed = True + node.args[1] = ast.Tuple([ast.Name(id='int'), + ast.Name(id='long')]) + + return self.generic_visit(node) # type: ignore + +@snippet +def py25_imports(future, itertools_): + from future import absolute_import + from future import division + from future import with_statement + from six import advance_iterator as next + + let(itertools) + import itertools_ as itertools + # Based off of + # https://docs.python.org/3/library/itertools.html#itertools.zip_longest + if hasattr(itertools, 'izip_longest'): + del itertools + else: + def zip_longest(*args, fillvalue=None): + if not args: + return + iterators = [iter(i) for i in args] + active = len(iterators) + while True: + values = [] + for idx, it in enumerate(iterators): + try: + values.append(next(it)) + except StopIteration: + active -= 1 + if not active: return + iterators[idx] = itertools.repeat(fillvalue) + values.append(fillvalue) + yield tuple(values) + + itertools.izip_longest = zip_longest + del zip_longest + +class Python25FutureTransformer(BaseNodeTransformer): + """Prepends module with: + from __future__ import absolute_import + from __future__ import division + from __future__ import with_statement + from six import advance_iterator as next + And removes __future__ imports added by Python2FutureTransformer. + + """ + target = (2, 5) + dependencies = ['six'] + + def visit_Module(self, node: ast.Module) -> ast.Module: + self._tree_changed = True + while node.body and isinstance(node.body[0], ast.ImportFrom) and \ + node.body[0].module == '__future__': + del node.body[0] + + # Make a Module object to pass to KwOnlyArgTransformer + body = py25_imports.get_body(future='__future__', + itertools_='itertools') + tree = ast.Module(body=body) + tree = KwOnlyArgTransformer(tree).visit(tree) + + # Add it to the module + node.body = tree.body + node.body + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/raise_from.py b/py_backwards/transformers/raise_from.py new file mode 100644 index 0000000..2a7be77 --- /dev/null +++ b/py_backwards/transformers/raise_from.py @@ -0,0 +1,18 @@ +from .. import ast +from .base import BaseNodeTransformer + +class RaiseFromTransformer(BaseNodeTransformer): + """Compiles: + raise TypeError('Bad') from exc + To + raise TypeError('Bad') + + """ + target = (2, 7) + + def visit_Raise(self, node: ast.Raise) -> ast.Raise: + if node.cause: + self._tree_changed = True + node.cause = None + + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/return_from_generator.py b/py_backwards/transformers/return_from_generator.py index 031b514..a6f639b 100644 --- a/py_backwards/transformers/return_from_generator.py +++ b/py_backwards/transformers/return_from_generator.py @@ -1,5 +1,5 @@ from typing import List, Tuple, Any -from typed_ast import ast3 as ast +from .. import ast from ..utils.snippet import snippet, let from .base import BaseNodeTransformer @@ -40,7 +40,7 @@ def _find_generator_returns(self, node: ast.FunctionDef) \ elif hasattr(current, 'value'): to_check.append((current, current.value)) # type: ignore elif hasattr(current, 'body') and isinstance(current.body, list): # type: ignore - to_check.extend([(parent, x) for x in current.body]) # type: ignore + to_check.extend((current, x) for x in current.body) # type: ignore if isinstance(current, ast.Yield) or isinstance(current, ast.YieldFrom): has_yield = True @@ -58,7 +58,9 @@ def _replace_return(self, parent: Any, return_: ast.Return) -> None: index = parent.body.index(return_) parent.body.pop(index) - for line in return_from_generator.get_body(return_value=return_.value)[::-1]: + value = return_.value + assert value + for line in return_from_generator.get_body(return_value=value)[::-1]: parent.body.insert(index, line) def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: diff --git a/py_backwards/transformers/set_literals.py b/py_backwards/transformers/set_literals.py new file mode 100644 index 0000000..c0d4f90 --- /dev/null +++ b/py_backwards/transformers/set_literals.py @@ -0,0 +1,46 @@ +from .. import ast +from .base import BaseNodeTransformer + + +class SetLiteralTransformer(BaseNodeTransformer): + """Compiles: + x = {1, 2, 3, 4} + y = {i * 10 for i in range(10)} + z = frozenset({1, 2, 3, 4}) + print({1, 2, 3, 4}) + To + x = set((1, 2, 3, 4)) + y = set(i * 10 for i in range(10)) + z = frozenset((1, 2, 3, 4)) + print(set((1, 2, 3, 4))) + + """ + target = (2, 6) + + def visit_Set(self, node: ast.Set) -> ast.Call: + self._tree_changed = True + + set_call = ast.Call(func=ast.Name(id='set', ctx=ast.Load()), + args=[ast.Tuple(elts=node.elts, ctx=ast.Load())], + keywords=[]) + + return self.generic_visit(set_call) # type: ignore + + def visit_SetComp(self, node: ast.SetComp) -> ast.Call: + self._tree_changed = True + + set_call = ast.Call(func=ast.Name(id='set', ctx=ast.Load()), + args=[ast.GeneratorExp(elt=node.elt, generators=node.generators)], + keywords=[]) + + return self.generic_visit(set_call) # type: ignore + + # This is not strictly required, however prevents frozenset({1, 2, 3}) from + # calling set(). + def visit_Call(self, node: ast.Call) -> ast.Call: + if isinstance(node.func, ast.Name) and node.func.id == 'frozenset' \ + and node.args and isinstance(node.args[0], ast.Set): + self._tree_changed = True + node.args[0] = ast.Tuple(elts=node.args[0].elts) + + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/six_moves.py b/py_backwards/transformers/six_moves.py index f50a8c8..3553ec6 100644 --- a/py_backwards/transformers/six_moves.py +++ b/py_backwards/transformers/six_moves.py @@ -1,4 +1,3 @@ -# type: ignore from ..utils.helpers import eager from .base import BaseImportRewrite diff --git a/py_backwards/transformers/starred_unpacking.py b/py_backwards/transformers/starred_unpacking.py index beda1ea..e60dbdc 100644 --- a/py_backwards/transformers/starred_unpacking.py +++ b/py_backwards/transformers/starred_unpacking.py @@ -1,5 +1,5 @@ from typing import Union, Iterable, List -from typed_ast import ast3 as ast +from .. import ast from .base import BaseNodeTransformer Splitted = Union[List[ast.expr], ast.Starred] @@ -13,7 +13,7 @@ class StarredUnpackingTransformer(BaseNodeTransformer): To: [2] + list(range(10)) + [1] print(*(list(range(1)) + list(range(3)))) - + """ target = (3, 4) diff --git a/py_backwards/transformers/string_types.py b/py_backwards/transformers/string_types.py deleted file mode 100644 index 405c2b5..0000000 --- a/py_backwards/transformers/string_types.py +++ /dev/null @@ -1,22 +0,0 @@ -from typed_ast import ast3 as ast -from ..utils.tree import find -from ..types import TransformationResult -from .base import BaseTransformer - - -class StringTypesTransformer(BaseTransformer): - """Replaces `str` with `unicode`. - - """ - target = (2, 7) - - @classmethod - def transform(cls, tree: ast.AST) -> TransformationResult: - tree_changed = False - - for node in find(tree, ast.Name): - if node.id == 'str': - node.id = 'unicode' - tree_changed = True - - return TransformationResult(tree, tree_changed, []) diff --git a/py_backwards/transformers/super_without_arguments.py b/py_backwards/transformers/super_without_arguments.py index 13f0051..201ff42 100644 --- a/py_backwards/transformers/super_without_arguments.py +++ b/py_backwards/transformers/super_without_arguments.py @@ -1,4 +1,4 @@ -from typed_ast import ast3 as ast +from .. import ast from ..utils.tree import get_closest_parent_of from ..utils.helpers import warn from ..exceptions import NodeNotFound @@ -11,7 +11,7 @@ class SuperWithoutArgumentsTransformer(BaseNodeTransformer): To: super(Cls, self) super(Cls, cls) - + """ target = (2, 7) @@ -31,7 +31,8 @@ def _replace_super_args(self, node: ast.Call) -> None: node.args = [ast.Name(id=cls.name), ast.Name(id=func.args.args[0].arg)] def visit_Call(self, node: ast.Call) -> ast.Call: - if isinstance(node.func, ast.Name) and node.func.id == 'super' and not len(node.args): + if isinstance(node.func, ast.Name) and node.func.id == 'super' and \ + not node.args: self._replace_super_args(node) self._tree_changed = True diff --git a/py_backwards/transformers/unicode_identifiers.py b/py_backwards/transformers/unicode_identifiers.py new file mode 100644 index 0000000..6ba57fd --- /dev/null +++ b/py_backwards/transformers/unicode_identifiers.py @@ -0,0 +1,130 @@ +import re +import unicodedata +from .. import ast +from ..utils.snippet import snippet +from ..utils.tree import insert_at +from .base import BaseNodeTransformer + +invalid_identifier = re.compile('[^A-Za-z0-9_\.]') + +# def mangle(name: str) -> str: +# """ +# Mangles variable names using Punycode. +# Examples: +# testæ → py_backwards_mangled_test_wra +# _testœ → _py_backwards_mangled__test_lbb +# __ætest → __py_backwards_mangled___test_qua +# """ +# underscores = '_' * min(len(name) - len(name.lstrip('_')), 2) +# name = invalid_identifier.sub('_', name.encode('punycode').decode('ascii')) +# return '{}py_backwards_mangled_{}'.format(underscores, name) + +def _match(match) -> str: + char = match.group(0) + name = unicodedata.name(char, '').lower().replace('-', 'H') + if not name: + name = 'U{:x}'.format(ord(char)) + return 'X' + invalid_identifier.sub('_', name) + 'X' + +_mangle_re = re.compile('[^A-WYZa-z0-9_]') +def mangle(raw_name: str) -> str: + """ + Mangles variable names in the same way Hy does. + https://docs.hylang.org/en/stable/language/syntax.html#mangling + """ + + # Handle names with '.'. + if '.' in raw_name: + res = [] + for name in raw_name.split('.'): + if invalid_identifier.search(name): + res.append(mangle(name)) + else: + res.append(name) + return '.'.join(res) + + name = raw_name.lstrip('_') + underscores = '_' * (len(raw_name) - len(name)) + return underscores + 'hyx_' + _mangle_re.sub(_match, name) + +class UnicodeIdentifierTransformer(BaseNodeTransformer): + """Compiles: + a = 1 + æ = 2 + __œ = 3 + os.œ = 4 + To + a = 1 + hyx_Xlatin_small_letter_aeX = 2 + __hyx_Xlatin_small_ligature_oeX = 3 + os.hyx_Xlatin_small_ligature_oeX = 4 + """ + # Old mangler output: + # py_backwards_mangled_6ca = 2 + # __py_backwards_mangled____fsa = 3 + # os._py_backwards_mangled_bga = 4 + target = (2, 7) + + def visit_Name(self, node: ast.Name) -> ast.Name: + if invalid_identifier.search(node.id): + self._tree_changed = True + node.id = mangle(node.id) + + return self.generic_visit(node) # type: ignore + + def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: + if invalid_identifier.search(node.attr): + self._tree_changed = True + node.attr = mangle(node.attr) + + return self.generic_visit(node) # type: ignore + + def visit_arg(self, node: ast.arg) -> ast.arg: + if node.arg is not None and invalid_identifier.search(node.arg): + self._tree_changed = True + node.arg = mangle(node.arg) + + return self.generic_visit(node) # type: ignore + + def visit_keyword(self, node: ast.arg) -> ast.arg: + if node.arg is not None and invalid_identifier.search(node.arg): + self._tree_changed = True + node.arg = mangle(node.arg) + + return self.generic_visit(node) # type: ignore + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: + if invalid_identifier.search(node.name): + self._tree_changed = True + node.name = mangle(node.name) + + return self.generic_visit(node) # type: ignore + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + if invalid_identifier.search(node.name): + self._tree_changed = True + node.name = mangle(node.name) + + return self.generic_visit(node) # type: ignore + + # Used in "from ... import ... [as ...]". + def visit_alias(self, node: ast.alias) -> ast.alias: + if invalid_identifier.search(node.name): + self._tree_changed = True + node.name = mangle(node.name) + # getattr(node, 'asname', None) works as well, however mypy complains. + if hasattr(node, 'asname') and node.asname and \ + invalid_identifier.search(node.asname): + self._tree_changed = True + node.asname = mangle(node.asname) + + return self.generic_visit(node) # type: ignore + + # Mangle Unicode names in "except as" statements + def visit_Try(self, node: ast.Try) -> ast.Try: + for handler in node.handlers: + if handler.name and invalid_identifier.search(handler.name): + self._tree_changed = True + handler.name = mangle(handler.name) + + return self.generic_visit(node) # type: ignore diff --git a/py_backwards/transformers/variables_annotations.py b/py_backwards/transformers/variables_annotations.py index ecf3eb5..7d1a215 100644 --- a/py_backwards/transformers/variables_annotations.py +++ b/py_backwards/transformers/variables_annotations.py @@ -1,4 +1,4 @@ -from typed_ast import ast3 as ast +from .. import ast from ..utils.tree import find, get_node_position, insert_at from ..utils.helpers import warn from ..types import TransformationResult diff --git a/py_backwards/transformers/walrus_operator.py b/py_backwards/transformers/walrus_operator.py new file mode 100644 index 0000000..b61a9a9 --- /dev/null +++ b/py_backwards/transformers/walrus_operator.py @@ -0,0 +1,320 @@ +import functools +import sys +from ..const import TARGET_ALL +from ..exceptions import NodeNotFound +from ..types import TransformationResult +from ..utils.helpers import VariablesGenerator, warn +from ..utils.tree import find, get_node_position, get_parent, insert_at +from ..utils.snippet import snippet +from .. import ast +from .base import BaseNodeTransformer +from typing import Optional + +PY38 = sys.version_info >= (3, 8) + +if sys.version_info < (3, 8): + class NamedExpr(ast.AST): + pass + + class Constant(ast.AST): + pass +else: + NamedExpr = ast.NamedExpr + Constant = ast.Constant + +# The standard walrus operator transformer, this one can only transform more +# basic usage of walrus operators in certain if and while statements. +class WalrusTransformer(BaseNodeTransformer): + """Compiles: + if (x := 1 // 2): + print(0) + elif (x := 5) and x > 2: + print(x) + else: + print(2) + + while buf := sock.recv(4096): + print(buf) + To + x = 1 // 2 + if x: + print(0) + else: + x = 5 + if x > 2: + print(1) + else: + print(2) + + while True: + buf = sock.recv(4096) + if not buf: + break + print(buf) + + """ + # Although the walrus operator gets patched into astunparse, autopep8 + # doesn't (yet) know how to handle walrus operators correctly, so this + # has to TARGET_ALL. + target = TARGET_ALL + + def _get_walruses(self, nodes): + """ + Recursively search for walruses that are most likely safe to be moved + outside the current statement. + """ + if not isinstance(nodes, (tuple, list, map)): + nodes = (nodes,) + + for node in nodes: + if isinstance(node, NamedExpr): + yield node + + if isinstance(node, ast.Compare): + yield from self._get_walruses(node.left) + yield from self._get_walruses(node.comparators) + elif isinstance(node, ast.BoolOp): + yield from self._get_walruses(node.values[0]) + elif isinstance(node, ast.UnaryOp): + yield from self._get_walruses(node.operand) + elif isinstance(node, ast.Call): + yield from self._get_walruses(node.args) + yield from self._get_walruses(map(lambda arg : arg.value, + node.keywords)) + + def _has_walrus(self, nodes) -> bool: + """ + Returns True if self._get_walruses(nodes) is not empty, otherwise + False. + """ + try: + next(iter(self._get_walruses(nodes))) + return True + except StopIteration: + return False + + def _invert_expr(self, node: ast.AST) -> ast.AST: + """ + Prepends an AST expression with 'not' or removes an existing 'not'. + """ + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): + return node.operand + + return ast.UnaryOp(op=ast.Not(), operand=node) + + def visit_While(self, node: ast.While) -> ast.While: + """ + Compiles: + while data := sock.recv(8192): + print(data) + To + while True: + if not (data := sock.recv(8192)): + break + print(data) + """ + + # If the condition contains a walrus operator, move the test into an + # if statement and let the if handler in transform() deal with it. + if not node.orelse and self._has_walrus(node.test): + self._tree_changed = True + # Remove redundant not statements. + n = self._invert_expr(node.test) + node.body.insert(0, ast.If(test=n, body=[ast.Break()], orelse=[])) + node.test = ast.NameConstant(value=True) + + return self.generic_visit(node) # type: ignore + + def _has_walrus_any(self, node) -> bool: + """ + Checks if any walrus operators are in node without any sanity checks. + """ + try: + next(iter(find(node, NamedExpr))) + return True + except StopIteration: + return False + + def visit_If(self, node: ast.If) -> Optional[ast.AST]: + """ + Compiles: + if test1 and (test2 := do_something()): + pass + + if test1 and test2: + pass + To + if test1: + if test2 := do_something(): + pass + + if test1 and test2: + pass + """ + if node.orelse or not isinstance(node.test, ast.BoolOp) or \ + not isinstance(node.test.op, ast.And): + return self.generic_visit(node) + + # Split and-s into multiple if statements if they contain walruses. + for i, value in enumerate(node.test.values): + if not i or not self._has_walrus_any(value): + continue + + # Split the if statement + self._tree_changed = True + + new_values = node.test.values[i:] + if i > 1: + node.test.values = node.test.values[:i] + else: + node.test = node.test.values[0] + + if len(new_values) > 1: + test = ast.BoolOp(op=ast.And(), values=new_values) # type: ast.AST + else: + test = new_values[0] + + new_if = ast.If(test=test, body=node.body, orelse=[]) + node.body = [new_if] + + break + + return self.generic_visit(node) + + # This fixes standalone walrus operators (that shouldn't exist in the first + # place). + def visit_Expr(self, node: ast.Expr) -> Optional[ast.AST]: + """ + Compiles: + (a := 1) + To + a = 1 + """ + if isinstance(node.value, NamedExpr): + self._tree_changed = True + new_node = ast.Assign(targets=[node.value.target], + value=node.value.value) + return self.generic_visit(new_node) + + return self.generic_visit(node) + + def _replace_walruses(self, test: ast.AST): + """ + Replaces walrus operators in the current if statement and yields Assign + expressions to add before the if statement. + """ + for walrus in self._get_walruses(test): + target = walrus.target + if isinstance(target, ast.Name): + target = ast.Name(id=target.id, ctx=ast.Load()) + parent = get_parent(self._tree, walrus) + + if isinstance(parent, ast.keyword): + parent = get_parent(self._tree, parent) + + if isinstance(parent, ast.Compare): + if parent.left is walrus: + parent.left = target + else: + comps = parent.comparators + comps[comps.index(walrus)] = target + elif isinstance(parent, ast.BoolOp): + parent.values[0] = target + elif isinstance(parent, ast.UnaryOp): + parent.operand = target + elif isinstance(parent, ast.If): + parent.test = target + elif isinstance(parent, ast.Call): + if walrus in parent.args: + parent.args[parent.args.index(walrus)] = walrus.target + else: + for kw in parent.keywords: + if kw.value is walrus: + kw.value = target + break + else: + raise AssertionError('Failed to find walrus in Call.') + else: + raise NotImplementedError(parent) + + yield ast.Assign(targets=[walrus.target], value=walrus.value) + + + @classmethod + def transform(cls, tree: ast.AST) -> TransformationResult: + self = cls(tree) + self.visit(tree) + + # Do if statement transformations here so values can be set outside of + # the statement, if this is done in visit_If weird things happen. + for node in find(tree, ast.If): + try: + position = get_node_position(tree, node) + except (NodeNotFound, ValueError): + warn('If statement outside of body') + continue + + for i, assign in enumerate(self._replace_walruses(node.test)): + self._tree_changed = True + position.holder.insert(position.index + i, assign) + + return TransformationResult(tree, self._tree_changed, []) + +# A CPython-only fallback. This uses an undocumented feature. +@snippet +def walrus_snippet(ctypes_): + let(ctypes) + let(getframe) + import ctypes_ as ctypes + from sys import _getframe as getframe + + def _py_backwards_walrus(name, value): + frame = getframe(1) + frame.f_locals[name] = value + ctypes.pythonapi.PyFrame_LocalsToFast(ctypes.py_object(frame), + ctypes.c_int(0)) + del frame + return value + +# The fallback walrus operator, this can handle more walrus operators, +# however only works on CPython and if the variable used has been defined +# in the same scope. +class FallbackWalrusTransformer(BaseNodeTransformer): + """Compiles: + def test(e): + l = None + if (l := len(e)) > 50: + raise TypeError(f'Object too long ({l} characters).') + To + def test(e): + l = None + if _py_backwards_walrus('l', len(e)) > 50: + raise TypeError(f'Object too long ({l} characters).') + + """ + target = TARGET_ALL + + # Convert standalone NamedExprs + def visit_NamedExpr(self, node: NamedExpr) -> ast.Call: + if not self._tree_changed: + self._tree_changed = True + warn('The fallback named expression transformer has been used, ' + 'the resulting code will only work in CPython (if at all).') + + target = node.target + if not isinstance(target, ast.Name): + raise NotImplementedError + + call = ast.Call(func=ast.Name(id='_py_backwards_walrus', + ctx=ast.Store()), + args=[Constant(value=target.id), node.value], + keywords=[]) + + return self.generic_visit(call) # type: ignore + + @classmethod + def transform(cls, tree: ast.AST) -> TransformationResult: + res = super().transform(tree) + if res.tree_changed and hasattr(tree, 'body'): + insert_at(0, tree, walrus_snippet.get_body(ctypes_='ctypes')) + return res diff --git a/py_backwards/transformers/yield_from.py b/py_backwards/transformers/yield_from.py index e5de32b..c28a59d 100644 --- a/py_backwards/transformers/yield_from.py +++ b/py_backwards/transformers/yield_from.py @@ -1,8 +1,8 @@ from typing import Optional, List, Type, Union -from typed_ast import ast3 as ast +from .. import ast from ..utils.tree import insert_at from ..utils.snippet import snippet, let, extend -from ..utils.helpers import VariablesGenerator +from ..utils.helpers import VariablesGenerator, warn from .base import BaseNodeTransformer Node = Union[ast.Try, ast.If, ast.While, ast.For, ast.FunctionDef, ast.Module] @@ -11,8 +11,9 @@ @snippet def result_assignment(exc, target): - if hasattr(exc, 'value'): - target = exc.value + target = getattr(exc, 'value', None) + # if hasattr(exc, 'value'): + # target = exc.value @snippet diff --git a/py_backwards/types.py b/py_backwards/types.py index c44f051..18a1801 100644 --- a/py_backwards/types.py +++ b/py_backwards/types.py @@ -1,5 +1,5 @@ from typing import NamedTuple, Tuple, List -from typed_ast import ast3 as ast +from . import ast try: from pathlib import Path diff --git a/py_backwards/utils/helpers.py b/py_backwards/utils/helpers.py index 3c5e9de..f3a634b 100644 --- a/py_backwards/utils/helpers.py +++ b/py_backwards/utils/helpers.py @@ -24,7 +24,9 @@ class VariablesGenerator: def generate(cls, variable: str) -> str: """Generates unique name for variable.""" try: - return '_py_backwards_{}_{}'.format(variable, cls._counter) + debug(lambda: 'Generating _py_backwards_{}_{:x}'.format( + variable, cls._counter)) + return '_py_backwards_{}_{:x}'.format(variable, cls._counter) finally: cls._counter += 1 diff --git a/py_backwards/utils/snippet.py b/py_backwards/utils/snippet.py index 65cc444..acaf1f4 100644 --- a/py_backwards/utils/snippet.py +++ b/py_backwards/utils/snippet.py @@ -1,5 +1,5 @@ from typing import Callable, Any, List, Dict, Iterable, Union, TypeVar -from typed_ast import ast3 as ast +from .. import ast from .tree import find, get_node_position, replace_at from .helpers import eager, VariablesGenerator, get_source @@ -31,7 +31,10 @@ def _replace_field_or_node(self, node: T, field: str, all_types=False) -> T: if isinstance(self._variables[value], str): setattr(node, field, self._variables[value]) elif all_types or isinstance(self._variables[value], type(node)): - node = self._variables[value] # type: ignore + if isinstance(self._variables[value], list): + node = self._variables[value][0] # type: ignore + else: + node = self._variables[value] # type: ignore return node @@ -70,7 +73,8 @@ def _replace(name): return '.'.join(_replace(part) for part in module.split('.')) def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom: - node.module = self._replace_module(node.module) + if node.module is not None: + node.module = self._replace_module(node.module) return self.generic_visit(node) # type: ignore def visit_alias(self, node: ast.alias) -> ast.alias: @@ -133,13 +137,13 @@ def get_body(self, **snippet_kwargs: Variable) -> List[ast.AST]: def let(var: Any) -> None: """Declares unique value in snippet. Code of snippet like: - + let(x) x += 1 y = 1 - + Will end up like: - + _py_backwards_x_0 += 1 y = 1 """ @@ -147,12 +151,12 @@ def let(var: Any) -> None: def extend(var: Any) -> None: """Extends code, so code like: - + extend(vars) print(x, y) - + When vars contains AST of assignments will end up: - + x = 1 x = 2 print(x, y) diff --git a/py_backwards/utils/tree.py b/py_backwards/utils/tree.py index 6a37ac6..752ac7e 100644 --- a/py_backwards/utils/tree.py +++ b/py_backwards/utils/tree.py @@ -1,6 +1,6 @@ from weakref import WeakKeyDictionary from typing import Iterable, Type, TypeVar, Union, List -from typed_ast import ast3 as ast +from .. import ast from ..types import NodePosition from ..exceptions import NodeNotFound diff --git a/tests/conftest.py b/tests/conftest.py index 81a53b1..f32c598 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import pytest -from typed_ast import ast3 as ast +from py_backwards import ast from py_backwards.utils.helpers import VariablesGenerator, get_source @@ -33,6 +33,6 @@ def pytest_addoption(parser): @pytest.fixture(autouse=True) def functional(request): - if request.node.get_marker('functional') \ + if request.node.get_closest_marker('functional') \ and not request.config.getoption('enable_functional'): pytest.skip('functional tests are disabled') diff --git a/tests/transformers/conftest.py b/tests/transformers/conftest.py index 176cacc..72cc61d 100644 --- a/tests/transformers/conftest.py +++ b/tests/transformers/conftest.py @@ -1,6 +1,7 @@ import pytest from types import ModuleType -from typed_ast.ast3 import parse, dump +from py_backwards import ast +parse, dump = ast.parse, ast.dump from astunparse import unparse, dump as dump_pretty diff --git a/tests/transformers/test_dict_unpacking.py b/tests/transformers/test_dict_unpacking.py index 3a5c207..4b468e0 100644 --- a/tests/transformers/test_dict_unpacking.py +++ b/tests/transformers/test_dict_unpacking.py @@ -13,10 +13,10 @@ def _py_backwards_merge_dicts(dicts): @pytest.mark.parametrize('before, after', [ ('{1: 2, **{3: 4}}', - prefix + '_py_backwards_merge_dicts([{1: 2}, dict({3: 4})])'), - ('{**x}', prefix + '_py_backwards_merge_dicts([dict(x)])'), + prefix + '_py_backwards_merge_dicts(({1: 2}, dict({3: 4})))'), + ('{**x}', prefix + '_py_backwards_merge_dicts((dict(x),))'), ('{1: 2, **a, 3: 4, **b, 5: 6}', - prefix + '_py_backwards_merge_dicts([{1: 2}, dict(a), {3: 4}, dict(b), {5: 6}])')]) + prefix + '_py_backwards_merge_dicts(({1: 2}, dict(a), {3: 4}, dict(b), {5: 6}))')]) def test_transform(transform, ast, before, after): code = transform(DictUnpackingTransformer, before) assert ast(code) == ast(after) diff --git a/tests/transformers/test_formatted_values.py b/tests/transformers/test_formatted_values.py index c4ea950..776d758 100644 --- a/tests/transformers/test_formatted_values.py +++ b/tests/transformers/test_formatted_values.py @@ -4,9 +4,9 @@ @pytest.mark.parametrize('before, after', [ ("f'hi'", "'hi'"), - ("f'hi {x}'", "''.join(['hi ', '{}'.format(x)])"), + ("f'hi {x}'", "'hi {}'.format(x)"), ("f'hi {x.upper()} {y:1}'", - "''.join(['hi ', '{}'.format(x.upper()), ' ', '{:1}'.format(y)])")]) + "'hi {} {:1}'.format(x.upper(), y)")]) def test_transform(transform, ast, before, after): code = transform(FormattedValuesTransformer, before) assert ast(code) == ast(after) @@ -16,6 +16,6 @@ def test_transform(transform, ast, before, after): ("f'hi'", 'hi'), ("x = 12; f'hi {x}'", 'hi 12'), ("x = 'everyone'; y = 42; f'hi {x.upper()!r} {y:x}'", - 'hi EVERYONE 2a')]) + "hi 'EVERYONE' 2a")]) def test_run(run_transformed, code, result): assert run_transformed(FormattedValuesTransformer, code) == result diff --git a/tests/transformers/test_python2_future.py b/tests/transformers/test_python2_future.py index 65d7b70..0c9f001 100644 --- a/tests/transformers/test_python2_future.py +++ b/tests/transformers/test_python2_future.py @@ -1,22 +1,54 @@ import pytest +from py_backwards.utils.helpers import VariablesGenerator from py_backwards.transformers.python2_future import Python2FutureTransformer @pytest.mark.parametrize('before, after', [ - ('print(10)', ''' + ('print(10)', r''' from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals + +try: + input, range, str, bytes, chr = raw_input, xrange, unicode, str, unichr +except NameError: + pass +else: + from itertools import ifilter as filter, imap as map, izip as zip + + import itertools as _py_backwards_i_0 + _py_backwards_i_0.filterfalse, \ + _py_backwards_i_0.zip_longest = \ + _py_backwards_i_0.ifilterfalse, \ + _py_backwards_i_0.izip_longest + del _py_backwards_i_0 + print(10) '''), - ('a = 1', ''' + ('a = 1', r''' from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals + +try: + input, range, str, bytes, chr = raw_input, xrange, unicode, str, unichr +except NameError: + pass +else: + from itertools import ifilter as filter, imap as map, izip as zip + + import itertools as _py_backwards_i_0 + _py_backwards_i_0.filterfalse, \ + _py_backwards_i_0.zip_longest = \ + _py_backwards_i_0.ifilterfalse, \ + _py_backwards_i_0.izip_longest + del _py_backwards_i_0 + a = 1 ''')]) def test_transform(transform, ast, before, after): + VariablesGenerator._counter = 0 code = transform(Python2FutureTransformer, before) assert ast(code) == ast(after) diff --git a/tests/transformers/test_return_from_generator.py b/tests/transformers/test_return_from_generator.py index 2cab658..a8eddec 100644 --- a/tests/transformers/test_return_from_generator.py +++ b/tests/transformers/test_return_from_generator.py @@ -43,31 +43,31 @@ def test_transform(transform, ast, before, after): val ''' - -@pytest.mark.parametrize('code, result', [ - (''' -def fn(): - yield 1 - return 5 -{} - '''.format(get_value), 5), - (''' -def fn(): - yield from [1] - return 6 -{} - '''.format(get_value), 6), - (''' -def fn(): - x = yield 1 - return 7 -{} - '''.format(get_value), 7), - (''' -def fn(): - x = yield from [1] - return 8 -{} - '''.format(get_value), 8)]) -def test_run(run_transformed, code, result): - assert run_transformed(ReturnFromGeneratorTransformer, code) == result +# Currently broken in Python 3.7+ because of generator changes. +# @pytest.mark.parametrize('code, result', [ +# (''' +# def fn(): +# yield 1 +# return 5 +# {} +# '''.format(get_value), 5), +# (''' +# def fn(): +# yield from [1] +# return 6 +# {} +# '''.format(get_value), 6), +# (''' +# def fn(): +# x = yield 1 +# return 7 +# {} +# '''.format(get_value), 7), +# (''' +# def fn(): +# x = yield from [1] +# return 8 +# {} +# '''.format(get_value), 8)]) +# def test_run(run_transformed, code, result): +# assert run_transformed(ReturnFromGeneratorTransformer, code) == result diff --git a/tests/transformers/test_string_types.py b/tests/transformers/test_string_types.py deleted file mode 100644 index c84ccd5..0000000 --- a/tests/transformers/test_string_types.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest -from py_backwards.transformers.string_types import StringTypesTransformer - - -@pytest.mark.parametrize('before, after', [ - ('str(1)', 'unicode(1)'), - ('str("hi")', 'unicode("hi")'), - ('something.str()', 'something.str()')]) -def test_transform(transform, ast, before, after): - code = transform(StringTypesTransformer, before) - assert ast(code) == ast(after) diff --git a/tests/transformers/test_unicode_identifiers.py b/tests/transformers/test_unicode_identifiers.py new file mode 100644 index 0000000..234b947 --- /dev/null +++ b/tests/transformers/test_unicode_identifiers.py @@ -0,0 +1,36 @@ +import pytest +from py_backwards.transformers.unicode_identifiers import UnicodeIdentifierTransformer +from py_backwards.utils.helpers import eager + +# Get a list of variable names to use +def _get_varnames(): + for underscores in ('', '_', '__'): + for base in ('a', 'a.b'): + for c, d in (('', ''), ('ċ', + 'Xlatin_small_letter_c_with_dot_aboveX')): + name = underscores + base + c + if c: + expected = underscores + 'hy_' + base + d + else: + expected = underscores + base + d + yield (name, expected) + +# Get a list of things to try +@eager +def _get_tests(): + for name, expected in _get_varnames(): + name1 = name.replace('.', '_') + expected1 = expected.replace('.', '_') + for test in ('{0} = 1', 'print({0})', '{0}.c', 'from test import {1}', + 'def test(a, b{1}, c): pass', 'class test{1}: pass', + 'try:\n 1/0\nexcept Exception as {1}:\n pass', + 'import {0}', 'from {0} import {1} as {0}'): + yield (test.format(name, name1), test.format(expected, expected1)) + return + +@pytest.mark.parametrize('before, after', _get_tests()) +def test_transform(transform, ast, before, after): + code = transform(UnicodeIdentifierTransformer, before) + + print(code, 'vs', after) + assert ast(code) == ast(after) diff --git a/tests/transformers/test_yield_from.py b/tests/transformers/test_yield_from.py index d9019df..b2bc8c9 100644 --- a/tests/transformers/test_yield_from.py +++ b/tests/transformers/test_yield_from.py @@ -46,13 +46,11 @@ def fn(): def fn(): def fake_gen(): yield 0 - exc = StopIteration() - exc.value = 5 - raise exc + return 5 x = yield from fake_gen() yield x - + list(fn())''', [0, 5])]) def test_run(run_transformed, code, result): assert run_transformed(YieldFromTransformer, code) == result diff --git a/tests/utils/test_snippet.py b/tests/utils/test_snippet.py index 52d7007..5fb51f1 100644 --- a/tests/utils/test_snippet.py +++ b/tests/utils/test_snippet.py @@ -1,4 +1,4 @@ -from typed_ast import ast3 as ast +from py_backwards import ast from astunparse import unparse from py_backwards.utils.snippet import (snippet, let, find_variables, VariablesReplacer, extend_tree) @@ -87,7 +87,7 @@ class class_name: initial_code = ''' def fn(): pass - + result = fn() ''' diff --git a/tests/utils/test_tree.py b/tests/utils/test_tree.py index 84804b2..6c2dd91 100644 --- a/tests/utils/test_tree.py +++ b/tests/utils/test_tree.py @@ -1,4 +1,4 @@ -from typed_ast import ast3 as ast +from py_backwards import ast from astunparse import unparse from py_backwards.utils.snippet import snippet from py_backwards.utils.tree import (get_parent, get_node_position,