Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
long_description = f.read()

install_requires = [
"aiohttp",
"GitPython",
"requests",
"chevron",
Expand Down
233 changes: 232 additions & 1 deletion py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import atexit
import base64
import concurrent.futures
Expand Down Expand Up @@ -34,7 +35,9 @@
cast,
overload,
)
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode
from urllib.request import Request, urlopen

import chevron
import exceptiongroup
Expand Down Expand Up @@ -781,6 +784,77 @@ def post_json(self, object_type: str, args: Mapping[str, Any] | None = None) ->
response_raise_for_status(resp)
return resp.json()

async def aget_json(
self, object_type: str, args: Optional[Mapping[str, Any]] = None, retries: int = 0
) -> Mapping[str, Any] | None:
"""
Async version of get_json. Makes a true async HTTP GET request and returns JSON response.
"""
tries = retries + 1

for i in range(tries):
try:
# Build URL using the same logic as sync version
url = _urljoin(self.base_url, f"/{object_type}")
if args:
url += "?" + urlencode(_strip_nones(args))

# check if aiohttp is available, otherwise fall back to asyncio approach
from importlib.util import find_spec
if find_spec("aiohttp") is None:
# Fall back to asyncio + urllib approach
return await self._make_asyncio_request(url)
return await self._make_aiohttp_request(url)

except Exception as e:
if i < tries - 1:
_logger.warning(f"Retrying async API request {object_type} {args}: {e}")
await asyncio.sleep(0.1 * (i + 1)) # Progressive backoff
continue
raise

# Needed for type checking.
raise Exception("unreachable")

async def _make_aiohttp_request(self, url: str) -> Mapping[str, Any]:
"""Make async HTTP request using aiohttp"""
import aiohttp

headers = {}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"

async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status >= 400:
error_text = await response.text()
raise Exception(f"HTTP {response.status}: {error_text}")
return await response.json()

async def _make_asyncio_request(self, url: str) -> Mapping[str, Any]:
"""Make async HTTP request using asyncio and urllib (fallback)"""
loop = asyncio.get_running_loop()
timeout_secs = parse_env_var_float("BRAINTRUST_HTTP_TIMEOUT", 60.0)

def sync_request():
request = Request(url)
if self.token:
request.add_header("Authorization", f"Bearer {self.token}")

try:
response_obj = urlopen(request, timeout=timeout_secs)
response_data = response_obj.read()
return json.loads(response_data.decode("utf-8"))
except HTTPError as e:
if e.code >= 400:
error_body = e.read().decode("utf-8") if hasattr(e, "read") else str(e)
raise Exception(f"HTTP {e.code}: {error_body}")
raise
except URLError as e:
raise Exception(f"URL Error: {e}")

return await loop.run_in_executor(HTTP_REQUEST_THREAD_POOL, sync_request)


# Sometimes we'd like to launch network requests concurrently. We provide a
# thread pool to accomplish this. Use a multiple of number of CPU cores to limit
Expand Down Expand Up @@ -2001,6 +2075,163 @@ def compute_metadata():
)


async def aload_prompt(
project: Optional[str] = None,
slug: Optional[str] = None,
version: Optional[Union[str, int]] = None,
project_id: Optional[str] = None,
prompt_id: str | None = None,
defaults: Optional[Mapping[str, Any]] = None,
no_trace: bool = False,
environment: str | None = None,
app_url: Optional[str] = None,
api_key: Optional[str] = None,
org_name: Optional[str] = None,
) -> "Prompt":
"""
Async version of load_prompt. Loads a prompt from the specified project.

:param project: The name of the project to load the prompt from. Must specify at least one of `project` or `project_id`.
:param slug: The slug of the prompt to load.
:param version: An optional version of the prompt (to read). If not specified, the latest version will be used.
:param project_id: The id of the project to load the prompt from. This takes precedence over `project` if specified.
:param prompt_id: The id of a specific prompt to load. If specified, this takes precedence over all other parameters (project, slug, version).
:param defaults: (Optional) A dictionary of default values to use when rendering the prompt. Prompt values will override these defaults.
:param no_trace: If true, do not include logging metadata for this prompt when build() is called.
:param environment: The environment to load the prompt from. Cannot be used together with version.
:param app_url: The URL of the Braintrust App. Defaults to https://www.braintrust.dev.
:param api_key: The API key to use. If the parameter is not specified, will try to use the `BRAINTRUST_API_KEY` environment variable. If no API
key is specified, will prompt the user to login.
:param org_name: (Optional) The name of a specific organization to connect to. This is useful if you belong to multiple.
:returns: The prompt object.
"""

