Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix build for TensorRT 8.x but it works only with TensorRT 10.x #464

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Detector/tensorrt_yolo/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,23 @@ inline uint32_t getElementSize(nvinfer1::DataType t) noexcept
{
switch (t)
{
#if (NV_TENSORRT_MAJOR > 8)
case nvinfer1::DataType::kINT64: return 8;
#endif
case nvinfer1::DataType::kINT32:
case nvinfer1::DataType::kFLOAT: return 4;
#if (NV_TENSORRT_MAJOR > 8)
case nvinfer1::DataType::kBF16:
#endif
case nvinfer1::DataType::kHALF: return 2;
case nvinfer1::DataType::kBOOL:
case nvinfer1::DataType::kUINT8:
case nvinfer1::DataType::kINT8:
case nvinfer1::DataType::kFP8: return 1;
#if (NV_TENSORRT_MAJOR > 8)
case nvinfer1::DataType::kINT4:
ASSERT(false && "Element size is not implemented for sub-byte data-types");
#endif
}
return 0;
}
Expand Down
6 changes: 6 additions & 0 deletions src/Detector/tensorrt_yolo/common/safeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,23 @@ inline uint32_t elementSize(nvinfer1::DataType t)
{
switch (t)
{
#if (NV_TENSORRT_MAJOR > 8)
case nvinfer1::DataType::kINT64: return 8;
#endif
case nvinfer1::DataType::kINT32:
case nvinfer1::DataType::kFLOAT: return 4;
case nvinfer1::DataType::kHALF:
#if (NV_TENSORRT_MAJOR > 8)
case nvinfer1::DataType::kBF16: return 2;
#endif
case nvinfer1::DataType::kINT8:
case nvinfer1::DataType::kUINT8:
case nvinfer1::DataType::kBOOL:
case nvinfer1::DataType::kFP8: return 1;
#if (NV_TENSORRT_MAJOR > 8)
case nvinfer1::DataType::kINT4:
SAFE_ASSERT(false && "Element size is not implemented for sub-byte data-types");
#endif
}
return 0;
}
Expand Down
8 changes: 8 additions & 0 deletions src/Detector/tensorrt_yolo/common/sampleDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,11 +503,19 @@ class OutputAllocator : public nvinfer1::IOutputAllocator
}

//! IMirroredBuffer does not implement Async allocation, hence this is just a wrap around
#if (NV_TENSORRT_MAJOR > 8)
void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment,
cudaStream_t /*stream*/) noexcept override
{
return reallocateOutput(tensorName, currentMemory, size, alignment);
}
#else
void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment,
cudaStream_t /*stream*/) noexcept
{
return reallocateOutput(tensorName, currentMemory, size, alignment);
}
#endif

void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override
{
Expand Down
81 changes: 64 additions & 17 deletions src/Detector/tensorrt_yolo/common/sampleEngines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ nvinfer1::ICudaEngine* LazilyDeserializedEngine::get()

if (mEngine == nullptr)
{
#if (NV_TENSORRT_MAJOR > 8)
SMP_RETVAL_IF_FALSE(getFileReader().isOpen() || !getBlob().empty(), "Engine is empty. Nothing to deserialize!",
nullptr, sample::gLogError);

#endif
using time_point = std::chrono::time_point<std::chrono::high_resolution_clock>;
using duration = std::chrono::duration<float>;
time_point const deserializeStartTime{std::chrono::high_resolution_clock::now()};
Expand Down Expand Up @@ -126,6 +127,7 @@ nvinfer1::ICudaEngine* LazilyDeserializedEngine::get()
}
#endif

#if (NV_TENSORRT_MAJOR > 8)
if (getFileReader().isOpen())
{
mEngine.reset(mRuntime->deserializeCudaEngine(getFileReader()));
Expand All @@ -135,6 +137,11 @@ nvinfer1::ICudaEngine* LazilyDeserializedEngine::get()
auto const& engineBlob = getBlob();
mEngine.reset(mRuntime->deserializeCudaEngine(engineBlob.data, engineBlob.size));
}
#else
auto const& engineBlob = getBlob();
mEngine.reset(mRuntime->deserializeCudaEngine(engineBlob.data, engineBlob.size));
std::cerr << "getFileReader is not implemented! Use TensorRT 10.x and higher" << std::endl;
#endif
SMP_RETVAL_IF_FALSE(mEngine != nullptr, "Engine deserialization failed", nullptr, sample::gLogError);

time_point const deserializeEndTime{std::chrono::high_resolution_clock::now()};
Expand Down Expand Up @@ -405,8 +412,12 @@ bool setTensorDynamicRange(INetworkDefinition const& network, float inRange = 2.

bool isNonActivationType(nvinfer1::DataType const type)
{
return type == nvinfer1::DataType::kINT32 || type == nvinfer1::DataType::kINT64 || type == nvinfer1::DataType::kBOOL
|| type == nvinfer1::DataType::kUINT8;
return type == nvinfer1::DataType::kINT32
#if (NV_TENSORRT_MAJOR > 8)
|| type == nvinfer1::DataType::kINT64
#endif
|| type == nvinfer1::DataType::kBOOL
|| type == nvinfer1::DataType::kUINT8;
}

void setLayerPrecisions(INetworkDefinition& network, LayerPrecisions const& layerPrecisions)
Expand Down Expand Up @@ -567,6 +578,7 @@ void setLayerDeviceTypes(

void markDebugTensors(INetworkDefinition& network, StringSet const& debugTensors)
{
#if (NV_TENSORRT_MAJOR > 8)
for (int64_t inputIndex = 0; inputIndex < network.getNbInputs(); ++inputIndex)
{
auto* t = network.getInput(inputIndex);
Expand All @@ -589,6 +601,9 @@ void markDebugTensors(INetworkDefinition& network, StringSet const& debugTensors
}
}
}
#else
std::cerr << "Can not markDebugTensors. Use TensorRT 10.x or higher" << std::endl;
#endif
}

void setMemoryPoolLimits(IBuilderConfig& config, BuildOptions const& build)
Expand Down Expand Up @@ -626,10 +641,12 @@ void setMemoryPoolLimits(IBuilderConfig& config, BuildOptions const& build)
{
config.setMemoryPoolLimit(MemoryPoolType::kDLA_GLOBAL_DRAM, roundToBytes(build.dlaGlobalDRAM));
}
#if (NV_TENSORRT_MAJOR > 8)
if (build.tacticSharedMem >= 0)
{
config.setMemoryPoolLimit(MemoryPoolType::kTACTIC_SHARED_MEMORY, roundToBytes(build.tacticSharedMem, false));
}
#endif
}

void setPreviewFeatures(IBuilderConfig& config, BuildOptions const& build)
Expand All @@ -641,7 +658,9 @@ void setPreviewFeatures(IBuilderConfig& config, BuildOptions const& build)
config.setPreviewFeature(feat, build.previewFeatures.at(featVal));
}
};
#if (NV_TENSORRT_MAJOR > 8)
setFlag(PreviewFeature::kALIASED_PLUGIN_IO_10_03);
#endif
}

} // namespace
Expand Down Expand Up @@ -845,7 +864,7 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,

