Skip to content

Commit 8ad810e

Browse files
authored
fix PLI and FIR handling, wrongly triggering track.OnEnded (#420)
1 parent 6f204fa commit 8ad810e

File tree

2 files changed

+137
-26
lines changed

2 files changed

+137
-26
lines changed

track.go

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
const (
2525
rtpOutboundMTU = 1200
26+
rtcpInboundMTU = 1500
2627
)
2728

2829
var (
@@ -223,38 +224,48 @@ func (track *baseTrack) bind(ctx webrtc.TrackLocalContext, specializedTrack Trac
223224
keyFrameController, ok := encodedReader.Controller().(codec.KeyFrameController)
224225
if ok {
225226
stopRead = make(chan struct{})
226-
go func() {
227-
reader := ctx.RTCPReader()
228-
for {
229-
select {
230-
case <-stopRead:
231-
return
232-
default:
233-
}
227+
go track.rtcpReadLoop(ctx.RTCPReader(), keyFrameController, stopRead)
228+
}
234229

235-
var readerBuffer []byte
236-
_, _, err := reader.Read(readerBuffer, interceptor.Attributes{})
237-
if err != nil {
238-
track.onError(err)
239-
return
240-
}
230+
return selectedCodec, nil
231+
}
241232

242-
pkts, err := rtcp.Unmarshal(readerBuffer)
233+
func (track *baseTrack) rtcpReadLoop(reader interceptor.RTCPReader, keyFrameController codec.KeyFrameController, stopRead chan struct{}) {
234+
readerBuffer := make([]byte, rtcpInboundMTU)
243235

244-
for _, pkt := range pkts {
245-
switch pkt.(type) {
246-
case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest:
247-
if err := keyFrameController.ForceKeyFrame(); err != nil {
248-
track.onError(err)
249-
return
250-
}
251-
}
236+
readLoop:
237+
for {
238+
select {
239+
case <-stopRead:
240+
return
241+
default:
242+
}
243+
244+
readLength, _, err := reader.Read(readerBuffer, interceptor.Attributes{})
245+
if err != nil {
246+
if errors.Is(err, io.EOF) {
247+
return
248+
}
249+
logger.Warnf("failed to read rtcp packet: %s", err)
250+
continue
251+
}
252+
253+
pkts, err := rtcp.Unmarshal(readerBuffer[:readLength])
254+
if err != nil {
255+
logger.Warnf("failed to unmarshal rtcp packet: %s", err)
256+
continue
257+
}
258+
259+
for _, pkt := range pkts {
260+
switch pkt.(type) {
261+
case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest:
262+
if err := keyFrameController.ForceKeyFrame(); err != nil {
263+
logger.Warnf("failed to force key frame: %s", err)
264+
continue readLoop
252265
}
253266
}
254-
}()
267+
}
255268
}
256-
257-
return selectedCodec, nil
258269
}
259270

260271
func (track *baseTrack) unbind(ctx webrtc.TrackLocalContext) error {

track_test.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package mediadevices
22

33
import (
44
"errors"
5+
"github.com/pion/interceptor"
6+
"io"
57
"testing"
68
"time"
79
)
@@ -53,3 +55,101 @@ func TestOnEnded(t *testing.T) {
5355
}
5456
})
5557
}
58+
59+
type fakeRTCPReader struct {
60+
mockReturn chan []byte
61+
end chan struct{}
62+
}
63+
64+
func (mock *fakeRTCPReader) Read(buffer []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) {
65+
select {
66+
case <-mock.end:
67+
return 0, nil, io.EOF
68+
case mockReturn := <-mock.mockReturn:
69+
if len(buffer) < len(mock.mockReturn) {
70+
return 0, nil, io.ErrShortBuffer
71+
}
72+
73+
return copy(buffer, mockReturn), attributes, nil
74+
}
75+
}
76+
77+
type fakeKeyFrameController struct {
78+
called chan struct{}
79+
}
80+
81+
func (mock *fakeKeyFrameController) ForceKeyFrame() error {
82+
mock.called <- struct{}{}
83+
return nil
84+
}
85+
86+
func TestRtcpHandler(t *testing.T) {
87+
88+
t.Run("ShouldStopReading", func(t *testing.T) {
89+
tr := &baseTrack{}
90+
stop := make(chan struct{}, 1)
91+
stopped := make(chan struct{})
92+
go func() {
93+
tr.rtcpReadLoop(&fakeRTCPReader{end: stop}, &fakeKeyFrameController{}, stop)
94+
stopped <- struct{}{}
95+
}()
96+
97+
stop <- struct{}{}
98+
99+
select {
100+
case <-time.After(100 * time.Millisecond):
101+
t.Error("Timeout")
102+
case <-stopped:
103+
}
104+
})
105+
106+
t.Run("ShouldForceKeyFrame", func(t *testing.T) {
107+
for packetType, packet := range map[string][]byte{
108+
"PLI": {
109+
// v=2, p=0, FMT=1, PSFB, len=1
110+
0x81, 0xce, 0x00, 0x02,
111+
// ssrc=0x0
112+
0x00, 0x00, 0x00, 0x00,
113+
// ssrc=0x4bc4fcb4
114+
0x4b, 0xc4, 0xfc, 0xb4,
115+
},
116+
"FIR": {
117+
// v=2, p=0, FMT=4, PSFB, len=3
118+
0x84, 0xce, 0x00, 0x04,
119+
// ssrc=0x0
120+
0x00, 0x00, 0x00, 0x00,
121+
// ssrc=0x4bc4fcb4
122+
0x4b, 0xc4, 0xfc, 0xb4,
123+
// ssrc=0x12345678
124+
0x12, 0x34, 0x56, 0x78,
125+
// Seqno=0x42
126+
0x42, 0x00, 0x00, 0x00,
127+
},
128+
} {
129+
t.Run(packetType, func(t *testing.T) {
130+
tr := &baseTrack{}
131+
tr.OnEnded(func(err error) {
132+
if err != io.EOF {
133+
t.Error(err)
134+
}
135+
})
136+
stop := make(chan struct{}, 1)
137+
defer func() {
138+
stop <- struct{}{}
139+
}()
140+
mockKeyFrameController := &fakeKeyFrameController{called: make(chan struct{}, 1)}
141+
mockRTCPReader := &fakeRTCPReader{end: stop, mockReturn: make(chan []byte, 1)}
142+
143+
go tr.rtcpReadLoop(mockRTCPReader, mockKeyFrameController, stop)
144+
145+
mockRTCPReader.mockReturn <- packet
146+
147+
select {
148+
case <-time.After(1000 * time.Millisecond):
149+
t.Error("Timeout")
150+
case <-mockKeyFrameController.called:
151+
}
152+
})
153+
}
154+
})
155+
}

0 commit comments

Comments
 (0)