Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changelog/_unreleased.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@ id = "ab7f2766-e6f5-4236-946d-bddedcd73433"
type = "fix"
description = "Fix ClassDecoratorSetting and get_class_settings to correctly handle __databind_settings__ on subclasses"
author = "@NiklasRosenstein"

[[entries]]
id = "6d0f41f2-f7f9-4808-af65-196a7a909b4f"
type = "fix"
description = "Fix #47: Union with Literal in them can now de/serialize"
author = "@rhaps0dy"
35 changes: 34 additions & 1 deletion databind/src/databind/core/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import types
import typing as t

from typeapi import ClassTypeHint, TypeHint
from typeapi import ClassTypeHint, LiteralTypeHint, TypeHint

from databind.core.utils import T

Expand Down Expand Up @@ -47,6 +47,19 @@ def get_type_by_id(self, type_id: str) -> t.Any:
ValueError: If the *type_id* is not an ID among the union members.
"""

def get_type_id_for_value(self, value: t.Any) -> str:
"""Given a Python value, return the ID of the type among the union members.

This method allows matching values against Literal type members, which cannot be
resolved by type alone. The default implementation falls back to `get_type_id(type(value))`.

Arguments:
value: The Python value to retrieve the type ID for.
Raises:
ValueError: If no member matches the *value*.
"""
return self.get_type_id(type(value))

@abc.abstractmethod
def get_type_ids(self) -> t.List[str]:
"""
Expand Down Expand Up @@ -101,6 +114,17 @@ def get_type_id(self, type_: t.Any) -> str:
return type_id
raise ValueError(f"type {type_} is not a member of {self}")

def get_type_id_for_value(self, value: t.Any) -> str:
# Check LiteralTypeHint members first (more specific — match value in reference_type.values).
# Members may be stored as raw typing annotations or as TypeHint instances, so wrap if needed.
for type_id in self.members:
reference_type = self.get_type_by_id(type_id)
hint = reference_type if isinstance(reference_type, TypeHint) else TypeHint(reference_type)
if isinstance(hint, LiteralTypeHint) and value in hint.values:
return type_id
# Then fall back to class/type members (existing type-based matching)
return self.get_type_id(type(value))

def get_type_by_id(self, type_id: str) -> t.Any:
try:
return self._eval_cache[type_id]
Expand Down Expand Up @@ -226,6 +250,15 @@ def get_type_id(self, type_: t.Any) -> str:
errors.append(exc)
raise ValueError(f"{type_!r} is not a member of {self}\n" + "- \n".join(map(str, errors)))

def get_type_id_for_value(self, value: t.Any) -> str:
errors = []
for delegate in self.delegates:
try:
return delegate.get_type_id_for_value(value)
except ValueError as exc:
errors.append(exc)
raise ValueError(f"{value!r} is not a member of {self}\n" + "- \n".join(map(str, errors)))

def get_type_by_id(self, type_id: str) -> t.Any:
errors = []
for delegate in self.delegates:
Expand Down
22 changes: 16 additions & 6 deletions databind/src/databind/json/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,14 +764,21 @@ def _check_style_compatibility(self, ctx: Context, style: str, value: t.Any) ->
def convert(self, ctx: Context) -> t.Any:
datatype = ctx.datatype
union: t.Optional[Union]

if isinstance(datatype, UnionTypeHint):
if datatype.has_none_type():
raise NotImplementedError("unable to handle Union type with None in it")
if not all(isinstance(a, ClassTypeHint) for a in datatype):
raise NotImplementedError(f"members of plain Union must be concrete types: {datatype}")
members = {t.cast(ClassTypeHint, a).type.__name__: a for a in datatype}
if len(members) != len(datatype):

literal_types = [a for a in datatype if isinstance(a, LiteralTypeHint)]
non_literal_types = [a for a in datatype if not isinstance(a, LiteralTypeHint)]
if not all(isinstance(a, ClassTypeHint) for a in non_literal_types):
raise NotImplementedError(f"members of plain Union must be concrete or Literal types: {datatype}")

members: t.Dict[str, t.Any] = {t.cast(ClassTypeHint, a).type.__name__: a for a in non_literal_types}
if len(members) != len(non_literal_types):
raise NotImplementedError(f"members of plain Union cannot have overlapping type names: {datatype}")
for lit in literal_types:
members[type_repr(lit.hint)] = lit
union = Union(members, Union.BEST_MATCH)
elif isinstance(datatype, (AnnotatedTypeHint, ClassTypeHint)):
union = ctx.get_setting(Union)
Expand Down Expand Up @@ -807,8 +814,11 @@ def convert(self, ctx: Context) -> t.Any:
member_type = union.members.get_type_by_id(member_name)

