From 32a1bd23eec49b1aaf171cd591999088e13d6e3a Mon Sep 17 00:00:00 2001 From: Jonathan Lifflander Date: Wed, 31 May 2023 14:15:59 -0700 Subject: [PATCH] #276: reduce: simplify stamp and message extraction --- src/vt/collective/reduce/get_reduce_stamp.h | 33 ++++++++++++------- src/vt/collective/reduce/reduce.h | 6 +--- src/vt/collective/reduce/reduce.impl.h | 20 ++++++++--- src/vt/objgroup/proxy/proxy_objgroup.impl.h | 21 ++---------- .../vrt/collection/reducable/reducable.impl.h | 22 ++----------- 5 files changed, 44 insertions(+), 58 deletions(-) diff --git a/src/vt/collective/reduce/get_reduce_stamp.h b/src/vt/collective/reduce/get_reduce_stamp.h index 03032b2406..1d3783cece 100644 --- a/src/vt/collective/reduce/get_reduce_stamp.h +++ b/src/vt/collective/reduce/get_reduce_stamp.h @@ -57,8 +57,12 @@ namespace vt { namespace collective { namespace reduce { template struct GetReduceStamp : std::false_type { template - static auto getMsg(Args&&... args) { - return vt::makeMessage(std::tuple{std::forward(args)...}); + static auto getStampMsg(Args&&... args) { + return + std::make_tuple( + collective::reduce::ReduceStamp{}, + vt::makeMessage(std::tuple{std::forward(args)...}) + ); } }; @@ -67,8 +71,11 @@ struct GetReduceStamp< std::enable_if_t> > : std::false_type { template - static auto getMsg() { - return vt::makeMessage(std::tuple<>{}); + static auto getStampMsg() { + return std::make_tuple( + collective::reduce::ReduceStamp{}, + vt::makeMessage(std::tuple<>{}) + ); } }; @@ -90,13 +97,17 @@ struct GetReduceStamp< } template - static auto getMsg(Args&&... args) { - return vt::makeMessage( - getMsgHelper( - std::tie(std::forward(args)...), - std::make_index_sequence{} - ) - ); + static auto getStampMsg(Args&&... args) { + auto tp = std::make_tuple(std::forward(args)...); + return + std::make_tuple( + std::get(tp), + vt::makeMessage( + getMsgHelper( + std::move(tp), std::make_index_sequence{} + ) + ) + ); } }; diff --git a/src/vt/collective/reduce/reduce.h b/src/vt/collective/reduce/reduce.h index 1293782edb..37b1905b2f 100644 --- a/src/vt/collective/reduce/reduce.h +++ b/src/vt/collective/reduce/reduce.h @@ -181,11 +181,7 @@ struct Reduce : virtual collective::tree::Tree { template class Op = NoneOp, typename... Params > - PendingSendType reduce(Node root, Params&&... params) { - using Tuple = typename FuncTraits::TupleType; - using OpT = Op; - return reduce(root, std::forward(params)...); - } + PendingSendType reduce(Node root, Params&&... params); template PendingSendType reduce(Node root, Params&&... params); diff --git a/src/vt/collective/reduce/reduce.impl.h b/src/vt/collective/reduce/reduce.impl.h index 692297efa0..d91de33e12 100644 --- a/src/vt/collective/reduce/reduce.impl.h +++ b/src/vt/collective/reduce/reduce.impl.h @@ -50,6 +50,7 @@ #include "vt/messaging/active.h" #include "vt/messaging/message.h" #include "vt/runnable/make_runnable.h" +#include "vt/collective/reduce/get_reduce_stamp.h" namespace vt { namespace collective { namespace reduce { @@ -83,20 +84,29 @@ void Reduce::reduceRootRecv(MsgT* msg) { .run(); } +template < + auto f, + template class Op, + typename... Params +> +Reduce::PendingSendType Reduce::reduce(Node root, Params&&... params) { + using Tuple = typename FuncTraits::TupleType; + using OpT = Op; + return reduce(root, std::forward(params)...); +} + template Reduce::PendingSendType Reduce::reduce(Node root, Params&&... params) { using Tuple = typename FuncTraits::TupleType; using MsgT = ReduceTMsg; - - auto msg = vt::makeMessage(std::tuple{std::forward(params)...}); - auto id = detail::ReduceStamp{}; + using GetReduceStamp = collective::reduce::GetReduceStamp; + auto [stamp, msg] = GetReduceStamp::template getStampMsg(std::forward(params)...); auto han = auto_registry::makeAutoHandlerParam(); msg->root_handler_ = han; - return reduce(root.get(), msg.get(), id, 1); + return reduce(root.get(), msg.get(), stamp, 1); } - template *f> Reduce::PendingSendType Reduce::reduce( NodeType const& root, MsgT* msg, Callback cb, detail::ReduceStamp id, diff --git a/src/vt/objgroup/proxy/proxy_objgroup.impl.h b/src/vt/objgroup/proxy/proxy_objgroup.impl.h index 21c380608f..50d28298f0 100644 --- a/src/vt/objgroup/proxy/proxy_objgroup.impl.h +++ b/src/vt/objgroup/proxy/proxy_objgroup.impl.h @@ -124,12 +124,7 @@ Proxy::allreduce( using GetReduceStamp = collective::reduce::GetReduceStamp; auto cb = theCB()->makeBcast(*this); - ReduceStamp stamp{}; - if constexpr (GetReduceStamp::value) { - stamp = std::get(std::tie(std::forward(args)...)); - } - - auto msg = GetReduceStamp::template getMsg(std::forward(args)...); + auto [stamp, msg] = GetReduceStamp::template getStampMsg(std::forward(args)...); msg->setCallback(cb); auto proxy = Proxy(*this); return theObjGroup()->reduce< @@ -158,12 +153,7 @@ Proxy::reduce( using GetReduceStamp = collective::reduce::GetReduceStamp; auto cb = theCB()->makeSend(target); - ReduceStamp stamp{}; - if constexpr (GetReduceStamp::value ) { - stamp = std::get(std::tie(std::forward(args)...)); - } - - auto msg = GetReduceStamp::template getMsg(std::forward(args)...); + auto [stamp, msg] = GetReduceStamp::template getStampMsg(std::forward(args)...); msg->setCallback(cb); auto proxy = Proxy(*this); return theObjGroup()->reduce< @@ -191,12 +181,7 @@ Proxy::reduce( using MsgT = collective::ReduceTMsg; using GetReduceStamp = collective::reduce::GetReduceStamp; - ReduceStamp stamp{}; - if constexpr (GetReduceStamp::value ) { - stamp = std::get(std::tie(std::forward(args)...)); - } - - auto msg = GetReduceStamp::template getMsg(std::forward(args)...); + auto [stamp, msg] = GetReduceStamp::template getStampMsg(std::forward(args)...); msg->setCallback(cb); auto proxy = Proxy(*this); return theObjGroup()->reduce< diff --git a/src/vt/vrt/collection/reducable/reducable.impl.h b/src/vt/vrt/collection/reducable/reducable.impl.h index dd0481fe7e..39afc6fca6 100644 --- a/src/vt/vrt/collection/reducable/reducable.impl.h +++ b/src/vt/vrt/collection/reducable/reducable.impl.h @@ -67,13 +67,7 @@ messaging::PendingSend Reducable::allreduce( using MsgT = collective::ReduceTMsg; using GetReduceStamp = collective::reduce::GetReduceStamp; auto cb = theCB()->makeBcast(*this); - - ReduceStamp stamp{}; - if constexpr (GetReduceStamp::value) { - stamp = std::get(std::tie(std::forward(args)...)); - } - - auto msg = GetReduceStamp::template getMsg(std::forward(args)...); + auto [stamp, msg] = GetReduceStamp::template getStampMsg(std::forward(args)...); msg->setCallback(cb); auto const root_node = 0; auto const proxy = this->getProxy(); @@ -98,12 +92,7 @@ messaging::PendingSend Reducable::reduce( auto cb = theCB()->makeSend(target); - ReduceStamp stamp{}; - if constexpr (GetReduceStamp::value) { - stamp = std::get(std::tie(std::forward(args)...)); - } - - auto msg = GetReduceStamp::template getMsg(std::forward(args)...); + auto [stamp, msg] = GetReduceStamp::template getStampMsg(std::forward(args)...); msg->setCallback(cb); auto const root_node = 0; auto const proxy = this->getProxy(); @@ -128,12 +117,7 @@ messaging::PendingSend Reducable::reduce( using MsgT = collective::ReduceTMsg; using GetReduceStamp = collective::reduce::GetReduceStamp; - ReduceStamp stamp{}; - if constexpr (GetReduceStamp::value) { - stamp = std::get(std::tie(std::forward(args)...)); - } - - auto msg = GetReduceStamp::template getMsg(std::forward(args)...); + auto [stamp, msg] = GetReduceStamp::template getStampMsg(std::forward(args)...); msg->setCallback(cb); auto const root_node = 0; auto const proxy = this->getProxy();