Skip to content

Commit 52ccbf4

Browse files
Elias Ellisonpytorchmergebot
Elias Ellison
authored andcommitted
Lock thread/block computation (pytorch#73800)
Summary: Pull Request resolved: pytorch#73800 Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D34647281 Pulled By: eellison fbshipit-source-id: adbdaf24191c4c1b85e0b62564388f2481002ed2 (cherry picked from commit 6cf3801)
1 parent 4a74285 commit 52ccbf4

File tree

3 files changed

+74
-53
lines changed

3 files changed

+74
-53
lines changed

test/cpp/tensorexpr/test_dynamic_shapes.cpp

+61-50
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <gtest/gtest.h>
22

3+
#include <ATen/code_template.h>
4+
#include <c10/core/DeviceType.h>
35
#include <test/cpp/tensorexpr/test_base.h>
46
#include <torch/csrc/jit/ir/ir.h>
57
#include <torch/csrc/jit/ir/irparser.h>
@@ -629,59 +631,68 @@ TEST(DynamicShapes, GraphFromModel) {
629631

630632
TEST(DynamicShapes, MultiThreadedExecution) {
631633
#ifdef TORCH_ENABLE_LLVM
632-
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
633-
const auto graph_string = R"IR(
634-
graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
635-
%y : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
634+
const auto graph_template = R"IR(
635+
graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}),
636+
%y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}),
636637
%SS_2 : int,
637638
%SS_3 : int):
638-
%3 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::tanh(%x)
639-
%4 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::erf(%3)
640-
%5 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%4, %y)
639+
%3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x)
640+
%4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3)
641+
%5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y)
641642
return (%5))IR";
642-
torch::jit::parseIR(graph_string, graph.get());
643-
644-
std::vector<int64_t> symbolic_shape_inputs = {-2, -3};
645-
646-
std::vector<torch::jit::StrideInput> input_desc = {
647-
torch::jit::StrideInput::TENSOR_CONT};
648-
std::unordered_map<
649-
const torch::jit::Value*,
650-
std::vector<torch::jit::StrideInput>>
651-
symbolic_strides;
652-
symbolic_strides[graph->inputs().at(0)] = input_desc;
653-
symbolic_strides[graph->inputs().at(1)] = input_desc;
654-
symbolic_strides[graph->outputs().at(0)] = input_desc;
655-
656-
TensorExprKernel kernel(
657-
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
658-
659-
auto run_kernel = [&](int dim1, int dim2) {
660-
auto a =
661-
at::rand({dim1, dim2}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
662-
auto b =
663-
at::rand({dim1, dim2}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
664-
665-
auto ref = at::mul(at::erf(at::tanh(a)), b);
666-
667-
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
668-
stack.emplace_back(dim1);
669-
stack.emplace_back(dim2);
670-
kernel.run(stack);
671-
672-
auto o = stack[0].toTensor();
673-
ASSERT_TRUE(at::allclose(o, ref));
674-
};
675-
676-
// Run the kernel in parallel to ensure that the run() method calls in
677-
// TensorExprKernel are not changing any state.
678-
constexpr size_t kNumThreads = 4;
679-
std::vector<std::thread> threads;
680-
for (size_t id = 0; id < kNumThreads; ++id) {
681-
threads.emplace_back(run_kernel, id + 5, id + 20);
682-
}
683-
for (auto& t : threads) {
684-
t.join();
643+
for (bool use_cuda : {false, true}) {
644+
if (!torch::cuda::is_available() && use_cuda) {
645+
continue;
646+
}
647+
auto device = use_cuda ? at::kCUDA : at::kCPU;
648+
at::jit::TemplateEnv env;
649+
env.s("device", use_cuda ? "cuda:0" : "cpu");
650+
const auto graph_string = format(graph_template, env);
651+
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
652+
torch::jit::parseIR(graph_string, graph.get());
653+
654+
std::vector<int64_t> symbolic_shape_inputs = {-2, -3};
655+
656+
std::vector<torch::jit::StrideInput> input_desc = {
657+
torch::jit::StrideInput::TENSOR_CONT};
658+
std::unordered_map<
659+
const torch::jit::Value*,
660+
std::vector<torch::jit::StrideInput>>
661+
symbolic_strides;
662+
symbolic_strides[graph->inputs().at(0)] = input_desc;
663+
symbolic_strides[graph->inputs().at(1)] = input_desc;
664+
symbolic_strides[graph->outputs().at(0)] = input_desc;
665+
666+
TensorExprKernel kernel(
667+
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
668+
669+
auto run_kernel = [&](int dim1, int dim2) {
670+
auto a =
671+
at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat));
672+
auto b =
673+
at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat));
674+
675+
auto ref = at::mul(at::erf(at::tanh(a)), b);
676+
677+
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
678+
stack.emplace_back(dim1);
679+
stack.emplace_back(dim2);
680+
kernel.run(stack);
681+
682+
auto o = stack[0].toTensor();
683+
ASSERT_TRUE(at::allclose(o, ref));
684+
};
685+
686+
// Run the kernel in parallel to ensure that the run() method calls in
687+
// TensorExprKernel are not changing any state.
688+
constexpr size_t kNumThreads = 4;
689+
std::vector<std::thread> threads;
690+
for (size_t id = 0; id < kNumThreads; ++id) {
691+
threads.emplace_back(run_kernel, id + 5, id + 20);
692+
}
693+
for (auto& t : threads) {
694+
t.join();
695+
}
685696
}
686697
#endif
687698
}

torch/csrc/jit/tensorexpr/cuda_codegen.cpp

+12-3
Original file line numberDiff line numberDiff line change
@@ -1170,15 +1170,24 @@ void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
11701170
gpu_block_extents_v[i] = immediateAs<int64_t>(gpu_block_extents[i]);
11711171
continue;
11721172
}
1173-
gpu_block_extents_v[i] = block_extents_eval_[i].value<int64_t>(extent_args);
1173+
{
1174+
// invocation of block_extents_eval_ isn't thread safe and this function
1175+
// may be invoked by multiple threads
1176+
std::lock_guard<std::mutex> guard(eval_lock_);
1177+
gpu_block_extents_v[i] =
1178+
block_extents_eval_[i].value<int64_t>(extent_args);
1179+
}
11741180
}
11751181
for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
11761182
if (gpu_thread_extents[i]->isConstant()) {
11771183
gpu_thread_extents_v[i] = immediateAs<int64_t>(gpu_thread_extents[i]);
11781184
continue;
11791185
}
1180-
gpu_thread_extents_v[i] =
1181-
thread_extents_eval_[i].value<int64_t>(extent_args);
1186+
{
1187+
std::lock_guard<std::mutex> guard(eval_lock_);
1188+
gpu_thread_extents_v[i] =
1189+
thread_extents_eval_[i].value<int64_t>(extent_args);
1190+
}
11821191
}
11831192

11841193
// Skip launching the kernel if there are no elements to process.

torch/csrc/jit/tensorexpr/cuda_codegen.h

+1
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ class TORCH_CUDA_CU_API CudaCodeGen : public CodeGen {
273273
std::unique_ptr<CudaAnalysis> cuda_analysis_;
274274
std::unique_ptr<GPUMetaVarRewriter> metavar_rewriter_;
275275
std::unordered_set<std::string> taken_func_names;
276+
std::mutex eval_lock_;
276277
CUfunction function_;
277278
bool has_random_ = false;
278279
int thread_block_size_ = -1;

0 commit comments

Comments
 (0)