Skip to content

Commit

Permalink
Fix get_type_hints for _: dataclasses.KW_ONLY
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 526934911
  • Loading branch information
Conchylicultor authored and The etils Authors committed Apr 25, 2023
1 parent 96a5309 commit 0249a70
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Changelog follow https://keepachangelog.com/ format.
* `edc`:
* Add `contextvars` option: Fields annotated as `edc.ContextVars[T]` are
wrapped in `contextvars.ContextVars`.
* Fix error when using `_: dataclasses.KW_ONLY`
* `epy`:
* Better `epy.Lines.block` for custom pretty print classes, list,...

Expand Down
24 changes: 23 additions & 1 deletion etils/edc/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _get_type_hints(cls, *, include_extras: bool = False) -> _Hints:
"""`get_type_hints` with better error reporting."""
# At this point, `ForwardRef` should have been resolved.
try:
return typing_extensions.get_type_hints(cls, include_extras=include_extras)
return _get_type_hints_fix(cls, include_extras=include_extras)
except Exception as e: # pylint: disable=broad-except
msg = (
f'Could not infer typing annotation of {cls.__qualname__} '
Expand All @@ -129,3 +129,25 @@ def _get_type_hints(cls, *, include_extras: bool = False) -> _Hints:
lines = '\n'.join(lines)

epy.reraise(e, prefix=msg + lines + '\n')


def _get_type_hints_fix(cls, *, include_extras: bool = False) -> _Hints:
"""`get_type_hints` with bug fixes."""
# TODO(py311): `get_type_hints` fail for `_: dataclasses.KW_ONLY`
old_annotations = [_fix_annotations(subcls) for subcls in cls.mro()]
try:
return typing_extensions.get_type_hints(cls, include_extras=include_extras)
finally:
# Restore the annotations
for subcls, annotations in zip(cls.mro(), old_annotations):
if annotations:
subcls.__annotations__ = annotations


def _fix_annotations(cls):
"""Remove the `_: dataclasses.KW_ONLY` annotation."""
if cls is object or '_' not in getattr(cls, '__annotations__', {}):
return
old_annotations = dict(cls.__annotations__)
cls.__annotations__.pop('_')
return old_annotations

0 comments on commit 0249a70

Please sign in to comment.