diff --git a/memsync/sync.py b/memsync/sync.py index a52847f..78d5d9e 100644 --- a/memsync/sync.py +++ b/memsync/sync.py @@ -65,14 +65,18 @@ def harvest_memory_content(transcript: str, current_memory: str, config: Config) SESSION TRANSCRIPT: {transcript}""" + prefill = _build_prefill(current_memory) response = client.messages.create( model=config.model, max_tokens=4096, system=HARVEST_SYSTEM_PROMPT, - messages=[{"role": "user", "content": user_prompt}], + messages=[ + {"role": "user", "content": user_prompt}, + {"role": "assistant", "content": prefill}, + ], ) - updated_content = _strip_model_wrapper(response.content[0].text) + updated_content = _strip_model_wrapper(prefill + response.content[0].text) if not _looks_like_memory_file(updated_content): return { @@ -98,6 +102,20 @@ def harvest_memory_content(transcript: str, current_memory: str, config: Config) } +def _build_prefill(current_memory: str) -> str: + """ + Build an assistant prefill string that forces the model to start outputting + the memory file rather than a narrative summary. + + Uses the first line of the current memory if it looks like a valid start + (heading or comment marker), otherwise falls back to the memsync comment. + """ + first_line = current_memory.strip().splitlines()[0] if current_memory.strip() else "" + if first_line.startswith("#") or first_line.startswith("" + + def _strip_model_wrapper(content: str) -> str: """ Strip wrapper artifacts the model sometimes adds around the memory file: @@ -154,14 +172,18 @@ def refresh_memory_content(notes: str, current_memory: str, config: Config) -> d SESSION NOTES: {notes}""" + prefill = _build_prefill(current_memory) response = client.messages.create( model=config.model, max_tokens=4096, system=SYSTEM_PROMPT, - messages=[{"role": "user", "content": user_prompt}], + messages=[ + {"role": "user", "content": user_prompt}, + {"role": "assistant", "content": prefill}, + ], ) - updated_content = _strip_model_wrapper(response.content[0].text) + updated_content = _strip_model_wrapper(prefill + response.content[0].text) # Reject responses that look like narrative explanations rather than a memory file. # The model occasionally ignores "no preamble" and returns prose — writing that diff --git a/tests/test_harvest.py b/tests/test_harvest.py index e1d345d..10fb2ba 100644 --- a/tests/test_harvest.py +++ b/tests/test_harvest.py @@ -308,7 +308,19 @@ def test_ignores_invalid_dict_entries(self, tmp_path): # --------------------------------------------------------------------------- class TestHarvestMemoryContent: - def _make_mock_response(self, text: str, stop_reason: str = "end_turn") -> MagicMock: + def _make_mock_response(self, text: str, stop_reason: str = "end_turn", + current_memory: str = SAMPLE_MEMORY) -> MagicMock: + """ + Simulate the API returning a continuation after the prefill. + + With assistant prefill, the API only returns the text *after* the prefill. + The code then combines: prefill + response.content[0].text. + So the mock must strip the prefill line from the expected output. + """ + from memsync.sync import _build_prefill + prefill = _build_prefill(current_memory) + if text.startswith(prefill): + text = text[len(prefill):] mock = MagicMock() mock.content = [MagicMock(text=text)] mock.stop_reason = stop_reason diff --git a/tests/test_sync.py b/tests/test_sync.py index bf3bfd7..6b9326b 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -136,7 +136,19 @@ def test_content_includes_notes(self, tmp_path): class TestRefreshMemoryContent: - def _make_mock_response(self, text: str, stop_reason: str = "end_turn") -> MagicMock: + def _make_mock_response(self, text: str, stop_reason: str = "end_turn", + current_memory: str = SAMPLE_MEMORY) -> MagicMock: + """ + Simulate the API returning a continuation after the prefill. + + With assistant prefill, the API only returns the text *after* the prefill. + The code then combines: prefill + response.content[0].text. + So the mock must strip the prefill line from the expected output. + """ + from memsync.sync import _build_prefill + prefill = _build_prefill(current_memory) + if text.startswith(prefill): + text = text[len(prefill):] mock_response = MagicMock() mock_response.content = [MagicMock(text=text)] mock_response.stop_reason = stop_reason