Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions bluecellulab/cell/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,19 +850,22 @@ def add_synapse_replay(

if not file_path.exists():
raise FileNotFoundError(f"Spike file not found: {str(file_path)}")
synapse_spikes: dict = get_synapse_replay_spikes(str(file_path))

synapse_spikes: dict[CellId, np.ndarray] = get_synapse_replay_spikes(str(file_path))

for synapse_id, synapse in self.synapses.items():
source_population = synapse.syn_description["source_population_name"]
pre_gid = CellId(
source_population, int(synapse.syn_description[SynapseProperty.PRE_GID])
pre_cell_id = CellId(
str(synapse.syn_description["source_population_name"]),
int(synapse.syn_description[SynapseProperty.PRE_GID]),
)
if pre_gid.id in synapse_spikes:
spikes_of_interest = synapse_spikes[pre_gid.id]
# filter spikes of interest >=stimulus.delay, <=stimulus.duration

if pre_cell_id in synapse_spikes:
spikes_of_interest = synapse_spikes[pre_cell_id]
spikes_of_interest = spikes_of_interest[
(spikes_of_interest >= stimulus.delay)
& (spikes_of_interest <= stimulus.duration)
]

connection = bluecellulab.Connection(
synapse,
pre_spiketrain=spikes_of_interest,
Expand All @@ -872,7 +875,7 @@ def add_synapse_replay(
spike_location=spike_location,
)
logger.debug(
f"Added synapse replay from {pre_gid} to {self.cell_id.id}, {synapse_id}"
f"Added synapse replay from {pre_cell_id} to {self.cell_id.id}, {synapse_id}"
)

self.connections[synapse_id] = connection
Expand Down
29 changes: 19 additions & 10 deletions bluecellulab/circuit/simulation_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,17 @@ def get_spikes(self) -> dict[CellId, np.ndarray]:
return outdat.to_dict()


def get_synapse_replay_spikes(f_name: str) -> dict:
def get_synapse_replay_spikes(f_name: str) -> dict[CellId, np.ndarray]:
"""Read the .h5 file containing the spike replays.

Args:
f_name: Path to SONATA .h5 spike file.

Returns:
Dictionary mapping node_id to np.array of spike times.
Dictionary mapping CellId(population, node_id) to np.array of spike times.
"""
all_spikes = []
with h5py.File(f_name, 'r') as f:
with h5py.File(f_name, "r") as f:
if "spikes" not in f:
raise ValueError("spike file is missing required 'spikes' group.")

Expand All @@ -190,7 +190,13 @@ def get_synapse_replay_spikes(f_name: str) -> dict:
timestamps = pop_group["timestamps"][:]
node_ids = pop_group["node_ids"][:]

pop_spikes = pd.DataFrame({"t": timestamps, "node_id": node_ids})
pop_spikes = pd.DataFrame(
{
"t": timestamps,
"population": str(population),
"node_id": node_ids,
}
)
pop_spikes = pop_spikes.astype({"node_id": int})
all_spikes.append(pop_spikes)

Expand All @@ -201,9 +207,12 @@ def get_synapse_replay_spikes(f_name: str) -> dict:

if (spikes["t"] < 0).any():
logger.warning("Found negative spike times... Clipping them to 0")
spikes["t"].clip(lower=0., inplace=True)

# Group spikes by node_id and ensure spike times are sorted in ascending order.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't want to keep this comment?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's a good idea, I added it back and updated it.

# This is critical because NEURON's VecStim requires monotonically increasing times per train.
grouped = spikes.groupby("node_id")["t"]
return {k: np.sort(np.asarray(v.values)) for k, v in grouped}
spikes["t"] = spikes["t"].clip(lower=0.0)

# Group spikes by CellId (population, node_id) and sort each spike train,
# since NEURON VecStim requires monotonically increasing times.
grouped = spikes.groupby(["population", "node_id"])["t"]
return {
CellId(str(population), int(node_id)): np.sort(np.asarray(times.values))
for (population, node_id), times in grouped
}
Binary file not shown.
62 changes: 62 additions & 0 deletions tests/test_cell/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from unittest.mock import MagicMock, patch
import uuid

from bluecellulab.circuit.node_id import CellId
from bluecellulab.circuit.synapse_properties import SynapseProperty
import neuron
import numpy as np
import pytest
Expand Down Expand Up @@ -923,3 +925,63 @@ def test_add_currents_recordings_with_point_process(self):
called = [c.args[0] for c in mock_add.call_args_list]
assert "ina" in called
assert "i_ExpSyn" in called


@pytest.mark.v6
def test_add_synapse_replay_matches_population_and_node_id(monkeypatch, tmp_path):
"""Cell: test add_synapse_replay matches using population and node_id."""
import h5py

spike_file = tmp_path / "spikes.h5"
with h5py.File(spike_file, "w") as f:
spikes = f.create_group("spikes")

vpm = spikes.create_group("VPM")
vpm.create_dataset("timestamps", data=np.array([10.0, 20.0]))
vpm.create_dataset("node_ids", data=np.array([14, 14]))

class DummyStimulus:
def __init__(self, spike_file):
self.spike_file = str(spike_file)
self.config_dir = None
self.delay = 0.0
self.duration = 100.0

class DummySynapse:
def __init__(self, source_population, pre_gid):
self.syn_description = {
"source_population_name": source_population,
SynapseProperty.PRE_GID: pre_gid,
}

class DummyConnection:
def __init__(
self,
synapse,
pre_spiketrain,
pre_cell,
stim_dt,
spike_threshold,
spike_location,
):
self.synapse = synapse
self.pre_spiketrain = pre_spiketrain

monkeypatch.setattr(bluecellulab, "Connection", DummyConnection)

cell = Cell.__new__(Cell)
cell.sonata_proxy = object()
cell.cell_id = CellId("S1nonbarrel_neurons", 3)
cell.record_dt = 0.1
cell.connections = {}
cell.synapses = {
("vpm", 0): DummySynapse("VPM", 14),
("s1", 0): DummySynapse("S1nonbarrel_neurons", 14),
}

stimulus = DummyStimulus(spike_file)
Cell.add_synapse_replay(cell, stimulus, -20.0, "soma")

assert set(cell.connections.keys()) == {("vpm", 0)}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question to be sure that I understand how this works correctly. Here, we have only vpm and we do not have s1 because the only stimulus we gave to add_synapse_replay comes from an h5 file that only has VPM as the population.
Is that correct?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, the replay file only contains VPM (L939) spikes, so only the synapse coming from VPM should get connected. The S1 synapse has the same pre_gid but since its source population is different, it should not match. That’s why we only expect {("vpm", 0)} here.

conn = cell.connections[("vpm", 0)]
assert np.allclose(conn.pre_spiketrain, [10.0, 20.0])
33 changes: 30 additions & 3 deletions tests/test_circuit/test_simulation_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,36 @@ def test_get_spikes(self):


def test_get_synapse_replay_spikes():
""".Test get_synapse_replay_spikes."""
"""Test get_synapse_replay_spikes."""
res = get_synapse_replay_spikes(
parent_dir / "data" / "synapse_replay_file" / "spikes.h5"
)
assert set(res.keys()) == {5382}
assert res[5382].tolist() == [1500.0, 2000.0, 2500.0]
key = CellId("All", 5382)
assert set(res.keys()) == {key}
assert res[key].tolist() == [1500.0, 2000.0, 2500.0]


def test_get_synapse_replay_spikes_keeps_population(tmp_path):
"""Test get_synapse_replay_spikes keeps population in the key."""
import h5py

spike_file = tmp_path / "spikes.h5"
with h5py.File(spike_file, "w") as f:
spikes = f.create_group("spikes")

s1 = spikes.create_group("S1nonbarrel_neurons")
s1.create_dataset("timestamps", data=np.array([1.0, 2.0]))
s1.create_dataset("node_ids", data=np.array([1, 1]))

vpm = spikes.create_group("VPM")
vpm.create_dataset("timestamps", data=np.array([3.0]))
vpm.create_dataset("node_ids", data=np.array([1]))

res = get_synapse_replay_spikes(spike_file)

assert set(res.keys()) == {
CellId("S1nonbarrel_neurons", 1),
CellId("VPM", 1),
}
assert res[CellId("S1nonbarrel_neurons", 1)].tolist() == [1.0, 2.0]
assert res[CellId("VPM", 1)].tolist() == [3.0]