Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_overloads() #1140

Merged
merged 5 commits into from
Apr 16, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add get_overloads()
  • Loading branch information
JelleZijlstra committed Apr 16, 2022
commit de68224bbe91c6dcfddea7e647c20e625379bb9a
5 changes: 4 additions & 1 deletion typing_extensions/CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Unreleased

- Add `typing.assert_type`. Backport from bpo-46480.
- Add `typing_extensions.get_overloads` and
`typing_extensions.clear_overloads`, and add registry support to
`typing_extensions.overload`. Backport from python/cpython#89263.
- Add `typing_extensions.assert_type`. Backport from bpo-46480.
- Drop support for Python 3.6. Original patch by Adam Turner (@AA-Turner).

# Release 4.1.1 (February 13, 2022)
Expand Down
5 changes: 5 additions & 0 deletions typing_extensions/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ This module currently contains the following:

- ``assert_never``
- ``assert_type``
- ``clear_overloads``
- ``get_overloads``
- ``LiteralString`` (see PEP 675)
- ``Never``
- ``NotRequired`` (see PEP 655)
Expand Down Expand Up @@ -122,6 +124,9 @@ Certain objects were changed after they were added to ``typing``, and
Python 3.8 and lack support for ``ParamSpecArgs`` and ``ParamSpecKwargs``
in 3.9.
- ``@final`` was changed in Python 3.11 to set the ``.__final__`` attribute.
- ``@overload`` was changed in Python 3.11 to register overload function.
In order to access overloads with ``typing_extensions.get_overloads()``,
you must use ``@typing_extensions.overload``.
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved

There are a few types whose interface was modified between different
versions of typing. For example, ``typing.Sequence`` was modified to
Expand Down
72 changes: 71 additions & 1 deletion typing_extensions/src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import abc
import contextlib
import collections
from collections import defaultdict
import collections.abc
from functools import lru_cache
import inspect
import pickle
import subprocess
import types
from unittest import TestCase, main, skipUnless, skipIf
from unittest.mock import patch
from test import ann_module, ann_module2, ann_module3
import typing
from typing import TypeVar, Optional, Union, Any, AnyStr
Expand All @@ -21,9 +23,10 @@
from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict, Self
from typing_extensions import TypeAlias, ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs, TypeGuard
from typing_extensions import Awaitable, AsyncIterator, AsyncContextManager, Required, NotRequired
from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, overload, final, is_typeddict
from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, final, is_typeddict
from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString
from typing_extensions import assert_type, get_type_hints, get_origin, get_args
from typing_extensions import clear_overloads, get_overloads, overload

# Flags used to mark tests that only apply after a specific
# version of the typing module.
Expand Down Expand Up @@ -403,6 +406,20 @@ def test_no_multiple_subscripts(self):
Literal[1][1]


class MethodHolder:
@classmethod
def clsmethod(cls): ...
@staticmethod
def stmethod(): ...
def method(self): ...


if TYPING_3_11_0:
registry_holder = typing
else:
registry_holder = typing_extensions


class OverloadTests(BaseTestCase):

def test_overload_fails(self):
Expand All @@ -424,6 +441,59 @@ def blah():

blah()

def set_up_overloads(self):
def blah():
pass

overload1 = blah
overload(blah)

def blah():
pass

overload2 = blah
overload(blah)

def blah():
pass

return blah, [overload1, overload2]

# Make sure we don't clear the global overload registry
@patch(f"{registry_holder.__name__}._overload_registry",
defaultdict(lambda: defaultdict(dict)))
def test_overload_registry(self):
registry = getattr(registry_holder, "_overload_registry")
# The registry starts out empty
self.assertEqual(registry, {})

impl, overloads = self.set_up_overloads()
self.assertNotEqual(registry, {})
self.assertEqual(list(get_overloads(impl)), overloads)

def some_other_func(): pass
overload(some_other_func)
other_overload = some_other_func
def some_other_func(): pass
self.assertEqual(list(get_overloads(some_other_func)), [other_overload])

# Make sure that after we clear all overloads, the registry is
# completely empty.
clear_overloads()
self.assertEqual(registry, {})
self.assertEqual(get_overloads(impl), [])

# Querying a function with no overloads shouldn't change the registry.
def the_only_one(): pass
self.assertEqual(get_overloads(the_only_one), [])
self.assertEqual(registry, {})

def test_overload_registry_repeated(self):
for _ in range(2):
impl, overloads = self.set_up_overloads()

self.assertEqual(list(get_overloads(impl)), overloads)


class AssertTypeTests(BaseTestCase):

Expand Down
70 changes: 69 additions & 1 deletion typing_extensions/src/typing_extensions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import collections
import collections.abc
import functools
import operator
import sys
import types as _types
Expand Down Expand Up @@ -46,7 +47,9 @@
'Annotated',
'assert_never',
'assert_type',
'clear_overloads',
'dataclass_transform',
'get_overloads',
'final',
'get_args',
'get_origin',
Expand Down Expand Up @@ -249,7 +252,72 @@ def __getitem__(self, parameters):


_overload_dummy = typing._overload_dummy # noqa
overload = typing.overload

if hasattr(typing, "get_overloads"): # 3.11+
overload = typing.overload
get_overloads = typing.get_overloads
clear_overloads = typing.clear_overloads
else:
# {module: {qualname: {firstlineno: func}}}
_overload_registry = collections.defaultdict(
functools.partial(collections.defaultdict, dict)
)


def overload(func):
"""Decorator for overloaded functions/methods.

In a stub file, place two or more stub definitions for the same
function in a row, each decorated with @overload. For example:

@overload
def utf8(value: None) -> None: ...
@overload
def utf8(value: bytes) -> bytes: ...
@overload
def utf8(value: str) -> bytes: ...

In a non-stub file (i.e. a regular .py file), do the same but
follow it with an implementation. The implementation should *not*
be decorated with @overload. For example:

@overload
def utf8(value: None) -> None: ...
@overload
def utf8(value: bytes) -> bytes: ...
@overload
def utf8(value: str) -> bytes: ...
def utf8(value):
# implementation goes here

The overloads for a function can be retrieved at runtime using the
get_overloads() function.
"""
# classmethod and staticmethod
f = getattr(func, "__func__", func)
try:
_overload_registry[f.__module__][f.__qualname__][f.__code__.co_firstlineno] = func
except AttributeError:
# Not a normal function; ignore.
pass
return _overload_dummy


def get_overloads(func):
"""Return all defined overloads for *func* as a sequence."""
# classmethod and staticmethod
f = getattr(func, "__func__", func)
if f.__module__ not in _overload_registry:
return []
mod_dict = _overload_registry[f.__module__]
if f.__qualname__ not in mod_dict:
return []
return list(mod_dict[f.__qualname__].values())


def clear_overloads():
"""Clear all overloads in the registry."""
_overload_registry.clear()


# This is not a real generic class. Don't use outside annotations.
Expand Down