Skip to content

Commit

Permalink
Initial tensorflow stubs (#8974)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Waygood <[email protected]>
  • Loading branch information
hmc-cs-mdrissi and AlexWaygood authored Jan 14, 2023
1 parent 0a291da commit ea0ae21
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 0 deletions.
22 changes: 22 additions & 0 deletions stubs/tensorflow/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Some methods are dynamically patched onto to instances as they
# may depend on whether code is executed in graph/eager/v1/v2/etc.
# Tensorflow supports multiple modes of execution which changes some
# of the attributes/methods/even class hierachies.
tensorflow.Tensor.__int__
tensorflow.Tensor.numpy
tensorflow.Tensor.__index__
# Incomplete
tensorflow.sparse.SparseTensor.__getattr__
tensorflow.SparseTensor.__getattr__
tensorflow.TensorShape.__getattr__
tensorflow.dtypes.DType.__getattr__
tensorflow.RaggedTensor.__getattr__
tensorflow.DType.__getattr__
tensorflow.Graph.__getattr__
tensorflow.Operation.__getattr__
tensorflow.Variable.__getattr__
# Internal undocumented API
tensorflow.RaggedTensor.__init__
# Has an undocumented extra argument that tf.Variable which acts like subclass
# (by dynamically patching tf.Tensor methods) does not preserve.
tensorflow.Tensor.__getitem__
3 changes: 3 additions & 0 deletions stubs/tensorflow/METADATA.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
version = "2.10.*"
# requires a version of numpy with a `py.typed` file
requires = ["numpy>=1.20"]
195 changes: 195 additions & 0 deletions stubs/tensorflow/tensorflow/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from _typeshed import Incomplete, Self, Unused
from abc import ABCMeta
from builtins import bool as _bool
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import contextmanager
from enum import Enum
from typing import Any, NoReturn, overload
from typing_extensions import TypeAlias

import numpy
from tensorflow.dtypes import *

# Most tf.math functions are exported as tf, but sadly not all are.
from tensorflow.math import abs as abs
from tensorflow.sparse import SparseTensor

# Tensors ideally should be a generic type, but properly typing data type/shape
# will be a lot of work. Until we have good non-generic tensorflow stubs,
# we will skip making Tensor generic. Also good type hints for shapes will
# run quickly into many places where type system is not strong enough today.
# So shape typing is probably not worth doing anytime soon.
_Slice: TypeAlias = int | slice | None

_FloatDataSequence: TypeAlias = Sequence[float] | Sequence[_FloatDataSequence]
_StrDataSequence: TypeAlias = Sequence[str] | Sequence[_StrDataSequence]
_ScalarTensorCompatible: TypeAlias = Tensor | str | float | numpy.ndarray[Any, Any] | numpy.number[Any]
_TensorCompatible: TypeAlias = _ScalarTensorCompatible | Sequence[_TensorCompatible]
_ShapeLike: TypeAlias = TensorShape | Iterable[_ScalarTensorCompatible | None] | int | Tensor
_DTypeLike: TypeAlias = DType | str | numpy.dtype[Any]

class Tensor:
def __init__(self, op: Operation, value_index: int, dtype: DType) -> None: ...
def consumers(self) -> list[Incomplete]: ...
@property
def shape(self) -> TensorShape: ...
def get_shape(self) -> TensorShape: ...
@property
def dtype(self) -> DType: ...
@property
def graph(self) -> Graph: ...
@property
def name(self) -> str: ...
@property
def op(self) -> Operation: ...
def numpy(self) -> numpy.ndarray[Any, Any]: ...
def __int__(self) -> int: ...
def __abs__(self, name: str | None = None) -> Tensor: ...
def __add__(self, other: _TensorCompatible) -> Tensor: ...
def __radd__(self, other: _TensorCompatible) -> Tensor: ...
def __sub__(self, other: _TensorCompatible) -> Tensor: ...
def __rsub__(self, other: _TensorCompatible) -> Tensor: ...
def __mul__(self, other: _TensorCompatible) -> Tensor: ...
def __rmul__(self, other: _TensorCompatible) -> Tensor: ...
def __pow__(self, other: _TensorCompatible) -> Tensor: ...
def __matmul__(self, other: _TensorCompatible) -> Tensor: ...
def __rmatmul__(self, other: _TensorCompatible) -> Tensor: ...
def __floordiv__(self, other: _TensorCompatible) -> Tensor: ...
def __rfloordiv__(self, other: _TensorCompatible) -> Tensor: ...
def __truediv__(self, other: _TensorCompatible) -> Tensor: ...
def __rtruediv__(self, other: _TensorCompatible) -> Tensor: ...
def __neg__(self, name: str | None = None) -> Tensor: ...
def __and__(self, other: _TensorCompatible) -> Tensor: ...
def __rand__(self, other: _TensorCompatible) -> Tensor: ...
def __or__(self, other: _TensorCompatible) -> Tensor: ...
def __ror__(self, other: _TensorCompatible) -> Tensor: ...
def __eq__(self, other: _TensorCompatible) -> Tensor: ... # type: ignore[override]
def __ne__(self, other: _TensorCompatible) -> Tensor: ... # type: ignore[override]
def __ge__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
def __gt__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
def __le__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
def __lt__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
def __bool__(self) -> NoReturn: ...
def __getitem__(self, slice_spec: _Slice | tuple[_Slice, ...]) -> Tensor: ...
def __len__(self) -> int: ...
# This only works for rank 0 tensors.
def __index__(self) -> int: ...
def __getattr__(self, name: str) -> Incomplete: ...

class VariableSynchronization(Enum):
AUTO = 0
NONE = 1
ON_WRITE = 2
ON_READ = 3

class VariableAggregation(Enum):
AUTO = 0
NONE = 1
ON_WRITE = 2
ON_READ = 3

class _VariableMetaclass(type): ...

# Variable class in intent/documentation is a Tensor. In implementation there's
# TODO comment to make it Tensor. It is not actually Tensor type wise, but even
# dynamically patches on most methods of tf.Tensor
# https://github.com/tensorflow/tensorflow/blob/9524a636cae9ae3f0554203c1ba7ee29c85fcf12/tensorflow/python/ops/variables.py#L1086.
class Variable(Tensor, metaclass=_VariableMetaclass):
def __init__(
self,
initial_value: Tensor | Callable[[], Tensor] | None = None,
trainable: _bool | None = None,
validate_shape: _bool = True,
# Valid non-None values are deprecated.
caching_device: None = None,
name: str | None = None,
# Real type is VariableDef protobuf type. Can be added after adding script
# to generate tensorflow protobuf stubs with mypy-protobuf.
variable_def: Incomplete | None = None,
dtype: _DTypeLike | None = None,
import_scope: str | None = None,
constraint: Callable[[Tensor], Tensor] | None = None,
synchronization: VariableSynchronization = VariableSynchronization.AUTO,
aggregation: VariableAggregation = VariableAggregation.NONE,
shape: _ShapeLike | None = None,
) -> None: ...
def __getattr__(self, name: str) -> Incomplete: ...

class RaggedTensor(metaclass=ABCMeta):
def bounding_shape(
self, axis: _TensorCompatible | None = None, name: str | None = None, out_type: _DTypeLike | None = None
) -> Tensor: ...
@classmethod
def from_sparse(
cls, st_input: SparseTensor, name: str | None = None, row_splits_dtype: _DTypeLike = int64
) -> RaggedTensor: ...
def to_sparse(self, name: str | None = None) -> SparseTensor: ...
def to_tensor(
self, default_value: float | str | None = None, name: str | None = None, shape: _ShapeLike | None = None
) -> Tensor: ...
def __add__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __radd__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __sub__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __mul__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __rmul__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __floordiv__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __truediv__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def __getitem__(self, slice_spec: _Slice | tuple[_Slice, ...]) -> RaggedTensor: ...
def __getattr__(self, name: str) -> Incomplete: ...

class Operation:
def __init__(
self,
node_def: Incomplete,
g: Graph,
# isinstance is used so can not be Sequence/Iterable.
inputs: list[Tensor] | None = None,
output_types: Unused = None,
control_inputs: Iterable[Tensor | Operation] | None = None,
input_types: Iterable[DType] | None = None,
original_op: Operation | None = None,
op_def: Incomplete = None,
) -> None: ...
@property
def inputs(self) -> list[Tensor]: ...
@property
def outputs(self) -> list[Tensor]: ...
@property
def device(self) -> str: ...
@property
def name(self) -> str: ...
@property
def type(self) -> str: ...
def __getattr__(self, name: str) -> Incomplete: ...

class TensorShape(metaclass=ABCMeta):
def __init__(self, dims: _ShapeLike) -> None: ...
@property
def rank(self) -> int: ...
def as_list(self) -> list[int | None]: ...
def assert_has_rank(self, rank: int) -> None: ...
def assert_is_compatible_with(self, other: Iterable[int | None]) -> None: ...
def __bool__(self) -> _bool: ...
@overload
def __getitem__(self, key: int) -> int | None: ...
@overload
def __getitem__(self, key: slice) -> TensorShape: ...
def __iter__(self) -> Iterator[int | None]: ...
def __len__(self) -> int: ...
def __add__(self, other: Iterable[int | None]) -> TensorShape: ...
def __radd__(self, other: Iterable[int | None]) -> TensorShape: ...
def __getattr__(self, name: str) -> Incomplete: ...

class Graph:
def add_to_collection(self, name: str, value: object) -> None: ...
def add_to_collections(self, names: Iterable[str] | str, value: object) -> None: ...
@contextmanager
def as_default(self: Self) -> Iterator[Self]: ...
def finalize(self) -> None: ...
def get_tensor_by_name(self, name: str) -> Tensor: ...
def get_operation_by_name(self, name: str) -> Operation: ...
def get_operations(self) -> list[Operation]: ...
def get_name_scope(self) -> str: ...
def __getattr__(self, name: str) -> Incomplete: ...

def __getattr__(name: str) -> Incomplete: ...
Empty file.
55 changes: 55 additions & 0 deletions stubs/tensorflow/tensorflow/dtypes.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from _typeshed import Incomplete
from abc import ABCMeta
from builtins import bool as _bool
from typing import Any

import numpy as np
from tensorflow import _DTypeLike

class _DTypeMeta(ABCMeta): ...

class DType(metaclass=_DTypeMeta):
@property
def name(self) -> str: ...
@property
def as_numpy_dtype(self) -> type[np.number[Any]]: ...
@property
def is_numpy_compatible(self) -> _bool: ...
@property
def is_bool(self) -> _bool: ...
@property
def is_floating(self) -> _bool: ...
@property
def is_integer(self) -> _bool: ...
@property
def is_quantized(self) -> _bool: ...
@property
def is_unsigned(self) -> _bool: ...
def __getattr__(self, name: str) -> Incomplete: ...

bool: DType
complex128: DType
complex64: DType
bfloat16: DType
float16: DType
half: DType
float32: DType
float64: DType
double: DType
int8: DType
int16: DType
int32: DType
int64: DType
uint8: DType
uint16: DType
uint32: DType
uint64: DType
qint8: DType
qint16: DType
qint32: DType
quint8: DType
quint16: DType
string: DType

def as_dtype(type_value: _DTypeLike) -> DType: ...
def __getattr__(name: str) -> Incomplete: ...
13 changes: 13 additions & 0 deletions stubs/tensorflow/tensorflow/math.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from _typeshed import Incomplete
from typing import overload

from tensorflow import RaggedTensor, Tensor, _TensorCompatible
from tensorflow.sparse import SparseTensor

@overload
def abs(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def abs(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
@overload
def abs(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
def __getattr__(name: str) -> Incomplete: ...
30 changes: 30 additions & 0 deletions stubs/tensorflow/tensorflow/sparse.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from _typeshed import Incomplete
from abc import ABCMeta
from typing_extensions import TypeAlias

from tensorflow import Tensor, TensorShape, _TensorCompatible
from tensorflow.dtypes import DType

_SparseTensorCompatible: TypeAlias = _TensorCompatible | SparseTensor

class SparseTensor(metaclass=ABCMeta):
@property
def indices(self) -> Tensor: ...
@property
def values(self) -> Tensor: ...
@property
def dense_shape(self) -> Tensor: ...
@property
def shape(self) -> TensorShape: ...
@property
def dtype(self) -> DType: ...
name: str
def __init__(self, indices: _TensorCompatible, values: _TensorCompatible, dense_shape: _TensorCompatible) -> None: ...
def get_shape(self) -> TensorShape: ...
# Many arithmetic operations are not directly supported. Some have alternatives like tf.sparse.add instead of +.
def __div__(self, y: _SparseTensorCompatible) -> SparseTensor: ...
def __truediv__(self, y: _SparseTensorCompatible) -> SparseTensor: ...
def __mul__(self, y: _SparseTensorCompatible) -> SparseTensor: ...
def __getattr__(self, name: str) -> Incomplete: ...

def __getattr__(name: str) -> Incomplete: ...

0 comments on commit ea0ae21

Please sign in to comment.