Skip to content

Commit

Permalink
#276: reduce: simplify stamp and message extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander committed Jun 13, 2023
1 parent 60896c3 commit 32a1bd2
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 58 deletions.
33 changes: 22 additions & 11 deletions src/vt/collective/reduce/get_reduce_stamp.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ namespace vt { namespace collective { namespace reduce {
template <typename enable = void, typename... Args>
struct GetReduceStamp : std::false_type {
template <typename MsgT>
static auto getMsg(Args&&... args) {
return vt::makeMessage<MsgT>(std::tuple{std::forward<Args>(args)...});
static auto getStampMsg(Args&&... args) {
return
std::make_tuple(
collective::reduce::ReduceStamp{},
vt::makeMessage<MsgT>(std::tuple{std::forward<Args>(args)...})
);
}
};

Expand All @@ -67,8 +71,11 @@ struct GetReduceStamp<
std::enable_if_t<std::is_same_v<void, void>>
> : std::false_type {
template <typename MsgT>
static auto getMsg() {
return vt::makeMessage<MsgT>(std::tuple<>{});
static auto getStampMsg() {
return std::make_tuple(
collective::reduce::ReduceStamp{},
vt::makeMessage<MsgT>(std::tuple<>{})
);
}
};

Expand All @@ -90,13 +97,17 @@ struct GetReduceStamp<
}

template <typename MsgT>
static auto getMsg(Args&&... args) {
return vt::makeMessage<MsgT>(
getMsgHelper(
std::tie(std::forward<Args>(args)...),
std::make_index_sequence<sizeof...(Args) - 1>{}
)
);
static auto getStampMsg(Args&&... args) {
auto tp = std::make_tuple(std::forward<Args>(args)...);
return
std::make_tuple(
std::get<sizeof...(Args) - 1>(tp),
vt::makeMessage<MsgT>(
getMsgHelper(
std::move(tp), std::make_index_sequence<sizeof...(Args) - 1>{}
)
)
);
}
};

Expand Down
6 changes: 1 addition & 5 deletions src/vt/collective/reduce/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,7 @@ struct Reduce : virtual collective::tree::Tree {
template <typename Arg> class Op = NoneOp,
typename... Params
>
PendingSendType reduce(Node root, Params&&... params) {
using Tuple = typename FuncTraits<decltype(f)>::TupleType;
using OpT = Op<Tuple>;
return reduce<OpT, f>(root, std::forward<Params>(params)...);
}
PendingSendType reduce(Node root, Params&&... params);

template <typename Op, auto f, typename... Params>
PendingSendType reduce(Node root, Params&&... params);
Expand Down
20 changes: 15 additions & 5 deletions src/vt/collective/reduce/reduce.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "vt/messaging/active.h"
#include "vt/messaging/message.h"
#include "vt/runnable/make_runnable.h"
#include "vt/collective/reduce/get_reduce_stamp.h"

namespace vt { namespace collective { namespace reduce {

Expand Down Expand Up @@ -83,20 +84,29 @@ void Reduce::reduceRootRecv(MsgT* msg) {
.run();
}

template <
auto f,
template <typename Arg> class Op,
typename... Params
>
Reduce::PendingSendType Reduce::reduce(Node root, Params&&... params) {
using Tuple = typename FuncTraits<decltype(f)>::TupleType;
using OpT = Op<Tuple>;
return reduce<OpT, f>(root, std::forward<Params>(params)...);
}

template <typename Op, auto f, typename... Params>
Reduce::PendingSendType Reduce::reduce(Node root, Params&&... params) {
using Tuple = typename FuncTraits<decltype(f)>::TupleType;
using MsgT = ReduceTMsg<Tuple>;

auto msg = vt::makeMessage<MsgT>(std::tuple{std::forward<Params>(params)...});
auto id = detail::ReduceStamp{};
using GetReduceStamp = collective::reduce::GetReduceStamp<void, Params...>;
auto [stamp, msg] = GetReduceStamp::template getStampMsg<MsgT>(std::forward<Params>(params)...);
auto han = auto_registry::makeAutoHandlerParam<decltype(f), f, MsgT>();
msg->root_handler_ = han;

return reduce<Op, operators::NoCombine, MsgT>(root.get(), msg.get(), id, 1);
return reduce<Op, operators::NoCombine, MsgT>(root.get(), msg.get(), stamp, 1);
}


template <typename OpT, typename MsgT, ActiveTypedFnType<MsgT> *f>
Reduce::PendingSendType Reduce::reduce(
NodeType const& root, MsgT* msg, Callback<MsgT> cb, detail::ReduceStamp id,
Expand Down
21 changes: 3 additions & 18 deletions src/vt/objgroup/proxy/proxy_objgroup.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,7 @@ Proxy<ObjT>::allreduce(
using GetReduceStamp = collective::reduce::GetReduceStamp<void, Args...>;
auto cb = theCB()->makeBcast<f>(*this);

ReduceStamp stamp{};
if constexpr (GetReduceStamp::value) {
stamp = std::get<sizeof...(Args) - 1>(std::tie(std::forward<Args>(args)...));
}

auto msg = GetReduceStamp::template getMsg<MsgT>(std::forward<Args>(args)...);
auto [stamp, msg] = GetReduceStamp::template getStampMsg<MsgT>(std::forward<Args>(args)...);
msg->setCallback(cb);
auto proxy = Proxy<ObjT>(*this);
return theObjGroup()->reduce<
Expand Down Expand Up @@ -158,12 +153,7 @@ Proxy<ObjT>::reduce(
using GetReduceStamp = collective::reduce::GetReduceStamp<void, Args...>;
auto cb = theCB()->makeSend<f>(target);

ReduceStamp stamp{};
if constexpr (GetReduceStamp::value ) {
stamp = std::get<sizeof...(Args) - 1>(std::tie(std::forward<Args>(args)...));
}

auto msg = GetReduceStamp::template getMsg<MsgT>(std::forward<Args>(args)...);
auto [stamp, msg] = GetReduceStamp::template getStampMsg<MsgT>(std::forward<Args>(args)...);
msg->setCallback(cb);
auto proxy = Proxy<ObjT>(*this);
return theObjGroup()->reduce<
Expand Down Expand Up @@ -191,12 +181,7 @@ Proxy<ObjT>::reduce(
using MsgT = collective::ReduceTMsg<Tuple>;
using GetReduceStamp = collective::reduce::GetReduceStamp<void, Args...>;

ReduceStamp stamp{};
if constexpr (GetReduceStamp::value ) {
stamp = std::get<sizeof...(Args) - 1>(std::tie(std::forward<Args>(args)...));
}

auto msg = GetReduceStamp::template getMsg<MsgT>(std::forward<Args>(args)...);
auto [stamp, msg] = GetReduceStamp::template getStampMsg<MsgT>(std::forward<Args>(args)...);
msg->setCallback(cb);
auto proxy = Proxy<ObjT>(*this);
return theObjGroup()->reduce<
Expand Down
22 changes: 3 additions & 19 deletions src/vt/vrt/collection/reducable/reducable.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,7 @@ messaging::PendingSend Reducable<ColT,IndexT,BaseProxyT>::allreduce(
using MsgT = collective::ReduceTMsg<Tuple>;
using GetReduceStamp = collective::reduce::GetReduceStamp<void, Args...>;
auto cb = theCB()->makeBcast<f>(*this);

ReduceStamp stamp{};
if constexpr (GetReduceStamp::value) {
stamp = std::get<sizeof...(Args) - 1>(std::tie(std::forward<Args>(args)...));
}

auto msg = GetReduceStamp::template getMsg<MsgT>(std::forward<Args>(args)...);
auto [stamp, msg] = GetReduceStamp::template getStampMsg<MsgT>(std::forward<Args>(args)...);
msg->setCallback(cb);
auto const root_node = 0;
auto const proxy = this->getProxy();
Expand All @@ -98,12 +92,7 @@ messaging::PendingSend Reducable<ColT,IndexT,BaseProxyT>::reduce(

auto cb = theCB()->makeSend<f>(target);

ReduceStamp stamp{};
if constexpr (GetReduceStamp::value) {
stamp = std::get<sizeof...(Args) - 1>(std::tie(std::forward<Args>(args)...));
}

auto msg = GetReduceStamp::template getMsg<MsgT>(std::forward<Args>(args)...);
auto [stamp, msg] = GetReduceStamp::template getStampMsg<MsgT>(std::forward<Args>(args)...);
msg->setCallback(cb);
auto const root_node = 0;
auto const proxy = this->getProxy();
Expand All @@ -128,12 +117,7 @@ messaging::PendingSend Reducable<ColT,IndexT,BaseProxyT>::reduce(
using MsgT = collective::ReduceTMsg<Tuple>;
using GetReduceStamp = collective::reduce::GetReduceStamp<void, Args...>;

ReduceStamp stamp{};
if constexpr (GetReduceStamp::value) {
stamp = std::get<sizeof...(Args) - 1>(std::tie(std::forward<Args>(args)...));
}

auto msg = GetReduceStamp::template getMsg<MsgT>(std::forward<Args>(args)...);
auto [stamp, msg] = GetReduceStamp::template getStampMsg<MsgT>(std::forward<Args>(args)...);
msg->setCallback(cb);
auto const root_node = 0;
auto const proxy = this->getProxy();
Expand Down

0 comments on commit 32a1bd2

Please sign in to comment.