Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes UDPMux IPv4/IPv6 aware #431

Merged
merged 1 commit into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
return errUDPMuxDisabled
}

localIPs, err := localInterfaces(a.net, a.interfaceFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.networkTypes)
switch {
case err != nil:
return err
Expand All @@ -254,7 +254,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
}
}

conn, err := a.udpMux.GetConn(a.localUfrag)
conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -351,7 +351,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne

for i := range urls {
wg.Add(1)
go func(url URL, network string) {
go func(url URL, network string, isIPv6 bool) {
defer wg.Done()

hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
Expand All @@ -367,7 +367,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
return
}

conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String())
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), isIPv6)
if err != nil {
a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err)
return
Expand Down Expand Up @@ -397,7 +397,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
}
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err)
}
}(*urls[i], networkType.String())
}(*urls[i], networkType.String(), networkType.IsIPv6())
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion gather_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Du
return nil, errNotImplemented
}

func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.getConnForURLTimes++
Expand Down
87 changes: 60 additions & 27 deletions udp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
// UDPMux allows multiple connections to go over a single UDP port
type UDPMux interface {
io.Closer
GetConn(ufrag string) (net.PacketConn, error)
GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
}

Expand All @@ -25,8 +25,8 @@ type UDPMuxDefault struct {
closedChan chan struct{}
closeOnce sync.Once

// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType
conns map[string]*udpMuxedConn
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn

addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn
Expand Down Expand Up @@ -54,7 +54,8 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
m := &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{},
params: params,
conns: make(map[string]*udpMuxedConn),
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
Expand All @@ -76,43 +77,47 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {

// GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()

if m.IsClosed() {
return nil, io.ErrClosedPipe
}

if c, ok := m.conns[ufrag]; ok {
return c, nil
if conn, ok := m.getConn(ufrag, isIPv6); ok {
return conn, nil
}

c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.removeConn(ufrag)
}()
m.conns[ufrag] = c

if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
m.connsIPv4[ufrag] = c
}

return c, nil
}

// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Lock()
removedConns := make([]*udpMuxedConn, 0)
for key := range m.conns {
if key != ufrag {
continue
}
removedConns := make([]*udpMuxedConn, 0, 2)

c := m.conns[key]
delete(m.conns, key)
if c != nil {
removedConns = append(removedConns, c)
}
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
}
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()

m.addressMapMu.Lock()
Expand Down Expand Up @@ -143,21 +148,39 @@ func (m *UDPMuxDefault) Close() error {
m.mu.Lock()
defer m.mu.Unlock()

for _, c := range m.conns {
for _, c := range m.connsIPv4 {
_ = c.Close()
}
m.conns = make(map[string]*udpMuxedConn)
for _, c := range m.connsIPv6 {
_ = c.Close()
}

m.connsIPv4 = make(map[string]*udpMuxedConn)
m.connsIPv6 = make(map[string]*udpMuxedConn)

close(m.closedChan)
})
return err
}

func (m *UDPMuxDefault) removeConn(key string) {
m.mu.Lock()
c := m.conns[key]
delete(m.conns, key)
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()
c := func() *udpMuxedConn {
m.mu.Lock()
defer m.mu.Unlock()

if c, ok := m.connsIPv4[key]; ok {
delete(m.connsIPv4, key)
return c
}

if c, ok := m.connsIPv6[key]; ok {
delete(m.connsIPv6, key)
return c
}

return nil
}()

if c == nil {
return
Expand Down Expand Up @@ -255,9 +278,10 @@ func (m *UDPMuxDefault) connWorker() {
}

ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := udpAddr.IP.To4() == nil

m.mu.Lock()
destinationConn = m.conns[ufrag]
destinationConn, _ = m.getConn(ufrag, isIPv6)
m.mu.Unlock()
}

Expand All @@ -272,6 +296,15 @@ func (m *UDPMuxDefault) connWorker() {
}
}

func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}
return
}

type bufferHolder struct {
buffer []byte
}
Expand Down
4 changes: 2 additions & 2 deletions udp_mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestUDPMux(t *testing.T) {
require.NoError(t, udpMux.Close())

// can't create more connections
_, err = udpMux.GetConn("failufrag")
_, err = udpMux.GetConn("failufrag", false)
require.Error(t, err)
}

Expand Down Expand Up @@ -110,7 +110,7 @@ func TestAddressEncoding(t *testing.T) {
}

func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag)
pktConn, err := udpMux.GetConn(ufrag, false)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
Expand Down
6 changes: 3 additions & 3 deletions udp_mux_universal.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type UniversalUDPMux interface {
UDPMux
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
GetConnForURL(ufrag string, url string) (net.PacketConn, error)
GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error)
}

// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
Expand Down Expand Up @@ -84,8 +84,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time

// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url))
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), isIPv6)
}

// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address.
Expand Down
2 changes: 1 addition & 1 deletion udp_mux_universal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestUniversalUDPMux(t *testing.T) {
}

func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag)
pktConn, err := udpMux.GetConn(ufrag, false)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
Expand Down