if version is not None and environment is not None:
raise ValueError(
"Cannot specify both 'version' and 'environment' parameters. Please use only one (remove the other)."
)

if prompt_id:
pass
elif not project and not project_id:
raise ValueError("Must specify at least one of project or project_id")
elif not slug:
raise ValueError("Must specify slug")

loop = asyncio.get_running_loop()
response = None

try:
# Run login in thread pool since it's synchronous
await loop.run_in_executor(HTTP_REQUEST_THREAD_POOL, login, app_url, api_key, org_name)
if prompt_id:
args = _populate_args({
"version": version,
"environment": environment
})

response = await _state.api_conn().aget_json(f"/v1/prompt/{prompt_id}", args)

if response:
response = {"objects": [response]}

else:
args = _populate_args(
{
"project_name": project,
"project_id": project_id,
"slug": slug,
"version": version,
"environment": environment
},
)

response = await _state.api_conn().aget_json("/v1/prompt", args)

except Exception as server_error:
# If environment was specified, don't fall back to cache
if environment is not None:
raise ValueError(f"Prompt not found for specified environment {environment}") from server_error

eprint(f"Failed to load prompt, attempting to fall back to cache: {server_error}")
try:
if prompt_id:
cache_result = await loop.run_in_executor(
HTTP_REQUEST_THREAD_POOL,
lambda: _state._prompt_cache.get(
id=prompt_id
),
)
else:
cache_result = await loop.run_in_executor(
HTTP_REQUEST_THREAD_POOL,
lambda: _state._prompt_cache.get(
slug,
version=str(version) if version else "latest",
project_id=project_id,
project_name=project,
),
)
# Return Prompt with pre-computed metadata from cache
return Prompt(
lazy_metadata=LazyValue(lambda: cache_result, use_mutex=True),
defaults=defaults or {},
no_trace=no_trace,
)
except Exception as cache_error:
if prompt_id:
raise ValueError(
f"Prompt with id {prompt_id} not found (not found on server or in local cache): {cache_error}"
) from server_error
raise ValueError(
f"Prompt {slug} (version {version or 'latest'}) not found in {project or project_id} (not found on server or in local cache): {cache_error}"
) from server_error

if response is None or "objects" not in response or len(response["objects"]) == 0:
if prompt_id:
raise ValueError(f"Prompt with id {prompt_id} not found.")

raise ValueError(f"Prompt {slug} not found in project {project or project_id}.")
elif len(response["objects"]) > 1:
if prompt_id:
raise ValueError(f"Multiple prompts found with id {prompt_id}. This should never happen.")

raise ValueError(
f"Multiple prompts found with slug {slug} in project {project or project_id}. This should never happen."
)

resp_prompt = response["objects"][0]
prompt_metadata = PromptSchema.from_dict_deep(resp_prompt)
try:
# save prompt to cache
if prompt_id:
await loop.run_in_executor(
HTTP_REQUEST_THREAD_POOL,
lambda: _state._prompt_cache.set(
prompt_metadata,
id=prompt_id
),
)
else:
await loop.run_in_executor(
HTTP_REQUEST_THREAD_POOL,
lambda: _state._prompt_cache.set(
prompt_metadata,
slug=slug,
version=str(version) if version else "latest",
project_id=project_id,
project_name=project,
),
)
except Exception as e:
eprint(f"Failed to store prompt in cache: {e}")

# Return Prompt with pre-computed metadata
return Prompt(
lazy_metadata=LazyValue(lambda: prompt_metadata, use_mutex=True), defaults=defaults or {}, no_trace=no_trace
)


login_lock = threading.RLock()


Expand Down Expand Up @@ -4398,7 +4629,7 @@ def stringify_exception(exc_type: type[BaseException], exc_value: BaseException,
return "".join(traceback.format_exception(exc_type, exc_value, tb))


def _strip_nones(d: T, deep: bool) -> T:
def _strip_nones(d: T, deep: bool = False) -> T:
if not isinstance(d, dict):
return d
return {k: (_strip_nones(v, deep) if deep else v) for (k, v) in d.items() if v is not None} # type: ignore
Expand Down
Loading