Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor - JIT] Gather Scatter new implementations #1356

Merged
merged 9 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions crates/burn-wgpu/src/codegen/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub struct InplaceMapping {

#[derive(Default)]
pub struct CompilationSettings {
vectorization: Vectorization,
vectorization: Option<Vectorization>,
inplace_available: bool,
workgroup_size: WorkgroupSize,
}
Expand All @@ -43,7 +43,7 @@ impl CompilationSettings {
/// Compile the shader with vectorization enabled.
#[allow(dead_code)]
pub fn vectorize(mut self, vectorization: Vectorization) -> Self {
self.vectorization = vectorization;
self.vectorization = Some(vectorization);
self
}
/// Compile the shader with inplace enabled.
Expand Down Expand Up @@ -100,7 +100,9 @@ impl Compilation {

/// Performs the compilation with the provided [settings](CompilationSettings).
pub fn compile(mut self, settings: CompilationSettings) -> ComputeShader {
self.info.scope.vectorize(settings.vectorization);
if let Some(vectorization) = settings.vectorization {
self.info.scope.vectorize(vectorization);
}

self.register_inputs(&settings);
self.register_outputs(&settings);
Expand Down Expand Up @@ -137,7 +139,11 @@ impl Compilation {
for input in self.info.inputs.drain(..) {
match input {
InputInfo::Array { item, visibility } => {
let item = item.vectorize(settings.vectorization);
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
item
};

self.input_bindings.push(Binding {
item: bool_item(item),
Expand Down Expand Up @@ -178,7 +184,11 @@ impl Compilation {
for array in self.info.outputs.drain(..) {
match array {
OutputInfo::ArrayWrite { item, local } => {
let item = item.vectorize(settings.vectorization);
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
item
};
let elem_adapted = bool_item(item);

self.output_bindings.push(Binding {
Expand All @@ -194,15 +204,23 @@ impl Compilation {
index += 1;
}
OutputInfo::InputArrayWrite { item, input, local } => {
let item = item.vectorize(settings.vectorization);
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
item
};

self.info.scope.write_global(
Variable::Local(local, item, self.info.scope.depth),
Variable::GlobalInputArray(input, bool_item(item)),
);
}
OutputInfo::Array { item } => {
let item = item.vectorize(settings.vectorization);
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
item
};
let elem_adapted = bool_item(item);

self.output_bindings.push(Binding {
Expand Down
57 changes: 0 additions & 57 deletions crates/burn-wgpu/src/codegen/dialect/gpu/algorithm.rs

This file was deleted.

8 changes: 4 additions & 4 deletions crates/burn-wgpu/src/codegen/dialect/gpu/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{Elem, Item, Scope, Variable};
use serde::{Deserialize, Serialize};

/// All branching types.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Branch {
// An if statement.
If(If),
Expand All @@ -16,20 +16,20 @@ pub enum Branch {
Break,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct If {
pub cond: Variable,
pub scope: Scope,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct IfElse {
pub cond: Variable,
pub scope_if: Scope,
pub scope_else: Scope,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RangeLoop {
pub i: Variable,
pub start: Variable,
Expand Down
11 changes: 11 additions & 0 deletions crates/burn-wgpu/src/codegen/dialect/gpu/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ macro_rules! gpu {
};
// out = input
($scope:expr, $out:ident = $input:ident) => {
gpu!($scope, $out = cast($input))
};
// out = cast(input)
($scope:expr, $out:ident = cast($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Assign(
gpu!(unary $input, $out)
));
Expand All @@ -217,6 +221,13 @@ macro_rules! gpu {
out: $out.into(),
});
};
// out = len(array)
($scope:expr, $out:ident = len($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Metadata::ArrayLength {
var: $input.into(),
out: $out.into(),
});
};
// range(start, end).for_each(|scope| { ... })
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
Expand Down
20 changes: 12 additions & 8 deletions crates/burn-wgpu/src/codegen/dialect/gpu/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
/// Therefore, during tracing, only operators and procedures can be registered.
///
/// [Procedure] expansions can safely use all operation variants.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub enum Operation {
Operator(Operator),
Expand All @@ -19,7 +19,7 @@ pub enum Operation {
}

/// All operators that can be used in a GPU compute shader.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub enum Operator {
Add(BinaryOperator),
Expand Down Expand Up @@ -50,7 +50,7 @@ pub enum Operator {
}

/// All metadata that can be access in a shader.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Metadata {
/// The stride of an array at the given dimension.
Stride {
Expand All @@ -64,35 +64,39 @@ pub enum Metadata {
var: Variable,
out: Variable,
},
ArrayLength {
var: Variable,
out: Variable,
},
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BinaryOperator {
pub lhs: Variable,
pub rhs: Variable,
pub out: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UnaryOperator {
pub input: Variable,
pub out: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ClampOperator {
pub input: Variable,
pub min_value: Variable,
pub max_value: Variable,
pub out: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ReadGlobalOperator {
pub variable: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ReadGlobalWithLayoutOperator {
pub variable: Variable,
pub tensor_read_pos: usize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::codegen::dialect::gpu::{macros::gpu, Item, Scope, Variable, Vectoriza
use serde::{Deserialize, Serialize};

/// Assign value to a variable based on a given condition.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ConditionalAssign {
pub cond: Variable,
pub lhs: Variable,
Expand Down
19 changes: 8 additions & 11 deletions crates/burn-wgpu/src/codegen/dialect/gpu/procedure/base.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
use super::{ConditionalAssign, Matmul, ReadGlobal, ReadGlobalWithLayout, WriteGlobal};
use super::{
ConditionalAssign, IndexOffsetGlobalWithLayout, ReadGlobal, ReadGlobalWithLayout, WriteGlobal,
};
use crate::codegen::dialect::gpu::Vectorization;
use serde::{Deserialize, Serialize};

/// Tensor operations that can't be executed with a simple [operator](super::super::Operator) should use a
/// procedure.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Procedure {
/// Read a global array with the given layout.
///
/// Crucial to read arrays that aren't contiguous and to perform correct broadcasting.
ReadGlobalWithLayout(ReadGlobalWithLayout),
/// Read a global array.
IndexOffsetGlobalWithLayout(IndexOffsetGlobalWithLayout),
ReadGlobal(ReadGlobal),
/// Matrix Multiplication procedure.
Matmul(Matmul),
/// Write to a global array.
WriteGlobal(WriteGlobal),
/// Assign value to a variable based on a given condition.
ConditionalAssign(ConditionalAssign),
}

Expand All @@ -27,11 +22,13 @@ impl Procedure {
Procedure::ReadGlobalWithLayout(op.vectorize(vectorization))
}
Procedure::ReadGlobal(op) => Procedure::ReadGlobal(op.vectorize(vectorization)),
Procedure::Matmul(op) => Procedure::Matmul(op.vectorize(vectorization)),
Procedure::WriteGlobal(op) => Procedure::WriteGlobal(op.vectorize(vectorization)),
Procedure::ConditionalAssign(proc) => {
Procedure::ConditionalAssign(proc.vectorize(vectorization))
}
Procedure::IndexOffsetGlobalWithLayout(op) => {
Procedure::IndexOffsetGlobalWithLayout(op.vectorize(vectorization))
}
}
}
}
Loading
Loading