diff --git a/js/llm.ts b/js/llm.ts index 4682a4e..0c65c77 100644 --- a/js/llm.ts +++ b/js/llm.ts @@ -51,6 +51,16 @@ export function templateUsesThreadVariables(template: string): boolean { return THREAD_VARIABLE_PATTERN.test(template); } +function filterSystemMessagesFromThread(thread: unknown[]): unknown[] { + return thread.filter((message) => { + if (!message || typeof message !== "object" || Array.isArray(message)) { + return true; + } + const role = Reflect.get(message, "role"); + return role !== "system"; + }); +} + const NO_COT_SUFFIX = "Answer the question by calling `select_choice` with a single choice from {{__choices}}."; @@ -311,7 +321,8 @@ export function LLMClassifierFromTemplate({ let threadVars: Record = {}; if (runtimeArgs.trace && templateUsesThreadVariables(promptTemplate)) { const thread = await runtimeArgs.trace.getThread(); - const computed = computeThreadTemplateVars(thread); + const scorerThread = filterSystemMessagesFromThread(thread); + const computed = computeThreadTemplateVars(scorerThread); // Build threadVars from THREAD_VARIABLE_NAMES to keep in sync with the pattern for (const name of THREAD_VARIABLE_NAMES) { threadVars[name] = computed[name as keyof ThreadTemplateVars]; diff --git a/py/autoevals/llm.py b/py/autoevals/llm.py index e3a6482..c69d273 100644 --- a/py/autoevals/llm.py +++ b/py/autoevals/llm.py @@ -64,6 +64,7 @@ from .thread_utils import ( THREAD_VARIABLE_NAMES, compute_thread_template_vars, + filter_system_messages_from_thread, template_uses_thread_variables, ) @@ -427,7 +428,7 @@ def _compute_thread_vars_sync(self, trace) -> dict[str, object]: if not isinstance(thread, list): thread = list(thread) - computed = compute_thread_template_vars(thread) + computed = compute_thread_template_vars(filter_system_messages_from_thread(thread)) return {name: computed[name] for name in self._thread_variable_names} async def _compute_thread_vars_async(self, trace) -> dict[str, object]: @@ -443,7 +444,7 @@ async def _compute_thread_vars_async(self, trace) -> dict[str, object]: if not isinstance(thread, list): thread = list(thread) - computed = compute_thread_template_vars(thread) + computed = compute_thread_template_vars(filter_system_messages_from_thread(thread)) return {name: computed[name] for name in self._thread_variable_names} def _request_args(self, output, expected, **kwargs): diff --git a/py/autoevals/thread_utils.py b/py/autoevals/thread_utils.py index e816348..bdfb13c 100644 --- a/py/autoevals/thread_utils.py +++ b/py/autoevals/thread_utils.py @@ -38,6 +38,10 @@ def is_llm_message_array(value: Any) -> bool: return isinstance(value, list) and all(is_role_content_message(item) for item in value) +def filter_system_messages_from_thread(thread: list[Any]) -> list[Any]: + return [message for message in thread if not (isinstance(message, Mapping) and message.get("role") == "system")] + + def _indent(text: str, prefix: str = " ") -> str: return "\n".join(prefix + line for line in text.split("\n"))