Skip to content

Commit

Permalink
Add internal/net test (#615)
Browse files Browse the repository at this point in the history
* refactor and add net test

* fix

* revert changes

* add comments

* fix type

* fix

* fix deepsource

* add comment

* fix CI
  • Loading branch information
kevindiu committed Aug 13, 2020
1 parent 13c5365 commit aca45ea
Show file tree
Hide file tree
Showing 3 changed files with 537 additions and 237 deletions.
5 changes: 5 additions & 0 deletions internal/errors/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,9 @@ var (
ErrInvalidDNSConfig = func(dnsRefreshDur, dnsCacheExp time.Duration) error {
return Errorf("dnsRefreshDuration > dnsCacheExp, %s, %s", dnsRefreshDur, dnsCacheExp)
}

// net

// ErrNoPortAvailiable defines no port available error
ErrNoPortAvailable = New("no port available")
)
72 changes: 59 additions & 13 deletions internal/net/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"syscall"

"github.com/vdaas/vald/internal/errgroup"
"github.com/vdaas/vald/internal/errors"
"github.com/vdaas/vald/internal/log"
)

const (
Expand All @@ -36,30 +38,48 @@ const (
defaultPort = 80
)

type Conn = net.Conn
type Dialer = net.Dialer
type ListenConfig = net.ListenConfig
type Listener = net.Listener
type Resolver = net.Resolver
type (
// Conn is an alias of net.Conn
Conn = net.Conn

// Dialer is an alias of net.Dialer
Dialer = net.Dialer

// ListenConfig is an alias of net.ListenConfig
ListenConfig = net.ListenConfig

// Listener is an alias of net.Listener
Listener = net.Listener

// Resolver is an alias of net.Resolver
Resolver = net.Resolver
)

var (
// DefaultResolver is an alias of net.DefaultResolver
DefaultResolver = net.DefaultResolver
)

// Listen is a wrapper function of the net.Listen function.
func Listen(network, address string) (Listener, error) {
return net.Listen(network, address)
}

// IsLocal returns if the host is the localhost address.
func IsLocal(host string) bool {
return host == localHost ||
host == localIPv4 ||
host == localIPv6
}

func Dial(network string, addr string) (conn Conn, err error) {
// Dial is a wrapper function of the net.Dial function.
func Dial(network, addr string) (conn Conn, err error) {
return net.Dial(network, addr)
}

// Parse parses the hostname, IPv4 or IPv6 address and return the hostname/IP, port number,
// whether the address is IP, and any parsing error occurred.
// The address should contains the port number, otherwise an error will return.
func Parse(addr string) (host string, port uint16, isIP bool, err error) {
host, port, err = SplitHostPort(addr)
isIP = IsIPv6(host) || IsIPv4(host)
Expand All @@ -69,18 +89,27 @@ func Parse(addr string) (host string, port uint16, isIP bool, err error) {
return host, port, isIP, err
}

// IsIPv6 returns weather the address is IPv6 address.
func IsIPv6(addr string) bool {
return net.ParseIP(addr) != nil && strings.Count(addr, ":") >= 2
}

// IsIPv4 returns weather the address is IPv4 address.
func IsIPv4(addr string) bool {
return net.ParseIP(addr) != nil && strings.Count(addr, ":") < 2
}

// SplitHostPort splits the address, and return the host/IP address and the port number,
// and any error occurred.
// If it is the loopback address, it will return the loopback address and corresponding port number.
// IPv6 loopback address is not supported yet.
// For more information, please read https://github.com/vdaas/vald/projects/3#card-43504189
func SplitHostPort(hostport string) (host string, port uint16, err error) {
switch {
/* TODO: IPv6 loopback address support
case strings.HasPrefix(hostport, "::"):
hostport = localIPv6 + hostport
*/
case strings.HasPrefix(hostport, ":"):
hostport = localIPv4 + hostport
}
Expand All @@ -98,16 +127,23 @@ func SplitHostPort(hostport string) (host string, port uint16, err error) {
return host, port, err
}

// ScanPorts scans the given range of port numbers from the host (inclusively),
// and return the list of ports that can be connected through TCP, or any error occurred.
func ScanPorts(ctx context.Context, start, end uint16, host string) (ports []uint16, err error) {
if start > end {
start, end = end, start
}

var rl syscall.Rlimit
err = syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rl)
if err != nil {
if err = syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rl); err != nil {
return nil, err
}
eg, egctx := errgroup.New(ctx)
eg.Limitation(int(rl.Max) / 2)

var mu sync.Mutex
for i := start; i <= end; i++ {

for i := start; i >= start && i <= end; i++ {
port := i
eg.Go(func() error {
select {
Expand All @@ -116,19 +152,29 @@ func ScanPorts(ctx context.Context, start, end uint16, host string) (ports []uin
default:
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port))
if err != nil {
return err
log.Warn(err)
return nil
}

mu.Lock()
ports = append(ports, port)
mu.Unlock()
return conn.Close()

if err = conn.Close(); err != nil {
log.Warn(err)
}
return nil
}
})
}
err = eg.Wait()
if err != nil {

if err = eg.Wait(); err != nil {
return nil, err
}

if len(ports) == 0 {
return nil, errors.ErrNoPortAvailable
}

return ports, nil
}
Loading

0 comments on commit aca45ea

Please sign in to comment.