Skip to content

Commit

Permalink
[ntuple] Refactor templating away
Browse files Browse the repository at this point in the history
This makes the interface more low-level, but that's exactly what we
want, e.g. for eventual integration with the `RNTupleProcessor`.

This commit also introduces the policy that only trivial types (PODs and
`std::strings`) can be used as indices.
  • Loading branch information
enirolf committed Apr 10, 2024
1 parent 6333581 commit 4dc20b9
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 105 deletions.
1 change: 1 addition & 0 deletions tree/ntuple/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ SOURCES
v7/src/RNTupleDescriptor.cxx
v7/src/RNTupleDescriptorFmt.cxx
v7/src/RNTupleFillContext.cxx
v7/src/RNTupleIndex.cxx
v7/src/RNTupleMerger.cxx
v7/src/RNTupleMetrics.cxx
v7/src/RNTupleModel.cxx
Expand Down
26 changes: 26 additions & 0 deletions tree/ntuple/v7/inc/ROOT/RField.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ namespace Internal {
struct RFieldCallbackInjector;
class RPageSink;
class RPageSource;
class RNTupleIndex;
// TODO(jblomer): find a better way to not have these three methods in the RFieldBase public API
void CallCommitClusterOnField(RFieldBase &);
void CallConnectPageSinkOnField(RFieldBase &, RPageSink &, NTupleSize_t firstEntry = 0);
Expand Down Expand Up @@ -94,6 +95,7 @@ This is and can only be partially enforced through C++.
// clang-format on
class RFieldBase {
friend class ROOT::Experimental::RCollectionField; // to move the fields from the collection model
friend class ROOT::Experimental::Internal::RNTupleIndex;
friend struct ROOT::Experimental::Internal::RFieldCallbackInjector; // used for unit tests
friend void Internal::CallCommitClusterOnField(RFieldBase &);
friend void Internal::CallConnectPageSinkOnField(RFieldBase &, Internal::RPageSink &, NTupleSize_t);
Expand Down Expand Up @@ -530,6 +532,8 @@ protected:
/// as appropriate
virtual void OnConnectPageSource() {}

virtual NTupleIndexValue_t GetIndexRepresentation(void *from);

/// Factory method to resurrect a field from the stored on-disk type information. This overload takes an already
/// normalized type name and type alias
/// TODO(jalopezg): this overload may eventually be removed leaving only the `RFieldBase::Create()` that takes a
Expand Down Expand Up @@ -1874,6 +1878,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) bool(false); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "bool"; }
Expand Down Expand Up @@ -1904,6 +1909,9 @@ public:

template <>
class RField<float> final : public RFieldBase {
private:
std::hash<float> fHashFunc = std::hash<float>();

protected:
std::unique_ptr<RFieldBase> CloneImpl(std::string_view newName) const final
{
Expand All @@ -1914,6 +1922,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) float(0.0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return fHashFunc(*static_cast<float *>(from)); }

public:
static std::string TypeName() { return "float"; }
Expand Down Expand Up @@ -1946,6 +1955,9 @@ public:

template <>
class RField<double> final : public RFieldBase {
private:
std::hash<double> fHashFunc = std::hash<double>();

protected:
std::unique_ptr<RFieldBase> CloneImpl(std::string_view newName) const final
{
Expand All @@ -1956,6 +1968,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) double(0.0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return fHashFunc(*static_cast<double *>(from)); }

public:
static std::string TypeName() { return "double"; }
Expand Down Expand Up @@ -1999,6 +2012,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) std::byte{0}; }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "std::byte"; }
Expand Down Expand Up @@ -2038,6 +2052,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) char(0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "char"; }
Expand Down Expand Up @@ -2078,6 +2093,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) int8_t(0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "std::int8_t"; }
Expand Down Expand Up @@ -2118,6 +2134,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) uint8_t(0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "std::uint8_t"; }
Expand Down Expand Up @@ -2158,6 +2175,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) int16_t(0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "std::int16_t"; }
Expand Down Expand Up @@ -2198,6 +2216,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) int16_t(0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "std::uint16_t"; }
Expand Down Expand Up @@ -2238,6 +2257,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) int32_t(0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "std::int32_t"; }
Expand Down Expand Up @@ -2278,6 +2298,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) uint32_t(0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "std::uint32_t"; }
Expand Down Expand Up @@ -2320,6 +2341,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) uint64_t(0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "std::uint64_t"; }
Expand Down Expand Up @@ -2360,6 +2382,7 @@ protected:
void GenerateColumnsImpl() final;
void GenerateColumnsImpl(const RNTupleDescriptor &desc) final;
void ConstructValue(void *where) const final { new (where) int64_t(0); }
NTupleIndexValue_t GetIndexRepresentation(void *from) final { return *static_cast<NTupleIndexValue_t *>(from); }

