Skip to content

Commit

Permalink
[Refactor - JIT] Gather Scatter new implementations (#1356)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Feb 26, 2024
1 parent 576bb44 commit bdec8d5
Show file tree
Hide file tree
Showing 21 changed files with 609 additions and 437 deletions.
32 changes: 25 additions & 7 deletions crates/burn-wgpu/src/codegen/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub struct InplaceMapping {

#[derive(Default)]
pub struct CompilationSettings {
vectorization: Vectorization,
vectorization: Option<Vectorization>,
inplace_available: bool,
workgroup_size: WorkgroupSize,
}
Expand All @@ -43,7 +43,7 @@ impl CompilationSettings {
/// Compile the shader with vectorization enabled.
#[allow(dead_code)]
pub fn vectorize(mut self, vectorization: Vectorization) -> Self {
self.vectorization = vectorization;
self.vectorization = Some(vectorization);
self
}
/// Compile the shader with inplace enabled.
Expand Down Expand Up @@ -100,7 +100,9 @@ impl Compilation {

/// Performs the compilation with the provided [settings](CompilationSettings).
pub fn compile(mut self, settings: CompilationSettings) -> ComputeShader {
self.info.scope.vectorize(settings.vectorization);
if let Some(vectorization) = settings.vectorization {
self.info.scope.vectorize(vectorization);
}

self.register_inputs(&settings);
self.register_outputs(&settings);
Expand Down Expand Up @@ -137,7 +139,11 @@ impl Compilation {
for input in self.info.inputs.drain(..) {
match input {
InputInfo::Array { item, visibility } => {
let item = item.vectorize(settings.vectorization);
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
item
};

self.input_bindings.push(Binding {
item: bool_item(item),
Expand Down Expand Up @@ -178,7 +184,11 @@ impl Compilation {
for array in self.info.outputs.drain(..) {
match array {
OutputInfo::ArrayWrite { item, local } => {
let item = item.vectorize(settings.vectorization);
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
item
};
let elem_adapted = bool_item(item);

self.output_bindings.push(Binding {
Expand All @@ -194,15 +204,23 @@ impl Compilation {
index += 1;
}
OutputInfo::InputArrayWrite { item, input, local } => {
let item = item.vectorize(settings.vectorization);
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
item
};

self.info.scope.write_global(
Variable::Local(local, item, self.info.scope.depth),
Variable::GlobalInputArray(input, bool_item(item)),
);
}
OutputInfo::Array { item } => {
let item = item.vectorize(settings.vectorization);
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
item
};
let elem_adapted = bool_item(item);

self.output_bindings.push(Binding {
Expand Down
57 changes: 0 additions & 57 deletions crates/burn-wgpu/src/codegen/dialect/gpu/algorithm.rs

This file was deleted.

8 changes: 4 additions & 4 deletions crates/burn-wgpu/src/codegen/dialect/gpu/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{Elem, Item, Scope, Variable};
use serde::{Deserialize, Serialize};

/// All branching types.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Branch {
// An if statement.
If(If),
Expand All @@ -16,20 +16,20 @@ pub enum Branch {
Break,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct If {
pub cond: Variable,
pub scope: Scope,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct IfElse {
pub cond: Variable,
pub scope_if: Scope,
pub scope_else: Scope,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RangeLoop {
pub i: Variable,
pub start: Variable,
Expand Down
11 changes: 11 additions & 0 deletions crates/burn-wgpu/src/codegen/dialect/gpu/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ macro_rules! gpu {
};
// out = input
($scope:expr, $out:ident = $input:ident) => {
gpu!($scope, $out = cast($input))
};
// out = cast(input)
($scope:expr, $out:ident = cast($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Assign(
gpu!(unary $input, $out)
));
Expand All @@ -217,6 +221,13 @@ macro_rules! gpu {
out: $out.into(),
});
};
// out = len(array)
($scope:expr, $out:ident = len($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Metadata::ArrayLength {
var: $input.into(),
out: $out.into(),
});
};
// range(start, end).for_each(|scope| { ... })
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
Expand Down
20 changes: 12 additions & 8 deletions crates/burn-wgpu/src/codegen/dialect/gpu/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
/// Therefore, during tracing, only operators and procedures can be registered.
///
/// [Procedure] expansions can safely use all operation variants.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub enum Operation {
Operator(Operator),
Expand All @@ -19,7 +19,7 @@ pub enum Operation {
}

/// All operators that can be used in a GPU compute shader.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub enum Operator {
Add(BinaryOperator),
Expand Down Expand Up @@ -50,7 +50,7 @@ pub enum Operator {
}

/// All metadata that can be access in a shader.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Metadata {
/// The stride of an array at the given dimension.
Stride {
Expand All @@ -64,35 +64,39 @@ pub enum Metadata {
var: Variable,
out: Variable,
},
ArrayLength {
var: Variable,
out: Variable,
},
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BinaryOperator {
pub lhs: Variable,
pub rhs: Variable,
pub out: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UnaryOperator {
pub input: Variable,
pub out: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ClampOperator {
pub input: Variable,
pub min_value: Variable,
pub max_value: Variable,
pub out: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ReadGlobalOperator {
pub variable: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ReadGlobalWithLayoutOperator {
pub variable: Variable,
pub tensor_read_pos: usize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::codegen::dialect::gpu::{macros::gpu, Item, Scope, Variable, Vectoriza
use serde::{Deserialize, Serialize};

/// Assign value to a variable based on a given condition.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ConditionalAssign {
pub cond: Variable,
pub lhs: Variable,
Expand Down
19 changes: 8 additions & 11 deletions crates/burn-wgpu/src/codegen/dialect/gpu/procedure/base.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
use super::{ConditionalAssign, Matmul, ReadGlobal, ReadGlobalWithLayout, WriteGlobal};
use super::{
ConditionalAssign, IndexOffsetGlobalWithLayout, ReadGlobal, ReadGlobalWithLayout, WriteGlobal,
};
use crate::codegen::dialect::gpu::Vectorization;
use serde::{Deserialize, Serialize};

/// Tensor operations that can't be executed with a simple [operator](super::super::Operator) should use a
/// procedure.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Procedure {
/// Read a global array with the given layout.
///
/// Crucial to read arrays that aren't contiguous and to perform correct broadcasting.
ReadGlobalWithLayout(ReadGlobalWithLayout),
/// Read a global array.
IndexOffsetGlobalWithLayout(IndexOffsetGlobalWithLayout),
ReadGlobal(ReadGlobal),
/// Matrix Multiplication procedure.
Matmul(Matmul),
/// Write to a global array.
WriteGlobal(WriteGlobal),
/// Assign value to a variable based on a given condition.
ConditionalAssign(ConditionalAssign),
}

Expand All @@ -27,11 +22,13 @@ impl Procedure {
Procedure::ReadGlobalWithLayout(op.vectorize(vectorization))
}
Procedure::ReadGlobal(op) => Procedure::ReadGlobal(op.vectorize(vectorization)),
Procedure::Matmul(op) => Procedure::Matmul(op.vectorize(vectorization)),
Procedure::WriteGlobal(op) => Procedure::WriteGlobal(op.vectorize(vectorization)),
Procedure::ConditionalAssign(proc) => {
Procedure::ConditionalAssign(proc.vectorize(vectorization))
}
Procedure::IndexOffsetGlobalWithLayout(op) => {
Procedure::IndexOffsetGlobalWithLayout(op.vectorize(vectorization))
}
}
}
}
Loading

0 comments on commit bdec8d5

Please sign in to comment.