diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index 41acd54ef7487..150ab8d8c2aa5 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -64,10 +64,12 @@ constexpr auto selectArrowType() return atype::FLOAT; } else if constexpr (std::is_same_v) { return atype::DOUBLE; - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { return atype::INT8; } else if constexpr (std::is_same_v) { return atype::INT16; + } else if constexpr (std::is_same_v) { + return atype::UINT8; } else { return atype::NA; } @@ -236,6 +238,16 @@ inline Node operator/(Node left, Node right) return Node{OpNode{BasicOp::Division}, std::move(left), std::move(right)}; } +inline Node operator/(BindingNode left, Node right) +{ + return Node{OpNode{BasicOp::Division}, left, std::move(right)}; +} + +inline Node operator/(Node left, BindingNode right) +{ + return Node{OpNode{BasicOp::Division}, std::move(left), right}; +} + inline Node operator+(Node left, Node right) { return Node{OpNode{BasicOp::Addition}, std::move(left), std::move(right)}; @@ -246,6 +258,11 @@ inline Node operator-(Node left, Node right) return Node{OpNode{BasicOp::Subtraction}, std::move(left), std::move(right)}; } +inline Node operator-(BindingNode left, BindingNode right) +{ + return Node{OpNode{BasicOp::Subtraction}, left, right}; +} + /// arithmetical operations between node and literal template inline Node operator*(Node left, T right) diff --git a/Framework/Core/src/AODReaderHelpers.cxx b/Framework/Core/src/AODReaderHelpers.cxx index 5807ac1286476..169479525c1cb 100644 --- a/Framework/Core/src/AODReaderHelpers.cxx +++ b/Framework/Core/src/AODReaderHelpers.cxx @@ -41,25 +41,6 @@ namespace o2::framework::readers { - -namespace -{ -auto tableTypeFromInput(InputSpec const& spec) -{ - auto description = std::visit( - overloaded{ - [](ConcreteDataMatcher const& matcher) { return matcher.description; }, - [](auto&&) { return header::DataDescription{""}; }}, - spec.matcher); - - if (description == header::DataDescription{"TRACKPAR"}) { - return o2::aod::TracksMetadata{}; - } else { - throw std::runtime_error("Not an extended table"); - } -} -} // namespace - enum AODTypeMask : uint64_t { None = 0, Track = 1 << 0, @@ -192,12 +173,6 @@ auto spawner(framework::pack columns, arrow::Table* atable) return results; } -template -auto extractTable(ProcessingContext& pc) -{ - return pc.inputs().get(aod::MetadataTrait::metadata::tableLabel())->asArrowTable(); -} - AlgorithmSpec AODReaderHelpers::aodSpawnerCallback(std::vector requested) { return AlgorithmSpec::InitCallback{[requested](InitContext& ic) { @@ -213,26 +188,41 @@ AlgorithmSpec AODReaderHelpers::aodSpawnerCallback(std::vector reques auto outputs = pc.outputs(); // spawn tables for (auto& input : requested) { - using metadata = decltype(tableTypeFromInput(input)); - using base_t = metadata::base_table_t; - using expressions = metadata::expression_pack_t; - auto extra_schema = o2::soa::createSchemaFromColumns(expressions{}); - auto original_table = extractTable(pc); - auto original_schema = original_table->schema(); - auto num_fields = original_schema->num_fields(); - std::vector> fields; - auto arrays = spawner(expressions{}, original_table.get()); - std::vector> columns = original_table->columns(); - for (auto i = 0; i < num_fields; ++i) { - fields.emplace_back(original_schema->field(i)); - } - for (auto i = 0u; i < framework::pack_size(expressions{}); ++i) { - columns.push_back(arrays[i]); - fields.emplace_back(extra_schema->field(i)); + auto description = std::visit( + overloaded{ + [](ConcreteDataMatcher const& matcher) { return matcher.description; }, + [](auto&&) { return header::DataDescription{""}; }}, + input.matcher); + + auto origin = std::visit( + overloaded{ + [](ConcreteDataMatcher const& matcher) { return matcher.origin; }, + [](auto&&) { return header::DataOrigin{""}; }}, + input.matcher); + + auto maker = [&](auto metadata) { + using metadata_t = decltype(metadata); + using expressions = typename metadata_t::expression_pack_t; + auto extra_schema = o2::soa::createSchemaFromColumns(expressions{}); + auto original_table = pc.inputs().get(input.binding)->asArrowTable(); + auto original_fields = original_table->schema()->fields(); + std::vector> fields; + auto arrays = spawner(expressions{}, original_table.get()); + std::vector> columns = original_table->columns(); + std::copy(original_fields.begin(), original_fields.end(), std::back_inserter(fields)); + for (auto i = 0u; i < framework::pack_size(expressions{}); ++i) { + columns.push_back(arrays[i]); + fields.emplace_back(extra_schema->field(i)); + } + auto new_schema = std::make_shared(fields); + return arrow::Table::Make(new_schema, columns); + }; + + if (description == header::DataDescription{"TRACKPAR"}) { + outputs.adopt(Output{origin, description}, maker(o2::aod::TracksMetadata{})); + } else { + throw std::runtime_error("Not an extended table"); } - auto new_schema = std::make_shared(fields); - auto new_table = arrow::Table::Make(new_schema, columns); - outputs.adopt(Output{metadata::origin(), metadata::description()}, new_table); } }; }}; diff --git a/Framework/Core/src/ExpressionHelpers.h b/Framework/Core/src/ExpressionHelpers.h index 58587a177d9d5..df2aa56a19ed4 100644 --- a/Framework/Core/src/ExpressionHelpers.h +++ b/Framework/Core/src/ExpressionHelpers.h @@ -81,6 +81,8 @@ struct ColumnOperationSpec { case BasicOp::NotEqual: type = atype::BOOL; break; + case BasicOp::Division: + type = atype::FLOAT; default: type = atype::NA; } diff --git a/Framework/Core/src/Expressions.cxx b/Framework/Core/src/Expressions.cxx index 3edea56c8c29e..f401333a2f2bf 100644 --- a/Framework/Core/src/Expressions.cxx +++ b/Framework/Core/src/Expressions.cxx @@ -51,6 +51,8 @@ struct OpNodeHelper { std::shared_ptr concreteArrowType(atype::type type) { switch (type) { + case atype::UINT8: + return arrow::uint8(); case atype::INT8: return arrow::int8(); case atype::INT16: @@ -199,17 +201,20 @@ Operations createOperations(Filter const& expression) return atype::FLOAT; } - if (t1 == t2) + if (t1 == t2) { return t1; + } - if (t1 == atype::INT32) { + if (t1 == atype::INT32 || t1 == atype::INT8 || t1 == atype::INT16 || t1 == atype::UINT8) { + if (t2 == atype::INT32 || t2 == atype::INT8 || t2 == atype::INT16 || t2 == atype::UINT8) + return atype::FLOAT; if (t2 == atype::FLOAT) return atype::FLOAT; if (t2 == atype::DOUBLE) return atype::DOUBLE; } if (t1 == atype::FLOAT) { - if (t2 == atype::INT32) + if (t2 == atype::INT32 || t2 == atype::INT8 || t2 == atype::INT16 || t2 == atype::UINT8) return atype::FLOAT; if (t2 == atype::DOUBLE) return atype::DOUBLE;