Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make TCPMux IPv4/IPv6 aware #424

Merged
merged 1 commit into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Luke Curley <[email protected]>
Meelap Shah <[email protected]>
Michael MacDonald <[email protected]>
Michael MacDonald <[email protected]>
Mikhail Bragin <[email protected]>
Nevio Vesic <[email protected]>
Ori Bernstein <[email protected]>
Robert Eperjesi <[email protected]>
Expand Down
2 changes: 1 addition & 1 deletion gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
case tcp:
// Handle ICE TCP passive mode
a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag)
conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag)
conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil)
if err != nil {
if !errors.Is(err, ErrTCPMuxNotInitialized) {
a.log.Warnf("error getting tcp conn by ufrag: %s %s %s\n", network, ip, a.localUfrag)
Expand Down
71 changes: 50 additions & 21 deletions tcp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
// 2. allow mocking in tests.
type TCPMux interface {
io.Closer
GetConnByUfrag(ufrag string) (net.PacketConn, error)
GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
}

Expand All @@ -36,7 +36,7 @@ func (m *invalidTCPMux) Close() error {
}

// GetConnByUfrag implements TCPMux interface.
func (m *invalidTCPMux) GetConnByUfrag(ufrag string) (net.PacketConn, error) {
func (m *invalidTCPMux) GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) {
return nil, ErrTCPMuxNotInitialized
}

Expand All @@ -49,8 +49,8 @@ type TCPMuxDefault struct {
params *TCPMuxParams
closed bool

// conns is a map of all tcpPacketConns indexed by ufrag
conns map[string]*tcpPacketConn
// connsIPv4 and connsIPv6 are maps of all tcpPacketConns indexed by ufrag
connsIPv4, connsIPv6 map[string]*tcpPacketConn

mu sync.Mutex
wg sync.WaitGroup
Expand All @@ -72,7 +72,8 @@ func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault {
m := &TCPMuxDefault{
params: &params,

conns: map[string]*tcpPacketConn{},
connsIPv4: map[string]*tcpPacketConn{},
connsIPv6: map[string]*tcpPacketConn{},
}

m.wg.Add(1)
Expand Down Expand Up @@ -109,33 +110,33 @@ func (m *TCPMuxDefault) LocalAddr() net.Addr {
}

// GetConnByUfrag retrieves an existing or creates a new net.PacketConn.
func (m *TCPMuxDefault) GetConnByUfrag(ufrag string) (net.PacketConn, error) {
func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()

if m.closed {
return nil, io.ErrClosedPipe
}

conn, ok := m.conns[ufrag]

if ok {
if conn, ok := m.getConn(ufrag, isIPv6); ok {
return conn, nil
// return nil, fmt.Errorf("duplicate ufrag %v", ufrag)
}

conn = m.createConn(ufrag, m.LocalAddr())

return conn, nil
return m.createConn(ufrag, m.LocalAddr(), isIPv6), nil
}

func (m *TCPMuxDefault) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn {
func (m *TCPMuxDefault) createConn(ufrag string, localAddr net.Addr, isIPv6 bool) *tcpPacketConn {
conn := newTCPPacketConn(tcpPacketParams{
ReadBuffer: m.params.ReadBufferSize,
LocalAddr: localAddr,
Logger: m.params.Logger,
})
m.conns[ufrag] = conn

if isIPv6 {
m.connsIPv6[ufrag] = conn
} else {
m.connsIPv4[ufrag] = conn
}

m.wg.Add(1)
go func() {
Expand Down Expand Up @@ -199,9 +200,17 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
m.mu.Lock()
defer m.mu.Unlock()

packetConn, ok := m.conns[ufrag]
host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to get host in STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr())
return
}

isIPv6 := net.ParseIP(host).To4() == nil
packetConn, ok := m.getConn(ufrag, isIPv6)
if !ok {
packetConn = m.createConn(ufrag, conn.LocalAddr())
packetConn = m.createConn(ufrag, conn.LocalAddr(), isIPv6)
}

if err := packetConn.AddConn(conn, buf); err != nil {
Expand All @@ -216,10 +225,15 @@ func (m *TCPMuxDefault) Close() error {
m.mu.Lock()
m.closed = true

for _, conn := range m.conns {
for _, conn := range m.connsIPv4 {
m.closeAndLogError(conn)
}
m.conns = map[string]*tcpPacketConn{}
for _, conn := range m.connsIPv6 {
m.closeAndLogError(conn)
}

m.connsIPv4 = map[string]*tcpPacketConn{}
m.connsIPv6 = map[string]*tcpPacketConn{}

err := m.params.Listener.Close()

Expand All @@ -235,12 +249,27 @@ func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Lock()
defer m.mu.Unlock()

if conn, ok := m.conns[ufrag]; ok {
if conn, ok := m.connsIPv4[ufrag]; ok {
m.closeAndLogError(conn)
delete(m.connsIPv4, ufrag)
}

if conn, ok := m.connsIPv6[ufrag]; ok {
m.closeAndLogError(conn)
delete(m.conns, ufrag)
delete(m.connsIPv6, ufrag)
}
}

func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *tcpPacketConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}

return
}

const streamingPacketHeaderLen = 2

// readStreamingPacket reads 1 packet from stream
Expand Down
6 changes: 3 additions & 3 deletions tcp_mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestTCPMux_Recv(t *testing.T) {
n, err := writeStreamingPacket(conn, msg.Raw)
require.NoError(t, err, "error writing tcp stun packet")

pktConn, err := tcpMux.GetConnByUfrag("myufrag")
pktConn, err := tcpMux.GetConnByUfrag("myufrag", false)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
Expand Down Expand Up @@ -90,12 +90,12 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
ReadBufferSize: 20,
})

_, err = tcpMux.GetConnByUfrag("test")
_, err = tcpMux.GetConnByUfrag("test", false)
require.NoError(t, err, "error getting conn by ufrag")

require.NoError(t, tcpMux.Close(), "error closing tcpMux")

conn, err := tcpMux.GetConnByUfrag("test")
conn, err := tcpMux.GetConnByUfrag("test", false)
assert.Nil(t, conn, "should receive nil because mux is closed")
assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed")
}