@@ -602,16 +602,22 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
602
602
const auto & streamMetadata =
603
603
containerMetadata_.allStreamMetadata [activeStreamIndex_];
604
604
const auto & streamInfo = streamInfos_[activeStreamIndex_];
605
- int64_t numFrames = getNumFrames (streamMetadata);
606
605
TORCH_CHECK (
607
606
start >= 0 , " Range start, " + std::to_string (start) + " is less than 0." );
608
- TORCH_CHECK (
609
- stop <= numFrames,
610
- " Range stop, " + std::to_string (stop) +
611
- " , is more than the number of frames, " + std::to_string (numFrames));
612
607
TORCH_CHECK (
613
608
step > 0 , " Step must be greater than 0; is " + std::to_string (step));
614
609
610
+ // Note that if we do not have the number of frames available in our metadata,
611
+ // then we assume that the upper part of the range is valid.
612
+ std::optional<int64_t > numFrames = getNumFrames (streamMetadata);
613
+ if (numFrames.has_value ()) {
614
+ TORCH_CHECK (
615
+ stop <= numFrames.value (),
616
+ " Range stop, " + std::to_string (stop) +
617
+ " , is more than the number of frames, " +
618
+ std::to_string (numFrames.value ()));
619
+ }
620
+
615
621
int64_t numOutputFrames = std::ceil ((stop - start) / double (step));
616
622
const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
617
623
FrameBatchOutput frameBatchOutput (
@@ -676,7 +682,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
676
682
containerMetadata_.allStreamMetadata [activeStreamIndex_];
677
683
678
684
double minSeconds = getMinSeconds (streamMetadata);
679
- double maxSeconds = getMaxSeconds (streamMetadata);
685
+ std::optional< double > maxSeconds = getMaxSeconds (streamMetadata);
680
686
681
687
// The frame played at timestamp t and the one played at timestamp `t +
682
688
// eps` are probably the same frame, with the same index. The easiest way to
@@ -687,10 +693,20 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
687
693
for (size_t i = 0 ; i < timestamps.size (); ++i) {
688
694
auto frameSeconds = timestamps[i];
689
695
TORCH_CHECK (
690
- frameSeconds >= minSeconds && frameSeconds < maxSeconds ,
696
+ frameSeconds >= minSeconds,
691
697
" frame pts is " + std::to_string (frameSeconds) +
692
- " ; must be in range [" + std::to_string (minSeconds) + " , " +
693
- std::to_string (maxSeconds) + " )." );
698
+ " ; must be greater than or equal to " + std::to_string (minSeconds) +
699
+ " ." );
700
+
701
+ // Note that if we can't determine the maximum number of seconds from the
702
+ // metadata, then we assume the frame's pts is valid.
703
+ if (maxSeconds.has_value ()) {
704
+ TORCH_CHECK (
705
+ frameSeconds < maxSeconds.value (),
706
+ " frame pts is " + std::to_string (frameSeconds) +
707
+ " ; must be less than " + std::to_string (maxSeconds.value ()) +
708
+ " ." );
709
+ }
694
710
695
711
frameIndices[i] = secondsToIndexLowerBound (frameSeconds);
696
712
}
@@ -737,17 +753,26 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
737
753
}
738
754
739
755
double minSeconds = getMinSeconds (streamMetadata);
740
- double maxSeconds = getMaxSeconds (streamMetadata);
741
756
TORCH_CHECK (
742
- startSeconds >= minSeconds && startSeconds < maxSeconds ,
757
+ startSeconds >= minSeconds,
743
758
" Start seconds is " + std::to_string (startSeconds) +
744
- " ; must be in range [" + std::to_string (minSeconds) + " , " +
745
- std::to_string (maxSeconds) + " )." );
746
- TORCH_CHECK (
747
- stopSeconds <= maxSeconds,
748
- " Stop seconds (" + std::to_string (stopSeconds) +
749
- " ; must be less than or equal to " + std::to_string (maxSeconds) +
750
- " )." );
759
+ " ; must be greater than or equal to " + std::to_string (minSeconds) +
760
+ " ." );
761
+
762
+ // Note that if we can't determine the maximum seconds from the metadata, then
763
+ // we assume upper range is valid.
764
+ std::optional<double > maxSeconds = getMaxSeconds (streamMetadata);
765
+ if (maxSeconds.has_value ()) {
766
+ TORCH_CHECK (
767
+ startSeconds < maxSeconds.value (),
768
+ " Start seconds is " + std::to_string (startSeconds) +
769
+ " ; must be less than " + std::to_string (maxSeconds.value ()) + " ." );
770
+ TORCH_CHECK (
771
+ stopSeconds <= maxSeconds.value (),
772
+ " Stop seconds (" + std::to_string (stopSeconds) +
773
+ " ; must be less than or equal to " +
774
+ std::to_string (maxSeconds.value ()) + " )." );
775
+ }
751
776
752
777
// Note that we look at nextPts for a frame, and not its pts or duration.
753
778
// Our abstract player displays frames starting at the pts for that frame
@@ -1456,16 +1481,13 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
1456
1481
// STREAM AND METADATA APIS
1457
1482
// --------------------------------------------------------------------------
1458
1483
1459
- int64_t SingleStreamDecoder::getNumFrames (
1484
+ std::optional< int64_t > SingleStreamDecoder::getNumFrames (
1460
1485
const StreamMetadata& streamMetadata) {
1461
1486
switch (seekMode_) {
1462
1487
case SeekMode::exact:
1463
1488
return streamMetadata.numFramesFromScan .value ();
1464
1489
case SeekMode::approximate: {
1465
- TORCH_CHECK (
1466
- streamMetadata.numFrames .has_value (),
1467
- " Cannot use approximate mode since we couldn't find the number of frames from the metadata." );
1468
- return streamMetadata.numFrames .value ();
1490
+ return streamMetadata.numFrames ;
1469
1491
}
1470
1492
default :
1471
1493
throw std::runtime_error (" Unknown SeekMode" );
@@ -1484,16 +1506,13 @@ double SingleStreamDecoder::getMinSeconds(
1484
1506
}
1485
1507
}
1486
1508
1487
- double SingleStreamDecoder::getMaxSeconds (
1509
+ std::optional< double > SingleStreamDecoder::getMaxSeconds (
1488
1510
const StreamMetadata& streamMetadata) {
1489
1511
switch (seekMode_) {
1490
1512
case SeekMode::exact:
1491
1513
return streamMetadata.maxPtsSecondsFromScan .value ();
1492
1514
case SeekMode::approximate: {
1493
- TORCH_CHECK (
1494
- streamMetadata.durationSeconds .has_value (),
1495
- " Cannot use approximate mode since we couldn't find the duration from the metadata." );
1496
- return streamMetadata.durationSeconds .value ();
1515
+ return streamMetadata.durationSeconds ;
1497
1516
}
1498
1517
default :
1499
1518
throw std::runtime_error (" Unknown SeekMode" );
@@ -1539,12 +1558,22 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
1539
1558
void SingleStreamDecoder::validateFrameIndex (
1540
1559
const StreamMetadata& streamMetadata,
1541
1560
int64_t frameIndex) {
1542
- int64_t numFrames = getNumFrames (streamMetadata);
1543
1561
TORCH_CHECK (
1544
- frameIndex >= 0 && frameIndex < numFrames ,
1562
+ frameIndex >= 0 ,
1545
1563
" Invalid frame index=" + std::to_string (frameIndex) +
1546
1564
" for streamIndex=" + std::to_string (streamMetadata.streamIndex ) +
1547
- " numFrames=" + std::to_string (numFrames));
1565
+ " ; must be greater than or equal to 0" );
1566
+
1567
+ // Note that if we do not have the number of frames available in our metadata,
1568
+ // then we assume that the frameIndex is valid.
1569
+ std::optional<int64_t > numFrames = getNumFrames (streamMetadata);
1570
+ if (numFrames.has_value ()) {
1571
+ TORCH_CHECK (
1572
+ frameIndex < numFrames.value (),
1573
+ " Invalid frame index=" + std::to_string (frameIndex) +
1574
+ " for streamIndex=" + std::to_string (streamMetadata.streamIndex ) +
1575
+ " ; must be less than " + std::to_string (numFrames.value ()));
1576
+ }
1548
1577
}
1549
1578
1550
1579
// --------------------------------------------------------------------------
0 commit comments