Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prod and prod_dim tensor ops #1460

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 59 additions & 57 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,37 +134,37 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t

Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.

| Burn | PyTorch Equivalent |
| ------------------------------------- | -------------------------------------------- |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------ |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |

### Numeric Operations

Expand Down Expand Up @@ -203,13 +203,13 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` |
| `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` |
| `tensor.max()` | `tensor.max()` |
| `tensor.max_dim(dim)` | `tensor.max(dim)` |
| `tensor.max_dim(dim)` | `tensor.max(dim, keepdim=True)` |
| `tensor.max_dim_with_indices(dim)` | N/A |
| `tensor.max_pair(other)` | `torch.Tensor.max(a,b)` |
| `tensor.mean()` | `tensor.mean()` |
| `tensor.mean_dim(dim)` | `tensor.mean(dim)` |
| `tensor.mean_dim(dim)` | `tensor.mean(dim, keepdim=True)` |
| `tensor.min()` | `tensor.min()` |
| `tensor.min_dim(dim)` | `tensor.min(dim)` |
| `tensor.min_dim(dim)` | `tensor.min(dim, keepdim=True)` |
| `tensor.min_dim_with_indices(dim)` | N/A |
| `tensor.min_pair(other)` | `torch.Tensor.min(a,b)` |
| `tensor.mul(other)` or `tensor * other` | `tensor * other` |
Expand All @@ -218,14 +218,16 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
| `tensor.prod()` | `tensor.prod()` |
| `tensor.prod_dim(dim)` | `tensor.prod(dim, keepdim=True)` |
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
| `tensor.select_assign(dim, indices, values)` | N/A |
| `tensor.sign()` | `tensor.sign()` |
| `tensor.sub(other)` or `tensor - other` | `tensor - other` |
| `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` |
| `tensor.sum()` | `tensor.sum()` |
| `tensor.sum_dim(dim)` | `tensor.sum(dim)` |
| `tensor.sum_dim(dim)` | `tensor.sum(dim, keepdim=True)` |
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |

Expand Down Expand Up @@ -269,35 +271,35 @@ Those operations are only available for `Int` tensors.
| ------------------------------------------------ | ------------------------------------------------------- |
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` |
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |

# Bool Operations

Those operations are only available for `Bool` tensors.

| Burn API | PyTorch Equivalent |
| ------------------- | ----------------------------------- |
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.not()` | `tensor.logical_not()` |
| `tensor.argwhere()` | `tensor.argwhere()` |
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |
| Burn API | PyTorch Equivalent |
| ------------------- | ------------------------------- |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.int()` | `tensor.to(torch.long)` |
| `tensor.not()` | `tensor.logical_not()` |
| `tensor.argwhere()` | `tensor.argwhere()` |
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |

## Activation Functions

| Burn API | PyTorch Equivalent |
| ---------------------------------------- | ----------------------------------------------------- |
| `activation::gelu(tensor)` | Similar to `nn.functional.gelu(tensor)` |
| `activation::log_sigmoid(tensor)` | Similar to `nn.functional.log_sigmoid(tensor)` |
| `activation::log_softmax(tensor, dim)` | Similar to `nn.functional.log_softmax(tensor, dim)` |
| `activation::mish(tensor)` | Similar to `nn.functional.mish(tensor)` |
| `activation::prelu(tensor,alpha)` | Similar to `nn.functional.prelu(tensor,weight)` |
| `activation::quiet_softmax(tensor, dim)` | Similar to `nn.functional.quiet_softmax(tensor, dim)` |
| `activation::relu(tensor)` | Similar to `nn.functional.relu(tensor)` |
| `activation::sigmoid(tensor)` | Similar to `nn.functional.sigmoid(tensor)` |
| `activation::silu(tensor)` | Similar to `nn.functional.silu(tensor)` |
| `activation::softmax(tensor, dim)` | Similar to `nn.functional.softmax(tensor, dim)` |
| `activation::softplus(tensor, beta)` | Similar to `nn.functional.softplus(tensor, beta)` |
| `activation::tanh(tensor)` | Similar to `nn.functional.tanh(tensor)` |
| Burn API | PyTorch Equivalent |
| ---------------------------------------- | ------------------------------------------ |
| `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` |
| `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` |
| `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` |
| `activation::mish(tensor)` | `nn.functional.mish(tensor)` |
| `activation::prelu(tensor,alpha)` | `nn.functional.prelu(tensor,weight)` |
| `activation::quiet_softmax(tensor, dim)` | `nn.functional.quiet_softmax(tensor, dim)` |
| `activation::relu(tensor)` | `nn.functional.relu(tensor)` |
| `activation::sigmoid(tensor)` | `nn.functional.sigmoid(tensor)` |
| `activation::silu(tensor)` | `nn.functional.silu(tensor)` |
| `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` |
| `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` |
| `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` |
8 changes: 8 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,4 +356,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
fn int_sign<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
B::int_sign(tensor)
}

fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
B::int_prod(tensor)
}

fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
B::int_prod_dim(tensor, dim)
}
}
3 changes: 3 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2379,6 +2379,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.parents([&tensor])
.stateless(B::float_sign(tensor.primitive))
}

// TODO: Implement float_prod and float_sum
// https://github.com/tracel-ai/burn/issues/1458
}

#[derive(Debug, Clone)]
Expand Down
8 changes: 8 additions & 0 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}

fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!("prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)")
}

fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
todo!("prod_int is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)")
}

fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
// Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
panic!("Not supported by Candle")
Expand Down
41 changes: 41 additions & 0 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,47 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}

fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
unary_int_ops!(ProdOps, B::int_prod);

let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(vec![1]);

let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Prod(desc.clone())),
ProdOps::<D>::new(desc),
);

out
}

fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
scalar_int_ops!(ProdDimOps, B::int_prod_dim, usize, noconvert);

let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.tensor_uninitialized(shape);

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::ProdDim(desc.clone())),
ProdDimOps::<D>::new(desc),
);

out
}

fn int_mean<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
unary_int_ops!(MeanOps, B::int_mean);

Expand Down
13 changes: 13 additions & 0 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,19 @@ impl<E: Element> NumericOperationDescription<E> {
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::Prod(desc) => {
NumericOperationDescription::Prod(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::ProdDim(desc) => {
NumericOperationDescription::ProdDim(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs, // Dim should stay the same.
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::EqualElem(desc) => {
NumericOperationDescription::EqualElem(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
Expand Down
Loading
Loading