Skip to content

Commit 34ed834

Browse files
authored
Merge pull request #92 from InfiniTensor/dist-merge
feat (dist): nccl通信库接入,allreduce算子
2 parents e3febd9 + 16870d6 commit 34ed834

File tree

17 files changed

+666
-9
lines changed

17 files changed

+666
-9
lines changed

src/04kernel/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ if(USE_CUDA)
2828
# cudnn for conv and others
2929
target_link_libraries(kernel PUBLIC cuda nvrtc cublas cublasLt cudnn kernel_cuda)
3030
target_include_directories(kernel PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
31+
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
32+
find_package(NCCL REQUIRED)
33+
target_link_libraries(kernel PUBLIC nccl)
3134
endif()
3235
if(USE_KUNLUN)
3336
include_directories(${KUNLUN_HOME}/XTDK/include/)

src/04kernel/cmake/FindNCCL.cmake

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# From PyTorch:
4+
#
5+
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
6+
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
7+
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
8+
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
9+
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
10+
# Copyright (c) 2011-2013 NYU (Clement Farabet)
11+
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
12+
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
13+
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
14+
#
15+
# From Caffe2:
16+
#
17+
# Copyright (c) 2016-present, Facebook Inc. All rights reserved.
18+
#
19+
# All contributions by Facebook:
20+
# Copyright (c) 2016 Facebook Inc.
21+
#
22+
# All contributions by Google:
23+
# Copyright (c) 2015 Google Inc.
24+
# All rights reserved.
25+
#
26+
# All contributions by Yangqing Jia:
27+
# Copyright (c) 2015 Yangqing Jia
28+
# All rights reserved.
29+
#
30+
# All contributions by Kakao Brain:
31+
# Copyright 2019-2020 Kakao Brain
32+
#
33+
# All contributions from Caffe:
34+
# Copyright(c) 2013, 2014, 2015, the respective contributors
35+
# All rights reserved.
36+
#
37+
# All other contributions:
38+
# Copyright(c) 2015, 2016 the respective contributors
39+
# All rights reserved.
40+
#
41+
# Caffe2 uses a copyright model similar to Caffe: each contributor holds
42+
# copyright over their contributions to Caffe2. The project versioning records
43+
# all such contribution and copyright details. If a contributor wants to further
44+
# mark their specific copyright on a particular contribution, they should
45+
# indicate their copyright solely in the commit message of the change when it is
46+
# committed.
47+
#
48+
# All rights reserved.
49+
#
50+
# Redistribution and use in source and binary forms, with or without
51+
# modification, are permitted provided that the following conditions are met:
52+
#
53+
# 1. Redistributions of source code must retain the above copyright
54+
# notice, this list of conditions and the following disclaimer.
55+
#
56+
# 2. Redistributions in binary form must reproduce the above copyright
57+
# notice, this list of conditions and the following disclaimer in the
58+
# documentation and/or other materials provided with the distribution.
59+
#
60+
# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
61+
# and IDIAP Research Institute nor the names of its contributors may be
62+
# used to endorse or promote products derived from this software without
63+
# specific prior written permission.
64+
#
65+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
66+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
67+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
68+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
69+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
70+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
71+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
72+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
73+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
74+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
75+
# POSSIBILITY OF SUCH DAMAGE.
76+
#
77+
# Find the nccl libraries
78+
#
79+
# The following variables are optionally searched for defaults
80+
# NCCL_ROOT: Base directory where all NCCL components are foundHong Xu, 1 year ago: • Let CMake handle NCCL detection instead of ou…
81+
# NCCL_INCLUDE_DIR: Directory where NCCL header is foundPieter Noordhuis, 3 years ago: • Bump gloo
82+
# NCCL_LIB_DIR: Directory where NCCL library is found
83+
#
84+
# The following are set after configuration is done:
85+
# NCCL_FOUND
86+
# NCCL_INCLUDE_DIRS
87+
# NCCL_LIBRARIES
88+
#
89+
# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks
90+
# install NCCL in the same location as the CUDA toolkit.
91+
# See https://github.com/caffe2/caffe2/issues/1601
92+
93+
set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers")
94+
set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries")
95+
set(NCCL_VERSION $ENV{NCCL_VERSION} CACHE STRING "Version of NCCL to build with")
96+
97+
if ($ENV{NCCL_ROOT_DIR})
98+
message(WARNING "NCCL_ROOT_DIR is deprecated. Please set NCCL_ROOT instead.")
99+
endif()
100+
list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
101+
# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
102+
list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT})
103+
104+
find_path(NCCL_INCLUDE_DIRS
105+
NAMES nccl.h
106+
HINTS ${NCCL_INCLUDE_DIR})
107+
108+
if (USE_STATIC_NCCL)
109+
MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.")
110+
SET(NCCL_LIBNAME "nccl_static")
111+
if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified
112+
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
113+
endif()
114+
else()
115+
SET(NCCL_LIBNAME "nccl")
116+
if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified
117+
set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
118+
endif()
119+
endif()
120+
121+
find_library(NCCL_LIBRARIES
122+
NAMES ${NCCL_LIBNAME}
123+
HINTS ${NCCL_LIB_DIR})
124+
125+
include(FindPackageHandleStandardArgs)
126+
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
127+
128+
if(NCCL_FOUND) # obtaining NCCL version and some sanity checks
129+
set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
130+
message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...")
131+
set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
132+
list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS})
133+
include(CheckCXXSymbolExists)
134+
check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED)
135+
136+
if (NCCL_VERSION_DEFINED)
137+
set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc")
138+
file(WRITE ${file} "
139+
#include <iostream>
140+
#include <nccl.h>
141+
int main()
142+
{
143+
std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl;
144+
int x;
145+
ncclGetVersion(&x);
146+
return x == NCCL_VERSION_CODE;
147+
}
148+
")
149+
try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file}
150+
RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER
151+
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}"
152+
LINK_LIBRARIES ${NCCL_LIBRARIES})
153+
if (NOT NCCL_VERSION_MATCHED)
154+
message(FATAL_ERROR "Found NCCL header version and library version do not match! \
155+
(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.")
156+
endif()
157+
message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}")
158+
else()
159+
# message(STATUS "NCCL version < 2.3.5-5")
160+
endif ()
161+
set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})
162+
163+
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
164+
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
165+
endif()

