Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/pytorch/torchcodec into cuda8
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmadsharif1 committed Oct 16, 2024
2 parents 6be7b76 + c6a0a5a commit 4bdc851
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 12 deletions.
67 changes: 58 additions & 9 deletions src/torchcodec/decoders/_core/CudaDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,40 @@ AVBufferRef* getFromCache(const torch::Device& device) {
return nullptr;
}

AVBufferRef* getCudaContext(const torch::Device& device) {
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);

AVBufferRef* hw_device_ctx = getFromCache(device);
if (hw_device_ctx != nullptr) {
return hw_device_ctx;
AVBufferRef* getFFMPEGContextFromExistingCudaContext(
const torch::Device& device,
torch::DeviceIndex nonNegativeDeviceIndex,
enum AVHWDeviceType type) {
c10::cuda::CUDAGuard deviceGuard(device);
// Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
// So we ensure the deviceIndex is not negative.
// We set the device because we may be called from a different thread than
// the one that initialized the cuda context.
cudaSetDevice(nonNegativeDeviceIndex);
AVBufferRef* hw_device_ctx = nullptr;
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
int err = av_hwdevice_ctx_create(
&hw_device_ctx,
type,
deviceOrdinal.c_str(),
nullptr,
AV_CUDA_USE_CURRENT_CONTEXT);
if (err < 0) {
TORCH_CHECK(
false,
"Failed to create specified HW device",
getFFMPEGErrorStringFromErrorCode(err));
}
return hw_device_ctx;
}

std::string deviceOrdinal = std::to_string(deviceIndex);
AVBufferRef* getFFMPEGContextFromNewCudaContext(
const torch::Device& device,
torch::DeviceIndex nonNegativeDeviceIndex,
enum AVHWDeviceType type) {
AVBufferRef* hw_device_ctx = nullptr;
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
int err = av_hwdevice_ctx_create(
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
if (err < 0) {
Expand All @@ -99,6 +122,32 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
return hw_device_ctx;
}

AVBufferRef* getCudaContext(const torch::Device& device) {
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
torch::DeviceIndex nonNegativeDeviceIndex =
getFFMPEGCompatibleDeviceIndex(device);

AVBufferRef* hw_device_ctx = getFromCache(device);
if (hw_device_ctx != nullptr) {
return hw_device_ctx;
}

// 58.26.100 introduced the concept of reusing the existing cuda context
// which is much faster and lower memory than creating a new cuda context.
// So we try to use that if it is available.
// FFMPEG 6.1.2 appears to be the earliest release that contains version
// 58.26.100 of avutil.
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
return getFFMPEGContextFromExistingCudaContext(
device, nonNegativeDeviceIndex, type);
#else
return getFFMPEGContextFromNewCudaContext(
device, nonNegativeDeviceIndex, type);
#endif
}

torch::Tensor allocateDeviceTensor(
at::IntArrayRef shape,
torch::Device device,
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
getDecodedOutputWithFilter([this](int frameStreamIndex, AVFrame* frame) {
StreamInfo& activeStream = streams_[frameStreamIndex];
return frame->pts >=
activeStream.discardFramesBeforePts.value_or(INT64_MIN);
activeStream.discardFramesBeforePts;
});
return rawOutput;
}
Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ class VideoDecoder {
int64_t currentDuration = 0;
// The desired position of the cursor in the stream. We send frames >=
// this pts to the user when they request a frame.
// We set this field if the user requested a seek.
std::optional<int64_t> discardFramesBeforePts = 0;
// We update this field if the user requested a seek.
int64_t discardFramesBeforePts = INT64_MIN;
VideoStreamDecoderOptions options;
// The filter state associated with this stream (for video streams). The
// actual graph will be nullptr for inactive streams.
Expand Down

0 comments on commit 4bdc851

Please sign in to comment.