@@ -23,6 +23,10 @@ limitations under the License.
23
23
#include < utility>
24
24
#include < vector>
25
25
26
+ #include " absl/strings/match.h"
27
+ #include " absl/strings/numbers.h"
28
+ #include " absl/strings/str_join.h"
29
+ #include " absl/strings/str_split.h"
26
30
#include " flatbuffers/flatbuffer_builder.h" // from @flatbuffers
27
31
#include " tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
28
32
#include " tensorflow/lite/experimental/acceleration/mini_benchmark/benchmark_result_evaluator.h"
@@ -54,6 +58,56 @@ std::unique_ptr<FlatBufferBuilder> CopyModel(
54
58
return copy;
55
59
}
56
60
61
+ // A simple holder for file descriptor that will close the file descriptor at
62
+ // destruction time.
63
+ class FdHolder {
64
+ public:
65
+ explicit FdHolder (int fd) : fd_(fd) {}
66
+
67
+ // Move only.
68
+ FdHolder (FdHolder&& other) = default ;
69
+ FdHolder& operator =(FdHolder&& other) = default ;
70
+
71
+ ~FdHolder () {
72
+ if (fd_ > 0 ) {
73
+ close (fd_);
74
+ }
75
+ }
76
+
77
+ private:
78
+ int fd_;
79
+ };
80
+
81
+ // Returns a FdHolder that will close the duped file descriptor when going out
82
+ // of scope. If the model is passed in as a file descriptor, update the
83
+ // model_path with a duped file descriptor. The original file descriptor may be
84
+ // opened with FD_CLOEXEC, and cannot be read from the child process.
85
+ std::unique_ptr<FdHolder> UpdateModelPathIfUsingFd (std::string& model_path) {
86
+ if (!absl::StartsWith (model_path, " fd:" )) {
87
+ return nullptr ;
88
+ }
89
+ std::vector<std::string> parts = absl::StrSplit (model_path, ' :' );
90
+ int model_fd;
91
+ if (!absl::SimpleAtoi (parts[1 ], &model_fd)) {
92
+ TFLITE_LOG_PROD (TFLITE_LOG_ERROR,
93
+ " Failed to parse file descriptor %s from model_path %s" ,
94
+ parts[1 ].c_str (), model_path.c_str ());
95
+ return nullptr ;
96
+ }
97
+ int new_fd = dup (model_fd);
98
+ if (new_fd < 0 ) {
99
+ TFLITE_LOG_PROD (
100
+ TFLITE_LOG_ERROR,
101
+ " Failed to dup() file descriptor. Original fd: %d errno: %d" , model_fd,
102
+ errno);
103
+ return nullptr ;
104
+ }
105
+
106
+ parts[1 ] = std::to_string (new_fd);
107
+ model_path = absl::StrJoin (parts, " :" );
108
+ return std::make_unique<FdHolder>(new_fd);
109
+ }
110
+
57
111
} // namespace
58
112
59
113
MinibenchmarkStatus ValidatorRunnerImpl::Init () {
@@ -143,7 +197,7 @@ void ValidatorRunnerImpl::TriggerValidationAsync(
143
197
// error_reporter is not passed in because the ownership cannot be passed to
144
198
// the thread.
145
199
std::thread detached_thread (
146
- [model_path = fd_or_model_path_, storage_path = storage_path_,
200
+ [original_model_path = fd_or_model_path_, storage_path = storage_path_,
147
201
data_directory_path = data_directory_path_,
148
202
tflite_settings = std::move (tflite_settings),
149
203
validation_entrypoint_name =
@@ -157,6 +211,10 @@ void ValidatorRunnerImpl::TriggerValidationAsync(
157
211
if (!lock.TryLock ()) {
158
212
return ;
159
213
}
214
+
215
+ std::string model_path = original_model_path;
216
+ std::unique_ptr<FdHolder> fd_holder =
217
+ UpdateModelPathIfUsingFd (model_path);
160
218
for (auto & one_setting : *tflite_settings) {
161
219
FlatbufferStorage<BenchmarkEvent> storage (storage_path);
162
220
TFLiteSettingsT tflite_settings_obj;
0 commit comments