Skip to content

Commit

Permalink
fix (JAX backend)(sorting.py): fixing the sort function to handle `…
Browse files Browse the repository at this point in the history
…jax.__version__ > 0.4.28` wherein the `kind` argument has been deprecated.
  • Loading branch information
YushaArif99 committed Sep 24, 2024
1 parent bce9202 commit 54fbb25
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions ivy/functional/backends/jax/sorting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# global
import jax
import jax.numpy as jnp
from typing import Optional, Literal, Union, List

from packaging.version import parse as _parse
# local
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
Expand Down Expand Up @@ -31,7 +32,11 @@ def sort(
out: Optional[JaxArray] = None,
) -> JaxArray:
kind = "stable" if stable else "quicksort"
ret = jnp.asarray(jnp.sort(x, axis=axis, kind=kind))
if _parse(jax.__version__) > _parse("0.4.28"):
# kind argument has been removed in jax 0.4.28
ret = jnp.asarray(jnp.sort(x, axis=axis, stable=stable))
else:
ret = jnp.asarray(jnp.sort(x, axis=axis, kind=kind))
if descending:
ret = jnp.asarray(jnp.flip(ret, axis=axis))
return ret
Expand Down

0 comments on commit 54fbb25

Please sign in to comment.