Skip to content

Commit

Permalink
Migrate/jit/mask (#1456)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Mar 12, 2024
1 parent fa0dfec commit 278fcb3
Show file tree
Hide file tree
Showing 15 changed files with 373 additions and 303 deletions.
10 changes: 10 additions & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,16 @@ macro_rules! gpu {
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs != rhs
($scope:expr, $out:ident = $lhs:ident != $rhs:expr) => {
gpu!($scope, $out = not_equal($lhs, $rhs))
};
// out = not_equal(lhs, rhs)
($scope:expr, $out:ident = not_equal($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::NotEqual(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs > rhs
($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => {
gpu!($scope, $out = greater($lhs, $rhs))
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub enum Operator {
Erf(UnaryOperator),
Recip(UnaryOperator),
Equal(BinaryOperator),
NotEqual(BinaryOperator),
Lower(BinaryOperator),
Clamp(ClampOperator),
Greater(BinaryOperator),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl Operator {
Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)),
Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)),
Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)),
Operator::NotEqual(op) => Operator::NotEqual(op.vectorize(vectorization)),
Operator::Lower(op) => Operator::Lower(op.vectorize(vectorization)),
Operator::Clamp(op) => Operator::Clamp(op.vectorize(vectorization)),
Operator::Greater(op) => Operator::Greater(op.vectorize(vectorization)),
Expand Down
5 changes: 5 additions & 0 deletions crates/burn-jit/src/fusion/tracing/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,11 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::NotEqual(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Sqrt(op) => mark_unary(
op,
&mut local_tensor_ids_input,
Expand Down
1 change: 0 additions & 1 deletion crates/burn-jit/src/kernel/contiguous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ impl IntoContiguousShader {

let offset_input = scope.zero(Elem::UInt);

// Batch offset for the lhs & rhs matrices.
IndexOffsetGlobalWithLayout {
tensors: vec![tensor],
indexes: vec![offset_input],
Expand Down
81 changes: 36 additions & 45 deletions crates/burn-jit/src/kernel/mask/mask_fill.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
use crate::{
compute::StaticKernel,
codegen::{EagerHandle, Execution, WorkgroupLaunch},
element::JitElement,
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
};

kernel_wgsl!(MaskFill, "../../template/mask/fill.wgsl");
kernel_wgsl!(MaskFillInplace, "../../template/mask/fill_inplace.wgsl");
use super::{MaskFill, MaskInplaceEagerKernel, MaskReadOnlyEagerKernel};

#[derive(Clone, Copy, Debug)]
/// Define how to run the mask fill kernel.
Expand Down Expand Up @@ -37,58 +34,52 @@ pub fn mask_fill<R: Runtime, E: JitElement, const D: usize>(
}
}

fn mask_fill_readonly<R: Runtime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
mask: JitTensor<R, u32, D>,
value: E,
) -> JitTensor<R, E, D> {
let num_elems = input.shape.num_elements();
fn mask_fill_readonly<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
input: JitTensor<R, EI, D>,
mask: JitTensor<R, EM, D>,
value: EI,
) -> JitTensor<R, EI, D> {
let client = input.client.clone();
let kernel = MaskReadOnlyEagerKernel::<MaskFill, R, EI, EM>::new(false);

let output = empty_device(
input.client.clone(),
input.device.clone(),
input.shape.clone(),
);

let value_handle = output.client.create(E::as_bytes(&[value]));
let kernel = StaticKernel::<
KernelSettings<MaskFill, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &mask, &output]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));

input.client.execute(
Box::new(kernel),
&[
&input.handle,
&value_handle,
&mask.handle,
Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
])
.outputs(&[EagerHandle::new(
&output.handle,
&info_handle,
],
);
&output.strides,
&output.shape.dims,
)])
.with_scalars(&[value])
.execute(WorkgroupLaunch::Output { pos: 0 });

output
}