public:
static std::string TypeName() { return "std::int64_t"; }
Expand Down Expand Up @@ -2392,6 +2415,7 @@ template <>
class RField<std::string> final : public RFieldBase {
private:
ClusterSize_t fIndex;
std::hash<std::string> fHashFunc = std::hash<std::string>();

std::unique_ptr<RFieldBase> CloneImpl(std::string_view newName) const final
{
Expand Down Expand Up @@ -2423,6 +2447,8 @@ public:
size_t GetValueSize() const final { return sizeof(std::string); }
size_t GetAlignment() const final { return std::alignment_of<std::string>(); }
void AcceptVisitor(Detail::RFieldVisitor &visitor) const final;

NTupleIndexValue_t GetIndexRepresentation(void *from) final { return fHashFunc(*static_cast<std::string *>(from)); }
};

/// TObject requires special handling of the fBits and fUniqueID members
Expand Down
89 changes: 29 additions & 60 deletions tree/ntuple/v7/inc/ROOT/RNTupleIndex.hxx
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
/// \file ROOT/RNTupleView.hxx
/// \file ROOT/RNTupleIndex.hxx
/// \ingroup NTuple ROOT7
/// \author Florine de Geus <[email protected]>
/// \date 2024-02-08
/// \date 2024-04-02
/// \warning This is part of the ROOT 7 prototype! It will change without notice. It might trigger earthquakes. Feedback
/// is welcome!

/*************************************************************************
* Copyright (C) 1995-2024, Rene Brun and Fons Rademakers. *
* All rights reserved. *
* *
* For the licensing terms see $ROOTSYS/LICENSE. *
* For the list of contributors see $ROOTSYS/README/CREDITS. *
*************************************************************************/

#ifndef ROOT7_RNTupleIndex
#define ROOT7_RNTupleIndex

Expand All @@ -20,66 +28,36 @@
namespace ROOT {
namespace Experimental {
namespace Internal {

// clang-format off
/**
\class ROOT::Experimental::RNTupleIndex
\class ROOT::Experimental::Internal::RNTupleIndex
\ingroup NTuple
\brief Build an index for an RNTuple so it can be joined onto other RNTuples.
*/
// clang-format on
template <class IndexValueT>
class RNTupleIndex {
friend class RNTupleReader;

private:
/// The maximum number of index elements we allow to be kept in memory. Used as a failsafe.
std::size_t fMaxElemsInMemory = 64 * 1024 * 1024;
std::size_t fNElems = 0;
std::string fFieldName; // TODO store more field info (for merging checks)
std::unordered_map<IndexValueT, std::set<NTupleSize_t>> fIndex;
std::unique_ptr<RFieldBase> fField;
std::unordered_map<NTupleIndexValue_t, std::set<NTupleSize_t>> fIndex;

void Merge(const RNTupleIndex<IndexValueT> &other)
{
if (fFieldName != other.fFieldName)
throw RException(R__FAIL("can only merge indices for the same field"));

fNElems += other.fNElems;

for (const auto &val : other.fIndex) {
auto res = fIndex.insert(val);
if (!res.second) {
// The index value is already present so the insertion failed. Instead, we have to merge the values.
fIndex.at(val.first).insert(val.second.begin(), val.second.end());
}
}
}
void Merge(const RNTupleIndex &other);

public:
RNTupleIndex(std::string_view fieldName) : fFieldName(std::string(fieldName)) {}
RNTupleIndex() : fFieldName("") {}
RNTupleIndex<IndexValueT> &operator=(const RNTupleIndex<IndexValueT> &other)
{
fFieldName = other.fFieldName;
fNElems = other.fNElems;
fIndex = other.fIndex;
return *this;
}
RNTupleIndex(std::unique_ptr<RFieldBase> field) : fField(std::move(field)) {}
RNTupleIndex() : fField(nullptr) {}
RNTupleIndex(const RNTupleIndex &other) { *this = other; }
RNTupleIndex &operator=(const RNTupleIndex &other);

void SetMaxElementsInMemory(std::size_t maxElems) { fMaxElemsInMemory = maxElems; }

std::size_t GetNElems() const { return fNElems; }

void Add(const IndexValueT &value, NTupleSize_t entry)
{
if (fNElems > fMaxElemsInMemory) {
throw RException(
R__FAIL("in-memory index exceeds maximum allowed size (" + std::to_string(fMaxElemsInMemory) + ")"));
}

fIndex[value].insert(entry);
fNElems++;
}
void Add(void *objPtr, NTupleSize_t entry);

/////////////////////////////////////////////////////////////////////////////
/// \brief Get the entry number containing the given index value **after** the provided minimum entry.
Expand All @@ -88,18 +66,16 @@ public:
/// \param[in] lowerBound The minimum entry number (inclusive) to retrieve. By default, all entries are considered.
/// \return The entry number, starting from `lowerBound`, containing the specified index value. When no such entry
/// exists, return `kInvalidNTupleIndex`
NTupleSize_t GetEntry(const IndexValueT &value, NTupleSize_t lowerBound = 0) const
{
if (!fIndex.count(value)) {
return kInvalidNTupleIndex;
}

auto indexEntries = fIndex.at(value);

if (auto entry = indexEntries.lower_bound(lowerBound); entry != indexEntries.end())
return *entry;
NTupleSize_t GetEntry(void *valuePtr, NTupleSize_t lowerBound = 0) const;

return kInvalidNTupleIndex;
/////////////////////////////////////////////////////////////////////////////
/// \brief Get the entry number containing the given index value **after** the provided minimum entry.
///
/// \sa GetEntry(void *valuePtr, NTupleSize_t lowerBound = 0)
template <typename T>
NTupleSize_t GetEntry(std::shared_ptr<T> valuePtr, NTupleSize_t lowerBound = 0) const
{
return GetEntry(valuePtr.get(), lowerBound);
}

/////////////////////////////////////////////////////////////////////////////
Expand All @@ -110,14 +86,7 @@ public:
///
/// \return A new RNTupleIndex resulting from the concatenation
///
static RNTupleIndex<IndexValueT>
Concatenate(const RNTupleIndex<IndexValueT> &left, const RNTupleIndex<IndexValueT> &right)
{
RNTupleIndex<IndexValueT> index(left.fFieldName);
index.Merge(left);
index.Merge(right);
return index;
}
static RNTupleIndex Concatenate(const RNTupleIndex &left, const RNTupleIndex &right);
};

} // namespace Internal
Expand Down
27 changes: 15 additions & 12 deletions tree/ntuple/v7/inc/ROOT/RNTupleReader.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -349,21 +349,24 @@ public:
/// \param[in] fieldName The name of the field for which to create the index
/// \return A pointer to the newly created index
///
/// \note Building the index can be sped up significantly by enabling implicit multithreading!
template <typename T>
std::unique_ptr<Internal::RNTupleIndex<T>> CreateIndex(std::string_view fieldName)
/// \note Building the index can be sped up significantly by enabling implicit multithreading
/// (`ROOT::EnableImplicitMT()`).
std::unique_ptr<Internal::RNTupleIndex> CreateIndex(std::string_view fieldName)
{
using Internal::RNTupleIndex;

auto makeIndex = [&](std::pair<std::uint64_t, std::uint64_t> range) -> RNTupleIndex<T> {
auto partialIndex = RNTupleIndex<T>(fieldName);
const RFieldBase &field = GetModel().GetField(fieldName);

auto makeIndex = [&](std::pair<std::uint64_t, std::uint64_t> range) -> RNTupleIndex {
RNTupleIndex partialIndex(field.Clone(fieldName));
auto reader = this->Clone();
auto entry = reader->GetModel().CreateEntry();
auto token = entry->GetToken(fieldName);

for (std::uint64_t i = range.first; i < range.second; ++i) {
reader->LoadEntry(i, *entry);
auto ptr = entry->GetPtr<T>(fieldName);
partialIndex.Add(*ptr, i);
auto ptr = entry->GetPtr<void>(token);
partialIndex.Add(ptr.get(), i);
}

return partialIndex;
Expand All @@ -386,19 +389,19 @@ public:
}
}

auto reducePartialIndices = [fieldName](const std::vector<RNTupleIndex<T>> indices) -> RNTupleIndex<T> {
return std::accumulate(indices.begin(), indices.end(), RNTupleIndex<T>(fieldName),
RNTupleIndex<T>::Concatenate);
auto reducePartialIndices = [&field, &fieldName](const std::vector<RNTupleIndex> indices) -> RNTupleIndex {
return std::accumulate(indices.begin(), indices.end(), RNTupleIndex(field.Clone(fieldName)),
RNTupleIndex::Concatenate);
};

auto index = ROOT::TThreadExecutor{}.MapReduce(makeIndex, ranges, reducePartialIndices);

return std::make_unique<RNTupleIndex<T>>(index);
return std::make_unique<RNTupleIndex>(index);
}
#endif

auto index = makeIndex(std::pair<std::uint64_t, std::uint64_t>{0, GetNEntries()});
return std::unique_ptr<RNTupleIndex<T>>(new RNTupleIndex<T>(index));
return std::unique_ptr<RNTupleIndex>(new RNTupleIndex(index));
}

// template <typename T>
Expand Down
2 changes: 2 additions & 0 deletions tree/ntuple/v7/inc/ROOT/RNTupleUtil.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ constexpr ColumnId_t kInvalidColumnId = -1;
using DescriptorId_t = std::uint64_t;
constexpr DescriptorId_t kInvalidDescriptorId = std::uint64_t(-1);

using NTupleIndexValue_t = std::size_t;

/// Addresses a column element or field item relative to a particular cluster, instead of a global NTupleSize_t index
class RClusterIndex {
private:
Expand Down
6 changes: 6 additions & 0 deletions tree/ntuple/v7/src/RField.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,12 @@ void ROOT::Experimental::RFieldBase::AcceptVisitor(Detail::RFieldVisitor &visito
visitor.VisitField(*this);
}

ROOT::Experimental::NTupleIndexValue_t ROOT::Experimental::RFieldBase::GetIndexRepresentation(void * /*from*/)
{
R__ASSERT(false && "indexing is not supported for this field type");
return 0;
}

//-----------------------------------------------------------------------------

std::unique_ptr<ROOT::Experimental::RFieldBase>
Expand Down
Loading

0 comments on commit 4dc20b9

Please sign in to comment.