Skip to content

Fix non-periodic cell construction in get_neighborhood#1395

Draft
LarsSchaaf wants to merge 1 commit intoACEsuit:developfrom
LarsSchaaf:fix-neighbourhood-no-pbc
Draft

Fix non-periodic cell construction in get_neighborhood#1395
LarsSchaaf wants to merge 1 commit intoACEsuit:developfrom
LarsSchaaf:fix-neighbourhood-no-pbc

Conversation

@LarsSchaaf
Copy link
Copy Markdown
Collaborator

@LarsSchaaf LarsSchaaf commented Mar 3, 2026

When pbc=False, the artificial cell for the neighbor list was computed as:

max_positions * 5 * cutoff * identity[i, :]                                                                                                                                     
  • This is multiplicative in system size and cutoff, producing unnecessarily large cells for large molecules/clusters (e.g., a 100 Å system with 6 Å cutoff gives ~3000 Å cells). This can cause box skew errors and excessive memory usage (cuda out of memory) in the electrostatic models.
  • The hard coded factor of 5 (tied to MACE max model layers) is irrelevant as the neighbor list operates on a single cutoff, not the model's receptive field. Instead we use a padding of cutoff + 1A.

Fix: Use an additive, per-dimension formula:

  max_positions = np.max(positions, axis=0) - np.min(positions, axis=0)
  cell[i, :] = (max_positions[i] + cutoff + padding) * identity[i, :]

Minimal failure example

at1 = Atoms('H3O', positions=[[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], cell=None, pbc=False)
at = at1.copy()
at.set_positions(at1.get_positions() + [5000,0,0])
out = neighborhood.get_neighborhood(
    positions= at.get_positions(),
    cell= at.get_cell(),
    pbc= at.get_pbc(),
    cutoff= 5.0,
)
print(out[-1])

@LarsSchaaf LarsSchaaf requested review from Copilot and ilyes319 March 3, 2026 15:43
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes non-periodic neighbor-list cell construction and adds infrastructure for accelerated/compiled execution paths (padding, TorchSim interface, OEQ/CuEQ/hybrid conversions), along with accompanying tests and CI coverage.

Changes:

  • Fix non-PBC artificial cell sizing in get_neighborhood to be additive per-dimension with a small padding.
  • Add graph padding utilities + calculator support to stabilize torch.compile shapes (and new tests).
  • Introduce TorchSim model interface + hybrid (cueq+oeq) conversion tooling; expand compile/stress tests and CI.

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
mace/data/neighborhood.py Fixes non-PBC cell construction to avoid huge artificial cells.
mace/data/padding_tools.py Adds helper to generate padding graphs for fixed-size batching.
mace/data/__init__.py Exposes padding helper as part of public mace.data API.
mace/modules/utils.py Allows externally provided displacement tensor for stress/autograd flows (compile-friendly).
mace/modules/wrapper_ops.py Adds OEQ conv fusion/scatter wrappers + hooks OEQ config into TensorProduct creation.
mace/modules/blocks.py Propagates oeq_config into TensorProduct construction.
mace/tools/scripts_utils.py Changes model extraction to load state dict with strict=False.
mace/calculators/mace.py Adds padding support, hybrid backend option, and updates compile/stress handling.
mace/calculators/mace_torchsim.py New TorchSim wrapper with optional padding budgets + compile integration.
mace/cli/convert_e3nn_oeq.py Enables OEQ by default and makes state transfer more defensive.
mace/cli/convert_e3nn_hybrid.py Adds hybrid cueq+oeq converter utility.
tests/test_padding.py New unit tests for padding graph creation and batching.
tests/test_calculator.py Adds regression test ensuring padded calculator matches unpadded outputs.
tests/test_compile.py Adds torch.compile stress test.
tests/test_torchsim.py Adds TorchSim integration tests with a minimal trained model.
setup.cfg Adds optional torchsim extra.
.github/workflows/unittest.yaml Adds a TorchSim CI job on Python 3.12.
mace/__version__.py Bumps version to 0.3.16.
mace/calculators/foundations_models.py Minor formatting/whitespace adjustments.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +27 to +35
max_positions = np.max(positions, axis=0) - np.min(positions, axis=0)
padding = 1 # 1 angstrom padding

if not pbc_x:
cell[0, :] = max_positions * 5 * cutoff * identity[0, :]
cell[0, :] = (max_positions[0] + cutoff + padding) * identity[0, :]
if not pbc_y:
cell[1, :] = max_positions * 5 * cutoff * identity[1, :]
cell[1, :] = (max_positions[1] + cutoff + padding) * identity[1, :]
if not pbc_z:
cell[2, :] = max_positions * 5 * cutoff * identity[2, :]
cell[2, :] = (max_positions[2] + cutoff + padding) * identity[2, :]
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR fixes a concrete failure mode for pbc=False, but there’s no regression test covering the “large absolute coordinates” example from the PR description. Adding a focused test that constructs a translated cluster (e.g., +5000 Å shift) and asserts get_neighborhood succeeds (and returns a reasonable non-pathological cell / neighbor list) would prevent reintroducing the oversized-cell behavior.

Copilot uses AI. Check for mistakes.
Comment on lines 499 to 502
def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module:
model_copy = model.__class__(**extract_config_mace_model(model))
model_copy.load_state_dict(model.state_dict())
model_copy.load_state_dict(model.state_dict(), strict=False)
return model_copy.to(map_location)
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to strict=False can silently drop missing/unexpected keys and return a partially initialized model without any signal, which can lead to incorrect inference results that are hard to debug. Consider keeping strict=True, or if strict=False is required for specific conversion/acceleration paths, capture and validate the missing_keys / unexpected_keys from load_state_dict(...) and raise (or at least log a warning) when they’re non-empty.

Copilot uses AI. Check for mistakes.
Comment on lines +170 to +199
conv_tp.original_forward = conv_tp.forward
if not hasattr(conv_tp, "weight_numel"):
conv_tp.weight_numel = conv_tp.input_args["problem"].weight_numel
conv_tp.layout_transpose_in = transpose_in
conv_tp.layout_transpose_out = transpose_out

def forward(
self,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
tp_weights: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
sender = edge_index[0]
receiver = edge_index[1]
if self.layout_transpose_in is not None:
node_feats = self.layout_transpose_in(node_feats)
out = self.original_forward(
node_feats,
edge_attrs,
tp_weights,
receiver,
sender,
)
if self.layout_transpose_out is not None:
out = self.layout_transpose_out(out)
return out

conv_tp.forward = types.MethodType(forward, conv_tp)
return conv_tp
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wrapper overwrites conv_tp.original_forward unconditionally. If with_oeq_conv_fusion is applied twice to the same module (directly or indirectly), original_forward will point to the already-wrapped method and can cause incorrect behavior (including recursion depending on the wrapping order). Add a guard (e.g., check for an existing original_forward attribute and avoid re-wrapping, or store the true original only once).

Copilot uses AI. Check for mistakes.
Comment on lines +92 to +95
transferred_keys = set()
remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
)
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transferred_keys is currently always empty, so it doesn’t affect remaining_keys. Either populate it with the keys handled by transfer_symmetric_contractions(...) (if that’s the intent) or remove it to reduce dead/placeholder logic and avoid confusion about what weights were transferred vs. copied.

Suggested change
transferred_keys = set()
remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
)
remaining_keys = set(source_dict.keys()) & set(target_dict.keys())

