diff --git a/PERF_IDEAS.md b/PERF_IDEAS.md new file mode 100644 index 00000000..09caaaca --- /dev/null +++ b/PERF_IDEAS.md @@ -0,0 +1,145 @@ +# Tracing Performance Optimization Ideas + +Baseline: **967 us/req** user thread, **39 us/item** flush (5000 reqs e2e) + +## User Thread Bottlenecks (profiled with cProfile, 3000 reqs) + +Total: 10.88s for 3000 requests = 3627 us/req + +### 1. `_to_bt_safe` primitives-last (1.63s / 15%) + +`_to_bt_safe` is called for every leaf value in `_deep_copy_object`. It checks +Span/Experiment/Dataset/Logger isinstance, dataclasses, Pydantic model_dump +(with warnings.catch_warnings + filterwarnings!), and Pydantic v1 `.dict()` -- +all before checking if the value is a simple int/str/float/bool/None. + +**Fix**: Move primitive checks (`type(v) is int/str/float`) to the top of +`_to_bt_safe`. Guard Pydantic attempts with `hasattr(v, "model_dump")`. + +**Impact**: Eliminates ~5s of isinstance/warnings/regex overhead. Estimated 3-5x +improvement on user thread. + +### 2. `_deep_copy_object` uses `isinstance(v, Mapping)` (0.62s + isinstance overhead) + +Every dict goes through `isinstance(v, (Mapping, list, tuple, set))` which is +slow for abstract types from `collections.abc`. Then a second +`isinstance(v, Mapping)` check. + +**Fix**: Use `type(v) is dict` / `type(v) is list` for the common fast path. +Also inline the primitive check at the top of `_deep_copy_object` to skip +calling `_to_bt_safe` entirely for leaf values. + +**Impact**: Combined with #1, reduces `_deep_copy_object` from ~9.2s to ~0.3s. + +### 3. `warnings.catch_warnings` + `filterwarnings` in `_to_bt_safe` (0.51s + 0.49s) + +Every call to `_to_bt_safe` on a non-primitive does +`warnings.catch_warnings()` + `warnings.filterwarnings(...)` which involves +regex compilation (`re.compile`), list manipulation, and lock acquisition. +Called 195k times for 3000 requests. + +**Fix**: Already fixed by #1 (primitives skip this entirely). Additionally, +guard with `hasattr(v, "model_dump")` so only actual Pydantic models pay +the cost. + +### 4. `get_caller_location()` always called (visible in __init__) + +`get_caller_location()` walks the stack with `inspect.currentframe()` on every +span creation, even when the caller provides an explicit `name=`. + +**Fix**: Only call `get_caller_location()` when `name is None`. + +**Impact**: Small but free (~5us per span). + +### 5. `bt_safe_deep_copy` called on internal-only data (end/set_attributes) + +`end()` calls `log_internal(internal_data={metrics: {end: time}})` which goes +through the full `bt_safe_deep_copy`. This data is all primitives -- no user +object references to break. + +**Fix**: Skip `bt_safe_deep_copy` when `event` is None/empty (internal-only). + +**Impact**: Saves ~15-20us per `end()` call. + +### 6. `_strip_nones` recurses unnecessarily (0.11s) + +Called with `deep=True` on internal_data, recurses into every nested dict even +when there are no Nones. Also always creates a new dict even when no Nones. + +**Fix**: Fast-path: check if any values are None before copying. Use +`type(d) is dict` instead of `isinstance`. Skip recursion when no nested dicts. + +### 7. `split_logging_data` does redundant work for empty event/internal_data + +When `event=None` (from `end()`), it still calls `_validate_and_sanitize({})`, +`_strip_nones({})`, and `merge_dicts({}, ...)`. + +**Fix**: Short-circuit when one side is empty. Add early return to +`_validate_and_sanitize` for empty events. + +### 8. `_EXEC_COUNTER` uses threading.Lock (small) + +Global counter protected by a lock. Under CPython GIL, `itertools.count()` with +`next()` is atomic and lock-free. + +**Fix**: Replace `threading.Lock` + global int with `itertools.count(1)`. + +### 9. `merge_dicts` path tracking overhead (small) + +`merge_dicts` delegates to `merge_dicts_with_paths(... (), set())` which creates +tuples for every key path. The simple `merge_dicts` call never uses merge_paths. + +**Fix**: Inline the simple merge logic in `merge_dicts` without path tracking. + +## Flush Thread Bottlenecks (profiled with cProfile, 3000 reqs) + +Total: 0.756s for 6000 items = 126 us/item (includes merge of 18000 -> 6000) + +### 10. `_get_exporter` calls `os.getenv` every time (0.048s) + +Called 18000 times in flush. Does `os.getenv("BRAINTRUST_OTEL_COMPAT")` + +`.lower()` comparison each time. + +**Fix**: Cache the result in a module-level variable. Add `_reset_cached_exporter()` +for tests. Also reuse in `export()` which has a duplicate env var check. + +### 11. `compute_record` creates SpanComponentsV3 per item (in _get_exporter cost) + +Each queued item's `compute_record()` closure calls `_get_exporter()(object_type=..., +object_id=...).object_id_fields()`, creating a new dataclass with `__post_init__` +assertions and then a small dict. This is constant per span. + +**Fix**: Cache `object_id_fields` result per span in a `LazyValue`, reuse across +all `compute_record` closures from the same span. + +### 12. `merge_row_batch` with merge_dicts_with_paths (0.134s) + +The merge step uses the full `merge_dicts_with_paths` with tuple path tracking. +Also `_pop_merge_row_skip_fields` / `_restore_merge_row_skip_fields` do field-by-field +dict manipulation. + +**Fix**: Already partially addressed by #9 (merge_dicts fast path). Further +optimization possible but lower priority since flush is already fast. + +## Implementation Priority + +High impact (implement first): +1. `_to_bt_safe` primitives-first + hasattr guards (#1, #3) +2. `_deep_copy_object` type-identity fast paths (#2) +3. Skip deep copy for internal-only data (#5) +4. Lazy `get_caller_location` (#4) + +Medium impact: +5. `_strip_nones` / `split_logging_data` / `_validate_and_sanitize` fast paths (#6, #7) +6. `merge_dicts` inline fast path (#9) +7. Cache `_get_exporter` (#10) +8. Cache `object_id_fields` per span (#11) + +Low impact: +9. `itertools.count` for exec counter (#8) + +## Expected Combined Result + +Based on isolated testing of each change: +- User thread: ~967 us/req -> ~200 us/req (4-5x improvement) +- Flush: ~39 us/item -> ~25 us/item (1.5x improvement, more with orjson) diff --git a/py/bench_e2e.py b/py/bench_e2e.py new file mode 100644 index 00000000..c7c478e8 --- /dev/null +++ b/py/bench_e2e.py @@ -0,0 +1,136 @@ +"""End-to-end CPU time benchmark + cProfile analysis for tracing. + +Usage: + python bench_e2e.py # benchmark only + python bench_e2e.py --profile # benchmark + cProfile breakdown +""" + +import cProfile +import os +import pstats +import sys +import time + + +os.environ["BRAINTRUST_DISABLE_ATEXIT_FLUSH"] = "true" +sys.path.insert(0, "src") + +from braintrust.logger import ( + BraintrustState, + SpanImpl, + SpanObjectTypeV3, + _MemoryBackgroundLogger, + stringify_with_overflow_meta, +) +from braintrust.merge_row_batch import merge_row_batch +from braintrust.util import LazyValue + + +def make_state(): + state = BraintrustState() + ml = _MemoryBackgroundLogger() + state._override_bg_logger.logger = ml + pid = LazyValue(lambda: "proj-abc123", use_mutex=False) + pid.get() + return state, ml, pid + + +def run_workload(state, ml, pid, num_requests): + """Simulate num_requests LLM calls with root + child spans.""" + t_start = time.perf_counter() + for i in range(num_requests): + root = SpanImpl( + parent_object_type=SpanObjectTypeV3.PROJECT_LOGS, + parent_object_id=pid, + parent_compute_object_metadata_args=None, + parent_span_ids=None, + name="handle_request", + state=state, + event={ + "input": { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": f"Question {i}: What is {i} + {i}?"}, + ] + }, + "metadata": {"user_id": f"user_{i % 100}", "session_id": "sess_abc"}, + }, + lookup_span_parent=False, + ) + child = root.start_span( + name="llm_call", + input={"model": "gpt-4", "temperature": 0.7, "max_tokens": 500}, + ) + child.log( + output={ + "choices": [{"message": {"role": "assistant", "content": f"The answer is {i * 2}."}}], + "usage": {"prompt_tokens": 50, "completion_tokens": 20, "total_tokens": 70}, + }, + metrics={"latency": 0.234, "tokens_per_second": 85.5}, + ) + child.end() + root.log( + output=f"The answer is {i * 2}.", + scores={"accuracy": 0.95, "relevance": 0.88}, + ) + root.end() + t_user = time.perf_counter() - t_start + return t_user + + +def run_flush(ml): + """Simulate the flush path (unwrap lazy values, merge, stringify).""" + items = ml.logs[:] + t0 = time.perf_counter() + unwrapped = [it.get() for it in items] + merged = merge_row_batch(unwrapped) + _ = [stringify_with_overflow_meta(m) for m in merged] + t_flush = time.perf_counter() - t0 + return t_flush, len(items), len(merged) + + +def benchmark(): + # Warmup + s, ml, pid = make_state() + run_workload(s, ml, pid, 10) + + print("End-to-end benchmark") + print("=" * 70) + for n in [100, 1000, 5000]: + s, ml, pid = make_state() + t_user = run_workload(s, ml, pid, n) + t_flush, num_items, num_merged = run_flush(ml) + t_total = t_user + t_flush + print( + f" {n:5d} reqs: " + f"user={t_user * 1000:7.1f}ms ({t_user / n * 1e6:5.0f} us/req) " + f"flush={t_flush * 1000:7.1f}ms ({t_flush / num_merged * 1e6:5.0f} us/item) " + f"total={t_total * 1000:7.1f}ms" + ) + + +def profile(): + N = 3000 + + # Profile user thread + s, ml, pid = make_state() + pr = cProfile.Profile() + pr.enable() + run_workload(s, ml, pid, N) + pr.disable() + print(f"\n=== User thread profile ({N} requests) ===") + pstats.Stats(pr).sort_stats("tottime").print_stats(30) + + # Profile flush + pr2 = cProfile.Profile() + pr2.enable() + run_flush(ml) + pr2.disable() + print(f"\n=== Flush profile ({N} requests) ===") + pstats.Stats(pr2).sort_stats("tottime").print_stats(20) + + +if __name__ == "__main__": + benchmark() + if "--profile" in sys.argv: + profile() diff --git a/py/src/braintrust/bt_json.py b/py/src/braintrust/bt_json.py index e0c7be13..84b31474 100644 --- a/py/src/braintrust/bt_json.py +++ b/py/src/braintrust/bt_json.py @@ -19,6 +19,29 @@ def _to_bt_safe(v: Any) -> Any: """ Converts the object to a Braintrust-safe representation (i.e. Attachment objects are safe (specially handled by background logger)). """ + # Fast path: check primitives first via type identity. These are the + # vast majority of values in logged data and must not pay the cost of + # isinstance checks against abstract classes or Pydantic model_dump. + if v is None or v is True or v is False: + return v + tv = type(v) + if tv is int or tv is str or tv is float: + if tv is float: + if math.isnan(v): + return "NaN" + if math.isinf(v): + return "Infinity" if v > 0 else "-Infinity" + return v + # Also catch str/int subclasses (e.g. str-enums like SpanTypeAttribute) + if isinstance(v, (int, str, bool)): + return v + if isinstance(v, float): + if math.isnan(v): + return "NaN" + if math.isinf(v): + return "Infinity" if v > 0 else "-Infinity" + return v + # avoid circular imports from braintrust.logger import BaseAttachment, Dataset, Experiment, Logger, ReadonlyAttachment, Span @@ -57,32 +80,20 @@ def _to_bt_safe(v: Any) -> Any: # Suppress Pydantic serializer warnings that arise from generic/discriminated-union # models (e.g. OpenAI's ParsedResponse[T]). See # https://github.com/braintrustdata/braintrust-sdk-python/issues/60 - try: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="Pydantic serializer warnings", category=UserWarning) - return cast(Any, v).model_dump(exclude_none=True) - except (AttributeError, TypeError): - pass + if hasattr(v, "model_dump"): + try: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Pydantic serializer warnings", category=UserWarning) + return cast(Any, v).model_dump(exclude_none=True) + except (AttributeError, TypeError): + pass # Attempt to dump a Pydantic v1 `BaseModel`. - try: - return cast(Any, v).dict(exclude_none=True) - except (AttributeError, TypeError): - pass - - if isinstance(v, float): - # Handle NaN and Infinity for JSON compatibility - if math.isnan(v): - return "NaN" - - if math.isinf(v): - return "Infinity" if v > 0 else "-Infinity" - - return v - - if isinstance(v, (int, str, bool)) or v is None: - # Skip roundtrip for primitive types. - return v + if hasattr(v, "dict") and not isinstance(v, type): + try: + return cast(Any, v).dict(exclude_none=True) + except (AttributeError, TypeError): + pass # Note: we avoid using copy.deepcopy, because it's difficult to # guarantee the independence of such copied types from their origin. @@ -127,44 +138,94 @@ def bt_safe_deep_copy(obj: Any, max_depth: int = 200): """ # Track visited objects to detect circular references visited: set[int] = set() + visited_add = visited.add + visited_discard = visited.discard + _to_bt_safe_fn = _to_bt_safe def _deep_copy_object(v: Any, depth: int = 0) -> Any: - # Check depth limit - use >= to stop before exceeding + # Fast path: primitives don't need deep copy or circular ref tracking. + if v is None or v is True or v is False: + return v + tv = type(v) + if tv is int or tv is str or tv is float: + if tv is float: + if math.isnan(v): + return "NaN" + if math.isinf(v): + return "Infinity" if v > 0 else "-Infinity" + return v + # Also catch str/int subclasses (e.g. str-enums) + if isinstance(v, (int, str, bool)): + return v + if isinstance(v, float): + if math.isnan(v): + return "NaN" + if math.isinf(v): + return "Infinity" if v > 0 else "-Infinity" + return v + if depth >= max_depth: return "" - # Check for circular references in mutable containers - # Use id() to track object identity - if isinstance(v, (Mapping, list, tuple, set)): + # Fast path for dict (the most common container in log data). + # Uses type identity instead of isinstance(v, Mapping) which is slow. + if tv is dict: obj_id = id(v) if obj_id in visited: return "" - visited.add(obj_id) + visited_add(obj_id) try: - if isinstance(v, Mapping): - # Prevent dict keys from holding references to user data. Note that - # `bt_json` already coerces keys to string, a behavior that comes from - # `json.dumps`. However, that runs at log upload time, while we want to - # cut out all the references to user objects synchronously in this - # function. - result = {} - for k in v: + result = {} + for k in v: + if type(k) is str: + key_str = k + else: try: key_str = str(k) except Exception: - # If str() fails on the key, use a fallback representation key_str = f"" - result[key_str] = _deep_copy_object(v[k], depth + 1) - return result - elif isinstance(v, (list, tuple, set)): - return [_deep_copy_object(x, depth + 1) for x in v] + result[key_str] = _deep_copy_object(v[k], depth + 1) + return result + finally: + visited_discard(obj_id) + elif tv is list or tv is tuple: + obj_id = id(v) + if obj_id in visited: + return "" + visited_add(obj_id) + try: + return [_deep_copy_object(x, depth + 1) for x in v] + finally: + visited_discard(obj_id) + # Slow path for non-builtin Mapping/set types. + elif isinstance(v, Mapping): + obj_id = id(v) + if obj_id in visited: + return "" + visited_add(obj_id) + try: + result = {} + for k in v: + try: + key_str = str(k) + except Exception: + key_str = f"" + result[key_str] = _deep_copy_object(v[k], depth + 1) + return result + finally: + visited_discard(obj_id) + elif isinstance(v, (set,)): + obj_id = id(v) + if obj_id in visited: + return "" + visited_add(obj_id) + try: + return [_deep_copy_object(x, depth + 1) for x in v] finally: - # Remove from visited set after processing to allow the same object - # to appear in different branches of the tree - visited.discard(obj_id) + visited_discard(obj_id) try: - return _to_bt_safe(v) + return _to_bt_safe_fn(v) except Exception: return f"" diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index a9ba479b..e9d286c7 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -2768,6 +2768,8 @@ def _helper(v: Any) -> Any: def _validate_and_sanitize_experiment_log_partial_args(event: Mapping[str, Any]) -> dict[str, Any]: + if not event: + return {} scores = event.get("scores") if scores: for name, score in scores.items(): @@ -4192,7 +4194,7 @@ def __init__( if self.propagated_event: merge_dicts(event, self.propagated_event) - caller_location = get_caller_location() + caller_location = get_caller_location() if name is None else None if name is None: if not parent_span_ids: name = "root" @@ -4300,7 +4302,13 @@ def log_internal(self, event: dict[str, Any] | None = None, internal_data: dict[ **{IS_MERGE_FIELD: self._is_merge}, ) - serializable_partial_record = bt_safe_deep_copy(partial_record) + # Only deep copy when user event data is present. Internal-only data + # (metrics, span_attributes, created, context) contains only primitives + # and doesn't reference user objects, so deep copy is unnecessary. + if event: + serializable_partial_record = bt_safe_deep_copy(partial_record) + else: + serializable_partial_record = partial_record if serializable_partial_record.get("metrics", {}).get("end") is not None: self._logged_end_time = serializable_partial_record["metrics"]["end"] @@ -4556,9 +4564,23 @@ def stringify_exception(exc_type: type[BaseException], exc_value: BaseException, def _strip_nones(d: T, deep: bool) -> T: - if not isinstance(d, dict): + if type(d) is not dict: + return d + has_none = any(v is None for v in d.values()) + if not has_none and not deep: + return d + if deep: + if has_none: + return {k: (_strip_nones(v, True) if type(v) is dict else v) for k, v in d.items() if v is not None} # type: ignore + if any(type(v) is dict for v in d.values()): + return {k: (_strip_nones(v, True) if type(v) is dict else v) for k, v in d.items()} # type: ignore return d - return {k: (_strip_nones(v, deep) if deep else v) for (k, v) in d.items() if v is not None} # type: ignore + if has_none: + return {k: v for k, v in d.items() if v is not None} # type: ignore + return d + + +_EMPTY_DICT: dict[str, Any] = {} def split_logging_data( @@ -4567,24 +4589,34 @@ def split_logging_data( # There should be no overlap between the dictionaries being merged, # except for `sanitized` and `internal_data`, where the former overrides # the latter. - sanitized = _validate_and_sanitize_experiment_log_partial_args(event or {}) - sanitized_and_internal_data = _strip_nones(internal_data or {}, deep=True) - merge_dicts(sanitized_and_internal_data, _strip_nones(sanitized, deep=False)) + sanitized = _validate_and_sanitize_experiment_log_partial_args(event or _EMPTY_DICT) + + if internal_data and sanitized: + sanitized_and_internal_data = _strip_nones(internal_data, deep=True) + merge_dicts(sanitized_and_internal_data, _strip_nones(sanitized, deep=False)) + elif internal_data: + sanitized_and_internal_data = _strip_nones(internal_data, deep=True) + elif sanitized: + sanitized_and_internal_data = _strip_nones(sanitized, deep=False) + else: + return _EMPTY_DICT, _EMPTY_DICT - serializable_partial_record: dict[str, Any] = {} + # Fast path: no BraintrustStream values (the common case) lazy_partial_record: dict[str, Any] = {} - for k, v in sanitized_and_internal_data.items(): + for v in sanitized_and_internal_data.values(): if isinstance(v, BraintrustStream): - # Python has weird semantics with loop variables and lambda functions, so we - # capture `v` by plugging it through a closure that itself returns the LazyValue - def make_final_value_callback(v): - return LazyValue(lambda: v.copy().final_value(), use_mutex=False) + serializable_partial_record: dict[str, Any] = {} + for k2, v2 in sanitized_and_internal_data.items(): + if isinstance(v2, BraintrustStream): - lazy_partial_record[k] = make_final_value_callback(v) - else: - serializable_partial_record[k] = v + def make_final_value_callback(v2): + return LazyValue(lambda: v2.copy().final_value(), use_mutex=False) - return serializable_partial_record, lazy_partial_record + lazy_partial_record[k2] = make_final_value_callback(v2) + else: + serializable_partial_record[k2] = v2 + return serializable_partial_record, lazy_partial_record + return sanitized_and_internal_data, lazy_partial_record class Dataset(ObjectFetcher[DatasetEvent]): diff --git a/py/src/braintrust/util.py b/py/src/braintrust/util.py index 516cb9b6..068817e1 100644 --- a/py/src/braintrust/util.py +++ b/py/src/braintrust/util.py @@ -98,9 +98,26 @@ def merge_dicts_with_paths( def merge_dicts(merge_into: dict[str, Any], merge_from: Mapping[str, Any]) -> dict[str, Any]: - """Merges merge_from into merge_into, destructively updating merge_into.""" + """Merges merge_from into merge_into, destructively updating merge_into. - return merge_dicts_with_paths(merge_into, merge_from, (), set()) + Inlines the common fast path to avoid tuple path tracking overhead. + """ + for k, merge_from_v in merge_from.items(): + merge_into_v = merge_into.get(k) + if type(merge_into_v) is dict and type(merge_from_v) is dict: + merge_dicts(merge_into_v, merge_from_v) + elif k in _SET_UNION_FIELDS and isinstance(merge_into_v, list) and isinstance(merge_from_v, list): + seen: set[str] = set() + combined = [] + for item in merge_into_v + list(merge_from_v): + item_key = json.dumps(item, sort_keys=True) if isinstance(item, (dict, list)) else str(item) + if item_key not in seen: + seen.add(item_key) + combined.append(item) + merge_into[k] = combined + else: + merge_into[k] = merge_from_v + return merge_into def encode_uri_component(name: str) -> str: