Skip to content

Commit 47dad55

Browse files
committed
Add methods to add and remove extensions
Added `AddExtension` and `RemoveExtension` methods to `ICECandidate`, allowing extensions to be managed dynamically. Ensure that `TCPType` is stored in one place (candidate.TCPType)
1 parent cad1676 commit 47dad55

File tree

3 files changed

+228
-10
lines changed

3 files changed

+228
-10
lines changed

candidate.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,18 @@ type Candidate interface {
5858
// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1
5959
//.
6060
Extensions() []CandidateExtension
61-
6261
// GetExtension returns the value of the extension attribute associated with the ICECandidate.
6362
// Extension attributes are defined in RFC 5245, Section 15.1:
6463
// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1
6564
//.
6665
GetExtension(key string) (value CandidateExtension, ok bool)
66+
// AddExtension adds an extension attribute to the ICECandidate.
67+
// If an extension with the same key already exists, it will be overwritten.
68+
// Extension attributes are defined in RFC 5245, Section 15.1:
69+
AddExtension(extension CandidateExtension) error
70+
// RemoveExtension removes an extension attribute from the ICECandidate.
71+
// Extension attributes are defined in RFC 5245, Section 15.1:
72+
RemoveExtension(key string) (ok bool)
6773

6874
String() string
6975
Type() CandidateType

candidate_base.go

+65-9
Original file line numberDiff line numberDiff line change
@@ -548,17 +548,22 @@ type CandidateExtension struct {
548548
}
549549

550550
func (c *candidateBase) Extensions() []CandidateExtension {
551-
// IF Extensions were not parsed using UnmarshalCandidate
552-
// For backwards compatibility when the TCPType is set manually
553-
if len(c.extensions) == 0 && c.TCPType() != TCPTypeUnspecified {
554-
return []CandidateExtension{{
551+
tcpType := c.TCPType()
552+
hasTCPType := 0
553+
if tcpType != TCPTypeUnspecified {
554+
hasTCPType = 1
555+
}
556+
557+
extensions := make([]CandidateExtension, len(c.extensions)+hasTCPType)
558+
// We store the TCPType in c.tcpType, but we need to return it as an extension.
559+
if hasTCPType == 1 {
560+
extensions[0] = CandidateExtension{
555561
Key: "tcptype",
556-
Value: c.TCPType().String(),
557-
}}
562+
Value: tcpType.String(),
563+
}
558564
}
559565

560-
extensions := make([]CandidateExtension, len(c.extensions))
561-
copy(extensions, c.extensions)
566+
copy(extensions[hasTCPType:], c.extensions)
562567

563568
return extensions
564569
}
@@ -576,7 +581,7 @@ func (c *candidateBase) GetExtension(key string) (CandidateExtension, bool) {
576581
}
577582

578583
// TCPType was manually set.
579-
if key == "tcptype" && c.TCPType() != TCPTypeUnspecified {
584+
if key == "tcptype" && c.TCPType() != TCPTypeUnspecified { //nolint:goconst
580585
extension.Value = c.TCPType().String()
581586

582587
return extension, true
@@ -585,6 +590,55 @@ func (c *candidateBase) GetExtension(key string) (CandidateExtension, bool) {
585590
return extension, false
586591
}
587592

593+
func (c *candidateBase) AddExtension(ext CandidateExtension) error {
594+
if ext.Key == "tcptype" {
595+
tcpType := NewTCPType(ext.Value)
596+
if tcpType == TCPTypeUnspecified {
597+
return fmt.Errorf("%w: invalid or unsupported TCPtype %s", errParseTCPType, ext.Value)
598+
}
599+
600+
c.tcpType = tcpType
601+
602+
return nil
603+
}
604+
605+
if ext.Key == "" {
606+
return fmt.Errorf("%w: key is empty", errParseExtension)
607+
}
608+
609+
// per spec, Extensions aren't explicitly unique, we only set the first one.
610+
// If the exteion is set multiple times.
611+
for i := range c.extensions {
612+
if c.extensions[i].Key == ext.Key {
613+
c.extensions[i] = ext
614+
615+
return nil
616+
}
617+
}
618+
619+
c.extensions = append(c.extensions, ext)
620+
621+
return nil
622+
}
623+
624+
func (c *candidateBase) RemoveExtension(key string) (ok bool) {
625+
if key == "tcptype" {
626+
c.tcpType = TCPTypeUnspecified
627+
ok = true
628+
}
629+
630+
for i := range c.extensions {
631+
if c.extensions[i].Key == key {
632+
c.extensions = append(c.extensions[:i], c.extensions[i+1:]...)
633+
ok = true
634+
635+
break
636+
}
637+
}
638+
639+
return ok
640+
}
641+
588642
// marshalExtensions returns the string representation of the candidate extensions.
589643
func (c *candidateBase) marshalExtensions() string {
590644
value := ""
@@ -994,6 +1048,8 @@ func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension,
9941048

9951049
if key == "tcptype" {
9961050
rawTCPTypeRaw = value
1051+
1052+
continue
9971053
}
9981054

9991055
extensions = append(extensions, CandidateExtension{key, value})

candidate_test.go

+156
Original file line numberDiff line numberDiff line change
@@ -1271,3 +1271,159 @@ func TestBaseCandidateExtensionsEqual(t *testing.T) {
12711271
})
12721272
}
12731273
}
1274+
1275+
func TestCandidateAddExtension(t *testing.T) {
1276+
t.Run("Add extension", func(t *testing.T) {
1277+
candidate, err := NewCandidateHost(&CandidateHostConfig{
1278+
Network: NetworkTypeUDP4.String(),
1279+
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
1280+
Port: 53987,
1281+
Priority: 500,
1282+
Foundation: "750",
1283+
})
1284+
if err != nil {
1285+
t.Error(err)
1286+
}
1287+
1288+
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
1289+
require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"}))
1290+
1291+
extensions := candidate.Extensions()
1292+
require.Equal(t, []CandidateExtension{{"a", "b"}, {"c", "d"}}, extensions)
1293+
})
1294+
1295+
t.Run("Add extension with existing key", func(t *testing.T) {
1296+
candidate, err := NewCandidateHost(&CandidateHostConfig{
1297+
Network: NetworkTypeUDP4.String(),
1298+
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
1299+
Port: 53987,
1300+
Priority: 500,
1301+
Foundation: "750",
1302+
})
1303+
if err != nil {
1304+
t.Error(err)
1305+
}
1306+
1307+
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
1308+
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "d"}))
1309+
1310+
extensions := candidate.Extensions()
1311+
require.Equal(t, []CandidateExtension{{"a", "d"}}, extensions)
1312+
})
1313+
1314+
t.Run("Keep tcptype extension", func(t *testing.T) {
1315+
candidate, err := NewCandidateHost(&CandidateHostConfig{
1316+
Network: NetworkTypeTCP4.String(),
1317+
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
1318+
Port: 53987,
1319+
Priority: 500,
1320+
Foundation: "750",
1321+
TCPType: TCPTypeActive,
1322+
})
1323+
if err != nil {
1324+
t.Error(err)
1325+
}
1326+
1327+
ext, ok := candidate.GetExtension("tcptype")
1328+
require.True(t, ok)
1329+
require.Equal(t, ext, CandidateExtension{"tcptype", "active"})
1330+
require.Equal(t, candidate.Extensions(), []CandidateExtension{{"tcptype", "active"}})
1331+
1332+
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
1333+
1334+
ext, ok = candidate.GetExtension("tcptype")
1335+
require.True(t, ok)
1336+
require.Equal(t, ext, CandidateExtension{"tcptype", "active"})
1337+
require.Equal(t, candidate.Extensions(), []CandidateExtension{{"tcptype", "active"}, {"a", "b"}})
1338+
})
1339+
1340+
t.Run("TcpType change extension", func(t *testing.T) {
1341+
candidate, err := NewCandidateHost(&CandidateHostConfig{
1342+
Network: NetworkTypeTCP4.String(),
1343+
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
1344+
Port: 53987,
1345+
Priority: 500,
1346+
Foundation: "750",
1347+
})
1348+
if err != nil {
1349+
t.Error(err)
1350+
}
1351+
1352+
require.NoError(t, candidate.AddExtension(CandidateExtension{"tcptype", "active"}))
1353+
1354+
extensions := candidate.Extensions()
1355+
require.Equal(t, []CandidateExtension{{"tcptype", "active"}}, extensions)
1356+
require.Equal(t, TCPTypeActive, candidate.TCPType())
1357+
1358+
require.Error(t, candidate.AddExtension(CandidateExtension{"tcptype", "INVALID"}))
1359+
})
1360+
}
1361+
1362+
func TestCandidateRemoveExtension(t *testing.T) {
1363+
t.Run("Remove extension", func(t *testing.T) {
1364+
candidate, err := NewCandidateHost(&CandidateHostConfig{
1365+
Network: NetworkTypeUDP4.String(),
1366+
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
1367+
Port: 53987,
1368+
Priority: 500,
1369+
Foundation: "750",
1370+
})
1371+
if err != nil {
1372+
t.Error(err)
1373+
}
1374+
1375+
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
1376+
require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"}))
1377+
1378+
require.True(t, candidate.RemoveExtension("a"))
1379+
1380+
extensions := candidate.Extensions()
1381+
require.Equal(t, []CandidateExtension{{"c", "d"}}, extensions)
1382+
})
1383+
1384+
t.Run("Remove extension that does not exist", func(t *testing.T) {
1385+
candidate, err := NewCandidateHost(&CandidateHostConfig{
1386+
Network: NetworkTypeUDP4.String(),
1387+
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
1388+
Port: 53987,
1389+
Priority: 500,
1390+
Foundation: "750",
1391+
})
1392+
if err != nil {
1393+
t.Error(err)
1394+
}
1395+
1396+
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
1397+
require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"}))
1398+
1399+
require.False(t, candidate.RemoveExtension("b"))
1400+
1401+
extensions := candidate.Extensions()
1402+
require.Equal(t, []CandidateExtension{{"a", "b"}, {"c", "d"}}, extensions)
1403+
})
1404+
1405+
t.Run("Remove tcptype extension", func(t *testing.T) {
1406+
candidate, err := NewCandidateHost(&CandidateHostConfig{
1407+
Network: NetworkTypeTCP4.String(),
1408+
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
1409+
Port: 53987,
1410+
Priority: 500,
1411+
Foundation: "750",
1412+
TCPType: TCPTypeActive,
1413+
})
1414+
if err != nil {
1415+
t.Error(err)
1416+
}
1417+
1418+
// tcptype extension should be removed, even if it's not in the extensions list (Not Parsed)
1419+
require.True(t, candidate.RemoveExtension("tcptype"))
1420+
require.Equal(t, TCPTypeUnspecified, candidate.TCPType())
1421+
require.Empty(t, candidate.Extensions())
1422+
1423+
require.NoError(t, candidate.AddExtension(CandidateExtension{"tcptype", "passive"}))
1424+
1425+
require.True(t, candidate.RemoveExtension("tcptype"))
1426+
require.Equal(t, TCPTypeUnspecified, candidate.TCPType())
1427+
require.Empty(t, candidate.Extensions())
1428+
})
1429+
}

0 commit comments

Comments
 (0)