Skip to content

Commit bffabe6

Browse files
committed
[tmva][sofie] Apply fixes when batch size is given by user
When input shape is parametric, give possiibility to fix batch size if provided by user in RModel::Generate
1 parent ba52dca commit bffabe6

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

tmva/sofie/src/RModel.cxx

+5-6
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ void RModel::Initialize(int batchSize, bool verbose) {
262262
// loop on inputs and see if shape can be full specified
263263
// if the batch size is provided it can be used to specify the full shape
264264
// Add the full specified tensors in fReadyInputTensors collection
265-
for (auto &input : fInputTensorInfos) {
265+
auto originalInputTensorInfos = fInputTensorInfos; // need to copy because we may delete elements
266+
for (auto &input : originalInputTensorInfos) {
266267
std::cout << "looking at the tensor " << input.first << std::endl;
267268
// if a batch size is provided convert batch size
268269
// assume is parametrized as "bs" or "batch_size"
@@ -288,12 +289,10 @@ void RModel::Initialize(int batchSize, bool verbose) {
288289
}
289290
auto shape = ConvertShapeToInt(input.second.shape);
290291
if (!shape.empty()) {
291-
#if 0
292292
// remove from the tensor info old dynamic shape
293293
fInputTensorInfos.erase(input.first);
294294
// add to the ready input tensor information the new fixed shape
295-
AddInputTensorInfo(input.first, input.second.type, shape);
296-
#endif
295+
AddInputTensorInfo(input.first, input.second.type, shape);
297296
}
298297
// store the parameters of the input tensors
299298
else {
@@ -647,7 +646,7 @@ void RModel::ReadInitializedTensorsFromFile(long pos) {
647646
fGC += " std::ifstream f;\n";
648647
fGC += " f.open(filename);\n";
649648
fGC += " if (!f.is_open()) {\n";
650-
fGC += " throw std::runtime_error(\"tmva-sofie failed to open file for input weights\");\n";
649+
fGC += " throw std::runtime_error(\"tmva-sofie failed to open file \" + filename + \" for input weights\");\n";
651650
fGC += " }\n";
652651

653652
if(fIsGNNComponent) {
@@ -783,7 +782,7 @@ long RModel::WriteInitializedTensorsToFile(std::string filename) {
783782
}
784783
if (!f.is_open())
785784
throw
786-
std::runtime_error("tmva-sofie failed to open file for tensor weight data");
785+
std::runtime_error("tmva-sofie failed to open file " + filename + " for tensor weight data");
787786
for (auto& i: fInitializedTensors) {
788787
if (i.second.type() == ETensorType::FLOAT) {
789788
size_t length = 1;

0 commit comments

Comments
 (0)