diff --git a/main.go b/main.go index a3b71d0..36f1054 100644 --- a/main.go +++ b/main.go @@ -3,16 +3,17 @@ package main import ( "context" "fmt" - "github.com/cloverstd/tcping/ping" - "github.com/cloverstd/tcping/ping/http" - "github.com/cloverstd/tcping/ping/tcp" - "github.com/spf13/cobra" "net" "net/url" "os" "os/signal" "strconv" "syscall" + + "github.com/cloverstd/tcping/ping" + "github.com/cloverstd/tcping/ping/http" + "github.com/cloverstd/tcping/ping/tcp" + "github.com/spf13/cobra" ) var ( @@ -79,7 +80,8 @@ var rootCmd = cobra.Command{ cmd.Printf("%s is invalid port.\n", defaultPort) return } - url.Host = fmt.Sprintf("%s:%d", url.Hostname(), port) + url.Host = ping.GetUrlHost(url.Hostname(), port) + timeoutDuration, err := ping.ParseDuration(timeout) if err != nil { @@ -110,7 +112,11 @@ var rootCmd = cobra.Command{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) { for _, addr := range dnsServer { - if conn, err = net.Dial("udp", addr+":53"); err != nil { + ipAddr, err := ping.FormatIP(addr) + if err != nil { + ipAddr = addr + } + if conn, err = net.Dial("udp", ipAddr+":53"); err != nil { continue } else { return conn, nil diff --git a/ping/tcp/tcp.go b/ping/tcp/tcp.go index f74ce1b..9de5a68 100644 --- a/ping/tcp/tcp.go +++ b/ping/tcp/tcp.go @@ -61,17 +61,17 @@ func (p *Ping) Ping(ctx context.Context) *ping.Stats { tlsErr error ) if p.tls { - tlsConn, err = tls.DialWithDialer(p.dialer, "tcp", fmt.Sprintf("%s:%d", p.host, p.port), &tls.Config{ + tlsConn, err = tls.DialWithDialer(p.dialer, "tcp", ping.GetUrlHost(p.host, p.port), &tls.Config{ InsecureSkipVerify: true, }) if err == nil { conn = tlsConn.NetConn() } else { tlsErr = err - conn, err = p.dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", p.host, p.port)) + conn, err = p.dialer.DialContext(ctx, "tcp", ping.GetUrlHost(p.host, p.port)) } } else { - conn, err = p.dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", p.host, p.port)) + conn, err = p.dialer.DialContext(ctx, "tcp", ping.GetUrlHost(p.host, p.port)) } stats.Duration = time.Since(start) if err != nil { diff --git a/ping/utils.go b/ping/utils.go index 2c7a987..fc83485 100644 --- a/ping/utils.go +++ b/ping/utils.go @@ -30,6 +30,7 @@ func FormatIP(IP string) (string, error) { return "", fmt.Errorf("error IP format") } + // ParseDuration parse the t as time.Duration, it will parse t as mills when missing unit. func ParseDuration(t string) (time.Duration, error) { if timeout, err := strconv.ParseInt(t, 10, 64); err == nil { @@ -38,8 +39,26 @@ func ParseDuration(t string) (time.Duration, error) { return time.ParseDuration(t) } +func GetUrlHost(host string, port int) string { + if port <= 0 { + return host + } + if ipAddr, err := FormatIP(host); err == nil { + return fmt.Sprintf("%s:%d", ipAddr, port) + } + + return fmt.Sprintf("%s:%d", host, port) + +} + // ParseAddress will try to parse addr as url.URL. func ParseAddress(addr string) (*url.URL, error) { + if ipAddr, err := FormatIP(addr); err == nil { + return &url.URL{ + Scheme: "tcp", + Host: ipAddr, + }, nil + } if strings.Contains(addr, "://") { // it maybe with scheme, try url.Parse return url.Parse(addr)