Skip to content

Commit 0f22b2b

Browse files
authored
Allow num_frames and duration to be absent in C++ decoder (#708)
1 parent ae50558 commit 0f22b2b

File tree

3 files changed

+64
-35
lines changed

3 files changed

+64
-35
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -602,16 +602,22 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
602602
const auto& streamMetadata =
603603
containerMetadata_.allStreamMetadata[activeStreamIndex_];
604604
const auto& streamInfo = streamInfos_[activeStreamIndex_];
605-
int64_t numFrames = getNumFrames(streamMetadata);
606605
TORCH_CHECK(
607606
start >= 0, "Range start, " + std::to_string(start) + " is less than 0.");
608-
TORCH_CHECK(
609-
stop <= numFrames,
610-
"Range stop, " + std::to_string(stop) +
611-
", is more than the number of frames, " + std::to_string(numFrames));
612607
TORCH_CHECK(
613608
step > 0, "Step must be greater than 0; is " + std::to_string(step));
614609

610+
// Note that if we do not have the number of frames available in our metadata,
611+
// then we assume that the upper part of the range is valid.
612+
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
613+
if (numFrames.has_value()) {
614+
TORCH_CHECK(
615+
stop <= numFrames.value(),
616+
"Range stop, " + std::to_string(stop) +
617+
", is more than the number of frames, " +
618+
std::to_string(numFrames.value()));
619+
}
620+
615621
int64_t numOutputFrames = std::ceil((stop - start) / double(step));
616622
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
617623
FrameBatchOutput frameBatchOutput(
@@ -676,7 +682,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
676682
containerMetadata_.allStreamMetadata[activeStreamIndex_];
677683

678684
double minSeconds = getMinSeconds(streamMetadata);
679-
double maxSeconds = getMaxSeconds(streamMetadata);
685+
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
680686

681687
// The frame played at timestamp t and the one played at timestamp `t +
682688
// eps` are probably the same frame, with the same index. The easiest way to
@@ -687,10 +693,20 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
687693
for (size_t i = 0; i < timestamps.size(); ++i) {
688694
auto frameSeconds = timestamps[i];
689695
TORCH_CHECK(
690-
frameSeconds >= minSeconds && frameSeconds < maxSeconds,
696+
frameSeconds >= minSeconds,
691697
"frame pts is " + std::to_string(frameSeconds) +
692-
"; must be in range [" + std::to_string(minSeconds) + ", " +
693-
std::to_string(maxSeconds) + ").");
698+
"; must be greater than or equal to " + std::to_string(minSeconds) +
699+
".");
700+
701+
// Note that if we can't determine the maximum number of seconds from the
702+
// metadata, then we assume the frame's pts is valid.
703+
if (maxSeconds.has_value()) {
704+
TORCH_CHECK(
705+
frameSeconds < maxSeconds.value(),
706+
"frame pts is " + std::to_string(frameSeconds) +
707+
"; must be less than " + std::to_string(maxSeconds.value()) +
708+
".");
709+
}
694710

695711
frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
696712
}
@@ -737,17 +753,26 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
737753
}
738754

739755
double minSeconds = getMinSeconds(streamMetadata);
740-
double maxSeconds = getMaxSeconds(streamMetadata);
741756
TORCH_CHECK(
742-
startSeconds >= minSeconds && startSeconds < maxSeconds,
757+
startSeconds >= minSeconds,
743758
"Start seconds is " + std::to_string(startSeconds) +
744-
"; must be in range [" + std::to_string(minSeconds) + ", " +
745-
std::to_string(maxSeconds) + ").");
746-
TORCH_CHECK(
747-
stopSeconds <= maxSeconds,
748-
"Stop seconds (" + std::to_string(stopSeconds) +
749-
"; must be less than or equal to " + std::to_string(maxSeconds) +
750-
").");
759+
"; must be greater than or equal to " + std::to_string(minSeconds) +
760+
".");
761+
762+
// Note that if we can't determine the maximum seconds from the metadata, then
763+
// we assume upper range is valid.
764+
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
765+
if (maxSeconds.has_value()) {
766+
TORCH_CHECK(
767+
startSeconds < maxSeconds.value(),
768+
"Start seconds is " + std::to_string(startSeconds) +
769+
"; must be less than " + std::to_string(maxSeconds.value()) + ".");
770+
TORCH_CHECK(
771+
stopSeconds <= maxSeconds.value(),
772+
"Stop seconds (" + std::to_string(stopSeconds) +
773+
"; must be less than or equal to " +
774+
std::to_string(maxSeconds.value()) + ").");
775+
}
751776

