diff --git a/mindee/parsing/v2/field/object_field.py b/mindee/parsing/v2/field/object_field.py index 2e002194..564dc8b3 100644 --- a/mindee/parsing/v2/field/object_field.py +++ b/mindee/parsing/v2/field/object_field.py @@ -1,8 +1,13 @@ +from typing import TYPE_CHECKING, Dict, cast from mindee.parsing.common.string_dict import StringDict from mindee.parsing.v2.field.base_field import BaseField from mindee.parsing.v2.field.dynamic_field import FieldType from mindee.parsing.v2.field.inference_fields import InferenceFields +if TYPE_CHECKING: + from mindee.parsing.v2.field.list_field import ListField + from mindee.parsing.v2.field.simple_field import SimpleField + class ObjectField(BaseField): """Object field containing multiple fields.""" @@ -37,5 +42,86 @@ def multi_str(self) -> str: first = False return out_str + @property + def simple_fields(self) -> Dict[str, "SimpleField"]: + """ + Extract and return all SimpleField fields from the `fields` attribute. + + :return: A dictionary containing all fields that have a type of `FieldType.SIMPLE`. + """ + simple_fields = {} + for field_key, field_value in self.fields.items(): + if field_value.field_type == FieldType.SIMPLE: + simple_fields[field_key] = cast("SimpleField", field_value) + return simple_fields + + @property + def list_fields(self) -> Dict[str, "ListField"]: + """ + Retrieves all ListField fields from the `fields` attribute. + + :return: A dictionary containing all fields of type `LIST`, with keys + representing field keys and values being the corresponding field + objects. + """ + list_fields = {} + for field_key, field_value in self.fields.items(): + if field_value.field_type == FieldType.LIST: + list_fields[field_key] = cast("ListField", field_value) + return list_fields + + @property + def object_fields(self) -> Dict[str, "ObjectField"]: + """ + Retrieves all ObjectField fields from the `fields` attribute of the instance. + + :returns: A dictionary containing fields of type `FieldType.OBJECT`. The keys represent + the field names, and the values are corresponding ObjectField objects. + """ + object_fields = {} + for field_key, field_value in self.fields.items(): + if field_value.field_type == FieldType.OBJECT: + object_fields[field_key] = cast("ObjectField", field_value) + return object_fields + + def get_simple_field(self, field_name: str) -> "SimpleField": + """ + Retrieves a SimpleField from the provided field name. + + :param field_name: The name of the field to retrieve. + :type field_name: str + :return: The SimpleField object corresponding to the given field name. + :raises ValueError: If the specified field is not of type SimpleField. + """ + if self.fields[field_name].field_type != FieldType.SIMPLE: + raise ValueError(f"Field {field_name} is not a SimpleField.") + return cast("SimpleField", self.fields[field_name]) + + def get_list_field(self, field_name: str) -> "ListField": + """ + Retrieves the ``ListField`` for the specified field name. + + :param field_name: The name of the field to retrieve. + :type field_name: str + :return: The corresponding ``ListField`` for the given field name. + :raises ValueError: If the field is not of type ``ListField``. + """ + if self.fields[field_name].field_type != FieldType.LIST: + raise ValueError(f"Field {field_name} is not a ListField.") + return cast("ListField", self.fields[field_name]) + + def get_object_field(self, field_name: str) -> "ObjectField": + """ + Retrieves the `ObjectField` associated with the specified field name. + + :param field_name: The name of the field to retrieve. + :type field_name: str + :return: The `ObjectField` associated with the given field name. + :raises ValueError: If the field specified by `field_name` is not an `ObjectField`. + """ + if self.fields[field_name].field_type != FieldType.OBJECT: + raise ValueError(f"Field {field_name} is not an ObjectField.") + return cast("ObjectField", self.fields[field_name]) + def __str__(self) -> str: return self.single_str() diff --git a/tests/utils.py b/tests/utils.py index 058e3595..252a699c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,7 +17,6 @@ V2_DATA_DIR = ROOT_DATA_DIR / "v2" V2_PRODUCT_DATA_DIR = V2_DATA_DIR / "products" -V2_UTILITIES_DATA_DIR = V2_DATA_DIR / "utilities" def clear_envvars(monkeypatch) -> None: diff --git a/tests/v2/input/__init__.py b/tests/v2/input/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/v2/product/extraction/test_extraction_response.py b/tests/v2/product/extraction/test_extraction_response.py index e35091f9..912d0846 100644 --- a/tests/v2/product/extraction/test_extraction_response.py +++ b/tests/v2/product/extraction/test_extraction_response.py @@ -6,7 +6,10 @@ from mindee import InferenceResponse from mindee.parsing.v2 import InferenceActiveOptions -from mindee.parsing.v2.field import FieldConfidence, ListField, ObjectField, SimpleField +from mindee.parsing.v2.field.field_confidence import FieldConfidence +from mindee.parsing.v2.field.list_field import ListField +from mindee.parsing.v2.field.object_field import ObjectField +from mindee.parsing.v2.field.simple_field import SimpleField from mindee.parsing.v2.field.inference_fields import InferenceFields from mindee.parsing.v2.inference import Inference from mindee.parsing.v2.inference_file import InferenceFile @@ -53,34 +56,46 @@ def test_deep_nested_fields(): response.inference.result.fields["field_object"].fields["sub_object_object"], ObjectField, ) + fields = response.inference.result.fields + assert isinstance(fields.get("field_object"), ObjectField) assert isinstance( - response.inference.result.fields["field_object"] - .fields["sub_object_object"] - .fields, + fields.get("field_object").get_simple_field("sub_object_simple"), SimpleField + ) + assert isinstance( + fields.get("field_object").get_list_field("sub_object_list"), ListField + ) + assert isinstance( + fields.get("field_object").get_object_field("sub_object_object"), ObjectField + ) + assert len(fields.get("field_object").simple_fields) == 1 + assert len(fields.get("field_object").list_fields) == 1 + assert len(fields.get("field_object").object_fields) == 1 + assert isinstance( + fields["field_object"].fields["sub_object_object"].fields, dict, ) assert isinstance( - response.inference.result.fields["field_object"] + fields["field_object"] .fields["sub_object_object"] .fields["sub_object_object_sub_object_list"], ListField, ) assert isinstance( - response.inference.result.fields["field_object"] + fields["field_object"] .fields["sub_object_object"] .fields["sub_object_object_sub_object_list"] .items, list, ) assert isinstance( - response.inference.result.fields["field_object"] + fields["field_object"] .fields["sub_object_object"] .fields["sub_object_object_sub_object_list"] .items[0], ObjectField, ) assert isinstance( - response.inference.result.fields["field_object"] + fields["field_object"] .fields["sub_object_object"] .fields["sub_object_object_sub_object_list"] .items[0] @@ -88,7 +103,7 @@ def test_deep_nested_fields(): SimpleField, ) assert ( - response.inference.result.fields["field_object"] + fields["field_object"] .fields["sub_object_object"] .fields["sub_object_object_sub_object_list"] .items[0] @@ -103,6 +118,7 @@ def test_standard_field_types(): json_sample, rst_sample = _get_inference_samples("standard_field_types") response = InferenceResponse(json_sample) assert isinstance(response.inference, Inference) + field_simple_string = response.inference.result.fields["field_simple_string"] assert isinstance(field_simple_string, SimpleField) assert field_simple_string.value == "field_simple_string-value"