Skip to content

Commit

Permalink
Merge branch 'main' of github.com:pytorch/torchcodec into mac_wheels_ci
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts committed Oct 23, 2024
2 parents 571ed17 + c8de21c commit cd96123
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/macos_wheel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
package-name: torchcodec
trigger-event: ${{ github.event_name }}
build-platform: "python-build-package"
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 ${CONDA_RUN} python3 -m build --wheel -vvv --no-isolation"
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 python -m build --wheel -vvv --no-isolation"

install-and-test:
runs-on: macos-m1-stable
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/decoders/benchmark_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def get_frames_from_video(self, video_file, pts_list):
best_video_stream = metadata["bestVideoStreamIndex"]
indices_list = [int(pts * average_fps) for pts in pts_list]
frames = []
frames = get_frames_at_indices(
frames, *_ = get_frames_at_indices(
decoder, stream_index=best_video_stream, frame_indices=indices_list
)
return frames
Expand All @@ -226,7 +226,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
best_video_stream = metadata["bestVideoStreamIndex"]
frames = []
indices_list = list(range(numFramesToDecode))
frames = get_frames_at_indices(
frames, *_ = get_frames_at_indices(
decoder, stream_index=best_video_stream, frame_indices=indices_list
)
return frames
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_samplers/video_clip_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _get_clips_for_index_based_sampling(
clip_start_idx + i * index_based_sampler_args.video_frame_dilation
for i in range(index_based_sampler_args.frames_per_clip)
]
frames = get_frames_at_indices(
frames, *_ = get_frames_at_indices(
video_decoder,
stream_index=metadata_json["bestVideoStreamIndex"],
frame_indices=batch_indexes,
Expand Down
59 changes: 46 additions & 13 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,14 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata)
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})),
frames(torch::empty(
: frames(torch::empty(
{numFrames,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width),
3},
{torch::kUInt8})) {}
{torch::kUInt8})),
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {}

VideoDecoder::VideoDecoder() {}

Expand Down Expand Up @@ -1017,24 +1017,57 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
validateUserProvidedStreamIndex(streamIndex);
validateScannedAllStreams("getFramesAtIndices");

auto indicesAreSorted =
std::is_sorted(frameIndices.begin(), frameIndices.end());

std::vector<size_t> argsort;
if (!indicesAreSorted) {
// if frameIndices is [13, 10, 12, 11]
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
// to use to decode the frames
// and argsort is [ 1, 3, 2, 0]
argsort.resize(frameIndices.size());
for (size_t i = 0; i < argsort.size(); ++i) {
argsort[i] = i;
}
std::sort(
argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) {
return frameIndices[a] < frameIndices[b];
});
}

const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& stream = streams_[streamIndex];
const auto& options = stream.options;
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);

auto previousIndexInVideo = -1;
for (auto f = 0; f < frameIndices.size(); ++f) {
auto frameIndex = frameIndices[f];
if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) {
auto indexInOutput = indicesAreSorted ? f : argsort[f];
auto indexInVideo = frameIndices[indexInOutput];
if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) {
throw std::runtime_error(
"Invalid frame index=" + std::to_string(frameIndex));
"Invalid frame index=" + std::to_string(indexInVideo));
}
DecodedOutput singleOut =
getFrameAtIndex(streamIndex, frameIndex, output.frames[f]);
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
output.frames[f] = singleOut.frame;
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
// Avoid decoding the same frame twice
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]);
output.ptsSeconds[indexInOutput] =
output.ptsSeconds[previousIndexInOutput];
output.durationSeconds[indexInOutput] =
output.durationSeconds[previousIndexInOutput];
} else {
DecodedOutput singleOut = getFrameAtIndex(
streamIndex, indexInVideo, output.frames[indexInOutput]);
if (options.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
output.frames[indexInOutput] = singleOut.frame;
}
output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds;
output.durationSeconds[indexInOutput] = singleOut.durationSeconds;
}
// Note that for now we ignore the pts and duration parts of the output,
// because they're never used in any caller.
previousIndexInVideo = indexInVideo;
}
output.frames = MaybePermuteHWC2CHW(options, output.frames);
return output;
Expand Down
6 changes: 3 additions & 3 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor");
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def(
Expand Down Expand Up @@ -218,15 +218,15 @@ OpsDecodedOutput get_frame_at_index(
return makeOpsDecodedOutput(result);
}

at::Tensor get_frames_at_indices(
OpsBatchDecodedOutput get_frames_at_indices(
at::Tensor& decoder,
int64_t stream_index,
at::IntArrayRef frame_indices) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<int64_t> frameIndicesVec(
frame_indices.begin(), frame_indices.end());
auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec);
return result.frames;
return makeOpsBatchDecodedOutput(result);
}

OpsBatchDecodedOutput get_frames_in_range(
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder);

// Return the frames at a given index for a given stream as a single stacked
// Tensor.
at::Tensor get_frames_at_indices(
OpsBatchDecodedOutput get_frames_at_indices(
at::Tensor& decoder,
int64_t stream_index,
at::IntArrayRef frame_indices);
Expand Down
8 changes: 6 additions & 2 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,13 @@ def get_frames_at_indices_abstract(
*,
stream_index: int,
frame_indices: List[int],
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)
return (
torch.empty(image_size),
torch.empty([], dtype=torch.float),
torch.empty([], dtype=torch.float),
)


@register_fake("torchcodec_ns::get_frames_in_range")
Expand Down
35 changes: 33 additions & 2 deletions test/decoders/test_video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,45 @@ def test_get_frames_at_indices(self):
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
add_video_stream(decoder)
frames0and180 = get_frames_at_indices(
frames0and180, *_ = get_frames_at_indices(
decoder, stream_index=3, frame_indices=[0, 180]
)
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
reference_frame180 = NASA_VIDEO.get_frame_by_name("time6.000000")
assert_tensor_equal(frames0and180[0], reference_frame0)
assert_tensor_equal(frames0and180[1], reference_frame180)

def test_get_frames_at_indices_unsorted_indices(self):
decoder = create_from_file(str(NASA_VIDEO.path))
_add_video_stream(decoder)
scan_all_streams_to_update_metadata(decoder)
stream_index = 3

frame_indices = [2, 0, 1, 0, 2]

expected_frames = [
get_frame_at_index(
decoder, stream_index=stream_index, frame_index=frame_index
)[0]
for frame_index in frame_indices
]

frames, *_ = get_frames_at_indices(
decoder,
stream_index=stream_index,
frame_indices=frame_indices,
)
for frame, expected_frame in zip(frames, expected_frames):
assert_tensor_equal(frame, expected_frame)

# first and last frame should be equal, at index 2. We then modify the
# first frame and assert that it's now different from the last frame.
# This ensures a copy was properly made during the de-duplication logic.
assert_tensor_equal(frames[0], frames[-1])
frames[0] += 20
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

def test_get_frames_in_range(self):
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
Expand Down Expand Up @@ -425,7 +456,7 @@ def test_color_conversion_library_with_dimension_order(
assert frames.shape[1:] == expected_shape
assert_tensor_equal(frames[0], frame0_ref)

frames = get_frames_at_indices(
frames, *_ = get_frames_at_indices(
decoder, stream_index=stream_index, frame_indices=[0, 1, 3, 4]
)
assert frames.shape[1:] == expected_shape
Expand Down

0 comments on commit cd96123

Please sign in to comment.