@@ -38,6 +38,7 @@ namespace refactor::onnx {
38
38
opType == " onnx::Neg" ? Ty::Neg :
39
39
opType == " onnx::Identity" ? Ty::Identity:
40
40
opType == " onnx::HardSwish" ? Ty::HardSwish :
41
+ opType == " onnx::Exp" ? Ty::Exp :
41
42
UNREACHABLEX (Ty, " Unsupported unary operator: {}" , opType);
42
43
// clang-format on
43
44
@@ -134,6 +135,10 @@ namespace refactor::onnx {
134
135
static uint8_t ID = 22 ;
135
136
return reinterpret_cast <size_t >(&ID);
136
137
}
138
+ case Ty::Exp: {
139
+ static uint8_t ID = 23 ;
140
+ return reinterpret_cast <size_t >(&ID);
141
+ }
137
142
default :
138
143
UNREACHABLE ();
139
144
}
@@ -165,6 +170,7 @@ namespace refactor::onnx {
165
170
case Ty::Neg : return " onnx::Neg" ;
166
171
case Ty::Identity : return " onnx::Identity" ;
167
172
case Ty::HardSwish : return " onnx::HardSwish" ;
173
+ case Ty::Exp : return " onnx::Exp" ;
168
174
default : UNREACHABLE ();
169
175
}
170
176
// clang-format on
@@ -194,7 +200,7 @@ namespace refactor::onnx {
194
200
Ty::Cos, Ty::Cosh,
195
201
Ty::Sin, Ty::Sinh,
196
202
Ty::Tan, Ty::HardSwish},
197
- {Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log},
203
+ {Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log, Ty::Exp },
198
204
{Ty::Neg},
199
205
{Ty::Identity}};
200
206
if (SET[0 ].contains (type)) {
@@ -294,6 +300,7 @@ namespace refactor::onnx {
294
300
case Ty::Neg : type_ = Ty_::Neg ; break ;
295
301
case Ty::Identity : return std::make_unique<computation::Identity>();
296
302
case Ty::HardSwish : type_ = Ty_::HardSwish ; break ;
303
+ case Ty::Exp : type_ = Ty_::Exp ; break ;
297
304
default : UNREACHABLE ();
298
305
}
299
306
// clang-format on
0 commit comments