diff --git a/faust/__init__.py b/faust/__init__.py index c20b05903..f3ef14c75 100644 --- a/faust/__init__.py +++ b/faust/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Python Stream processing.""" + # :copyright: (c) 2017-2020, Robinhood Markets, Inc. # All rights reserved. # :license: BSD (3 Clause), see LICENSE for more details. diff --git a/faust/stores/__init__.py b/faust/stores/__init__.py index 56cc0dc61..dd3f0feb9 100644 --- a/faust/stores/__init__.py +++ b/faust/stores/__init__.py @@ -12,6 +12,7 @@ memory="faust.stores.memory:Store", rocksdb="faust.stores.rocksdb:Store", aerospike="faust.stores.aerospike:AeroSpikeStore", + bigtable="faust.stores.bigtable:BigTableStore", ) STORES.include_setuptools_namespace("faust.stores") by_name = STORES.by_name diff --git a/faust/stores/aerospike.py b/faust/stores/aerospike.py index 291ccbe9f..198f78dc7 100644 --- a/faust/stores/aerospike.py +++ b/faust/stores/aerospike.py @@ -98,7 +98,7 @@ def _get(self, key: bytes) -> Optional[bytes]: key = (self.namespace, self.table_name, key) fun = self.client.get try: - (key, meta, bins) = self.aerospike_fun_call_with_retry(fun=fun, key=key) + key, meta, bins = self.aerospike_fun_call_with_retry(fun=fun, key=key) if bins: return bins[self.BIN_KEY] return None @@ -173,7 +173,7 @@ def _itervalues(self) -> Iterator[bytes]: fun=fun, namespace=self.namespace, set=self.table_name ) for result in scan.results(): - (key, meta, bins) = result + key, meta, bins = result if bins: yield bins[self.BIN_KEY] else: @@ -193,8 +193,8 @@ def _iteritems(self) -> Iterator[Tuple[bytes, bytes]]: fun=fun, namespace=self.namespace, set=self.table_name ) for result in scan.results(): - (key_data, meta, bins) = result - (ns, set, policy, key) = key_data + key_data, meta, bins = result + ns, set, policy, key = key_data if bins: bins = bins[self.BIN_KEY] @@ -214,7 +214,7 @@ def _contains(self, key: bytes) -> bool: try: if self.app.conf.store_check_exists: key = (self.namespace, self.table_name, key) - (key, meta) = self.aerospike_fun_call_with_retry( + key, meta = self.aerospike_fun_call_with_retry( fun=self.client.exists, key=key ) if meta: diff --git a/faust/stores/bigtable.py b/faust/stores/bigtable.py new file mode 100644 index 000000000..ded7ed76d --- /dev/null +++ b/faust/stores/bigtable.py @@ -0,0 +1,630 @@ +"""BigTable storage.""" + +from __future__ import annotations + +import gc +import logging +import threading +import time +import traceback +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Union, +) + +from mode.utils.collections import LRUCache + +try: # pragma: no cover + from google.api_core.exceptions import AlreadyExists + from google.cloud.bigtable import column_family + from google.cloud.bigtable.batcher import MutationsBatcher + from google.cloud.bigtable.client import Client + from google.cloud.bigtable.instance import Instance + from google.cloud.bigtable.row import DirectRow + from google.cloud.bigtable.row_filters import CellsColumnLimitFilter + from google.cloud.bigtable.row_set import RowSet + from google.cloud.bigtable.table import Table + + # Make one container for all imported functions + # This is needed for testing and controlling the imports + class BT: + column_family = column_family + Client = Client + Instance = Instance + DirectRow = DirectRow + CellsColumnLimitFilter = CellsColumnLimitFilter + RowSet = RowSet + Table = Table + +except ImportError as e: # pragma: no cover + logger = logging.getLogger(__name__).error(e) + BT = None # noqa + +from yarl import URL + +from faust.stores import base +from faust.streams import current_event +from faust.types import TP, AppT, CollectionT, EventT + +COLUMN_FAMILY_ID = "FaustColumnFamily" +COLUMN_NAME = "DATA" + + +class BigTableStore(base.SerializedStore): + """Bigtable table storage.""" + + client: BT.Client + instance: BT.Instance + bt_table: BT.Table + + BT_COLUMN_NAME_KEY = "bt_column_name_key" + BT_INSTANCE_KEY = "bt_instance_key" + BT_OFFSET_KEY_PREFIX = "bt_offset_key_prefix" + BT_PROJECT_KEY = "bt_project_key" + BT_TABLE_NAME_GENERATOR_KEY = "bt_table_name_generator_key" + BT_STARTUP_CACHE_ENABLE_KEY = "bt_startup_cache_enable_key" + BT_STARTUP_CACHE_TTL_KEY = "bt_startup_cache_ttl_key" + BT_MUTATION_BATCHER_ENABLE_KEY = "bt_mutation_batcher_enable_key" + BT_MUTATION_BATCHER_FLUSH_COUNT_KEY = "bt_mutation_batcher_flush_count_key" + BT_MUTATION_BATCHER_FLUSH_INTERVAL_KEY = "bt_mutation_batcher_flush_interval_key" + + def __init__( + self, + url: Union[str, URL], + app: AppT, + table: CollectionT, + options: Dict[str, Any], + **kwargs: Any, + ) -> None: + self._set_options(options) + try: + self._setup_bigtable(table, options) + self._setup_caches(options) + self._setup_mutation_batcher(options) + self.key_index_size = app.conf.table_key_index_size + self._key_index = LRUCache(limit=self.key_index_size) + except Exception as ex: # pragma: no cover + logging.getLogger(__name__).error(f"Error in Bigtable init {ex}") + raise ex + super().__init__(url, app, table, **kwargs) + + @staticmethod + def default_translator(user_key): + return user_key + + def _setup_mutation_batcher(self, options): + self._mutation_batcher_enable = options.get( + BigTableStore.BT_MUTATION_BATCHER_ENABLE_KEY, False + ) + self._mutation_batcher_cache = {} + if self._mutation_batcher_enable: + flush_count = options.get( + BigTableStore.BT_MUTATION_BATCHER_FLUSH_COUNT_KEY, 10_000 + ) + flush_interval = options.get( + BigTableStore.BT_MUTATION_BATCHER_FLUSH_INTERVAL_KEY, 300 + ) + self._mutation_batcher = MutationsBatcher( + self.bt_table, + flush_count=flush_count, + flush_interval=flush_interval, + batch_completed_callback=lambda x: self._mutation_batcher_cache.clear(), + ) + + def _setup_caches( + self, + options: Dict[str, Any] = None, + ): + self._startup_cache_enable = options.get( + BigTableStore.BT_STARTUP_CACHE_ENABLE_KEY, False + ) + + self._startup_cache = None + self._startup_cache_partitions: Set[int] = set() + self._startup_cache_ttl = options.get( + BigTableStore.BT_STARTUP_CACHE_TTL_KEY, -1 + ) + if self._startup_cache_enable: + self._startup_cache: Dict[bytes, bytes] = {} + self._invalidation_timer: Optional[threading.Timer] = None + + def _set_options(self, options) -> None: + self._all_options = options + self.table_name_generator = options.get( + BigTableStore.BT_TABLE_NAME_GENERATOR_KEY, lambda t: t.name + ) + self.row_filter = BT.CellsColumnLimitFilter(1) + self.offset_key_prefix = options.get( + BigTableStore.BT_OFFSET_KEY_PREFIX, "==>offset_for_partition_" + ) + + def _setup_bigtable(self, table, options: Dict[str, Any]): + self.bt_table_name = self.table_name_generator(table) + self.client: BT.Client = BT.Client( + options.get(BigTableStore.BT_PROJECT_KEY), + admin=True, + ) + self.instance: BT.Instance = self.client.instance( + options.get(BigTableStore.BT_INSTANCE_KEY) + ) + self.bt_table: BT.Table = self.instance.table(self.bt_table_name) + if not self.bt_table.exists(): + try: + self.bt_table.create( + column_families={ + COLUMN_FAMILY_ID: BT.column_family.MaxVersionsGCRule(1) + } + ) + except AlreadyExists: + logging.getLogger(__name__).info( + "BigTableStore: Using existing " + f"bigtablestore with {self.bt_table_name=} for {table.name} " + f"with {options=} due to AlreadyExists exception" + ) + return + logging.getLogger(__name__).info( + f"BigTableStore: Making new bigtablestore with {self.bt_table_name=} " + f"for {table.name} with {options=}" + ) + else: + logging.getLogger(__name__).info( + "BigTableStore: Using existing " + f"bigtablestore with {self.bt_table_name=} for {table.name} " + f"with {options=}" + ) + + def _add_partition_prefix_to_key( + self, key: bytes, partition: Optional[int] + ) -> bytes: + if partition is None: + return key + separator = b"_..._" + partition_bytes = str(partition).encode("utf-8") + return separator.join([partition_bytes, key]) + + def _remove_partition_prefix_from_bigtable_key(self, key: bytes) -> bytes: + separator = b"_..._" + key = key.rsplit(separator, 1)[-1] + return key + + def _get_partition_from_bigtable_key(self, key: bytes) -> int: + separator = b"_..._" + partition_str, _ = key.rsplit(separator, 1) + return int(partition_str) + + def _active_partitions(self) -> List[int]: + actives = self.app.assignor.assigned_actives() + topic = self.table.changelog_topic_name + partitions = [] + for partition in range(self.app.conf.topic_partitions): + tp = TP(topic=topic, partition=partition) + if tp in actives or self.table.is_global: + partitions.append(partition) + return partitions + + def _get_current_partitions(self) -> List[int]: + event = current_event() + if ( + event is not None + and event.message.topic is not None + and not self.table.is_global + and not self.table.use_partitioner + ): + partition = event.message.partition + return [partition] + return self._active_partitions() + + def _get_partitions_for_key(self, key: bytes) -> List[int]: + if key in self._key_index: + return [self._key_index[key]] + return self._get_current_partitions() + + @staticmethod + def bigtable_extract_row_data(row_data): + return list(row_data.to_dict().values())[0][0].value + + def _del_cache(self, key: bytes): + if self._startup_cache is not None: + self._startup_cache[key] = None + + def _set_cache(self, key: bytes, value): + if self._startup_cache is not None: + self._startup_cache[key] = value + + def _get_cache(self, key: bytes): + if self._startup_cache_enable and self._startup_cache is not None: + if key in self._startup_cache: + return self._startup_cache[key], True + return None, False + + def _invalidate_startup_cache(self): + if self._startup_cache is not None: + self._startup_cache.clear() + self._startup_cache_partitions = set() + gc.collect() + self.log.info(f"Invalidated startup cache for table {self.table_name}") + self._invalidation_timer.cancel() + del self._invalidation_timer + self._invalidation_timer = None + + def _set_mutation(self, key: bytes, value: Optional[bytes], mutated_row: DirectRow): + self._mutation_batcher.mutate(mutated_row) + self._mutation_batcher_cache[key] = value + + def _bigtable_get(self, keys: List[bytes]) -> Tuple[Optional[bytes], Optional[int]]: + rowset = BT.RowSet() + for key in keys: + if self._mutation_batcher_enable and key in self._mutation_batcher_cache: + partition = self._get_partition_from_bigtable_key(key) + value = self._mutation_batcher_cache[key] + if value is not None: + # Since deletes can happen async we need to make sure + # that we don't return a value for a delete that happened on + # another partition + return value, partition + rowset.add_row_key(key) + + rows = self.bt_table.read_rows(row_set=rowset, filter_=self.row_filter) + for row in rows: + if row is not None: + partition = self._get_partition_from_bigtable_key(row.row_key) + return self.bigtable_extract_row_data(row), partition + return None, None + + def _get(self, key: bytes) -> Optional[bytes]: + try: + value, found = self._get_cache(key) + if found: + return value + + partitions = set(self._get_partitions_for_key(key)) + # Remove partitions that we already have in cache + partitions.difference_update(self._startup_cache_partitions) + # Nothing todo + if len(partitions) == 0: + return None + + keys = [self._add_partition_prefix_to_key(key, p) for p in partitions] + value, partition = self._bigtable_get(keys) + if value is not None: + self._key_index[key] = partition + return value + except Exception as ex: # pragma: no cover + self.log.error( + f"Error in get for table {self.table_name} exception {ex} key {key}" + ) + raise ex + + def _bigtable_set(self, key: bytes, value: bytes): + row = self.bt_table.direct_row(key) + row.set_cell( + COLUMN_FAMILY_ID, + COLUMN_NAME, + value, + ) + + if self._mutation_batcher_enable: + self._set_mutation(key, value, row) + else: + row.commit() + + def _set(self, key: bytes, value: bytes) -> None: + try: + self._set_cache(key, value) + + event = current_event() + assert event is not None + partition = event.message.partition + key = self._add_partition_prefix_to_key(key, partition) + + self._bigtable_set(key, value) + self._key_index[key] = partition + except Exception as ex: # pragma: no cover + self.log.error( + f"FaustBigtableException Error in set for " + f"table {self.table_name} exception {ex} key {key=} " + f"{value=} Traceback: {traceback.format_exc()}" + ) + raise ex + + def _bigtable_del(self, key: bytes): + row = self.bt_table.direct_row(key) + row.delete() + if self._mutation_batcher_enable: + self._set_mutation(key, None, row) + else: + row.commit() + + def _del(self, key: bytes) -> None: + try: + self._del_cache(key) + partitions = self._get_partitions_for_key(key) + for partition in partitions: + key_with_partition = self._add_partition_prefix_to_key(key, partition) + self._bigtable_del(key_with_partition) + except Exception as ex: # pragma: no cover + self.log.error( + f"FaustBigtableException Error in del for " + f"table {self.table_name} exception {ex} key {key=} " + f"Traceback: {traceback.format_exc()}" + ) + raise ex + + def _bigtable_iteritems(self, partitions): + try: + start = time.time() + if partitions is None: + partitions = self._active_partitions() + row_set = BT.RowSet() + self.log.info( + f"BigtableStore: Iterating over {len(partitions)} partitions " + f"for table {self.table_name}" + ) + + need_all_keys = self.table.is_global or self.table.use_partitioner + if not need_all_keys: + for partition in partitions: + prefix = self._add_partition_prefix_to_key(b"", partition).decode() + row_set.add_row_range_with_prefix(prefix) + + if self._mutation_batcher_enable: + self._mutation_batcher.flush() + + offset_key_prefix = self.offset_key_prefix.encode() + for row in self.bt_table.read_rows( + row_set=row_set, filter_=self.row_filter + ): + # abort it key is an offset key + if need_all_keys and offset_key_prefix in row.row_key: + continue + + value = self.bigtable_extract_row_data(row) + key = self._remove_partition_prefix_from_bigtable_key(row.row_key) + yield key, value + end = time.time() + self.log.info( + f"{self.table_name} _bigtable_iteritems took {end - start}s " + f"for partitions {partitions}" + ) + except Exception as ex: # pragma: no cover + self.log.error( + f"FaustBigtableException Error " + f"in _iteritems for table {self.table_name}" + f" exception {ex}" + ) + raise ex + + def _iteritems( + self, partitions: Optional[List[int]] = None + ) -> Iterator[Tuple[bytes, bytes]]: + if self._startup_cache is not None: + if partitions is None: + partitions = set(self._active_partitions()) + for k, v in self._startup_cache.items(): + if v is not None: + yield k, v + partitions.difference_update(self._startup_cache_partitions) + + if partitions is None or len(partitions) > 0: + yield from self._bigtable_iteritems(partitions) + + def _iterkeys(self) -> Iterator[bytes]: + for row in self._iteritems(): + yield row[0] + + def _itervalues(self) -> Iterator[bytes]: + for row in self._iteritems(): + yield row[1] + + def _size(self) -> int: + """Always returns 0 for Bigtable.""" + return 0 + + def _contains(self, key: bytes) -> bool: + try: + if not self.app.conf.store_check_exists: + return True + return self._get(key) is not None + except Exception as ex: # pragma: no cover + self.log.error( + f"FaustBigtableException Error in _contains for table " + f"{self.table_name} exception " + f"{ex} key {key}. " + f"Traceback: {traceback.format_exc()}" + ) + raise ex + + def _clear(self) -> None: + """This is typically used to clear data. + + This does nothing when using the Bigtable store. + + """ + ... + + def reset_state(self) -> None: + """Remove system state. + + This does nothing when using the Bigtable store. + + """ + ... + + def get_offset_key(self, tp: TP): + return self.offset_key_prefix + str(tp.partition) + + def persisted_offset(self, tp: TP) -> Optional[int]: + """Return the last persisted offset. + See :meth:`set_persisted_offset`. + """ + offset_key = self.get_offset_key(tp).encode() + if self._mutation_batcher_enable: + self._mutation_batcher.flush() + row = self.bt_table.read_row(offset_key, filter_=self.row_filter) + offset = self.bigtable_extract_row_data(row) if row is not None else None + return int(offset) if offset is not None else None + + def set_persisted_offset(self, tp: TP, offset: int) -> None: + """Set the last persisted offset for this table. + + This will remember the last offset that we wrote to BigTableStore, + so that on rebalance/recovery we can seek past this point + to only read the events that occurred recently while + we were not an active replica. + """ + try: + offset_key = self.get_offset_key(tp).encode() + self._bigtable_set(offset_key, str(offset).encode()) + except Exception: # pragma: no cover + self.log.error( + f"Failed to commit offset for {self.table.name}" + " -> will cause additional changelogs if restart happens" + f"TRACEBACK: {traceback.format_exc()}" + ) + + def apply_changelog_batch( + self, + batch: Iterable[EventT], + to_key: Callable[[Any], Any], + to_value: Callable[[Any], Any], + ) -> None: + """Write batch of changelog events to local BigTableStore storage. + + Arguments: + batch: Iterable of changelog events (:class:`faust.Event`) + to_key: A callable you can use to deserialize the key + of a changelog event. + to_value: A callable you can use to deserialize the value + of a changelog event. + """ + tp_offsets: Dict[TP, int] = {} + for event in batch: + tp, offset = event.message.tp, event.message.offset + tp_offsets[tp] = ( + offset if tp not in tp_offsets else max(offset, tp_offsets[tp]) + ) + msg = event.message + bt_key = self._add_partition_prefix_to_key(msg.key, msg.partition) + + if msg.value is None: + self._del_cache(msg.key) + self._bigtable_del(bt_key) + else: + self._set_cache(msg.key, msg.value) + self._bigtable_set(bt_key, msg.value) + + for tp, offset in tp_offsets.items(): + self.set_persisted_offset(tp, offset) + + async def backup_partition( + self, + tp: Union[TP, int], + flush: bool = True, + purge: bool = False, + keep: int = 1, + ) -> None: + """Backup partition from this store. + + Not yet implemented for Bigtable. + + """ + raise NotImplementedError("Not yet implemented for Bigtable.") + + def restore_backup( + self, tp: Union[TP, int], latest: bool = True, backup_id: int = 0 + ) -> None: + """Restore partition backup from this store. + + Not yet implemented for Bigtable. + + """ + raise NotImplementedError("Not yet implemented for Bigtable.") + + def _fill_caches(self, partitions): + for k, v in self._bigtable_iteritems(partitions=partitions): + self._set_cache(k, v) + + self._startup_cache_partitions |= set(partitions) + # Invalidate startup cache after self._startup_cache_ttl + # or reset the timer if already running + if self._invalidation_timer is not None: + self._invalidation_timer.cancel() + del self._invalidation_timer + self._invalidation_timer = None + + if self._startup_cache_ttl > 0: + # if _startup_cache_ttl < 0 keep cache forever + self._invalidation_timer = threading.Timer( + self._startup_cache_ttl, self._invalidate_startup_cache + ) + self._invalidation_timer.start() + + def _get_active_changelogtopic_partitions( + self, table: CollectionT, tps: Set[TP] + ) -> Set[int]: + partitions = set() + standby_tps = self.app.assignor.assigned_standbys() + my_topics = table.changelog_topic.topics + for tp in tps: + if tp.topic in my_topics and tp not in standby_tps: + partitions.add(tp.partition) + return partitions + + async def assign_partitions( + self, table: CollectionT, tps: Set[TP], generation_id: int = 0 + ) -> None: + # Fill cache with all keys for the partitions we are assigned + partitions = self._get_active_changelogtopic_partitions(table, tps) + if self._startup_cache_enable is False: + return + + if len(partitions) == 0: + return + self.log.info(f"Assigning partitions {partitions} for {table.name}") + self._fill_caches(partitions) + + def revoke_partitions(self, table: CollectionT, tps: Set[TP]) -> None: + partitions = set() + for tp in tps: + if tp.topic in table.changelog_topic.topics: + partitions.add(tp.partition) + + if len(partitions) == 0: + return + + self._startup_cache_partitions.difference_update(partitions) + # The memory of the startup cache will be freed after the ttl is over + # TODO: Free memory that is not needed instantly + self.log.info(f"Revoking partitions {partitions} for {table.name}") + + async def on_rebalance( + self, + assigned: Set[TP], + revoked: Set[TP], + newly_assigned: Set[TP], + generation_id: int = 0, + ) -> None: + """Rebalance occurred. + + Arguments: + assigned: Set of all assigned topic partitions. + revoked: Set of newly revoked topic partitions. + newly_assigned: Set of newly assigned topic partitions, + for which we were not assigned the last time. + generation_id: the metadata generation identifier for the re-balance + """ + self.revoke_partitions(self.table, revoked) + await self.assign_partitions(self.table, newly_assigned, generation_id) + + async def stop(self) -> None: + if self._mutation_batcher_enable: + self.log.info("Flushing to bigtable on stop") + self._mutation_batcher.flush() diff --git a/faust/tables/recovery.py b/faust/tables/recovery.py index 11737aef0..fe2d2a2d7 100644 --- a/faust/tables/recovery.py +++ b/faust/tables/recovery.py @@ -876,7 +876,9 @@ async def detect_aborted_tx(): self._standbys_span = None self.tables.on_standbys_ready() except Exception as ex: - logger.warning(f"Error in recovery {ex}") + # logger.warning(f"Error in recovery {ex}") + # Write a warning with traceback to the log + logger.warning("Error in recovery", exc_info=ex) def flush_buffers(self) -> None: """Flush changelog buffers.""" diff --git a/faust/transport/drivers/aiokafka.py b/faust/transport/drivers/aiokafka.py index 8da59e5f6..4bc23c5a8 100644 --- a/faust/transport/drivers/aiokafka.py +++ b/faust/transport/drivers/aiokafka.py @@ -1,6 +1,7 @@ """Message transport using :pypi:`aiokafka`.""" import asyncio +import inspect import typing from asyncio import Lock, QueueEmpty from collections import deque @@ -1111,7 +1112,7 @@ def __post_init__(self) -> None: def _settings_default(self) -> Mapping[str, Any]: transport = cast(Transport, self.transport) - return { + settings = { "bootstrap_servers": server_list(transport.url, transport.default_port), "client_id": self.client_id, "acks": self.acks, @@ -1122,10 +1123,18 @@ def _settings_default(self) -> Mapping[str, Any]: "security_protocol": "SSL" if self.ssl_context else "PLAINTEXT", "partitioner": self.partitioner, "request_timeout_ms": int(self.request_timeout * 1000), - "api_version": self._api_version, "metadata_max_age_ms": self.app.conf.producer_metadata_max_age_ms, "connections_max_idle_ms": self.app.conf.producer_connections_max_idle_ms, } + if self._producer_accepts_api_version(): + settings["api_version"] = self._api_version + return settings + + def _producer_accepts_api_version(self) -> bool: + return ( + "api_version" + in inspect.signature(aiokafka.AIOKafkaProducer.__init__).parameters + ) def _settings_auth(self) -> Mapping[str, Any]: return credentials_to_aiokafka_auth(self.credentials, self.ssl_context) diff --git a/requirements/extras/bigtable.txt b/requirements/extras/bigtable.txt new file mode 100644 index 000000000..47acfcb87 --- /dev/null +++ b/requirements/extras/bigtable.txt @@ -0,0 +1 @@ +google-cloud-bigtable diff --git a/setup.py b/setup.py index e8047d085..6e55de1d9 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ BUNDLES = { "aiodns", "aiomonitor", + "bigtable", "cchardet", "ciso8601", "cython", diff --git a/tests/unit/stores/test_bigtable.py b/tests/unit/stores/test_bigtable.py new file mode 100644 index 000000000..23ee050a0 --- /dev/null +++ b/tests/unit/stores/test_bigtable.py @@ -0,0 +1,878 @@ +from unittest.mock import MagicMock, call, patch + +import pytest + +from faust.stores.bigtable import ( + BigTableStore, +) +from faust.types.tuples import TP + + +def to_bt_key(key): + len_total = len(key) + len_prefix = 5 + len_first_id = key[len_prefix] // 2 + if len_prefix + 1 + len_first_id + 1 >= len_total: + # This happens if there is e.g. no organisation id + return key + len_second_id = key[len_prefix + 1 + len_first_id + 1] // 2 + key_prefix = key[len_total - len_second_id :] + return key_prefix + key + + +def from_bt_key(key): + magic_byte_pos = key.find(bytes(4)) + if magic_byte_pos == 0: + return key + return key[magic_byte_pos:] + + +def get_preload_prefix_len(key) -> int: + preload_len = key.find(bytes(4)) + if preload_len == 0: + return len(key) + return preload_len + + +class MyTestResponse: + def __init__(self, code) -> None: + self.code = code + + +class RowSetMock: + # We will mock rowsets in a way that it is just a + # list with all requested keys, so that we then just call + # read_row of the mocked bigtable multiple times + def __init__(self) -> None: + self.keys = set() + self.add_row_key = MagicMock(wraps=self._add_row_key) + self.add_row_range_from_keys = MagicMock(wraps=self._add_row_range_from_keys) + self.add_row_range_with_prefix = MagicMock( + wraps=self._add_row_range_with_prefix + ) + + def _add_row_key(self, key): + self.keys.add(key) + + def _add_row_range_with_prefix(self, prefix): + if isinstance(prefix, str): + prefix = prefix.encode() + self._add_row_range_from_keys(prefix, prefix, end_inclusive=True) + + def _add_row_range_from_keys( + self, start_key: bytes, end_key: bytes, end_inclusive=False + ): + if isinstance(start_key, str): + start_key = start_key.encode() + if isinstance(end_key, str): + end_key = end_key.encode() + if end_inclusive: + self.keys.add(b"".join([start_key, b"_*ei_", end_key])) + else: + self.keys.add(b"".join([start_key, b"_*_", end_key])) + + +class BigTableMock: + def __init__(self) -> None: + self.data = {} + self.read_row = MagicMock(wraps=self._read_row) + self.read_rows = MagicMock(wraps=self._read_rows) + self.name = "test_bigtable" + + def _read_row(self, key: bytes, **kwargs): + res = self.data.get(key, None) + cell_wrapper = MagicMock() + cell_wrapper.value = res + row_wrapper = [cell_wrapper] + if res is None: + return res + row = MagicMock() + row.row_key = key + row.to_dict = MagicMock(return_value={"x": row_wrapper}) + return row + + def _read_rows(self, row_set, **kwargs): + iterator = row_set.keys + if len(iterator) == 0: + iterator = self.data.keys() + for k in iterator: + res = None + if b"_*_" in k: + for key in self.data.keys(): + start, end = k.split(b"_*_") + if start <= key[: len(end)] < end: + yield self._read_row(key) + continue + elif b"_*ei_" in k: + for key in self.data.keys(): + start, end = k.split(b"_*ei_") + if start <= key[: len(end)] <= end: + yield self._read_row(key) + continue + else: + res = self._read_row(k) + if res is None: + continue + else: + yield res + + def add_test_data(self, keys): + for k in keys: + self.data[k] = k + + +class TestBigTableStore: + TEST_KEY1 = b"TEST_KEY1" + TEST_KEY2 = b"TEST_KEY2" + TEST_KEY3 = b"TEST_KEY3" + TEST_KEY4 = b"\x00\x00\x00\x00\x01\x0eNoGroup\x00063d76e3ebd7e634de234c67d" + TEST_KEY5 = ( + b"\x00\x00\x00\x00\x02062a99788df917508d1891ed2\x00062a99788df917508d1891ed2" + ) + TEST_KEY6 = b"\x00\x00\x00\x00\x02062a99788df917508d1891ed2\x02" + + @pytest.fixture() + def bt_imports(self): + with patch("faust.stores.bigtable.BT") as bt: + bt.CellsColumnLimitFilter = MagicMock(return_value="a_filter") + bt.column_family.MaxVersionsGCRule = MagicMock(return_value="a_rule") + bt.RowSet = MagicMock(return_value=RowSetMock()) + yield bt + + @pytest.mark.asyncio + async def test_bigtable_set_options_default(self, bt_imports): + self_mock = MagicMock() + bt_imports.CellsColumnLimitFilter = MagicMock(return_value="a_filter") + + BigTableStore._set_options(self_mock, options={}) + assert self_mock.offset_key_prefix == "==>offset_for_partition_" + assert self_mock.row_filter == "a_filter" + + @pytest.mark.asyncio + async def test_bigtable_set_options(self, bt_imports): + self_mock = MagicMock() + bt_imports.CellsColumnLimitFilter = MagicMock(return_value="a_filter") + bt_imports.column_family = MagicMock(return_value=MagicMock()) + name_lambda = lambda x: print(x) # noqa + + def to_bt_key(key): + len_total = len(key) + len_prefix = 4 + len_num_bytes_len = key[len_prefix] // 2 + len_first_id = key[len_prefix + len_num_bytes_len] // 2 + len_second_id = ( + key[len_prefix + 1 + len_num_bytes_len + len_first_id + 1] // 2 + ) + key_prefix = key[len_total - len_second_id :] + return key_prefix + key + + def from_bt_key(key): + return key[key.find(b"\x00\x00\x00") :] + + options = { + BigTableStore.BT_TABLE_NAME_GENERATOR_KEY: name_lambda, + BigTableStore.BT_OFFSET_KEY_PREFIX: "offset_test", + } + BigTableStore._set_options(self_mock, options) + assert self_mock.offset_key_prefix == "offset_test" + assert self_mock.row_filter == "a_filter" + assert self_mock.table_name_generator == name_lambda + + @pytest.mark.asyncio + async def test_bigtable_setup(self, bt_imports): + self_mock = MagicMock() + + faust_table_mock = MagicMock() + faust_table_mock.name = MagicMock(return_value="ABC") + + def table_name_gen(table): + return table.name[::-1] + + self_mock.table_name_generator = table_name_gen + self_mock.bt_table_name = self_mock.table_name_generator(faust_table_mock) + + client_mock = MagicMock() + instance_mock = MagicMock() + table_mock = MagicMock() + + client_mock.instance = MagicMock(return_value=instance_mock) + instance_mock.table = MagicMock(return_value=table_mock) + table_mock.exists = MagicMock(return_value=True) + table_mock.create = MagicMock() + + bt_imports.Client = MagicMock(return_value=client_mock) + options = {} + options[BigTableStore.BT_INSTANCE_KEY] = "bt_instance" + options[BigTableStore.BT_PROJECT_KEY] = "bt_project" + + return_value = BigTableStore._setup_bigtable( + self_mock, faust_table_mock, options + ) + bt_imports.Client.assert_called_once_with( + options[BigTableStore.BT_PROJECT_KEY], admin=True + ) + client_mock.instance.assert_called_once_with( + options[BigTableStore.BT_INSTANCE_KEY] + ) + + instance_mock.table.assert_called_once_with(self_mock.bt_table_name) + table_mock.create.assert_not_called() + assert return_value is None + + # Test with no existing table + self_mock.reset_mock() + self_mock.table_name_generator = table_name_gen + self_mock.bt_table_name = self_mock.table_name_generator(faust_table_mock) + table_mock.exists = MagicMock(return_value=False) + return_value = BigTableStore._setup_bigtable( + self_mock, faust_table_mock, options + ) + instance_mock.table.assert_called_once_with(self_mock.bt_table_name) + table_mock.create.assert_called_once_with( + column_families={"FaustColumnFamily": "a_rule"} + ) + assert return_value is None + + @pytest.fixture() + def store(self, bt_imports): + with patch("faust.stores.bigtable.BT", bt_imports): + options = {} + options[BigTableStore.BT_INSTANCE_KEY] = "bt_instance" + options[BigTableStore.BT_PROJECT_KEY] = "bt_project" + store = BigTableStore( + "bigtable://", MagicMock(), MagicMock(), options=options + ) + store.bt_table = BigTableMock() + return store + + def test_bigtable_get(self, store, bt_imports): + keys = [self.TEST_KEY1, self.TEST_KEY2] + for idx, k in enumerate(keys): + keys[idx] = store._add_partition_prefix_to_key(k, 2) + store.bt_table.add_test_data(keys) + + # Test get from bigtable + value, partition = store._bigtable_get([keys[1]]) + store.bt_table.read_rows.assert_called_once() + assert partition == 2 + assert value == keys[1] + + # Test get from mutation buffer + store._mutation_batcher_enable = True + store._mutation_batcher_cache = {keys[1]: b"123"} + value, partition = store._bigtable_get([keys[1]]) + store.bt_table.read_rows.assert_called_once() + assert value == b"123" + assert partition == 2 + + def test_bigtable_get_on_empty(self, store, bt_imports): + return_value = store._bigtable_get([self.TEST_KEY1, self.TEST_KEY2]) + store.bt_table.read_rows.assert_called_once() + assert return_value == (None, None) + + def test_bigtable_delete(self, store): + row_mock = MagicMock() + row_mock.commit = MagicMock() + row_mock.delete = MagicMock() + store.bt_table.direct_row = MagicMock(return_value=row_mock) + store._set_mutation = MagicMock() + + store._bigtable_del(self.TEST_KEY1) + store._set_mutation.assert_not_called() + row_mock.delete.assert_called_once() + row_mock.commit.assert_called_once() + + # Test with mutation buffer + store._mutation_batcher_enable = True + store._bigtable_del(self.TEST_KEY1) + store._set_mutation.assert_called_once_with(self.TEST_KEY1, None, row_mock) + assert row_mock.delete.call_count == 2 + assert row_mock.commit.call_count == 1 + + def test_bigtable_set(self, store): + row_mock = MagicMock() + row_mock.set_cell = MagicMock() + row_mock.commit = MagicMock() + store.bt_table.direct_row = MagicMock(return_value=row_mock) + store._set_mutation = MagicMock() + + store._bigtable_set(self.TEST_KEY1, b"a_value") + store._set_mutation.assert_not_called() + row_mock.set_cell.assert_called_once() + row_mock.commit.assert_called_once() + + # Test with mutation buffer + store._mutation_batcher_enable = True + store._bigtable_set(self.TEST_KEY1, "a_value") + store._set_mutation.assert_called_once_with(self.TEST_KEY1, "a_value", row_mock) + assert row_mock.set_cell.call_count == 2 + assert row_mock.commit.call_count == 1 + + def test_get_partition_from_message(self, store): + event_mock = MagicMock() + event_mock.message = MagicMock() + event_mock.message.partition = 69 + current_event_mock = MagicMock(return_value=event_mock) + + store.table.is_global = False + store.table.use_partitioner = False + topic = store.table.changelog_topic_name + store.app.assignor.assigned_actives = MagicMock( + return_value={TP(topic, 123), TP(topic, 69)} + ) + store.app.conf.topic_partitions = 123 + with patch("faust.stores.bigtable.current_event", current_event_mock): + return_value = store._get_current_partitions() + assert return_value == [69] + + store.table.is_global = True + with patch("faust.stores.bigtable.current_event", current_event_mock): + return_value = store._get_current_partitions() + assert return_value == list(range(123)) + + store.table.is_global = False + current_event_mock = MagicMock(return_value=None) + + with patch("faust.stores.bigtable.current_event", current_event_mock): + return_value = store._get_current_partitions() + assert return_value == [69] + + def test_get_faust_key(self, store): + key_with_partition = b"\x13_..._THEACTUALKEY" + res = store._remove_partition_prefix_from_bigtable_key(key_with_partition) + assert res == b"THEACTUALKEY" + + def test_get_key_with_partition(self, store): + partition = 19 + res = store._add_partition_prefix_to_key(self.TEST_KEY1, partition) + extracted_partition = store._get_partition_from_bigtable_key(res) + assert extracted_partition == partition + assert store._remove_partition_prefix_from_bigtable_key(res) == self.TEST_KEY1 + + def test_partitions_for_key(self, store): + store._get_current_partitions = MagicMock(return_value=[19]) + res = list(store._get_partitions_for_key(self.TEST_KEY1)) + assert res == [19] + + def test_get_keyerror(self, store): + partition = None + store._get_current_partitions = MagicMock(return_value=[partition]) + store._bigtable_get = MagicMock(return_value=(None, None)) + with pytest.raises(KeyError): + key = "123" + store[key] + + def test_get_with_known_partition(self, store): + partitions = [19, 20] + store._get_cache = MagicMock(return_value=(b"this is ignored", False)) + store._key_index = {} + store._get_current_partitions = MagicMock(return_value=partitions) + # Scenario: Found + store._bigtable_get = MagicMock(return_value=(b"a_value", 19)) + + res = store._get(self.TEST_KEY1) + get_keys = [ + store._add_partition_prefix_to_key(self.TEST_KEY1, p) for p in partitions + ] + store._bigtable_get.assert_called_once_with(get_keys) + assert res == b"a_value" + + # Scenario: Not Found + store._bigtable_get = MagicMock(return_value=(None, None)) + res = store._get(self.TEST_KEY1) + store._bigtable_get.assert_called_with( + [get_keys[0]] + ) # because the partition is in key_index + assert res is None + + # Scenario: Cache hit on value + store._get_cache = MagicMock(return_value=(b"a_value_from_cache", True)) + store._bigtable_get = MagicMock(return_value=(None, None)) + res = store._get(self.TEST_KEY1) + store._bigtable_get.assert_not_called() + assert res == b"a_value_from_cache" + + # Scenario: Cache hit on None value + store._get_cache = MagicMock(return_value=(None, True)) + store._bigtable_get = MagicMock(return_value=(None, None)) + res = store._get(self.TEST_KEY2) + store._bigtable_get.assert_not_called() + assert res is None + + # Scenario: Cache miss, but partition should be in startup cache + store._startup_cache_partitions = {19, 20} + store._get_cache = MagicMock(return_value=(None, False)) + store._bigtable_get = MagicMock(return_value=(None, None)) + res = store._get(self.TEST_KEY2) + store._bigtable_get.assert_not_called() + assert res is None + + def test_set(self, store): + # Scenario: No cache + event_mock = MagicMock() + event_mock.message = MagicMock() + event_mock.message.partition = 69 + current_event_mock = MagicMock(return_value=event_mock) + no_event_mock = MagicMock(return_value=None) + + # Test assertion withour current event + with patch("faust.stores.bigtable.current_event", no_event_mock): + with pytest.raises(AssertionError): + store["123"] = "000" + + with patch("faust.stores.bigtable.current_event", current_event_mock): + store._key_index = {} + store._set_cache = MagicMock() + store._bigtable_set = MagicMock() + store._set(self.TEST_KEY1, b"a_value") + + key = store._add_partition_prefix_to_key(self.TEST_KEY1, 69) + store._set_cache.assert_called_with(self.TEST_KEY1, b"a_value") + store._bigtable_set.assert_called_once_with(key, b"a_value") + + def test_del(self, store): + # Scenario: No cache + store._bigtable_del = MagicMock() + store._del_cache = MagicMock(return_value=None) + store._get_partitions_for_key = MagicMock(return_value=[1, 2, 3]) + store._del(self.TEST_KEY1) + # Check one call for each partition + keys = [ + store._add_partition_prefix_to_key(self.TEST_KEY1, p) for p in [1, 2, 3] + ] + store._del_cache.assert_called_once_with(self.TEST_KEY1) + assert store._bigtable_del.call_count == 3 + expected_calls = [call(key) for key in keys] + for call_args in store._bigtable_del.call_args_list: + assert call_args in expected_calls + + def test_active_partitions(self, store): + active_topics = [ + TP("a_changelogtopic", 19), + TP("a_different_chaneglogtopic", 19), + ] + store.app.assignor.assigned_actives = MagicMock(return_value=active_topics) + store.app.conf.topic_partitions = 20 + store.table.changelog_topic_name = "a_changelogtopic" + store.table.is_global = False + + # Scenario: No global table + res = store._active_partitions() + all_res = list(res) + assert all_res == [19] + + # Scenario: Global table + store.table.is_global = True + res = store._active_partitions() + all_res = list(res) + assert list(range(store.app.conf.topic_partitions)) == all_res + + def test_iteritems(self, store): + store._active_partitions = MagicMock(return_value=[1, 3]) + store._bigtable_iteritems = MagicMock(wraps=store._bigtable_iteritems) + store.bt_table.read_rows = MagicMock() + _ = sorted(store._iteritems()) + store.bt_table.read_rows.assert_called_once() + # Calling with None means get all rows + store._bigtable_iteritems.assert_called_once_with(None) + + def test_iteritems_with_startup_cache(self, store, bt_imports): + store._active_partitions = MagicMock(return_value=[1, 3]) + store._startup_cache = { + self.TEST_KEY1: b"this is a value", + self.TEST_KEY2: b"this is another value", + b"Dont return this": None, + } + store._startup_cache_partitions = [1] + + store._bigtable_iteritems = MagicMock(wraps=store._bigtable_iteritems) + store.bt_table.read_rows = MagicMock( + return_value=[ + MagicMock( + row_key=store._add_partition_prefix_to_key(self.TEST_KEY3, 3), + to_dict=MagicMock(return_value={"x": [MagicMock(value=b"1")]}), + commit=MagicMock(), + ), + MagicMock( + row_key=store._add_partition_prefix_to_key(self.TEST_KEY4, 3), + to_dict=MagicMock(return_value={"x": [MagicMock(value=b"2")]}), + commit=MagicMock(), + ), + ] + ) + res = sorted(store._iteritems()) + store._bigtable_iteritems.assert_called_once_with({3}) + all_entries = { + self.TEST_KEY1: b"this is a value", + self.TEST_KEY2: b"this is another value", + self.TEST_KEY3: b"1", + self.TEST_KEY4: b"2", + } + assert res == sorted(all_entries.items()) + keys = sorted(store._iterkeys()) + values = sorted(store._itervalues()) + + assert keys == sorted(all_entries.keys()) + assert values == sorted(all_entries.values()) + + def test_iterkeys(self, store): + values = [("K1", "V1"), ("K2", "V2")] + store._iteritems = MagicMock(return_value=values) + all_res = sorted(store._iterkeys()) + assert all_res == ["K1", "K2"] + + def test_itervalues(self, store): + values = [("K1", "V1"), ("K2", "V2")] + store._iteritems = MagicMock(return_value=values) + all_res = sorted(store._itervalues()) + assert all_res == ["V1", "V2"] + + def test_size(self, store): + assert 0 == store._size() + + def test_get_offset_key(self, store): + tp = TP("AAAA", 19) + assert store.get_offset_key(tp)[-2:] == "19" + + def test_set_persisted_offset(self, store): + tp = TP("a_topic", 19) + expected_offset_key = store.get_offset_key(tp).encode() + store._bigtable_set = MagicMock() + + store.set_persisted_offset(tp, 123) + store._bigtable_set.assert_called_once_with(expected_offset_key, b"123") + + def test_apply_changelog_batch(self, store): + row_mock = MagicMock() + row_mock.delete = MagicMock() + row_mock.set_cell = MagicMock() + store.bt_table.direct_row = MagicMock(return_value=row_mock) + store._bigtable_del = MagicMock() + store._bigtable_set = MagicMock() + store.set_persisted_offset = MagicMock() + store._set_cache = MagicMock() + store._del_cache = MagicMock() + + class TestMessage: + def __init__(self, value, key, tp, offset, partition): + self.value = value + self.key = key + self.tp = tp + self.offset = offset + self.partition = partition + + class TestEvent: + def __init__(self, message): + self.message = message + + tp = TP("a", 19) + tp2 = TP("b", 19) + messages = [ + TestEvent(TestMessage("a", self.TEST_KEY1, tp, 0, 1)), + TestEvent(TestMessage(None, self.TEST_KEY1, tp, 1, 1)), # Delete + TestEvent(TestMessage("a", self.TEST_KEY1, tp, 3, 1)), # Out of order + TestEvent(TestMessage("b", self.TEST_KEY2, tp2, 4, 2)), + TestEvent(TestMessage("a", self.TEST_KEY1, tp, 2, 1)), + ] + store.apply_changelog_batch(messages, lambda x: x, lambda x: x) + assert store._bigtable_set.call_count == 4 + assert store._bigtable_del.call_count == 1 + assert store.set_persisted_offset.call_count == 2 + + @pytest.mark.asyncio + async def test_fill_caches(self, store, bt_imports): + store._bigtable_iteritems = MagicMock( + return_value=[(b"key1", b"value1"), (b"key2", b"value2")] + ) + store._set_cache = MagicMock() + store._startup_cache_ttl = 1800 + store._invalidation_timer = None + store._startup_cache_partitions = set() + store._startup_cache = {} + + partitions = {TP("topic", 0), TP("topic", 1)} + partitions2 = {TP("topic", 0), TP("topic", 2)} + store._fill_caches(partitions) + + assert store._bigtable_iteritems.call_args == call(partitions=partitions) + assert store._set_cache.call_args_list == [ + call(b"key1", b"value1"), + call(b"key2", b"value2"), + ] + assert store._startup_cache_partitions == partitions + assert store._invalidation_timer is not None + assert store._invalidation_timer.is_alive() + + # Test with different partitions + # This should reset the _invalidation_timer + old_invalid_timer = store._invalidation_timer.__hash__() + + store._bigtable_iteritems = MagicMock( + return_value=[(b"key3", b"value3"), (b"key4", b"value4")] + ) + store._set_cache = MagicMock() + store._fill_caches(partitions2) + new_invalid_timer = store._invalidation_timer.__hash__() + # Check if old invalidation timer is different from new one + assert old_invalid_timer != new_invalid_timer + assert store._invalidation_timer is not None + assert store._invalidation_timer.is_alive() + + assert store._bigtable_iteritems.call_args == call(partitions=partitions2) + assert store._set_cache.call_args_list == [ + call(b"key3", b"value3"), + call(b"key4", b"value4"), + ] + assert store._startup_cache_partitions == partitions | partitions2 + assert store._invalidation_timer is not None + assert store._invalidation_timer.is_alive() + + # Wait for the invalidation timer to expire + store._invalidation_timer.cancel() + store._invalidate_startup_cache() + + assert store._startup_cache == {} + assert store._startup_cache_partitions == set() + assert store._invalidation_timer is None + + @pytest.mark.asyncio + async def test_fill_caches_no_ttl(self, store, bt_imports): + store._bigtable_iteritems = MagicMock( + return_value=[(b"key1", b"value1"), (b"key2", b"value2")] + ) + store._set_cache = MagicMock() + store._startup_cache_ttl = 0 + store._invalidation_timer = None + store._startup_cache_partitions = set() + store._startup_cache = {} + + partitions = {TP("topic", 0), TP("topic", 1)} + + store._fill_caches(partitions) + + assert store._bigtable_iteritems.call_args == call(partitions=partitions) + assert store._set_cache.call_args_list == [ + call(b"key1", b"value1"), + call(b"key2", b"value2"), + ] + assert store._startup_cache_partitions == partitions + assert store._invalidation_timer is None + + @pytest.mark.asyncio + async def test__get_active_changelogtopic_partitions(self, store): + tps_table = { + "changelog_topic", + "other_topic", + "other_topic2", + } + store.table = MagicMock(changelog_topic=MagicMock(topics=tps_table)) + + tps = {TP("changelog_topic", 0), TP("other_topic", 1)} + active_partitions = store._get_active_changelogtopic_partitions( + store.table, tps + ) + assert active_partitions == {0, 1} + + @pytest.mark.asyncio + async def test_bigtable_on_rebalance(self, store, bt_imports): + store.assign_partitions = MagicMock(wraps=store.assign_partitions) + store.revoke_partitions = MagicMock(wraps=store.revoke_partitions) + + tps_table = { + "topic1", + "topic2", + "topic3", + "topic4", + "topic5", + } + store.table = MagicMock(changelog_topic=MagicMock(topics=tps_table)) + + store._fill_caches = MagicMock() + assigned = {TP("topic1", 0), TP("topic2", 1)} + revoked = {TP("topic3", 2)} + newly_assigned = {TP("topic4", 3), TP("topic5", 4)} + store._startup_cache_enable = False + await store.on_rebalance(assigned, revoked, newly_assigned, generation_id=1) + store.assign_partitions.assert_called_once_with(store.table, newly_assigned, 1) + store.revoke_partitions.assert_called_once_with(store.table, revoked) + store._fill_caches.assert_not_called() + newly_assigned = set() + + # Test with empty newly_assigned + store._startup_cache_enable = True + await store.on_rebalance(assigned, revoked, newly_assigned, generation_id=2) + store.assign_partitions.assert_called_with(store.table, newly_assigned, 2) + store._fill_caches.assert_not_called() + + store._startup_cache_enable = True + newly_assigned = {TP("topic4", 3), TP("topic5", 4)} + await store.on_rebalance(assigned, revoked, newly_assigned, generation_id=3) + store.assign_partitions.assert_called_with(store.table, newly_assigned, 3) + store._fill_caches.assert_called_once_with({3, 4}) + + def test_revoke_partitions(self, store): + store._startup_cache_partitions = {1, 2, 3} + store._startup_cache = {b"key1": b"value1", b"key2": b"value2"} + revoked = {TP("topic", 1), TP("topic", 2)} + store.table = MagicMock(changelog_topic=MagicMock(topics={"topic"})) + store.revoke_partitions(store.table, revoked) + assert store._startup_cache_partitions == {3} + + def test_contains(self, store, bt_imports): + store._get = MagicMock(return_value=b"test_value") + + # Test that _contains returns True when store_check_exists is False + store.app.conf.store_check_exists = False + assert store._contains(b"test_key") is True + + # Test that _contains returns True when _get returns a value + store.app.conf.store_check_exists = True + assert store._contains(b"test_key") is True + + # Test that _contains returns False when _get returns None + store._get = MagicMock(return_value=None) + assert store._contains(b"test_key") is False + + def test_setup_caches_startup_cache_enable(self, store): + options = { + BigTableStore.BT_STARTUP_CACHE_ENABLE_KEY: True, + BigTableStore.BT_STARTUP_CACHE_TTL_KEY: 60, + } + store._setup_caches(options=options) + assert store._startup_cache_enable is True + assert store._startup_cache_ttl == 60 + assert isinstance(store._startup_cache, dict) + assert isinstance(store._startup_cache_partitions, set) + assert store._invalidation_timer is None + + def test_setup_caches_startup_cache_disable(self, store): + options = { + BigTableStore.BT_STARTUP_CACHE_ENABLE_KEY: False, + } + store._setup_caches(options=options) + assert store._startup_cache_enable is False + assert store._startup_cache_ttl == -1 # Default value + assert store._startup_cache is None + assert store._startup_cache_partitions == set() + assert store._startup_cache_enable is False + + def test_set_del_get_cache(self, store): + store._startup_cache_enable = False + store._startup_cache = None + store._startup_cache_partitions = set() + + key = self.TEST_KEY1 + + store._set_cache(key, b"123") + res = store._get_cache(key) + assert store._startup_cache is None + assert store._startup_cache_partitions == set() + assert res == (None, False) + + store._del_cache(key) + res = store._get_cache(key) + assert res == (None, False) + assert store._startup_cache is None + assert store._startup_cache_partitions == set() + + # Now with enabled startup cache + store._startup_cache_enable = True + store._startup_cache = {} + store._startup_cache_partitions = {1, 2} + + store._set_cache(key, b"123") + res = store._get_cache(key) + assert store._startup_cache == {key: b"123"} + assert store._startup_cache_partitions == {1, 2} + assert res == (b"123", True) + store._del_cache(key) + res = store._get_cache(key) + assert store._startup_cache == {key: None} + assert store._startup_cache_partitions == {1, 2} + assert res == (None, True) + + def test_persisted_offset(self, store): + tp = TP("topic", 0) + offset_key = store.get_offset_key(tp).encode() + store.bt_table.data = {offset_key: b"1"} + store._mutation_batcher = MagicMock(flush=MagicMock()) + + assert store.persisted_offset(tp) == 1 + store._mutation_batcher.flush.assert_not_called() + + store._mutation_batcher_enable = True + assert store.persisted_offset(tp) == 1 + store._mutation_batcher.flush.assert_called_once() + + @pytest.mark.asyncio + async def test_stop(self, store): + store._mutation_batcher = MagicMock(flush=MagicMock()) + store._mutation_batcher_enable = False + await store.stop() + store._mutation_batcher.flush.assert_not_called() + + store._mutation_batcher_enable = True + await store.stop() + store._mutation_batcher.flush.assert_called_once() + + def test_set_mutation(self, store): + store._mutation_batcher = MagicMock(flush=MagicMock()) + store._set_mutation(self.TEST_KEY1, b"123", MagicMock()) + store._mutation_batcher.flush.assert_not_called() + assert store._mutation_batcher_cache[self.TEST_KEY1] == b"123" + + def test_bigtable_iteritems_with_global_table(self, store, bt_imports): + store.table.is_global = True + store._active_partitions = MagicMock(return_value=[1, 3]) + # Add table to data fro partition 1 to 5 with corresponding offset keys + store.bt_table.data = {} + for i in range(1, 5): + key = store.get_offset_key(TP("topic", i)).encode() + store.bt_table.data[key] = str(i).encode() + tp_key = store._add_partition_prefix_to_key(f"key{i}".encode(), i) + store.bt_table.data[tp_key] = str(i).encode() + + res = sorted(store._iteritems()) + assert res == [(f"key{i}".encode(), str(i).encode()) for i in range(1, 5)] + store.bt_table.read_rows.assert_called_once() + + def test_bigtable_iteritems_with_global_table2(self, store, bt_imports): + store.table.is_global = False + store.table.use_partitioner = False + store._mutation_batcher_enable = True + store._mutation_batcher = MagicMock(flush=MagicMock()) + store._active_partitions = MagicMock(return_value={1, 3}) + # Add table to data fro partition 1 to 5 with corresponding offset keys + store.bt_table.data = {} + for i in range(1, 5): + key = store.get_offset_key(TP("topic", i)).encode() + store.bt_table.data[key] = str(i).encode() + tp_key = store._add_partition_prefix_to_key(f"key{i}".encode(), i) + store.bt_table.data[tp_key] = str(i).encode() + + res = sorted(store._iteritems()) + assert res == [(f"key{i}".encode(), str(i).encode()) for i in [1, 3]] + store.bt_table.read_rows.assert_called_once() + store._mutation_batcher.flush.assert_called_once() + + def test_get_after_delete(self, store, bt_imports): + partitions = [19, 20] + store._get_cache = MagicMock(return_value=(b"this is ignored", False)) + store._key_index = {} + store._get_current_partitions = MagicMock(return_value=partitions) + row_mock = MagicMock() + row_mock.commit = MagicMock() + row_mock.delete = MagicMock() + store.bt_table.direct_row = MagicMock(return_value=row_mock) + store.bt_table.direct_row = MagicMock(return_value=row_mock) + store._mutation_batcher_enable = True + + key_right = b"20_..._" + self.TEST_KEY1 + key_wrong = b"19_..._" + self.TEST_KEY1 + + # This is the case if a delete happened before + store._mutation_batcher_cache = {key_right: b"123", key_wrong: None} + store.bt_table.add_test_data(key_right) + + res = store._get(self.TEST_KEY1) + assert res is not None diff --git a/tests/unit/transport/drivers/test_aiokafka.py b/tests/unit/transport/drivers/test_aiokafka.py index 35fe87ba4..aec4dd1ad 100644 --- a/tests/unit/transport/drivers/test_aiokafka.py +++ b/tests/unit/transport/drivers/test_aiokafka.py @@ -1,3 +1,4 @@ +import inspect import random import string from contextlib import contextmanager @@ -1379,26 +1380,32 @@ def assert_new_producer( security_protocol="PLAINTEXT", **kwargs, ): + expected_kwargs = dict( + bootstrap_servers=bootstrap_servers, + client_id=client_id, + acks=acks, + linger_ms=linger_ms, + max_batch_size=max_batch_size, + max_request_size=max_request_size, + compression_type=compression_type, + security_protocol=security_protocol, + partitioner=producer.partitioner, + transactional_id=None, + metadata_max_age_ms=metadata_max_age_ms, + connections_max_idle_ms=connections_max_idle_ms, + request_timeout_ms=request_timeout_ms, + **kwargs, + ) + if ( + "api_version" + in inspect.signature(aiokafka.AIOKafkaProducer.__init__).parameters + ): + expected_kwargs["api_version"] = api_version + with patch("aiokafka.AIOKafkaProducer") as AIOKafkaProducer: p = producer._new_producer() assert p is AIOKafkaProducer.return_value - AIOKafkaProducer.assert_called_once_with( - bootstrap_servers=bootstrap_servers, - client_id=client_id, - acks=acks, - linger_ms=linger_ms, - max_batch_size=max_batch_size, - max_request_size=max_request_size, - compression_type=compression_type, - security_protocol=security_protocol, - partitioner=producer.partitioner, - transactional_id=None, - api_version=api_version, - metadata_max_age_ms=metadata_max_age_ms, - connections_max_idle_ms=connections_max_idle_ms, - request_timeout_ms=request_timeout_ms, - **kwargs, - ) + AIOKafkaProducer.assert_called_once_with(**expected_kwargs) class TestProducer(ProducerBaseTest):