Skip to content

Commit

Permalink
Make TCPMux IPv4/IPv6 aware
Browse files Browse the repository at this point in the history
TCPMux before would create one internal connection per ufrag. This could
cause remote IPv6 traffic to be dispatched to a local IPv4 handler (or
the inverse). The ice.Agent would then discard the traffic since a
candidate pair must be the same IP version.

This commit now creates two connections per ufrag. When requesting a
connection for a ufrag the user must specify if they want IPv4 or IPv6.

Resolves pion/webrtc#2125
Resolves pion/webrtc#1356
  • Loading branch information
Sean-Der committed Feb 21, 2022
1 parent 2d70ec8 commit a813428
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 25 deletions.
2 changes: 1 addition & 1 deletion gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
case tcp:
// Handle ICE TCP passive mode
a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag)
conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag)
conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil)
if err != nil {
if !errors.Is(err, ErrTCPMuxNotInitialized) {
a.log.Warnf("error getting tcp conn by ufrag: %s %s %s\n", network, ip, a.localUfrag)
Expand Down
71 changes: 50 additions & 21 deletions tcp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
// 2. allow mocking in tests.
type TCPMux interface {
io.Closer
GetConnByUfrag(ufrag string) (net.PacketConn, error)
GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
}

Expand All @@ -36,7 +36,7 @@ func (m *invalidTCPMux) Close() error {
}

// GetConnByUfrag implements TCPMux interface.
func (m *invalidTCPMux) GetConnByUfrag(ufrag string) (net.PacketConn, error) {
func (m *invalidTCPMux) GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) {
return nil, ErrTCPMuxNotInitialized
}

Expand All @@ -49,8 +49,8 @@ type TCPMuxDefault struct {
params *TCPMuxParams
closed bool

// conns is a map of all tcpPacketConns indexed by ufrag
conns map[string]*tcpPacketConn
// connsIPv4 and connsIPv6 are maps of all tcpPacketConns indexed by ufrag
connsIPv4, connsIPv6 map[string]*tcpPacketConn

mu sync.Mutex
wg sync.WaitGroup
Expand All @@ -72,7 +72,8 @@ func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault {
m := &TCPMuxDefault{
params: &params,

conns: map[string]*tcpPacketConn{},
connsIPv4: map[string]*tcpPacketConn{},
connsIPv6: map[string]*tcpPacketConn{},
}

m.wg.Add(1)
Expand Down Expand Up @@ -109,33 +110,33 @@ func (m *TCPMuxDefault) LocalAddr() net.Addr {
}

// GetConnByUfrag retrieves an existing or creates a new net.PacketConn.
func (m *TCPMuxDefault) GetConnByUfrag(ufrag string) (net.PacketConn, error) {
func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()

if m.closed {
return nil, io.ErrClosedPipe
}

conn, ok := m.conns[ufrag]

if ok {
if conn, ok := m.getConn(ufrag, isIPv6); ok {
return conn, nil
// return nil, fmt.Errorf("duplicate ufrag %v", ufrag)
}

conn = m.createConn(ufrag, m.LocalAddr())

return conn, nil
return m.createConn(ufrag, m.LocalAddr(), isIPv6), nil
}

func (m *TCPMuxDefault) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn {
func (m *TCPMuxDefault) createConn(ufrag string, localAddr net.Addr, isIPv6 bool) *tcpPacketConn {
conn := newTCPPacketConn(tcpPacketParams{
ReadBuffer: m.params.ReadBufferSize,
LocalAddr: localAddr,
Logger: m.params.Logger,
})
m.conns[ufrag] = conn

if isIPv6 {
m.connsIPv6[ufrag] = conn
} else {
m.connsIPv4[ufrag] = conn
}

m.wg.Add(1)
go func() {
Expand Down Expand Up @@ -199,9 +200,17 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
m.mu.Lock()
defer m.mu.Unlock()

packetConn, ok := m.conns[ufrag]
host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to get host in STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr())
return
}

isIPv6 := net.ParseIP(host).To4() == nil
packetConn, ok := m.getConn(ufrag, isIPv6)
if !ok {
packetConn = m.createConn(ufrag, conn.LocalAddr())
packetConn = m.createConn(ufrag, conn.LocalAddr(), isIPv6)
}

if err := packetConn.AddConn(conn, buf); err != nil {
Expand All @@ -216,10 +225,15 @@ func (m *TCPMuxDefault) Close() error {
m.mu.Lock()
m.closed = true

for _, conn := range m.conns {
for _, conn := range m.connsIPv4 {
m.closeAndLogError(conn)
}
m.conns = map[string]*tcpPacketConn{}
for _, conn := range m.connsIPv6 {
m.closeAndLogError(conn)
}

m.connsIPv4 = map[string]*tcpPacketConn{}
m.connsIPv6 = map[string]*tcpPacketConn{}

err := m.params.Listener.Close()

Expand All @@ -235,12 +249,27 @@ func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Lock()
defer m.mu.Unlock()

if conn, ok := m.conns[ufrag]; ok {
if conn, ok := m.connsIPv4[ufrag]; ok {
m.closeAndLogError(conn)
delete(m.connsIPv4, ufrag)
}

if conn, ok := m.connsIPv6[ufrag]; ok {
m.closeAndLogError(conn)
delete(m.conns, ufrag)
delete(m.connsIPv6, ufrag)
}
}

func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *tcpPacketConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}

return
}

const streamingPacketHeaderLen = 2

// readStreamingPacket reads 1 packet from stream
Expand Down
6 changes: 3 additions & 3 deletions tcp_mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestTCPMux_Recv(t *testing.T) {
n, err := writeStreamingPacket(conn, msg.Raw)
require.NoError(t, err, "error writing tcp stun packet")

pktConn, err := tcpMux.GetConnByUfrag("myufrag")
pktConn, err := tcpMux.GetConnByUfrag("myufrag", false)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
Expand Down Expand Up @@ -90,12 +90,12 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
ReadBufferSize: 20,
})

_, err = tcpMux.GetConnByUfrag("test")
_, err = tcpMux.GetConnByUfrag("test", false)
require.NoError(t, err, "error getting conn by ufrag")

require.NoError(t, tcpMux.Close(), "error closing tcpMux")

conn, err := tcpMux.GetConnByUfrag("test")
conn, err := tcpMux.GetConnByUfrag("test", false)
assert.Nil(t, conn, "should receive nil because mux is closed")
assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed")
}

0 comments on commit a813428

Please sign in to comment.