|
1 | 1 | #include <gtest/gtest.h>
|
2 | 2 |
|
| 3 | +#include <ATen/code_template.h> |
| 4 | +#include <c10/core/DeviceType.h> |
3 | 5 | #include <test/cpp/tensorexpr/test_base.h>
|
4 | 6 | #include <torch/csrc/jit/ir/ir.h>
|
5 | 7 | #include <torch/csrc/jit/ir/irparser.h>
|
@@ -629,59 +631,68 @@ TEST(DynamicShapes, GraphFromModel) {
|
629 | 631 |
|
630 | 632 | TEST(DynamicShapes, MultiThreadedExecution) {
|
631 | 633 | #ifdef TORCH_ENABLE_LLVM
|
632 |
| - std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
633 |
| - const auto graph_string = R"IR( |
634 |
| - graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), |
635 |
| - %y : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), |
| 634 | + const auto graph_template = R"IR( |
| 635 | + graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), |
| 636 | + %y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), |
636 | 637 | %SS_2 : int,
|
637 | 638 | %SS_3 : int):
|
638 |
| - %3 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::tanh(%x) |
639 |
| - %4 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::erf(%3) |
640 |
| - %5 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%4, %y) |
| 639 | + %3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x) |
| 640 | + %4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3) |
| 641 | + %5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y) |
641 | 642 | return (%5))IR";
|
642 |
| - torch::jit::parseIR(graph_string, graph.get()); |
643 |
| - |
644 |
| - std::vector<int64_t> symbolic_shape_inputs = {-2, -3}; |
645 |
| - |
646 |
| - std::vector<torch::jit::StrideInput> input_desc = { |
647 |
| - torch::jit::StrideInput::TENSOR_CONT}; |
648 |
| - std::unordered_map< |
649 |
| - const torch::jit::Value*, |
650 |
| - std::vector<torch::jit::StrideInput>> |
651 |
| - symbolic_strides; |
652 |
| - symbolic_strides[graph->inputs().at(0)] = input_desc; |
653 |
| - symbolic_strides[graph->inputs().at(1)] = input_desc; |
654 |
| - symbolic_strides[graph->outputs().at(0)] = input_desc; |
655 |
| - |
656 |
| - TensorExprKernel kernel( |
657 |
| - graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
658 |
| - |
659 |
| - auto run_kernel = [&](int dim1, int dim2) { |
660 |
| - auto a = |
661 |
| - at::rand({dim1, dim2}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
662 |
| - auto b = |
663 |
| - at::rand({dim1, dim2}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
664 |
| - |
665 |
| - auto ref = at::mul(at::erf(at::tanh(a)), b); |
666 |
| - |
667 |
| - std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b})); |
668 |
| - stack.emplace_back(dim1); |
669 |
| - stack.emplace_back(dim2); |
670 |
| - kernel.run(stack); |
671 |
| - |
672 |
| - auto o = stack[0].toTensor(); |
673 |
| - ASSERT_TRUE(at::allclose(o, ref)); |
674 |
| - }; |
675 |
| - |
676 |
| - // Run the kernel in parallel to ensure that the run() method calls in |
677 |
| - // TensorExprKernel are not changing any state. |
678 |
| - constexpr size_t kNumThreads = 4; |
679 |
| - std::vector<std::thread> threads; |
680 |
| - for (size_t id = 0; id < kNumThreads; ++id) { |
681 |
| - threads.emplace_back(run_kernel, id + 5, id + 20); |
682 |
| - } |
683 |
| - for (auto& t : threads) { |
684 |
| - t.join(); |
| 643 | + for (bool use_cuda : {false, true}) { |
| 644 | + if (!torch::cuda::is_available() && use_cuda) { |
| 645 | + continue; |
| 646 | + } |
| 647 | + auto device = use_cuda ? at::kCUDA : at::kCPU; |
| 648 | + at::jit::TemplateEnv env; |
| 649 | + env.s("device", use_cuda ? "cuda:0" : "cpu"); |
| 650 | + const auto graph_string = format(graph_template, env); |
| 651 | + std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
| 652 | + torch::jit::parseIR(graph_string, graph.get()); |
| 653 | + |
| 654 | + std::vector<int64_t> symbolic_shape_inputs = {-2, -3}; |
| 655 | + |
| 656 | + std::vector<torch::jit::StrideInput> input_desc = { |
| 657 | + torch::jit::StrideInput::TENSOR_CONT}; |
| 658 | + std::unordered_map< |
| 659 | + const torch::jit::Value*, |
| 660 | + std::vector<torch::jit::StrideInput>> |
| 661 | + symbolic_strides; |
| 662 | + symbolic_strides[graph->inputs().at(0)] = input_desc; |
| 663 | + symbolic_strides[graph->inputs().at(1)] = input_desc; |
| 664 | + symbolic_strides[graph->outputs().at(0)] = input_desc; |
| 665 | + |
| 666 | + TensorExprKernel kernel( |
| 667 | + graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
| 668 | + |
| 669 | + auto run_kernel = [&](int dim1, int dim2) { |
| 670 | + auto a = |
| 671 | + at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); |
| 672 | + auto b = |
| 673 | + at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); |
| 674 | + |
| 675 | + auto ref = at::mul(at::erf(at::tanh(a)), b); |
| 676 | + |
| 677 | + std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b})); |
| 678 | + stack.emplace_back(dim1); |
| 679 | + stack.emplace_back(dim2); |
| 680 | + kernel.run(stack); |
| 681 | + |
| 682 | + auto o = stack[0].toTensor(); |
| 683 | + ASSERT_TRUE(at::allclose(o, ref)); |
| 684 | + }; |
| 685 | + |
| 686 | + // Run the kernel in parallel to ensure that the run() method calls in |
| 687 | + // TensorExprKernel are not changing any state. |
| 688 | + constexpr size_t kNumThreads = 4; |
| 689 | + std::vector<std::thread> threads; |
| 690 | + for (size_t id = 0; id < kNumThreads; ++id) { |
| 691 | + threads.emplace_back(run_kernel, id + 5, id + 20); |
| 692 | + } |
| 693 | + for (auto& t : threads) { |
| 694 | + t.join(); |
| 695 | + } |
685 | 696 | }
|
686 | 697 | #endif
|
687 | 698 | }
|
|
0 commit comments