752777
// Note that we look at nextPts for a frame, and not its pts or duration.
753778
// Our abstract player displays frames starting at the pts for that frame
@@ -1456,16 +1481,13 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
14561481
// STREAM AND METADATA APIS
14571482
// --------------------------------------------------------------------------
14581483

1459-
int64_t SingleStreamDecoder::getNumFrames(
1484+
std::optional<int64_t> SingleStreamDecoder::getNumFrames(
14601485
const StreamMetadata& streamMetadata) {
14611486
switch (seekMode_) {
14621487
case SeekMode::exact:
14631488
return streamMetadata.numFramesFromScan.value();
14641489
case SeekMode::approximate: {
1465-
TORCH_CHECK(
1466-
streamMetadata.numFrames.has_value(),
1467-
"Cannot use approximate mode since we couldn't find the number of frames from the metadata.");
1468-
return streamMetadata.numFrames.value();
1490+
return streamMetadata.numFrames;
14691491
}
14701492
default:
14711493
throw std::runtime_error("Unknown SeekMode");
@@ -1484,16 +1506,13 @@ double SingleStreamDecoder::getMinSeconds(
14841506
}
14851507
}
14861508

1487-
double SingleStreamDecoder::getMaxSeconds(
1509+
std::optional<double> SingleStreamDecoder::getMaxSeconds(
14881510
const StreamMetadata& streamMetadata) {
14891511
switch (seekMode_) {
14901512
case SeekMode::exact:
14911513
return streamMetadata.maxPtsSecondsFromScan.value();
14921514
case SeekMode::approximate: {
1493-
TORCH_CHECK(
1494-
streamMetadata.durationSeconds.has_value(),
1495-
"Cannot use approximate mode since we couldn't find the duration from the metadata.");
1496-
return streamMetadata.durationSeconds.value();
1515+
return streamMetadata.durationSeconds;
14971516
}
14981517
default:
14991518
throw std::runtime_error("Unknown SeekMode");
@@ -1539,12 +1558,22 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
15391558
void SingleStreamDecoder::validateFrameIndex(
15401559
const StreamMetadata& streamMetadata,
15411560
int64_t frameIndex) {
1542-
int64_t numFrames = getNumFrames(streamMetadata);
15431561
TORCH_CHECK(
1544-
frameIndex >= 0 && frameIndex < numFrames,
1562+
frameIndex >= 0,
15451563
"Invalid frame index=" + std::to_string(frameIndex) +
15461564
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1547-
" numFrames=" + std::to_string(numFrames));
1565+
"; must be greater than or equal to 0");
1566+
1567+
// Note that if we do not have the number of frames available in our metadata,
1568+
// then we assume that the frameIndex is valid.
1569+
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
1570+
if (numFrames.has_value()) {
1571+
TORCH_CHECK(
1572+
frameIndex < numFrames.value(),
1573+
"Invalid frame index=" + std::to_string(frameIndex) +
1574+
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1575+
"; must be less than " + std::to_string(numFrames.value()));
1576+
}
15481577
}
15491578

15501579
// --------------------------------------------------------------------------

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,9 @@ class SingleStreamDecoder {
304304
// index. Note that this index may be truncated for some files.
305305
int getBestStreamIndex(AVMediaType mediaType);
306306

307-
int64_t getNumFrames(const StreamMetadata& streamMetadata);
307+
std::optional<int64_t> getNumFrames(const StreamMetadata& streamMetadata);
308308
double getMinSeconds(const StreamMetadata& streamMetadata);
309-
double getMaxSeconds(const StreamMetadata& streamMetadata);
309+
std::optional<double> getMaxSeconds(const StreamMetadata& streamMetadata);
310310

311311
// --------------------------------------------------------------------------
312312
// VALIDATION UTILS

test/test_decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,10 +597,10 @@ def test_get_frames_played_at(self, device, seek_mode):
597597
def test_get_frames_played_at_fails(self, device, seek_mode):
598598
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
599599

600-
with pytest.raises(RuntimeError, match="must be in range"):
600+
with pytest.raises(RuntimeError, match="must be greater than or equal to"):
601601
decoder.get_frames_played_at([-1])
602602

603-
with pytest.raises(RuntimeError, match="must be in range"):
603+
with pytest.raises(RuntimeError, match="must be less than"):
604604
decoder.get_frames_played_at([14])
605605

606606
with pytest.raises(RuntimeError, match="Expected a value of type"):

0 commit comments

Comments
 (0)