Copilot uses AI. Check for mistakes.
Comment on lines +321 to +336
pad_count = A - n_real_atoms
pad_pos = torch.zeros(pad_count, 3, device=self._device, dtype=self._dtype)
padded_positions = torch.cat([data_dict["positions"], pad_pos])

padded: Dict[str, torch.Tensor] = {
"positions": padded_positions,
"node_attrs": self._buf_node_attrs,
"batch": self._buf_batch,
"edge_index": self._buf_edge_index,
"shifts": self._buf_shifts,
"unit_shifts": self._buf_unit_shifts,
"ptr": self._buf_ptr,
"cell": self._buf_cell,
"head": self._buf_head,
"pbc": data_dict["pbc"],
}
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_fill_padded_data allocates pad_pos and concatenates every forward, even though _allocate_buffers already creates self._buf_positions. This adds per-step allocations and undermines the fixed-buffer goal (especially for large A). Consider copying real positions into self._buf_positions[:n_real_atoms], leaving the tail as zeros, and passing self._buf_positions directly.

Copilot uses AI. Check for mistakes.
Comment on lines +94 to +108
run_env["PYTHONPATH"] = ":".join(sys.path)

cmd = (
sys.executable
+ " "
+ str(run_train)
+ " "
+ " ".join(
[
(f"--{k}={v}" if v is not None else f"--{k}")
for k, v in mace_params.items()
]
)
)
p = subprocess.run(cmd.split(), env=run_env, check=True)
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building a shell-like command string and using cmd.split() is brittle (paths/args with spaces will break) and hardcodes PYTHONPATH separator as :. Prefer constructing an argv list (e.g., [sys.executable, str(run_train), ...]) and use os.pathsep.join(sys.path) for PYTHONPATH to be robust across environments.

