From 10eb833ba9b1b274ae9510a8b13d14f192140fb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Thu, 19 Mar 2026 17:01:11 -0700 Subject: [PATCH 1/3] fix: langchain python pkg got updated, some submodules are now accessible with langchain_classic module --- py/src/braintrust/wrappers/langchain.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/py/src/braintrust/wrappers/langchain.py b/py/src/braintrust/wrappers/langchain.py index 6beeb578..7a4c9f6d 100644 --- a/py/src/braintrust/wrappers/langchain.py +++ b/py/src/braintrust/wrappers/langchain.py @@ -9,11 +9,11 @@ _logger = logging.getLogger("braintrust.wrappers.langchain") try: - from langchain.callbacks.base import BaseCallbackHandler - from langchain.schema import Document - from langchain.schema.agent import AgentAction - from langchain.schema.messages import BaseMessage - from langchain.schema.output import LLMResult + from langchain_classic.callbacks.base import BaseCallbackHandler + from langchain_classic.schema import Document + from langchain_classic.schema.agent import AgentAction + from langchain_classic.schema.messages import BaseMessage + from langchain_classic.schema.output import LLMResult except ImportError: _logger.warning("Failed to import langchain, using stubs") BaseCallbackHandler = object From 28ad13e3ba143d0d37503521a5d0ba4f8172bda0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Fri, 20 Mar 2026 22:32:18 +0000 Subject: [PATCH 2/3] chore: start removing ty errors --- py/src/braintrust/__init__.py | 28 +++--- py/src/braintrust/bt_json.py | 6 +- py/src/braintrust/cli/install/api.py | 28 +++--- py/src/braintrust/cli/push.py | 2 +- py/src/braintrust/devserver/dataset.py | 21 +++-- py/src/braintrust/devserver/server.py | 32 ++++--- py/src/braintrust/framework.py | 43 +++++---- py/src/braintrust/framework2.py | 18 ++-- py/src/braintrust/logger.py | 91 ++++++++++--------- py/src/braintrust/oai.py | 22 +++-- py/src/braintrust/otel/__init__.py | 30 +++--- py/src/braintrust/otel/context.py | 16 ++-- py/src/braintrust/parameters.py | 29 +++--- py/src/braintrust/serializable_data_class.py | 25 ++--- py/src/braintrust/span_identifier_v3.py | 7 +- py/src/braintrust/span_identifier_v4.py | 9 +- py/src/braintrust/test_bt_json.py | 18 ++-- py/src/braintrust/test_logger.py | 54 +++++------ py/src/braintrust/wrappers/adk/__init__.py | 2 +- py/src/braintrust/wrappers/agno/utils.py | 4 +- .../claude_agent_sdk/_test_transport.py | 2 +- .../wrappers/claude_agent_sdk/_wrapper.py | 8 +- .../wrappers/claude_agent_sdk/test_wrapper.py | 16 ++-- .../braintrust/wrappers/langsmith_wrapper.py | 16 ++-- py/src/braintrust/wrappers/litellm.py | 10 +- py/src/braintrust/wrappers/test_openai.py | 39 ++++---- 26 files changed, 308 insertions(+), 268 deletions(-) diff --git a/py/src/braintrust/__init__.py b/py/src/braintrust/__init__.py index 32ef4999..8b24fd54 100644 --- a/py/src/braintrust/__init__.py +++ b/py/src/braintrust/__init__.py @@ -63,7 +63,7 @@ def is_equal(expected, output): from .audit import * from .auto import ( - auto_instrument, # noqa: F401 # type: ignore[reportUnusedImport] + auto_instrument, # noqa: F401 ) from .framework import * from .framework2 import * @@ -72,29 +72,29 @@ def is_equal(expected, output): from .generated_types import * from .logger import * from .logger import ( - _internal_get_global_state, # noqa: F401 # type: ignore[reportUnusedImport] - _internal_reset_global_state, # noqa: F401 # type: ignore[reportUnusedImport] - _internal_with_custom_background_logger, # noqa: F401 # type: ignore[reportUnusedImport] + _internal_get_global_state, # noqa: F401 + _internal_reset_global_state, # noqa: F401 + _internal_with_custom_background_logger, # noqa: F401 ) from .oai import ( - wrap_openai, # noqa: F401 # type: ignore[reportUnusedImport] + wrap_openai, # noqa: F401 ) from .sandbox import ( - RegisteredSandboxFunction, # noqa: F401 # type: ignore[reportUnusedImport] - RegisterSandboxResult, # noqa: F401 # type: ignore[reportUnusedImport] - SandboxConfig, # noqa: F401 # type: ignore[reportUnusedImport] - register_sandbox, # noqa: F401 # type: ignore[reportUnusedImport] + RegisteredSandboxFunction, # noqa: F401 + RegisterSandboxResult, # noqa: F401 + SandboxConfig, # noqa: F401 + register_sandbox, # noqa: F401 ) from .util import ( - BT_IS_ASYNC_ATTRIBUTE, # noqa: F401 # type: ignore[reportUnusedImport] - MarkAsyncWrapper, # noqa: F401 # type: ignore[reportUnusedImport] + BT_IS_ASYNC_ATTRIBUTE, # noqa: F401 + MarkAsyncWrapper, # noqa: F401 ) from .wrappers.anthropic import ( - wrap_anthropic, # noqa: F401 # type: ignore[reportUnusedImport] + wrap_anthropic, # noqa: F401 ) from .wrappers.litellm import ( - wrap_litellm, # noqa: F401 # type: ignore[reportUnusedImport] + wrap_litellm, # noqa: F401 ) from .wrappers.pydantic_ai import ( - setup_pydantic_ai, # noqa: F401 # type: ignore[reportUnusedImport] + setup_pydantic_ai, # noqa: F401 ) diff --git a/py/src/braintrust/bt_json.py b/py/src/braintrust/bt_json.py index e0c7be13..a9702f50 100644 --- a/py/src/braintrust/bt_json.py +++ b/py/src/braintrust/bt_json.py @@ -245,11 +245,11 @@ def bt_dumps(obj: Any, encoder: Encoder | None = _json_encoder, **kwargs: Any) - # Try orjson first for better performance try: # pylint: disable=no-member # orjson is a C extension, pylint can't introspect it - return orjson.dumps( # type: ignore[possibly-unbound] + return orjson.dumps( obj, default=encoder.orjson if encoder else None, # options match json.dumps behavior for bc - option=orjson.OPT_SORT_KEYS | orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS, # type: ignore[possibly-unbound] + option=orjson.OPT_SORT_KEYS | orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS, ).decode("utf-8") except Exception: # If orjson fails, fall back to standard json @@ -278,7 +278,7 @@ def bt_loads(s: str, **kwargs) -> Any: # Try orjson first for better performance try: # pylint: disable=no-member # orjson is a C extension, pylint can't introspect it - return orjson.loads(s) # type: ignore[possibly-unbound] + return orjson.loads(s) except Exception: # If orjson fails, fall back to standard json pass diff --git a/py/src/braintrust/cli/install/api.py b/py/src/braintrust/cli/install/api.py index 6c95774a..fda1a507 100644 --- a/py/src/braintrust/cli/install/api.py +++ b/py/src/braintrust/cli/install/api.py @@ -2,6 +2,7 @@ import os import textwrap import time +from typing import Any from botocore.exceptions import ClientError from braintrust.logger import app_conn, login @@ -281,7 +282,7 @@ def build_parser(subparsers, parents): def main(args): template = args.template or LATEST_TEMPLATE - status = None + status: Any = None try: statuses = cloudformation.describe_stacks(StackName=args.name)["Stacks"] if len(statuses) == 1: @@ -377,9 +378,11 @@ def main(args): _logger.info(f"Stack with name {args.name} has been created with status: {status['StackStatus']}") exit(0) - _logger.info(f"Stack with name {args.name} has status: {status['StackStatus']}") + from typing import cast + status_any = cast(Any, status) + _logger.info(f"Stack with name {args.name} has status: {status_any['StackStatus']}") - if not ("_COMPLETE" in status["StackStatus"] or "_FAILED" in status["StackStatus"]): + if not ("_COMPLETE" in status_any["StackStatus"] or "_FAILED" in status_any["StackStatus"]): _logger.info(f"Please re-run this command once the stack has finished creating or updating") exit(0) @@ -400,7 +403,7 @@ def main(args): new_template = cloudformation.get_template_summary(TemplateURL=template) new_params = set(x["ParameterKey"] for x in new_template["Parameters"]) else: - new_params = set(x["ParameterKey"] for x in status["Parameters"]) + new_params = set(x["ParameterKey"] for x in cast(Any, status)["Parameters"]) stack = cloudformation.describe_stacks(StackName=args.name)["Stacks"][0] try: @@ -459,7 +462,7 @@ def main(args): _logger.info(f"Stack with name {args.name} has been updated with status: {status['StackStatus']}") _logger.info(f"Universal URL: {universal_url}") - org_info = [] + org_info: Any = [] if args.api_key: login(api_key=args.api_key) resp = app_conn().post("api/apikey/login") @@ -484,16 +487,17 @@ def main(args): if len(org_info) == 1: org_info = org_info[0] - if org_info and (universal_url and org_info["api_url"] != universal_url): + org_info_any = cast(Any, org_info) + if org_info_any and (universal_url and org_info_any["api_url"] != universal_url): if args.update_stack_url: - _logger.info(f"Will update org {org_info['name']}'s urls.") + _logger.info(f"Will update org {org_info_any['name']}'s urls.") _logger.info(f" They are currently set to:") - _logger.info(f" API URL: {org_info['api_url']}") - _logger.info(f" Proxy URL: {org_info['proxy_url']}") + _logger.info(f" API URL: {org_info_any['api_url']}") + _logger.info(f" Proxy URL: {org_info_any['proxy_url']}") _logger.info(f"And will update them to:") - patch_args = {"id": org_info["id"]} - if universal_url and org_info["api_url"] != universal_url: + patch_args = {"id": org_info_any["id"]} + if universal_url and org_info_any["api_url"] != universal_url: patch_args["api_url"] = universal_url patch_args["is_universal_api"] = True _logger.info(f" API URL: {universal_url}") @@ -510,6 +514,6 @@ def main(args): ) else: _logger.info(f"Stack URL differs from organization API URL:") - _logger.info(f" Current API URL: {org_info['api_url']}") + _logger.info(f" Current API URL: {org_info_any['api_url']}") _logger.info(f" Stack Universal URL: {universal_url}") _logger.info(f"To update the organization's API URL, rerun with --update-stack-url flag") diff --git a/py/src/braintrust/cli/push.py b/py/src/braintrust/cli/push.py index 6c95e2cd..2f00214c 100644 --- a/py/src/braintrust/cli/push.py +++ b/py/src/braintrust/cli/push.py @@ -51,7 +51,7 @@ def _pydantic_to_json_schema(m): def _check_uv(): try: - import uv as _ # noqa: F401 # type: ignore[reportUnusedImport] + import uv as _ # noqa: F401 except ImportError: raise ValueError( textwrap.dedent( diff --git a/py/src/braintrust/devserver/dataset.py b/py/src/braintrust/devserver/dataset.py index de222efb..926d5e0f 100644 --- a/py/src/braintrust/devserver/dataset.py +++ b/py/src/braintrust/devserver/dataset.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, cast from braintrust import init_dataset from braintrust._generated_types import RunEvalData, RunEvalData1, RunEvalData2 @@ -34,28 +34,29 @@ async def get_dataset(state: BraintrustState, data: RunEvalData | RunEvalData1 | """ # Handle dict-based data (common case) if isinstance(data, dict): - if "project_name" in data and "dataset_name" in data: + data_dict = cast(dict[str, Any], data) + if "project_name" in data_dict and "dataset_name" in data_dict: # Dataset reference by name return init_dataset( state=state, - project=data["project_name"], - name=data["dataset_name"], + project=data_dict["project_name"], + name=data_dict["dataset_name"], # _internal_btql is optional - **({"_internal_btql": data["_internal_btql"]} if "_internal_btql" in data else {}), + **({"_internal_btql": data_dict["_internal_btql"]} if "_internal_btql" in data_dict else {}), ) - elif "dataset_id" in data: + elif "dataset_id" in data_dict: # Dataset reference by ID - dataset_info = await get_dataset_by_id(state, data["dataset_id"]) + dataset_info = await get_dataset_by_id(state, data_dict["dataset_id"]) return init_dataset( state=state, project_id=dataset_info["project_id"], name=dataset_info["dataset"], # _internal_btql is optional - **({"_internal_btql": data["_internal_btql"]} if "_internal_btql" in data else {}), + **({"_internal_btql": data_dict["_internal_btql"]} if "_internal_btql" in data_dict else {}), ) - elif "data" in data: + elif "data" in data_dict: # Inline data - return data["data"] + return data_dict["data"] # If it's not a dict, assume it's inline data return data diff --git a/py/src/braintrust/devserver/server.py b/py/src/braintrust/devserver/server.py index 4b3ff79a..457a5389 100644 --- a/py/src/braintrust/devserver/server.py +++ b/py/src/braintrust/devserver/server.py @@ -2,7 +2,7 @@ import json import sys import textwrap -from typing import Any +from typing import Any, cast try: @@ -183,17 +183,21 @@ async def run_eval(request: Request) -> JSONResponse | StreamingResponse: async def task(input: Any, hooks: EvalHooks[Any]): task_hooks = hooks if validated_parameters is None else _ParameterOverrideHooks(hooks, validated_parameters) + task_hooks_typed = cast(EvalHooks[Any], task_hooks) if bt_iscoroutinefunction(evaluator.task): - result = await evaluator.task(input, task_hooks) + result = await cast(Any, evaluator.task)(input, task_hooks_typed) else: - result = evaluator.task(input, task_hooks) + result = cast(Any, evaluator.task)(input, task_hooks_typed) task_hooks.report_progress( - { - "format": "code", - "output_type": "completion", - "event": "json_delta", - "data": json.dumps(result), - } + cast( + Any, + { + "format": "code", + "output_type": "completion", + "event": "json_delta", + "data": json.dumps(result), + }, + ) ) return result @@ -212,7 +216,7 @@ def stream_fn(event: SSEProgressEvent): parent = eval_data.get("parent") if parent: - parent = parse_parent(parent) + parent = parse_parent(cast(str | dict | None, parent)) eval_kwargs = { k: v for (k, v) in evaluator.__dict__.items() if k not in ["eval_name", "project_name", "parameter_values"] @@ -222,14 +226,14 @@ def stream_fn(event: SSEProgressEvent): try: eval_task = asyncio.create_task( - EvalAsync( + cast(Any, EvalAsync)( name=eval_data["name"], **{ **eval_kwargs, "state": state, "scores": evaluator.scores + [ - make_scorer(state, score["name"], score["function_id"], ctx.project_id) + make_scorer(state, cast(str, score["name"]), cast(Any, score["function_id"]), ctx.project_id) for score in eval_data.get("scores", []) ], "stream": stream_fn, @@ -310,8 +314,8 @@ def create_app(evaluators: list[Evaluator[Any, Any]], org_name: str | None = Non app = Starlette(routes=routes) # Add middlewares in reverse order (last added is executed first) - app.add_middleware(CheckAuthorizedMiddleware, allowed_org_name=org_name) - app.add_middleware(AuthorizationMiddleware) + cast(Any, app).add_middleware(CheckAuthorizedMiddleware, allowed_org_name=org_name) + cast(Any, app).add_middleware(AuthorizationMiddleware) app.add_middleware(create_cors_middleware()) return app diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index d80fb1f9..3871f667 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -20,6 +20,7 @@ Optional, TypeVar, Union, + cast, ) from tqdm.asyncio import tqdm as async_tqdm @@ -131,7 +132,6 @@ class EvalResult(SerializableDataClass, Generic[Input, Output]): exc_info: str | None = None -@dataclasses.dataclass class TaskProgressEvent(TypedDict): """Progress event that can be reported during task execution.""" @@ -610,14 +610,14 @@ def pluralize(n, singular, plural): return plural -def report_failures(evaluator: Evaluator, failing_results: Iterable[EvalResult], verbose: bool, jsonl: bool) -> None: +def report_failures(evaluator: Evaluator, failing_results: list[EvalResult], verbose: bool, jsonl: bool) -> None: eprint( f"{bcolors.FAIL}Evaluator {evaluator.eval_name} failed with {len(failing_results)} {pluralize(len(failing_results), 'error', 'errors')}{bcolors.ENDC}" ) - errors = [ + errors: list[str] = [ ( - result.exc_info + result.exc_info or "" if verbose or jsonl else "\n".join(traceback.format_exception_only(type(result.error), result.error)) ) @@ -706,7 +706,7 @@ def _EvalCommon( project_name=name, data=data, task=task, - scores=scores, + scores=list(scores), experiment_name=experiment_name, trial_count=trial_count, metadata=metadata, @@ -740,7 +740,7 @@ async def make_empty_summary(): "Must specify a reporter object, not a name. Can only specify reporter names when running 'braintrust eval'" ) - reporter = reporter or default_reporter + reporter = cast(Any, reporter or default_reporter) if base_experiment_name is None and isinstance(evaluator.data, BaseExperiment): base_experiment_name = evaluator.data.name @@ -1167,7 +1167,7 @@ def __init__( expected: Any | None = None, trial_index: int = 0, tags: Sequence[str] | None = None, - report_progress: Callable[[TaskProgressEvent], None] = None, + report_progress: Optional[Callable[[TaskProgressEvent], None]] = None, parameters: ValidatedParameters | None = None, ): if metadata is not None: @@ -1384,9 +1384,9 @@ async def await_or_run_scorer(root_span, scorer, name, **kwargs): raise ValueError( f"When returning an array of scores, each score must be a valid Score object. Got: {s}" ) - result = list(result) + result = cast(list[Score], list(result)) elif is_score(result): - result = [result] + result = [cast(Score, result)] else: result = [Score(name=name, score=result)] @@ -1452,10 +1452,10 @@ async def run_evaluator_task(datum, trial_index=0): ) if experiment: - root_span = experiment.start_span(**base_event) + root_span = cast(Any, experiment).start_span(**base_event) else: # In most cases this will be a no-op span, but if the parent is set, it will use that ctx. - root_span = start_span(state=state, **base_event) + root_span = cast(Any, start_span)(state=state, **base_event) with root_span: try: @@ -1464,8 +1464,15 @@ def report_progress(event: TaskProgressEvent): if not stream: return stream( - SSEProgressEvent( - id=root_span.id, origin=origin, name=evaluator.eval_name, object_type="task", **event + cast( + Any, + SSEProgressEvent( + id=root_span.id, + origin=cast(Any, origin), + name=evaluator.eval_name, + object_type="task", + **event, + ), ) ) @@ -1473,9 +1480,9 @@ def report_progress(event: TaskProgressEvent): metadata, expected=datum.expected, trial_index=trial_index, - tags=tags, + tags=cast(Any, tags), report_progress=report_progress, - parameters=resolved_evaluator_parameters, + parameters=cast(Any, resolved_evaluator_parameters), ) # Check if the task takes a hooks argument @@ -1608,7 +1615,7 @@ async def ensure_spans_flushed(): ) except Exception as e: exc_type, exc_value, tb = sys.exc_info() - root_span.log(error=stringify_exception(exc_type, exc_value, tb)) + root_span.log(error=stringify_exception(cast(Any, exc_type), cast(Any, exc_value), tb)) error = e # Python3.10 has a different set of arguments to format_exception than earlier versions, @@ -1619,7 +1626,7 @@ async def ensure_spans_flushed(): input=datum.input, expected=datum.expected, metadata=metadata, - tags=tags, + tags=cast(Any, tags), output=output, scores={ **( @@ -1702,7 +1709,7 @@ async def with_max_concurrency(coro): def build_local_summary( - evaluator: Evaluator[Input, Output], results: list[EvalResultWithSummary[Input, Output]] + evaluator: Evaluator[Input, Output], results: list[EvalResult[Input, Output]] ) -> ExperimentSummary: scores_by_name = defaultdict(lambda: (0, 0)) for result in results: diff --git a/py/src/braintrust/framework2.py b/py/src/braintrust/framework2.py index 64399696..751d8ef2 100644 --- a/py/src/braintrust/framework2.py +++ b/py/src/braintrust/framework2.py @@ -1,12 +1,12 @@ import dataclasses import json from collections.abc import Callable, Sequence -from typing import Any, overload +from typing import Any, cast, overload import slugify from braintrust.logger import api_conn, app_conn, login -from .framework import _is_lazy_load, bcolors # type: ignore +from .framework import _is_lazy_load, bcolors from .generated_types import ( ChatCompletionMessageParam, IfExists, @@ -187,8 +187,9 @@ def create( """ self._task_counter += 1 if not name: - if handler.__name__ and handler.__name__ != "": - name = handler.__name__ + handler_name = getattr(handler, "__name__", "") + if handler_name and handler_name != "": + name = handler_name else: name = f"Tool {self._task_counter}" assert name is not None @@ -300,7 +301,7 @@ def create( tool_functions.append(tool) else: # ToolFunctionDefinition - raw_tools.append(tool) + raw_tools.append(cast(ToolFunctionDefinition, tool)) prompt_data: PromptData = {} if messages is not None: @@ -477,9 +478,10 @@ def create( choice_scores: The scores for each choice. Required. """ self._task_counter += 1 - if name is None or len(name) == 0: - if handler and handler.__name__ and handler.__name__ != "": - name = handler.__name__ + if not name: + handler_name = getattr(handler, "__name__", "") if handler else "" + if handler_name and handler_name != "": + name = handler_name else: name = f"Scorer {self._task_counter}" if slug is None or len(slug) == 0: diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index a9ba479b..c12b9868 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -68,7 +68,7 @@ from .merge_row_batch import batch_items, merge_row_batch from .object import DEFAULT_IS_LEGACY_DATASET, ensure_dataset_record from .parameters import RemoteEvalParameters -from .prompt import BRAINTRUST_PARAMS, ImagePart, PromptBlockData, PromptData, PromptMessage, PromptSchema, TextPart +from .prompt import BRAINTRUST_PARAMS, ImagePart, PromptBlockData, PromptChatBlock, PromptCompletionBlock, PromptData, PromptMessage, PromptSchema, TextPart from .prompt_cache.disk_cache import DiskCache from .prompt_cache.lru_cache import LRUCache from .prompt_cache.parameters_cache import ParametersCache @@ -915,6 +915,9 @@ def log(self, *args: LazyValue[dict[str, Any]]) -> None: def flush(self, batch_size: int | None = None): pass + def set_masking_function(self, masking_function: "Callable[[Any], Any] | None") -> None: + pass + class _MemoryBackgroundLogger(_BackgroundLogger): def __init__(self): @@ -1036,7 +1039,7 @@ def __init__(self, api_conn: LazyValue[HTTPConnection]): except: self.queue_drop_logging_period = 60 - self._queue_drop_logging_state = dict(lock=threading.Lock(), num_dropped=0, last_logged_timestamp=0) + self._queue_drop_logging_state: dict[str, Any] = dict(lock=threading.Lock(), num_dropped=0, last_logged_timestamp=0) try: self.failed_publish_payloads_dir = os.environ["BRAINTRUST_FAILED_PUBLISH_PAYLOADS_DIR"] @@ -1667,7 +1670,7 @@ def compute_metadata(): # pylint: disable=function-redefined def compute_metadata(): state.login(org_name=org_name, api_key=api_key, app_url=app_url) - args = { + args: dict[str, Any] = { "project_name": project, "project_id": project_id, "org_id": state.org_id, @@ -1687,7 +1690,7 @@ def compute_metadata(): merged_git_metadata_settings = state.git_metadata_settings if git_metadata_settings is not None: merged_git_metadata_settings = GitMetadataSettings.merge( - merged_git_metadata_settings, git_metadata_settings + cast(Any, merged_git_metadata_settings), git_metadata_settings ) repo_info_arg = get_repo_info(merged_git_metadata_settings) @@ -1704,13 +1707,14 @@ def compute_metadata(): if dataset is not None: if isinstance(dataset, dict): # Simple {"id": ..., "version": ...} dict - args["dataset_id"] = dataset["id"] - if "version" in dataset: - args["dataset_version"] = dataset["version"] + dataset_dict = cast(dict[str, Any], dataset) + args["dataset_id"] = dataset_dict["id"] + if "version" in dataset_dict: + args["dataset_version"] = dataset_dict["version"] else: # Full Dataset object - args["dataset_id"] = dataset.id - args["dataset_version"] = dataset.version + args["dataset_id"] = cast(Any, dataset).id + args["dataset_version"] = cast(Any, dataset).version parameters_ref = _get_parameters_ref(parameters) if parameters_ref is not None: @@ -1839,17 +1843,17 @@ def _compute_logger_metadata(project_name: str | None = None, project_id: str | ) resp_project = response["project"] return OrgProjectMetadata( - org_id=org_id, + org_id=cast(str, org_id), project=ObjectMetadata(id=resp_project["id"], name=resp_project["name"], full_info=resp_project), ) elif project_name is None: response = _state.app_conn().get_json("api/project", {"id": project_id}) return OrgProjectMetadata( - org_id=org_id, project=ObjectMetadata(id=project_id, name=response["name"], full_info=response) + org_id=cast(str, org_id), project=ObjectMetadata(id=project_id, name=response["name"], full_info=response) ) else: return OrgProjectMetadata( - org_id=org_id, project=ObjectMetadata(id=project_id, name=project_name, full_info=dict()) + org_id=cast(str, org_id), project=ObjectMetadata(id=project_id, name=project_name, full_info=dict()) ) @@ -2583,9 +2587,9 @@ def wrapper_sync_gen(*f_args, **f_kwargs): # We determine if the decorator is invoked bare or with arguments by # checking if the first positional argument to the decorator is a callable. if len(span_args) == 1 and len(span_kwargs) == 0 and callable(span_args[0]): - return decorator(span_args[1:], span_kwargs, cast(F, span_args[0])) + return cast(Any, decorator)(span_args[1:], span_kwargs, span_args[0]) else: - return cast(Callable[[F], F], partial(decorator, span_args, span_kwargs)) + return cast(Any, partial(decorator, span_args, span_kwargs)) def start_span( @@ -3135,7 +3139,7 @@ def error_wrapper() -> AttachmentStatus: status["error_message"] = str(e) request_params = { - "key": self._reference["key"], + "key": cast(Any, self._reference)["key"], "org_id": org_id, "status": status, } @@ -3683,7 +3687,8 @@ def compute_parent_object_id(): arg_parent_object_id = LazyValue(compute_parent_object_id, use_mutex=False) if parent_components.row_id: arg_parent_span_ids = ParentSpanIds( - span_id=parent_components.span_id, root_span_id=parent_components.root_span_id + span_id=cast(str, parent_components.span_id), + root_span_id=cast(str, parent_components.root_span_id), ) else: arg_parent_span_ids = None @@ -3736,14 +3741,17 @@ def __next__(self) -> _ExperimentDatasetEvent: continue output, expected = value.get("output"), value.get("expected") - ret: _ExperimentDatasetEvent = { - "input": value.get("input"), - "expected": expected if expected is not None else output, - "tags": value.get("tags"), - "metadata": value.get("metadata"), - "id": value["id"], - "_xact_id": value["_xact_id"], - } + ret: _ExperimentDatasetEvent = cast( + Any, + { + "input": value.get("input"), + "expected": expected if expected is not None else output, + "tags": value.get("tags"), + "metadata": value.get("metadata"), + "id": value["id"], + "_xact_id": value["_xact_id"], + }, + ) return ret @@ -3977,7 +3985,7 @@ def summarize( self.flush() state = self._get_state() - project_url = f"{state.app_public_url}/app/{encode_uri_component(state.org_name)}/p/{encode_uri_component(self.project.name)}" + project_url = f"{state.app_public_url}/app/{encode_uri_component(cast(str, state.org_name))}/p/{encode_uri_component(self.project.name)}" experiment_url = f"{project_url}/experiments/{encode_uri_component(self.name)}" score_summary = {} @@ -4820,7 +4828,7 @@ def summarize(self, summarize_data: bool = True) -> "DatasetSummary": # includes the new experiment. self.flush() state = self._get_state() - project_url = f"{state.app_public_url}/app/{encode_uri_component(state.org_name)}/p/{encode_uri_component(self.project.name)}" + project_url = f"{state.app_public_url}/app/{encode_uri_component(cast(str, state.org_name))}/p/{encode_uri_component(self.project.name)}" dataset_url = f"{project_url}/datasets/{encode_uri_component(self.name)}" data_summary = None @@ -4929,7 +4937,7 @@ def render_message(render: Callable[[str], str], message: PromptMessage): def _create_custom_render(): def _get_key(key: str, scopes: list[dict[str, Any]], warn: bool) -> Any: - thing = chevron.renderer._get_key(key, scopes, warn) # type: ignore + thing = chevron.renderer._get_key(key, scopes, warn) if isinstance(thing, str): return thing return json.dumps(thing) @@ -4962,9 +4970,9 @@ def render_templated_object(obj: Any, args: Any) -> Any: if isinstance(obj, str): return render_mustache(obj, data=args, renderer=_custom_render, strict=strict) elif isinstance(obj, list): - return [render_templated_object(item, args) for item in obj] # type: ignore + return [render_templated_object(item, args) for item in obj] elif isinstance(obj, dict): - return {str(k): render_templated_object(v, args) for k, v in obj.items()} # type: ignore + return {str(k): render_templated_object(v, args) for k, v in obj.items()} return obj @@ -5059,7 +5067,7 @@ def from_prompt_data( @property def id(self) -> str: - return self._lazy_metadata.get().id + return cast(str, self._lazy_metadata.get().id) @property def name(self) -> str: @@ -5075,11 +5083,11 @@ def prompt(self) -> PromptBlockData | None: @property def version(self) -> str: - return self._lazy_metadata.get()._xact_id + return cast(str, self._lazy_metadata.get()._xact_id) @property def options(self) -> PromptOptions: - return self._lazy_metadata.get().prompt_data.options or {} + return cast(Any, self._lazy_metadata.get().prompt_data.options or {}) # Capture all metadata attributes which aren't covered by existing methods. def __getattr__(self, name: str) -> Any: @@ -5125,9 +5133,9 @@ def build(self, **build_args: Any) -> Mapping[str, Any]: if not self.prompt: raise ValueError("Empty prompt") - if self.prompt.type == "completion": + if isinstance(self.prompt, PromptCompletionBlock): ret["prompt"] = render_mustache(self.prompt.content, data=build_args, strict=strict) - elif self.prompt.type == "chat": + elif isinstance(self.prompt, PromptChatBlock): def render(template: str): return render_mustache(template, data=build_args, strict=strict) @@ -5141,7 +5149,7 @@ def render(template: str): def _make_iter_list(self) -> Sequence[str]: meta_keys = list(self.options.keys()) - if self.prompt.type == "completion": + if isinstance(self.prompt, PromptCompletionBlock): meta_keys.append("prompt") else: meta_keys.append("chat") @@ -5157,11 +5165,11 @@ def __len__(self) -> int: def __getitem__(self, x): if x == "prompt": - return self.prompt.prompt + return cast(Any, self.prompt).prompt elif x == "chat": - return self.prompt.messages + return cast(Any, self.prompt).messages elif x == "tools": - return self.prompt.tools + return cast(Any, self.prompt).tools else: return self.options[x] @@ -5194,7 +5202,7 @@ def lazy_init(self): @property def id(self) -> str: self.lazy_init() - return self._id + return cast(str, self._id) @property def name(self): @@ -5449,8 +5457,9 @@ def _get_link_base_url(self) -> str | None: # the url and org name can be passed into init_logger, resolved by login or provided as env variables # so this resolves all of those things. It's possible we never have an org name if the user has not # yet logged in and there is nothing else configured. - app_url = self.state.app_url or self._link_args.get("app_url") or _get_app_url() - org_name = self.state.org_name or self._link_args.get("org_name") or _get_org_name() + link_args = self._link_args or {} + app_url = self.state.app_url or link_args.get("app_url") or _get_app_url() + org_name = self.state.org_name or link_args.get("org_name") or _get_org_name() if not app_url or not org_name: return None return f"{app_url}/app/{org_name}" diff --git a/py/src/braintrust/oai.py b/py/src/braintrust/oai.py index be8c3b17..09de540b 100644 --- a/py/src/braintrust/oai.py +++ b/py/src/braintrust/oai.py @@ -4,7 +4,7 @@ import time import warnings from collections.abc import Callable -from typing import Any +from typing import Any, cast from wrapt import wrap_function_wrapper @@ -54,7 +54,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - def __class__(self): # type: ignore + def __class__(self): return self._response.__class__ def __str__(self) -> str: @@ -160,7 +160,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - create_response = self.create_fn(*args, **kwargs) + create_response = cast(Callable[..., Any], self.create_fn)(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() log_headers(create_response, span) @@ -213,7 +213,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - create_response = await self.acreate_fn(*args, **kwargs) + create_response = await cast(Callable[..., Any], self.acreate_fn)(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() @@ -415,7 +415,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - create_response = self.create_fn(*args, **kwargs) + create_response = cast(Callable[..., Any], self.create_fn)(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() log_headers(create_response, span) @@ -467,7 +467,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - create_response = await self.acreate_fn(*args, **kwargs) + create_response = await cast(Callable[..., Any], self.acreate_fn)(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() log_headers(create_response, span) @@ -656,7 +656,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any: with start_span( **merge_dicts(dict(name=self._name, span_attributes={"type": SpanTypeAttribute.LLM}), params) ) as span: - create_response = self._create_fn(*args, **kwargs) + create_response = cast(Callable[..., Any], self._create_fn)(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() log_headers(create_response, span) @@ -673,7 +673,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any: with start_span( **merge_dicts(dict(name=self._name, span_attributes={"type": SpanTypeAttribute.LLM}), params) ) as span: - create_response = await self._acreate_fn(*args, **kwargs) + create_response = await cast(Callable[..., Any], self._acreate_fn)(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() log_headers(create_response, span) @@ -837,6 +837,7 @@ def __init__(self, chat: Any): super().__init__(chat) import openai + import openai.resources.chat.completions # ensure submodule is imported if type(chat.completions) == openai.resources.chat.completions.AsyncCompletions: self.completions = AsyncCompletionsV1Wrapper(chat.completions) @@ -925,6 +926,9 @@ class OpenAIV1Wrapper(NamedWrapper): def __init__(self, openai: Any): super().__init__(openai) import openai as oai + import openai.resources.embeddings + import openai.resources.moderations + import openai.resources.responses.responses self.chat = ChatV1Wrapper(openai.chat) @@ -1106,7 +1110,7 @@ def patch_openai() -> bool: wrap_function_wrapper("openai", "OpenAI.__init__", _openai_init_wrapper) wrap_function_wrapper("openai", "AsyncOpenAI.__init__", _openai_init_wrapper) - openai.__braintrust_wrapped__ = True + setattr(openai, "__braintrust_wrapped__", True) return True except ImportError: diff --git a/py/src/braintrust/otel/__init__.py b/py/src/braintrust/otel/__init__.py index e6fe7f3e..12d2cb68 100644 --- a/py/src/braintrust/otel/__init__.py +++ b/py/src/braintrust/otel/__init__.py @@ -1,6 +1,7 @@ import logging import os import warnings +from typing import Any, cast from urllib.parse import urljoin @@ -330,7 +331,7 @@ def _get_parent_otel_braintrust_parent(self, parent_context): if current_span and hasattr(current_span, "attributes") and current_span.attributes: # Check if parent span has braintrust.parent attribute - attributes = dict(current_span.attributes) + attributes = dict(cast(Any, current_span.attributes)) return attributes.get("braintrust.parent") return None @@ -440,8 +441,8 @@ def context_from_span_export(export_str: str): ) # Convert hex strings to OTEL integers - trace_id_int = int(components.root_span_id, 16) - span_id_int = int(components.span_id, 16) + trace_id_int = int(cast(str, components.root_span_id), 16) + span_id_int = int(cast(str, components.span_id), 16) # Create OTEL SpanContext marked as remote span_context = SpanContext( @@ -631,37 +632,38 @@ def parent_from_headers(headers: dict[str, str], propagator=None) -> str | None: return None if braintrust_parent: + braintrust_parent_str = cast(str, braintrust_parent) from braintrust.span_identifier_v3 import SpanObjectTypeV3 # Parse braintrust.parent format: "project_id:abc", "project_name:xyz", or "experiment_id:123" - if braintrust_parent.startswith("project_id:"): + if braintrust_parent_str.startswith("project_id:"): object_type = SpanObjectTypeV3.PROJECT_LOGS - object_id = braintrust_parent[len("project_id:") :] + object_id = braintrust_parent_str[len("project_id:") :] if not object_id: logging.error( - f"parent_from_headers: Invalid braintrust.parent format (empty project_id): {braintrust_parent}" + f"parent_from_headers: Invalid braintrust.parent format (empty project_id): {braintrust_parent_str}" ) return None - elif braintrust_parent.startswith("project_name:"): + elif braintrust_parent_str.startswith("project_name:"): object_type = SpanObjectTypeV3.PROJECT_LOGS - project_name = braintrust_parent[len("project_name:") :] + project_name = braintrust_parent_str[len("project_name:") :] if not project_name: logging.error( - f"parent_from_headers: Invalid braintrust.parent format (empty project_name): {braintrust_parent}" + f"parent_from_headers: Invalid braintrust.parent format (empty project_name): {braintrust_parent_str}" ) return None compute_args = {"project_name": project_name} - elif braintrust_parent.startswith("experiment_id:"): + elif braintrust_parent_str.startswith("experiment_id:"): object_type = SpanObjectTypeV3.EXPERIMENT - object_id = braintrust_parent[len("experiment_id:") :] + object_id = braintrust_parent_str[len("experiment_id:") :] if not object_id: logging.error( - f"parent_from_headers: Invalid braintrust.parent format (empty experiment_id): {braintrust_parent}" + f"parent_from_headers: Invalid braintrust.parent format (empty experiment_id): {braintrust_parent_str}" ) return None else: logging.error( - f"parent_from_headers: Invalid braintrust.parent format: {braintrust_parent}. " + f"parent_from_headers: Invalid braintrust.parent format: {braintrust_parent_str}. " "Expected format: 'project_id:ID', 'project_name:NAME', or 'experiment_id:ID'" ) return None @@ -669,7 +671,7 @@ def parent_from_headers(headers: dict[str, str], propagator=None) -> str | None: # Create SpanComponentsV4 and export as string # Set row_id to enable span_id/root_span_id (required for parent linking) components = SpanComponentsV4( - object_type=object_type, + object_type=cast(Any, object_type), object_id=object_id, compute_object_metadata_args=compute_args, row_id="otel", # Dummy row_id to enable span_id/root_span_id fields diff --git a/py/src/braintrust/otel/context.py b/py/src/braintrust/otel/context.py index bb65be77..3dea2fdc 100644 --- a/py/src/braintrust/otel/context.py +++ b/py/src/braintrust/otel/context.py @@ -1,7 +1,7 @@ """Unified context management using OTEL's built-in context.""" import logging -from typing import Any, Optional +from typing import Any, Optional, cast from braintrust.context import ParentSpanIds, SpanInfo from braintrust.logger import Span @@ -39,7 +39,8 @@ def get_current_span_info(self) -> Optional["SpanInfo"]: # If there's a BT span stored AND the current OTEL span is a NonRecordingSpan # (which means it's our BT->OTEL wrapper), then return BT span info if bt_span and isinstance(current_span, trace.NonRecordingSpan): - return SpanInfo(trace_id=bt_span.root_span_id, span_id=bt_span.span_id, span_object=bt_span) + bt_span_any = cast(Any, bt_span) + return SpanInfo(trace_id=bt_span_any.root_span_id, span_id=bt_span_any.span_id, span_object=bt_span) else: # Return OTEL span info - this is a real OTEL span, not our wrapper otel_trace_id = format(span_context.trace_id, "032x") @@ -55,22 +56,23 @@ def set_current_span(self, span: Span) -> Any: # This is an OTEL span - it will manage its own context return None else: + bt_span = cast(Any, span) try: - trace_id_int = int(span.root_span_id, 16) + trace_id_int = int(bt_span.root_span_id, 16) except ValueError: - log.debug(f"Invalid root_span_id: {span.root_span_id}") + log.debug(f"Invalid root_span_id: {bt_span.root_span_id}") return None try: - span_id_int = int(span.span_id, 16) + span_id_int = int(bt_span.span_id, 16) except ValueError: - log.debug(f"Invalid span_id: {span.span_id}") + log.debug(f"Invalid span_id: {bt_span.span_id}") return None # This is a BT span - store it in OTEL context AND set as current OTEL span # First store the BT span ctx = context.set_value("braintrust_span", span) - parent_value = span._get_otel_parent() + parent_value = bt_span._get_otel_parent() ctx = context.set_value("braintrust.parent", parent_value, ctx) otel_span_context = SpanContext( diff --git a/py/src/braintrust/parameters.py b/py/src/braintrust/parameters.py index 595ba3ce..dd2581d3 100644 --- a/py/src/braintrust/parameters.py +++ b/py/src/braintrust/parameters.py @@ -2,11 +2,11 @@ from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, TypedDict +from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast from jsonschema import Draft7Validator from jsonschema.exceptions import ValidationError as JSONSchemaValidationError -from typing_extensions import NotRequired +from typing_extensions import NotRequired, TypeGuard from .prompt import PromptData from .serializable_data_class import SerializableDataClass @@ -80,11 +80,11 @@ def _pydantic_to_json_schema(model: Any) -> dict[str, Any]: raise ValueError(f"Cannot convert {model} to JSON schema - not a pydantic model") -def _is_prompt_parameter(schema: Any) -> bool: +def _is_prompt_parameter(schema: Any) -> TypeGuard[PromptParameter]: return isinstance(schema, dict) and schema.get("type") == "prompt" -def _is_model_parameter(schema: Any) -> bool: +def _is_model_parameter(schema: Any) -> TypeGuard[ModelParameter]: return isinstance(schema, dict) and schema.get("type") == "model" @@ -151,7 +151,7 @@ def _resolve_local_json_schema_refs( def _serialize_pydantic_parameter_schema(schema: Any) -> dict[str, Any]: schema_json = _pydantic_to_json_schema(schema) - schema_json = _resolve_local_json_schema_refs(schema_json, schema_json) + schema_json = cast(dict[str, Any], _resolve_local_json_schema_refs(schema_json, schema_json)) schema_json.pop("$defs", None) schema_json.pop("definitions", None) fields = _get_pydantic_fields(schema) @@ -279,28 +279,29 @@ def _validate_local_parameters( elif schema is None: result[name] = value elif _is_pydantic_model(schema): + schema_cls = cast(Any, schema) fields = _get_pydantic_fields(schema) if len(fields) == 1 and "value" in fields: if value is None: try: - default_instance = schema() + default_instance = schema_cls() result[name] = default_instance.value except Exception as exc: raise ValueError(f"Parameter '{name}' is required") from exc elif hasattr(schema, "parse_obj"): - result[name] = schema.parse_obj({"value": value}).value + result[name] = schema_cls.parse_obj({"value": value}).value else: - result[name] = schema.model_validate({"value": value}).value + result[name] = schema_cls.model_validate({"value": value}).value else: if value is None: try: - result[name] = schema() + result[name] = schema_cls() except Exception as exc: raise ValueError(f"Parameter '{name}' is required") from exc elif hasattr(schema, "parse_obj"): - result[name] = schema.parse_obj(value) + result[name] = schema_cls.parse_obj(value) else: - result[name] = schema.model_validate(value) + result[name] = schema_cls.model_validate(value) else: result[name] = value except JSONSchemaValidationError as exc: @@ -344,7 +345,7 @@ def serialize_eval_parameters(parameters: EvalParameters) -> dict[str, Any]: for name, schema in parameters.items(): if _is_prompt_parameter(schema): - parameter_data = { + parameter_data: dict[str, Any] = { "type": "prompt", "description": schema.get("description"), } @@ -416,10 +417,10 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema: property_schema["description"] = schema["description"] properties[name] = property_schema elif _is_model_parameter(schema): - property_schema = { + property_schema = cast(dict[str, Any], { "type": "string", "x-bt-type": "model", - } + }) if "default" in schema: property_schema["default"] = schema.get("default") if schema.get("description") is not None: diff --git a/py/src/braintrust/serializable_data_class.py b/py/src/braintrust/serializable_data_class.py index 8f9eeefc..060d7221 100644 --- a/py/src/braintrust/serializable_data_class.py +++ b/py/src/braintrust/serializable_data_class.py @@ -1,12 +1,12 @@ import dataclasses import json -from typing import Union, get_origin +from typing import Any, Union, cast, get_origin class SerializableDataClass: def as_dict(self): """Serialize the object to a dictionary.""" - return dataclasses.asdict(self) + return dataclasses.asdict(cast(Any, self)) def as_json(self, **kwargs): """Serialize the object to JSON.""" @@ -33,14 +33,15 @@ def from_dict_deep(cls, d: dict): if k not in fields: continue + field_type = cast(Any, fields[k].type) if ( isinstance(v, dict) - and isinstance(fields[k].type, type) - and issubclass(fields[k].type, SerializableDataClass) + and isinstance(field_type, type) + and issubclass(field_type, SerializableDataClass) ): - filtered[k] = fields[k].type.from_dict_deep(v) - elif get_origin(fields[k].type) == Union: - for t in fields[k].type.__args__: + filtered[k] = field_type.from_dict_deep(v) + elif get_origin(field_type) == Union: + for t in field_type.__args__: if t == type(None) and v is None: filtered[k] = None break @@ -54,12 +55,12 @@ def from_dict_deep(cls, d: dict): filtered[k] = v elif ( isinstance(v, list) - and get_origin(fields[k].type) == list - and len(fields[k].type.__args__) == 1 - and isinstance(fields[k].type.__args__[0], type) - and issubclass(fields[k].type.__args__[0], SerializableDataClass) + and get_origin(field_type) == list + and len(field_type.__args__) == 1 + and isinstance(field_type.__args__[0], type) + and issubclass(field_type.__args__[0], SerializableDataClass) ): - filtered[k] = [fields[k].type.__args__[0].from_dict_deep(i) for i in v] + filtered[k] = [field_type.__args__[0].from_dict_deep(i) for i in v] else: filtered[k] = v return cls(**filtered) diff --git a/py/src/braintrust/span_identifier_v3.py b/py/src/braintrust/span_identifier_v3.py index d8690315..7ad600ea 100644 --- a/py/src/braintrust/span_identifier_v3.py +++ b/py/src/braintrust/span_identifier_v3.py @@ -6,6 +6,7 @@ import dataclasses import json from enum import Enum +from typing import Any, cast from uuid import UUID from .span_identifier_v2 import SpanComponentsV2 @@ -217,7 +218,7 @@ def _from_json_obj(json_obj: dict) -> "SpanComponentsV3": **json_obj, "object_type": SpanObjectTypeV3(json_obj["object_type"]), } - return SpanComponentsV3(**kwargs) + return SpanComponentsV3(**cast(Any, kwargs)) def parse_parent(parent: str | dict | None) -> str | None: @@ -243,7 +244,7 @@ def parse_parent(parent: str | dict | None) -> str | None: "project_logs": SpanObjectTypeV3.PROJECT_LOGS, } - object_type = object_type_map.get(parent.get("object_type")) + object_type = object_type_map.get(cast(str, parent.get("object_type"))) if not object_type: raise ValueError(f"Invalid object_type: {parent.get('object_type')}") @@ -275,6 +276,6 @@ def parse_parent(parent: str | dict | None) -> str | None: if "propagated_event" in parent: kwargs["propagated_event"] = parent.get("propagated_event") - return SpanComponentsV3(**kwargs).to_str() + return SpanComponentsV3(**cast(Any, kwargs)).to_str() else: return None diff --git a/py/src/braintrust/span_identifier_v4.py b/py/src/braintrust/span_identifier_v4.py index c881ef49..39106824 100644 --- a/py/src/braintrust/span_identifier_v4.py +++ b/py/src/braintrust/span_identifier_v4.py @@ -5,6 +5,7 @@ import dataclasses import json from enum import Enum +from typing import Any, cast from .span_identifier_v3 import ( SpanComponentsV3, @@ -123,7 +124,7 @@ def add_hex_field(orig_val, field_id): hex_bytes, is_hex = None, False if is_hex: - hex_entries.append(bytes([field_id.value]) + hex_bytes) + hex_entries.append(bytes([field_id.value]) + cast(bytes, hex_bytes)) else: json_obj[_FIELDS_ID_TO_NAME[field_id]] = orig_val @@ -230,7 +231,7 @@ def _from_json_obj(json_obj: dict) -> "SpanComponentsV4": **json_obj, "object_type": SpanObjectTypeV3(json_obj["object_type"]), } - return SpanComponentsV4(**kwargs) + return SpanComponentsV4(**cast(Any, kwargs)) def parse_parent(parent: str | dict | None) -> str | None: @@ -245,7 +246,7 @@ def parse_parent(parent: str | dict | None) -> str | None: "project_logs": SpanObjectTypeV3.PROJECT_LOGS, } - object_type = object_type_map.get(parent.get("object_type")) + object_type = object_type_map.get(cast(str, parent.get("object_type"))) if not object_type: raise ValueError(f"Invalid object_type: {parent.get('object_type')}") @@ -275,6 +276,6 @@ def parse_parent(parent: str | dict | None) -> str | None: if "propagated_event" in parent: kwargs["propagated_event"] = parent.get("propagated_event") - return SpanComponentsV4(**kwargs).to_str() + return SpanComponentsV4(**cast(Any, kwargs)).to_str() else: return None diff --git a/py/src/braintrust/test_bt_json.py b/py/src/braintrust/test_bt_json.py index 0064fb41..b867caca 100644 --- a/py/src/braintrust/test_bt_json.py +++ b/py/src/braintrust/test_bt_json.py @@ -3,7 +3,7 @@ # pyright: reportPrivateUsage=false import json import warnings -from typing import Any +from typing import Any, cast from unittest import TestCase import pytest @@ -25,7 +25,7 @@ def testdeep_copy_event_basic(self): def test_deep_copy_mutation_independence(self): """Test that mutating the copy doesn't affect the original (true dereferencing).""" - original = { + original: dict[str, Any] = { "top_level": "value", "nested_dict": {"inner": "data", "deep": {"level": 3}}, "nested_list": [1, 2, [3, 4]], @@ -47,7 +47,7 @@ def test_deep_copy_mutation_independence(self): self.assertEqual(original["nested_dict"]["inner"], "data") self.assertEqual(original["nested_dict"]["deep"]["level"], 3) self.assertEqual(original["nested_list"][0], 1) - self.assertEqual(original["nested_list"][2][0], 3) + self.assertEqual(cast(list, original["nested_list"])[2][0], 3) self.assertEqual(original["nested_in_list"][0]["key"], "val") # Add new keys to copy @@ -194,8 +194,8 @@ def test_deep_copy_empty_containers(self): def test_deep_copy_exactly_max_depth(self): """Test behavior at exactly MAX_DEPTH (200).""" # Create nested structure at depth 199 (just under limit) - nested = {"level": 0} - current = nested + nested: dict[str, Any] = {"level": 0} + current: dict[str, Any] = nested for i in range(1, 199): current["child"] = {"level": i} current = current["child"] @@ -217,8 +217,8 @@ def test_deep_copy_exactly_max_depth(self): def test_deep_copy_exceeds_max_depth(self): """Test behavior exceeding MAX_DEPTH (200).""" # Create nested structure at depth 201 (exceeds limit) - nested = {"level": 0} - current = nested + nested: dict[str, Any] = {"level": 0} + current: dict[str, Any] = nested for i in range(1, 201): current["child"] = {"level": i} current = current["child"] @@ -381,7 +381,7 @@ def test_to_bt_safe_attachments(self): "filename": "readonly.txt", "content_type": "text/plain", } - readonly = ReadonlyAttachment(reference) + readonly = ReadonlyAttachment(cast(Any, reference)) result_readonly = _to_bt_safe(readonly) self.assertEqual(result_readonly, reference) self.assertIsNot(result_readonly, readonly) @@ -503,7 +503,7 @@ def test_bt_safe_deep_copy_mixed_attachment_types(self): "filename": "readonly.txt", "content_type": "text/plain", } - readonly_attachment = ReadonlyAttachment(reference) + readonly_attachment = ReadonlyAttachment(cast(Any, reference)) original = { "base": base_attachment, diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index 7662ad77..226f6195 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -5,7 +5,7 @@ import logging import os import time -from typing import AsyncGenerator, List +from typing import Any, AsyncGenerator, List, cast from unittest import TestCase from unittest.mock import MagicMock, patch @@ -38,10 +38,10 @@ init_test_exp, init_test_logger, preserve_env_vars, - simulate_login, # noqa: F401 # type: ignore[reportUnusedImport] + simulate_login, # noqa: F401 simulate_logout, - with_memory_logger, # noqa: F401 # type: ignore[reportUnusedImport] - with_simulate_login, # noqa: F401 # type: ignore[reportUnusedImport] + with_memory_logger, # noqa: F401 + with_simulate_login, # noqa: F401 ) @@ -134,7 +134,7 @@ def test_init_enable_atexit_flush(self): } with patch("atexit.register") as mock_register: - _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore + _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) mock_register.assert_called() def test_init_disable_atexit_flush(self): @@ -147,17 +147,17 @@ def test_init_disable_atexit_flush(self): with patch.dict(os.environ, {"BRAINTRUST_DISABLE_ATEXIT_FLUSH": "True"}): with patch("atexit.register") as mock_register: - _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore + _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) mock_register.assert_not_called() with patch.dict(os.environ, {"BRAINTRUST_DISABLE_ATEXIT_FLUSH": "1"}): with patch("atexit.register") as mock_register: - _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore + _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) mock_register.assert_not_called() with patch.dict(os.environ, {"BRAINTRUST_DISABLE_ATEXIT_FLUSH": "yes"}): with patch("atexit.register") as mock_register: - _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore + _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) mock_register.assert_not_called() def test_init_with_saved_parameters_attaches_reference(self): @@ -764,7 +764,7 @@ def test_span_log_with_simple_circular_reference(with_memory_logger): with logger.start_span(name="test_span") as span: # Create simple circular reference - data = {"key": "value"} + data: dict[str, Any] = {"key": "value"} data["self"] = data # Should handle circular reference gracefully @@ -790,8 +790,8 @@ def test_span_log_with_nested_circular_reference(with_memory_logger): with logger.start_span(name="test_span") as span: # Create nested structure with circular reference - page = {"page_number": 1, "content": "text"} - document = {"pages": [page]} + page: dict[str, Any] = {"page_number": 1, "content": "text"} + document: dict[str, Any] = {"pages": [page]} page["document"] = document # Should handle circular reference gracefully @@ -819,13 +819,13 @@ def test_span_log_with_deep_document_structure(with_memory_logger): with logger.start_span(name="test_span") as span: # Create deeply nested document structure with circular reference - doc_data = { + doc_data: dict[str, Any] = { "model_id": "document-model", "content": "Document content", "pages": [], } - page = { + page: dict[str, Any] = { "page_number": 1, "lines": [{"content": "Line 1"}], } @@ -867,8 +867,8 @@ def test_span_log_with_extremely_deep_nesting(with_memory_logger): recursion_limit = sys.getrecursionlimit() # Create structure deeper than recursion limit - deeply_nested = {"level": 0} - current = deeply_nested + deeply_nested: dict[str, Any] = {"level": 0} + current: dict[str, Any] = deeply_nested for i in range(1, recursion_limit + 100): current["nested"] = {"level": i} current = current["nested"] @@ -1082,7 +1082,7 @@ def test_span_link_logged_out_org_name(with_memory_logger): link = span.link() assert ( link - == f"https://www.braintrust.dev/app/test-org-name/object?object_type=project_logs&object_id=test-project-id&id={span._id}" + == f"https://www.braintrust.dev/app/test-org-name/object?object_type=project_logs&object_id=test-project-id&id={span.id}" ) @@ -1101,7 +1101,7 @@ def test_span_link_logged_out_org_name_env_vars(with_memory_logger): link = span.link() assert ( link - == f"https://my-own-thing.ca/foo/bar/app/my-own-thing/object?object_type=project_logs&object_id=test-project-id&id={span._id}" + == f"https://my-own-thing.ca/foo/bar/app/my-own-thing/object?object_type=project_logs&object_id=test-project-id&id={span.id}" ) finally: for k, v in originals.items(): @@ -1122,7 +1122,7 @@ def test_span_project_id_logged_in(with_memory_logger, with_simulate_login): link = span.link() assert ( link - == f"https://www.braintrust.dev/app/test-org-name/object?object_type=project_logs&object_id=test-project-id&id={span._id}" + == f"https://www.braintrust.dev/app/test-org-name/object?object_type=project_logs&object_id=test-project-id&id={span.id}" ) @@ -1142,7 +1142,7 @@ def test_span_project_name_logged_in(with_simulate_login, with_memory_logger): span.end() link = span.link() - assert link == f"https://www.braintrust.dev/app/test-org-name/p/test-project/logs?oid={span._id}" + assert link == f"https://www.braintrust.dev/app/test-org-name/p/test-project/logs?oid={span.id}" def test_span_link_with_resolved_experiment(with_simulate_login, with_memory_logger): @@ -1156,13 +1156,13 @@ def test_span_link_with_resolved_experiment(with_simulate_login, with_memory_log assert eid == "test-experiment-id" span = experiment.start_span(name="test-span") - span.parent_object_id = id_lazy_value + cast(Any, span).parent_object_id = id_lazy_value span.end() link = span.link() assert ( link - == f"https://www.braintrust.dev/app/test-org-name/object?object_type=experiment&object_id=test-experiment-id&id={span._id}" + == f"https://www.braintrust.dev/app/test-org-name/object?object_type=experiment&object_id=test-experiment-id&id={span.id}" ) @@ -1197,7 +1197,7 @@ def test_experiment_span_link_uses_env_vars_when_logged_out(with_memory_logger): # Create span with resolved experiment ID span = experiment.start_span(name="test-span") - span.parent_object_id = LazyValue(lambda: "test-exp-id", use_mutex=False) + cast(Any, span).parent_object_id = LazyValue(lambda: "test-exp-id", use_mutex=False) span.end() link = span.link() @@ -1226,7 +1226,7 @@ def test_permalink_with_valid_span_logged_in(with_simulate_login, with_memory_lo link = braintrust.permalink(span_export, org_name="test-org-name", app_url="https://www.braintrust.dev") - expected_link = f"https://www.braintrust.dev/app/test-org-name/object?object_type=project_logs&object_id=test-project-id&id={span._id}" + expected_link = f"https://www.braintrust.dev/app/test-org-name/object?object_type=project_logs&object_id=test-project-id&id={span.id}" assert link == expected_link @@ -1260,7 +1260,7 @@ async def get_link_in_async(): # The link should NOT be the noop link assert link != "https://www.braintrust.dev/noop-span" # The link should contain the span ID - assert span._id in link + assert span.id in link # The link should contain the project ID assert "test-project-id" in link @@ -1346,7 +1346,7 @@ async def level1(): span.end() assert link != "https://www.braintrust.dev/noop-span" - assert span._id in link + assert span.id in link def test_current_logger_in_thread(with_simulate_login, with_memory_logger): @@ -1397,7 +1397,7 @@ def get_link_in_thread(): # The link should NOT be the noop link assert thread_result["link"] != "https://www.braintrust.dev/noop-span" # The link should contain the span ID - assert span._id in thread_result["link"] + assert span.id in thread_result["link"] @pytest.mark.asyncio @@ -3623,7 +3623,7 @@ def test_span_exit_logs_exception_group_sub_exceptions(with_memory_logger): init_test_logger(__name__) with pytest.raises(exceptiongroup.ExceptionGroup): - with braintrust.current_logger().start_span(name="eg-span"): + with cast(Any, braintrust.current_logger()).start_span(name="eg-span"): raise _raise_test_exception_group() logs = with_memory_logger.pop() diff --git a/py/src/braintrust/wrappers/adk/__init__.py b/py/src/braintrust/wrappers/adk/__init__.py index 6c6b8a14..96b7e0cb 100644 --- a/py/src/braintrust/wrappers/adk/__init__.py +++ b/py/src/braintrust/wrappers/adk/__init__.py @@ -608,7 +608,7 @@ def _serialize_config(config: Any) -> dict[str, Any] | Any: # Serialize the config config_dict = bt_safe_deep_copy(config) if not isinstance(config_dict, dict): - return config_dict # type: ignore + return config_dict # Replace schema fields with serialized versions config_dict.update(serialized_schemas) diff --git a/py/src/braintrust/wrappers/agno/utils.py b/py/src/braintrust/wrappers/agno/utils.py index 7951ac7c..1845d275 100644 --- a/py/src/braintrust/wrappers/agno/utils.py +++ b/py/src/braintrust/wrappers/agno/utils.py @@ -360,7 +360,7 @@ def _aggregate_agent_chunks(chunks: list[Any]) -> dict[str, Any]: elif event == "RunContent": if hasattr(chunk, "content") and chunk.content: - aggregated["content"] += str(chunk.content) # type: ignore + aggregated["content"] += str(chunk.content) if hasattr(chunk, "reasoning_content") and chunk.reasoning_content: aggregated["reasoning_content"] += chunk.reasoning_content if hasattr(chunk, "citations"): @@ -379,7 +379,7 @@ def _aggregate_agent_chunks(chunks: list[Any]) -> dict[str, Any]: elif event == "ToolCallStarted": if hasattr(chunk, "tool_call"): - aggregated["tool_calls"].append( # type:ignore + aggregated["tool_calls"].append( { "id": getattr(chunk.tool_call, "id", None), "type": "function", diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/_test_transport.py b/py/src/braintrust/wrappers/claude_agent_sdk/_test_transport.py index 3a516568..7f24353e 100644 --- a/py/src/braintrust/wrappers/claude_agent_sdk/_test_transport.py +++ b/py/src/braintrust/wrappers/claude_agent_sdk/_test_transport.py @@ -69,7 +69,7 @@ def _normalize_write(data: str, *, sanitize: bool = False) -> dict[str, Any]: async def _empty_stream(): return - yield {} # type: ignore[unreachable] + yield {} def _normalize_for_match(value: Any) -> Any: diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py b/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py index e019241d..7a268e49 100644 --- a/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py +++ b/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py @@ -170,7 +170,7 @@ def _serialize_system_message(message: Any) -> dict[str, Any]: def _create_tool_wrapper_class(original_tool_class: Any) -> Any: """Creates a wrapper class for SdkMcpTool that re-enters active TOOL spans.""" - class WrappedSdkMcpTool(original_tool_class): # type: ignore[valid-type,misc] + class WrappedSdkMcpTool(original_tool_class): def __init__( self, name: Any, @@ -180,9 +180,9 @@ def __init__( **kwargs: Any, ): wrapped_handler = _wrap_tool_handler(handler, name) - super().__init__(name, description, input_schema, wrapped_handler, **kwargs) # type: ignore[call-arg] + super().__init__(name, description, input_schema, wrapped_handler, **kwargs) - __class_getitem__ = classmethod(lambda cls, params: cls) # type: ignore[assignment] + __class_getitem__ = classmethod(lambda cls, params: cls) return WrappedSdkMcpTool @@ -232,7 +232,7 @@ async def wrapped_handler(args: Any) -> Any: finally: active_tool_span.release() - wrapped_handler._braintrust_wrapped = True # type: ignore[attr-defined] + wrapped_handler._braintrust_wrapped = True return wrapped_handler diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py b/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py index eb12fa3d..2b9bec74 100644 --- a/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py +++ b/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py @@ -289,7 +289,7 @@ async def test_query_async_iterable(memory_logger, cassette_name, input_factory, wrapped_client_class = _create_client_wrapper_class(FakeClaudeSDKClient) client = wrapped_client_class() - client._WrappedClaudeSDKClient__client.messages = [ # type: ignore[attr-defined] + client._WrappedClaudeSDKClient__client.messages = [ AssistantMessage(content=[TextBlock("done")]), ResultMessage(), ] @@ -537,7 +537,7 @@ async def test_delegated_subagent_llm_and_tool_spans_nest_under_task_span(memory wrapped_client_class = _create_client_wrapper_class(FakeClaudeSDKClient) client = wrapped_client_class() - client._WrappedClaudeSDKClient__client.messages = [ # type: ignore[attr-defined] + client._WrappedClaudeSDKClient__client.messages = [ AssistantMessage( content=[ ToolUseBlock( @@ -618,7 +618,7 @@ async def test_multiple_subagent_orchestration_keeps_outer_agent_tool_calls_outs wrapped_client_class = _create_client_wrapper_class(FakeClaudeSDKClient) client = wrapped_client_class() - client._WrappedClaudeSDKClient__client.messages = [ # type: ignore[attr-defined] + client._WrappedClaudeSDKClient__client.messages = [ AssistantMessage( content=[ TextBlock("Launching the first delegated agent."), @@ -754,7 +754,7 @@ async def test_relay_user_messages_between_parallel_agent_calls_do_not_split_llm wrapped_client_class = _create_client_wrapper_class(FakeClaudeSDKClient) client = wrapped_client_class() - client._WrappedClaudeSDKClient__client.messages = [ # type: ignore[attr-defined] + client._WrappedClaudeSDKClient__client.messages = [ # Orchestrator responds with thinking + text + first Agent call AssistantMessage( content=[ @@ -893,7 +893,7 @@ async def test_agent_tool_spans_encapsulate_child_task_spans(memory_logger): wrapped_client_class = _create_client_wrapper_class(FakeClaudeSDKClient) client = wrapped_client_class() - client._WrappedClaudeSDKClient__client.messages = [ # type: ignore[attr-defined] + client._WrappedClaudeSDKClient__client.messages = [ # Orchestrator responds with text + first Agent call AssistantMessage( content=[ @@ -1209,7 +1209,7 @@ async def test_receive_response_suppresses_cancelled_error_after_messages(memory wrapped_client_class = _create_client_wrapper_class(FakeCancelledClaudeSDKClient) client = wrapped_client_class() - client._WrappedClaudeSDKClient__client.messages = [ # type: ignore[attr-defined] + client._WrappedClaudeSDKClient__client.messages = [ AssistantMessage(content=[TextBlock("The answer is 42.")]), ResultMessage(), ] @@ -1251,7 +1251,7 @@ async def test_receive_response_suppresses_cancelled_error_mid_stream(memory_log wrapped_client_class = _create_client_wrapper_class(FakeCancelledMidStreamClaudeSDKClient) client = wrapped_client_class() - client._WrappedClaudeSDKClient__client.messages = [ # type: ignore[attr-defined] + client._WrappedClaudeSDKClient__client.messages = [ AssistantMessage(content=[TextBlock("Partial answer.")]), # Second message never arrives — CancelledError fires instead. AssistantMessage(content=[TextBlock("This should not be received.")]), @@ -1289,7 +1289,7 @@ async def test_genuine_task_cancel_propagates_after_receive_response(memory_logg wrapped_client_class = _create_client_wrapper_class(FakeCancelledClaudeSDKClient) client = wrapped_client_class() - client._WrappedClaudeSDKClient__client.messages = [ # type: ignore[attr-defined] + client._WrappedClaudeSDKClient__client.messages = [ AssistantMessage(content=[TextBlock("Hello.")]), ResultMessage(), ] diff --git a/py/src/braintrust/wrappers/langsmith_wrapper.py b/py/src/braintrust/wrappers/langsmith_wrapper.py index b22117df..5f5ffe77 100644 --- a/py/src/braintrust/wrappers/langsmith_wrapper.py +++ b/py/src/braintrust/wrappers/langsmith_wrapper.py @@ -158,7 +158,7 @@ def decorator(fn: Callable[P, R]) -> Callable[P, R]: fn = traceable(fn, **kwargs) # Always apply Braintrust tracing - return traced(name=span_name)(fn) # type: ignore[return-value] + return traced(name=span_name)(fn) if func is not None: return decorator(func) @@ -190,7 +190,7 @@ def wrap_client( "evaluate", make_evaluate_wrapper(standalone=standalone, project_name=project_name, project_id=project_id), ) - Client.evaluate._braintrust_patched = True # type: ignore[attr-defined] + Client.evaluate._braintrust_patched = True if hasattr(Client, "aevaluate") and not _is_patched(Client.aevaluate): wrap_function_wrapper( @@ -198,7 +198,7 @@ def wrap_client( "aevaluate", make_aevaluate_wrapper(standalone=standalone, project_name=project_name, project_id=project_id), ) - Client.aevaluate._braintrust_patched = True # type: ignore[attr-defined] + Client.aevaluate._braintrust_patched = True return Client @@ -276,8 +276,8 @@ def wrap_evaluate( return evaluate evaluate_wrapper = make_evaluate_wrapper(standalone=standalone, project_name=project_name, project_id=project_id) - evaluate_wrapper._braintrust_patched = True # type: ignore[attr-defined] - return evaluate_wrapper # type: ignore[return-value] + evaluate_wrapper._braintrust_patched = True + return evaluate_wrapper def wrap_aevaluate( @@ -302,8 +302,8 @@ def wrap_aevaluate( return aevaluate aevaluate_wrapper = make_aevaluate_wrapper(standalone=standalone, project_name=project_name, project_id=project_id) - aevaluate_wrapper._braintrust_patched = True # type: ignore[attr-defined] - return aevaluate_wrapper # type: ignore[return-value] + aevaluate_wrapper._braintrust_patched = True + return aevaluate_wrapper def _is_patched(obj: Any) -> bool: @@ -443,7 +443,7 @@ def load_data() -> Iterator[EvalCase[Any, Any]]: # Determine the source iterable without loading everything into memory source: Iterable[Any] if callable(data): - source = data() # type: ignore + source = data() elif isinstance(data, str): # Load examples from LangSmith dataset by name try: diff --git a/py/src/braintrust/wrappers/litellm.py b/py/src/braintrust/wrappers/litellm.py index 236df998..12d0e25c 100644 --- a/py/src/braintrust/wrappers/litellm.py +++ b/py/src/braintrust/wrappers/litellm.py @@ -3,7 +3,7 @@ import time from collections.abc import AsyncGenerator, Callable, Generator from types import TracebackType -from typing import Any +from typing import Any, cast from braintrust.logger import Span, start_span from braintrust.span_types import SpanTypeAttribute @@ -145,7 +145,7 @@ def completion(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - completion_response = self.completion_fn(*args, **kwargs) + completion_response = cast(Any, self.completion_fn)(*args, **kwargs) # if hasattr(completion_response, "parse"): # raw_response = completion_response.parse() # log_headers(completion_response, span) @@ -175,7 +175,7 @@ async def acompletion(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - completion_response = await self.acompletion_fn(*args, **kwargs) + completion_response = await cast(Any, self.acompletion_fn)(*args, **kwargs) # if hasattr(completion_response, "parse"): # raw_response = completion_response.parse() @@ -323,7 +323,7 @@ def responses(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - response = self.responses_fn(*args, **kwargs) + response = cast(Any, self.responses_fn)(*args, **kwargs) if is_streaming: should_end = False @@ -346,7 +346,7 @@ async def aresponses(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - response = await self.aresponses_fn(*args, **kwargs) + response = await cast(Any, self.aresponses_fn)(*args, **kwargs) if is_streaming: should_end = False diff --git a/py/src/braintrust/wrappers/test_openai.py b/py/src/braintrust/wrappers/test_openai.py index 6ab9b343..bd729936 100644 --- a/py/src/braintrust/wrappers/test_openai.py +++ b/py/src/braintrust/wrappers/test_openai.py @@ -1,5 +1,6 @@ import asyncio import time +from typing import Any, cast import braintrust import openai @@ -47,11 +48,11 @@ def export(self): with braintrust.start_span(name="parent-span") as parent_span: assert braintrust.current_span() == parent_span - processor.on_trace_start(trace) + processor.on_trace_start(cast(Any, trace)) created_span = processor._spans[trace.trace_id] assert braintrust.current_span() == created_span - processor.on_trace_end(trace) + processor.on_trace_end(cast(Any, trace)) assert braintrust.current_span() == parent_span spans = memory_logger.pop() @@ -109,7 +110,7 @@ def test_openai_responses_metrics(memory_logger): assert unwrapped_response assert unwrapped_response.output assert len(unwrapped_response.output) > 0 - unwrapped_content = unwrapped_response.output[0].content[0].text + unwrapped_content = cast(Any, unwrapped_response.output[0]).content[0].text # No spans should be generated with unwrapped client assert not memory_logger.pop() @@ -128,7 +129,7 @@ def test_openai_responses_metrics(memory_logger): # Extract content from output field assert response.output assert len(response.output) > 0 - wrapped_content = response.output[0].content[0].text + wrapped_content = cast(Any, response.output[0]).content[0].text # Both should contain a numeric response for the math question assert "24" in unwrapped_content or "twenty-four" in unwrapped_content.lower() @@ -479,7 +480,7 @@ def test_openai_chat_with_system_prompt(memory_logger): assert response assert response.choices - assert "24" in response.choices[0].message.content + assert "24" in cast(str, response.choices[0].message.content) if not is_wrapped: assert not memory_logger.pop() @@ -620,7 +621,7 @@ async def test_openai_responses_async(memory_logger): assert len(resp.output) > 0 # Extract the text from the output - content = resp.output[0].content[0].text + content = cast(Any, resp.output[0]).content[0].text # Verify response contains correct answer assert "24" in content or "twenty-four" in content.lower() @@ -793,7 +794,7 @@ async def test_openai_chat_async_with_system_prompt(memory_logger): assert response assert response.choices - assert "24" in response.choices[0].message.content + assert "24" in cast(str, response.choices[0].message.content) if not is_wrapped: assert not memory_logger.pop() @@ -1000,7 +1001,7 @@ async def test_openai_response_streaming_async(memory_logger): stream = await client.responses.create(model=TEST_MODEL, input="What's 12 + 12?", stream=True) chunks = [] - async for chunk in stream: + async for chunk in cast(Any, stream): if chunk.type == "response.output_text.delta": chunks.append(chunk.delta) end = time.time() @@ -1133,7 +1134,7 @@ def test_openai_responses_not_given_filtering(memory_logger): assert response assert response.output assert len(response.output) > 0 - content = response.output[0].content[0].text + content = cast(Any, response.output[0]).content[0].text assert "24" in content or "twenty-four" in content.lower() # Check the logged span @@ -1225,7 +1226,7 @@ def test_openai_responses_with_raw_response_create(memory_logger): instructions="Just the number please", ) assert raw.headers # HTTP response headers are accessible - response = raw.parse() + response = cast(Any, raw.parse()) assert response.output content = response.output[0].content[0].text assert "24" in content or "twenty-four" in content.lower() @@ -1245,7 +1246,7 @@ def test_openai_responses_with_raw_response_create(memory_logger): assert raw.headers response = raw.parse() assert response.output - content = response.output[0].content[0].text + content = cast(Any, response.output[0]).content[0].text assert "24" in content or "twenty-four" in content.lower() # A span must have been recorded with correct metrics and metadata. @@ -1276,7 +1277,7 @@ def test_openai_responses_with_raw_response_create_stream(memory_logger): ) assert raw.headers chunks = [] - for chunk in raw.parse(): + for chunk in cast(Any, raw.parse()): if chunk.type == "response.output_text.delta": chunks.append(chunk.delta) assert "24" in "".join(chunks) or "twenty-four" in "".join(chunks).lower() @@ -1364,7 +1365,7 @@ async def test_openai_responses_with_raw_response_async(memory_logger): instructions="Just the number please", ) assert raw.headers - response = raw.parse() + response = cast(Any, raw.parse()) assert response.output content = response.output[0].content[0].text assert "24" in content or "twenty-four" in content.lower() @@ -1382,7 +1383,7 @@ async def test_openai_responses_with_raw_response_async(memory_logger): assert raw.headers response = raw.parse() assert response.output - content = response.output[0].content[0].text + content = cast(Any, response.output[0]).content[0].text assert "24" in content or "twenty-four" in content.lower() spans = memory_logger.pop() @@ -1412,7 +1413,7 @@ async def test_openai_responses_with_raw_response_create_stream_async(memory_log ) assert raw.headers chunks = [] - async for chunk in raw.parse(): + async for chunk in cast(Any, raw.parse()): if chunk.type == "response.output_text.delta": chunks.append(chunk.delta) assert "24" in "".join(chunks) or "twenty-four" in "".join(chunks).lower() @@ -1487,7 +1488,7 @@ def test_openai_parallel_tool_calls(memory_logger): for client in clients: start = time.time() - resp = client.chat.completions.create( + resp = cast(Any, client).chat.completions.create( model=TEST_MODEL, messages=[{"role": "user", "content": "What's the weather in New York and the time in Tokyo?"}], tools=tools, @@ -1498,7 +1499,7 @@ def test_openai_parallel_tool_calls(memory_logger): if stream: # Consume the stream - for chunk in resp: # type: ignore + for chunk in resp: # Exhaust the stream pass @@ -1912,8 +1913,8 @@ def export(self): trace = MockTrace("test-trace", "Test Trace", {"conversation_id": "test-12345"}) # Execute trace lifecycle - processor.on_trace_start(trace) - processor.on_trace_end(trace) + processor.on_trace_start(cast(Any, trace)) + processor.on_trace_end(cast(Any, trace)) # Verify metadata was logged to root span spans = memory_logger.pop() From 8b056186b2aebcce857ca7fe1f7c6904b02ba623 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Fri, 20 Mar 2026 22:50:07 +0000 Subject: [PATCH 3/3] chore: the agent added cast everywhere, removing them --- py/src/braintrust/cli/install/api.py | 23 +++---- py/src/braintrust/devserver/server.py | 32 ++++----- py/src/braintrust/framework.py | 34 +++++----- py/src/braintrust/logger.py | 70 +++++++++++--------- py/src/braintrust/oai.py | 14 ++-- py/src/braintrust/otel/__init__.py | 31 +++++---- py/src/braintrust/otel/context.py | 16 ++--- py/src/braintrust/parameters.py | 25 ++++--- py/src/braintrust/serializable_data_class.py | 12 ++-- py/src/braintrust/span_identifier_v3.py | 8 +-- py/src/braintrust/span_identifier_v4.py | 10 +-- py/src/braintrust/test_bt_json.py | 8 +-- py/src/braintrust/test_logger.py | 8 +-- py/src/braintrust/wrappers/litellm.py | 10 +-- py/src/braintrust/wrappers/test_openai.py | 38 +++++------ 15 files changed, 164 insertions(+), 175 deletions(-) diff --git a/py/src/braintrust/cli/install/api.py b/py/src/braintrust/cli/install/api.py index fda1a507..a34f5d59 100644 --- a/py/src/braintrust/cli/install/api.py +++ b/py/src/braintrust/cli/install/api.py @@ -378,11 +378,9 @@ def main(args): _logger.info(f"Stack with name {args.name} has been created with status: {status['StackStatus']}") exit(0) - from typing import cast - status_any = cast(Any, status) - _logger.info(f"Stack with name {args.name} has status: {status_any['StackStatus']}") + _logger.info(f"Stack with name {args.name} has status: {status['StackStatus']}") - if not ("_COMPLETE" in status_any["StackStatus"] or "_FAILED" in status_any["StackStatus"]): + if not ("_COMPLETE" in status["StackStatus"] or "_FAILED" in status["StackStatus"]): _logger.info(f"Please re-run this command once the stack has finished creating or updating") exit(0) @@ -403,7 +401,7 @@ def main(args): new_template = cloudformation.get_template_summary(TemplateURL=template) new_params = set(x["ParameterKey"] for x in new_template["Parameters"]) else: - new_params = set(x["ParameterKey"] for x in cast(Any, status)["Parameters"]) + new_params = set(x["ParameterKey"] for x in status["Parameters"]) stack = cloudformation.describe_stacks(StackName=args.name)["Stacks"][0] try: @@ -487,17 +485,16 @@ def main(args): if len(org_info) == 1: org_info = org_info[0] - org_info_any = cast(Any, org_info) - if org_info_any and (universal_url and org_info_any["api_url"] != universal_url): + if org_info and (universal_url and org_info["api_url"] != universal_url): if args.update_stack_url: - _logger.info(f"Will update org {org_info_any['name']}'s urls.") + _logger.info(f"Will update org {org_info['name']}'s urls.") _logger.info(f" They are currently set to:") - _logger.info(f" API URL: {org_info_any['api_url']}") - _logger.info(f" Proxy URL: {org_info_any['proxy_url']}") + _logger.info(f" API URL: {org_info['api_url']}") + _logger.info(f" Proxy URL: {org_info['proxy_url']}") _logger.info(f"And will update them to:") - patch_args = {"id": org_info_any["id"]} - if universal_url and org_info_any["api_url"] != universal_url: + patch_args = {"id": org_info["id"]} + if universal_url and org_info["api_url"] != universal_url: patch_args["api_url"] = universal_url patch_args["is_universal_api"] = True _logger.info(f" API URL: {universal_url}") @@ -514,6 +511,6 @@ def main(args): ) else: _logger.info(f"Stack URL differs from organization API URL:") - _logger.info(f" Current API URL: {org_info_any['api_url']}") + _logger.info(f" Current API URL: {org_info['api_url']}") _logger.info(f" Stack Universal URL: {universal_url}") _logger.info(f"To update the organization's API URL, rerun with --update-stack-url flag") diff --git a/py/src/braintrust/devserver/server.py b/py/src/braintrust/devserver/server.py index 457a5389..4b3ff79a 100644 --- a/py/src/braintrust/devserver/server.py +++ b/py/src/braintrust/devserver/server.py @@ -2,7 +2,7 @@ import json import sys import textwrap -from typing import Any, cast +from typing import Any try: @@ -183,21 +183,17 @@ async def run_eval(request: Request) -> JSONResponse | StreamingResponse: async def task(input: Any, hooks: EvalHooks[Any]): task_hooks = hooks if validated_parameters is None else _ParameterOverrideHooks(hooks, validated_parameters) - task_hooks_typed = cast(EvalHooks[Any], task_hooks) if bt_iscoroutinefunction(evaluator.task): - result = await cast(Any, evaluator.task)(input, task_hooks_typed) + result = await evaluator.task(input, task_hooks) else: - result = cast(Any, evaluator.task)(input, task_hooks_typed) + result = evaluator.task(input, task_hooks) task_hooks.report_progress( - cast( - Any, - { - "format": "code", - "output_type": "completion", - "event": "json_delta", - "data": json.dumps(result), - }, - ) + { + "format": "code", + "output_type": "completion", + "event": "json_delta", + "data": json.dumps(result), + } ) return result @@ -216,7 +212,7 @@ def stream_fn(event: SSEProgressEvent): parent = eval_data.get("parent") if parent: - parent = parse_parent(cast(str | dict | None, parent)) + parent = parse_parent(parent) eval_kwargs = { k: v for (k, v) in evaluator.__dict__.items() if k not in ["eval_name", "project_name", "parameter_values"] @@ -226,14 +222,14 @@ def stream_fn(event: SSEProgressEvent): try: eval_task = asyncio.create_task( - cast(Any, EvalAsync)( + EvalAsync( name=eval_data["name"], **{ **eval_kwargs, "state": state, "scores": evaluator.scores + [ - make_scorer(state, cast(str, score["name"]), cast(Any, score["function_id"]), ctx.project_id) + make_scorer(state, score["name"], score["function_id"], ctx.project_id) for score in eval_data.get("scores", []) ], "stream": stream_fn, @@ -314,8 +310,8 @@ def create_app(evaluators: list[Evaluator[Any, Any]], org_name: str | None = Non app = Starlette(routes=routes) # Add middlewares in reverse order (last added is executed first) - cast(Any, app).add_middleware(CheckAuthorizedMiddleware, allowed_org_name=org_name) - cast(Any, app).add_middleware(AuthorizationMiddleware) + app.add_middleware(CheckAuthorizedMiddleware, allowed_org_name=org_name) + app.add_middleware(AuthorizationMiddleware) app.add_middleware(create_cors_middleware()) return app diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index 3871f667..69be4d93 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -20,7 +20,6 @@ Optional, TypeVar, Union, - cast, ) from tqdm.asyncio import tqdm as async_tqdm @@ -740,7 +739,7 @@ async def make_empty_summary(): "Must specify a reporter object, not a name. Can only specify reporter names when running 'braintrust eval'" ) - reporter = cast(Any, reporter or default_reporter) + reporter = reporter or default_reporter if base_experiment_name is None and isinstance(evaluator.data, BaseExperiment): base_experiment_name = evaluator.data.name @@ -1384,9 +1383,9 @@ async def await_or_run_scorer(root_span, scorer, name, **kwargs): raise ValueError( f"When returning an array of scores, each score must be a valid Score object. Got: {s}" ) - result = cast(list[Score], list(result)) + result = list(result) elif is_score(result): - result = [cast(Score, result)] + result = [result] else: result = [Score(name=name, score=result)] @@ -1452,10 +1451,10 @@ async def run_evaluator_task(datum, trial_index=0): ) if experiment: - root_span = cast(Any, experiment).start_span(**base_event) + root_span = experiment.start_span(**base_event) else: # In most cases this will be a no-op span, but if the parent is set, it will use that ctx. - root_span = cast(Any, start_span)(state=state, **base_event) + root_span = start_span(state=state, **base_event) with root_span: try: @@ -1464,15 +1463,12 @@ def report_progress(event: TaskProgressEvent): if not stream: return stream( - cast( - Any, - SSEProgressEvent( - id=root_span.id, - origin=cast(Any, origin), - name=evaluator.eval_name, - object_type="task", - **event, - ), + SSEProgressEvent( + id=root_span.id, + origin=origin, + name=evaluator.eval_name, + object_type="task", + **event, ) ) @@ -1480,9 +1476,9 @@ def report_progress(event: TaskProgressEvent): metadata, expected=datum.expected, trial_index=trial_index, - tags=cast(Any, tags), + tags=tags, report_progress=report_progress, - parameters=cast(Any, resolved_evaluator_parameters), + parameters=resolved_evaluator_parameters, ) # Check if the task takes a hooks argument @@ -1615,7 +1611,7 @@ async def ensure_spans_flushed(): ) except Exception as e: exc_type, exc_value, tb = sys.exc_info() - root_span.log(error=stringify_exception(cast(Any, exc_type), cast(Any, exc_value), tb)) + root_span.log(error=stringify_exception(exc_type, exc_value, tb)) error = e # Python3.10 has a different set of arguments to format_exception than earlier versions, @@ -1626,7 +1622,7 @@ async def ensure_spans_flushed(): input=datum.input, expected=datum.expected, metadata=metadata, - tags=cast(Any, tags), + tags=tags, output=output, scores={ **( diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index c12b9868..19bdeff0 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -68,7 +68,17 @@ from .merge_row_batch import batch_items, merge_row_batch from .object import DEFAULT_IS_LEGACY_DATASET, ensure_dataset_record from .parameters import RemoteEvalParameters -from .prompt import BRAINTRUST_PARAMS, ImagePart, PromptBlockData, PromptChatBlock, PromptCompletionBlock, PromptData, PromptMessage, PromptSchema, TextPart +from .prompt import ( + BRAINTRUST_PARAMS, + ImagePart, + PromptBlockData, + PromptChatBlock, + PromptCompletionBlock, + PromptData, + PromptMessage, + PromptSchema, + TextPart, +) from .prompt_cache.disk_cache import DiskCache from .prompt_cache.lru_cache import LRUCache from .prompt_cache.parameters_cache import ParametersCache @@ -1039,7 +1049,9 @@ def __init__(self, api_conn: LazyValue[HTTPConnection]): except: self.queue_drop_logging_period = 60 - self._queue_drop_logging_state: dict[str, Any] = dict(lock=threading.Lock(), num_dropped=0, last_logged_timestamp=0) + self._queue_drop_logging_state: dict[str, Any] = dict( + lock=threading.Lock(), num_dropped=0, last_logged_timestamp=0 + ) try: self.failed_publish_payloads_dir = os.environ["BRAINTRUST_FAILED_PUBLISH_PAYLOADS_DIR"] @@ -1690,7 +1702,7 @@ def compute_metadata(): merged_git_metadata_settings = state.git_metadata_settings if git_metadata_settings is not None: merged_git_metadata_settings = GitMetadataSettings.merge( - cast(Any, merged_git_metadata_settings), git_metadata_settings + merged_git_metadata_settings, git_metadata_settings ) repo_info_arg = get_repo_info(merged_git_metadata_settings) @@ -1707,14 +1719,13 @@ def compute_metadata(): if dataset is not None: if isinstance(dataset, dict): # Simple {"id": ..., "version": ...} dict - dataset_dict = cast(dict[str, Any], dataset) - args["dataset_id"] = dataset_dict["id"] - if "version" in dataset_dict: - args["dataset_version"] = dataset_dict["version"] + args["dataset_id"] = dataset["id"] + if "version" in dataset: + args["dataset_version"] = dataset["version"] else: # Full Dataset object - args["dataset_id"] = cast(Any, dataset).id - args["dataset_version"] = cast(Any, dataset).version + args["dataset_id"] = dataset.id + args["dataset_version"] = dataset.version parameters_ref = _get_parameters_ref(parameters) if parameters_ref is not None: @@ -1843,17 +1854,17 @@ def _compute_logger_metadata(project_name: str | None = None, project_id: str | ) resp_project = response["project"] return OrgProjectMetadata( - org_id=cast(str, org_id), + org_id=org_id, project=ObjectMetadata(id=resp_project["id"], name=resp_project["name"], full_info=resp_project), ) elif project_name is None: response = _state.app_conn().get_json("api/project", {"id": project_id}) return OrgProjectMetadata( - org_id=cast(str, org_id), project=ObjectMetadata(id=project_id, name=response["name"], full_info=response) + org_id=org_id, project=ObjectMetadata(id=project_id, name=response["name"], full_info=response) ) else: return OrgProjectMetadata( - org_id=cast(str, org_id), project=ObjectMetadata(id=project_id, name=project_name, full_info=dict()) + org_id=org_id, project=ObjectMetadata(id=project_id, name=project_name, full_info=dict()) ) @@ -2587,9 +2598,9 @@ def wrapper_sync_gen(*f_args, **f_kwargs): # We determine if the decorator is invoked bare or with arguments by # checking if the first positional argument to the decorator is a callable. if len(span_args) == 1 and len(span_kwargs) == 0 and callable(span_args[0]): - return cast(Any, decorator)(span_args[1:], span_kwargs, span_args[0]) + return decorator(span_args[1:], span_kwargs, cast(F, span_args[0])) else: - return cast(Any, partial(decorator, span_args, span_kwargs)) + return cast(Callable[[F], F], partial(decorator, span_args, span_kwargs)) def start_span( @@ -3139,7 +3150,7 @@ def error_wrapper() -> AttachmentStatus: status["error_message"] = str(e) request_params = { - "key": cast(Any, self._reference)["key"], + "key": self._reference["key"], "org_id": org_id, "status": status, } @@ -3687,8 +3698,8 @@ def compute_parent_object_id(): arg_parent_object_id = LazyValue(compute_parent_object_id, use_mutex=False) if parent_components.row_id: arg_parent_span_ids = ParentSpanIds( - span_id=cast(str, parent_components.span_id), - root_span_id=cast(str, parent_components.root_span_id), + span_id=parent_components.span_id, + root_span_id=parent_components.root_span_id, ) else: arg_parent_span_ids = None @@ -3741,17 +3752,14 @@ def __next__(self) -> _ExperimentDatasetEvent: continue output, expected = value.get("output"), value.get("expected") - ret: _ExperimentDatasetEvent = cast( - Any, - { + ret: _ExperimentDatasetEvent = { "input": value.get("input"), "expected": expected if expected is not None else output, "tags": value.get("tags"), "metadata": value.get("metadata"), "id": value["id"], "_xact_id": value["_xact_id"], - }, - ) + } return ret @@ -3985,7 +3993,7 @@ def summarize( self.flush() state = self._get_state() - project_url = f"{state.app_public_url}/app/{encode_uri_component(cast(str, state.org_name))}/p/{encode_uri_component(self.project.name)}" + project_url = f"{state.app_public_url}/app/{encode_uri_component(state.org_name)}/p/{encode_uri_component(self.project.name)}" experiment_url = f"{project_url}/experiments/{encode_uri_component(self.name)}" score_summary = {} @@ -4828,7 +4836,7 @@ def summarize(self, summarize_data: bool = True) -> "DatasetSummary": # includes the new experiment. self.flush() state = self._get_state() - project_url = f"{state.app_public_url}/app/{encode_uri_component(cast(str, state.org_name))}/p/{encode_uri_component(self.project.name)}" + project_url = f"{state.app_public_url}/app/{encode_uri_component(state.org_name)}/p/{encode_uri_component(self.project.name)}" dataset_url = f"{project_url}/datasets/{encode_uri_component(self.name)}" data_summary = None @@ -5067,7 +5075,7 @@ def from_prompt_data( @property def id(self) -> str: - return cast(str, self._lazy_metadata.get().id) + return self._lazy_metadata.get().id @property def name(self) -> str: @@ -5083,11 +5091,11 @@ def prompt(self) -> PromptBlockData | None: @property def version(self) -> str: - return cast(str, self._lazy_metadata.get()._xact_id) + return self._lazy_metadata.get()._xact_id @property def options(self) -> PromptOptions: - return cast(Any, self._lazy_metadata.get().prompt_data.options or {}) + return self._lazy_metadata.get().prompt_data.options or {} # Capture all metadata attributes which aren't covered by existing methods. def __getattr__(self, name: str) -> Any: @@ -5165,11 +5173,11 @@ def __len__(self) -> int: def __getitem__(self, x): if x == "prompt": - return cast(Any, self.prompt).prompt + return self.prompt.prompt elif x == "chat": - return cast(Any, self.prompt).messages + return self.prompt.messages elif x == "tools": - return cast(Any, self.prompt).tools + return self.prompt.tools else: return self.options[x] @@ -5202,7 +5210,7 @@ def lazy_init(self): @property def id(self) -> str: self.lazy_init() - return cast(str, self._id) + return self._id @property def name(self): diff --git a/py/src/braintrust/oai.py b/py/src/braintrust/oai.py index 09de540b..2e9c0895 100644 --- a/py/src/braintrust/oai.py +++ b/py/src/braintrust/oai.py @@ -4,7 +4,7 @@ import time import warnings from collections.abc import Callable -from typing import Any, cast +from typing import Any from wrapt import wrap_function_wrapper @@ -160,7 +160,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - create_response = cast(Callable[..., Any], self.create_fn)(*args, **kwargs) + create_response = self.create_fn(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() log_headers(create_response, span) @@ -213,7 +213,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - create_response = await cast(Callable[..., Any], self.acreate_fn)(*args, **kwargs) + create_response = await self.acreate_fn(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() @@ -415,7 +415,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - create_response = cast(Callable[..., Any], self.create_fn)(*args, **kwargs) + create_response = self.create_fn(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() log_headers(create_response, span) @@ -467,7 +467,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - create_response = await cast(Callable[..., Any], self.acreate_fn)(*args, **kwargs) + create_response = await self.acreate_fn(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() log_headers(create_response, span) @@ -656,7 +656,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any: with start_span( **merge_dicts(dict(name=self._name, span_attributes={"type": SpanTypeAttribute.LLM}), params) ) as span: - create_response = cast(Callable[..., Any], self._create_fn)(*args, **kwargs) + create_response = self._create_fn(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() log_headers(create_response, span) @@ -673,7 +673,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any: with start_span( **merge_dicts(dict(name=self._name, span_attributes={"type": SpanTypeAttribute.LLM}), params) ) as span: - create_response = await cast(Callable[..., Any], self._acreate_fn)(*args, **kwargs) + create_response = await self._acreate_fn(*args, **kwargs) if hasattr(create_response, "parse"): raw_response = create_response.parse() log_headers(create_response, span) diff --git a/py/src/braintrust/otel/__init__.py b/py/src/braintrust/otel/__init__.py index 12d2cb68..ec7212df 100644 --- a/py/src/braintrust/otel/__init__.py +++ b/py/src/braintrust/otel/__init__.py @@ -1,7 +1,7 @@ import logging import os import warnings -from typing import Any, cast +from typing import Any from urllib.parse import urljoin @@ -331,7 +331,7 @@ def _get_parent_otel_braintrust_parent(self, parent_context): if current_span and hasattr(current_span, "attributes") and current_span.attributes: # Check if parent span has braintrust.parent attribute - attributes = dict(cast(Any, current_span.attributes)) + attributes = dict(current_span.attributes) return attributes.get("braintrust.parent") return None @@ -441,8 +441,8 @@ def context_from_span_export(export_str: str): ) # Convert hex strings to OTEL integers - trace_id_int = int(cast(str, components.root_span_id), 16) - span_id_int = int(cast(str, components.span_id), 16) + trace_id_int = int(components.root_span_id, 16) + span_id_int = int(components.span_id, 16) # Create OTEL SpanContext marked as remote span_context = SpanContext( @@ -632,38 +632,37 @@ def parent_from_headers(headers: dict[str, str], propagator=None) -> str | None: return None if braintrust_parent: - braintrust_parent_str = cast(str, braintrust_parent) from braintrust.span_identifier_v3 import SpanObjectTypeV3 # Parse braintrust.parent format: "project_id:abc", "project_name:xyz", or "experiment_id:123" - if braintrust_parent_str.startswith("project_id:"): + if braintrust_parent.startswith("project_id:"): object_type = SpanObjectTypeV3.PROJECT_LOGS - object_id = braintrust_parent_str[len("project_id:") :] + object_id = braintrust_parent[len("project_id:"):] if not object_id: logging.error( - f"parent_from_headers: Invalid braintrust.parent format (empty project_id): {braintrust_parent_str}" + f"parent_from_headers: Invalid braintrust.parent format (empty project_id): {braintrust_parent}" ) return None - elif braintrust_parent_str.startswith("project_name:"): + elif braintrust_parent.startswith("project_name:"): object_type = SpanObjectTypeV3.PROJECT_LOGS - project_name = braintrust_parent_str[len("project_name:") :] + project_name = braintrust_parent[len("project_name:"):] if not project_name: logging.error( - f"parent_from_headers: Invalid braintrust.parent format (empty project_name): {braintrust_parent_str}" + f"parent_from_headers: Invalid braintrust.parent format (empty project_name): {braintrust_parent}" ) return None compute_args = {"project_name": project_name} - elif braintrust_parent_str.startswith("experiment_id:"): + elif braintrust_parent.startswith("experiment_id:"): object_type = SpanObjectTypeV3.EXPERIMENT - object_id = braintrust_parent_str[len("experiment_id:") :] + object_id = braintrust_parent[len("experiment_id:"):] if not object_id: logging.error( - f"parent_from_headers: Invalid braintrust.parent format (empty experiment_id): {braintrust_parent_str}" + f"parent_from_headers: Invalid braintrust.parent format (empty experiment_id): {braintrust_parent}" ) return None else: logging.error( - f"parent_from_headers: Invalid braintrust.parent format: {braintrust_parent_str}. " + f"parent_from_headers: Invalid braintrust.parent format: {braintrust_parent}. " "Expected format: 'project_id:ID', 'project_name:NAME', or 'experiment_id:ID'" ) return None @@ -671,7 +670,7 @@ def parent_from_headers(headers: dict[str, str], propagator=None) -> str | None: # Create SpanComponentsV4 and export as string # Set row_id to enable span_id/root_span_id (required for parent linking) components = SpanComponentsV4( - object_type=cast(Any, object_type), + object_type=object_type, object_id=object_id, compute_object_metadata_args=compute_args, row_id="otel", # Dummy row_id to enable span_id/root_span_id fields diff --git a/py/src/braintrust/otel/context.py b/py/src/braintrust/otel/context.py index 3dea2fdc..bb65be77 100644 --- a/py/src/braintrust/otel/context.py +++ b/py/src/braintrust/otel/context.py @@ -1,7 +1,7 @@ """Unified context management using OTEL's built-in context.""" import logging -from typing import Any, Optional, cast +from typing import Any, Optional from braintrust.context import ParentSpanIds, SpanInfo from braintrust.logger import Span @@ -39,8 +39,7 @@ def get_current_span_info(self) -> Optional["SpanInfo"]: # If there's a BT span stored AND the current OTEL span is a NonRecordingSpan # (which means it's our BT->OTEL wrapper), then return BT span info if bt_span and isinstance(current_span, trace.NonRecordingSpan): - bt_span_any = cast(Any, bt_span) - return SpanInfo(trace_id=bt_span_any.root_span_id, span_id=bt_span_any.span_id, span_object=bt_span) + return SpanInfo(trace_id=bt_span.root_span_id, span_id=bt_span.span_id, span_object=bt_span) else: # Return OTEL span info - this is a real OTEL span, not our wrapper otel_trace_id = format(span_context.trace_id, "032x") @@ -56,23 +55,22 @@ def set_current_span(self, span: Span) -> Any: # This is an OTEL span - it will manage its own context return None else: - bt_span = cast(Any, span) try: - trace_id_int = int(bt_span.root_span_id, 16) + trace_id_int = int(span.root_span_id, 16) except ValueError: - log.debug(f"Invalid root_span_id: {bt_span.root_span_id}") + log.debug(f"Invalid root_span_id: {span.root_span_id}") return None try: - span_id_int = int(bt_span.span_id, 16) + span_id_int = int(span.span_id, 16) except ValueError: - log.debug(f"Invalid span_id: {bt_span.span_id}") + log.debug(f"Invalid span_id: {span.span_id}") return None # This is a BT span - store it in OTEL context AND set as current OTEL span # First store the BT span ctx = context.set_value("braintrust_span", span) - parent_value = bt_span._get_otel_parent() + parent_value = span._get_otel_parent() ctx = context.set_value("braintrust.parent", parent_value, ctx) otel_span_context = SpanContext( diff --git a/py/src/braintrust/parameters.py b/py/src/braintrust/parameters.py index dd2581d3..c38296d4 100644 --- a/py/src/braintrust/parameters.py +++ b/py/src/braintrust/parameters.py @@ -2,7 +2,7 @@ from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast +from typing import TYPE_CHECKING, Any, Literal, TypedDict from jsonschema import Draft7Validator from jsonschema.exceptions import ValidationError as JSONSchemaValidationError @@ -151,7 +151,7 @@ def _resolve_local_json_schema_refs( def _serialize_pydantic_parameter_schema(schema: Any) -> dict[str, Any]: schema_json = _pydantic_to_json_schema(schema) - schema_json = cast(dict[str, Any], _resolve_local_json_schema_refs(schema_json, schema_json)) + schema_json = _resolve_local_json_schema_refs(schema_json, schema_json) schema_json.pop("$defs", None) schema_json.pop("definitions", None) fields = _get_pydantic_fields(schema) @@ -279,29 +279,28 @@ def _validate_local_parameters( elif schema is None: result[name] = value elif _is_pydantic_model(schema): - schema_cls = cast(Any, schema) fields = _get_pydantic_fields(schema) if len(fields) == 1 and "value" in fields: if value is None: try: - default_instance = schema_cls() + default_instance = schema() result[name] = default_instance.value except Exception as exc: raise ValueError(f"Parameter '{name}' is required") from exc elif hasattr(schema, "parse_obj"): - result[name] = schema_cls.parse_obj({"value": value}).value + result[name] = schema.parse_obj({"value": value}).value else: - result[name] = schema_cls.model_validate({"value": value}).value + result[name] = schema.model_validate({"value": value}).value else: if value is None: try: - result[name] = schema_cls() + result[name] = schema() except Exception as exc: raise ValueError(f"Parameter '{name}' is required") from exc elif hasattr(schema, "parse_obj"): - result[name] = schema_cls.parse_obj(value) + result[name] = schema.parse_obj(value) else: - result[name] = schema_cls.model_validate(value) + result[name] = schema.model_validate(value) else: result[name] = value except JSONSchemaValidationError as exc: @@ -417,10 +416,10 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema: property_schema["description"] = schema["description"] properties[name] = property_schema elif _is_model_parameter(schema): - property_schema = cast(dict[str, Any], { - "type": "string", - "x-bt-type": "model", - }) + property_schema = { + "type": "string", + "x-bt-type": "model", + } if "default" in schema: property_schema["default"] = schema.get("default") if schema.get("description") is not None: diff --git a/py/src/braintrust/serializable_data_class.py b/py/src/braintrust/serializable_data_class.py index 060d7221..044cdfab 100644 --- a/py/src/braintrust/serializable_data_class.py +++ b/py/src/braintrust/serializable_data_class.py @@ -1,12 +1,12 @@ import dataclasses import json -from typing import Any, Union, cast, get_origin +from typing import Any, Union, get_origin class SerializableDataClass: def as_dict(self): """Serialize the object to a dictionary.""" - return dataclasses.asdict(cast(Any, self)) + return dataclasses.asdict(self) def as_json(self, **kwargs): """Serialize the object to JSON.""" @@ -33,12 +33,8 @@ def from_dict_deep(cls, d: dict): if k not in fields: continue - field_type = cast(Any, fields[k].type) - if ( - isinstance(v, dict) - and isinstance(field_type, type) - and issubclass(field_type, SerializableDataClass) - ): + field_type = fields[k].type + if isinstance(v, dict) and isinstance(field_type, type) and issubclass(field_type, SerializableDataClass): filtered[k] = field_type.from_dict_deep(v) elif get_origin(field_type) == Union: for t in field_type.__args__: diff --git a/py/src/braintrust/span_identifier_v3.py b/py/src/braintrust/span_identifier_v3.py index 7ad600ea..01ec0bab 100644 --- a/py/src/braintrust/span_identifier_v3.py +++ b/py/src/braintrust/span_identifier_v3.py @@ -6,7 +6,7 @@ import dataclasses import json from enum import Enum -from typing import Any, cast +from typing import Any from uuid import UUID from .span_identifier_v2 import SpanComponentsV2 @@ -218,7 +218,7 @@ def _from_json_obj(json_obj: dict) -> "SpanComponentsV3": **json_obj, "object_type": SpanObjectTypeV3(json_obj["object_type"]), } - return SpanComponentsV3(**cast(Any, kwargs)) + return SpanComponentsV3(**kwargs) def parse_parent(parent: str | dict | None) -> str | None: @@ -244,7 +244,7 @@ def parse_parent(parent: str | dict | None) -> str | None: "project_logs": SpanObjectTypeV3.PROJECT_LOGS, } - object_type = object_type_map.get(cast(str, parent.get("object_type"))) + object_type = object_type_map.get(parent.get("object_type")) if not object_type: raise ValueError(f"Invalid object_type: {parent.get('object_type')}") @@ -276,6 +276,6 @@ def parse_parent(parent: str | dict | None) -> str | None: if "propagated_event" in parent: kwargs["propagated_event"] = parent.get("propagated_event") - return SpanComponentsV3(**cast(Any, kwargs)).to_str() + return SpanComponentsV3(**kwargs).to_str() else: return None diff --git a/py/src/braintrust/span_identifier_v4.py b/py/src/braintrust/span_identifier_v4.py index 39106824..49b37f06 100644 --- a/py/src/braintrust/span_identifier_v4.py +++ b/py/src/braintrust/span_identifier_v4.py @@ -5,7 +5,7 @@ import dataclasses import json from enum import Enum -from typing import Any, cast +from typing import Any from .span_identifier_v3 import ( SpanComponentsV3, @@ -124,7 +124,7 @@ def add_hex_field(orig_val, field_id): hex_bytes, is_hex = None, False if is_hex: - hex_entries.append(bytes([field_id.value]) + cast(bytes, hex_bytes)) + hex_entries.append(bytes([field_id.value]) + hex_bytes) else: json_obj[_FIELDS_ID_TO_NAME[field_id]] = orig_val @@ -231,7 +231,7 @@ def _from_json_obj(json_obj: dict) -> "SpanComponentsV4": **json_obj, "object_type": SpanObjectTypeV3(json_obj["object_type"]), } - return SpanComponentsV4(**cast(Any, kwargs)) + return SpanComponentsV4(**kwargs) def parse_parent(parent: str | dict | None) -> str | None: @@ -246,7 +246,7 @@ def parse_parent(parent: str | dict | None) -> str | None: "project_logs": SpanObjectTypeV3.PROJECT_LOGS, } - object_type = object_type_map.get(cast(str, parent.get("object_type"))) + object_type = object_type_map.get(parent.get("object_type")) if not object_type: raise ValueError(f"Invalid object_type: {parent.get('object_type')}") @@ -276,6 +276,6 @@ def parse_parent(parent: str | dict | None) -> str | None: if "propagated_event" in parent: kwargs["propagated_event"] = parent.get("propagated_event") - return SpanComponentsV4(**cast(Any, kwargs)).to_str() + return SpanComponentsV4(**kwargs).to_str() else: return None diff --git a/py/src/braintrust/test_bt_json.py b/py/src/braintrust/test_bt_json.py index b867caca..bfc9fe54 100644 --- a/py/src/braintrust/test_bt_json.py +++ b/py/src/braintrust/test_bt_json.py @@ -3,7 +3,7 @@ # pyright: reportPrivateUsage=false import json import warnings -from typing import Any, cast +from typing import Any from unittest import TestCase import pytest @@ -47,7 +47,7 @@ def test_deep_copy_mutation_independence(self): self.assertEqual(original["nested_dict"]["inner"], "data") self.assertEqual(original["nested_dict"]["deep"]["level"], 3) self.assertEqual(original["nested_list"][0], 1) - self.assertEqual(cast(list, original["nested_list"])[2][0], 3) + self.assertEqual(original["nested_list"][2][0], 3) self.assertEqual(original["nested_in_list"][0]["key"], "val") # Add new keys to copy @@ -381,7 +381,7 @@ def test_to_bt_safe_attachments(self): "filename": "readonly.txt", "content_type": "text/plain", } - readonly = ReadonlyAttachment(cast(Any, reference)) + readonly = ReadonlyAttachment(reference) result_readonly = _to_bt_safe(readonly) self.assertEqual(result_readonly, reference) self.assertIsNot(result_readonly, readonly) @@ -503,7 +503,7 @@ def test_bt_safe_deep_copy_mixed_attachment_types(self): "filename": "readonly.txt", "content_type": "text/plain", } - readonly_attachment = ReadonlyAttachment(cast(Any, reference)) + readonly_attachment = ReadonlyAttachment(reference) original = { "base": base_attachment, diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index 226f6195..68464bcf 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -5,7 +5,7 @@ import logging import os import time -from typing import Any, AsyncGenerator, List, cast +from typing import Any, AsyncGenerator, List from unittest import TestCase from unittest.mock import MagicMock, patch @@ -1156,7 +1156,7 @@ def test_span_link_with_resolved_experiment(with_simulate_login, with_memory_log assert eid == "test-experiment-id" span = experiment.start_span(name="test-span") - cast(Any, span).parent_object_id = id_lazy_value + span.parent_object_id = id_lazy_value span.end() link = span.link() @@ -1197,7 +1197,7 @@ def test_experiment_span_link_uses_env_vars_when_logged_out(with_memory_logger): # Create span with resolved experiment ID span = experiment.start_span(name="test-span") - cast(Any, span).parent_object_id = LazyValue(lambda: "test-exp-id", use_mutex=False) + span.parent_object_id = LazyValue(lambda: "test-exp-id", use_mutex=False) span.end() link = span.link() @@ -3623,7 +3623,7 @@ def test_span_exit_logs_exception_group_sub_exceptions(with_memory_logger): init_test_logger(__name__) with pytest.raises(exceptiongroup.ExceptionGroup): - with cast(Any, braintrust.current_logger()).start_span(name="eg-span"): + with braintrust.current_logger().start_span(name="eg-span"): raise _raise_test_exception_group() logs = with_memory_logger.pop() diff --git a/py/src/braintrust/wrappers/litellm.py b/py/src/braintrust/wrappers/litellm.py index 12d0e25c..236df998 100644 --- a/py/src/braintrust/wrappers/litellm.py +++ b/py/src/braintrust/wrappers/litellm.py @@ -3,7 +3,7 @@ import time from collections.abc import AsyncGenerator, Callable, Generator from types import TracebackType -from typing import Any, cast +from typing import Any from braintrust.logger import Span, start_span from braintrust.span_types import SpanTypeAttribute @@ -145,7 +145,7 @@ def completion(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - completion_response = cast(Any, self.completion_fn)(*args, **kwargs) + completion_response = self.completion_fn(*args, **kwargs) # if hasattr(completion_response, "parse"): # raw_response = completion_response.parse() # log_headers(completion_response, span) @@ -175,7 +175,7 @@ async def acompletion(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - completion_response = await cast(Any, self.acompletion_fn)(*args, **kwargs) + completion_response = await self.acompletion_fn(*args, **kwargs) # if hasattr(completion_response, "parse"): # raw_response = completion_response.parse() @@ -323,7 +323,7 @@ def responses(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - response = cast(Any, self.responses_fn)(*args, **kwargs) + response = self.responses_fn(*args, **kwargs) if is_streaming: should_end = False @@ -346,7 +346,7 @@ async def aresponses(self, *args: Any, **kwargs: Any) -> Any: try: start = time.time() - response = await cast(Any, self.aresponses_fn)(*args, **kwargs) + response = await self.aresponses_fn(*args, **kwargs) if is_streaming: should_end = False diff --git a/py/src/braintrust/wrappers/test_openai.py b/py/src/braintrust/wrappers/test_openai.py index bd729936..7b4b4a25 100644 --- a/py/src/braintrust/wrappers/test_openai.py +++ b/py/src/braintrust/wrappers/test_openai.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Any, cast +from typing import Any import braintrust import openai @@ -48,11 +48,11 @@ def export(self): with braintrust.start_span(name="parent-span") as parent_span: assert braintrust.current_span() == parent_span - processor.on_trace_start(cast(Any, trace)) + processor.on_trace_start(trace) created_span = processor._spans[trace.trace_id] assert braintrust.current_span() == created_span - processor.on_trace_end(cast(Any, trace)) + processor.on_trace_end(trace) assert braintrust.current_span() == parent_span spans = memory_logger.pop() @@ -110,7 +110,7 @@ def test_openai_responses_metrics(memory_logger): assert unwrapped_response assert unwrapped_response.output assert len(unwrapped_response.output) > 0 - unwrapped_content = cast(Any, unwrapped_response.output[0]).content[0].text + unwrapped_content = unwrapped_response.output[0].content[0].text # No spans should be generated with unwrapped client assert not memory_logger.pop() @@ -129,7 +129,7 @@ def test_openai_responses_metrics(memory_logger): # Extract content from output field assert response.output assert len(response.output) > 0 - wrapped_content = cast(Any, response.output[0]).content[0].text + wrapped_content = response.output[0].content[0].text # Both should contain a numeric response for the math question assert "24" in unwrapped_content or "twenty-four" in unwrapped_content.lower() @@ -480,7 +480,7 @@ def test_openai_chat_with_system_prompt(memory_logger): assert response assert response.choices - assert "24" in cast(str, response.choices[0].message.content) + assert "24" in response.choices[0].message.content if not is_wrapped: assert not memory_logger.pop() @@ -621,7 +621,7 @@ async def test_openai_responses_async(memory_logger): assert len(resp.output) > 0 # Extract the text from the output - content = cast(Any, resp.output[0]).content[0].text + content = resp.output[0].content[0].text # Verify response contains correct answer assert "24" in content or "twenty-four" in content.lower() @@ -794,7 +794,7 @@ async def test_openai_chat_async_with_system_prompt(memory_logger): assert response assert response.choices - assert "24" in cast(str, response.choices[0].message.content) + assert "24" in response.choices[0].message.content if not is_wrapped: assert not memory_logger.pop() @@ -1001,7 +1001,7 @@ async def test_openai_response_streaming_async(memory_logger): stream = await client.responses.create(model=TEST_MODEL, input="What's 12 + 12?", stream=True) chunks = [] - async for chunk in cast(Any, stream): + async for chunk in stream: if chunk.type == "response.output_text.delta": chunks.append(chunk.delta) end = time.time() @@ -1134,7 +1134,7 @@ def test_openai_responses_not_given_filtering(memory_logger): assert response assert response.output assert len(response.output) > 0 - content = cast(Any, response.output[0]).content[0].text + content = response.output[0].content[0].text assert "24" in content or "twenty-four" in content.lower() # Check the logged span @@ -1226,7 +1226,7 @@ def test_openai_responses_with_raw_response_create(memory_logger): instructions="Just the number please", ) assert raw.headers # HTTP response headers are accessible - response = cast(Any, raw.parse()) + response = raw.parse() assert response.output content = response.output[0].content[0].text assert "24" in content or "twenty-four" in content.lower() @@ -1246,7 +1246,7 @@ def test_openai_responses_with_raw_response_create(memory_logger): assert raw.headers response = raw.parse() assert response.output - content = cast(Any, response.output[0]).content[0].text + content = response.output[0].content[0].text assert "24" in content or "twenty-four" in content.lower() # A span must have been recorded with correct metrics and metadata. @@ -1277,7 +1277,7 @@ def test_openai_responses_with_raw_response_create_stream(memory_logger): ) assert raw.headers chunks = [] - for chunk in cast(Any, raw.parse()): + for chunk in raw.parse(): if chunk.type == "response.output_text.delta": chunks.append(chunk.delta) assert "24" in "".join(chunks) or "twenty-four" in "".join(chunks).lower() @@ -1365,7 +1365,7 @@ async def test_openai_responses_with_raw_response_async(memory_logger): instructions="Just the number please", ) assert raw.headers - response = cast(Any, raw.parse()) + response = raw.parse() assert response.output content = response.output[0].content[0].text assert "24" in content or "twenty-four" in content.lower() @@ -1383,7 +1383,7 @@ async def test_openai_responses_with_raw_response_async(memory_logger): assert raw.headers response = raw.parse() assert response.output - content = cast(Any, response.output[0]).content[0].text + content = response.output[0].content[0].text assert "24" in content or "twenty-four" in content.lower() spans = memory_logger.pop() @@ -1413,7 +1413,7 @@ async def test_openai_responses_with_raw_response_create_stream_async(memory_log ) assert raw.headers chunks = [] - async for chunk in cast(Any, raw.parse()): + async for chunk in raw.parse(): if chunk.type == "response.output_text.delta": chunks.append(chunk.delta) assert "24" in "".join(chunks) or "twenty-four" in "".join(chunks).lower() @@ -1488,7 +1488,7 @@ def test_openai_parallel_tool_calls(memory_logger): for client in clients: start = time.time() - resp = cast(Any, client).chat.completions.create( + resp = client.chat.completions.create( model=TEST_MODEL, messages=[{"role": "user", "content": "What's the weather in New York and the time in Tokyo?"}], tools=tools, @@ -1913,8 +1913,8 @@ def export(self): trace = MockTrace("test-trace", "Test Trace", {"conversation_id": "test-12345"}) # Execute trace lifecycle - processor.on_trace_start(cast(Any, trace)) - processor.on_trace_end(cast(Any, trace)) + processor.on_trace_start(trace) + processor.on_trace_end(trace) # Verify metadata was logged to root span spans = memory_logger.pop()