-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Alex Waygood <[email protected]>
- Loading branch information
1 parent
0a291da
commit ea0ae21
Showing
7 changed files
with
318 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Some methods are dynamically patched onto to instances as they | ||
# may depend on whether code is executed in graph/eager/v1/v2/etc. | ||
# Tensorflow supports multiple modes of execution which changes some | ||
# of the attributes/methods/even class hierachies. | ||
tensorflow.Tensor.__int__ | ||
tensorflow.Tensor.numpy | ||
tensorflow.Tensor.__index__ | ||
# Incomplete | ||
tensorflow.sparse.SparseTensor.__getattr__ | ||
tensorflow.SparseTensor.__getattr__ | ||
tensorflow.TensorShape.__getattr__ | ||
tensorflow.dtypes.DType.__getattr__ | ||
tensorflow.RaggedTensor.__getattr__ | ||
tensorflow.DType.__getattr__ | ||
tensorflow.Graph.__getattr__ | ||
tensorflow.Operation.__getattr__ | ||
tensorflow.Variable.__getattr__ | ||
# Internal undocumented API | ||
tensorflow.RaggedTensor.__init__ | ||
# Has an undocumented extra argument that tf.Variable which acts like subclass | ||
# (by dynamically patching tf.Tensor methods) does not preserve. | ||
tensorflow.Tensor.__getitem__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
version = "2.10.*" | ||
# requires a version of numpy with a `py.typed` file | ||
requires = ["numpy>=1.20"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
from _typeshed import Incomplete, Self, Unused | ||
from abc import ABCMeta | ||
from builtins import bool as _bool | ||
from collections.abc import Callable, Iterable, Iterator, Sequence | ||
from contextlib import contextmanager | ||
from enum import Enum | ||
from typing import Any, NoReturn, overload | ||
from typing_extensions import TypeAlias | ||
|
||
import numpy | ||
from tensorflow.dtypes import * | ||
|
||
# Most tf.math functions are exported as tf, but sadly not all are. | ||
from tensorflow.math import abs as abs | ||
from tensorflow.sparse import SparseTensor | ||
|
||
# Tensors ideally should be a generic type, but properly typing data type/shape | ||
# will be a lot of work. Until we have good non-generic tensorflow stubs, | ||
# we will skip making Tensor generic. Also good type hints for shapes will | ||
# run quickly into many places where type system is not strong enough today. | ||
# So shape typing is probably not worth doing anytime soon. | ||
_Slice: TypeAlias = int | slice | None | ||
|
||
_FloatDataSequence: TypeAlias = Sequence[float] | Sequence[_FloatDataSequence] | ||
_StrDataSequence: TypeAlias = Sequence[str] | Sequence[_StrDataSequence] | ||
_ScalarTensorCompatible: TypeAlias = Tensor | str | float | numpy.ndarray[Any, Any] | numpy.number[Any] | ||
_TensorCompatible: TypeAlias = _ScalarTensorCompatible | Sequence[_TensorCompatible] | ||
_ShapeLike: TypeAlias = TensorShape | Iterable[_ScalarTensorCompatible | None] | int | Tensor | ||
_DTypeLike: TypeAlias = DType | str | numpy.dtype[Any] | ||
|
||
class Tensor: | ||
def __init__(self, op: Operation, value_index: int, dtype: DType) -> None: ... | ||
def consumers(self) -> list[Incomplete]: ... | ||
@property | ||
def shape(self) -> TensorShape: ... | ||
def get_shape(self) -> TensorShape: ... | ||
@property | ||
def dtype(self) -> DType: ... | ||
@property | ||
def graph(self) -> Graph: ... | ||
@property | ||
def name(self) -> str: ... | ||
@property | ||
def op(self) -> Operation: ... | ||
def numpy(self) -> numpy.ndarray[Any, Any]: ... | ||
def __int__(self) -> int: ... | ||
def __abs__(self, name: str | None = None) -> Tensor: ... | ||
def __add__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __radd__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __sub__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rsub__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __mul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rmul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __pow__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __matmul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rmatmul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __floordiv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rfloordiv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __truediv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rtruediv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __neg__(self, name: str | None = None) -> Tensor: ... | ||
def __and__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rand__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __or__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __ror__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __eq__(self, other: _TensorCompatible) -> Tensor: ... # type: ignore[override] | ||
def __ne__(self, other: _TensorCompatible) -> Tensor: ... # type: ignore[override] | ||
def __ge__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ... | ||
def __gt__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ... | ||
def __le__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ... | ||
def __lt__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ... | ||
def __bool__(self) -> NoReturn: ... | ||
def __getitem__(self, slice_spec: _Slice | tuple[_Slice, ...]) -> Tensor: ... | ||
def __len__(self) -> int: ... | ||
# This only works for rank 0 tensors. | ||
def __index__(self) -> int: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
class VariableSynchronization(Enum): | ||
AUTO = 0 | ||
NONE = 1 | ||
ON_WRITE = 2 | ||
ON_READ = 3 | ||
|
||
class VariableAggregation(Enum): | ||
AUTO = 0 | ||
NONE = 1 | ||
ON_WRITE = 2 | ||
ON_READ = 3 | ||
|
||
class _VariableMetaclass(type): ... | ||
|
||
# Variable class in intent/documentation is a Tensor. In implementation there's | ||
# TODO comment to make it Tensor. It is not actually Tensor type wise, but even | ||
# dynamically patches on most methods of tf.Tensor | ||
# https://github.com/tensorflow/tensorflow/blob/9524a636cae9ae3f0554203c1ba7ee29c85fcf12/tensorflow/python/ops/variables.py#L1086. | ||
class Variable(Tensor, metaclass=_VariableMetaclass): | ||
def __init__( | ||
self, | ||
initial_value: Tensor | Callable[[], Tensor] | None = None, | ||
trainable: _bool | None = None, | ||
validate_shape: _bool = True, | ||
# Valid non-None values are deprecated. | ||
caching_device: None = None, | ||
name: str | None = None, | ||
# Real type is VariableDef protobuf type. Can be added after adding script | ||
# to generate tensorflow protobuf stubs with mypy-protobuf. | ||
variable_def: Incomplete | None = None, | ||
dtype: _DTypeLike | None = None, | ||
import_scope: str | None = None, | ||
constraint: Callable[[Tensor], Tensor] | None = None, | ||
synchronization: VariableSynchronization = VariableSynchronization.AUTO, | ||
aggregation: VariableAggregation = VariableAggregation.NONE, | ||
shape: _ShapeLike | None = None, | ||
) -> None: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
class RaggedTensor(metaclass=ABCMeta): | ||
def bounding_shape( | ||
self, axis: _TensorCompatible | None = None, name: str | None = None, out_type: _DTypeLike | None = None | ||
) -> Tensor: ... | ||
@classmethod | ||
def from_sparse( | ||
cls, st_input: SparseTensor, name: str | None = None, row_splits_dtype: _DTypeLike = int64 | ||
) -> RaggedTensor: ... | ||
def to_sparse(self, name: str | None = None) -> SparseTensor: ... | ||
def to_tensor( | ||
self, default_value: float | str | None = None, name: str | None = None, shape: _ShapeLike | None = None | ||
) -> Tensor: ... | ||
def __add__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __radd__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __sub__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __mul__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __rmul__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __floordiv__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __truediv__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __getitem__(self, slice_spec: _Slice | tuple[_Slice, ...]) -> RaggedTensor: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
class Operation: | ||
def __init__( | ||
self, | ||
node_def: Incomplete, | ||
g: Graph, | ||
# isinstance is used so can not be Sequence/Iterable. | ||
inputs: list[Tensor] | None = None, | ||
output_types: Unused = None, | ||
control_inputs: Iterable[Tensor | Operation] | None = None, | ||
input_types: Iterable[DType] | None = None, | ||
original_op: Operation | None = None, | ||
op_def: Incomplete = None, | ||
) -> None: ... | ||
@property | ||
def inputs(self) -> list[Tensor]: ... | ||
@property | ||
def outputs(self) -> list[Tensor]: ... | ||
@property | ||
def device(self) -> str: ... | ||
@property | ||
def name(self) -> str: ... | ||
@property | ||
def type(self) -> str: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
class TensorShape(metaclass=ABCMeta): | ||
def __init__(self, dims: _ShapeLike) -> None: ... | ||
@property | ||
def rank(self) -> int: ... | ||
def as_list(self) -> list[int | None]: ... | ||
def assert_has_rank(self, rank: int) -> None: ... | ||
def assert_is_compatible_with(self, other: Iterable[int | None]) -> None: ... | ||
def __bool__(self) -> _bool: ... | ||
@overload | ||
def __getitem__(self, key: int) -> int | None: ... | ||
@overload | ||
def __getitem__(self, key: slice) -> TensorShape: ... | ||
def __iter__(self) -> Iterator[int | None]: ... | ||
def __len__(self) -> int: ... | ||
def __add__(self, other: Iterable[int | None]) -> TensorShape: ... | ||
def __radd__(self, other: Iterable[int | None]) -> TensorShape: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
class Graph: | ||
def add_to_collection(self, name: str, value: object) -> None: ... | ||
def add_to_collections(self, names: Iterable[str] | str, value: object) -> None: ... | ||
@contextmanager | ||
def as_default(self: Self) -> Iterator[Self]: ... | ||
def finalize(self) -> None: ... | ||
def get_tensor_by_name(self, name: str) -> Tensor: ... | ||
def get_operation_by_name(self, name: str) -> Operation: ... | ||
def get_operations(self) -> list[Operation]: ... | ||
def get_name_scope(self) -> str: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
def __getattr__(name: str) -> Incomplete: ... |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from _typeshed import Incomplete | ||
from abc import ABCMeta | ||
from builtins import bool as _bool | ||
from typing import Any | ||
|
||
import numpy as np | ||
from tensorflow import _DTypeLike | ||
|
||
class _DTypeMeta(ABCMeta): ... | ||
|
||
class DType(metaclass=_DTypeMeta): | ||
@property | ||
def name(self) -> str: ... | ||
@property | ||
def as_numpy_dtype(self) -> type[np.number[Any]]: ... | ||
@property | ||
def is_numpy_compatible(self) -> _bool: ... | ||
@property | ||
def is_bool(self) -> _bool: ... | ||
@property | ||
def is_floating(self) -> _bool: ... | ||
@property | ||
def is_integer(self) -> _bool: ... | ||
@property | ||
def is_quantized(self) -> _bool: ... | ||
@property | ||
def is_unsigned(self) -> _bool: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
bool: DType | ||
complex128: DType | ||
complex64: DType | ||
bfloat16: DType | ||
float16: DType | ||
half: DType | ||
float32: DType | ||
float64: DType | ||
double: DType | ||
int8: DType | ||
int16: DType | ||
int32: DType | ||
int64: DType | ||
uint8: DType | ||
uint16: DType | ||
uint32: DType | ||
uint64: DType | ||
qint8: DType | ||
qint16: DType | ||
qint32: DType | ||
quint8: DType | ||
quint16: DType | ||
string: DType | ||
|
||
def as_dtype(type_value: _DTypeLike) -> DType: ... | ||
def __getattr__(name: str) -> Incomplete: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from _typeshed import Incomplete | ||
from typing import overload | ||
|
||
from tensorflow import RaggedTensor, Tensor, _TensorCompatible | ||
from tensorflow.sparse import SparseTensor | ||
|
||
@overload | ||
def abs(x: _TensorCompatible, name: str | None = None) -> Tensor: ... | ||
@overload | ||
def abs(x: SparseTensor, name: str | None = None) -> SparseTensor: ... | ||
@overload | ||
def abs(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... | ||
def __getattr__(name: str) -> Incomplete: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from _typeshed import Incomplete | ||
from abc import ABCMeta | ||
from typing_extensions import TypeAlias | ||
|
||
from tensorflow import Tensor, TensorShape, _TensorCompatible | ||
from tensorflow.dtypes import DType | ||
|
||
_SparseTensorCompatible: TypeAlias = _TensorCompatible | SparseTensor | ||
|
||
class SparseTensor(metaclass=ABCMeta): | ||
@property | ||
def indices(self) -> Tensor: ... | ||
@property | ||
def values(self) -> Tensor: ... | ||
@property | ||
def dense_shape(self) -> Tensor: ... | ||
@property | ||
def shape(self) -> TensorShape: ... | ||
@property | ||
def dtype(self) -> DType: ... | ||
name: str | ||
def __init__(self, indices: _TensorCompatible, values: _TensorCompatible, dense_shape: _TensorCompatible) -> None: ... | ||
def get_shape(self) -> TensorShape: ... | ||
# Many arithmetic operations are not directly supported. Some have alternatives like tf.sparse.add instead of +. | ||
def __div__(self, y: _SparseTensorCompatible) -> SparseTensor: ... | ||
def __truediv__(self, y: _SparseTensorCompatible) -> SparseTensor: ... | ||
def __mul__(self, y: _SparseTensorCompatible) -> SparseTensor: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
def __getattr__(name: str) -> Incomplete: ... |