Fix non-periodic cell construction in get_neighborhood#1395
Fix non-periodic cell construction in get_neighborhood#1395LarsSchaaf wants to merge 1 commit intoACEsuit:developfrom
get_neighborhood#1395Conversation
There was a problem hiding this comment.
Pull request overview
Fixes non-periodic neighbor-list cell construction and adds infrastructure for accelerated/compiled execution paths (padding, TorchSim interface, OEQ/CuEQ/hybrid conversions), along with accompanying tests and CI coverage.
Changes:
- Fix non-PBC artificial cell sizing in
get_neighborhoodto be additive per-dimension with a small padding. - Add graph padding utilities + calculator support to stabilize
torch.compileshapes (and new tests). - Introduce TorchSim model interface + hybrid (cueq+oeq) conversion tooling; expand compile/stress tests and CI.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
mace/data/neighborhood.py |
Fixes non-PBC cell construction to avoid huge artificial cells. |
mace/data/padding_tools.py |
Adds helper to generate padding graphs for fixed-size batching. |
mace/data/__init__.py |
Exposes padding helper as part of public mace.data API. |
mace/modules/utils.py |
Allows externally provided displacement tensor for stress/autograd flows (compile-friendly). |
mace/modules/wrapper_ops.py |
Adds OEQ conv fusion/scatter wrappers + hooks OEQ config into TensorProduct creation. |
mace/modules/blocks.py |
Propagates oeq_config into TensorProduct construction. |
mace/tools/scripts_utils.py |
Changes model extraction to load state dict with strict=False. |
mace/calculators/mace.py |
Adds padding support, hybrid backend option, and updates compile/stress handling. |
mace/calculators/mace_torchsim.py |
New TorchSim wrapper with optional padding budgets + compile integration. |
mace/cli/convert_e3nn_oeq.py |
Enables OEQ by default and makes state transfer more defensive. |
mace/cli/convert_e3nn_hybrid.py |
Adds hybrid cueq+oeq converter utility. |
tests/test_padding.py |
New unit tests for padding graph creation and batching. |
tests/test_calculator.py |
Adds regression test ensuring padded calculator matches unpadded outputs. |
tests/test_compile.py |
Adds torch.compile stress test. |
tests/test_torchsim.py |
Adds TorchSim integration tests with a minimal trained model. |
setup.cfg |
Adds optional torchsim extra. |
.github/workflows/unittest.yaml |
Adds a TorchSim CI job on Python 3.12. |
mace/__version__.py |
Bumps version to 0.3.16. |
mace/calculators/foundations_models.py |
Minor formatting/whitespace adjustments. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| max_positions = np.max(positions, axis=0) - np.min(positions, axis=0) | ||
| padding = 1 # 1 angstrom padding | ||
|
|
||
| if not pbc_x: | ||
| cell[0, :] = max_positions * 5 * cutoff * identity[0, :] | ||
| cell[0, :] = (max_positions[0] + cutoff + padding) * identity[0, :] | ||
| if not pbc_y: | ||
| cell[1, :] = max_positions * 5 * cutoff * identity[1, :] | ||
| cell[1, :] = (max_positions[1] + cutoff + padding) * identity[1, :] | ||
| if not pbc_z: | ||
| cell[2, :] = max_positions * 5 * cutoff * identity[2, :] | ||
| cell[2, :] = (max_positions[2] + cutoff + padding) * identity[2, :] |
There was a problem hiding this comment.
The PR fixes a concrete failure mode for pbc=False, but there’s no regression test covering the “large absolute coordinates” example from the PR description. Adding a focused test that constructs a translated cluster (e.g., +5000 Å shift) and asserts get_neighborhood succeeds (and returns a reasonable non-pathological cell / neighbor list) would prevent reintroducing the oversized-cell behavior.
| def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: | ||
| model_copy = model.__class__(**extract_config_mace_model(model)) | ||
| model_copy.load_state_dict(model.state_dict()) | ||
| model_copy.load_state_dict(model.state_dict(), strict=False) | ||
| return model_copy.to(map_location) |
There was a problem hiding this comment.
Switching to strict=False can silently drop missing/unexpected keys and return a partially initialized model without any signal, which can lead to incorrect inference results that are hard to debug. Consider keeping strict=True, or if strict=False is required for specific conversion/acceleration paths, capture and validate the missing_keys / unexpected_keys from load_state_dict(...) and raise (or at least log a warning) when they’re non-empty.
| conv_tp.original_forward = conv_tp.forward | ||
| if not hasattr(conv_tp, "weight_numel"): | ||
| conv_tp.weight_numel = conv_tp.input_args["problem"].weight_numel | ||
| conv_tp.layout_transpose_in = transpose_in | ||
| conv_tp.layout_transpose_out = transpose_out | ||
|
|
||
| def forward( | ||
| self, | ||
| node_feats: torch.Tensor, | ||
| edge_attrs: torch.Tensor, | ||
| tp_weights: torch.Tensor, | ||
| edge_index: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| sender = edge_index[0] | ||
| receiver = edge_index[1] | ||
| if self.layout_transpose_in is not None: | ||
| node_feats = self.layout_transpose_in(node_feats) | ||
| out = self.original_forward( | ||
| node_feats, | ||
| edge_attrs, | ||
| tp_weights, | ||
| receiver, | ||
| sender, | ||
| ) | ||
| if self.layout_transpose_out is not None: | ||
| out = self.layout_transpose_out(out) | ||
| return out | ||
|
|
||
| conv_tp.forward = types.MethodType(forward, conv_tp) | ||
| return conv_tp |
There was a problem hiding this comment.
The wrapper overwrites conv_tp.original_forward unconditionally. If with_oeq_conv_fusion is applied twice to the same module (directly or indirectly), original_forward will point to the already-wrapped method and can cause incorrect behavior (including recursion depending on the wrapping order). Add a guard (e.g., check for an existing original_forward attribute and avoid re-wrapping, or store the true original only once).
| transferred_keys = set() | ||
| remaining_keys = ( | ||
| set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys | ||
| ) |
There was a problem hiding this comment.
transferred_keys is currently always empty, so it doesn’t affect remaining_keys. Either populate it with the keys handled by transfer_symmetric_contractions(...) (if that’s the intent) or remove it to reduce dead/placeholder logic and avoid confusion about what weights were transferred vs. copied.
| transferred_keys = set() | |
| remaining_keys = ( | |
| set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys | |
| ) | |
| remaining_keys = set(source_dict.keys()) & set(target_dict.keys()) |
| pad_count = A - n_real_atoms | ||
| pad_pos = torch.zeros(pad_count, 3, device=self._device, dtype=self._dtype) | ||
| padded_positions = torch.cat([data_dict["positions"], pad_pos]) | ||
|
|
||
| padded: Dict[str, torch.Tensor] = { | ||
| "positions": padded_positions, | ||
| "node_attrs": self._buf_node_attrs, | ||
| "batch": self._buf_batch, | ||
| "edge_index": self._buf_edge_index, | ||
| "shifts": self._buf_shifts, | ||
| "unit_shifts": self._buf_unit_shifts, | ||
| "ptr": self._buf_ptr, | ||
| "cell": self._buf_cell, | ||
| "head": self._buf_head, | ||
| "pbc": data_dict["pbc"], | ||
| } |
There was a problem hiding this comment.
_fill_padded_data allocates pad_pos and concatenates every forward, even though _allocate_buffers already creates self._buf_positions. This adds per-step allocations and undermines the fixed-buffer goal (especially for large A). Consider copying real positions into self._buf_positions[:n_real_atoms], leaving the tail as zeros, and passing self._buf_positions directly.
| run_env["PYTHONPATH"] = ":".join(sys.path) | ||
|
|
||
| cmd = ( | ||
| sys.executable | ||
| + " " | ||
| + str(run_train) | ||
| + " " | ||
| + " ".join( | ||
| [ | ||
| (f"--{k}={v}" if v is not None else f"--{k}") | ||
| for k, v in mace_params.items() | ||
| ] | ||
| ) | ||
| ) | ||
| p = subprocess.run(cmd.split(), env=run_env, check=True) |
There was a problem hiding this comment.
Building a shell-like command string and using cmd.split() is brittle (paths/args with spaces will break) and hardcodes PYTHONPATH separator as :. Prefer constructing an argv list (e.g., [sys.executable, str(run_train), ...]) and use os.pathsep.join(sys.path) for PYTHONPATH to be robust across environments.
| run_env["PYTHONPATH"] = ":".join(sys.path) | |
| cmd = ( | |
| sys.executable | |
| + " " | |
| + str(run_train) | |
| + " " | |
| + " ".join( | |
| [ | |
| (f"--{k}={v}" if v is not None else f"--{k}") | |
| for k, v in mace_params.items() | |
| ] | |
| ) | |
| ) | |
| p = subprocess.run(cmd.split(), env=run_env, check=True) | |
| run_env["PYTHONPATH"] = os.pathsep.join(sys.path) | |
| args = [ | |
| sys.executable, | |
| str(run_train), | |
| ] + [ | |
| (f"--{k}={v}" if v is not None else f"--{k}") | |
| for k, v in mace_params.items() | |
| ] | |
| p = subprocess.run(args, env=run_env, check=True) |
| # For models with more than 5 layers, the multiplicative constant needs to be increased. | ||
| # temp_cell = np.copy(cell) | ||
| max_positions = np.max(positions, axis=0) - np.min(positions, axis=0) | ||
| padding = 1 # 1 angstrom padding |
There was a problem hiding this comment.
padding is a hard-coded magic number and is currently an int while the rest of the computation is floating-point. Consider using 1.0 and/or factoring this into a named constant (or parameter) so it’s easier to tune/document and consistent with the cell’s float dtype.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 1 out of 1 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| max_positions = np.max(positions, axis=0) - np.min(positions, axis=0) | ||
| padding = 1 # 1 angstrom padding |
There was a problem hiding this comment.
max_positions now stores the per-axis position span (max-min), not a maximum. Renaming it (e.g., position_span / extent) would avoid confusion and make the subsequent indexing (max_positions[0], etc.) clearer.
| padding = 1 # 1 angstrom padding | ||
|
|
There was a problem hiding this comment.
padding = 1 is a new hard-coded magic number. Consider making this a module-level constant or a function parameter (and use 1.0 to keep everything explicitly float in Angstrom units) so callers/tests can tune it if needed.
|
Leads to out of memory errors for electrostatic (eg. MACE-Polar) models @WillBaldwin0. TracebackProvided by @GabrielGreenstein01 File ~/projects/mace-field/mace/modules/extensions.py:635, in FieldFukuiMACE.forward(self, data, training, compute_force, compute_virials, compute_stress, compute_displacement, compute_hessian, compute_edge_forces, compute_atomic_stresses, lammps_mliap, use_pbc_evaluator, fermi_level, external_field)
632 inter_e = scatter_sum(node_inter_es, data["batch"], dim=-1, dim_size=num_graphs)
634 # Build k-grid
--> 635 k_vectors, kv_norms_squared, kv_mask = compute_k_vectors(
636 self.kspace_cutoff, cell.view(-1, 3, 3), data["rcell"].view(-1, 3, 3)
637 )
639 # SCF fixed point
640 features_mixed = self.layer_feature_mixer(torch.stack(node_feats_list, dim=0))
File ~/projects/graph_longrange_repo/graph_longrange/kspace.py:64, in compute_k_vectors(cutoff, cell_vectors, r_cell_vectors)
57 # make a single superset of all coeficients
58 origin = torch.cartesian_prod(
59 torch.arange(0, 1, 1, device=device),
60 torch.arange(0, 1, 1, device=device),
61 torch.arange(0, 1, 1, device=device),
62 ).type(torch.float32)
---> 64 open_half_sphere = torch.cartesian_prod(
65 torch.arange(1, n1max, 1, device=device),
66 torch.arange(-n2max, n2max, 1, device=device),
67 torch.arange(-n3max, n3max, 1, device=device),
68 ).type(torch.float32)
70 open_half_plane = torch.cartesian_prod(
71 torch.arange(0, 1, 1, device=device),
72 torch.arange(1, n2max, 1, device=device),
73 torch.arange(-n3max, n3max, 1, device=device),
74 ).type(torch.float32)
76 open_half_line = torch.cartesian_prod(
77 torch.arange(0, 1, 1, device=device),
78 torch.arange(0, 1, 1, device=device),
79 torch.arange(1, n3max, 1, device=device),
80 ).type(torch.float32)
File ~/.conda/envs/enzts/lib/python3.12/site-packages/torch/functional.py:1419, in cartesian_prod(*tensors)
1417 if has_torch_function(tensors):
1418 return handle_torch_function(cartesian_prod, tensors, *tensors)
-> 1419 return _VF.cartesian_prod(tensors)
OutOfMemoryError: CUDA out of memory. |
|
@ilyes319 I wonder if we should also add a warning when the atom locations are too large, since for small applied fields this could lead to nonsense without appropriate shifts. |
When
pbc=False, the artificial cell for the neighbor list was computed as:Fix: Use an additive, per-dimension formula:
Minimal failure example