diff --git a/awslambdaric/__main__.py b/awslambdaric/__main__.py index 5cbbaab..9a0ef21 100644 --- a/awslambdaric/__main__.py +++ b/awslambdaric/__main__.py @@ -2,23 +2,31 @@ Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. """ -import os import sys +from .lambda_config import LambdaConfigProvider +from .lambda_runtime_client import LambdaRuntimeClient +from .lambda_multi_concurrent_utils import MultiConcurrentRunner from . import bootstrap def main(args): - app_root = os.getcwd() - - try: - handler = args[1] - except IndexError: - raise ValueError("Handler not set") - - lambda_runtime_api_addr = os.environ["AWS_LAMBDA_RUNTIME_API"] - - bootstrap.run(app_root, handler, lambda_runtime_api_addr) + config = LambdaConfigProvider(args) + handler = config.handler + api_addr = config.api_address + use_thread = config.use_thread_polling + + if config.is_multi_concurrent: + # Multi-concurrent mode: redirect fork, stdout/stderr and run + max_conc = int(config.max_concurrency) + socket_path = config.lmi_socket_path + MultiConcurrentRunner.run_concurrent( + handler, api_addr, use_thread, socket_path, max_conc + ) + else: + # Standard Lambda mode: single call + client = LambdaRuntimeClient(api_addr, use_thread) + bootstrap.run(handler, client) if __name__ == "__main__": diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py index cb8d5c3..d90f2a3 100644 --- a/awslambdaric/bootstrap.py +++ b/awslambdaric/bootstrap.py @@ -37,7 +37,7 @@ def _get_handler(handler): try: - (modname, fname) = handler.rsplit(".", 1) + modname, fname = handler.rsplit(".", 1) except ValueError as e: raise FaultException( FaultException.MALFORMED_HANDLER_NAME, @@ -477,19 +477,11 @@ def _setup_logging(log_format, log_level, log_sink): logger.addHandler(logger_handler) -def run(app_root, handler, lambda_runtime_api_addr): +def run(handler, lambda_runtime_client): sys.stdout = Unbuffered(sys.stdout) sys.stderr = Unbuffered(sys.stderr) - use_thread_for_polling_next = os.environ.get("AWS_EXECUTION_ENV") in { - "AWS_Lambda_python3.12", - "AWS_Lambda_python3.13", - } - with create_log_sink() as log_sink: - lambda_runtime_client = LambdaRuntimeClient( - lambda_runtime_api_addr, use_thread_for_polling_next - ) error_result = None try: diff --git a/awslambdaric/lambda_config.py b/awslambdaric/lambda_config.py new file mode 100644 index 0000000..c9922d8 --- /dev/null +++ b/awslambdaric/lambda_config.py @@ -0,0 +1,70 @@ +""" +Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +""" + +import os + + +class LambdaConfigProvider: + SUPPORTED_THREADPOLLING_ENVS = { + "AWS_Lambda_python3.12", + "AWS_Lambda_python3.13", + "AWS_Lambda_python3.14", + } + SOCKET_PATH_ENV = "_LAMBDA_TELEMETRY_LOG_FD_PROVIDER_SOCKET" + AWS_LAMBDA_RUNTIME_API = "AWS_LAMBDA_RUNTIME_API" + AWS_LAMBDA_MAX_CONCURRENCY = "AWS_LAMBDA_MAX_CONCURRENCY" + AWS_EXECUTION_ENV = "AWS_EXECUTION_ENV" + + def __init__(self, args, environ=None): + self._environ = environ if environ is not None else os.environ + self._handler = self._parse_handler(args) + self._api_address = self._parse_api_address() + self._max_concurrency = self._parse_concurrency() + self._use_thread_polling = self._parse_thread_polling() + self._lmi_socket_path = self._parse_lmi_socket_path() + + def _parse_handler(self, args): + try: + return args[1] + except IndexError: + raise ValueError("Handler not set") + + def _parse_api_address(self): + return self._environ[self.AWS_LAMBDA_RUNTIME_API] + + def _parse_concurrency(self): + return self._environ.get(self.AWS_LAMBDA_MAX_CONCURRENCY) + + def _parse_thread_polling(self): + return ( + self._environ.get(self.AWS_EXECUTION_ENV) + in self.SUPPORTED_THREADPOLLING_ENVS + ) + + def _parse_lmi_socket_path(self): + return self._environ.get(self.SOCKET_PATH_ENV) + + @property + def handler(self): + return self._handler + + @property + def api_address(self): + return self._api_address + + @property + def max_concurrency(self): + return self._max_concurrency + + @property + def use_thread_polling(self): + return self._use_thread_polling + + @property + def is_multi_concurrent(self): + return self._max_concurrency is not None + + @property + def lmi_socket_path(self): + return self._lmi_socket_path diff --git a/awslambdaric/lambda_multi_concurrent_utils.py b/awslambdaric/lambda_multi_concurrent_utils.py new file mode 100644 index 0000000..b6678d9 --- /dev/null +++ b/awslambdaric/lambda_multi_concurrent_utils.py @@ -0,0 +1,53 @@ +""" +Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +""" + +import os +import sys +import socket +import multiprocessing + +from . import bootstrap +from .lambda_runtime_client import LambdaMultiConcurrentRuntimeClient + + +class MultiConcurrentRunner: + @staticmethod + def _redirect_stream_to_fd(stream_fd: int, socket_path: str): + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.connect(socket_path) + os.dup2(s.fileno(), stream_fd) + + @classmethod + def _redirect_output(cls, socket_path: str): + for std_fd in (sys.stdout.fileno(), sys.stderr.fileno()): + cls._redirect_stream_to_fd(std_fd, socket_path) + + @classmethod + def run_single( + cls, handler: str, api_addr: str, use_thread: bool, socket_path: str + ): + if socket_path: + cls._redirect_output(socket_path) + client = LambdaMultiConcurrentRuntimeClient(api_addr, use_thread) + bootstrap.run(handler, client) + + @classmethod + def run_concurrent( + cls, + handler: str, + api_addr: str, + use_thread: bool, + socket_path: str, + max_concurrency: int, + ): + processes = [] + for _ in range(max_concurrency): + p = multiprocessing.Process( + target=cls.run_single, + args=(handler, api_addr, use_thread, socket_path), + ) + p.start() + processes.append(p) + for p in processes: + p.join() diff --git a/awslambdaric/lambda_runtime_client.py b/awslambdaric/lambda_runtime_client.py index ba4ad92..2cc8f3b 100644 --- a/awslambdaric/lambda_runtime_client.py +++ b/awslambdaric/lambda_runtime_client.py @@ -6,8 +6,14 @@ from awslambdaric import __version__ from .lambda_runtime_exception import FaultException from .lambda_runtime_marshaller import to_json +import logging +import time ERROR_TYPE_HEADER = "Lambda-Runtime-Function-Error-Type" +# Retry config constants +DEFAULT_RETRY_MAX_ATTEMPTS = 5 +DEFAULT_RETRY_INITIAL_DELAY = 0.1 # seconds +DEFAULT_RETRY_BACKOFF_FACTOR = 2.0 def _user_agent(): @@ -46,13 +52,17 @@ def __init__(self, endpoint, response_code, response_body): ) -class LambdaRuntimeClient(object): +class BaseLambdaRuntimeClient(object): marshaller = LambdaMarshaller() """marshaller is a class attribute that determines the unmarshalling and marshalling logic of a function's event and response. It allows for function authors to override the the default implementation, LambdaMarshaller which unmarshals and marshals JSON, to an instance of a class that implements the same interface.""" - def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False): + def __init__( + self, + lambda_runtime_address, + use_thread_for_polling_next=False, + ): self.lambda_runtime_address = lambda_runtime_address self.use_thread_for_polling_next = use_thread_for_polling_next if self.use_thread_for_polling_next: @@ -94,9 +104,16 @@ def post_init_error(self, error_response_data, error_type_override=None): else error_response_data["errorType"] ) } - self.call_rapid( - "POST", endpoint, http.HTTPStatus.ACCEPTED, error_response_data, headers - ) + try: + self.call_rapid( + "POST", endpoint, http.HTTPStatus.ACCEPTED, error_response_data, headers + ) + except Exception as e: + self.handle_init_error(e) + + def handle_init_error(self, exc): + """Override in subclasses to customize init error handling.""" + raise NotImplementedError def restore_next(self): import http @@ -113,6 +130,16 @@ def report_restore_error(self, restore_error_data): "POST", endpoint, http.HTTPStatus.ACCEPTED, restore_error_data, headers ) + def handle_exception(self, exc, func_to_retry=None, use_backoff=False): + """Override in subclasses to customize error handling.""" + raise NotImplementedError + + def _get_next(self): + try: + return runtime_client.next() + except Exception as e: + return self.handle_exception(e, runtime_client.next, True) + def wait_next_invocation(self): # Calling runtime_client.next() from a separate thread unblocks the main thread, # which can then process signals. @@ -120,7 +147,7 @@ def wait_next_invocation(self): try: # TPE class is supposed to be registered at construction time and be ready to use. with self.ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(runtime_client.next) + future = executor.submit(self._get_next) response_body, headers = future.result() except Exception as e: raise FaultException( @@ -145,17 +172,66 @@ def wait_next_invocation(self): def post_invocation_result( self, invoke_id, result_data, content_type="application/json" ): - runtime_client.post_invocation_result( - invoke_id, - ( - result_data - if isinstance(result_data, bytes) - else result_data.encode("utf-8") - ), - content_type, - ) + try: + runtime_client.post_invocation_result( + invoke_id, + ( + result_data + if isinstance(result_data, bytes) + else result_data.encode("utf-8") + ), + content_type, + ) + except Exception as e: + self.handle_exception(e) def post_invocation_error(self, invoke_id, error_response_data, xray_fault): - max_header_size = 1024 * 1024 # 1MiB - xray_fault = xray_fault if len(xray_fault.encode()) < max_header_size else "" - runtime_client.post_error(invoke_id, error_response_data, xray_fault) + try: + max_header_size = 1024 * 1024 + xray_fault = ( + xray_fault if len(xray_fault.encode()) < max_header_size else "" + ) + runtime_client.post_error(invoke_id, error_response_data, xray_fault) + except Exception as e: + self.handle_exception(e) + + +class LambdaRuntimeClient(BaseLambdaRuntimeClient): + def handle_exception(self, exc, func_to_retry=None, use_backoff=False): + raise exc + + def handle_init_error(self, exc): + raise exc + + +class LambdaMultiConcurrentRuntimeClient(BaseLambdaRuntimeClient): + def _get_next_with_backoff(self, e, func_to_retry): + logging.warning(f"Initial runtime_client.next() failed: {e}") + delay = DEFAULT_RETRY_INITIAL_DELAY + latest_exception = None + for attempt in range(1, DEFAULT_RETRY_MAX_ATTEMPTS): + try: + logging.info( + f"Retrying runtime_client.next() [attempt {attempt + 1}]..." + ) + time.sleep(delay) + return func_to_retry() + except Exception as e: + logging.warning(f"Attempt {attempt + 1} failed: {e}") + delay *= DEFAULT_RETRY_BACKOFF_FACTOR + latest_exception = e + + raise latest_exception + + # In multi-concurrent mode we don't want to raise unhandled exception and crash the worker on non-2xx responses from RAPID + def handle_exception(self, exc, func_to_retry=None, use_backoff=False): + if use_backoff: + return self._get_next_with_backoff(exc, func_to_retry) + # We retry if getting next invoke failed, but if posting response to RAPID failed we just log it and continue + logging.warning(f"{exc}: This won't kill the Runtime loop") + + def handle_init_error(self, exc): + if isinstance(exc, LambdaRuntimeClientError) and exc.response_code == 403: + # Suppress 403 errors from RAPID during init - indicates another runtime worker has already posted init error + return + raise exc diff --git a/awslambdaric/lambda_runtime_marshaller.py b/awslambdaric/lambda_runtime_marshaller.py index 4256066..a527674 100644 --- a/awslambdaric/lambda_runtime_marshaller.py +++ b/awslambdaric/lambda_runtime_marshaller.py @@ -18,6 +18,7 @@ def __init__(self): if os.environ.get("AWS_EXECUTION_ENV") in { "AWS_Lambda_python3.12", "AWS_Lambda_python3.13", + "AWS_Lambda_python3.14", }: super().__init__(use_decimal=False, ensure_ascii=False, allow_nan=True) else: diff --git a/deps/aws-lambda-cpp-0.2.6.tar.gz b/deps/aws-lambda-cpp-0.2.6.tar.gz index 51d7f51..a055b74 100644 Binary files a/deps/aws-lambda-cpp-0.2.6.tar.gz and b/deps/aws-lambda-cpp-0.2.6.tar.gz differ diff --git a/deps/patches/aws-lambda-cpp-logging-error.patch b/deps/patches/aws-lambda-cpp-logging-error.patch new file mode 100644 index 0000000..ac9dc1b --- /dev/null +++ b/deps/patches/aws-lambda-cpp-logging-error.patch @@ -0,0 +1,16 @@ +diff --git a/src/runtime.cpp b/src/runtime.cpp +index 9763282..9fe78d8 100644 +--- a/src/runtime.cpp ++++ b/src/runtime.cpp +@@ -379,7 +379,10 @@ runtime::post_outcome runtime::do_post( + + if (!is_success(aws::http::response_code(http_response_code))) { + logging::log_error( +- LOG_TAG, "Failed to post handler success response. Http response code: %ld.", http_response_code); ++ LOG_TAG, ++ "Failed to post handler success response. Http response code: %ld. %s", ++ http_response_code, ++ resp.get_body().c_str()); + return aws::http::response_code(http_response_code); + } + diff --git a/scripts/update_deps.sh b/scripts/update_deps.sh index 4799a6f..841d320 100755 --- a/scripts/update_deps.sh +++ b/scripts/update_deps.sh @@ -31,7 +31,8 @@ wget -c https://github.com/awslabs/aws-lambda-cpp/archive/v$AWS_LAMBDA_CPP_RELEA patch -p1 < ../patches/aws-lambda-cpp-make-the-runtime-client-user-agent-overrideable.patch && \ patch -p1 < ../patches/aws-lambda-cpp-make-lto-optional.patch && \ patch -p1 < ../patches/aws-lambda-cpp-add-content-type.patch && \ - patch -p1 < ../patches/aws-lambda-cpp-add-tenant-id.patch + patch -p1 < ../patches/aws-lambda-cpp-add-tenant-id.patch && \ + patch -p1 < ../patches/aws-lambda-cpp-logging-error.patch ) ## Pack again and remove the folder diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 33afb1c..1eb2bb0 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -1496,24 +1496,20 @@ def test_set_log_level_with_dictConfig(self, mock_stderr, mock_stdout): class TestBootstrapModule(unittest.TestCase): - @patch("awslambdaric.bootstrap.LambdaRuntimeClient") - def test_run(self, mock_runtime_client): - expected_app_root = "/tmp/test/app_root" + def test_run(self): expected_handler = "app.my_test_handler" - expected_lambda_runtime_api_addr = "test_addr" mock_event_request = MagicMock() mock_event_request.x_amzn_trace_id = "123" + mock_runtime_client = MagicMock() mock_runtime_client.return_value.wait_next_invocation.side_effect = [ mock_event_request, MagicMock(), ] with self.assertRaises(SystemExit) as cm: - bootstrap.run( - expected_app_root, expected_handler, expected_lambda_runtime_api_addr - ) + bootstrap.run(expected_handler, mock_runtime_client) self.assertEqual(cm.exception.code, 1) @@ -1523,23 +1519,20 @@ def test_run(self, mock_runtime_client): ) @patch("awslambdaric.bootstrap.build_fault_result") @patch("awslambdaric.bootstrap.log_error", MagicMock()) - @patch("awslambdaric.bootstrap.LambdaRuntimeClient", MagicMock()) @patch("awslambdaric.bootstrap.sys") def test_run_exception(self, mock_sys, mock_build_fault_result): class TestException(Exception): pass - expected_app_root = "/tmp/test/app_root" expected_handler = "app.my_test_handler" - expected_lambda_runtime_api_addr = "test_addr" + + mock_runtime_client = MagicMock() mock_build_fault_result.return_value = {} mock_sys.exit.side_effect = TestException("Boom!") with self.assertRaises(TestException): - bootstrap.run( - expected_app_root, expected_handler, expected_lambda_runtime_api_addr - ) + bootstrap.run(expected_handler, mock_runtime_client) mock_sys.exit.assert_called_once_with(1) @@ -1561,8 +1554,8 @@ def tearDown(self): def raise_type_error(self): raise TypeError("This is a Dummy type error") - @patch("awslambdaric.bootstrap.LambdaRuntimeClient") - def test_before_snapshot_exception(self, mock_runtime_client): + def test_before_snapshot_exception(self): + mock_runtime_client = MagicMock() snapshot_restore_py.register_before_snapshot(self.raise_type_error) with self.assertRaises(SystemExit) as cm: @@ -1576,8 +1569,8 @@ def test_before_snapshot_exception(self, mock_runtime_client): FaultException.BEFORE_SNAPSHOT_ERROR, ) - @patch("awslambdaric.bootstrap.LambdaRuntimeClient") - def test_after_restore_exception(self, mock_runtime_client): + def test_after_restore_exception(self): + mock_runtime_client = MagicMock() snapshot_restore_py.register_after_restore(self.raise_type_error) with self.assertRaises(SystemExit) as cm: diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..9e74793 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,51 @@ +""" +Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +""" + +import multiprocessing +import unittest +from unittest.mock import patch, MagicMock + +from awslambdaric.lambda_multi_concurrent_utils import MultiConcurrentRunner + + +class LambdaRuntimeConcurrencyTest(unittest.TestCase): + def setUp(self): + # common args + self.handler = "h.fn" + self.addr = "addr" + self.use_thread = False + self.socket = "/tmp/sock" + + def test_success_and_failure_isolation(self): + success_counter = multiprocessing.Value("i", 0) + fail_counter = multiprocessing.Value("i", 0) + + def fake_bootstrap_run(handler, lambda_runtime_client): + pid = multiprocessing.current_process().pid + if pid % 2 == 0: + for _ in range(3): + with success_counter.get_lock(): + success_counter.value += 1 + else: + with fail_counter.get_lock(): + fail_counter.value += 1 + raise RuntimeError("Simulated failure") + + with patch( + "awslambdaric.lambda_multi_concurrent_utils.MultiConcurrentRunner._redirect_output" + ), patch( + "awslambdaric.lambda_multi_concurrent_utils.bootstrap.run", + side_effect=fake_bootstrap_run, + ): + # spawn 4 multi-concurrent processes + MultiConcurrentRunner.run_concurrent( + self.handler, self.addr, self.use_thread, self.socket, max_concurrency=4 + ) + + self.assertEqual(success_counter.value, 6) + self.assertEqual(fail_counter.value, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_lambda_config.py b/tests/test_lambda_config.py new file mode 100644 index 0000000..6e33afd --- /dev/null +++ b/tests/test_lambda_config.py @@ -0,0 +1,70 @@ +""" +Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +""" + +import os +import unittest +from awslambdaric.lambda_config import LambdaConfigProvider + + +class TestLambdaConfigProvider(unittest.TestCase): + def setUp(self): + self.orig = os.environ.copy() + + def tearDown(self): + os.environ.clear() + os.environ.update(self.orig) + + def test_handler_property_and_missing(self): + cfg = LambdaConfigProvider( + ["prog", "h.fn"], environ={"AWS_LAMBDA_RUNTIME_API": "a"} + ) + self.assertEqual(cfg.handler, "h.fn") + with self.assertRaises(ValueError): + LambdaConfigProvider(["prog"], environ={"AWS_LAMBDA_RUNTIME_API": "a"}) + + def test_api_address_property_and_missing(self): + cfg = LambdaConfigProvider( + ["prog", "h.fn"], environ={"AWS_LAMBDA_RUNTIME_API": "endpoint"} + ) + self.assertEqual(cfg.api_address, "endpoint") + with self.assertRaises(KeyError): + LambdaConfigProvider(["prog", "h.fn"], environ={}) + + def test_concurrency_and_is_multi_concurrent(self): + env = {"AWS_LAMBDA_RUNTIME_API": "a", "AWS_LAMBDA_MAX_CONCURRENCY": "4"} + cfg = LambdaConfigProvider(["p", "h.fn"], environ=env) + self.assertEqual(cfg.max_concurrency, "4") + self.assertTrue(cfg.is_multi_concurrent) + env2 = {"AWS_LAMBDA_RUNTIME_API": "a"} + cfg2 = LambdaConfigProvider(["p", "h.fn"], environ=env2) + self.assertIsNone(cfg2.max_concurrency) + self.assertFalse(cfg2.is_multi_concurrent) + + def test_use_thread_polling_flag(self): + env = { + "AWS_LAMBDA_RUNTIME_API": "a", + "AWS_EXECUTION_ENV": "AWS_Lambda_python3.12", + } + cfg = LambdaConfigProvider(["p", "h.fn"], environ=env) + self.assertTrue(cfg.use_thread_polling) + env2 = {"AWS_LAMBDA_RUNTIME_API": "a", "AWS_EXECUTION_ENV": "OTHER"} + cfg2 = LambdaConfigProvider(["p", "h.fn"], environ=env2) + self.assertFalse(cfg2.use_thread_polling) + + def test_lmi_socket_path_property(self): + env = { + "AWS_LAMBDA_RUNTIME_API": "a", + "_LAMBDA_TELEMETRY_LOG_FD_PROVIDER_SOCKET": "/sock", + } + cfg = LambdaConfigProvider(["p", "h.fn"], environ=env) + self.assertEqual(cfg.lmi_socket_path, "/sock") + + # Test case where socket path env var is not set + env2 = {"AWS_LAMBDA_RUNTIME_API": "a"} + cfg2 = LambdaConfigProvider(["p", "h.fn"], environ=env2) + self.assertIsNone(cfg2.lmi_socket_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_lambda_runtime_client.py b/tests/test_lambda_runtime_client.py index fc4af65..c25e581 100644 --- a/tests/test_lambda_runtime_client.py +++ b/tests/test_lambda_runtime_client.py @@ -5,12 +5,14 @@ import http import http.client import unittest.mock +import threading from unittest.mock import MagicMock, patch from awslambdaric import __version__ from awslambdaric.lambda_runtime_client import ( InvocationRequest, LambdaRuntimeClient, + LambdaMultiConcurrentRuntimeClient, LambdaRuntimeClientError, _user_agent, ) @@ -61,20 +63,21 @@ def test_constructor(self): class TestLambdaRuntime(unittest.TestCase): + get_next_headers = { + "Lambda-Runtime-Aws-Request-Id": "RID1234", + "Lambda-Runtime-Trace-Id": "TID1234", + "Lambda-Runtime-Invoked-Function-Arn": "FARN1234", + "Lambda-Runtime-Deadline-Ms": 12, + "Lambda-Runtime-Client-Context": "client_context", + "Lambda-Runtime-Cognito-Identity": "cognito_identity", + "Lambda-Runtime-Aws-Tenant-Id": "tenant_id", + "Content-Type": "application/json", + } + @patch("awslambdaric.lambda_runtime_client.runtime_client") def test_wait_next_invocation(self, mock_runtime_client): response_body = b"{}" - headears = { - "Lambda-Runtime-Aws-Request-Id": "RID1234", - "Lambda-Runtime-Trace-Id": "TID1234", - "Lambda-Runtime-Invoked-Function-Arn": "FARN1234", - "Lambda-Runtime-Deadline-Ms": 12, - "Lambda-Runtime-Client-Context": "client_context", - "Lambda-Runtime-Cognito-Identity": "cognito_identity", - "Lambda-Runtime-Aws-Tenant-Id": "tenant_id", - "Content-Type": "application/json", - } - mock_runtime_client.next.return_value = response_body, headears + mock_runtime_client.next.return_value = response_body, self.get_next_headers runtime_client = LambdaRuntimeClient("localhost:1234") event_request = runtime_client.wait_next_invocation() @@ -106,6 +109,30 @@ def test_wait_next_invocation(self, mock_runtime_client): self.assertEqual(event_request.content_type, "application/json") self.assertEqual(event_request.event_body, response_body) + @patch("awslambdaric.lambda_runtime_client.runtime_client") + def test_wait_next_invocation_calls_next_from_separate_thread( + self, mock_runtime_client + ): + thread_ids = [] + + def record_thread_id(): + thread_ids.append(threading.get_ident()) + return b"{}", self.get_next_headers + + mock_runtime_client.next.side_effect = record_thread_id + + main_thread_id = threading.get_ident() + + runtime_client = LambdaRuntimeClient("localhost:1234", True) + runtime_client.wait_next_invocation() + + self.assertEqual(len(thread_ids), 1) + self.assertNotEqual( + thread_ids[0], + main_thread_id, + "runtime_client.next() was not called from a separate thread", + ) + @patch("awslambdaric.lambda_runtime_client.runtime_client") def test_wait_next_invocation_without_tenant_id_header(self, mock_runtime_client): response_body = b"{}" @@ -285,6 +312,145 @@ def test_post_invocation_error(self, mock_runtime_client): invoke_id, error_data, xray_fault ) + @patch("awslambdaric.lambda_runtime_client.runtime_client") + def test_get_next_falls_back_to_backoff_if_multi_concurrent( + self, mock_runtime_client + ): + # First call raises, second call succeeds + mock_runtime_client.next.side_effect = [RuntimeError("first fail"), (b"{}", {})] + client = LambdaMultiConcurrentRuntimeClient( + "localhost:1234", use_thread_for_polling_next=True + ) + + result = client._get_next() + self.assertEqual(result, (b"{}", {})) + self.assertEqual(mock_runtime_client.next.call_count, 2) + + @patch("awslambdaric.lambda_runtime_client.runtime_client") + def test_get_next_raises_if_not_multi_concurrent(self, mock_runtime_client): + mock_runtime_client.next.side_effect = RuntimeError("fail") + + client = LambdaRuntimeClient("localhost:1234", use_thread_for_polling_next=True) + + with self.assertRaises(RuntimeError): + client._get_next() + + self.assertEqual(mock_runtime_client.next.call_count, 1) + + @patch("awslambdaric.lambda_runtime_client.runtime_client") + @patch("time.sleep", return_value=None) + def test_get_next_retries_with_exponential_backoff( + self, mock_sleep, mock_runtime_client + ): + # Simulate all attempts failing + mock_runtime_client.next.side_effect = RuntimeError("always fail") + client = LambdaMultiConcurrentRuntimeClient( + "localhost:1234", use_thread_for_polling_next=True + ) + + with self.assertRaises(RuntimeError): + client._get_next() + + # 1 initial + 4 retries + self.assertEqual(mock_runtime_client.next.call_count, 5) + + expected_delays = [0.1, 0.2, 0.4, 0.8] + actual_delays = [call.args[0] for call in mock_sleep.call_args_list] + self.assertEqual(actual_delays, expected_delays) + + @patch("awslambdaric.lambda_runtime_client.runtime_client") + def test_post_invocation_result_suppresses_error_if_multi_concurrent( + self, mock_runtime_client + ): + mock_runtime_client.post_invocation_result.side_effect = RuntimeError("failure") + + client = LambdaMultiConcurrentRuntimeClient( + "localhost:1234", use_thread_for_polling_next=True + ) + + with self.assertLogs(level="WARNING"): + client.post_invocation_result("invoke_id", "result") + + @patch("awslambdaric.lambda_runtime_client.runtime_client") + def test_post_invocation_result_raises_if_not_multi_concurrent( + self, mock_runtime_client + ): + mock_runtime_client.post_invocation_result.side_effect = RuntimeError("failure") + + client = LambdaRuntimeClient("localhost:1234", use_thread_for_polling_next=True) + + with self.assertRaises(RuntimeError): + client.post_invocation_result("invoke_id", "result") + + @patch("awslambdaric.lambda_runtime_client.runtime_client") + def test_post_invocation_error_suppresses_error_if_multi_concurrent( + self, mock_runtime_client + ): + mock_runtime_client.post_error.side_effect = RuntimeError("post error") + + client = LambdaMultiConcurrentRuntimeClient( + "localhost:1234", use_thread_for_polling_next=True + ) + + with self.assertLogs(level="WARNING"): + client.post_invocation_error("invoke_id", "error_data", "xray_data") + + @patch("awslambdaric.lambda_runtime_client.runtime_client") + def test_post_invocation_error_raises_if_not_multi_concurrent( + self, mock_runtime_client + ): + mock_runtime_client.post_error.side_effect = RuntimeError("post error") + + client = LambdaRuntimeClient("localhost:1234", use_thread_for_polling_next=True) + + with self.assertRaises(RuntimeError): + client.post_invocation_error("invoke_id", "error_data", "xray_data") + + @patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection) + def test_post_init_error_suppresses_403_if_multi_concurrent( + self, MockHTTPConnection + ): + mock_conn = MockHTTPConnection.return_value + mock_response = MagicMock(autospec=http.client.HTTPResponse) + mock_conn.getresponse.return_value = mock_response + mock_response.read.return_value = b"" + mock_response.code = http.HTTPStatus.FORBIDDEN + + client = LambdaMultiConcurrentRuntimeClient("localhost:1234") + + # Should not raise exception for 403 error + client.post_init_error(self.error_result) + + @patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection) + def test_post_init_error_raises_non_403_if_multi_concurrent( + self, MockHTTPConnection + ): + mock_conn = MockHTTPConnection.return_value + mock_response = MagicMock(autospec=http.client.HTTPResponse) + mock_conn.getresponse.return_value = mock_response + mock_response.read.return_value = b"" + mock_response.code = http.HTTPStatus.INTERNAL_SERVER_ERROR + + client = LambdaMultiConcurrentRuntimeClient("localhost:1234") + + with self.assertRaises(LambdaRuntimeClientError): + client.post_init_error(self.error_result) + + @patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection) + def test_post_init_error_raises_403_if_not_multi_concurrent( + self, MockHTTPConnection + ): + mock_conn = MockHTTPConnection.return_value + mock_response = MagicMock(autospec=http.client.HTTPResponse) + mock_conn.getresponse.return_value = mock_response + mock_response.read.return_value = b"" + mock_response.code = http.HTTPStatus.FORBIDDEN + + client = LambdaRuntimeClient("localhost:1234") + + with self.assertRaises(LambdaRuntimeClientError): + client.post_init_error(self.error_result) + @patch("awslambdaric.lambda_runtime_client.runtime_client") def test_post_invocation_error_with_large_xray_cause(self, mock_runtime_client): runtime_client = LambdaRuntimeClient("localhost:1234") diff --git a/tests/test_lambda_runtime_marshaller.py b/tests/test_lambda_runtime_marshaller.py index 843bcee..118d535 100644 --- a/tests/test_lambda_runtime_marshaller.py +++ b/tests/test_lambda_runtime_marshaller.py @@ -11,6 +11,7 @@ class TestLambdaRuntimeMarshaller(unittest.TestCase): execution_envs = ( + "AWS_Lambda_python3.14", "AWS_Lambda_python3.13", "AWS_Lambda_python3.12", "AWS_Lambda_python3.11", @@ -21,6 +22,7 @@ class TestLambdaRuntimeMarshaller(unittest.TestCase): envs_lambda_marshaller_ensure_ascii_false = { "AWS_Lambda_python3.12", "AWS_Lambda_python3.13", + "AWS_Lambda_python3.14", } execution_envs_lambda_marshaller_ensure_ascii_true = tuple( diff --git a/tests/test_main.py b/tests/test_main.py index 8a17a5d..d148b78 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,32 +2,55 @@ Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. """ -import os import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch import awslambdaric.__main__ as package_entry -class TestEnvVars(unittest.TestCase): - def setUp(self): - self.org_os_environ = os.environ - - def tearDown(self): - os.environ = self.org_os_environ - +class TestMain(unittest.TestCase): @patch("awslambdaric.__main__.bootstrap") - def test_main(self, mock_bootstrap): - expected_app_root = os.getcwd() - expected_handler = "app.my_test_handler" - expected_lambda_runtime_api_addr = "test_addr" - - args = ["dummy", expected_handler, "other_dummy"] + @patch("awslambdaric.__main__.LambdaRuntimeClient") + @patch("awslambdaric.__main__.LambdaConfigProvider") + def test_default_path_invokes_runtime_client_and_bootstrap( + self, mock_config_provider, mock_client_cls, mock_bootstrap + ): + # Non-multi-concurrent mode + cfg = MagicMock() + cfg.handler = "my.handler" + cfg.api_address = "http://addr" + cfg.use_thread_polling = False + cfg.is_multi_concurrent = False + mock_config_provider.return_value = cfg + + package_entry.main(["prog", "my.handler"]) + + mock_client_cls.assert_called_once_with("http://addr", False) + mock_bootstrap.run.assert_called_once_with( + "my.handler", mock_client_cls.return_value + ) - os.environ["AWS_LAMBDA_RUNTIME_API"] = expected_lambda_runtime_api_addr + @patch("awslambdaric.__main__.MultiConcurrentRunner") + @patch("awslambdaric.__main__.LambdaConfigProvider") + def test_multi_concurrent_path_dispatches_to_multi_concurrent_runner( + self, mock_config_provider, mock_runner + ): + # Multi-concurrent mode + cfg = MagicMock() + cfg.handler = "my.handler" + cfg.api_address = "http://addr" + cfg.use_thread_polling = True + cfg.is_multi_concurrent = True + cfg.max_concurrency = "2" + cfg.lmi_socket_path = "/tmp/lmi.sock" + mock_config_provider.return_value = cfg + + package_entry.main(["prog", "my.handler"]) + + mock_runner.run_concurrent.assert_called_once_with( + "my.handler", "http://addr", True, "/tmp/lmi.sock", 2 + ) - package_entry.main(args) - mock_bootstrap.run.assert_called_once_with( - expected_app_root, expected_handler, expected_lambda_runtime_api_addr - ) +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_multi_concurrent_runner.py b/tests/test_multi_concurrent_runner.py new file mode 100644 index 0000000..4a9d023 --- /dev/null +++ b/tests/test_multi_concurrent_runner.py @@ -0,0 +1,123 @@ +""" +Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +""" + +import sys +import unittest +from unittest.mock import patch, MagicMock + +from awslambdaric.lambda_multi_concurrent_utils import MultiConcurrentRunner + + +class TestMultiConcurrentRunnerRedirect(unittest.TestCase): + @patch("socket.socket") + @patch("os.dup2") + def test_redirect_output_opens_two_sockets_and_dup2s( + self, mock_dup2, mock_socket_cls + ): + sock1 = MagicMock() + sock1.fileno.return_value = 10 + sock1.__enter__.return_value = sock1 # <-- key line + sock1.__exit__.return_value = None + + sock2 = MagicMock() + sock2.fileno.return_value = 11 + sock2.__enter__.return_value = sock2 # <-- key line + sock2.__exit__.return_value = None + + mock_socket_cls.side_effect = [sock1, sock2] + + MultiConcurrentRunner._redirect_output("/fake/path") + + self.assertEqual(mock_socket_cls.call_count, 2) + sock1.connect.assert_called_once_with("/fake/path") + sock2.connect.assert_called_once_with("/fake/path") + mock_dup2.assert_any_call(10, sys.stdout.fileno()) + mock_dup2.assert_any_call(11, sys.stderr.fileno()) + + # With a context manager, prefer asserting __exit__ was called: + self.assertEqual(sock1.__enter__.call_count, 1) + self.assertEqual(sock1.__exit__.call_count, 1) + self.assertEqual(sock2.__enter__.call_count, 1) + self.assertEqual(sock2.__exit__.call_count, 1) + + @patch( + "awslambdaric.lambda_multi_concurrent_utils.LambdaMultiConcurrentRuntimeClient" + ) + @patch("awslambdaric.lambda_multi_concurrent_utils.bootstrap") + def test_run_single_creates_client_and_calls_bootstrap( + self, mock_bootstrap, mock_client_cls + ): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + + # stub out redirect + with patch.object(MultiConcurrentRunner, "_redirect_output"): + MultiConcurrentRunner.run_single("h.fn", "addr", True, "/socket") + + mock_client_cls.assert_called_once_with("addr", True) + mock_bootstrap.run.assert_called_once_with("h.fn", mock_client) + + @patch("multiprocessing.Process") + def test_run_concurrent_spawns_and_joins(self, mock_process): + fake_proc = MagicMock() + mock_process.return_value = fake_proc + + MultiConcurrentRunner.run_concurrent( + "h", "a", False, "/sock", max_concurrency=3 + ) + + self.assertEqual(mock_process.call_count, 3) + self.assertEqual(fake_proc.start.call_count, 3) + self.assertEqual(fake_proc.join.call_count, 3) + + for call_args in mock_process.call_args_list: + target = call_args.kwargs.get("target") or call_args[1].get("target") + args = call_args.kwargs.get("args") or call_args[1].get("args") + self.assertEqual(target, MultiConcurrentRunner.run_single) + self.assertEqual(args, ("h", "a", False, "/sock")) + + @patch( + "awslambdaric.lambda_multi_concurrent_utils.LambdaMultiConcurrentRuntimeClient" + ) + @patch("awslambdaric.lambda_multi_concurrent_utils.bootstrap") + def test_run_single_skips_redirect_when_socket_path_is_none( + self, mock_bootstrap, mock_client_cls + ): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + + with patch.object(MultiConcurrentRunner, "_redirect_output") as mock_redirect: + MultiConcurrentRunner.run_single("h.fn", "addr", True, None) + + # Verify _redirect_output was not called + mock_redirect.assert_not_called() + + # Verify client and bootstrap are still called normally + mock_client_cls.assert_called_once_with("addr", True) + mock_bootstrap.run.assert_called_once_with("h.fn", mock_client) + + @patch( + "awslambdaric.lambda_multi_concurrent_utils.LambdaMultiConcurrentRuntimeClient" + ) + @patch("awslambdaric.lambda_multi_concurrent_utils.bootstrap") + def test_run_single_calls_redirect_when_socket_path_is_provided( + self, mock_bootstrap, mock_client_cls + ): + """Test that _redirect_output is called when socket_path is provided""" + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + + with patch.object(MultiConcurrentRunner, "_redirect_output") as mock_redirect: + MultiConcurrentRunner.run_single("h.fn", "addr", True, "/valid/socket/path") + + # Verify _redirect_output was called with the socket path + mock_redirect.assert_called_once_with("/valid/socket/path") + + # Verify client and bootstrap are still called normally + mock_client_cls.assert_called_once_with("addr", True) + mock_bootstrap.run.assert_called_once_with("h.fn", mock_client) + + +if __name__ == "__main__": + unittest.main()