Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/user/data/data_formats_help.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ separated by whitespaces or commas or semicolons, in the following order:

where *Q* is assumed to have units of 1/Angstrom, *I(Q)* is assumed to have
units of 1/cm, *dI(Q)* is the uncertainty on the intensity value (also as 1/cm),
and *dQ(Q)* **is the one-sigma FWHM Gaussian instrumental resolution in** *Q*,
and *dQ(Q)* **is the one-sigma Gaussian instrumental resolution in** *Q*,
**assumed to have arisen from pinhole geometry**. If the data are slit-smeared,
see `Slit-Smeared Data`_.

Expand Down
2 changes: 1 addition & 1 deletion sasdata/data_util/err1d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This program is public domain
"""
Error propogation algorithms for simple arithmetic
Error propagation algorithms for simple arithmetic

Warning: like the underlying numpy library, the inplace operations
may return values of the wrong type if some of the arguments are
Expand Down
17 changes: 11 additions & 6 deletions sasdata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,19 @@ def to_string(self, header=""):
)
else:
attributes = ""
if self.contents:
if type(self.contents) is str:
match self.contents:
case str():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string, Quantity and ndarray cases are exactly the same as the default case. Maybe just remove all but the default?

children = f"\n{header} {self.contents}"
else:
case list() | tuple():
children = "".join([n.to_string(header + " ") for n in self.contents])
else:
children = ""

case Quantity():
children = f"\n{header} {self.contents}"
case ndarray():
children = f"\n{header} {self.contents}"
case None:
children = ""
case _:
children = f"\n{header} {self.contents}"
return f"\n{header}{self.name}:{attributes}{children}"

def filter(self, name: str) -> list[ndarray | Quantity | str]:
Expand Down
160 changes: 134 additions & 26 deletions sasdata/trend.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,156 @@

from dataclasses import dataclass

import numpy as np

from sasdata.data import SasData
from sasdata.data_backing import Dataset, Group
from sasdata.quantities.quantity import Quantity
from sasdata.transforms.rebinning import calculate_interpolation_matrix_1d

# Axis strs refer to the name of their associated NamedQuantity.

# TODO: This probably shouldn't be here but will keep it here for now.
# TODO: This probably shouldn't be here but will keep it here for now. --> In sasdta/data.py?
# TODO: Similarity/relation to __getitem__ in SasData class?
# TODO: Or a method of Metadata class?
# TODO: Not sure how to type hint the return.
def get_metadatum_from_path(data: SasData, metadata_path: list[str]):
current_group = data._raw_metadata
current_node = data.metadata.raw
for path_item in metadata_path:
current_item = current_group.children.get(path_item, None)
if current_item is None or (isinstance(current_item, Dataset) and path_item != metadata_path[-1]):
raise ValueError('Path does not lead to valid a metadatum.')
elif isinstance(current_item, Group):
current_group = current_item
if isinstance(current_node.contents, list):
# Search through list of MetaNodes
current_item = None
for node in current_node.contents:
if node.name == path_item:
current_item = node
break
else:
return current_item.data
raise ValueError('End of path without finding a dataset.')
# Not a list, can't navigate further
raise ValueError('Path does not lead to a valid metadatum.')

if current_item is None:
raise ValueError('Path does not lead to a valid metadatum.')
Comment on lines 26 to +31
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate errors. Combine?


# Check if we're at the end of the path
if path_item == metadata_path[-1]:
return current_item.contents

current_node = current_item
raise ValueError('End of path without finding a dataset.')

@dataclass
class Trend:
data: list[SasData]
# This is going to be a path to a specific metadatum.
#
# TODO: But what if the trend axis will be a particular NamedQuantity? Will probably need to think on this.
trend_axis: list[str]

# Designed to take in a particular value of the trend axis, and return the SasData object that matches it.
# TODO: Not exaclty sure what item's type will be. It could depend on where it is pointing to.
def __getitem__(self, item) -> SasData:
for datum in self.data:
metadatum = get_metadatum_from_path(datum, self.trend_axis)
if metadatum == item:
return datum
raise KeyError()
trend_axes: dict[str, list[str] | list] # Path or manual values

def __post_init__(self):

# First, filter out invalid data items
self._filter_and_validate_data()

# Validate data length matches manual value lists
self._validate_manual_values()

# Validate metadata paths
self._validate_metadata_paths()

def _filter_and_validate_data(self):
"""Filter out non-SasData objects and validate data integrity"""
valid_data = []
invalid_indices = []

for i, datum in enumerate(self.data):
if not isinstance(datum, SasData):
invalid_indices.append(i)
continue

# Check if datum has metadata
if not hasattr(datum, 'metadata') or datum.metadata is None:
invalid_indices.append(i)
continue

# Check if datum has raw metadata
if not hasattr(datum.metadata, 'raw') or datum.metadata.raw is None:
invalid_indices.append(i)
continue

valid_data.append(datum)

# Update data with only valid items
self.data = valid_data

# Warn about filtered items
if invalid_indices:
print(f"Warning: Removed data items at indices {invalid_indices} - not SasData objects or missing/invalid metadata")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The print statements should probably be changed to logging.warn.


# Additional validation
if not self.data:
raise ValueError("No valid data items remain after filtering")

if len(self.data) < 2:
print(f"Warning: Only {len(self.data)} valid data items remain")

def _validate_manual_values(self):
"""Ensure manual value lists match data length"""
for axis_name, axis_config in self.trend_axes.items():
if isinstance(axis_config, list) and not isinstance(axis_config[0], str):
# This is a manual value list (not a path)
if len(axis_config) != len(self.data):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary nesting

raise ValueError(f"Manual values for axis '{axis_name}' must have same length as data ({len(self.data)} items, got {len(axis_config)})")

def _validate_metadata_paths(self):
"""Validate metadata paths"""
for axis_name, axis_config in self.trend_axes.items():
if isinstance(axis_config, list) and len(axis_config) > 0 and isinstance(axis_config[0], str):
# This is a metadata path
for i, datum in enumerate(self.data):
try:
get_metadatum_from_path(datum, axis_config)
except ValueError as e:
raise ValueError(f"trend_axes['{axis_name}'] path {axis_config} invalid for data item {i}: {e}")

def get_trend_values(self, axis_name: str) -> list:
"""Get values for a named trend axis"""
if axis_name not in self.trend_axes:
raise KeyError(f"Axis '{axis_name}' not found")

axis_config = self.trend_axes[axis_name]

if isinstance(axis_config, list) and len(axis_config) > 0 and isinstance(axis_config[0], str):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is already performed when creating the trend. Is this necessary?

# Metadata path - extract from data
return [get_metadatum_from_path(datum, axis_config) for datum in self.data]
else:
# Manual values - return as-is
return axis_config.copy() # Return copy to prevent modification

def add_manual_axis(self, axis_name: str, values: list):
"""Add a new manual trend axis"""
if len(values) != len(self.data):
raise ValueError(f"Manual values must have same length as data ({len(self.data)} items, got {len(values)})")

self.trend_axes[axis_name] = values.copy()

def add_metadata_axis(self, axis_name: str, path: list[str]):
"""Add a new metadata trend axis"""
# Validate the path first
for i, datum in enumerate(self.data):
try:
get_metadatum_from_path(datum, path)
except ValueError as e:
raise ValueError(f"Path {path} invalid for data item {i}: {e}")

self.trend_axes[axis_name] = path

@property
def trend_axes(self) -> list[float]:
return [get_metadatum_from_path(datum, self.trend_axis) for datum in self.data]
def axis_names(self) -> list[str]:
return list(self.trend_axes.keys())

def is_manual_axis(self, axis_name: str) -> bool:
"""Check if an axis uses manual values or metadata path"""
if axis_name not in self.trend_axes:
raise KeyError(f"Axis '{axis_name}' not found")

axis_config = self.trend_axes[axis_name]
return not (isinstance(axis_config, list) and len(axis_config) > 0 and isinstance(axis_config[0], str))

# TODO: Assumes there are at least 2 items in data. Is this reasonable to assume? Should there be error handling for
# situations where this may not be the case?
Expand Down Expand Up @@ -85,5 +193,5 @@ def interpolate(self, axis: str) -> "Trend":
)
new_data.append(new_datum)
new_trend = Trend(new_data,
self.trend_axis)
self.trend_axes)
return new_trend
4 changes: 2 additions & 2 deletions test/utest_trend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_trend_build_interpolate(directory_name: str):
data = ascii_reader.load_data(params)
trend = Trend(
data=data,
trend_axis=['magnetic', 'applied_magnetic_field']
trend_axes={'applied_magnetic_field': ['magnetic', 'applied_magnetic_field']}
)
# Initially, the q axes in this date don't exactly match
to_interpolate_on = 'Q'
Expand All @@ -64,6 +64,6 @@ def test_trend_q_axis_match():
data = ascii_reader.load_data(params)
trend = Trend(
data=data,
trend_axis=['magnetic', 'counting_index']
trend_axes={'counting_index': ['magnetic', 'counting_index']}
)
assert trend.all_axis_match('Q')
Loading