Skip to content

Commit

Permalink
Feature: bind socket to interface by native API on Windows (#2662)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaling888 authored Apr 8, 2023
1 parent 95bbfe3 commit 20a521f
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 50 deletions.
12 changes: 11 additions & 1 deletion component/dhcp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,15 @@ func ListenDHCPClient(ctx context.Context, ifaceName string) (net.PacketConn, er
listenAddr = "255.255.255.255:68"
}

return dialer.ListenPacket(ctx, "udp4", listenAddr, dialer.WithInterface(ifaceName), dialer.WithAddrReuse(true))
options := []dialer.Option{
dialer.WithInterface(ifaceName),
dialer.WithAddrReuse(true),
}

// fallback bind on windows, because syscall bind can not receive broadcast
if runtime.GOOS == "windows" {
options = append(options, dialer.WithFallbackBind(true))
}

return dialer.ListenPacket(ctx, "udp4", listenAddr, options...)
}
47 changes: 1 addition & 46 deletions component/dialer/bind_others.go
Original file line number Diff line number Diff line change
@@ -1,57 +1,12 @@
//go:build !linux && !darwin
//go:build !linux && !darwin && !windows

package dialer

import (
"net"
"strconv"
"strings"

"github.com/Dreamacro/clash/component/iface"
)

func lookupLocalAddr(ifaceName string, network string, destination net.IP, port int) (net.Addr, error) {
ifaceObj, err := iface.ResolveInterface(ifaceName)
if err != nil {
return nil, err
}

var addr *net.IPNet
switch network {
case "udp4", "tcp4":
addr, err = ifaceObj.PickIPv4Addr(destination)
case "tcp6", "udp6":
addr, err = ifaceObj.PickIPv6Addr(destination)
default:
if destination != nil {
if destination.To4() != nil {
addr, err = ifaceObj.PickIPv4Addr(destination)
} else {
addr, err = ifaceObj.PickIPv6Addr(destination)
}
} else {
addr, err = ifaceObj.PickIPv4Addr(destination)
}
}
if err != nil {
return nil, err
}

if strings.HasPrefix(network, "tcp") {
return &net.TCPAddr{
IP: addr.IP,
Port: port,
}, nil
} else if strings.HasPrefix(network, "udp") {
return &net.UDPAddr{
IP: addr.IP,
Port: port,
}, nil
}

return nil, iface.ErrAddrNotFound
}

func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, network string, destination net.IP) error {
if !destination.IsGlobalUnicast() {
return nil
Expand Down
98 changes: 98 additions & 0 deletions component/dialer/bind_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package dialer

import (
"encoding/binary"
"net"
"strings"
"syscall"
"unsafe"

"github.com/Dreamacro/clash/component/iface"

"golang.org/x/sys/windows"
)

const (
IP_UNICAST_IF = 31
IPV6_UNICAST_IF = 31
)

type controlFn = func(network, address string, c syscall.RawConn) error

func bindControl(ifaceIdx int, chain controlFn) controlFn {
return func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()

ipStr, _, err := net.SplitHostPort(address)
if err == nil {
ip := net.ParseIP(ipStr)
if ip != nil && !ip.IsGlobalUnicast() {
return
}
}

var innerErr error
err = c.Control(func(fd uintptr) {
if ipStr == "" && strings.HasPrefix(network, "udp") {
// When listening udp ":0", we should bind socket to interface4 and interface6 at the same time
// and ignore the error of bind6
_ = bindSocketToInterface6(windows.Handle(fd), ifaceIdx)
innerErr = bindSocketToInterface4(windows.Handle(fd), ifaceIdx)
return
}
switch network {
case "tcp4", "udp4":
innerErr = bindSocketToInterface4(windows.Handle(fd), ifaceIdx)
case "tcp6", "udp6":
innerErr = bindSocketToInterface6(windows.Handle(fd), ifaceIdx)
}
})

if innerErr != nil {
err = innerErr
}

return
}
}

func bindSocketToInterface4(handle windows.Handle, ifaceIdx int) error {
// MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros.
// Ref: https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
var bytes [4]byte
binary.BigEndian.PutUint32(bytes[:], uint32(ifaceIdx))
index := *(*uint32)(unsafe.Pointer(&bytes[0]))
err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(index))
if err != nil {
return err
}
return nil
}

