diff --git a/.changelog/_unreleased.toml b/.changelog/_unreleased.toml index e9fc08d..f22e064 100644 --- a/.changelog/_unreleased.toml +++ b/.changelog/_unreleased.toml @@ -1,3 +1,9 @@ +[[entries]] +id = "693674ea-b2b2-4733-bce6-4d5bae59b164" +type = "fix" +description = "Fix #66: dataclasses inheriting from uninstantiated Generic did not get all their fields serialized" +author = "@rhaps0dy" + [[entries]] id = "ab7f2766-e6f5-4236-946d-bddedcd73433" type = "fix" diff --git a/databind/src/databind/core/schema.py b/databind/src/databind/core/schema.py index dafd295..1128edd 100644 --- a/databind/src/databind/core/schema.py +++ b/databind/src/databind/core/schema.py @@ -199,6 +199,7 @@ class A(Generic[T]): # Collect the members from the dataclass and its base classes. queue = [hint] + seen: t.Set[type] = set() fields: t.Dict[str, Field] = {} while queue: hint = queue.pop(0) @@ -244,11 +245,16 @@ class A(Generic[T]): # are overwritten by other fields. pass - # Continue with the base classes. - for base in hint.bases or hint.type.__bases__: + # Continue with the base classes. We iterate both hint.bases (which provides + # parameterized generic type info from __orig_bases__) and hint.type.__bases__ + # (the actual base classes). This is necessary because when a class inherits from + # a Generic without parameterizing it (e.g. `class B(A)` where A is Generic[T]), + # hint.bases only contains Generic[T] and misses the actual parent dataclass A. + for base in (*hint.bases, *hint.type.__bases__): base_hint = TypeHint(base, source=hint.type).evaluate().parameterize(parameter_map) assert isinstance(base_hint, ClassTypeHint), f"nani? {base_hint}" - if dataclasses.is_dataclass(base_hint.type): + if dataclasses.is_dataclass(base_hint.type) and base_hint.type not in seen: + seen.add(base_hint.type) queue.append(base_hint) return Schema(fields, t.cast("Constructor", dataclass_type), dataclass_type) diff --git a/databind/src/databind/core/tests/schema_test.py b/databind/src/databind/core/tests/schema_test.py index 8387311..406d759 100644 --- a/databind/src/databind/core/tests/schema_test.py +++ b/databind/src/databind/core/tests/schema_test.py @@ -456,3 +456,27 @@ def test_parse_dataclass_with_forward_ref() -> None: ClassWithForwardRef, ClassWithForwardRef, ) + + +UnboundTypeVar = t.TypeVar("UnboundTypeVar") + + +@dataclasses.dataclass +class GenericClass(t.Generic[UnboundTypeVar]): + a_field: int + + +@dataclasses.dataclass +class InheritGeneric(GenericClass): # type: ignore[type-arg] + b_field: str + + +def test_schema_generic_dataclass() -> None: + """Regression test for #66: dataclasses inheriting from Generic with an uninstantiated TypeVar don't get their + parents' fields. + """ + assert convert_dataclass_to_schema(InheritGeneric) == Schema( + {"a_field": Field(TypeHint(int), True), "b_field": Field(TypeHint(str), True)}, + InheritGeneric, + InheritGeneric, + ) diff --git a/databind/src/databind/json/tests/converters_test.py b/databind/src/databind/json/tests/converters_test.py index fe14e4d..4730eeb 100644 --- a/databind/src/databind/json/tests/converters_test.py +++ b/databind/src/databind/json/tests/converters_test.py @@ -719,6 +719,34 @@ def of(cls, v: str) -> "MyCls": assert mapper.deserialize("MyCls", MyCls) == MyCls() +UnboundTypeVar = t.TypeVar("UnboundTypeVar") + + +@dataclasses.dataclass +class GenericClass(t.Generic[UnboundTypeVar]): + a_field: int + + +@dataclasses.dataclass +class InheritGeneric(GenericClass): # type: ignore[type-arg] + b_field: str + + +@pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE)) +def test_convert_generic_dataclass(direction: Direction) -> None: + """Regression test for #66: dataclasses inheriting from Generic with an uninstantiated TypeVar don't get their + parents' fields. + """ + mapper = make_mapper([SchemaConverter(), PlainDatatypeConverter()]) + + if direction == Direction.SERIALIZE: + obj = InheritGeneric(2, "hi") + assert mapper.convert(direction, obj, InheritGeneric) == {"a_field": obj.a_field, "b_field": obj.b_field} + else: + obj = InheritGeneric(4, "something") + assert mapper.convert(direction, {"a_field": obj.a_field, "b_field": obj.b_field}, InheritGeneric) == obj + + def test_extra_keys_on_subclass_creates_own_settings() -> None: """Regression test: ExtraKeys() applied to a subclass must create its own __databind_settings__, not append to the parent's list via MRO traversal."""