Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Replaced lambda expressions with def functions #27738

Merged
merged 4 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,10 @@ def dynamic_backend_as(value):
downcast_dtypes = False
upcast_dtypes = False
crosscast_dtypes = False
cast_dtypes = lambda: downcast_dtypes and upcast_dtypes and crosscast_dtypes


def cast_dtypes():
return downcast_dtypes and upcast_dtypes and crosscast_dtypes


def downcast_data_types(val=True):
Expand Down
10 changes: 7 additions & 3 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,13 @@ def try_array_function_override(func, overloaded_args, types, args, kwargs):

def _get_first_array(*args, **kwargs):
# ToDo: make this more efficient, with function ivy.nested_nth_index_where
array_fn = lambda x: (
ivy.is_array(x) if not hasattr(x, "_ivy_array") else ivy.is_array(x.ivy_array)
)
def array_fn(x):
return (
ivy.is_array(x)
if not hasattr(x, "_ivy_array")
else ivy.is_array(x.ivy_array)
)

array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"]
arr = None
if args:
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/frontends/jax/numpy/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,9 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
if axis < 0:
axis = axis + ndim

func = lambda elem: func1d(elem, *args, **kwargs)
def func(elem):
return func1d(elem, *args, **kwargs)

for i in range(1, ndim - axis):
func = ivy.vmap(func, in_axes=i, out_axes=-1)
for i in range(axis):
Expand Down
33 changes: 24 additions & 9 deletions ivy/functional/frontends/numpy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def _assert_array(args, dtype, scalar_check=False, casting="safe"):
if ivy.is_bool_dtype(dtype):
assert_fn = ivy.is_bool_dtype
if ivy.is_int_dtype(dtype):
assert_fn = lambda x: not ivy.is_float_dtype(x)

def assert_fn(x): # noqa F811
return not ivy.is_float_dtype(x)

if assert_fn:
ivy.utils.assertions.check_all_or_any_fn(
*args,
Expand All @@ -51,13 +54,19 @@ def _assert_no_array(args, dtype, scalar_check=False, none=False):
if args:
first_arg = args[0]
fn_func = ivy.as_ivy_dtype(dtype) if ivy.exists(dtype) else ivy.dtype(first_arg)
assert_fn = lambda x: ivy.dtype(x) == fn_func

def assert_fn(x):
return ivy.dtype(x) == fn_func

if scalar_check:
assert_fn = lambda x: (
ivy.dtype(x) == fn_func
if ivy.shape(x) != ()
else _casting_no_special_case(ivy.dtype(x), fn_func, none)
)

def assert_fn(x): # noqa F811
return (
ivy.dtype(x) == fn_func
if ivy.shape(x) != ()
else _casting_no_special_case(ivy.dtype(x), fn_func, none)
)

ivy.utils.assertions.check_all_or_any_fn(
*args,
fn=assert_fn,
Expand Down Expand Up @@ -105,9 +114,15 @@ def _assert_scalar(args, dtype):
if args and dtype:
assert_fn = None
if ivy.is_int_dtype(dtype):
assert_fn = lambda x: not isinstance(x, float)

def assert_fn(x): # noqa F811
return not isinstance(x, float)

elif ivy.is_bool_dtype(dtype):
assert_fn = lambda x: isinstance(x, bool)

def assert_fn(x):
return isinstance(x, bool)

if assert_fn:
ivy.utils.assertions.check_all_or_any_fn(
*args,
Expand Down
15 changes: 12 additions & 3 deletions ivy/functional/ivy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,11 +1689,20 @@ def area_interpolate(x, dims, size, scale):
def get_interpolate_kernel(mode):
kernel_func = _triangle_kernel
if mode == "tf_bicubic":
kernel_func = lambda inputs: _cubic_kernel(inputs)

def kernel_func(inputs): # noqa F811
return _cubic_kernel(inputs)

elif mode == "lanczos3":
kernel_func = lambda inputs: _lanczos_kernel(3, inputs)

def kernel_func(inputs):
return _lanczos_kernel(3, inputs)

elif mode == "lanczos5":
kernel_func = lambda inputs: _lanczos_kernel(5, inputs)

def kernel_func(inputs):
return _lanczos_kernel(5, inputs)

return kernel_func


Expand Down
53 changes: 37 additions & 16 deletions ivy/utils/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@ def _broadcast_inputs(x1, x2):


def check_less(x1, x2, allow_equal=False, message="", as_array=True):
comp_fn = lambda x1, x2: (ivy.any(x1 > x2), ivy.any(x1 >= x2))
def comp_fn(x1, x2):
return ivy.any(x1 > x2), ivy.any(x1 >= x2)

if not as_array:
iter_comp_fn = lambda x1_, x2_: (
any(x1 > x2 for x1, x2 in zip(x1_, x2_)),
any(x1 >= x2 for x1, x2 in zip(x1_, x2_)),
)
comp_fn = lambda x1, x2: iter_comp_fn(*_broadcast_inputs(x1, x2))

def iter_comp_fn(x1_, x2_):
return any(x1 > x2 for x1, x2 in zip(x1_, x2_)), any(
x1 >= x2 for x1, x2 in zip(x1_, x2_)
)

def comp_fn(x1, x2): # noqa F811
return iter_comp_fn(*_broadcast_inputs(x1, x2))

gt, gt_eq = comp_fn(x1, x2)
# less_equal
if allow_equal and gt:
Expand All @@ -42,13 +48,19 @@ def check_less(x1, x2, allow_equal=False, message="", as_array=True):


def check_greater(x1, x2, allow_equal=False, message="", as_array=True):
comp_fn = lambda x1, x2: (ivy.any(x1 < x2), ivy.any(x1 <= x2))
def comp_fn(x1, x2):
return ivy.any(x1 < x2), ivy.any(x1 <= x2)

if not as_array:
iter_comp_fn = lambda x1_, x2_: (
any(x1 < x2 for x1, x2 in zip(x1_, x2_)),
any(x1 <= x2 for x1, x2 in zip(x1_, x2_)),
)
comp_fn = lambda x1, x2: iter_comp_fn(*_broadcast_inputs(x1, x2))

def iter_comp_fn(x1_, x2_):
return any(x1 < x2 for x1, x2 in zip(x1_, x2_)), any(
x1 <= x2 for x1, x2 in zip(x1_, x2_)
)

def comp_fn(x1, x2): # noqa F811
return iter_comp_fn(*_broadcast_inputs(x1, x2))

lt, lt_eq = comp_fn(x1, x2)
# greater_equal
if allow_equal and lt:
Expand All @@ -63,11 +75,20 @@ def check_greater(x1, x2, allow_equal=False, message="", as_array=True):

def check_equal(x1, x2, inverse=False, message="", as_array=True):
# not_equal
eq_fn = lambda x1, x2: (x1 == x2 if inverse else x1 != x2)
comp_fn = lambda x1, x2: ivy.any(eq_fn(x1, x2))
def eq_fn(x1, x2):
return x1 == x2 if inverse else x1 != x2

def comp_fn(x1, x2):
return ivy.any(eq_fn(x1, x2))

if not as_array:
iter_comp_fn = lambda x1_, x2_: any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_))
comp_fn = lambda x1, x2: iter_comp_fn(*_broadcast_inputs(x1, x2))

def iter_comp_fn(x1_, x2_):
return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_))

def comp_fn(x1, x2): # noqa F811
return iter_comp_fn(*_broadcast_inputs(x1, x2))

eq = comp_fn(x1, x2)
if inverse and eq:
raise ivy.utils.exceptions.IvyException(
Expand Down
Loading