func bindSocketToInterface6(handle windows.Handle, ifaceIdx int) error {
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, ifaceIdx)
}

func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ net.IP) error {
ifaceObj, err := iface.ResolveInterface(ifaceName)
if err != nil {
return err
}

dialer.Control = bindControl(ifaceObj.Index, dialer.Control)
return nil
}

func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address string) (string, error) {
ifaceObj, err := iface.ResolveInterface(ifaceName)
if err != nil {
return "", err
}

lc.Control = bindControl(ifaceObj.Index, lc.Control)
return address, nil
}
20 changes: 17 additions & 3 deletions component/dialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,15 @@ func ListenPacket(ctx context.Context, network, address string, options ...Optio

lc := &net.ListenConfig{}
if cfg.interfaceName != "" {
addr, err := bindIfaceToListenConfig(cfg.interfaceName, lc, network, address)
var (
addr string
err error
)
if cfg.fallbackBind {
addr, err = fallbackBindIfaceToListenConfig(cfg.interfaceName, lc, network, address)
} else {
addr, err = bindIfaceToListenConfig(cfg.interfaceName, lc, network, address)
}
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -83,8 +91,14 @@ func dialContext(ctx context.Context, network string, destination net.IP, port s

dialer := &net.Dialer{}
if opt.interfaceName != "" {
if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil {
return nil, err
if opt.fallbackBind {
if err := fallbackBindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil {
return nil, err
}
} else {
if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil {
return nil, err
}
}
}
if opt.routingMark != 0 {
Expand Down
90 changes: 90 additions & 0 deletions component/dialer/fallbackbind.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package dialer

import (
"net"
"strconv"
"strings"

"github.com/Dreamacro/clash/component/iface"
)

func lookupLocalAddr(ifaceName string, network string, destination net.IP, port int) (net.Addr, error) {
ifaceObj, err := iface.ResolveInterface(ifaceName)
if err != nil {
return nil, err
}

var addr *net.IPNet
switch network {
case "udp4", "tcp4":
addr, err = ifaceObj.PickIPv4Addr(destination)
case "tcp6", "udp6":
addr, err = ifaceObj.PickIPv6Addr(destination)
default:
if destination != nil {
if destination.To4() != nil {
addr, err = ifaceObj.PickIPv4Addr(destination)
} else {
addr, err = ifaceObj.PickIPv6Addr(destination)
}
} else {
addr, err = ifaceObj.PickIPv4Addr(destination)
}
}
if err != nil {
return nil, err
}

if strings.HasPrefix(network, "tcp") {
return &net.TCPAddr{
IP: addr.IP,
Port: port,
}, nil
} else if strings.HasPrefix(network, "udp") {
return &net.UDPAddr{
IP: addr.IP,
Port: port,
}, nil
}

return nil, iface.ErrAddrNotFound
}

func fallbackBindIfaceToDialer(ifaceName string, dialer *net.Dialer, network string, destination net.IP) error {
if !destination.IsGlobalUnicast() {
return nil
}

local := uint64(0)
if dialer.LocalAddr != nil {
_, port, err := net.SplitHostPort(dialer.LocalAddr.String())
if err == nil {
local, _ = strconv.ParseUint(port, 10, 16)
}
}

addr, err := lookupLocalAddr(ifaceName, network, destination, int(local))
if err != nil {
return err
}

dialer.LocalAddr = addr

return nil
}

func fallbackBindIfaceToListenConfig(ifaceName string, _ *net.ListenConfig, network, address string) (string, error) {
_, port, err := net.SplitHostPort(address)
if err != nil {
port = "0"
}

local, _ := strconv.ParseUint(port, 10, 16)

addr, err := lookupLocalAddr(ifaceName, network, nil, int(local))
if err != nil {
return "", err
}

return addr.String(), nil
}
7 changes: 7 additions & 0 deletions component/dialer/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ var (

type option struct {
interfaceName string
fallbackBind bool
addrReuse bool
routingMark int
}
Expand All @@ -22,6 +23,12 @@ func WithInterface(name string) Option {
}
}

func WithFallbackBind(fallback bool) Option {
return func(opt *option) {
opt.fallbackBind = fallback
}
}

func WithAddrReuse(reuse bool) Option {
return func(opt *option) {
opt.addrReuse = reuse
Expand Down

0 comments on commit 20a521f

Please sign in to comment.