diff --git a/src/vws/query.py b/src/vws/query.py index 6c2dc481e..b9d42d4c8 100644 --- a/src/vws/query.py +++ b/src/vws/query.py @@ -49,16 +49,20 @@ def __init__( client_access_key: str, client_secret_key: str, base_vwq_url: str = "https://cloudreco.vuforia.com", + request_timeout_seconds: float = 30.0, ) -> None: """ Args: client_access_key: A VWS client access key. client_secret_key: A VWS client secret key. base_vwq_url: The base URL for the VWQ API. + request_timeout_seconds: The timeout in seconds for each HTTP + request made to the Cloud Reco API. """ self._client_access_key = client_access_key self._client_secret_key = client_secret_key self._base_vwq_url = base_vwq_url + self.request_timeout_seconds = request_timeout_seconds def query( self, @@ -141,8 +145,7 @@ def query( url=urljoin(base=self._base_vwq_url, url=request_path), headers=headers, data=content, - # We should make the timeout customizable. - timeout=30, + timeout=self.request_timeout_seconds, ) response = Response( text=requests_response.text, diff --git a/src/vws/vws.py b/src/vws/vws.py index 9b3d19b4e..82eb7c5bc 100644 --- a/src/vws/vws.py +++ b/src/vws/vws.py @@ -68,6 +68,7 @@ def _target_api_request( data: bytes, request_path: str, base_vws_url: str, + request_timeout_seconds: float, ) -> Response: """Make a request to the Vuforia Target API. @@ -82,6 +83,7 @@ def _target_api_request( request_path: The path to the endpoint which will be used in the request. base_vws_url: The base URL for the VWS API. + request_timeout_seconds: The timeout in seconds for the request. Returns: The response to the request made by `requests`. @@ -111,8 +113,7 @@ def _target_api_request( url=url, headers=headers, data=data, - # We should make the timeout customizable. - timeout=30, + timeout=request_timeout_seconds, ) return Response( @@ -134,16 +135,20 @@ def __init__( server_access_key: str, server_secret_key: str, base_vws_url: str = "https://vws.vuforia.com", + request_timeout_seconds: float = 30.0, ) -> None: """ Args: server_access_key: A VWS server access key. server_secret_key: A VWS server secret key. base_vws_url: The base URL for the VWS API. + request_timeout_seconds: The timeout in seconds for each HTTP + request made to the VWS API. """ self._server_access_key = server_access_key self._server_secret_key = server_secret_key self._base_vws_url = base_vws_url + self.request_timeout_seconds = request_timeout_seconds def make_request( self, @@ -187,6 +192,7 @@ def make_request( data=data, request_path=request_path, base_vws_url=self._base_vws_url, + request_timeout_seconds=self.request_timeout_seconds, ) if ( diff --git a/tests/test_query.py b/tests/test_query.py index a79a712c0..85b2d03ab 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,9 +1,13 @@ """Tests for the ``CloudRecoService`` querying functionality.""" import io +import time import uuid from typing import BinaryIO +from unittest.mock import patch +import pytest +import requests from mock_vws import MockVWS from mock_vws.database import VuforiaDatabase @@ -42,6 +46,154 @@ def test_match( assert matching_target.target_id == target_id +class TestCustomRequestTimeout: + """Tests for using a custom request timeout.""" + + @staticmethod + def test_default_timeout() -> None: + """By default, the request timeout is 30 seconds.""" + default_timeout_seconds = 30.0 + with MockVWS() as mock: + database = VuforiaDatabase() + mock.add_database(database=database) + cloud_reco_client = CloudRecoService( + client_access_key=database.client_access_key, + client_secret_key=database.client_secret_key, + ) + expected = default_timeout_seconds + assert cloud_reco_client.request_timeout_seconds == expected + + @staticmethod + def test_custom_timeout(image: io.BytesIO | BinaryIO) -> None: + """It is possible to set a custom request timeout.""" + with MockVWS() as mock: + database = VuforiaDatabase() + mock.add_database(database=database) + vws_client = VWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + ) + custom_timeout = 60.5 + cloud_reco_client = CloudRecoService( + client_access_key=database.client_access_key, + client_secret_key=database.client_secret_key, + request_timeout_seconds=custom_timeout, + ) + assert cloud_reco_client.request_timeout_seconds == custom_timeout + + # Verify requests work with the custom timeout + target_id = vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + vws_client.wait_for_target_processed(target_id=target_id) + matches = cloud_reco_client.query(image=image) + assert len(matches) == 1 + + @staticmethod + def test_timeout_raises_on_slow_response( + image: io.BytesIO | BinaryIO, + ) -> None: + """A short timeout raises an error when the server is slow.""" + with MockVWS() as mock: + database = VuforiaDatabase() + mock.add_database(database=database) + vws_client = VWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + ) + cloud_reco_client = CloudRecoService( + client_access_key=database.client_access_key, + client_secret_key=database.client_secret_key, + request_timeout_seconds=0.1, + ) + + target_id = vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + vws_client.wait_for_target_processed(target_id=target_id) + + simulated_slow_threshold = 0.5 + original_request = requests.request + + def slow_request( + *args: object, + **kwargs: float | None, + ) -> requests.Response: + """Simulate a slow server response.""" + timeout = kwargs.get("timeout") + if timeout is not None and timeout < simulated_slow_threshold: + time.sleep(0.2) + raise requests.exceptions.Timeout + return original_request(*args, **kwargs) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + + with ( + patch.object( + target=requests, + attribute="request", + side_effect=slow_request, + ), + pytest.raises(expected_exception=requests.exceptions.Timeout), + ): + cloud_reco_client.query(image=image) + + @staticmethod + def test_longer_timeout_succeeds(image: io.BytesIO | BinaryIO) -> None: + """A longer timeout allows slow responses to complete.""" + simulated_slow_threshold = 0.5 + + with MockVWS() as mock: + database = VuforiaDatabase() + mock.add_database(database=database) + vws_client = VWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + ) + cloud_reco_client = CloudRecoService( + client_access_key=database.client_access_key, + client_secret_key=database.client_secret_key, + request_timeout_seconds=1.0, + ) + + target_id = vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + vws_client.wait_for_target_processed(target_id=target_id) + + original_request = requests.request + + def slow_request( + *args: object, + **kwargs: float | None, + ) -> requests.Response: + """Simulate a slow server response.""" + timeout = kwargs.get("timeout") + if timeout is not None and timeout < simulated_slow_threshold: + time.sleep(0.2) + raise requests.exceptions.Timeout + return original_request(*args, **kwargs) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + + with patch.object( + target=requests, + attribute="request", + side_effect=slow_request, + ): + # This should succeed because timeout is 1.0 > 0.5 + matches = cloud_reco_client.query(image=image) + assert len(matches) == 1 + + class TestCustomBaseVWQURL: """Tests for using a custom base VWQ URL.""" diff --git a/tests/test_vws.py b/tests/test_vws.py index 00936bde8..bbf157b4b 100644 --- a/tests/test_vws.py +++ b/tests/test_vws.py @@ -4,10 +4,13 @@ import datetime import io import secrets +import time import uuid from typing import BinaryIO +from unittest.mock import patch import pytest +import requests from freezegun import freeze_time from mock_vws import MockVWS from mock_vws.database import VuforiaDatabase @@ -92,6 +95,133 @@ def test_add_two_targets( ) +class TestCustomRequestTimeout: + """Tests for using a custom request timeout.""" + + @staticmethod + def test_default_timeout() -> None: + """By default, the request timeout is 30 seconds.""" + default_timeout_seconds = 30.0 + with MockVWS() as mock: + database = VuforiaDatabase() + mock.add_database(database=database) + vws_client = VWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + ) + expected = default_timeout_seconds + assert vws_client.request_timeout_seconds == expected + + @staticmethod + def test_custom_timeout(image: io.BytesIO | BinaryIO) -> None: + """It is possible to set a custom request timeout.""" + with MockVWS() as mock: + database = VuforiaDatabase() + mock.add_database(database=database) + custom_timeout = 60.5 + vws_client = VWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + request_timeout_seconds=custom_timeout, + ) + assert vws_client.request_timeout_seconds == custom_timeout + + # Verify requests work with the custom timeout + vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + @staticmethod + def test_timeout_raises_on_slow_response( + image: io.BytesIO | BinaryIO, + ) -> None: + """A short timeout raises an error when the server is slow.""" + simulated_slow_threshold = 0.5 + + with MockVWS() as mock: + database = VuforiaDatabase() + mock.add_database(database=database) + vws_client = VWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + request_timeout_seconds=0.1, + ) + + original_request = requests.request + + def slow_request( + *args: object, + **kwargs: float | None, + ) -> requests.Response: + """Simulate a slow server response.""" + timeout = kwargs.get("timeout") + if timeout is not None and timeout < simulated_slow_threshold: + time.sleep(0.2) + raise requests.exceptions.Timeout + return original_request(*args, **kwargs) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + + with ( + patch.object( + target=requests, + attribute="request", + side_effect=slow_request, + ), + pytest.raises(expected_exception=requests.exceptions.Timeout), + ): + vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + @staticmethod + def test_longer_timeout_succeeds(image: io.BytesIO | BinaryIO) -> None: + """A longer timeout allows slow responses to complete.""" + simulated_slow_threshold = 0.5 + + with MockVWS() as mock: + database = VuforiaDatabase() + mock.add_database(database=database) + vws_client = VWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + request_timeout_seconds=1.0, + ) + + original_request = requests.request + + def slow_request( + *args: object, + **kwargs: float | None, + ) -> requests.Response: + """Simulate a slow server response.""" + timeout = kwargs.get("timeout") + if timeout is not None and timeout < simulated_slow_threshold: + time.sleep(0.2) + raise requests.exceptions.Timeout + return original_request(*args, **kwargs) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + + with patch.object( + target=requests, + attribute="request", + side_effect=slow_request, + ): + # This should succeed because timeout is 1.0 > 0.5 + vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + class TestCustomBaseVWSURL: """Tests for using a custom base VWS URL."""