diff --git a/gather.go b/gather.go index 8d2c40db..a1dc437b 100644 --- a/gather.go +++ b/gather.go @@ -177,7 +177,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ case tcp: // Handle ICE TCP passive mode a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag) - conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag) + conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil) if err != nil { if !errors.Is(err, ErrTCPMuxNotInitialized) { a.log.Warnf("error getting tcp conn by ufrag: %s %s %s\n", network, ip, a.localUfrag) diff --git a/tcp_mux.go b/tcp_mux.go index e5da608c..86cf530e 100644 --- a/tcp_mux.go +++ b/tcp_mux.go @@ -19,7 +19,7 @@ import ( // 2. allow mocking in tests. type TCPMux interface { io.Closer - GetConnByUfrag(ufrag string) (net.PacketConn, error) + GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) RemoveConnByUfrag(ufrag string) } @@ -36,7 +36,7 @@ func (m *invalidTCPMux) Close() error { } // GetConnByUfrag implements TCPMux interface. -func (m *invalidTCPMux) GetConnByUfrag(ufrag string) (net.PacketConn, error) { +func (m *invalidTCPMux) GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) { return nil, ErrTCPMuxNotInitialized } @@ -49,8 +49,8 @@ type TCPMuxDefault struct { params *TCPMuxParams closed bool - // conns is a map of all tcpPacketConns indexed by ufrag - conns map[string]*tcpPacketConn + // connsIPv4 and connsIPv6 are maps of all tcpPacketConns indexed by ufrag + connsIPv4, connsIPv6 map[string]*tcpPacketConn mu sync.Mutex wg sync.WaitGroup @@ -72,7 +72,8 @@ func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault { m := &TCPMuxDefault{ params: ¶ms, - conns: map[string]*tcpPacketConn{}, + connsIPv4: map[string]*tcpPacketConn{}, + connsIPv6: map[string]*tcpPacketConn{}, } m.wg.Add(1) @@ -109,7 +110,7 @@ func (m *TCPMuxDefault) LocalAddr() net.Addr { } // GetConnByUfrag retrieves an existing or creates a new net.PacketConn. -func (m *TCPMuxDefault) GetConnByUfrag(ufrag string) (net.PacketConn, error) { +func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() @@ -117,25 +118,25 @@ func (m *TCPMuxDefault) GetConnByUfrag(ufrag string) (net.PacketConn, error) { return nil, io.ErrClosedPipe } - conn, ok := m.conns[ufrag] - - if ok { + if conn, ok := m.getConn(ufrag, isIPv6); ok { return conn, nil - // return nil, fmt.Errorf("duplicate ufrag %v", ufrag) } - conn = m.createConn(ufrag, m.LocalAddr()) - - return conn, nil + return m.createConn(ufrag, m.LocalAddr(), isIPv6), nil } -func (m *TCPMuxDefault) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn { +func (m *TCPMuxDefault) createConn(ufrag string, localAddr net.Addr, isIPv6 bool) *tcpPacketConn { conn := newTCPPacketConn(tcpPacketParams{ ReadBuffer: m.params.ReadBufferSize, LocalAddr: localAddr, Logger: m.params.Logger, }) - m.conns[ufrag] = conn + + if isIPv6 { + m.connsIPv6[ufrag] = conn + } else { + m.connsIPv4[ufrag] = conn + } m.wg.Add(1) go func() { @@ -199,9 +200,17 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) { m.mu.Lock() defer m.mu.Unlock() - packetConn, ok := m.conns[ufrag] + host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + m.closeAndLogError(conn) + m.params.Logger.Warnf("Failed to get host in STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr()) + return + } + + isIPv6 := net.ParseIP(host).To4() == nil + packetConn, ok := m.getConn(ufrag, isIPv6) if !ok { - packetConn = m.createConn(ufrag, conn.LocalAddr()) + packetConn = m.createConn(ufrag, conn.LocalAddr(), isIPv6) } if err := packetConn.AddConn(conn, buf); err != nil { @@ -216,10 +225,15 @@ func (m *TCPMuxDefault) Close() error { m.mu.Lock() m.closed = true - for _, conn := range m.conns { + for _, conn := range m.connsIPv4 { m.closeAndLogError(conn) } - m.conns = map[string]*tcpPacketConn{} + for _, conn := range m.connsIPv6 { + m.closeAndLogError(conn) + } + + m.connsIPv4 = map[string]*tcpPacketConn{} + m.connsIPv6 = map[string]*tcpPacketConn{} err := m.params.Listener.Close() @@ -235,12 +249,27 @@ func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) { m.mu.Lock() defer m.mu.Unlock() - if conn, ok := m.conns[ufrag]; ok { + if conn, ok := m.connsIPv4[ufrag]; ok { + m.closeAndLogError(conn) + delete(m.connsIPv4, ufrag) + } + + if conn, ok := m.connsIPv6[ufrag]; ok { m.closeAndLogError(conn) - delete(m.conns, ufrag) + delete(m.connsIPv6, ufrag) } } +func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *tcpPacketConn, ok bool) { + if isIPv6 { + val, ok = m.connsIPv6[ufrag] + } else { + val, ok = m.connsIPv4[ufrag] + } + + return +} + const streamingPacketHeaderLen = 2 // readStreamingPacket reads 1 packet from stream diff --git a/tcp_mux_test.go b/tcp_mux_test.go index 8340b943..56fee6ee 100644 --- a/tcp_mux_test.go +++ b/tcp_mux_test.go @@ -55,7 +55,7 @@ func TestTCPMux_Recv(t *testing.T) { n, err := writeStreamingPacket(conn, msg.Raw) require.NoError(t, err, "error writing tcp stun packet") - pktConn, err := tcpMux.GetConnByUfrag("myufrag") + pktConn, err := tcpMux.GetConnByUfrag("myufrag", false) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() @@ -90,12 +90,12 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) { ReadBufferSize: 20, }) - _, err = tcpMux.GetConnByUfrag("test") + _, err = tcpMux.GetConnByUfrag("test", false) require.NoError(t, err, "error getting conn by ufrag") require.NoError(t, tcpMux.Close(), "error closing tcpMux") - conn, err := tcpMux.GetConnByUfrag("test") + conn, err := tcpMux.GetConnByUfrag("test", false) assert.Nil(t, conn, "should receive nil because mux is closed") assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed") }