diff --git a/client_integration_test.go b/client_integration_test.go index ab4bedc..6eb6071 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -1,10 +1,10 @@ package wgctrl_test import ( - "bytes" "errors" "fmt" "net" + "net/netip" "os" "sort" "strings" @@ -144,9 +144,9 @@ func testGet(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { var ( port = 8888 - ips = []net.IPNet{ - wgtest.MustCIDR("192.0.2.0/32"), - wgtest.MustCIDR("2001:db8::/128"), + ips = []netip.Prefix{ + netip.MustParsePrefix("192.0.2.0/32"), + netip.MustParsePrefix("2001:db8::/128"), } priv = wgtest.MustPrivateKey() @@ -194,7 +194,7 @@ func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { for i := range dn.Peers { ips := dn.Peers[i].AllowedIPs sort.Slice(ips, func(i, j int) bool { - return bytes.Compare(ips[i].IP, ips[j].IP) > 0 + return ips[i].Addr().Compare(ips[j].Addr()) > 0 }) } @@ -229,17 +229,19 @@ func testConfigureManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { t.Fatalf("failed to create cursor: %v", err) } - var ips []net.IPNet + var ips []netip.Prefix for pos := cur.Next(); pos != nil; pos = cur.Next() { bits := 128 if pos.IP.To4() != nil { bits = 32 } - ips = append(ips, net.IPNet{ - IP: pos.IP, - Mask: net.CIDRMask(bits, bits), - }) + addr, ok := netip.AddrFromSlice(pos.IP) + if !ok { + t.Fatalf("failed to convert net.IP to netip.Addr: %s", pos.IP) + } + + ips = append(ips, netip.PrefixFrom(addr, bits)) } peers = append(peers, wgtypes.PeerConfig{ @@ -291,7 +293,7 @@ func testConfigureManyPeers(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { PresharedKey: &pk, ReplaceAllowedIPs: true, Endpoint: &net.UDPAddr{ - IP: ips[0].IP, + IP: ips[0].Addr().AsSlice(), Port: 1111, }, PersistentKeepaliveInterval: &dur, @@ -370,7 +372,6 @@ func testConfigurePeersUpdateOnly(t *testing.T, c *wgctrl.Client, d *wgtypes.Dev t.Skip("FreeBSD kernel devices do not support UpdateOnly flag") } - t.Fatalf("failed to configure second time on %q: %v", d.Name, err) } @@ -428,7 +429,7 @@ func countPeerIPs(d *wgtypes.Device) int { return count } -func ipsString(ipns []net.IPNet) string { +func ipsString(ipns []netip.Prefix) string { ss := make([]string, 0, len(ipns)) for _, ipn := range ipns { ss = append(ss, ipn.String()) @@ -437,23 +438,25 @@ func ipsString(ipns []net.IPNet) string { return strings.Join(ss, ", ") } -func generateIPs(n int) []net.IPNet { +func generateIPs(n int) []netip.Prefix { cur, err := ipaddr.Parse("2001:db8::/64") if err != nil { panicf("failed to create cursor: %v", err) } - ips := make([]net.IPNet, 0, n) + ips := make([]netip.Prefix, 0, n) for i := 0; i < n; i++ { pos := cur.Next() if pos == nil { panic("hit nil IP during IP generation") } - ips = append(ips, net.IPNet{ - IP: pos.IP, - Mask: net.CIDRMask(128, 128), - }) + addr, ok := netip.AddrFromSlice(pos.IP) + if !ok { + panicf("failed to convert net.IP to netip.Addr: %s", pos.IP) + } + + ips = append(ips, netip.PrefixFrom(addr, 128)) } return ips diff --git a/cmd/wgctrl/main.go b/cmd/wgctrl/main.go index 9fbba3a..5830fa6 100644 --- a/cmd/wgctrl/main.go +++ b/cmd/wgctrl/main.go @@ -6,7 +6,7 @@ import ( "flag" "fmt" "log" - "net" + "net/netip" "strings" "golang.zx2c4.com/wireguard/wgctrl" @@ -83,7 +83,7 @@ func printPeer(p wgtypes.Peer) { ) } -func ipsString(ipns []net.IPNet) string { +func ipsString(ipns []netip.Prefix) string { ss := make([]string, 0, len(ipns)) for _, ipn := range ipns { ss = append(ss, ipn.String()) diff --git a/internal/wglinux/client_linux_test.go b/internal/wglinux/client_linux_test.go index 5f6c054..4aedfea 100644 --- a/internal/wglinux/client_linux_test.go +++ b/internal/wglinux/client_linux_test.go @@ -6,7 +6,7 @@ package wglinux import ( "errors" "fmt" - "net" + "net/netip" "os" "os/user" "syscall" @@ -325,7 +325,7 @@ func diffAttrs(x, y []netlink.Attribute) string { return cmp.Diff(xPrime, yPrime) } -func mustAllowedIPs(ipns []net.IPNet) []byte { +func mustAllowedIPs(ipns []netip.Prefix) []byte { ae := netlink.NewAttributeEncoder() if err := encodeAllowedIPs(ipns)(ae); err != nil { panicf("failed to create allowed IP attributes: %v", err) diff --git a/internal/wglinux/configure_linux.go b/internal/wglinux/configure_linux.go index bf29092..f8f281c 100644 --- a/internal/wglinux/configure_linux.go +++ b/internal/wglinux/configure_linux.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "unsafe" "github.com/mdlayher/netlink" @@ -101,16 +102,16 @@ func buildBatches(cfg wgtypes.Config) []wgtypes.Config { // Iterate until no more allowed IPs. var done bool for !done { - var tmp []net.IPNet + var tmp []netip.Prefix if len(p.AllowedIPs) < ipBatchChunk { // IPs all fit within a batch; we are done. - tmp = make([]net.IPNet, len(p.AllowedIPs)) + tmp = make([]netip.Prefix, len(p.AllowedIPs)) copy(tmp, p.AllowedIPs) done = true } else { // IPs are larger than a single batch, copy a batch out and // advance the cursor. - tmp = make([]net.IPNet, ipBatchChunk) + tmp = make([]netip.Prefix, ipBatchChunk) copy(tmp, p.AllowedIPs[:ipBatchChunk]) p.AllowedIPs = p.AllowedIPs[ipBatchChunk:] @@ -214,32 +215,26 @@ func encodePeer(p wgtypes.PeerConfig) func(ae *netlink.AttributeEncoder) error { // sockaddr_in or sockaddr_in6 bytes. func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) { return func() ([]byte, error) { - if !isValidIP(endpoint.IP) { + addrPort := endpoint.AddrPort() + if !addrPort.Addr().IsValid() { return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String()) } // Is this an IPv6 address? - if isIPv6(endpoint.IP) { - var addr [16]byte - copy(addr[:], endpoint.IP.To16()) - + if addrPort.Addr().Is6() { sa := unix.RawSockaddrInet6{ Family: unix.AF_INET6, Port: sockaddrPort(endpoint.Port), - Addr: addr, + Addr: addrPort.Addr().As16(), } return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil } - // IPv4 address handling. - var addr [4]byte - copy(addr[:], endpoint.IP.To4()) - sa := unix.RawSockaddrInet4{ Family: unix.AF_INET, Port: sockaddrPort(endpoint.Port), - Addr: addr, + Addr: addrPort.Addr().As4(), } return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil @@ -247,26 +242,25 @@ func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) { } // encodeAllowedIPs returns a function to encode allowed IP nested attributes. -func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error { +func encodeAllowedIPs(ipns []netip.Prefix) func(ae *netlink.AttributeEncoder) error { return func(ae *netlink.AttributeEncoder) error { for i, ipn := range ipns { - if !isValidIP(ipn.IP) { - return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.IP.String()) + if !ipn.Addr().IsValid() { + return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.Addr()) } family := uint16(unix.AF_INET6) - if !isIPv6(ipn.IP) { + if ipn.Addr().Is4() { // Make sure address is 4 bytes if IPv4. family = unix.AF_INET - ipn.IP = ipn.IP.To4() } // Netlink arrays use type as an array index. ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error { nae.Uint16(unix.WGALLOWEDIP_A_FAMILY, family) - nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.IP) + nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.Addr().AsSlice()) - ones, _ := ipn.Mask.Size() + ones := ipn.Bits() nae.Uint8(unix.WGALLOWEDIP_A_CIDR_MASK, uint8(ones)) return nil }) @@ -276,16 +270,6 @@ func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error } } -// isValidIP determines if IP is a valid IPv4 or IPv6 address. -func isValidIP(ip net.IP) bool { - return ip.To16() != nil -} - -// isIPv6 determines if IP is a valid IPv6 address. -func isIPv6(ip net.IP) bool { - return isValidIP(ip) && ip.To4() == nil -} - // sockaddrPort interprets port as a big endian uint16 for use passing sockaddr // structures to the kernel. func sockaddrPort(port int) uint16 { diff --git a/internal/wglinux/configure_linux_test.go b/internal/wglinux/configure_linux_test.go index 858ad30..a3d5f67 100644 --- a/internal/wglinux/configure_linux_test.go +++ b/internal/wglinux/configure_linux_test.go @@ -5,6 +5,7 @@ package wglinux import ( "net" + "net/netip" "testing" "time" "unsafe" @@ -45,9 +46,9 @@ func TestLinuxClientConfigureDevice(t *testing.T) { name: "bad peer allowed IP", cfg: wgtypes.Config{ Peers: []wgtypes.PeerConfig{{ - AllowedIPs: []net.IPNet{{ - IP: net.IP{0xff}, - }}, + AllowedIPs: []netip.Prefix{ + {}, + }, }}, }, }, @@ -71,8 +72,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) { PresharedKey: keyPtr(wgtest.MustHexKey("188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52")), Endpoint: wgtest.MustUDPAddr("[abcd:23::33%2]:51820"), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.4/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.4/32"), }, }, { @@ -81,17 +82,17 @@ func TestLinuxClientConfigureDevice(t *testing.T) { Endpoint: wgtest.MustUDPAddr("182.122.22.19:3233"), PersistentKeepaliveInterval: durPtr(111 * time.Second), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.6/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.6/32"), }, }, { PublicKey: wgtest.MustHexKey("662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58"), Endpoint: wgtest.MustUDPAddr("5.152.198.39:51820"), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.10/32"), - wgtest.MustCIDR("192.168.4.11/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.10/32"), + netip.MustParsePrefix("192.168.4.11/32"), }, }, { @@ -151,8 +152,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) { }, { Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("192.168.4.4/32"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("192.168.4.4/32"), }), }, }...), @@ -182,8 +183,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) { }, { Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("192.168.4.6/32"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("192.168.4.6/32"), }), }, }...), @@ -209,9 +210,9 @@ func TestLinuxClientConfigureDevice(t *testing.T) { }, { Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("192.168.4.10/32"), - wgtest.MustCIDR("192.168.4.11/32"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("192.168.4.10/32"), + netip.MustParsePrefix("192.168.4.11/32"), }), }, }...), @@ -513,23 +514,25 @@ func keyBytes(s string) []byte { return k[:] } -func generateIPs(n int) []net.IPNet { +func generateIPs(n int) []netip.Prefix { cur, err := ipaddr.Parse("2001:db8::/64") if err != nil { panicf("failed to create cursor: %v", err) } - ips := make([]net.IPNet, 0, n) + ips := make([]netip.Prefix, 0, n) for i := 0; i < n; i++ { pos := cur.Next() if pos == nil { panic("hit nil IP during IP generation") } - ips = append(ips, net.IPNet{ - IP: pos.IP, - Mask: net.CIDRMask(128, 128), - }) + addr, ok := netip.AddrFromSlice(pos.IP) + if !ok { + panicf("failed to convert net.IP to netip.Addr: %s", pos.IP) + } + + ips = append(ips, netip.PrefixFrom(addr, 128)) } return ips diff --git a/internal/wglinux/parse_linux.go b/internal/wglinux/parse_linux.go index 9630ffc..d8893bc 100644 --- a/internal/wglinux/parse_linux.go +++ b/internal/wglinux/parse_linux.go @@ -6,6 +6,7 @@ package wglinux import ( "fmt" "net" + "net/netip" "time" "unsafe" @@ -130,24 +131,26 @@ func parsePeer(ad *netlink.AttributeDecoder) wgtypes.Peer { } // parseAllowedIPs parses a slice of net.IPNet from a netlink attribute payload. -func parseAllowedIPs(ipns *[]net.IPNet) func(ad *netlink.AttributeDecoder) error { +func parseAllowedIPs(ipns *[]netip.Prefix) func(ad *netlink.AttributeDecoder) error { return func(ad *netlink.AttributeDecoder) error { // Initialize to the number of allowed IPs and begin iterating through // the netlink array to decode each one. - *ipns = make([]net.IPNet, 0, ad.Len()) + *ipns = make([]netip.Prefix, 0, ad.Len()) for ad.Next() { // Allowed IP nested attributes. ad.Nested(func(nad *netlink.AttributeDecoder) error { var ( - ipn net.IPNet - mask int + ipn netip.Addr + mask int + // TODO: we already have the family stored in ipn, is this needed? family int + _ = family ) for nad.Next() { switch nad.Type() { case unix.WGALLOWEDIP_A_IPADDR: - nad.Do(parseAddr(&ipn.IP)) + nad.Do(parseAddr(&ipn)) case unix.WGALLOWEDIP_A_CIDR_MASK: mask = int(nad.Uint8()) case unix.WGALLOWEDIP_A_FAMILY: @@ -159,16 +162,9 @@ func parseAllowedIPs(ipns *[]net.IPNet) func(ad *netlink.AttributeDecoder) error return err } - // The address family determines the correct number of bits in - // the mask. - switch family { - case unix.AF_INET: - ipn.Mask = net.CIDRMask(mask, 32) - case unix.AF_INET6: - ipn.Mask = net.CIDRMask(mask, 128) - } + ipp := netip.PrefixFrom(ipn, mask) - *ipns = append(*ipns, ipn) + *ipns = append(*ipns, ipp) return nil }) } @@ -191,17 +187,14 @@ func parseKey(key *wgtypes.Key) func(b []byte) error { } // parseAddr parses a net.IP from raw in_addr or in6_addr struct bytes. -func parseAddr(ip *net.IP) func(b []byte) error { +func parseAddr(ip *netip.Addr) func(b []byte) error { return func(b []byte) error { - switch len(b) { - case net.IPv4len, net.IPv6len: - // Okay to convert directly to net.IP; memory layout is identical. - *ip = make(net.IP, len(b)) - copy(*ip, b) - return nil - default: + parsedIP, ok := netip.AddrFromSlice(b) + if !ok { return fmt.Errorf("wglinux: unexpected IP address size: %d", len(b)) } + *ip = parsedIP + return nil } } diff --git a/internal/wglinux/parse_linux_test.go b/internal/wglinux/parse_linux_test.go index fcb4eeb..67ba8d9 100644 --- a/internal/wglinux/parse_linux_test.go +++ b/internal/wglinux/parse_linux_test.go @@ -5,6 +5,7 @@ package wglinux import ( "net" + "net/netip" "runtime" "testing" "time" @@ -229,9 +230,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("192.168.1.10/32"), - wgtest.MustCIDR("fd00::1/128"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("192.168.1.10/32"), + netip.MustParsePrefix("fd00::1/128"), }), }, { @@ -286,9 +287,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { LastHandshakeTime: time.Unix(10, 20), ReceiveBytes: 100, TransmitBytes: 200, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.1.10/32"), - wgtest.MustCIDR("fd00::1/128"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.10/32"), + netip.MustParsePrefix("fd00::1/128"), }, ProtocolVersion: 1, }, @@ -328,9 +329,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("192.168.1.10/32"), - wgtest.MustCIDR("192.168.1.11/32"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("192.168.1.10/32"), + netip.MustParsePrefix("192.168.1.11/32"), }), }, }...), @@ -352,9 +353,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("fd00:dead:beef:dead::/64"), - wgtest.MustCIDR("fd00:dead:beef:ffff::/64"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("fd00:dead:beef:dead::/64"), + netip.MustParsePrefix("fd00:dead:beef:ffff::/64"), }), }, }...), @@ -368,9 +369,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("10.10.10.0/24"), - wgtest.MustCIDR("10.10.11.0/24"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), }), }, }...), @@ -392,9 +393,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("10.10.12.0/24"), - wgtest.MustCIDR("10.10.13.0/24"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("10.10.12.0/24"), + netip.MustParsePrefix("10.10.13.0/24"), }), }, }...), @@ -408,9 +409,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("fd00:1234::/32"), - wgtest.MustCIDR("fd00:4567::/32"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("fd00:1234::/32"), + netip.MustParsePrefix("fd00:4567::/32"), }), }, }...), @@ -427,27 +428,27 @@ func TestLinuxClientDevicesOK(t *testing.T) { Peers: []wgtypes.Peer{ { PublicKey: keyA, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.1.10/32"), - wgtest.MustCIDR("192.168.1.11/32"), - wgtest.MustCIDR("fd00:dead:beef:dead::/64"), - wgtest.MustCIDR("fd00:dead:beef:ffff::/64"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.10/32"), + netip.MustParsePrefix("192.168.1.11/32"), + netip.MustParsePrefix("fd00:dead:beef:dead::/64"), + netip.MustParsePrefix("fd00:dead:beef:ffff::/64"), }, }, { PublicKey: keyB, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("10.10.10.0/24"), - wgtest.MustCIDR("10.10.11.0/24"), - wgtest.MustCIDR("10.10.12.0/24"), - wgtest.MustCIDR("10.10.13.0/24"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + netip.MustParsePrefix("10.10.13.0/24"), }, }, { PublicKey: keyC, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("fd00:1234::/32"), - wgtest.MustCIDR("fd00:4567::/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("fd00:1234::/32"), + netip.MustParsePrefix("fd00:4567::/32"), }, }, }, diff --git a/internal/wgtest/wgtest.go b/internal/wgtest/wgtest.go index c288e7f..62f2b68 100644 --- a/internal/wgtest/wgtest.go +++ b/internal/wgtest/wgtest.go @@ -3,21 +3,10 @@ package wgtest import ( "encoding/hex" "fmt" - "net" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "net" ) -// MustCIDR converts CIDR string s into a net.IPNet or panics. -func MustCIDR(s string) net.IPNet { - _, cidr, err := net.ParseCIDR(s) - if err != nil { - panicf("wgtest: failed to parse CIDR: %v", err) - } - - return *cidr -} - // MustHexKey decodes a hex string s as a key or panics. func MustHexKey(s string) wgtypes.Key { b, err := hex.DecodeString(s) diff --git a/internal/wguser/configure_test.go b/internal/wguser/configure_test.go index 7058c8b..a5dbf38 100644 --- a/internal/wguser/configure_test.go +++ b/internal/wguser/configure_test.go @@ -2,7 +2,7 @@ package wguser import ( "errors" - "net" + "net/netip" "os" "testing" "time" @@ -109,8 +109,8 @@ func TestClientConfigureDeviceOK(t *testing.T) { PresharedKey: keyPtr(wgtest.MustHexKey("188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52")), Endpoint: wgtest.MustUDPAddr("[abcd:23::33%2]:51820"), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.4/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.4/32"), }, }, { @@ -119,17 +119,17 @@ func TestClientConfigureDeviceOK(t *testing.T) { Endpoint: wgtest.MustUDPAddr("182.122.22.19:3233"), PersistentKeepaliveInterval: durPtr(111 * time.Second), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.6/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.6/32"), }, }, { PublicKey: wgtest.MustHexKey("662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58"), Endpoint: wgtest.MustUDPAddr("5.152.198.39:51820"), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.10/32"), - wgtest.MustCIDR("192.168.4.11/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.10/32"), + netip.MustParsePrefix("192.168.4.11/32"), }, }, { diff --git a/internal/wguser/parse.go b/internal/wguser/parse.go index dc996b2..e68d3e2 100644 --- a/internal/wguser/parse.go +++ b/internal/wguser/parse.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "net/netip" "os" "strconv" "time" @@ -243,16 +244,16 @@ func (dp *deviceParser) parseAddr(s string) *net.UDPAddr { } // parseInt parses an address CIDR from a string. -func (dp *deviceParser) parseCIDR(s string) *net.IPNet { +func (dp *deviceParser) parseCIDR(s string) *netip.Prefix { if dp.err != nil { return nil } - _, cidr, err := net.ParseCIDR(s) + prefix, err := netip.ParsePrefix(s) if err != nil { dp.err = err return nil } - return cidr + return &prefix } diff --git a/internal/wguser/parse_test.go b/internal/wguser/parse_test.go index 79e6b97..563a77a 100644 --- a/internal/wguser/parse_test.go +++ b/internal/wguser/parse_test.go @@ -2,6 +2,7 @@ package wguser import ( "net" + "net/netip" "testing" "time" @@ -99,11 +100,8 @@ func TestClientDevices(t *testing.T) { Zone: "2", }, LastHandshakeTime: time.Unix(1, 2), - AllowedIPs: []net.IPNet{ - { - IP: net.IP{0xc0, 0xa8, 0x4, 0x4}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff}, - }, + AllowedIPs: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{0xc0, 0xa8, 0x4, 0x4}), 32), }, }, { @@ -119,11 +117,8 @@ func TestClientDevices(t *testing.T) { PersistentKeepaliveInterval: 111000000000, ReceiveBytes: 2224, TransmitBytes: 38333, - AllowedIPs: []net.IPNet{ - { - IP: net.IP{0xc0, 0xa8, 0x4, 0x6}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff}, - }, + AllowedIPs: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{0xc0, 0xa8, 0x4, 0x6}), 32), }, }, { @@ -134,15 +129,9 @@ func TestClientDevices(t *testing.T) { }, ReceiveBytes: 1929999999, TransmitBytes: 1212111, - AllowedIPs: []net.IPNet{ - { - IP: net.IP{0xc0, 0xa8, 0x4, 0xa}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff}, - }, - { - IP: net.IP{0xc0, 0xa8, 0x4, 0xb}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff}, - }, + AllowedIPs: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{0xc0, 0xa8, 0x4, 0xa}), 32), + netip.PrefixFrom(netip.AddrFrom4([4]byte{0xc0, 0xa8, 0x4, 0xb}), 32), }, ProtocolVersion: 1, }, diff --git a/wgtypes/types.go b/wgtypes/types.go index 3b33b54..fe35e81 100644 --- a/wgtypes/types.go +++ b/wgtypes/types.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "fmt" "net" + "net/netip" "time" "golang.org/x/crypto/curve25519" @@ -195,7 +196,7 @@ type Peer struct { // // 0.0.0.0/0 indicates that all IPv4 addresses are allowed, and ::/0 // indicates that all IPv6 addresses are allowed. - AllowedIPs []net.IPNet + AllowedIPs []netip.Prefix // ProtocolVersion specifies which version of the WireGuard protocol is used // for this Peer. @@ -272,5 +273,5 @@ type PeerConfig struct { // AllowedIPs specifies a list of allowed IP addresses in CIDR notation // for this peer. - AllowedIPs []net.IPNet + AllowedIPs []netip.Prefix }