diff --git a/track.go b/track.go index 5417c29a..8feddb18 100644 --- a/track.go +++ b/track.go @@ -23,6 +23,7 @@ import ( const ( rtpOutboundMTU = 1200 + rtcpInboundMTU = 1500 ) var ( @@ -223,38 +224,48 @@ func (track *baseTrack) bind(ctx webrtc.TrackLocalContext, specializedTrack Trac keyFrameController, ok := encodedReader.Controller().(codec.KeyFrameController) if ok { stopRead = make(chan struct{}) - go func() { - reader := ctx.RTCPReader() - for { - select { - case <-stopRead: - return - default: - } + go track.rtcpReadLoop(ctx.RTCPReader(), keyFrameController, stopRead) + } - var readerBuffer []byte - _, _, err := reader.Read(readerBuffer, interceptor.Attributes{}) - if err != nil { - track.onError(err) - return - } + return selectedCodec, nil +} - pkts, err := rtcp.Unmarshal(readerBuffer) +func (track *baseTrack) rtcpReadLoop(reader interceptor.RTCPReader, keyFrameController codec.KeyFrameController, stopRead chan struct{}) { + readerBuffer := make([]byte, rtcpInboundMTU) - for _, pkt := range pkts { - switch pkt.(type) { - case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest: - if err := keyFrameController.ForceKeyFrame(); err != nil { - track.onError(err) - return - } - } +readLoop: + for { + select { + case <-stopRead: + return + default: + } + + readLength, _, err := reader.Read(readerBuffer, interceptor.Attributes{}) + if err != nil { + if errors.Is(err, io.EOF) { + return + } + logger.Warnf("failed to read rtcp packet: %s", err) + continue + } + + pkts, err := rtcp.Unmarshal(readerBuffer[:readLength]) + if err != nil { + logger.Warnf("failed to unmarshal rtcp packet: %s", err) + continue + } + + for _, pkt := range pkts { + switch pkt.(type) { + case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest: + if err := keyFrameController.ForceKeyFrame(); err != nil { + logger.Warnf("failed to force key frame: %s", err) + continue readLoop } } - }() + } } - - return selectedCodec, nil } func (track *baseTrack) unbind(ctx webrtc.TrackLocalContext) error { diff --git a/track_test.go b/track_test.go index 18b58697..d01bdc6f 100644 --- a/track_test.go +++ b/track_test.go @@ -2,6 +2,8 @@ package mediadevices import ( "errors" + "github.com/pion/interceptor" + "io" "testing" "time" ) @@ -53,3 +55,101 @@ func TestOnEnded(t *testing.T) { } }) } + +type fakeRTCPReader struct { + mockReturn chan []byte + end chan struct{} +} + +func (mock *fakeRTCPReader) Read(buffer []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { + select { + case <-mock.end: + return 0, nil, io.EOF + case mockReturn := <-mock.mockReturn: + if len(buffer) < len(mock.mockReturn) { + return 0, nil, io.ErrShortBuffer + } + + return copy(buffer, mockReturn), attributes, nil + } +} + +type fakeKeyFrameController struct { + called chan struct{} +} + +func (mock *fakeKeyFrameController) ForceKeyFrame() error { + mock.called <- struct{}{} + return nil +} + +func TestRtcpHandler(t *testing.T) { + + t.Run("ShouldStopReading", func(t *testing.T) { + tr := &baseTrack{} + stop := make(chan struct{}, 1) + stopped := make(chan struct{}) + go func() { + tr.rtcpReadLoop(&fakeRTCPReader{end: stop}, &fakeKeyFrameController{}, stop) + stopped <- struct{}{} + }() + + stop <- struct{}{} + + select { + case <-time.After(100 * time.Millisecond): + t.Error("Timeout") + case <-stopped: + } + }) + + t.Run("ShouldForceKeyFrame", func(t *testing.T) { + for packetType, packet := range map[string][]byte{ + "PLI": { + // v=2, p=0, FMT=1, PSFB, len=1 + 0x81, 0xce, 0x00, 0x02, + // ssrc=0x0 + 0x00, 0x00, 0x00, 0x00, + // ssrc=0x4bc4fcb4 + 0x4b, 0xc4, 0xfc, 0xb4, + }, + "FIR": { + // v=2, p=0, FMT=4, PSFB, len=3 + 0x84, 0xce, 0x00, 0x04, + // ssrc=0x0 + 0x00, 0x00, 0x00, 0x00, + // ssrc=0x4bc4fcb4 + 0x4b, 0xc4, 0xfc, 0xb4, + // ssrc=0x12345678 + 0x12, 0x34, 0x56, 0x78, + // Seqno=0x42 + 0x42, 0x00, 0x00, 0x00, + }, + } { + t.Run(packetType, func(t *testing.T) { + tr := &baseTrack{} + tr.OnEnded(func(err error) { + if err != io.EOF { + t.Error(err) + } + }) + stop := make(chan struct{}, 1) + defer func() { + stop <- struct{}{} + }() + mockKeyFrameController := &fakeKeyFrameController{called: make(chan struct{}, 1)} + mockRTCPReader := &fakeRTCPReader{end: stop, mockReturn: make(chan []byte, 1)} + + go tr.rtcpReadLoop(mockRTCPReader, mockKeyFrameController, stop) + + mockRTCPReader.mockReturn <- packet + + select { + case <-time.After(1000 * time.Millisecond): + t.Error("Timeout") + case <-mockKeyFrameController.called: + } + }) + } + }) +}