You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a function f I auto-parallelize with jit over a (8,) mesh. Now I want to use it on a (4, 8) mesh by replicating f over the first axis and vmap will mess up the auto-sharding (and seems an over complication.) I wonder if there's any mechanism to do a pmap-ish thing over the first axis.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello,
I have a function
f
I auto-parallelize withjit
over a(8,)
mesh. Now I want to use it on a(4, 8)
mesh by replicatingf
over the first axis andvmap
will mess up the auto-sharding (and seems an over complication.) I wonder if there's any mechanism to do apmap
-ish thing over the first axis.Thanks in advance.
Beta Was this translation helpful? Give feedback.
All reactions