Suggested change
run_env["PYTHONPATH"] = ":".join(sys.path)
cmd = (
sys.executable
+ " "
+ str(run_train)
+ " "
+ " ".join(
[
(f"--{k}={v}" if v is not None else f"--{k}")
for k, v in mace_params.items()
]
)
)
p = subprocess.run(cmd.split(), env=run_env, check=True)
run_env["PYTHONPATH"] = os.pathsep.join(sys.path)
args = [
sys.executable,
str(run_train),
] + [
(f"--{k}={v}" if v is not None else f"--{k}")
for k, v in mace_params.items()
]
p = subprocess.run(args, env=run_env, check=True)

Copilot uses AI. Check for mistakes.
# For models with more than 5 layers, the multiplicative constant needs to be increased.
# temp_cell = np.copy(cell)
max_positions = np.max(positions, axis=0) - np.min(positions, axis=0)
padding = 1 # 1 angstrom padding
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

padding is a hard-coded magic number and is currently an int while the rest of the computation is floating-point. Consider using 1.0 and/or factoring this into a named constant (or parameter) so it’s easier to tune/document and consistent with the cell’s float dtype.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 1 out of 1 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +27 to +28
max_positions = np.max(positions, axis=0) - np.min(positions, axis=0)
padding = 1 # 1 angstrom padding
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_positions now stores the per-axis position span (max-min), not a maximum. Renaming it (e.g., position_span / extent) would avoid confusion and make the subsequent indexing (max_positions[0], etc.) clearer.

Copilot uses AI. Check for mistakes.
Comment on lines +28 to +29
padding = 1 # 1 angstrom padding

Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

padding = 1 is a new hard-coded magic number. Consider making this a module-level constant or a function parameter (and use 1.0 to keep everything explicitly float in Angstrom units) so callers/tests can tune it if needed.

Copilot uses AI. Check for mistakes.
@LarsSchaaf
Copy link
Copy Markdown
Collaborator Author

Leads to out of memory errors for electrostatic (eg. MACE-Polar) models @WillBaldwin0.

Traceback

Provided by @GabrielGreenstein01

File ~/projects/mace-field/mace/modules/extensions.py:635, in FieldFukuiMACE.forward(self, data, training, compute_force, compute_virials, compute_stress, compute_displacement, compute_hessian, compute_edge_forces, compute_atomic_stresses, lammps_mliap, use_pbc_evaluator, fermi_level, external_field)
    632 inter_e = scatter_sum(node_inter_es, data["batch"], dim=-1, dim_size=num_graphs)
    634 # Build k-grid
--> 635 k_vectors, kv_norms_squared, kv_mask = compute_k_vectors(
    636     self.kspace_cutoff, cell.view(-1, 3, 3), data["rcell"].view(-1, 3, 3)
    637 )
    639 # SCF fixed point
    640 features_mixed = self.layer_feature_mixer(torch.stack(node_feats_list, dim=0))
File ~/projects/graph_longrange_repo/graph_longrange/kspace.py:64, in compute_k_vectors(cutoff, cell_vectors, r_cell_vectors)
     57 # make a single superset of all coeficients
     58 origin = torch.cartesian_prod(
     59     torch.arange(0, 1, 1, device=device),
     60     torch.arange(0, 1, 1, device=device),
     61     torch.arange(0, 1, 1, device=device),
     62 ).type(torch.float32)
---> 64 open_half_sphere = torch.cartesian_prod(
     65     torch.arange(1, n1max, 1, device=device),
     66     torch.arange(-n2max, n2max, 1, device=device),
     67     torch.arange(-n3max, n3max, 1, device=device),
     68 ).type(torch.float32)
     70 open_half_plane = torch.cartesian_prod(
     71     torch.arange(0, 1, 1, device=device),
     72     torch.arange(1, n2max, 1, device=device),
     73     torch.arange(-n3max, n3max, 1, device=device),
     74 ).type(torch.float32)
     76 open_half_line = torch.cartesian_prod(
     77     torch.arange(0, 1, 1, device=device),
     78     torch.arange(0, 1, 1, device=device),
     79     torch.arange(1, n3max, 1, device=device),
     80 ).type(torch.float32)

File ~/.conda/envs/enzts/lib/python3.12/site-packages/torch/functional.py:1419, in cartesian_prod(*tensors)
   1417 if has_torch_function(tensors):
   1418     return handle_torch_function(cartesian_prod, tensors, *tensors)
-> 1419 return _VF.cartesian_prod(tensors)

OutOfMemoryError: CUDA out of memory.

@LarsSchaaf LarsSchaaf requested a review from WillBaldwin0 March 12, 2026 17:55
@LarsSchaaf LarsSchaaf added the bug Something isn't working label Mar 12, 2026
@WillBaldwin0
Copy link
Copy Markdown
Collaborator

@ilyes319 I wonder if we should also add a warning when the atom locations are too large, since for small applied fields this could lead to nonsense without appropriate shifts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants