Skip to content

Commit 3cb6f5d

Browse files
authored
Merge pull request #91 from InfiniTensor/dev-exp
feat(kernel): 添加exp算子
2 parents 34ed834 + 3199eb5 commit 3cb6f5d

File tree

11 files changed

+57
-4
lines changed

11 files changed

+57
-4
lines changed

src/04kernel/include/kernel/collectors/simple_unary.h

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace refactor::kernel {
2626
Neg,
2727
Not,
2828
HardSwish,
29+
Exp,
2930
};
3031

3132
std::string_view unaryName(SimpleUnaryType type);

src/04kernel/src/collectors/simple_unary.cc

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace refactor::kernel {
2929
CASE(Sqrt);
3030
CASE(Sigmoid);
3131
CASE(Erf);
32+
CASE(Exp);
3233
CASE(Neg);
3334
CASE(Not);
3435
CASE(HardSwish);

src/04kernel/src/kernels/simple_unary/cpu_kernel.cc

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace refactor::kernel {
2020
Op::Neg,
2121
Op::Erf,
2222
Op::HardSwish,
23+
Op::Exp,
2324
};
2425
return supportedOp.contains(op) && a.dataType.isCpuNumberic()
2526
? std::make_unique<K>(op, a.dataType, a.elementsSize())
@@ -185,6 +186,13 @@ namespace refactor::kernel {
185186
default:
186187
UNREACHABLE();
187188
}
189+
case Op::Exp:
190+
switch (dataType) {
191+
CASE(std::exp, F32);
192+
CASE(std::exp, F64);
193+
default:
194+
UNREACHABLE();
195+
}
188196
default:
189197
UNREACHABLE();
190198
}

src/04kernel/src/kernels/simple_unary/cuda_kernel.cc

+16-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,17 @@ namespace refactor::kernel {
1717

1818
auto K::build(Op op, Tensor const &a) noexcept -> KernelBox {
1919
static const std::unordered_set<Op>
20-
supportedOp{Op::Abs, Op::Relu, Op::Sqrt,
21-
Op::Sigmoid, Op::Tanh, Op::Neg,
22-
Op::Erf, Op::HardSwish};
20+
supportedOp{
21+
Op::Abs,
22+
Op::Relu,
23+
Op::Sqrt,
24+
Op::Sigmoid,
25+
Op::Tanh,
26+
Op::Neg,
27+
Op::Erf,
28+
Op::HardSwish,
29+
Op::Exp,
30+
};
2331
#ifndef USE_CUDA
2432
return nullptr;
2533
#endif
@@ -159,6 +167,11 @@ extern "C" __global__ void kernel(
159167
{__(Op::HardSwish, DT::FP16), "x * __hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, hrcp(__float2half(6.f)) * x + hrcp(__float2half(2.f))))"},
160168
{__(Op::HardSwish, DT::F64 ), "x * fmax(0.0, fmin(1.0, fma(1.0/6.0, x, 0.5)))"},
161169

170+
{__(Op::Exp, DT::F32 ), "expf(x)"},
171+
{__(Op::Exp, DT::F64 ), "exp(x)"},
172+
{__(Op::Exp, DT::FP16), "hexp(x)"},
173+
{__(Op::Exp, DT::BF16), "hexp(x)"},
174+
162175
};
163176
// clang-format on
164177

src/04kernel/test/kernels/simple_unary/test_cpu.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,5 @@ TEST(kernel, SimpleUnaryCpu) {
5959
testOp(SimpleUnaryType::Erf, std::erf);
6060
testOpWithData(SimpleUnaryType::HardSwish,
6161
VecFloat{0.000000, 0.666667, 1.666667, 3.000000, 4.000000, 5.000000});
62+
testOpWithData(SimpleUnaryType::Exp, VecFloat{1.000000, 2.718282, 7.389056, 20.085537, 54.598148, 148.413162});
6263
}

src/04kernel/test/kernels/simple_unary/test_cuda.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ TEST(kernel, SimpleUnaryCuda) {
5353
testOp(SimpleUnaryType::Tanh);
5454
testOp(SimpleUnaryType::Erf);
5555
testOp(SimpleUnaryType::HardSwish);
56+
testOp(SimpleUnaryType::Exp);
5657
}
5758

5859
#endif

src/05computation/src/operators/simple_unary.cc

+6
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ namespace refactor::computation {
8585
static uint8_t ID = 20;
8686
return reinterpret_cast<size_t>(&ID);
8787
}
88+
case SimpleUnaryType::Exp: {
89+
static uint8_t ID = 21;
90+
return reinterpret_cast<size_t>(&ID);
91+
}
8892
default:
8993
UNREACHABLE();
9094
}
@@ -134,6 +138,8 @@ namespace refactor::computation {
134138
return "Not";
135139
case SimpleUnaryType::HardSwish:
136140
return "HardSwish";
141+
case SimpleUnaryType::Exp:
142+
return "Exp";
137143
default:
138144
UNREACHABLE();
139145
}

src/07onnx/src/operators.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ namespace refactor::onnx {
120120
REGISTER(Neg , SimpleUnary );
121121
REGISTER(Identity , SimpleUnary );
122122
REGISTER(HardSwish , SimpleUnary );
123+
REGISTER(Exp , SimpleUnary );
123124
REGISTER(Slice , Slice );
124125
REGISTER(Softmax , Softmax );
125126
REGISTER(Split , Split );

src/07onnx/src/operators/simple_unary.cc

+8-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace refactor::onnx {
3838
opType == "onnx::Neg" ? Ty::Neg :
3939
opType == "onnx::Identity"? Ty::Identity:
4040
opType == "onnx::HardSwish" ? Ty::HardSwish :
41+
opType == "onnx::Exp" ? Ty::Exp :
4142
UNREACHABLEX(Ty, "Unsupported unary operator: {}", opType);
4243
// clang-format on
4344

@@ -134,6 +135,10 @@ namespace refactor::onnx {
134135
static uint8_t ID = 22;
135136
return reinterpret_cast<size_t>(&ID);
136137
}
138+
case Ty::Exp: {
139+
static uint8_t ID = 23;
140+
return reinterpret_cast<size_t>(&ID);
141+
}
137142
default:
138143
UNREACHABLE();
139144
}
@@ -165,6 +170,7 @@ namespace refactor::onnx {
165170
case Ty::Neg : return "onnx::Neg";
166171
case Ty::Identity : return "onnx::Identity";
167172
case Ty::HardSwish : return "onnx::HardSwish";
173+
case Ty::Exp : return "onnx::Exp";
168174
default: UNREACHABLE();
169175
}
170176
// clang-format on
@@ -194,7 +200,7 @@ namespace refactor::onnx {
194200
Ty::Cos, Ty::Cosh,
195201
Ty::Sin, Ty::Sinh,
196202
Ty::Tan, Ty::HardSwish},
197-
{Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log},
203+
{Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log, Ty::Exp},
198204
{Ty::Neg},
199205
{Ty::Identity}};
200206
if (SET[0].contains(type)) {
@@ -294,6 +300,7 @@ namespace refactor::onnx {
294300
case Ty::Neg : type_ = Ty_::Neg ; break;
295301
case Ty::Identity : return std::make_unique<computation::Identity>();
296302
case Ty::HardSwish : type_ = Ty_::HardSwish ; break;
303+
case Ty::Exp : type_ = Ty_::Exp ; break;
297304
default: UNREACHABLE();
298305
}
299306
// clang-format on

src/07onnx/src/operators/simple_unary.hh

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ namespace refactor::onnx {
1717
Cos,
1818
Cosh,
1919
Erf,
20+
Exp,
2021
HardSwish,
2122
Identity,
2223
Log,

src/07onnx/test/test_simple_unary.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,17 @@ TEST(infer, SimpleUnary) {
3636
ASSERT_EQ(y->dataType, DataType::F32);
3737
ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)}));
3838
}
39+
{
40+
// Exp Test
41+
auto edges = Edges{
42+
{Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}};
43+
count_t inputs[]{0};
44+
auto infered = SimpleUnary(SimpleUnaryType::Exp).infer(TensorRefs(edges, inputs), {true});
45+
ASSERT_TRUE(infered.isOk());
46+
auto outputs = std::move(infered.unwrap());
47+
ASSERT_EQ(outputs.size(), 1);
48+
auto y = std::move(outputs[0]);
49+
ASSERT_EQ(y->dataType, DataType::F32);
50+
ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)}));
51+
}
3952
}

0 commit comments

Comments
 (0)