-
Notifications
You must be signed in to change notification settings - Fork 4
Refactor 24 trend #198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: refactor_24
Are you sure you want to change the base?
Refactor 24 trend #198
Changes from all commits
d2bd866
4322834
20c9131
6112d6a
967085a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -458,14 +458,19 @@ def to_string(self, header=""): | |
| ) | ||
| else: | ||
| attributes = "" | ||
| if self.contents: | ||
| if type(self.contents) is str: | ||
| match self.contents: | ||
| case str(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The string, Quantity and ndarray cases are exactly the same as the default case. Maybe just remove all but the default? |
||
| children = f"\n{header} {self.contents}" | ||
| else: | ||
| case list() | tuple(): | ||
| children = "".join([n.to_string(header + " ") for n in self.contents]) | ||
| else: | ||
| children = "" | ||
|
|
||
| case Quantity(): | ||
| children = f"\n{header} {self.contents}" | ||
| case ndarray(): | ||
| children = f"\n{header} {self.contents}" | ||
| case None: | ||
| children = "" | ||
| case _: | ||
| children = f"\n{header} {self.contents}" | ||
| return f"\n{header}{self.name}:{attributes}{children}" | ||
|
|
||
| def filter(self, name: str) -> list[ndarray | Quantity | str]: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,48 +1,156 @@ | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| import numpy as np | ||
|
|
||
| from sasdata.data import SasData | ||
| from sasdata.data_backing import Dataset, Group | ||
| from sasdata.quantities.quantity import Quantity | ||
| from sasdata.transforms.rebinning import calculate_interpolation_matrix_1d | ||
|
|
||
| # Axis strs refer to the name of their associated NamedQuantity. | ||
|
|
||
| # TODO: This probably shouldn't be here but will keep it here for now. | ||
| # TODO: This probably shouldn't be here but will keep it here for now. --> In sasdta/data.py? | ||
| # TODO: Similarity/relation to __getitem__ in SasData class? | ||
| # TODO: Or a method of Metadata class? | ||
| # TODO: Not sure how to type hint the return. | ||
| def get_metadatum_from_path(data: SasData, metadata_path: list[str]): | ||
| current_group = data._raw_metadata | ||
| current_node = data.metadata.raw | ||
| for path_item in metadata_path: | ||
| current_item = current_group.children.get(path_item, None) | ||
| if current_item is None or (isinstance(current_item, Dataset) and path_item != metadata_path[-1]): | ||
| raise ValueError('Path does not lead to valid a metadatum.') | ||
| elif isinstance(current_item, Group): | ||
| current_group = current_item | ||
| if isinstance(current_node.contents, list): | ||
| # Search through list of MetaNodes | ||
| current_item = None | ||
| for node in current_node.contents: | ||
| if node.name == path_item: | ||
| current_item = node | ||
| break | ||
| else: | ||
| return current_item.data | ||
| raise ValueError('End of path without finding a dataset.') | ||
| # Not a list, can't navigate further | ||
| raise ValueError('Path does not lead to a valid metadatum.') | ||
|
|
||
| if current_item is None: | ||
| raise ValueError('Path does not lead to a valid metadatum.') | ||
|
Comment on lines
26
to
+31
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Duplicate errors. Combine? |
||
|
|
||
| # Check if we're at the end of the path | ||
| if path_item == metadata_path[-1]: | ||
| return current_item.contents | ||
|
|
||
| current_node = current_item | ||
| raise ValueError('End of path without finding a dataset.') | ||
|
|
||
| @dataclass | ||
| class Trend: | ||
| data: list[SasData] | ||
| # This is going to be a path to a specific metadatum. | ||
| # | ||
| # TODO: But what if the trend axis will be a particular NamedQuantity? Will probably need to think on this. | ||
| trend_axis: list[str] | ||
|
|
||
| # Designed to take in a particular value of the trend axis, and return the SasData object that matches it. | ||
| # TODO: Not exaclty sure what item's type will be. It could depend on where it is pointing to. | ||
| def __getitem__(self, item) -> SasData: | ||
| for datum in self.data: | ||
| metadatum = get_metadatum_from_path(datum, self.trend_axis) | ||
| if metadatum == item: | ||
| return datum | ||
| raise KeyError() | ||
| trend_axes: dict[str, list[str] | list] # Path or manual values | ||
|
|
||
| def __post_init__(self): | ||
|
|
||
| # First, filter out invalid data items | ||
| self._filter_and_validate_data() | ||
|
|
||
| # Validate data length matches manual value lists | ||
| self._validate_manual_values() | ||
|
|
||
| # Validate metadata paths | ||
| self._validate_metadata_paths() | ||
|
|
||
| def _filter_and_validate_data(self): | ||
| """Filter out non-SasData objects and validate data integrity""" | ||
| valid_data = [] | ||
| invalid_indices = [] | ||
|
|
||
| for i, datum in enumerate(self.data): | ||
| if not isinstance(datum, SasData): | ||
| invalid_indices.append(i) | ||
| continue | ||
|
|
||
| # Check if datum has metadata | ||
| if not hasattr(datum, 'metadata') or datum.metadata is None: | ||
| invalid_indices.append(i) | ||
| continue | ||
|
|
||
| # Check if datum has raw metadata | ||
| if not hasattr(datum.metadata, 'raw') or datum.metadata.raw is None: | ||
| invalid_indices.append(i) | ||
| continue | ||
|
|
||
| valid_data.append(datum) | ||
|
|
||
| # Update data with only valid items | ||
| self.data = valid_data | ||
|
|
||
| # Warn about filtered items | ||
| if invalid_indices: | ||
| print(f"Warning: Removed data items at indices {invalid_indices} - not SasData objects or missing/invalid metadata") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| # Additional validation | ||
| if not self.data: | ||
| raise ValueError("No valid data items remain after filtering") | ||
|
|
||
| if len(self.data) < 2: | ||
| print(f"Warning: Only {len(self.data)} valid data items remain") | ||
|
|
||
| def _validate_manual_values(self): | ||
| """Ensure manual value lists match data length""" | ||
| for axis_name, axis_config in self.trend_axes.items(): | ||
| if isinstance(axis_config, list) and not isinstance(axis_config[0], str): | ||
| # This is a manual value list (not a path) | ||
| if len(axis_config) != len(self.data): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary nesting |
||
| raise ValueError(f"Manual values for axis '{axis_name}' must have same length as data ({len(self.data)} items, got {len(axis_config)})") | ||
|
|
||
| def _validate_metadata_paths(self): | ||
| """Validate metadata paths""" | ||
| for axis_name, axis_config in self.trend_axes.items(): | ||
| if isinstance(axis_config, list) and len(axis_config) > 0 and isinstance(axis_config[0], str): | ||
| # This is a metadata path | ||
| for i, datum in enumerate(self.data): | ||
| try: | ||
| get_metadatum_from_path(datum, axis_config) | ||
| except ValueError as e: | ||
| raise ValueError(f"trend_axes['{axis_name}'] path {axis_config} invalid for data item {i}: {e}") | ||
|
|
||
| def get_trend_values(self, axis_name: str) -> list: | ||
| """Get values for a named trend axis""" | ||
| if axis_name not in self.trend_axes: | ||
| raise KeyError(f"Axis '{axis_name}' not found") | ||
|
|
||
| axis_config = self.trend_axes[axis_name] | ||
|
|
||
| if isinstance(axis_config, list) and len(axis_config) > 0 and isinstance(axis_config[0], str): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check is already performed when creating the trend. Is this necessary? |
||
| # Metadata path - extract from data | ||
| return [get_metadatum_from_path(datum, axis_config) for datum in self.data] | ||
| else: | ||
| # Manual values - return as-is | ||
| return axis_config.copy() # Return copy to prevent modification | ||
|
|
||
| def add_manual_axis(self, axis_name: str, values: list): | ||
| """Add a new manual trend axis""" | ||
| if len(values) != len(self.data): | ||
| raise ValueError(f"Manual values must have same length as data ({len(self.data)} items, got {len(values)})") | ||
|
|
||
| self.trend_axes[axis_name] = values.copy() | ||
|
|
||
| def add_metadata_axis(self, axis_name: str, path: list[str]): | ||
| """Add a new metadata trend axis""" | ||
| # Validate the path first | ||
| for i, datum in enumerate(self.data): | ||
| try: | ||
| get_metadatum_from_path(datum, path) | ||
| except ValueError as e: | ||
| raise ValueError(f"Path {path} invalid for data item {i}: {e}") | ||
|
|
||
| self.trend_axes[axis_name] = path | ||
|
|
||
| @property | ||
| def trend_axes(self) -> list[float]: | ||
| return [get_metadatum_from_path(datum, self.trend_axis) for datum in self.data] | ||
| def axis_names(self) -> list[str]: | ||
| return list(self.trend_axes.keys()) | ||
|
|
||
| def is_manual_axis(self, axis_name: str) -> bool: | ||
| """Check if an axis uses manual values or metadata path""" | ||
| if axis_name not in self.trend_axes: | ||
| raise KeyError(f"Axis '{axis_name}' not found") | ||
|
|
||
| axis_config = self.trend_axes[axis_name] | ||
| return not (isinstance(axis_config, list) and len(axis_config) > 0 and isinstance(axis_config[0], str)) | ||
|
|
||
| # TODO: Assumes there are at least 2 items in data. Is this reasonable to assume? Should there be error handling for | ||
| # situations where this may not be the case? | ||
|
|
@@ -85,5 +193,5 @@ def interpolate(self, axis: str) -> "Trend": | |
| ) | ||
| new_data.append(new_datum) | ||
| new_trend = Trend(new_data, | ||
| self.trend_axis) | ||
| self.trend_axes) | ||
| return new_trend | ||
Uh oh!
There was an error while loading. Please reload this page.