Skip to content

Commit a71cb0a

Browse files
committed
Add support for TWCC
1 parent 1fb32cc commit a71cb0a

File tree

4 files changed

+297
-20
lines changed

4 files changed

+297
-20
lines changed

pkg/ccfb/history.go

+19-10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package ccfb
33
import (
44
"container/list"
55
"errors"
6+
"sync"
67
"time"
78

89
"github.com/pion/interceptor/internal/sequencenumber"
@@ -30,6 +31,7 @@ type sentPacket struct {
3031
}
3132

3233
type history struct {
34+
lock sync.Mutex
3335
size int
3436
evictList *list.List
3537
seqNrToPacket map[int64]*list.Element
@@ -39,6 +41,7 @@ type history struct {
3941

4042
func newHistory(size int) *history {
4143
return &history{
44+
lock: sync.Mutex{},
4245
size: size,
4346
evictList: list.New(),
4447
seqNrToPacket: make(map[int64]*list.Element),
@@ -48,6 +51,9 @@ func newHistory(size int) *history {
4851
}
4952

5053
func (h *history) add(seqNr uint16, size uint16, departure time.Time) error {
54+
h.lock.Lock()
55+
defer h.lock.Unlock()
56+
5157
sn := h.sentSeqNr.Unwrap(seqNr)
5258
last := h.evictList.Back()
5359
if last != nil {
@@ -65,11 +71,23 @@ func (h *history) add(seqNr uint16, size uint16, departure time.Time) error {
6571
if h.evictList.Len() > h.size {
6672
h.removeOldest()
6773
}
68-
6974
return nil
7075
}
7176

77+
// Must be called while holding the lock
78+
func (h *history) removeOldest() {
79+
if ent := h.evictList.Front(); ent != nil {
80+
v := h.evictList.Remove(ent)
81+
if sp, ok := v.(sentPacket); ok {
82+
delete(h.seqNrToPacket, sp.seqNr)
83+
}
84+
}
85+
}
86+
7287
func (h *history) getReportForAck(al acknowledgementList) PacketReportList {
88+
h.lock.Lock()
89+
defer h.lock.Unlock()
90+
7391
var reports []PacketReport
7492
for _, pr := range al.acks {
7593
sn := h.ackedSeqNr.Unwrap(pr.seqNr)
@@ -95,12 +113,3 @@ func (h *history) getReportForAck(al acknowledgementList) PacketReportList {
95113
Reports: reports,
96114
}
97115
}
98-
99-
func (h *history) removeOldest() {
100-
if ent := h.evictList.Front(); ent != nil {
101-
v := h.evictList.Remove(ent)
102-
if sp, ok := v.(sentPacket); ok {
103-
delete(h.seqNrToPacket, sp.seqNr)
104-
}
105-
}
106-
}

pkg/ccfb/interceptor.go

+43-10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"github.com/pion/rtp"
1010
)
1111

12+
const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01"
13+
1214
type ccfbAttributesKeyType uint32
1315

1416
const CCFBAttributesKey ccfbAttributesKeyType = iota
@@ -48,14 +50,42 @@ type Interceptor struct {
4850

4951
// BindLocalStream implements interceptor.Interceptor.
5052
func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
53+
var twccHdrExtID uint8
54+
var useTWCC bool
55+
for _, e := range info.RTPHeaderExtensions {
56+
if e.URI == transportCCURI {
57+
twccHdrExtID = uint8(e.ID)
58+
useTWCC = true
59+
break
60+
}
61+
}
62+
5163
i.lock.Lock()
5264
defer i.lock.Unlock()
53-
i.ssrcToHistory[info.SSRC] = newHistory(200)
65+
66+
ssrc := info.SSRC
67+
if useTWCC {
68+
ssrc = 0
69+
}
70+
i.ssrcToHistory[ssrc] = newHistory(200)
5471

5572
return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
5673
i.lock.Lock()
5774
defer i.lock.Unlock()
58-
i.ssrcToHistory[header.SSRC].add(header.SequenceNumber, uint16(header.MarshalSize()+len(payload)), i.timestamp())
75+
76+
// If we are using TWCC, we use the sequence number from the TWCC header
77+
// extension and save all TWCC sequence numbers with the same SSRC (0).
78+
// If we are not using TWCC, we save a history per SSRC and use the
79+
// normal RTP sequence numbers.
80+
ssrc := header.SSRC
81+
seqNr := header.SequenceNumber
82+
if useTWCC {
83+
ssrc = 0
84+
var twccHdrExt rtp.TransportCCExtension
85+
twccHdrExt.Unmarshal(header.GetExtension(twccHdrExtID))
86+
seqNr = twccHdrExt.TransportSequence
87+
}
88+
i.ssrcToHistory[ssrc].add(seqNr, uint16(header.MarshalSize()+len(payload)), i.timestamp())
5989
return writer.Write(header, payload, attributes)
6090
})
6191
}
@@ -80,16 +110,19 @@ func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.
80110

81111
pkts, err := attr.GetRTCPPackets(buf)
82112
for _, pkt := range pkts {
113+
var reportLists map[uint32]acknowledgementList
83114
switch fb := pkt.(type) {
84115
case *rtcp.CCFeedbackReport:
85-
reportLists := convertCCFB(now, fb)
86-
for ssrc, reportList := range reportLists {
87-
prl := i.ssrcToHistory[ssrc].getReportForAck(reportList)
88-
if l, ok := pktReportLists[ssrc]; !ok {
89-
pktReportLists[ssrc] = &prl
90-
} else {
91-
l.Reports = append(l.Reports, prl.Reports...)
92-
}
116+
reportLists = convertCCFB(now, fb)
117+
case *rtcp.TransportLayerCC:
118+
reportLists = convertTWCC(now, fb)
119+
}
120+
for ssrc, reportList := range reportLists {
121+
prl := i.ssrcToHistory[ssrc].getReportForAck(reportList)
122+
if l, ok := pktReportLists[ssrc]; !ok {
123+
pktReportLists[ssrc] = &prl
124+
} else {
125+
l.Reports = append(l.Reports, prl.Reports...)
93126
}
94127
}
95128
}

pkg/ccfb/twcc_receiver.go

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package ccfb
2+
3+
import (
4+
"log"
5+
"time"
6+
7+
"github.com/pion/rtcp"
8+
)
9+
10+
func convertTWCC(ts time.Time, feedback *rtcp.TransportLayerCC) map[uint32]acknowledgementList {
11+
log.Printf("got twcc report: %v", feedback)
12+
if feedback == nil {
13+
return nil
14+
}
15+
var acks []acknowledgement
16+
17+
nextTimestamp := time.Time{}.Add(time.Duration(feedback.ReferenceTime) * 64 * time.Millisecond)
18+
recvDeltaIndex := 0
19+
20+
offset := 0
21+
for _, pc := range feedback.PacketChunks {
22+
switch chunk := pc.(type) {
23+
case *rtcp.RunLengthChunk:
24+
for i := uint16(0); i < chunk.RunLength; i++ {
25+
seqNr := feedback.BaseSequenceNumber + uint16(offset)
26+
offset++
27+
switch chunk.PacketStatusSymbol {
28+
case rtcp.TypeTCCPacketNotReceived:
29+
acks = append(acks, acknowledgement{
30+
seqNr: seqNr,
31+
arrived: false,
32+
arrival: time.Time{},
33+
ecn: 0,
34+
})
35+
case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta:
36+
delta := feedback.RecvDeltas[recvDeltaIndex]
37+
nextTimestamp = nextTimestamp.Add(time.Duration(delta.Delta) * time.Microsecond)
38+
recvDeltaIndex++
39+
acks = append(acks, acknowledgement{
40+
seqNr: seqNr,
41+
arrived: true,
42+
arrival: nextTimestamp,
43+
ecn: 0,
44+
})
45+
case rtcp.TypeTCCPacketReceivedWithoutDelta:
46+
acks = append(acks, acknowledgement{
47+
seqNr: seqNr,
48+
arrived: true,
49+
arrival: time.Time{},
50+
ecn: 0,
51+
})
52+
}
53+
}
54+
case *rtcp.StatusVectorChunk:
55+
for _, s := range chunk.SymbolList {
56+
seqNr := feedback.BaseSequenceNumber + uint16(offset)
57+
offset++
58+
switch s {
59+
case rtcp.TypeTCCPacketNotReceived:
60+
acks = append(acks, acknowledgement{
61+
seqNr: seqNr,
62+
arrived: false,
63+
arrival: time.Time{},
64+
ecn: 0,
65+
})
66+
case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta:
67+
delta := feedback.RecvDeltas[recvDeltaIndex]
68+
nextTimestamp = nextTimestamp.Add(time.Duration(delta.Delta) * time.Microsecond)
69+
recvDeltaIndex++
70+
acks = append(acks, acknowledgement{
71+
seqNr: seqNr,
72+
arrived: true,
73+
arrival: nextTimestamp,
74+
ecn: 0,
75+
})
76+
case rtcp.TypeTCCPacketReceivedWithoutDelta:
77+
acks = append(acks, acknowledgement{
78+
seqNr: seqNr,
79+
arrived: true,
80+
arrival: time.Time{},
81+
ecn: 0,
82+
})
83+
}
84+
}
85+
}
86+
}
87+
88+
return map[uint32]acknowledgementList{
89+
0: {
90+
ts: ts,
91+
acks: acks,
92+
},
93+
}
94+
}

pkg/ccfb/twcc_receiver_test.go

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package ccfb
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
"time"
7+
8+
"github.com/pion/rtcp"
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestConvertTWCC(t *testing.T) {
13+
timeZero := time.Now()
14+
cases := []struct {
15+
ts time.Time
16+
feedback *rtcp.TransportLayerCC
17+
expect map[uint32]acknowledgementList
18+
}{
19+
{},
20+
{
21+
ts: timeZero.Add(2 * time.Second),
22+
feedback: &rtcp.TransportLayerCC{
23+
SenderSSRC: 1,
24+
MediaSSRC: 2,
25+
BaseSequenceNumber: 178,
26+
PacketStatusCount: 0,
27+
ReferenceTime: 0,
28+
FbPktCount: 0,
29+
PacketChunks: []rtcp.PacketStatusChunk{},
30+
RecvDeltas: []*rtcp.RecvDelta{},
31+
},
32+
expect: map[uint32]acknowledgementList{
33+
2: {
34+
ts: timeZero.Add(2 * time.Second),
35+
acks: []acknowledgement{},
36+
},
37+
},
38+
},
39+
{
40+
ts: timeZero.Add(2 * time.Second),
41+
feedback: &rtcp.TransportLayerCC{
42+
SenderSSRC: 1,
43+
MediaSSRC: 2,
44+
BaseSequenceNumber: 178,
45+
PacketStatusCount: 3,
46+
ReferenceTime: 0,
47+
FbPktCount: 0,
48+
PacketChunks: []rtcp.PacketStatusChunk{
49+
&rtcp.RunLengthChunk{
50+
PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta,
51+
RunLength: 3,
52+
},
53+
&rtcp.StatusVectorChunk{
54+
SymbolSize: rtcp.TypeTCCSymbolSizeOneBit,
55+
SymbolList: []uint16{
56+
rtcp.TypeTCCPacketReceivedSmallDelta,
57+
rtcp.TypeTCCPacketReceivedSmallDelta,
58+
rtcp.TypeTCCPacketReceivedSmallDelta,
59+
rtcp.TypeTCCPacketNotReceived,
60+
rtcp.TypeTCCPacketNotReceived,
61+
rtcp.TypeTCCPacketNotReceived,
62+
rtcp.TypeTCCPacketNotReceived,
63+
rtcp.TypeTCCPacketNotReceived,
64+
},
65+
},
66+
&rtcp.StatusVectorChunk{
67+
SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit,
68+
SymbolList: []uint16{
69+
rtcp.TypeTCCPacketReceivedLargeDelta,
70+
rtcp.TypeTCCPacketReceivedLargeDelta,
71+
rtcp.TypeTCCPacketNotReceived,
72+
rtcp.TypeTCCPacketNotReceived,
73+
rtcp.TypeTCCPacketNotReceived,
74+
rtcp.TypeTCCPacketNotReceived,
75+
rtcp.TypeTCCPacketNotReceived,
76+
},
77+
},
78+
},
79+
RecvDeltas: []*rtcp.RecvDelta{
80+
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
81+
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
82+
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
83+
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
84+
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
85+
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
86+
{Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 0},
87+
{Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 0},
88+
},
89+
},
90+
expect: map[uint32]acknowledgementList{
91+
2: {
92+
ts: timeZero.Add(2 * time.Second),
93+
acks: []acknowledgement{
94+
// first run length chunk
95+
{seqNr: 178, arrived: true, arrival: time.Time{}, ecn: 0},
96+
{seqNr: 179, arrived: true, arrival: time.Time{}, ecn: 0},
97+
{seqNr: 180, arrived: true, arrival: time.Time{}, ecn: 0},
98+
99+
// first status vector chunk
100+
{seqNr: 181, arrived: true, arrival: time.Time{}, ecn: 0},
101+
{seqNr: 182, arrived: true, arrival: time.Time{}, ecn: 0},
102+
{seqNr: 183, arrived: true, arrival: time.Time{}, ecn: 0},
103+
{seqNr: 184, arrived: false, arrival: time.Time{}, ecn: 0},
104+
{seqNr: 185, arrived: false, arrival: time.Time{}, ecn: 0},
105+
{seqNr: 186, arrived: false, arrival: time.Time{}, ecn: 0},
106+
{seqNr: 187, arrived: false, arrival: time.Time{}, ecn: 0},
107+
{seqNr: 188, arrived: false, arrival: time.Time{}, ecn: 0},
108+
109+
// second status vector chunk
110+
{seqNr: 189, arrived: true, arrival: time.Time{}, ecn: 0},
111+
{seqNr: 190, arrived: true, arrival: time.Time{}, ecn: 0},
112+
{seqNr: 191, arrived: false, arrival: time.Time{}, ecn: 0},
113+
{seqNr: 192, arrived: false, arrival: time.Time{}, ecn: 0},
114+
{seqNr: 193, arrived: false, arrival: time.Time{}, ecn: 0},
115+
{seqNr: 194, arrived: false, arrival: time.Time{}, ecn: 0},
116+
{seqNr: 195, arrived: false, arrival: time.Time{}, ecn: 0},
117+
},
118+
},
119+
},
120+
},
121+
}
122+
for i, tc := range cases {
123+
t.Run(fmt.Sprintf("%v", i), func(t *testing.T) {
124+
res := convertTWCC(tc.ts, tc.feedback)
125+
126+
// Can't directly check equality since arrival timestamp conversions
127+
// may be slightly off due to ntp conversions.
128+
assert.Equal(t, len(tc.expect), len(res))
129+
for i, ee := range tc.expect {
130+
assert.Equal(t, ee.ts, res[i].ts)
131+
for j, ack := range ee.acks {
132+
assert.Equal(t, ack.seqNr, res[i].acks[j].seqNr)
133+
assert.Equal(t, ack.arrived, res[i].acks[j].arrived)
134+
assert.Equal(t, ack.ecn, res[i].acks[j].ecn)
135+
assert.InDelta(t, ack.arrival.UnixNano(), res[i].acks[j].arrival.UnixNano(), float64(time.Millisecond.Nanoseconds()))
136+
}
137+
}
138+
})
139+
}
140+
141+
}

0 commit comments

Comments
 (0)