diff --git a/Framework/Core/include/Framework/AnalysisManagers.h b/Framework/Core/include/Framework/AnalysisManagers.h index 938f23e9ce5b0..64265c5a13b51 100644 --- a/Framework/Core/include/Framework/AnalysisManagers.h +++ b/Framework/Core/include/Framework/AnalysisManagers.h @@ -12,6 +12,7 @@ #ifndef FRAMEWORK_ANALYSISMANAGERS_H #define FRAMEWORK_ANALYSISMANAGERS_H #include "Framework/AnalysisHelpers.h" +#include "Framework/GroupedCombinations.h" #include "Framework/Kernels.h" #include "Framework/ASoA.h" #include "Framework/ProcessingContext.h" @@ -29,6 +30,29 @@ namespace o2::framework { +template +struct GroupedCombinationManager { + template + static void setGroupedCombination(ANY&, TG&, T2s&...) + { + } +}; + +template +struct GroupedCombinationManager, As...>> { + template + static void setGroupedCombination(GroupedCombinationsGenerator, As...>& comb, TH& hashes, TG& grouping, std::tuple& associated) + { + static_assert(sizeof...(T2s) > 0, "There must be associated tables in process() for a correct pair"); + static_assert(!soa::is_soa_iterator_t>::value, "Only full tables can be in process(), no grouping"); + if constexpr (std::conjunction_v, std::is_same>) { + // Take respective unique associated tables for grouping + auto associatedTuple = std::tuple(std::get(associated)...); + comb.setTables(hashes, grouping, associatedTuple); + } + } +}; + template struct PartitionManager { template diff --git a/Framework/Core/include/Framework/AnalysisTask.h b/Framework/Core/include/Framework/AnalysisTask.h index 2a0fb65641e45..021ccbdff959e 100644 --- a/Framework/Core/include/Framework/AnalysisTask.h +++ b/Framework/Core/include/Framework/AnalysisTask.h @@ -55,7 +55,6 @@ struct AnalysisTask { // Helper struct which builds a DataProcessorSpec from // the contents of an AnalysisTask... - struct AnalysisDataProcessorBuilder { template static ConfigParamSpec getSpec() @@ -262,6 +261,7 @@ struct AnalysisDataProcessorBuilder { // single argument to process homogeneous_apply_refs([&groupingTable](auto& x) { PartitionManager>::bindExternalIndices(x, &groupingTable); + GroupedCombinationManager>::setGroupedCombination(x, groupingTable); return true; }, task); @@ -310,6 +310,15 @@ struct AnalysisDataProcessorBuilder { }, associatedTables); + // GroupedCombinations bound separately, as they should be set once for all associated tables + auto hashes = std::get<0>(associatedTables); + auto realAssociated = tuple_tail(associatedTables); + homogeneous_apply_refs([&groupingTable, &hashes, &realAssociated](auto& t) { + GroupedCombinationManager>::setGroupedCombination(t, hashes, groupingTable, realAssociated); + return true; + }, + task); + if constexpr (soa::is_soa_iterator_t>::value) { // grouping case auto slicer = GroupSlicer(groupingTable, associatedTables); diff --git a/Framework/Core/include/Framework/GroupedCombinations.h b/Framework/Core/include/Framework/GroupedCombinations.h new file mode 100644 index 0000000000000..6dd0efc4f5cf2 --- /dev/null +++ b/Framework/Core/include/Framework/GroupedCombinations.h @@ -0,0 +1,267 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +#ifndef FRAMEWORK_GROUPEDCOMBINATIONS_H +#define FRAMEWORK_GROUPEDCOMBINATIONS_H + +#include "Framework/ASoAHelpers.h" +#include "Framework/GroupSlicer.h" +#include "Framework/Pack.h" + +namespace o2::framework +{ + +// Create an instance of a tuple interleaved from given tuples +template +auto interleaveTuplesImpl(std::tuple& t1, std::tuple& t2, std::index_sequence) +{ + return std::tuple_cat(std::make_tuple(std::get(t1), std::get(t2))...); +} + +template +auto interleaveTuples(std::tuple& t1, std::tuple& t2) +{ + return interleaveTuplesImpl(t1, t2, std::index_sequence_for()); +} + +// Functions to create a tuple from N runs of a function that returns a value +template +R execFunctionWithDummyIndex(R (C::*f)(Args...), C& obj, Args... args) +{ + return (obj.*f)(args...); +} + +template +auto functionToTupleImpl(R (C::*f)(Args...), C& obj, Args... args, std::index_sequence) +{ + return std::make_tuple((execFunctionWithDummyIndex(f, obj, args...))...); +} + +template +auto functionToTuple(R (C::*f)(Args...), C& obj, Args... args) +{ + return functionToTupleImpl(f, obj, args..., std::make_index_sequence()); +} + +template +struct GroupedCombinationsGenerator { +}; + +template +struct GroupedCombinationsGenerator, As...> { + using joinIterator = typename soa::Join::table_t::iterator; + using GroupedIteratorType = pack_to_tuple_t, pack>>; + + struct GroupedIterator : public std::iterator, public GroupingPolicy { + public: + using reference = GroupedIteratorType&; + using value_type = GroupedIteratorType; + using pointer = GroupedIteratorType*; + using iterator_category = std::forward_iterator_tag; + + GroupedIterator(const GroupingPolicy& groupingPolicy) : GroupingPolicy(groupingPolicy) {} + GroupedIterator(const GroupingPolicy& groupingPolicy, const H& hashes, const G& grouping, const std::shared_ptr>&& slicer_ptr) : GroupingPolicy(groupingPolicy), mSlicer{std::move(slicer_ptr)}, mGrouping{std::make_shared(std::vector{grouping.asArrowTable()})} + { + GroupingPolicy::setTables(join(hashes, grouping), join(hashes, grouping)); + if (!this->mIsEnd) { + setCurrentGroupedCombination(); + } + } + + GroupedIterator(GroupedIterator const&) = default; + GroupedIterator& operator=(GroupedIterator const&) = default; + ~GroupedIterator() = default; + + void setTables(const H& hashes, const G& grouping, std::shared_ptr> slicer_ptr) + { + mGrouping = std::make_shared(std::vector{grouping.asArrowTable()}); + mSlicer = slicer_ptr; + setMultipleGroupingTables(join(hashes, grouping)); + if (!this->mIsEnd) { + setCurrentGroupedCombination(); + } + } + + template + void setMultipleGroupingTables(const T& param, const Args&... args) + { + if constexpr (N == 1) { + GroupingPolicy::setTables(param, args...); + } else { + setMultipleGroupingTables(param, param, args...); + } + } + + void moveToEnd() + { + GroupingPolicy::moveToEnd(); + } + + // prefix increment + GroupedIterator& operator++() + { + if (!this->mIsEnd) { + this->addOne(); + setCurrentGroupedCombination(); + } + return *this; + } + // postfix increment + GroupedIterator operator++(int /*unused*/) + { + GroupedIterator copy(*this); + operator++(); + return copy; + } + // return reference + reference operator*() + { + return *mCurrentGrouped; + } + bool operator==(const GroupedIterator& rh) + { + return (this->mIsEnd && rh.mIsEnd) || (this->mCurrent == rh.mCurrent); + } + bool operator!=(const GroupedIterator& rh) + { + return !(*this == rh); + } + + private: + std::tuple getAssociatedTables() + { + auto& currentGrouping = GroupingPolicy::mCurrent; + constexpr auto k = sizeof...(As); + auto slicerIterators = functionToTuple(&GroupSlicer::begin, *mSlicer); + o2::soa::for_([&](auto i) { + auto col = std::get(currentGrouping); + for (auto& slice : *mSlicer) { + if (slice.groupingElement().globalIndex() == col.globalIndex()) { + std::get(slicerIterators) = slice; + break; + } + } + }); + + return getSlices(slicerIterators, std::index_sequence_for()); + } + + template + auto getSliceAt(std::tuple& t) + { + auto it = std::get(t); // Get the tables corresponding to the grouping at index I + auto associatedType = it.template prepareArgument(); + return associatedType; + } + + template + std::tuple getSlices(std::tuple& t, std::index_sequence is) + { + return std::make_tuple(getSliceAt(t)...); + } + + void setCurrentGroupedCombination() + { + std::tuple initAssociatedTables = getAssociatedTables(); + constexpr auto k = sizeof...(As); + bool moveForward = false; + o2::soa::for_([&](auto i) { + if (std::get(initAssociatedTables).size() == 0) { + moveForward = true; + } + }); + while (!this->mIsEnd && moveForward) { + GroupingPolicy::addOne(); + std::tuple temp = getAssociatedTables(); + moveForward = false; + o2::soa::for_([&](auto i) { + if (std::get(temp).size() == 0) { + moveForward = true; + } + }); + } + std::tuple associatedTables = getAssociatedTables(); + + if (!this->mIsEnd) { + auto& currentGrouping = GroupingPolicy::mCurrent; + o2::soa::for_([&](auto i) { + std::get(associatedTables).bindExternalIndices(mGrouping.get()); + }); + + mCurrentGrouped.emplace(interleaveTuples(currentGrouping, associatedTables)); + } + } + + std::shared_ptr> mSlicer = nullptr; + std::shared_ptr mGrouping; + std::optional mCurrentGrouped; + }; + + using iterator = GroupedIterator; + using const_iterator = GroupedIterator; + + inline iterator begin() + { + return iterator(mBegin); + } + inline iterator end() + { + return iterator(mEnd); + } + inline const_iterator begin() const + { + return iterator(mBegin); + } + inline const_iterator end() const + { + return iterator(mEnd); + } + + GroupedCombinationsGenerator(const char* category, int catNeighbours, const T1& outsider) : mBegin(GroupingPolicy(category, catNeighbours, outsider)), mEnd(GroupingPolicy(category, catNeighbours, outsider)), mCategory(category), mCatNeighbours(catNeighbours), mOutsider(outsider) {} + GroupedCombinationsGenerator(const char* category, int catNeighbours, const T1& outsider, H& hashes, G& grouping, std::tuple& associated) : GroupedCombinationsGenerator(category, catNeighbours, outsider) + { + setTables(hashes, grouping, associated); + } + ~GroupedCombinationsGenerator() = default; + + void setTables(H& hashes, G& grouping, std::tuple& associated) + { + std::shared_ptr slicer_ptr = std::make_shared>(grouping, associated); + mBegin.setTables(hashes, grouping, slicer_ptr); + mEnd.setTables(hashes, grouping, slicer_ptr); + mEnd.moveToEnd(); + } + + private: + iterator mBegin; + iterator mEnd; + const char* mCategory; + const int mCatNeighbours; + const T1 mOutsider; +}; + +// Aliases for 2-particle correlations +// 'Pair' and 'Triple' can be used for same kind pair/triple, too, just specify the same type twice +template +using joinedCollisions = typename soa::Join::table_t; +template , joinedCollisions>> +using Pair = GroupedCombinationsGenerator>>; +template , joinedCollisions>> +using SameKindPair = GroupedCombinationsGenerator, A, A>; + +// Aliases for 3-particle correlations +template , joinedCollisions, joinedCollisions>> +using Triple = GroupedCombinationsGenerator>>; +template , joinedCollisions, joinedCollisions>> +using SameKindTriple = GroupedCombinationsGenerator, A, A, A>; + +} // namespace o2::framework +#endif // FRAMEWORK_GROUPEDCOMBINATIONS_H_ diff --git a/Framework/Foundation/include/Framework/Pack.h b/Framework/Foundation/include/Framework/Pack.h index 7b3899060c491..bf331d40fea42 100644 --- a/Framework/Foundation/include/Framework/Pack.h +++ b/Framework/Foundation/include/Framework/Pack.h @@ -272,6 +272,24 @@ constexpr auto unique_pack(pack, PT p2) template using unique_pack_t = decltype(unique_pack(P{}, pack<>{})); +template +inline constexpr std::tuple pack_to_tuple(pack) +{ + return std::tuple{}; +} + +template +using pack_to_tuple_t = decltype(pack_to_tuple(P{})); + +template +inline auto sequence_to_pack(std::integer_sequence) +{ + return pack{}; +}; + +template +using repeated_type_pack_t = decltype(sequence_to_pack(std::make_index_sequence())); + } // namespace o2::framework #endif // O2_FRAMEWORK_PACK_H_ diff --git a/Framework/Foundation/test/test_FunctionalHelpers.cxx b/Framework/Foundation/test/test_FunctionalHelpers.cxx index 7871ed0f337e8..f42f77715a0ac 100644 --- a/Framework/Foundation/test/test_FunctionalHelpers.cxx +++ b/Framework/Foundation/test/test_FunctionalHelpers.cxx @@ -53,6 +53,8 @@ BOOST_AUTO_TEST_CASE(TestOverride) static_assert(std::is_same_v>, pack>, "pack should not have duplicated types"); static_assert(std::is_same_v, pack>, pack>, "interleaved packs of the same size"); + static_assert(std::is_same_v>, std::tuple>, "pack should become a tuple"); + static_assert(std::is_same_v, pack>, "pack should have float repeated 5 times"); struct ForwardDeclared; static_assert(is_type_complete_v == false, "This should not be complete because the struct is simply forward declared.");