src/04kernel/cuda/include/kernel/cuda/functions.cuh

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ namespace refactor::kernel::cuda {
66
int currentDevice();
77

88
void sync();
9+
10+
void setCudaDevice(int);
911

1012
void copyOut(void *dst, const void *src, size_t size);
1113

src/04kernel/cuda/src/functions.cu

+4
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,8 @@ namespace refactor::kernel::cuda {
1919
CUDA_ASSERT(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost));
2020
}
2121

22+
void setCudaDevice(int id) {
23+
cudaSetDevice(id);
24+
}
25+
2226
}// namespace refactor::kernel::cuda
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef KERNEL_COMMUNICATION_ATTRIBUTES_H
2+
#define KERNEL_COMMUNICATION_ATTRIBUTES_H
3+
4+
namespace refactor::kernel {
5+
enum class AllReduceType {
6+
Sum,
7+
Avg,
8+
Min,
9+
Max,
10+
Prod
11+
};
12+
}
13+
14+
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef KERNEL_COLLECTOR_ALL_REDUCE_H
2+
#define KERNEL_COLLECTOR_ALL_REDUCE_H
3+
4+
#include "../collector.h"
5+
#include "kernel/attributes/communication.h"
6+
7+
namespace refactor::kernel {
8+
9+
struct AllReduceCollector final : public InfoCollector {
10+
11+
AllReduceType type;
12+
13+
constexpr AllReduceCollector(decltype(_target) target, AllReduceType type_) noexcept
14+
: InfoCollector(target), type(type_) {}
15+
16+
std::vector<KernelBox>
17+
filter(TensorRefs inputs, TensorRefs outputs) const final;
18+
};
19+
}// namespace refactor::kernel
20+
21+
#endif
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "kernel/collectors/all_reduce.h"
2+
#include "../kernels/all_reduce/nccl_kernel.hh"
3+
namespace refactor::kernel {
4+
std::vector<KernelBox>
5+
AllReduceCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
6+
std::vector<KernelBox> ans;
7+
switch (_target) {
8+
case decltype(_target)::Cpu:
9+
break;
10+
case decltype(_target)::Nvidia:
11+
if (auto ptr = AllReduceNccl::build(type, inputs[0], outputs[0]); ptr) {
12+
ans.emplace_back(std::move(ptr));
13+
}
14+
break;
15+
default:
16+
UNREACHABLEX(void, "Unknown target");
17+
}
18+
return ans;
19+
}
20+
}// namespace refactor::kernel
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "nccl_kernel.hh"
2+
3+
namespace refactor::kernel {
4+
using K = AllReduceNccl;
5+
using DT = DataType;
6+
7+
K::AllReduceNccl(AllReduceType opType_, DT dataType_, size_t size_) noexcept
8+
: opType(opType_), dataType(dataType_), size(size_) {}
9+
10+
auto K::build(AllReduceType opType_, Tensor const &input, Tensor const &output) noexcept -> KernelBox {
11+
#ifndef USE_CUDA
12+
return nullptr;
13+
#endif
14+
if (input.elementsSize() != output.elementsSize() ||
15+
input.dataType != output.dataType) {
16+
return nullptr;
17+
}
18+
19+
return std::make_unique<K>(opType_, input.dataType, input.elementsSize());
20+
}
21+
22+
auto K::typeId() noexcept -> size_t {
23+
static uint8_t ID = 1;
24+
return reinterpret_cast<size_t>(&ID);
25+
}
26+
27+
auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
28+
auto K::description() const noexcept -> std::string_view {
29+
return "Performing AllReduce using NCCL";
30+
}
31+
32+
}// namespace refactor::kernel
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "nccl_kernel.hh"
2+
#include "../../utilities/cuda/nccl_communicator.hh"
3+
#include <nccl.h>
4+
namespace refactor::kernel {
5+
using K = AllReduceNccl;
6+
using DT = DataType;
7+
using namespace nccl;
8+
9+
auto K::lower(Resources &res) const noexcept -> RoutineWorkspace{
10+
return [count = size,
11+
redOp = getRedOp(opType),
12+
ncclDataType = getNcclDataType(dataType)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
13+
auto communicator = res.fetch<NcclCommunicator>();
14+
auto input = inputs[0];
15+
auto output = outputs[0];
16+
checkNcclError(ncclAllReduce(input, output, count, ncclDataType,
17+
redOp, communicator->get(), 0));// TODO: use default stream for now
18+
};
19+
}
20+
}// namespace refactor::kernel
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#ifndef KERNEL_ALLREDUCE_NCCL_KERNEL_HH
2+
#define KERNEL_ALLREDUCE_NCCL_KERNEL_HH
3+
4+
#include "kernel/collectors/all_reduce.h"
5+
#include "kernel/tensor.h"
6+
7+
namespace refactor::kernel {
8+
9+
struct AllReduceNccl final : public Kernel {
10+
AllReduceType opType;
11+
DataType dataType;
12+
size_t size;
13+
14+
AllReduceNccl(AllReduceType, DataType, size_t) noexcept;
15+
16+
static KernelBox build(AllReduceType, Tensor const &, Tensor const &) noexcept;
17+
static size_t typeId() noexcept;
18+
19+
size_t kernelTypeId() const noexcept final;
20+
std::string_view description() const noexcept final;
21+
#ifdef USE_CUDA
22+
RoutineWorkspace lower(Resources &) const noexcept final;
23+
#endif
24+
};
25+
26+
}// namespace refactor::kernel
27+
28+
#endif// KERNEL_ALLREDUCE_NCCL_KERNEL_HH
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include "common.h"
2+
#include "nccl_communicator.hh"
3+
#include <chrono>
4+
#include <cstdlib>
5+
#include <filesystem>
6+
#include <fstream>
7+
#include <thread>
8+
9+
10+
namespace refactor::kernel::nccl {
11+
NcclCommunicator::NcclCommunicator(int worldSize, int rank) : worldSize_(worldSize), rank_(rank) {
12+
const std::string filePath("./nccl_id.bin");
13+
14+
ncclUniqueId commId;
15+
16+
if (rank == 0) {
17+
checkNcclError(ncclGetUniqueId(&commId));
18+
std::ofstream ofs(filePath, std::ios::binary);
19+
ofs.write((char *) &commId, sizeof(ncclUniqueId));
20+
21+
} else {
22+
auto begin = std::chrono::steady_clock::now();
23+
while (!std::filesystem::exists(filePath)) {
24+
auto now = std::chrono::steady_clock::now();
25+
ASSERT(now < begin + std::chrono::seconds(10),
26+
"time limit (10s) exceeded.");
27+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
28+
}
29+
std::ifstream ifs(filePath, std::ios::binary);
30+
ifs.read((char *) &commId, sizeof(ncclUniqueId));
31+
}
32+
checkNcclError(ncclCommInitRank(&comm, worldSize, commId, rank));
33+
34+
if (rank == 0) {
35+
std::filesystem::remove(filePath);
36+
}
37+
38+
printf("Rank %d established NCCL communicator.\n", rank);
39+
}
40+
41+
NcclCommunicator::~NcclCommunicator() {
42+
checkNcclError(ncclCommFinalize(comm));
43+
checkNcclError(ncclCommDestroy(comm));
44+
}
45+
46+
auto NcclCommunicator::typeId() noexcept -> size_t {
47+
static uint8_t ID = 1;
48+
return reinterpret_cast<size_t>(&ID);
49+
}
50+
auto NcclCommunicator::build(int worldSize, int rank) noexcept -> runtime::ResourceBox {
51+
return std::make_unique<NcclCommunicator>(worldSize, rank);
52+
}
53+
54+
auto NcclCommunicator::resourceTypeId() const noexcept -> size_t {
55+
return typeId();
56+
}
57+
auto NcclCommunicator::description() const noexcept -> std::string_view {
58+
return "NcclCommunicator";
59+
}
60+
61+
}// namespace refactor::kernel::nccl

0 commit comments

Comments
 (0)