fn mask_fill_inplace<R: Runtime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
mask: JitTensor<R, u32, D>,
value: E,
) -> JitTensor<R, E, D> {
let num_elems = input.shape.num_elements();
let value_handle = input.client.create(E::as_bytes(&[value]));
let kernel = StaticKernel::<
KernelSettings<MaskFillInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &mask]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));
fn mask_fill_inplace<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
input: JitTensor<R, EI, D>,
mask: JitTensor<R, EM, D>,
value: EI,
) -> JitTensor<R, EI, D> {
let kernel = MaskInplaceEagerKernel::<MaskFill, R, EI, EM>::new(false);

input.client.execute(
Box::new(kernel),
&[&input.handle, &value_handle, &mask.handle, &info_handle],
);
let client = input.client.clone();

Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
])
.with_scalars(&[value])
.execute(WorkgroupLaunch::Input { pos: 0 });

input
}
85 changes: 36 additions & 49 deletions crates/burn-jit/src/kernel/mask/mask_where.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
use crate::{
compute::StaticKernel,
codegen::{EagerHandle, Execution, WorkgroupLaunch},
element::JitElement,
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
};

kernel_wgsl!(MaskWhere, "../../template/mask/where.wgsl");
kernel_wgsl!(MaskWhereInplace, "../../template/mask/where_inplace.wgsl");
use super::{MaskInplaceEagerKernel, MaskReadOnlyEagerKernel, MaskWhere};

#[derive(Clone, Copy, Debug)]
/// Define how to run the mask where kernel.
Expand Down Expand Up @@ -40,63 +37,53 @@ pub fn mask_where<R: Runtime, E: JitElement, const D: usize>(
}
}

fn mask_where_readonly<R: Runtime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
mask: JitTensor<R, u32, D>,
value: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
let num_elems = input.shape.num_elements();
fn mask_where_readonly<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
input: JitTensor<R, EI, D>,
mask: JitTensor<R, EM, D>,
value: JitTensor<R, EI, D>,
) -> JitTensor<R, EI, D> {
let client = input.client.clone();
let kernel = MaskReadOnlyEagerKernel::<MaskWhere, R, EI, EM>::new(false);

let output = empty_device(
input.client.clone(),
input.device.clone(),
input.shape.clone(),
);

let kernel = StaticKernel::<
KernelSettings<MaskWhere, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &value, &mask, &output]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));

input.client.execute(
Box::new(kernel),
&[
&input.handle,
&value.handle,
&mask.handle,
Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
])
.outputs(&[EagerHandle::new(
&output.handle,
&info_handle,
],
);
&output.strides,
&output.shape.dims,
)])
.execute(WorkgroupLaunch::Output { pos: 0 });

output
}

fn mask_where_inplace<R: Runtime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
mask: JitTensor<R, u32, D>,
value: JitTensor<R, E, D>,
fn mask_where_inplace<R: Runtime, EI: JitElement, EM: JitElement, const D: usize>(
input: JitTensor<R, EI, D>,
mask: JitTensor<R, EM, D>,
value: JitTensor<R, EI, D>,
reverse: bool,
) -> JitTensor<R, E, D> {
let kernel = StaticKernel::<
KernelSettings<MaskWhereInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
input.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let mask = JitTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let mut info = build_info(&[&input, &value, &mask]);
info.push(match reverse {
true => 1,
false => 0,
});
let info_handle = input.client.create(bytemuck::cast_slice(&info));
) -> JitTensor<R, EI, D> {
let kernel = MaskInplaceEagerKernel::<MaskWhere, R, EI, EM>::new(reverse);

input.client.execute(
Box::new(kernel),
&[&input.handle, &value.handle, &mask.handle, &info_handle],
);
let client = input.client.clone();

Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
EagerHandle::new(&mask.handle, &mask.strides, &mask.shape.dims),
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
])
.execute(WorkgroupLaunch::Input { pos: 0 });

input
}
2 changes: 2 additions & 0 deletions crates/burn-jit/src/kernel/mask/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod base;
mod mask_fill;
mod mask_where;
mod shader;

pub(crate) use base::*;
pub(crate) use shader::*;

pub use mask_fill::*;
pub use mask_where::*;
Loading

0 comments on commit 278fcb3

Please sign in to comment.