diff --git a/.changelog/_unreleased.toml b/.changelog/_unreleased.toml index f22e064..05e827f 100644 --- a/.changelog/_unreleased.toml +++ b/.changelog/_unreleased.toml @@ -9,3 +9,9 @@ id = "ab7f2766-e6f5-4236-946d-bddedcd73433" type = "fix" description = "Fix ClassDecoratorSetting and get_class_settings to correctly handle __databind_settings__ on subclasses" author = "@NiklasRosenstein" + +[[entries]] +id = "6d0f41f2-f7f9-4808-af65-196a7a909b4f" +type = "fix" +description = "Fix #47: Union with Literal in them can now de/serialize" +author = "@rhaps0dy" diff --git a/databind/src/databind/core/union.py b/databind/src/databind/core/union.py index e39f53e..16853b3 100644 --- a/databind/src/databind/core/union.py +++ b/databind/src/databind/core/union.py @@ -7,7 +7,7 @@ import types import typing as t -from typeapi import ClassTypeHint, TypeHint +from typeapi import ClassTypeHint, LiteralTypeHint, TypeHint from databind.core.utils import T @@ -47,6 +47,19 @@ def get_type_by_id(self, type_id: str) -> t.Any: ValueError: If the *type_id* is not an ID among the union members. """ + def get_type_id_for_value(self, value: t.Any) -> str: + """Given a Python value, return the ID of the type among the union members. + + This method allows matching values against Literal type members, which cannot be + resolved by type alone. The default implementation falls back to `get_type_id(type(value))`. + + Arguments: + value: The Python value to retrieve the type ID for. + Raises: + ValueError: If no member matches the *value*. + """ + return self.get_type_id(type(value)) + @abc.abstractmethod def get_type_ids(self) -> t.List[str]: """ @@ -101,6 +114,17 @@ def get_type_id(self, type_: t.Any) -> str: return type_id raise ValueError(f"type {type_} is not a member of {self}") + def get_type_id_for_value(self, value: t.Any) -> str: + # Check LiteralTypeHint members first (more specific — match value in reference_type.values). + # Members may be stored as raw typing annotations or as TypeHint instances, so wrap if needed. + for type_id in self.members: + reference_type = self.get_type_by_id(type_id) + hint = reference_type if isinstance(reference_type, TypeHint) else TypeHint(reference_type) + if isinstance(hint, LiteralTypeHint) and value in hint.values: + return type_id + # Then fall back to class/type members (existing type-based matching) + return self.get_type_id(type(value)) + def get_type_by_id(self, type_id: str) -> t.Any: try: return self._eval_cache[type_id] @@ -226,6 +250,15 @@ def get_type_id(self, type_: t.Any) -> str: errors.append(exc) raise ValueError(f"{type_!r} is not a member of {self}\n" + "- \n".join(map(str, errors))) + def get_type_id_for_value(self, value: t.Any) -> str: + errors = [] + for delegate in self.delegates: + try: + return delegate.get_type_id_for_value(value) + except ValueError as exc: + errors.append(exc) + raise ValueError(f"{value!r} is not a member of {self}\n" + "- \n".join(map(str, errors))) + def get_type_by_id(self, type_id: str) -> t.Any: errors = [] for delegate in self.delegates: diff --git a/databind/src/databind/json/converters.py b/databind/src/databind/json/converters.py index 2d4a16f..136c216 100644 --- a/databind/src/databind/json/converters.py +++ b/databind/src/databind/json/converters.py @@ -764,14 +764,21 @@ def _check_style_compatibility(self, ctx: Context, style: str, value: t.Any) -> def convert(self, ctx: Context) -> t.Any: datatype = ctx.datatype union: t.Optional[Union] + if isinstance(datatype, UnionTypeHint): if datatype.has_none_type(): raise NotImplementedError("unable to handle Union type with None in it") - if not all(isinstance(a, ClassTypeHint) for a in datatype): - raise NotImplementedError(f"members of plain Union must be concrete types: {datatype}") - members = {t.cast(ClassTypeHint, a).type.__name__: a for a in datatype} - if len(members) != len(datatype): + + literal_types = [a for a in datatype if isinstance(a, LiteralTypeHint)] + non_literal_types = [a for a in datatype if not isinstance(a, LiteralTypeHint)] + if not all(isinstance(a, ClassTypeHint) for a in non_literal_types): + raise NotImplementedError(f"members of plain Union must be concrete or Literal types: {datatype}") + + members: t.Dict[str, t.Any] = {t.cast(ClassTypeHint, a).type.__name__: a for a in non_literal_types} + if len(members) != len(non_literal_types): raise NotImplementedError(f"members of plain Union cannot have overlapping type names: {datatype}") + for lit in literal_types: + members[type_repr(lit.hint)] = lit union = Union(members, Union.BEST_MATCH) elif isinstance(datatype, (AnnotatedTypeHint, ClassTypeHint)): union = ctx.get_setting(Union) @@ -807,8 +814,11 @@ def convert(self, ctx: Context) -> t.Any: member_type = union.members.get_type_by_id(member_name) else: - # Identify the member type based on the Python value type. - member_name = union.members.get_type_id(type(ctx.value)) + # Identify the member type based on the Python value. + try: + member_name = union.members.get_type_id_for_value(ctx.value) + except ValueError as exc: + raise ConversionError(self, ctx, str(exc)) member_type = union.members.get_type_by_id(member_name) nesting_key = union.nesting_key or member_name diff --git a/databind/src/databind/json/tests/converters_test.py b/databind/src/databind/json/tests/converters_test.py index 4730eeb..8d17d70 100644 --- a/databind/src/databind/json/tests/converters_test.py +++ b/databind/src/databind/json/tests/converters_test.py @@ -30,6 +30,7 @@ DatetimeConverter, DecimalConverter, EnumConverter, + LiteralConverter, MappingConverter, OptionalConverter, PlainDatatypeConverter, @@ -332,6 +333,17 @@ def test_union_converter_best_match(direction: Direction) -> None: assert mapper.convert(direction, 42, t.Union[int, str]) == 42 +@pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE)) +def test_union_converter_best_match_literal(direction: Direction) -> None: + mapper = make_mapper([UnionConverter(), PlainDatatypeConverter(), LiteralConverter()]) + + LiteralUnionType = t.Union[int, t.Literal["hi"], t.Literal["bye"]] + + assert mapper.convert(direction, 42, LiteralUnionType) == 42 + assert mapper.convert(direction, "hi", LiteralUnionType) == "hi" + assert mapper.convert(direction, "bye", LiteralUnionType) == "bye" + + @pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE)) def test_union_converter_keyed(direction: Direction) -> None: mapper = make_mapper([UnionConverter(), PlainDatatypeConverter()]) @@ -343,6 +355,30 @@ def test_union_converter_keyed(direction: Direction) -> None: assert mapper.convert(direction, 42, th) == {"int": 42} +@pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE)) +def test_union_converter_keyed_literal(direction: Direction) -> None: + mapper = make_mapper([UnionConverter(), PlainDatatypeConverter(), LiteralConverter()]) + + th = te.Annotated[ + t.Union[int, t.Literal["hi"], t.Literal["bye"]], + Union({"int": int, "HiType": t.Literal["hi"], "ByeType": t.Literal["bye"]}, style=Union.KEYED), + ] + if direction == Direction.DESERIALIZE: + assert mapper.convert(direction, {"int": 42}, th) == 42 + assert mapper.convert(direction, {"HiType": "hi"}, th) == "hi" + assert mapper.convert(direction, {"ByeType": "bye"}, th) == "bye" + + with pytest.raises(ConversionError): + mapper.convert(direction, {"ByeType": "hi"}, th) + else: + assert mapper.convert(direction, 42, th) == {"int": 42} + assert mapper.convert(direction, "hi", th) == {"HiType": "hi"} + assert mapper.convert(direction, "bye", th) == {"ByeType": "bye"} + + with pytest.raises(ConversionError): + mapper.convert(direction, {"ByeType": "hi"}, th) + + @pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE)) def test_union_converter_flat_plain_types_not_supported(direction: Direction) -> None: mapper = make_mapper([UnionConverter(), PlainDatatypeConverter()]) @@ -821,3 +857,22 @@ class ChildOverriding(Parent): with pytest.raises(ConversionError) as excinfo: mapper.deserialize({"a": 1, "b": "hello", "extra": "ignored"}, ChildOverriding) assert "extra" in str(excinfo.value) + + +def test_union_literal() -> None: + mapper = make_mapper([UnionConverter(), PlainDatatypeConverter(), LiteralConverter()]) + + IntType = t.Union[int, t.Literal["hi", "bye"]] + StrType = t.Union[str, t.Literal["hi", "bye"]] + + assert mapper.serialize("hi", IntType) == "hi" + assert mapper.serialize(2, IntType) == 2 + + assert mapper.serialize("bye", StrType) == "bye" + assert mapper.serialize("other", StrType) == "other" + + assert mapper.deserialize("hi", IntType) == "hi" + assert mapper.deserialize(2, IntType) == 2 + + assert mapper.deserialize("bye", StrType) == "bye" + assert mapper.deserialize("other", StrType) == "other"