if (build.maxTactics != defaultMaxTactics)
{
#if (NV_TENSORRT_MAJOR < 9)
#if (NV_TENSORRT_MAJOR < 8)
config.setMaxNbTactics(build.maxTactics);
#else
config.setTacticSources(build.maxTactics);
Expand All @@ -856,7 +875,7 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
{
config.setFlag(BuilderFlag::kDISABLE_TIMING_CACHE);
}

#if (NV_TENSORRT_MAJOR > 8)
if (build.disableCompilationCache)
{
config.setFlag(BuilderFlag::kDISABLE_COMPILATION_CACHE);
Expand All @@ -866,7 +885,7 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
{
config.setFlag(BuilderFlag::kERROR_ON_TIMING_CACHE_MISS);
}

#endif
if (!build.tf32)
{
config.clearFlag(BuilderFlag::kTF32);
Expand All @@ -876,13 +895,13 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
{
config.setFlag(BuilderFlag::kREFIT);
}

#if (NV_TENSORRT_MAJOR > 8)
if (build.stripWeights)
{
// The kREFIT_IDENTICAL is enabled by default when kSTRIP_PLAN is on.
config.setFlag(BuilderFlag::kSTRIP_PLAN);
}

#endif
if (build.versionCompatible)
{
config.setFlag(BuilderFlag::kVERSION_COMPATIBLE);
Expand Down Expand Up @@ -924,23 +943,25 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
{
config.setFlag(BuilderFlag::kINT8);
}
#if (NV_TENSORRT_MAJOR > 8)
if (build.bf16)
{
config.setFlag(BuilderFlag::kBF16);
}
#endif

SMP_RETVAL_IF_FALSE(!(build.int8 && build.fp8), "FP8 and INT8 precisions have been specified", false, err);

if (build.fp8)
{
config.setFlag(BuilderFlag::kFP8);
}

#if (NV_TENSORRT_MAJOR > 8)
if (build.int4)
{
config.setFlag(BuilderFlag::kINT4);
}

#endif
if (build.int8 && !build.fp16)
{
sample::gLogInfo
Expand Down Expand Up @@ -1136,7 +1157,9 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
}

config.setHardwareCompatibilityLevel(build.hardwareCompatibilityLevel);
#if (NV_TENSORRT_MAJOR > 8)
config.setRuntimePlatform(build.runtimePlatform);
#endif

if (build.maxAuxStreams != defaultMaxAuxStreams)
{
Expand All @@ -1145,7 +1168,11 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,

if (build.allowWeightStreaming)
{
#if (NV_TENSORRT_MAJOR > 8)
config.setFlag(BuilderFlag::kWEIGHT_STREAMING);
#else
std::cerr << "BuilderFlag::kWEIGHT_STREAMING not allowed in TensorRT with version less than 10.x" << std::endl;
#endif
}

return true;
Expand Down Expand Up @@ -1208,9 +1235,13 @@ bool modelToBuildEnv(
env.builder.reset(createBuilder());
SMP_RETVAL_IF_FALSE(env.builder != nullptr, "Builder creation failed", false, err);
env.builder->setErrorRecorder(&gRecorder);
#if (NV_TENSORRT_MAJOR > 8)
auto networkFlags = (build.stronglyTyped)
? 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED)
: 0U;
#else
auto networkFlags = 0U;
#endif
#if !TRT_WINML
for (auto const& pluginPath : sys.dynamicPlugins)
{
Expand Down Expand Up @@ -1304,8 +1335,12 @@ std::pair<std::vector<std::string>, std::vector<WeightsRole>> getMissingLayerWei

bool loadStreamingEngineToBuildEnv(std::string const& filepath, BuildEnvironment& env, std::ostream& err)
{
#if (NV_TENSORRT_MAJOR > 8)
auto& reader = env.engine.getFileReader();
SMP_RETVAL_IF_FALSE(reader.open(filepath), "", false, err << "Error opening engine file: " << filepath);
#else
SMP_RETVAL_IF_FALSE(false, "", false, err << "Error opening engine file: " << filepath);
#endif
return true;
}

Expand Down Expand Up @@ -1337,12 +1372,14 @@ bool printPlanVersion(BuildEnvironment& env, std::ostream& err)
std::vector<uint8_t> data(kPLAN_SIZE);
auto blob = data.data();

#if (NV_TENSORRT_MAJOR > 8)
auto& reader = env.engine.getFileReader();
if (reader.isOpen())
{
SMP_RETVAL_IF_FALSE(reader.read(data.data(), kPLAN_SIZE) == kPLAN_SIZE, "Failed to read plan file", false, err);
}
else
#endif
{
SMP_RETVAL_IF_FALSE(env.engine.getBlob().data != nullptr, "Plan file is empty", false, err);
SMP_RETVAL_IF_FALSE(env.engine.getBlob().size >= 28, "Plan file is incorrect", false, err);
Expand Down Expand Up @@ -1473,14 +1510,21 @@ std::vector<std::pair<WeightsRole, Weights>> getAllRefitWeightsForLayer(const IL
{
case DataType::kFLOAT:
case DataType::kHALF:
#if (NV_TENSORRT_MAJOR > 8)
case DataType::kBF16:
#endif
case DataType::kINT8:
case DataType::kINT32:
case DataType::kINT64: return {std::make_pair(WeightsRole::kCONSTANT, weights)};
#if (NV_TENSORRT_MAJOR > 8)
case DataType::kINT64:
#endif
return {std::make_pair(WeightsRole::kCONSTANT, weights)};
case DataType::kBOOL:
case DataType::kUINT8:
case DataType::kFP8:
#if (NV_TENSORRT_MAJOR > 8)
case DataType::kINT4:
#endif
// Refit not supported for these types.
break;
}
Expand Down Expand Up @@ -1530,7 +1574,9 @@ std::vector<std::pair<WeightsRole, Weights>> getAllRefitWeightsForLayer(const IL
case LayerType::kPARAMETRIC_RELU:
case LayerType::kPLUGIN:
case LayerType::kPLUGIN_V2:
#if (NV_TENSORRT_MAJOR > 8)
case LayerType::kPLUGIN_V3:
#endif
case LayerType::kPOOLING:
case LayerType::kQUANTIZE:
case LayerType::kRAGGED_SOFTMAX:
Expand Down Expand Up @@ -1610,11 +1656,10 @@ bool timeRefit(INetworkDefinition const& network, nvinfer1::ICudaEngine& engine,
}
return layerNames.empty();
};

// Skip weights validation since we are confident that the new weights are similar to the weights used to build
// engine.
#if (NV_TENSORRT_MAJOR > 8)
// Skip weights validation since we are confident that the new weights are similar to the weights used to build engine.
refitter->setWeightsValidation(false);

#endif
// Warm up and report missing weights
// We only need to set weights for the first time and that can be reused in later refitting process.
bool const success = setWeights() && reportMissingWeights() && refitter->refitCudaEngine();
Expand All @@ -1623,9 +1668,10 @@ bool timeRefit(INetworkDefinition const& network, nvinfer1::ICudaEngine& engine,
return false;
}

TrtCudaStream stream;
constexpr int32_t kLOOP = 10;
time_point const refitStartTime{std::chrono::steady_clock::now()};
constexpr int32_t kLOOP = 10;
#if (NV_TENSORRT_MAJOR > 8)
TrtCudaStream stream;
{
for (int32_t l = 0; l < kLOOP; l++)
{
Expand All @@ -1636,6 +1682,7 @@ bool timeRefit(INetworkDefinition const& network, nvinfer1::ICudaEngine& engine,
}
}
stream.synchronize();
#endif
time_point const refitEndTime{std::chrono::steady_clock::now()};

sample::gLogInfo << "Engine refitted"
Expand Down
Loading
Loading