mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-01 23:38:50 +02:00
343 lines
10 KiB
Go
343 lines
10 KiB
Go
package wireguard
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
"github.com/sagernet/sing-box/common/dialer"
|
|
"github.com/sagernet/sing-tun"
|
|
"github.com/sagernet/sing/common"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
F "github.com/sagernet/sing/common/format"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
"github.com/sagernet/sing/common/x/list"
|
|
"github.com/sagernet/sing/service"
|
|
"github.com/sagernet/sing/service/pause"
|
|
"github.com/sagernet/wireguard-go/conn"
|
|
"github.com/sagernet/wireguard-go/device"
|
|
|
|
"go4.org/netipx"
|
|
)
|
|
|
|
type Endpoint struct {
|
|
options EndpointOptions
|
|
peers []peerConfig
|
|
ipcConf string
|
|
allowedAddress []netip.Prefix
|
|
tunDevice Device
|
|
natDevice NatDevice
|
|
device *device.Device
|
|
allowedIPs *device.AllowedIPs
|
|
pause pause.Manager
|
|
pauseCallback *list.Element[pause.Callback]
|
|
}
|
|
|
|
func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
|
|
if options.PrivateKey == "" {
|
|
return nil, E.New("missing private key")
|
|
}
|
|
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "decode private key")
|
|
}
|
|
privateKey := hex.EncodeToString(privateKeyBytes)
|
|
ipcConf := "private_key=" + privateKey
|
|
if options.ListenPort != 0 {
|
|
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
|
|
}
|
|
var peers []peerConfig
|
|
for peerIndex, rawPeer := range options.Peers {
|
|
peer := peerConfig{
|
|
allowedIPs: rawPeer.AllowedIPs,
|
|
keepalive: rawPeer.PersistentKeepaliveInterval,
|
|
}
|
|
if rawPeer.Endpoint.Addr.IsValid() {
|
|
peer.endpoint = rawPeer.Endpoint.AddrPort()
|
|
} else if rawPeer.Endpoint.IsDomain() {
|
|
peer.destination = rawPeer.Endpoint
|
|
}
|
|
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
|
|
}
|
|
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
|
|
if rawPeer.PreSharedKey != "" {
|
|
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
|
|
}
|
|
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
|
|
}
|
|
if len(rawPeer.AllowedIPs) == 0 {
|
|
return nil, E.New("missing allowed ips for peer ", peerIndex)
|
|
}
|
|
peers = append(peers, peer)
|
|
}
|
|
var allowedPrefixBuilder netipx.IPSetBuilder
|
|
for _, peer := range options.Peers {
|
|
for _, prefix := range peer.AllowedIPs {
|
|
allowedPrefixBuilder.AddPrefix(prefix)
|
|
}
|
|
}
|
|
allowedIPSet, err := allowedPrefixBuilder.IPSet()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allowedAddresses := allowedIPSet.Prefixes()
|
|
if options.MTU == 0 {
|
|
options.MTU = 1408
|
|
}
|
|
deviceOptions := DeviceOptions{
|
|
Context: options.Context,
|
|
Logger: options.Logger,
|
|
System: options.System,
|
|
Handler: options.Handler,
|
|
UDPTimeout: options.UDPTimeout,
|
|
CreateDialer: options.CreateDialer,
|
|
Name: options.Name,
|
|
MTU: options.MTU,
|
|
Address: options.Address,
|
|
AllowedAddress: allowedAddresses,
|
|
}
|
|
tunDevice, err := NewDevice(deviceOptions)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "create WireGuard device")
|
|
}
|
|
natDevice, isNatDevice := tunDevice.(NatDevice)
|
|
if !isNatDevice {
|
|
natDevice = NewNATDevice(options.Context, options.Logger, tunDevice)
|
|
}
|
|
return &Endpoint{
|
|
options: options,
|
|
peers: peers,
|
|
ipcConf: ipcConf,
|
|
allowedAddress: allowedAddresses,
|
|
tunDevice: tunDevice,
|
|
natDevice: natDevice,
|
|
}, nil
|
|
}
|
|
|
|
func (e *Endpoint) Start(resolve bool) error {
|
|
if common.Any(e.peers, func(peer peerConfig) bool {
|
|
return !peer.endpoint.IsValid() && peer.destination.IsDomain()
|
|
}) {
|
|
if !resolve {
|
|
return nil
|
|
}
|
|
for peerIndex, peer := range e.peers {
|
|
if peer.endpoint.IsValid() || !peer.destination.IsDomain() {
|
|
continue
|
|
}
|
|
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
|
|
if err != nil {
|
|
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
|
|
}
|
|
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
|
|
}
|
|
} else if resolve {
|
|
return nil
|
|
}
|
|
var bind conn.Bind
|
|
wgListener, isWgListener := common.Cast[dialer.WireGuardListener](e.options.Dialer)
|
|
if isWgListener {
|
|
bind = conn.NewStdNetBind(wgListener.WireGuardControl())
|
|
} else {
|
|
var (
|
|
isConnect bool
|
|
connectAddr netip.AddrPort
|
|
)
|
|
if len(e.peers) == 1 && e.peers[0].endpoint.IsValid() {
|
|
isConnect = true
|
|
connectAddr = e.peers[0].endpoint
|
|
}
|
|
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr)
|
|
}
|
|
err := e.tunDevice.Start()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
logger := &device.Logger{
|
|
Verbosef: func(format string, args ...any) {
|
|
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
|
|
},
|
|
Errorf: func(format string, args ...any) {
|
|
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
|
|
},
|
|
}
|
|
var deviceInput Device
|
|
if e.natDevice != nil {
|
|
deviceInput = e.natDevice
|
|
} else {
|
|
deviceInput = e.tunDevice
|
|
}
|
|
wgDevice := device.NewDevice(e.options.Context, deviceInput, bind, logger, e.options.Workers, e.options.PreallocatedBuffersPerPool, e.options.DisablePauses)
|
|
e.tunDevice.SetDevice(wgDevice)
|
|
var ipcConf strings.Builder
|
|
ipcConf.WriteString(e.ipcConf)
|
|
if e.options.Amnezia != nil {
|
|
if e.options.Amnezia.JC > 0 {
|
|
ipcConf.WriteString("\njc=" + strconv.Itoa(e.options.Amnezia.JC))
|
|
}
|
|
if e.options.Amnezia.JMin > 0 {
|
|
ipcConf.WriteString("\njmin=" + strconv.Itoa(e.options.Amnezia.JMin))
|
|
}
|
|
if e.options.Amnezia.JMax > 0 {
|
|
ipcConf.WriteString("\njmax=" + strconv.Itoa(e.options.Amnezia.JMax))
|
|
}
|
|
if e.options.Amnezia.S1 > 0 {
|
|
ipcConf.WriteString("\ns1=" + strconv.Itoa(e.options.Amnezia.S1))
|
|
}
|
|
if e.options.Amnezia.S2 > 0 {
|
|
ipcConf.WriteString("\ns2=" + strconv.Itoa(e.options.Amnezia.S2))
|
|
}
|
|
if e.options.Amnezia.S3 > 0 {
|
|
ipcConf.WriteString("\ns3=" + strconv.Itoa(e.options.Amnezia.S3))
|
|
}
|
|
if e.options.Amnezia.S4 > 0 {
|
|
ipcConf.WriteString("\ns4=" + strconv.Itoa(e.options.Amnezia.S4))
|
|
}
|
|
if e.options.Amnezia.H1 != nil {
|
|
ipcConf.WriteString("\nh1=" + e.options.Amnezia.H1.String())
|
|
}
|
|
if e.options.Amnezia.H2 != nil {
|
|
ipcConf.WriteString("\nh2=" + e.options.Amnezia.H2.String())
|
|
}
|
|
if e.options.Amnezia.H3 != nil {
|
|
ipcConf.WriteString("\nh3=" + e.options.Amnezia.H3.String())
|
|
}
|
|
if e.options.Amnezia.H4 != nil {
|
|
ipcConf.WriteString("\nh4=" + e.options.Amnezia.H4.String())
|
|
}
|
|
if e.options.Amnezia.I1 != "" {
|
|
ipcConf.WriteString("\ni1=" + e.options.Amnezia.I1)
|
|
}
|
|
if e.options.Amnezia.I2 != "" {
|
|
ipcConf.WriteString("\ni2=" + e.options.Amnezia.I2)
|
|
}
|
|
if e.options.Amnezia.I3 != "" {
|
|
ipcConf.WriteString("\ni3=" + e.options.Amnezia.I3)
|
|
}
|
|
if e.options.Amnezia.I4 != "" {
|
|
ipcConf.WriteString("\ni4=" + e.options.Amnezia.I4)
|
|
}
|
|
if e.options.Amnezia.I5 != "" {
|
|
ipcConf.WriteString("\ni5=" + e.options.Amnezia.I5)
|
|
}
|
|
if e.options.Amnezia.J1 != "" {
|
|
ipcConf.WriteString("\nj1=" + e.options.Amnezia.J1)
|
|
}
|
|
if e.options.Amnezia.J2 != "" {
|
|
ipcConf.WriteString("\nj2=" + e.options.Amnezia.J2)
|
|
}
|
|
if e.options.Amnezia.J3 != "" {
|
|
ipcConf.WriteString("\nj3=" + e.options.Amnezia.J3)
|
|
}
|
|
if e.options.Amnezia.ITime > 0 {
|
|
ipcConf.WriteString("\nitime=" + strconv.FormatInt(e.options.Amnezia.ITime, 10))
|
|
}
|
|
}
|
|
for _, peer := range e.peers {
|
|
ipcConf.WriteString(peer.GenerateIpcLines())
|
|
}
|
|
err = wgDevice.IpcSet(ipcConf.String())
|
|
if err != nil {
|
|
wgDevice.Close()
|
|
return E.Cause(err, "setup wireguard: \n", ipcConf.String())
|
|
}
|
|
e.device = wgDevice
|
|
e.pause = service.FromContext[pause.Manager](e.options.Context)
|
|
if e.pause != nil {
|
|
e.pauseCallback = e.pause.RegisterCallback(e.onPauseUpdated)
|
|
}
|
|
e.allowedIPs = (*device.AllowedIPs)(unsafe.Pointer(reflect.Indirect(reflect.ValueOf(wgDevice)).FieldByName("allowedips").UnsafeAddr()))
|
|
return nil
|
|
}
|
|
|
|
func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
|
if !destination.Addr.IsValid() {
|
|
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
|
}
|
|
return e.tunDevice.DialContext(ctx, network, destination)
|
|
}
|
|
|
|
func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
|
if !destination.Addr.IsValid() {
|
|
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
|
}
|
|
return e.tunDevice.ListenPacket(ctx, destination)
|
|
}
|
|
|
|
func (e *Endpoint) Close() error {
|
|
if e.pauseCallback != nil {
|
|
e.pause.UnregisterCallback(e.pauseCallback)
|
|
e.pauseCallback = nil
|
|
}
|
|
if e.device != nil {
|
|
e.device.Down()
|
|
e.device.Close()
|
|
e.device = nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *Endpoint) Lookup(address netip.Addr) *device.Peer {
|
|
if e.allowedIPs == nil {
|
|
return nil
|
|
}
|
|
return e.allowedIPs.Lookup(address.AsSlice())
|
|
}
|
|
|
|
func (e *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
|
|
if e.natDevice == nil {
|
|
return nil, os.ErrInvalid
|
|
}
|
|
return e.natDevice.CreateDestination(metadata, routeContext, timeout)
|
|
}
|
|
|
|
func (e *Endpoint) onPauseUpdated(event int) {
|
|
switch event {
|
|
case pause.EventDevicePaused, pause.EventNetworkPause:
|
|
e.device.Down()
|
|
case pause.EventDeviceWake, pause.EventNetworkWake:
|
|
e.device.Up()
|
|
}
|
|
}
|
|
|
|
type peerConfig struct {
|
|
destination M.Socksaddr
|
|
endpoint netip.AddrPort
|
|
publicKeyHex string
|
|
preSharedKeyHex string
|
|
allowedIPs []netip.Prefix
|
|
keepalive uint16
|
|
}
|
|
|
|
func (c peerConfig) GenerateIpcLines() string {
|
|
var ipcLines strings.Builder
|
|
ipcLines.WriteString("\npublic_key=" + c.publicKeyHex)
|
|
if c.endpoint.IsValid() {
|
|
ipcLines.WriteString("\nendpoint=" + c.endpoint.String())
|
|
}
|
|
if c.preSharedKeyHex != "" {
|
|
ipcLines.WriteString("\npreshared_key=" + c.preSharedKeyHex)
|
|
}
|
|
for _, allowedIP := range c.allowedIPs {
|
|
ipcLines.WriteString("\nallowed_ip=" + allowedIP.String())
|
|
}
|
|
if c.keepalive > 0 {
|
|
ipcLines.WriteString("\npersistent_keepalive_interval=" + F.ToString(c.keepalive))
|
|
}
|
|
return ipcLines.String()
|
|
}
|