else:
# Identify the member type based on the Python value type.
member_name = union.members.get_type_id(type(ctx.value))
# Identify the member type based on the Python value.
try:
member_name = union.members.get_type_id_for_value(ctx.value)
except ValueError as exc:
raise ConversionError(self, ctx, str(exc))
member_type = union.members.get_type_by_id(member_name)

nesting_key = union.nesting_key or member_name
Expand Down
55 changes: 55 additions & 0 deletions databind/src/databind/json/tests/converters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DatetimeConverter,
DecimalConverter,
EnumConverter,
LiteralConverter,
MappingConverter,
OptionalConverter,
PlainDatatypeConverter,
Expand Down Expand Up @@ -332,6 +333,17 @@ def test_union_converter_best_match(direction: Direction) -> None:
assert mapper.convert(direction, 42, t.Union[int, str]) == 42


@pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE))
def test_union_converter_best_match_literal(direction: Direction) -> None:
mapper = make_mapper([UnionConverter(), PlainDatatypeConverter(), LiteralConverter()])

LiteralUnionType = t.Union[int, t.Literal["hi"], t.Literal["bye"]]

assert mapper.convert(direction, 42, LiteralUnionType) == 42
assert mapper.convert(direction, "hi", LiteralUnionType) == "hi"
assert mapper.convert(direction, "bye", LiteralUnionType) == "bye"


@pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE))
def test_union_converter_keyed(direction: Direction) -> None:
mapper = make_mapper([UnionConverter(), PlainDatatypeConverter()])
Expand All @@ -343,6 +355,30 @@ def test_union_converter_keyed(direction: Direction) -> None:
assert mapper.convert(direction, 42, th) == {"int": 42}


@pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE))
def test_union_converter_keyed_literal(direction: Direction) -> None:
mapper = make_mapper([UnionConverter(), PlainDatatypeConverter(), LiteralConverter()])

th = te.Annotated[
t.Union[int, t.Literal["hi"], t.Literal["bye"]],
Union({"int": int, "HiType": t.Literal["hi"], "ByeType": t.Literal["bye"]}, style=Union.KEYED),
]
if direction == Direction.DESERIALIZE:
assert mapper.convert(direction, {"int": 42}, th) == 42
assert mapper.convert(direction, {"HiType": "hi"}, th) == "hi"
assert mapper.convert(direction, {"ByeType": "bye"}, th) == "bye"

with pytest.raises(ConversionError):
mapper.convert(direction, {"ByeType": "hi"}, th)
else:
assert mapper.convert(direction, 42, th) == {"int": 42}
assert mapper.convert(direction, "hi", th) == {"HiType": "hi"}
assert mapper.convert(direction, "bye", th) == {"ByeType": "bye"}

with pytest.raises(ConversionError):
mapper.convert(direction, {"ByeType": "hi"}, th)


@pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE))
def test_union_converter_flat_plain_types_not_supported(direction: Direction) -> None:
mapper = make_mapper([UnionConverter(), PlainDatatypeConverter()])
Expand Down Expand Up @@ -821,3 +857,22 @@ class ChildOverriding(Parent):
with pytest.raises(ConversionError) as excinfo:
mapper.deserialize({"a": 1, "b": "hello", "extra": "ignored"}, ChildOverriding)
assert "extra" in str(excinfo.value)


def test_union_literal() -> None:
mapper = make_mapper([UnionConverter(), PlainDatatypeConverter(), LiteralConverter()])

IntType = t.Union[int, t.Literal["hi", "bye"]]
StrType = t.Union[str, t.Literal["hi", "bye"]]

assert mapper.serialize("hi", IntType) == "hi"
assert mapper.serialize(2, IntType) == 2

assert mapper.serialize("bye", StrType) == "bye"
assert mapper.serialize("other", StrType) == "other"

assert mapper.deserialize("hi", IntType) == "hi"
assert mapper.deserialize(2, IntType) == 2

assert mapper.deserialize("bye", StrType) == "bye"
assert mapper.deserialize("other", StrType) == "other"
Loading