Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fp8 direct quantization #69

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support direct quantization for gated einsum
  • Loading branch information
wenscarl committed Sep 25, 2024
commit 951b2b74a3bf849a49aa240be913cbc5f3795f13
145 changes: 110 additions & 35 deletions praxis/layers/injection/fp8_nvidia_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,33 +139,59 @@ def __call__(self, equation: str, *args: JTensor) -> JTensor:
k.dtype == comp_dtype
), f'k dtype has to be {comp_dtype}, but got {k.dtype}'
x = jnp.asarray(x, comp_dtype)

if self.use_direct_quant:
def _quantized_dot_general(
lhs, rhs, dimension_numbers, precision=None,
preferred_element_type=None
):
theta = self.theta
return fp8_ops.q_dot_dq(
lhs,
rhs,
lhs_scale=theta.input_scale,
rhs_scale=theta.kernel_scale,
out_grad_scale=theta.output_grad_scale,
lhs_amax_history=theta.input_amax_history,
rhs_amax_history=theta.kernel_amax_history,
out_grad_amax_history=theta.output_grad_amax_history,
compute_dtype=comp_dtype,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
lhs,
rhs,
lhs_scale=theta.input_scale,
rhs_scale=theta.kernel_scale,
out_grad_scale=theta.output_grad_scale,
lhs_amax_history=theta.input_amax_history,
rhs_amax_history=theta.kernel_amax_history,
out_grad_amax_history=theta.output_grad_amax_history,
compute_dtype=comp_dtype,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general)
else:
y = self.quantized_einsum(equation, x, k, return_quantized_x=False)

return y

# This decorator wraps a function to perform quantized dot product.
# It prepares the arguments for quantized_dot, including the pre-quantized input,
# scales, and amax histories. This allows for efficient FP8 matrix multiplication while
# managing quantization parameters.
def quantized_dot_config(
compute_dtype, q_lhs, lhs_scale, q_rhs, rhs_scale, out_grad_scale,
out_grad_amax_history
):
def decorator(func):
def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
return fp8_ops.quantized_dot(
lhs=lhs,
q_lhs=q_lhs,
lhs_scale=lhs_scale,
rhs=rhs,
q_rhs=q_rhs,
rhs_scale=rhs_scale,
out_grad_scale=out_grad_scale,
out_grad_amax_history=out_grad_amax_history,
compute_dtype=compute_dtype,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type
)
return wrapper
return decorator

class Fp8EinsumGatedOp(Fp8EinsumOp):
"""Wrapper around two jnp.einsum for gated FFN."""
Expand Down Expand Up @@ -202,29 +228,78 @@ def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]:
), f'k dtype has to be {comp_dtype}, but got {k.dtype} and {k_gated.dtype}'
x = jnp.asarray(x, comp_dtype)

y, x_qdq = self.quantized_einsum(equation, x, k, return_quantized_x=True)

theta = self.theta

k_gated_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
k_gated,
theta.kernel_scale_gated,
theta.kernel_amax_history_gated,
)
y_gated_qdq = jnp.einsum(
equation,
x_qdq,
k_gated_qdq,
_dot_general=fp8_ops.dot_general_with_precision,
)
y_gated = fp8_ops.out_qdq(
comp_dtype,
jnp.float8_e5m2,
y_gated_qdq,
theta.output_grad_scale_gated,
theta.output_grad_amax_history_gated,
)
if self.use_direct_quant:
q_x, new_input_scale = fp8_ops.in_q(
comp_dtype, jnp.float8_e4m3fn, x, theta.input_scale, theta.input_amax_history
)
q_k, new_kernel_scale = fp8_ops.in_q(
comp_dtype, jnp.float8_e4m3fn, k, theta.kernel_scale, theta.kernel_amax_history
)
q_k_gated, new_kernel_scale_gated = fp8_ops.in_q(
comp_dtype, jnp.float8_e4m3fn, k_gated, theta.kernel_scale_gated, theta.kernel_amax_history_gated
)
common_args = (comp_dtype, q_x, new_input_scale)
main_fp8_metas = (
q_k,
new_kernel_scale,
theta.output_grad_scale,
theta.output_grad_amax_history
)
gated_fp8_metas = (
q_k_gated,
new_kernel_scale_gated,
theta.output_grad_scale_gated,
theta.output_grad_amax_history_gated
)

@quantized_dot_config(*common_args, *main_fp8_metas)
def _quantized_dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
pass

@quantized_dot_config(*common_args, *gated_fp8_metas)
def _quantized_dot_general_gated(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
pass

y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general)
y_gated = jnp.einsum(equation, x, k_gated, _dot_general=_quantized_dot_general_gated)

y = out_dq(
dq_type=x.dtype,
lhs_scale=new_input_scale,
rhs_scale=new_kernel_scale,
out=y
)
y_gated = out_dq(
dq_type=x.dtype,
lhs_scale=new_input_scale,
rhs_scale=new_kernel_scale_gated,
out=y
)
else:
y, x_qdq = self.quantized_einsum(
equation, x, k, return_quantized_x=True
)
k_gated_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
k_gated,
theta.kernel_scale_gated,
theta.kernel_amax_history_gated,
)
y_gated_qdq = jnp.einsum(
equation,
x_qdq,
k_gated_qdq,
_dot_general=fp8_ops.dot_general_with_precision,
)
y_gated = fp8_ops.out_qdq(
comp_dtype,
jnp.float8_e5m2,
y_gated_qdq,
theta.output_grad_scale_gated,
theta.output_grad_amax_history_gated,
)

return y, y_gated