Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion js/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ export function templateUsesThreadVariables(template: string): boolean {
return THREAD_VARIABLE_PATTERN.test(template);
}

function filterSystemMessagesFromThread(thread: unknown[]): unknown[] {
return thread.filter((message) => {
if (!message || typeof message !== "object" || Array.isArray(message)) {
return true;
}
const role = Reflect.get(message, "role");
return role !== "system";
});
}

const NO_COT_SUFFIX =
"Answer the question by calling `select_choice` with a single choice from {{__choices}}.";

Expand Down Expand Up @@ -311,7 +321,8 @@ export function LLMClassifierFromTemplate<RenderArgs>({
let threadVars: Record<string, unknown> = {};
if (runtimeArgs.trace && templateUsesThreadVariables(promptTemplate)) {
const thread = await runtimeArgs.trace.getThread();
const computed = computeThreadTemplateVars(thread);
const scorerThread = filterSystemMessagesFromThread(thread);
const computed = computeThreadTemplateVars(scorerThread);
// Build threadVars from THREAD_VARIABLE_NAMES to keep in sync with the pattern
for (const name of THREAD_VARIABLE_NAMES) {
threadVars[name] = computed[name as keyof ThreadTemplateVars];
Expand Down
5 changes: 3 additions & 2 deletions py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from .thread_utils import (
THREAD_VARIABLE_NAMES,
compute_thread_template_vars,
filter_system_messages_from_thread,
template_uses_thread_variables,
)

Expand Down Expand Up @@ -427,7 +428,7 @@ def _compute_thread_vars_sync(self, trace) -> dict[str, object]:
if not isinstance(thread, list):
thread = list(thread)

computed = compute_thread_template_vars(thread)
computed = compute_thread_template_vars(filter_system_messages_from_thread(thread))
return {name: computed[name] for name in self._thread_variable_names}

async def _compute_thread_vars_async(self, trace) -> dict[str, object]:
Expand All @@ -443,7 +444,7 @@ async def _compute_thread_vars_async(self, trace) -> dict[str, object]:
if not isinstance(thread, list):
thread = list(thread)

computed = compute_thread_template_vars(thread)
computed = compute_thread_template_vars(filter_system_messages_from_thread(thread))
return {name: computed[name] for name in self._thread_variable_names}

def _request_args(self, output, expected, **kwargs):
Expand Down
4 changes: 4 additions & 0 deletions py/autoevals/thread_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def is_llm_message_array(value: Any) -> bool:
return isinstance(value, list) and all(is_role_content_message(item) for item in value)


def filter_system_messages_from_thread(thread: list[Any]) -> list[Any]:
return [message for message in thread if not (isinstance(message, Mapping) and message.get("role") == "system")]


def _indent(text: str, prefix: str = " ") -> str:
return "\n".join(prefix + line for line in text.split("\n"))

Expand Down