diff --git a/py_backwards/transformers/__init__.py b/py_backwards/transformers/__init__.py index 82efa61..5df43ab 100644 --- a/py_backwards/transformers/__init__.py +++ b/py_backwards/transformers/__init__.py @@ -3,6 +3,7 @@ from .formatted_values import FormattedValuesTransformer from .functions_annotations import FunctionsAnnotationsTransformer from .starred_unpacking import StarredUnpackingTransformer +from .async_await import AsyncAwaitTransformer from .variables_annotations import VariablesAnnotationsTransformer from .yield_from import YieldFromTransformer from .return_from_generator import ReturnFromGeneratorTransformer @@ -23,6 +24,7 @@ # 3.4 DictUnpackingTransformer, StarredUnpackingTransformer, + AsyncAwaitTransformer, # 3.2 YieldFromTransformer, ReturnFromGeneratorTransformer, diff --git a/py_backwards/transformers/async_await.py b/py_backwards/transformers/async_await.py new file mode 100644 index 0000000..4851350 --- /dev/null +++ b/py_backwards/transformers/async_await.py @@ -0,0 +1,64 @@ +from typed_ast import ast3 as ast +from .base import BaseNodeTransformer + + +ASYNCIO_MODULE = '__py_backwards_asyncio__' + +def asyncio_decorator(): + return ast.Attribute( + value=ast.Name( + id=ASYNCIO_MODULE, + ctx=ast.Load(), + ), + attr='coroutine', + ) + + +def splice(fn, vars): + for i, v in enumerate(vars): + nv = fn(v) + if nv is None: + yield v + else: + yield nv + for v in vars[i:]: + yield v + break + + +class AsyncAwaitTransformer(BaseNodeTransformer): + """Compiles: + async def ham(): + await foo() + To + + import asyncio as __py_backwards_asyncio__ + + @__py_backwards_asyncio__.coroutine + def ham(): + yield from foo() + """ + target = (3, 4) + + def visit_Module(self, n: ast.Module) -> ast.Module: + self._tree_changed = True + def dosplice(v): + # insert just after the first node that's not a future import. + if not (isinstance(v, ast.ImportFrom) and v.module == '__future__'): + return ast.Import(names=[ast.alias('asyncio', ASYNCIO_MODULE)]) + + return self.generic_visit(ast.Module(body=list(splice(dosplice, n.body)))) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.FunctionDef: + self._tree_changed = True + return self.generic_visit(ast.FunctionDef( + name=node.name, + args=node.args, + body=node.body, + decorator_list=node.decorator_list + [asyncio_decorator()], + returns=node.returns, + )) + + def visit_Await(self, node: ast.Await) -> ast.Await: + self._tree_changed = True + return self.generic_visit(ast.YieldFrom(value=node.value)) diff --git a/tests/transformers/test_async_await.py b/tests/transformers/test_async_await.py new file mode 100644 index 0000000..dba6d15 --- /dev/null +++ b/tests/transformers/test_async_await.py @@ -0,0 +1,27 @@ +import pytest +from py_backwards.transformers.async_await import AsyncAwaitTransformer + + +@pytest.mark.parametrize('before, after', [ + (''' +async def fn(): + await range(10) + ''', '''import asyncio as __py_backwards_asyncio__ + +@__py_backwards_asyncio__.coroutine +def fn(): + yield from range(10) +'''), + (''' +async def fn(): + a = await range(10) + ''', '''import asyncio as __py_backwards_asyncio__ + +@__py_backwards_asyncio__.coroutine +def fn(): + a = yield from range(10) +'''), +]) +def test_transform(transform, ast, before, after): + code = transform(AsyncAwaitTransformer, before) + assert ast(code) == ast(after)