diff --git a/bluecellulab/cell/core.py b/bluecellulab/cell/core.py index b11389aa..306ac3d1 100644 --- a/bluecellulab/cell/core.py +++ b/bluecellulab/cell/core.py @@ -850,19 +850,22 @@ def add_synapse_replay( if not file_path.exists(): raise FileNotFoundError(f"Spike file not found: {str(file_path)}") - synapse_spikes: dict = get_synapse_replay_spikes(str(file_path)) + + synapse_spikes: dict[CellId, np.ndarray] = get_synapse_replay_spikes(str(file_path)) + for synapse_id, synapse in self.synapses.items(): - source_population = synapse.syn_description["source_population_name"] - pre_gid = CellId( - source_population, int(synapse.syn_description[SynapseProperty.PRE_GID]) + pre_cell_id = CellId( + str(synapse.syn_description["source_population_name"]), + int(synapse.syn_description[SynapseProperty.PRE_GID]), ) - if pre_gid.id in synapse_spikes: - spikes_of_interest = synapse_spikes[pre_gid.id] - # filter spikes of interest >=stimulus.delay, <=stimulus.duration + + if pre_cell_id in synapse_spikes: + spikes_of_interest = synapse_spikes[pre_cell_id] spikes_of_interest = spikes_of_interest[ (spikes_of_interest >= stimulus.delay) & (spikes_of_interest <= stimulus.duration) ] + connection = bluecellulab.Connection( synapse, pre_spiketrain=spikes_of_interest, @@ -872,7 +875,7 @@ def add_synapse_replay( spike_location=spike_location, ) logger.debug( - f"Added synapse replay from {pre_gid} to {self.cell_id.id}, {synapse_id}" + f"Added synapse replay from {pre_cell_id} to {self.cell_id.id}, {synapse_id}" ) self.connections[synapse_id] = connection diff --git a/bluecellulab/circuit/simulation_access.py b/bluecellulab/circuit/simulation_access.py index bf0eeec1..099b6110 100644 --- a/bluecellulab/circuit/simulation_access.py +++ b/bluecellulab/circuit/simulation_access.py @@ -171,17 +171,17 @@ def get_spikes(self) -> dict[CellId, np.ndarray]: return outdat.to_dict() -def get_synapse_replay_spikes(f_name: str) -> dict: +def get_synapse_replay_spikes(f_name: str) -> dict[CellId, np.ndarray]: """Read the .h5 file containing the spike replays. Args: f_name: Path to SONATA .h5 spike file. Returns: - Dictionary mapping node_id to np.array of spike times. + Dictionary mapping CellId(population, node_id) to np.array of spike times. """ all_spikes = [] - with h5py.File(f_name, 'r') as f: + with h5py.File(f_name, "r") as f: if "spikes" not in f: raise ValueError("spike file is missing required 'spikes' group.") @@ -190,7 +190,13 @@ def get_synapse_replay_spikes(f_name: str) -> dict: timestamps = pop_group["timestamps"][:] node_ids = pop_group["node_ids"][:] - pop_spikes = pd.DataFrame({"t": timestamps, "node_id": node_ids}) + pop_spikes = pd.DataFrame( + { + "t": timestamps, + "population": str(population), + "node_id": node_ids, + } + ) pop_spikes = pop_spikes.astype({"node_id": int}) all_spikes.append(pop_spikes) @@ -201,9 +207,12 @@ def get_synapse_replay_spikes(f_name: str) -> dict: if (spikes["t"] < 0).any(): logger.warning("Found negative spike times... Clipping them to 0") - spikes["t"].clip(lower=0., inplace=True) - - # Group spikes by node_id and ensure spike times are sorted in ascending order. - # This is critical because NEURON's VecStim requires monotonically increasing times per train. - grouped = spikes.groupby("node_id")["t"] - return {k: np.sort(np.asarray(v.values)) for k, v in grouped} + spikes["t"] = spikes["t"].clip(lower=0.0) + + # Group spikes by CellId (population, node_id) and sort each spike train, + # since NEURON VecStim requires monotonically increasing times. + grouped = spikes.groupby(["population", "node_id"])["t"] + return { + CellId(str(population), int(node_id)): np.sort(np.asarray(times.values)) + for (population, node_id), times in grouped + } diff --git a/tests/examples/sonata_unit_test_sims/synapse_replay/synapse_replay.h5 b/tests/examples/sonata_unit_test_sims/synapse_replay/synapse_replay.h5 index aeb3a7b8..567d11fb 100644 Binary files a/tests/examples/sonata_unit_test_sims/synapse_replay/synapse_replay.h5 and b/tests/examples/sonata_unit_test_sims/synapse_replay/synapse_replay.h5 differ diff --git a/tests/test_cell/test_core.py b/tests/test_cell/test_core.py index 13f913e9..c190d4d9 100644 --- a/tests/test_cell/test_core.py +++ b/tests/test_cell/test_core.py @@ -20,6 +20,8 @@ from unittest.mock import MagicMock, patch import uuid +from bluecellulab.circuit.node_id import CellId +from bluecellulab.circuit.synapse_properties import SynapseProperty import neuron import numpy as np import pytest @@ -923,3 +925,63 @@ def test_add_currents_recordings_with_point_process(self): called = [c.args[0] for c in mock_add.call_args_list] assert "ina" in called assert "i_ExpSyn" in called + + +@pytest.mark.v6 +def test_add_synapse_replay_matches_population_and_node_id(monkeypatch, tmp_path): + """Cell: test add_synapse_replay matches using population and node_id.""" + import h5py + + spike_file = tmp_path / "spikes.h5" + with h5py.File(spike_file, "w") as f: + spikes = f.create_group("spikes") + + vpm = spikes.create_group("VPM") + vpm.create_dataset("timestamps", data=np.array([10.0, 20.0])) + vpm.create_dataset("node_ids", data=np.array([14, 14])) + + class DummyStimulus: + def __init__(self, spike_file): + self.spike_file = str(spike_file) + self.config_dir = None + self.delay = 0.0 + self.duration = 100.0 + + class DummySynapse: + def __init__(self, source_population, pre_gid): + self.syn_description = { + "source_population_name": source_population, + SynapseProperty.PRE_GID: pre_gid, + } + + class DummyConnection: + def __init__( + self, + synapse, + pre_spiketrain, + pre_cell, + stim_dt, + spike_threshold, + spike_location, + ): + self.synapse = synapse + self.pre_spiketrain = pre_spiketrain + + monkeypatch.setattr(bluecellulab, "Connection", DummyConnection) + + cell = Cell.__new__(Cell) + cell.sonata_proxy = object() + cell.cell_id = CellId("S1nonbarrel_neurons", 3) + cell.record_dt = 0.1 + cell.connections = {} + cell.synapses = { + ("vpm", 0): DummySynapse("VPM", 14), + ("s1", 0): DummySynapse("S1nonbarrel_neurons", 14), + } + + stimulus = DummyStimulus(spike_file) + Cell.add_synapse_replay(cell, stimulus, -20.0, "soma") + + assert set(cell.connections.keys()) == {("vpm", 0)} + conn = cell.connections[("vpm", 0)] + assert np.allclose(conn.pre_spiketrain, [10.0, 20.0]) diff --git a/tests/test_circuit/test_simulation_access.py b/tests/test_circuit/test_simulation_access.py index f491c855..8263dec5 100644 --- a/tests/test_circuit/test_simulation_access.py +++ b/tests/test_circuit/test_simulation_access.py @@ -96,9 +96,36 @@ def test_get_spikes(self): def test_get_synapse_replay_spikes(): - """.Test get_synapse_replay_spikes.""" + """Test get_synapse_replay_spikes.""" res = get_synapse_replay_spikes( parent_dir / "data" / "synapse_replay_file" / "spikes.h5" ) - assert set(res.keys()) == {5382} - assert res[5382].tolist() == [1500.0, 2000.0, 2500.0] + key = CellId("All", 5382) + assert set(res.keys()) == {key} + assert res[key].tolist() == [1500.0, 2000.0, 2500.0] + + +def test_get_synapse_replay_spikes_keeps_population(tmp_path): + """Test get_synapse_replay_spikes keeps population in the key.""" + import h5py + + spike_file = tmp_path / "spikes.h5" + with h5py.File(spike_file, "w") as f: + spikes = f.create_group("spikes") + + s1 = spikes.create_group("S1nonbarrel_neurons") + s1.create_dataset("timestamps", data=np.array([1.0, 2.0])) + s1.create_dataset("node_ids", data=np.array([1, 1])) + + vpm = spikes.create_group("VPM") + vpm.create_dataset("timestamps", data=np.array([3.0])) + vpm.create_dataset("node_ids", data=np.array([1])) + + res = get_synapse_replay_spikes(spike_file) + + assert set(res.keys()) == { + CellId("S1nonbarrel_neurons", 1), + CellId("VPM", 1), + } + assert res[CellId("S1nonbarrel_neurons", 1)].tolist() == [1.0, 2.0] + assert res[CellId("VPM", 1)].tolist() == [3.0]