Skip to content

Commit

Permalink
Merge pull request rapidsai#602 from rongou/fix-device-scalar
Browse files Browse the repository at this point in the history
[REVIEW] fix `device_scalar` and its tests so that they use the correct CUDA stream
  • Loading branch information
jrhemstad committed Oct 16, 2020
2 parents 4ec890f + 29eed8c commit 8cdd176
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
## Bug Fixes

- PR #592 Add `auto_flush` to `make_logging_adaptor`

- PR #602 Fix `device_scalar` and its tests so that they use the correct CUDA stream

# RMM 0.16.0 (Date TBD)

Expand Down
26 changes: 22 additions & 4 deletions include/rmm/device_scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,25 @@ class device_scalar {
set_value(initial_value, stream);
}

/**
* @brief Construct a new `device_scalar` by deep copying the contents of
* another `device_scalar`, using the specified stream and memory
* resource.
*
* @throws rmm::bad_alloc If creating the new allocation fails.
* @throws rmm::cuda_error if copying from `other` fails.
*
* @param other The `device_scalar` whose contents will be copied
* @param stream The stream to use for the allocation and copy
* @param mr The resource to use for allocating the new `device_scalar`
*/
device_scalar(device_scalar const &other,
cudaStream_t stream = 0,
rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource())
: buffer{other.buffer, stream, mr}
{
}

/**
* @brief Copies the value from device to host, synchronizes, and returns the value.
*
Expand Down Expand Up @@ -212,10 +231,9 @@ class device_scalar {
*/
T const *data() const noexcept { return static_cast<T const *>(buffer.data()); }

device_scalar() = default;
~device_scalar() = default;
device_scalar(device_scalar const &) = default;
device_scalar(device_scalar &&) = default;
device_scalar() = default;
~device_scalar() = default;
device_scalar(device_scalar &&) = default;
device_scalar &operator=(device_scalar const &) = delete;
device_scalar &operator=(device_scalar &&) = delete;

Expand Down
18 changes: 9 additions & 9 deletions tests/device_scalar_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,34 @@ TYPED_TEST(DeviceScalarTest, InitialValue)
{
rmm::device_scalar<TypeParam> scalar{this->value, this->stream, this->mr};
EXPECT_NE(nullptr, scalar.data());
EXPECT_EQ(this->value, scalar.value());
EXPECT_EQ(this->value, scalar.value(this->stream));
}

TYPED_TEST(DeviceScalarTest, CopyCtor)
{
rmm::device_scalar<TypeParam> scalar{this->value, this->stream, this->mr};
EXPECT_NE(nullptr, scalar.data());
EXPECT_EQ(this->value, scalar.value());
EXPECT_EQ(this->value, scalar.value(this->stream));

rmm::device_scalar<TypeParam> copy{scalar};
rmm::device_scalar<TypeParam> copy{scalar, this->stream, this->mr};
EXPECT_NE(nullptr, copy.data());
EXPECT_NE(copy.data(), scalar.data());
EXPECT_EQ(copy.value(), scalar.value());
EXPECT_EQ(copy.value(this->stream), scalar.value(this->stream));
}

TYPED_TEST(DeviceScalarTest, MoveCtor)
{
rmm::device_scalar<TypeParam> scalar{this->value, this->stream, this->mr};
EXPECT_NE(nullptr, scalar.data());
EXPECT_EQ(this->value, scalar.value());
EXPECT_EQ(this->value, scalar.value(this->stream));

auto original_pointer = scalar.data();
auto original_value = scalar.value();
auto original_value = scalar.value(this->stream);

rmm::device_scalar<TypeParam> moved_to{std::move(scalar)};
EXPECT_NE(nullptr, moved_to.data());
EXPECT_EQ(moved_to.data(), original_pointer);
EXPECT_EQ(moved_to.value(), original_value);
EXPECT_EQ(moved_to.value(this->stream), original_value);
EXPECT_EQ(nullptr, scalar.data());
}

Expand All @@ -95,6 +95,6 @@ TYPED_TEST(DeviceScalarTest, SetValue)

auto expected = this->distribution(this->generator);

scalar.set_value(expected);
EXPECT_EQ(expected, scalar.value());
scalar.set_value(expected, this->stream);
EXPECT_EQ(expected, scalar.value(this->stream));
}

0 comments on commit 8cdd176

Please sign in to comment.