diff --git a/docs/source/user/data/data_formats_help.rst b/docs/source/user/data/data_formats_help.rst index f45c6cfcc..123b4d1f8 100644 --- a/docs/source/user/data/data_formats_help.rst +++ b/docs/source/user/data/data_formats_help.rst @@ -55,7 +55,7 @@ separated by whitespaces or commas or semicolons, in the following order: where *Q* is assumed to have units of 1/Angstrom, *I(Q)* is assumed to have units of 1/cm, *dI(Q)* is the uncertainty on the intensity value (also as 1/cm), -and *dQ(Q)* **is the one-sigma FWHM Gaussian instrumental resolution in** *Q*, +and *dQ(Q)* **is the one-sigma Gaussian instrumental resolution in** *Q*, **assumed to have arisen from pinhole geometry**. If the data are slit-smeared, see `Slit-Smeared Data`_. diff --git a/sasdata/data_util/err1d.py b/sasdata/data_util/err1d.py index bf164f117..bb9dd5b1a 100644 --- a/sasdata/data_util/err1d.py +++ b/sasdata/data_util/err1d.py @@ -1,6 +1,6 @@ # This program is public domain """ -Error propogation algorithms for simple arithmetic +Error propagation algorithms for simple arithmetic Warning: like the underlying numpy library, the inplace operations may return values of the wrong type if some of the arguments are diff --git a/sasdata/metadata.py b/sasdata/metadata.py index d53c3102c..5f9f96029 100644 --- a/sasdata/metadata.py +++ b/sasdata/metadata.py @@ -458,14 +458,19 @@ def to_string(self, header=""): ) else: attributes = "" - if self.contents: - if type(self.contents) is str: + match self.contents: + case str(): children = f"\n{header} {self.contents}" - else: + case list() | tuple(): children = "".join([n.to_string(header + " ") for n in self.contents]) - else: - children = "" - + case Quantity(): + children = f"\n{header} {self.contents}" + case ndarray(): + children = f"\n{header} {self.contents}" + case None: + children = "" + case _: + children = f"\n{header} {self.contents}" return f"\n{header}{self.name}:{attributes}{children}" def filter(self, name: str) -> list[ndarray | Quantity | str]: diff --git a/sasdata/trend.py b/sasdata/trend.py index 9b1a371a4..9d5542f5c 100644 --- a/sasdata/trend.py +++ b/sasdata/trend.py @@ -1,48 +1,156 @@ + from dataclasses import dataclass import numpy as np from sasdata.data import SasData -from sasdata.data_backing import Dataset, Group from sasdata.quantities.quantity import Quantity from sasdata.transforms.rebinning import calculate_interpolation_matrix_1d # Axis strs refer to the name of their associated NamedQuantity. -# TODO: This probably shouldn't be here but will keep it here for now. +# TODO: This probably shouldn't be here but will keep it here for now. --> In sasdta/data.py? +# TODO: Similarity/relation to __getitem__ in SasData class? +# TODO: Or a method of Metadata class? # TODO: Not sure how to type hint the return. def get_metadatum_from_path(data: SasData, metadata_path: list[str]): - current_group = data._raw_metadata + current_node = data.metadata.raw for path_item in metadata_path: - current_item = current_group.children.get(path_item, None) - if current_item is None or (isinstance(current_item, Dataset) and path_item != metadata_path[-1]): - raise ValueError('Path does not lead to valid a metadatum.') - elif isinstance(current_item, Group): - current_group = current_item + if isinstance(current_node.contents, list): + # Search through list of MetaNodes + current_item = None + for node in current_node.contents: + if node.name == path_item: + current_item = node + break else: - return current_item.data - raise ValueError('End of path without finding a dataset.') + # Not a list, can't navigate further + raise ValueError('Path does not lead to a valid metadatum.') + + if current_item is None: + raise ValueError('Path does not lead to a valid metadatum.') + # Check if we're at the end of the path + if path_item == metadata_path[-1]: + return current_item.contents + + current_node = current_item + raise ValueError('End of path without finding a dataset.') @dataclass class Trend: data: list[SasData] - # This is going to be a path to a specific metadatum. - # - # TODO: But what if the trend axis will be a particular NamedQuantity? Will probably need to think on this. - trend_axis: list[str] - - # Designed to take in a particular value of the trend axis, and return the SasData object that matches it. - # TODO: Not exaclty sure what item's type will be. It could depend on where it is pointing to. - def __getitem__(self, item) -> SasData: - for datum in self.data: - metadatum = get_metadatum_from_path(datum, self.trend_axis) - if metadatum == item: - return datum - raise KeyError() + trend_axes: dict[str, list[str] | list] # Path or manual values + + def __post_init__(self): + + # First, filter out invalid data items + self._filter_and_validate_data() + + # Validate data length matches manual value lists + self._validate_manual_values() + + # Validate metadata paths + self._validate_metadata_paths() + + def _filter_and_validate_data(self): + """Filter out non-SasData objects and validate data integrity""" + valid_data = [] + invalid_indices = [] + + for i, datum in enumerate(self.data): + if not isinstance(datum, SasData): + invalid_indices.append(i) + continue + + # Check if datum has metadata + if not hasattr(datum, 'metadata') or datum.metadata is None: + invalid_indices.append(i) + continue + + # Check if datum has raw metadata + if not hasattr(datum.metadata, 'raw') or datum.metadata.raw is None: + invalid_indices.append(i) + continue + + valid_data.append(datum) + + # Update data with only valid items + self.data = valid_data + + # Warn about filtered items + if invalid_indices: + print(f"Warning: Removed data items at indices {invalid_indices} - not SasData objects or missing/invalid metadata") + + # Additional validation + if not self.data: + raise ValueError("No valid data items remain after filtering") + + if len(self.data) < 2: + print(f"Warning: Only {len(self.data)} valid data items remain") + + def _validate_manual_values(self): + """Ensure manual value lists match data length""" + for axis_name, axis_config in self.trend_axes.items(): + if isinstance(axis_config, list) and not isinstance(axis_config[0], str): + # This is a manual value list (not a path) + if len(axis_config) != len(self.data): + raise ValueError(f"Manual values for axis '{axis_name}' must have same length as data ({len(self.data)} items, got {len(axis_config)})") + + def _validate_metadata_paths(self): + """Validate metadata paths""" + for axis_name, axis_config in self.trend_axes.items(): + if isinstance(axis_config, list) and len(axis_config) > 0 and isinstance(axis_config[0], str): + # This is a metadata path + for i, datum in enumerate(self.data): + try: + get_metadatum_from_path(datum, axis_config) + except ValueError as e: + raise ValueError(f"trend_axes['{axis_name}'] path {axis_config} invalid for data item {i}: {e}") + + def get_trend_values(self, axis_name: str) -> list: + """Get values for a named trend axis""" + if axis_name not in self.trend_axes: + raise KeyError(f"Axis '{axis_name}' not found") + + axis_config = self.trend_axes[axis_name] + + if isinstance(axis_config, list) and len(axis_config) > 0 and isinstance(axis_config[0], str): + # Metadata path - extract from data + return [get_metadatum_from_path(datum, axis_config) for datum in self.data] + else: + # Manual values - return as-is + return axis_config.copy() # Return copy to prevent modification + + def add_manual_axis(self, axis_name: str, values: list): + """Add a new manual trend axis""" + if len(values) != len(self.data): + raise ValueError(f"Manual values must have same length as data ({len(self.data)} items, got {len(values)})") + + self.trend_axes[axis_name] = values.copy() + + def add_metadata_axis(self, axis_name: str, path: list[str]): + """Add a new metadata trend axis""" + # Validate the path first + for i, datum in enumerate(self.data): + try: + get_metadatum_from_path(datum, path) + except ValueError as e: + raise ValueError(f"Path {path} invalid for data item {i}: {e}") + + self.trend_axes[axis_name] = path + @property - def trend_axes(self) -> list[float]: - return [get_metadatum_from_path(datum, self.trend_axis) for datum in self.data] + def axis_names(self) -> list[str]: + return list(self.trend_axes.keys()) + + def is_manual_axis(self, axis_name: str) -> bool: + """Check if an axis uses manual values or metadata path""" + if axis_name not in self.trend_axes: + raise KeyError(f"Axis '{axis_name}' not found") + + axis_config = self.trend_axes[axis_name] + return not (isinstance(axis_config, list) and len(axis_config) > 0 and isinstance(axis_config[0], str)) # TODO: Assumes there are at least 2 items in data. Is this reasonable to assume? Should there be error handling for # situations where this may not be the case? @@ -85,5 +193,5 @@ def interpolate(self, axis: str) -> "Trend": ) new_data.append(new_datum) new_trend = Trend(new_data, - self.trend_axis) + self.trend_axes) return new_trend diff --git a/test/utest_trend.py b/test/utest_trend.py index b079bf53c..221cc31b7 100644 --- a/test/utest_trend.py +++ b/test/utest_trend.py @@ -42,7 +42,7 @@ def test_trend_build_interpolate(directory_name: str): data = ascii_reader.load_data(params) trend = Trend( data=data, - trend_axis=['magnetic', 'applied_magnetic_field'] + trend_axes={'applied_magnetic_field': ['magnetic', 'applied_magnetic_field']} ) # Initially, the q axes in this date don't exactly match to_interpolate_on = 'Q' @@ -64,6 +64,6 @@ def test_trend_q_axis_match(): data = ascii_reader.load_data(params) trend = Trend( data=data, - trend_axis=['magnetic', 'counting_index'] + trend_axes={'counting_index': ['magnetic', 'counting_index']} ) assert trend.all_axis_match('Q')