Skip to content

Commit 8c22f58

Browse files
committed
Add SetExtensionWithProfile
Allows for conversion between RTP Header Extension Profiles easily. 'One Byte' packets can be updated without doing a full re-creating of the Extension Headers. Resolves #255 Resolves #249
1 parent 1a05037 commit 8c22f58

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

packet.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,31 @@ func (h *Header) SetExtension(id uint8, payload []byte) error { //nolint:gocogni
435435
return nil
436436
}
437437

438+
// SetExtensionWithProfile sets an RTP header extension and converts Header Extension Profile if needed.
439+
func (h *Header) SetExtensionWithProfile(id uint8, payload []byte, intendedProfile uint16) error {
440+
if !h.Extension || h.ExtensionProfile == intendedProfile {
441+
return h.SetExtension(id, payload)
442+
}
443+
444+
// Don't mutate the packet if Set is going to fail anyway
445+
if err := headerExtensionCheck(intendedProfile, id, payload); err != nil {
446+
return err
447+
}
448+
449+
// If downgrading assert that existing Extensions will work
450+
if intendedProfile == ExtensionProfileOneByte {
451+
for i := range h.Extensions {
452+
if err := headerExtensionCheck(intendedProfile, h.Extensions[i].id, h.Extensions[i].payload); err != nil {
453+
return err
454+
}
455+
}
456+
}
457+
458+
h.ExtensionProfile = intendedProfile
459+
460+
return h.SetExtension(id, payload)
461+
}
462+
438463
// GetExtensionIDs returns an extension id array.
439464
func (h *Header) GetExtensionIDs() []uint8 {
440465
if !h.Extension {

packet_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,68 @@ func TestDeprecatedPaddingSizeField(t *testing.T) {
14141414
assert.EqualValues(t, 0, parsedPacket2.Header.PaddingSize)
14151415
}
14161416

1417+
func TestSetExtensionWithProfile(t *testing.T) {
1418+
t.Run("add two-byte extension due to the size > 16", func(t *testing.T) {
1419+
h := Header{}
1420+
assert.NoError(t, h.SetExtension(1, make([]byte, 2)))
1421+
assert.NoError(t, h.SetExtension(2, make([]byte, 3)))
1422+
1423+
// Adding another extension that requires two-byte header extension
1424+
assert.NoError(t, h.SetExtensionWithProfile(3, make([]byte, 20), ExtensionProfileTwoByte))
1425+
assert.Equal(t, h.ExtensionProfile, uint16(ExtensionProfileTwoByte))
1426+
})
1427+
1428+
t.Run("add two-byte extension due to id > 14", func(t *testing.T) {
1429+
h := Header{}
1430+
assert.NoError(t, h.SetExtension(1, make([]byte, 2)))
1431+
assert.NoError(t, h.SetExtension(2, make([]byte, 3)))
1432+
1433+
// Adding another extension that requires two-byte header extension
1434+
// because the extmap ID is greater than 14.
1435+
assert.NoError(t, h.SetExtensionWithProfile(16, make([]byte, 4), ExtensionProfileTwoByte))
1436+
assert.Equal(t, h.ExtensionProfile, uint16(ExtensionProfileTwoByte))
1437+
})
1438+
1439+
t.Run("Downgrade 2 byte header Extension", func(t *testing.T) {
1440+
pkt := []byte{
1441+
0x90, 0x60, 0x00, 0x01, // V=2, P=0, X=1, CC=0; M=0, PT=96; sequence=1
1442+
0x00, 0x00, 0x00, 0x01, // timestamp=1
1443+
0x12, 0x34, 0x56, 0x78, // SSRC=0x12345678
1444+
0x10, 0x00, 0x00, 0x01, // profile=0x1000 (two-byte), length=1 (4 bytes)
1445+
0x01, 0x02, 0x00, 0x01, // id=1, len=2, data=0x00,0x01 (padded to 32-bit)
1446+
}
1447+
h := Header{}
1448+
1449+
_, err := h.Unmarshal(pkt)
1450+
assert.NoError(t, err)
1451+
assert.Equal(t, h.ExtensionProfile, uint16(ExtensionProfileTwoByte))
1452+
1453+
assert.NoError(t, h.SetExtensionWithProfile(1, []byte{0x02, 0x03}, ExtensionProfileOneByte))
1454+
assert.Equal(t, h.ExtensionProfile, uint16(ExtensionProfileOneByte))
1455+
1456+
pkt, err = h.Marshal()
1457+
assert.NoError(t, err)
1458+
1459+
assert.Equal(t, pkt, []byte{
1460+
0x90, 0x60, 0x00, 0x01,
1461+
0x00, 0x00, 0x00, 0x01,
1462+
0x12, 0x34, 0x56, 0x78,
1463+
0xbe, 0xde, 0x00, 0x01,
1464+
0x11, 0x02, 0x03, 0x00,
1465+
})
1466+
})
1467+
1468+
t.Run("Do not mutate packet for invalid extension", func(t *testing.T) {
1469+
h := Header{}
1470+
assert.NoError(t, h.SetExtension(1, make([]byte, 2)))
1471+
1472+
assert.Error(t, h.SetExtensionWithProfile(16, make([]byte, 4096), ExtensionProfileTwoByte))
1473+
1474+
assert.Equal(t, h.ExtensionProfile, uint16(ExtensionProfileOneByte))
1475+
assert.Len(t, h.Extensions, 1)
1476+
})
1477+
}
1478+
14171479
func BenchmarkMarshal(b *testing.B) {
14181480
rawPkt := []byte{
14191481
0x90, 0x60, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64,

0 commit comments

Comments
 (0)