fix: dns server boostrapping with underlying dns servers

#1241
This commit is contained in:
zaneschepke
2026-05-24 01:15:35 -04:00
parent 68dc57422c
commit bf432cca0d
12 changed files with 475 additions and 388 deletions
@@ -32,6 +32,7 @@ import kotlinx.coroutines.sync.withLock
class TunnelCoordinator(
private val tunnelProvider: TunnelProvider,
private val serviceManager: ServiceManager,
private val bootstrapCoordinator: AppBoostrapCoordinator,
settingsRepository: GeneralSettingRepository,
private val tunnelRepository: TunnelRepository,
dnsSettingsRepository: RoomDnsSettingsRepository,
@@ -86,7 +87,11 @@ class TunnelCoordinator(
suspend fun startTunnel(
config: TunnelConfig,
source: TunnelActionSource = TunnelActionSource.USER,
) = tunnelMutex.withLock { startTunnelInternal(config, source) }
) = tunnelMutex.withLock {
// wait for app to be bootstrapped
bootstrapCoordinator.isReady.first { it }
startTunnelInternal(config, source)
}
suspend fun stopTunnel(id: Int, source: TunnelActionSource = TunnelActionSource.USER) =
tunnelMutex.withLock {
@@ -24,6 +24,7 @@ val coordinatorModule = module {
get(),
get(),
get(),
get(),
get(named(Scope.APPLICATION)),
)
}
@@ -3,25 +3,27 @@ package com.zaneschepke.tunnel
import com.zaneschepke.tunnel.model.DnsBootstrapResult
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import org.json.JSONObject
internal object DnsConfigManager {
private external fun setDNSConfig(configJson: String)
fun update(protocol: String, upstream: String) {
val config =
JSONObject().apply {
put("protocol", protocol)
put("upstream", upstream)
}
setDNSConfig(config.toString())
}
private external fun resolveBootstrap(
host: String,
protocol: String,
upstream: String,
underlyingDnsServers: String,
bypass: Int,
): String
private external fun resolveBootstrap(host: String, bypass: Int): String
suspend fun resolveHostBootstrap(host: String, bypass: Boolean): DnsBootstrapResult =
suspend fun resolveHostBootstrap(
host: String,
protocol: String,
upstream: String,
underlyingDnsServers: String,
bypass: Boolean,
): DnsBootstrapResult =
withContext(Dispatchers.IO) {
val raw = resolveBootstrap(host, if (bypass) 1 else 0)
val bypassOption = if (bypass) 1 else 0
val raw = resolveBootstrap(host, protocol, upstream, underlyingDnsServers, bypassOption)
if (raw.startsWith("ERR|")) {
throw RuntimeException(raw.removePrefix("ERR|"))
@@ -5,6 +5,7 @@ import com.zaneschepke.tunnel.DnsConfigManager
import com.zaneschepke.tunnel.Tunnel
import com.zaneschepke.tunnel.event.ActorEvent
import com.zaneschepke.tunnel.event.ActorEvent.ActiveConfigUpdated
import com.zaneschepke.tunnel.event.ActorEvent.BootstrapConfigUpdated
import com.zaneschepke.tunnel.event.ActorEvent.BootstrapStateChanged
import com.zaneschepke.tunnel.event.ActorEvent.EngineStatus
import com.zaneschepke.tunnel.event.ActorEvent.KillSwitchStateChanged
@@ -12,7 +13,9 @@ import com.zaneschepke.tunnel.event.ActorEvent.PeersUpdated
import com.zaneschepke.tunnel.event.ActorEvent.ResolvedPeersApplied
import com.zaneschepke.tunnel.event.ActorEvent.TunnelStarted
import com.zaneschepke.tunnel.event.ActorEvent.TunnelStopped
import com.zaneschepke.tunnel.event.ActorEvent.UnderlyingDnsServersUpdated
import com.zaneschepke.tunnel.event.TunnelEvent
import com.zaneschepke.tunnel.event.TunnelEvent.NoRootShellAccess
import com.zaneschepke.tunnel.model.BackendMode
import com.zaneschepke.tunnel.model.DnsBootstrapResult
import com.zaneschepke.tunnel.model.PublicKey
@@ -122,6 +125,10 @@ internal class TunnelActor(
engine.stop(runtime.running.handle, runtime.running.mode)
}
is TunnelCommand.SetBootstrapConfig -> {
apply(BootstrapConfigUpdated(cmd.config))
}
is TunnelCommand.UpdatePeers -> {
val runtime = _state.value.byTunnelId[cmd.tunnelId] ?: continue
val running = runtime.running
@@ -209,9 +216,7 @@ internal class TunnelActor(
} catch (t: Throwable) {
Timber.w(t, "Root shell commands failed")
if (t is RootShellException.NoRootAccess) {
_events.emit(
TunnelEvent.NoRootShellAccess(tunnelId = cmd.tunnelId)
)
_events.emit(NoRootShellAccess(tunnelId = cmd.tunnelId))
}
} finally {
if (isPostDown) {
@@ -219,6 +224,10 @@ internal class TunnelActor(
}
}
}
is TunnelCommand.UpdateUnderlyingDnsServers -> {
apply(UnderlyingDnsServersUpdated(cmd.servers))
}
}
} catch (t: Throwable) {
Timber.e(t, "Tunnel command failed: $cmd")
@@ -621,6 +630,14 @@ internal class TunnelActor(
is KillSwitchStateChanged -> {
state.copy(killSwitchEnabled = event.enabled)
}
is BootstrapConfigUpdated -> {
state.copy(dnsConfig = event.config)
}
is UnderlyingDnsServersUpdated -> {
state.copy(dnsConfig = state.dnsConfig.copy(underlyingDnsServers = event.servers))
}
}
}
@@ -667,9 +684,17 @@ internal class TunnelActor(
val endpoint = peer.endpoint ?: continue
val host = endpoint.substringBeforeLast(":")
val dnsConfig = state.value.dnsConfig
val dnsResult =
try {
DnsConfigManager.resolveHostBootstrap(host = host, bypass = bypassNeeded)
DnsConfigManager.resolveHostBootstrap(
host = host,
dnsConfig.protocol,
dnsConfig.upstream,
dnsConfig.underlyingDnsServers,
bypass = bypassNeeded,
)
} catch (e: Exception) {
Timber.w(e, "DNS failed for $host")
continue
@@ -1,8 +1,9 @@
package com.zaneschepke.tunnel.backend
import com.zaneschepke.networkmonitor.ActiveNetwork
import com.zaneschepke.networkmonitor.DnsInfo
import com.zaneschepke.networkmonitor.NetworkMonitor
import com.zaneschepke.networkmonitor.PrivateDnsMode
import com.zaneschepke.tunnel.DnsConfigManager
import com.zaneschepke.tunnel.NotificationProvider
import com.zaneschepke.tunnel.Tunnel
import com.zaneschepke.tunnel.event.TunnelEvent
@@ -14,6 +15,7 @@ import com.zaneschepke.tunnel.model.TunnelCommand
import com.zaneschepke.tunnel.service.VpnService
import com.zaneschepke.tunnel.state.BackendStatus
import com.zaneschepke.tunnel.state.KillSwitchState
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
import com.zaneschepke.tunnel.util.BackendException
import java.lang.ref.WeakReference
import kotlin.reflect.KClass
@@ -27,8 +29,10 @@ import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.distinctUntilChangedBy
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeoutOrNull
import org.koin.java.KoinJavaComponent.inject
import timber.log.Timber
@@ -191,24 +195,90 @@ class TunnelBackend(
override suspend fun setBootstrapDnsMode(mode: DnsBoostrapMode) {
_status.update { it.copy(dnsMode = mode) }
dnsConfigJob?.cancel()
dnsConfigJob = null
when (mode) {
is DnsBoostrapMode.Custom -> {
Timber.d("DNS Boostrap mode set to custom, disabling system dns monitoring")
dnsConfigJob?.cancel()
dnsConfigJob = null
DnsConfigManager.update(
mode.config.protocol,
mode.config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
Timber.d(
"DNS Bootstrap mode set to custom: ${mode.config.protocol} -> ${mode.config.upstream}"
)
emitInitialDnsConfig(mode.config.protocol, mode.config.upstream)
startUnderlyingDnsMonitoring()
}
DnsBoostrapMode.System -> {
Timber.d("DNS Bootstrap mode set to System")
emitInitialDnsConfig()
startSystemDnsMonitoring()
}
}
}
private suspend fun emitInitialDnsConfig(protocol: String? = null, upstream: String? = null) {
val state =
withTimeoutOrNull(2_500L) {
networkMonitor.connectivityStateFlow.first { connectivityState ->
val dns = connectivityState.underlyingDnsInfo
dns.servers.isNotEmpty() ||
connectivityState.activeNetwork is ActiveNetwork.Disconnected
}
} ?: networkMonitor.connectivityStateFlow.firstOrNull()
val dns = state?.underlyingDnsInfo ?: DnsInfo()
val finalProtocol =
protocol
?: when (dns.privateDnsMode) {
PrivateDnsMode.HOSTNAME -> "dot"
else -> "plain"
}
val finalUpstream =
upstream
?: when (dns.privateDnsMode) {
PrivateDnsMode.HOSTNAME ->
dns.privateDnsHostname?.takeIf { it.isNotBlank() }
?: DnsBoostrapConfig.DEFAULT_UPSTREAM
else -> dns.servers.firstOrNull() ?: DnsBoostrapConfig.DEFAULT_UPSTREAM
}
val underlying =
dns.servers.joinToString(",").ifBlank { DnsBoostrapConfig.DEFAULT_UPSTREAM }
Timber.d(
"DNS initial emission: protocol=$finalProtocol upstream=$finalUpstream underlying=$underlying"
)
actor.send(
TunnelCommand.SetBootstrapConfig(
RuntimeDnsConfig(
protocol = finalProtocol,
upstream = finalUpstream,
underlyingDnsServers = underlying,
)
)
)
}
private fun startUnderlyingDnsMonitoring() {
if (dnsConfigJob?.isActive == true) return
dnsConfigJob = scope.launch {
networkMonitor.connectivityStateFlow
.distinctUntilChangedBy { it.underlyingDnsInfo.servers }
.collect { state ->
val dns = state.underlyingDnsInfo
val underlying = dns.servers.joinToString(",")
Timber.d("Underlying DNS servers changed: $underlying")
actor.send(TunnelCommand.UpdateUnderlyingDnsServers(underlying))
}
}
}
override fun emergencyStopAllOfTypeSync(modeClass: KClass<out BackendMode>) {
actor.emergencyStopAllOfType(modeClass)
}
@@ -236,13 +306,12 @@ class TunnelBackend(
val config =
when (dns.privateDnsMode) {
PrivateDnsMode.OFF,
PrivateDnsMode.AUTOMATIC -> {
PrivateDnsMode.AUTOMATIC ->
DnsBoostrapConfig.Plain(
dns.servers.firstOrNull() ?: DnsBoostrapConfig.DEFAULT_UPSTREAM
)
}
PrivateDnsMode.HOSTNAME -> {
PrivateDnsMode.HOSTNAME ->
dns.privateDnsHostname
?.takeIf { it.isNotBlank() }
?.let { DnsBoostrapConfig.DoT(it) }
@@ -250,12 +319,16 @@ class TunnelBackend(
dns.servers.firstOrNull()
?: DnsBoostrapConfig.DEFAULT_UPSTREAM
)
}
}
DnsConfigManager.update(
config.protocol,
config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
actor.send(
TunnelCommand.SetBootstrapConfig(
RuntimeDnsConfig(
protocol = config.protocol,
upstream = config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
underlyingDnsServers = dns.servers.joinToString(","),
)
)
)
}
}
@@ -6,6 +6,7 @@ import com.zaneschepke.tunnel.model.TunnelCommand
import com.zaneschepke.tunnel.state.BootstrapState
import com.zaneschepke.tunnel.state.EngineStartResult
import com.zaneschepke.tunnel.state.NativeTunnelStatus
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
import com.zaneschepke.wireguardautotunnel.parser.ActiveConfig
import com.zaneschepke.wireguardautotunnel.parser.PeerSection
@@ -23,6 +24,10 @@ sealed class ActorEvent {
val preferIpv6: Boolean,
) : ActorEvent()
data class BootstrapConfigUpdated(val config: RuntimeDnsConfig) : ActorEvent()
data class UnderlyingDnsServersUpdated(val servers: String) : ActorEvent()
data class ResolvedPeersApplied(
val tunnelId: Int,
val cache: Map<PublicKey, DnsBootstrapResult>,
@@ -17,7 +17,7 @@ sealed class DnsBoostrapConfig(open val upstream: String?) {
data class DoH(override val upstream: String?) : DnsBoostrapConfig(upstream) {
override val protocol: String
get() = "dot"
get() = "doh"
}
data class DoT(override val upstream: String?) : DnsBoostrapConfig(upstream) {
@@ -2,6 +2,7 @@ package com.zaneschepke.tunnel.model
import com.zaneschepke.tunnel.Tunnel
import com.zaneschepke.tunnel.state.BootstrapState
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
import com.zaneschepke.wireguardautotunnel.parser.PeerSection
sealed class TunnelCommand {
@@ -24,6 +25,10 @@ sealed class TunnelCommand {
data class SetBootstrapState(val tunnelId: Int, val state: BootstrapState) : TunnelCommand()
data class SetBootstrapConfig(val config: RuntimeDnsConfig) : TunnelCommand()
data class UpdateUnderlyingDnsServers(val servers: String) : TunnelCommand()
data class RunHook(val tunnelId: Int, val phase: Phase, val cmds: List<String>?) :
TunnelCommand() {
enum class Phase {
@@ -4,4 +4,5 @@ data class ActorState(
val byTunnelId: Map<Int, TunnelRuntimeState>,
val byHandle: Map<Int, Int>,
val killSwitchEnabled: Boolean = false,
val dnsConfig: RuntimeDnsConfig = RuntimeDnsConfig(),
)
@@ -0,0 +1,9 @@
package com.zaneschepke.tunnel.state
import com.zaneschepke.tunnel.model.DnsBoostrapConfig
data class RuntimeDnsConfig(
val protocol: String = "plain",
val upstream: String = DnsBoostrapConfig.DEFAULT_UPSTREAM,
val underlyingDnsServers: String = DnsBoostrapConfig.DEFAULT_UPSTREAM,
)
+286 -322
View File
@@ -9,7 +9,6 @@ import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
@@ -18,7 +17,6 @@ import (
"net/netip"
"net/url"
"strings"
"sync"
"syscall"
"time"
@@ -27,21 +25,6 @@ import (
"golang.org/x/sys/unix"
)
const defaultPlain = "udp://1.1.1.1:53"
var (
currentConfig DNSConfig = DNSConfig{
"plain",
"1.1.1.1:53",
}
configMu sync.RWMutex
)
type DNSConfig struct {
Protocol string `json:"protocol"` // plain, doh, or dot
Upstream string `json:"upstream"`
}
type Resolved struct {
V4 []netip.Addr
V6 []netip.Addr
@@ -52,36 +35,34 @@ type ResolverOptions struct {
Timeout time.Duration
}
func DefaultOptions() ResolverOptions {
return ResolverOptions{
UpstreamURL: defaultPlain,
Timeout: 5 * time.Second,
}
}
//export SetDNSConfig
func SetDNSConfig(config string) {
var cfg DNSConfig
if err := json.Unmarshal([]byte(config), &cfg); err != nil {
shared.LogError("DNS", "Failed to parse DNSConfig: %v", err)
return
}
if cfg.Protocol != "plain" && cfg.Protocol != "doh" && cfg.Protocol != "dot" {
cfg.Protocol = "plain"
}
configMu.Lock()
currentConfig = cfg
configMu.Unlock()
shared.LogDebug("DNS", "DNS config updated: %s %s", cfg.Protocol, cfg.Upstream)
type Transport interface {
Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
}
//export ResolveBootstrap
func ResolveBootstrap(host *C.char, bypass C.int) *C.char {
func ResolveBootstrap(
host *C.char,
protocol *C.char,
upstream *C.char,
underlyingDnsServers *C.char,
bypass C.int,
) *C.char {
h := C.GoString(host)
p := C.GoString(protocol)
u := C.GoString(upstream)
underlying := C.GoString(underlyingDnsServers)
bp := bypass == 1
shared.LogDebug("DNS", "ResolveBootstrap called for host=%s (bypass=%t)", h, bp)
v4, v6, err := Resolve(h, bp)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
shared.LogDebug(
"DNS",
"ResolveBootstrap called host=%s protocol=%s upstream=%s bypass=%t",
h, p, u, bp,
)
v4, v6, err := Resolve(ctx, h, p, u, bp, underlying)
if err != nil {
shared.LogError("DNS", "ResolveBootstrap failed for %s: %v", h, err)
return C.CString("ERR|" + err.Error())
@@ -96,15 +77,76 @@ func ResolveBootstrap(host *C.char, bypass C.int) *C.char {
v6Str[i] = ip.String()
}
result := "v4=" + strings.Join(v4Str, ",") + ";v6=" + strings.Join(v6Str, ",")
result := "v4=" + strings.Join(v4Str, ",") +
";v6=" + strings.Join(v6Str, ",")
shared.LogDebug("DNS", "ResolveBootstrap success for %s: %s", h, result)
return C.CString(result)
}
func getConfig() DNSConfig {
configMu.RLock()
defer configMu.RUnlock()
return currentConfig
type DoTTransport struct {
Client *dns.Client
Servers []string
}
type DoHTransport struct {
Client *http.Client
URL string
Servers []string // IPv4 first, IPv6 fallback
Hostname string // for SNI and Host header
}
type PlainTransport struct {
Client *dns.Client
Servers []string
}
func resolveHost(
ctx context.Context,
t Transport,
host string,
) (v4, v6 []netip.Addr, err error) {
a4, e4 := resolveQ(ctx, t, host, dns.TypeA)
if e4 == nil {
v4 = a4
}
a6, e6 := resolveQ(ctx, t, host, dns.TypeAAAA)
if e6 == nil {
v6 = a6
}
if len(v4) > 0 || len(v6) > 0 {
return v4, v6, nil
}
return nil, nil, errors.Join(e4, e6)
}
func resolveQ(
ctx context.Context,
t Transport,
host string,
qtype uint16,
) ([]netip.Addr, error) {
req := &dns.Msg{}
req.SetQuestion(dns.Fqdn(host), qtype)
req.SetEdns0(4096, true)
res, err := t.Query(ctx, req)
if err != nil {
return nil, err
}
if res == nil {
return nil, fmt.Errorf("nil DNS response")
}
if res.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("rcode %d", res.Rcode)
}
addrs := parseDNSAnswers(res, qtype)
if len(addrs) == 0 {
return nil, fmt.Errorf("no answers for qtype %d", qtype)
}
return addrs, nil
}
func parseUpstream(upstreamURL string) (network, address string, err error) {
@@ -113,7 +155,6 @@ func parseUpstream(upstreamURL string) (network, address string, err error) {
if !strings.Contains(u, "://") {
u = "udp://" + u
}
parsed, err := url.Parse(u)
if err != nil {
shared.LogError("DNS", "parseUpstream failed for %q: %v", upstreamURL, err)
@@ -141,189 +182,127 @@ func parseUpstream(upstreamURL string) (network, address string, err error) {
return network, address, nil
}
func resolveServerAddr(ctx context.Context, address string, bypass bool) (string, error) {
func newUnderlyingResolver(bypass bool, underlying string) *net.Resolver {
if !bypass {
return &net.Resolver{PreferGo: false}
}
rawServers := strings.Split(underlying, ",")
var servers []string
for _, s := range rawServers {
s = strings.TrimSpace(s)
if s == "" {
continue
}
if !strings.Contains(s, ":") {
s = net.JoinHostPort(s, "53")
}
servers = append(servers, s)
}
if len(servers) == 0 {
servers = []string{"1.1.1.1:53"}
}
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
for _, server := range servers {
conn, err := GetDialer(true).DialContext(ctx, network, server)
if err == nil {
shared.LogDebug("DNS", "Using underlying bootstrap resolver: %s", server)
return conn, nil
}
shared.LogDebug("DNS", "Bootstrap resolver failed for %s: %v", server, err)
}
return nil, fmt.Errorf("all underlying DNS servers failed")
},
}
}
func resolveServerAddrs(
ctx context.Context,
address string,
bypass bool,
defaultPort string,
underlying string,
) ([]string, string, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
shared.LogError("DNS", "resolveServerAddr: invalid address %q: %v", address, err)
return "", err
host = address
port = defaultPort
}
if net.ParseIP(host) != nil {
return address, nil
}
shared.LogDebug("DNS", "resolveServerAddr: bootstrapping upstream hostname %s (bypass=%t)", host, bypass)
bootstrapDialer := GetDialer(bypass)
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
return bootstrapDialer.DialContext(ctx, network, "1.1.1.1:53")
},
return []string{net.JoinHostPort(host, port)}, host, nil
}
resolver := newUnderlyingResolver(bypass, underlying)
ips, err := resolver.LookupIP(ctx, "ip", host)
if err != nil {
shared.LogError("DNS", "Failed to resolve upstream hostname %s (bypass=%t): %v", host, bypass, err)
return "", fmt.Errorf("failed to resolve upstream hostname %s: %w", host, err)
}
if len(ips) == 0 {
err = errors.New("no IPs found for upstream hostname")
shared.LogError("DNS", "%v for %s", err, host)
return "", err
shared.LogError("DNS", "Failed to resolve upstream %s (bypass=%t): %v", host, bypass, err)
return nil, "", err
}
addr := net.JoinHostPort(ips[0].String(), port)
shared.LogDebug("DNS", "Resolved upstream %s -> %s", host, addr)
return addr, nil
}
func resolveInner(host string, ipType uint16, network, serverAddr string, bypass bool) ([]netip.Addr, error) {
req := &dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.SetQuestion(dns.Fqdn(host), ipType)
req.SetEdns0(4096, true)
client := &dns.Client{
Net: network,
Dialer: GetDialer(bypass),
Timeout: 5 * time.Second,
UDPSize: 4096,
}
res, _, err := client.Exchange(req, serverAddr)
if err != nil {
shared.LogError("DNS", "resolveInner: DNS exchange failed for %s (type=%d, server=%s, bypass=%t): %v", host, ipType, serverAddr, bypass, err)
return nil, err
}
if res.Rcode != dns.RcodeSuccess {
shared.LogError("DNS", "resolveInner: DNS query failed with Rcode %d for %s", res.Rcode, host)
return nil, fmt.Errorf("DNS query failed with Rcode: %d", res.Rcode)
}
var addr []netip.Addr
for _, ans := range res.Answer {
switch ipType {
case dns.TypeA:
if a, ok := ans.(*dns.A); ok {
if ip, err := netip.ParseAddr(a.A.String()); err == nil {
addr = append(addr, ip)
}
}
case dns.TypeAAAA:
if aaaa, ok := ans.(*dns.AAAA); ok {
if ip, err := netip.ParseAddr(aaaa.AAAA.String()); err == nil {
addr = append(addr, ip)
}
}
var v4, v6 []string
for _, ip := range ips {
addr := net.JoinHostPort(ip.String(), port)
if ip.To4() != nil {
v4 = append(v4, addr)
} else {
v6 = append(v6, addr)
}
}
return addr, nil
return append(v4, v6...), host, nil
}
func resolvePlain(host, upstreamURL string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
shared.LogDebug("DNS", "resolvePlain: %s with upstream=%s (bypass=%t)", host, upstreamURL, bypass)
network, addr, err := parseUpstream(upstreamURL)
if err != nil {
return nil, nil, err
func (t PlainTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
for _, server := range t.Servers {
m, _, err := t.Client.Exchange(msg, server)
if err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
return m, nil
}
}
serverAddr, err := resolveServerAddr(context.Background(), addr, bypass)
if err != nil {
return nil, nil, err
}
var wg sync.WaitGroup
var v4, v6 []netip.Addr
var v4Err, v6Err error
wg.Add(2)
go func() { v4, v4Err = resolveInner(host, dns.TypeA, network, serverAddr, bypass); wg.Done() }()
go func() { v6, v6Err = resolveInner(host, dns.TypeAAAA, network, serverAddr, bypass); wg.Done() }()
wg.Wait()
if v4Err != nil && v6Err != nil {
shared.LogError("DNS", "resolvePlain failed for %s: both A and AAAA failed", host)
return nil, nil, errors.Join(v4Err, v6Err)
}
if len(v4) == 0 && len(v6) == 0 {
err = errors.New("no IP addresses found")
shared.LogError("DNS", "%v for %s", err, host)
return nil, nil, err
}
return v4, v6, nil
return nil, fmt.Errorf("all DNS servers failed")
}
func resolveDoH(host, dohURL string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
shared.LogDebug("DNS", "Resolving DOH: %s with %s", host, dohURL)
var v4, v6 []netip.Addr
var v4Err, v6Err error
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
v4, v4Err = doDoHQuery(dohURL, host, dns.TypeA, bypass)
}()
go func() {
defer wg.Done()
v6, v6Err = doDoHQuery(dohURL, host, dns.TypeAAAA, bypass)
}()
wg.Wait()
if v4Err != nil && v6Err != nil {
shared.LogError("DNS", "resolveDoH failed for %s: both queries failed", host)
return nil, nil, errors.Join(v4Err, v6Err)
func (t DoTTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
for _, server := range t.Servers {
m, _, err := t.Client.Exchange(msg, server)
if err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
return m, nil
}
}
return v4, v6, nil
return nil, fmt.Errorf("all DoT servers failed")
}
func doDoHQuery(dohURL, host string, qtype uint16, bypass bool) ([]netip.Addr, error) {
req := &dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.SetEdns0(4096, true)
req.SetQuestion(dns.Fqdn(host), qtype)
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
h, port, _ := net.SplitHostPort(addr)
if net.ParseIP(h) == nil {
ips, err := CustomResolver(bypass).LookupIP(ctx, "ip", h)
if err == nil && len(ips) > 0 {
h = ips[0].String()
}
}
return GetDialer(bypass).DialContext(ctx, network, net.JoinHostPort(h, port))
},
}
client := &http.Client{
Transport: transport,
Timeout: 5 * time.Second,
}
wire, err := req.Pack()
func (t DoHTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
wire, err := msg.Pack()
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(context.Background(), "POST", dohURL, bytes.NewReader(wire))
req, err := http.NewRequestWithContext(
ctx, "POST", t.URL, bytes.NewReader(wire),
)
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/dns-message")
httpReq.Header.Set("Accept", "application/dns-message")
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
req.Host = t.Hostname // important for virtual hosting and cert validation
resp, err := client.Do(httpReq)
resp, err := t.Client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
shared.LogError("DNS", "doDoHQuery: DoH server returned HTTP %d for %s", resp.StatusCode, host)
return nil, fmt.Errorf("DoH HTTP %d", resp.StatusCode)
return nil, fmt.Errorf("doh status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
@@ -335,158 +314,143 @@ func doDoHQuery(dohURL, host string, qtype uint16, bypass bool) ([]netip.Addr, e
if err := res.Unpack(body); err != nil {
return nil, err
}
if res.Rcode != dns.RcodeSuccess {
shared.LogError("DNS", "doDoHQuery: DoH Rcode %d for %s", res.Rcode, host)
return nil, fmt.Errorf("DoH Rcode %d", res.Rcode)
}
var addrs []netip.Addr
for _, ans := range res.Answer {
if qtype == dns.TypeA {
if a, ok := ans.(*dns.A); ok {
if ip, _ := netip.ParseAddr(a.A.String()); ip.Is4() {
addrs = append(addrs, ip)
}
}
} else if qtype == dns.TypeAAAA {
if aaaa, ok := ans.(*dns.AAAA); ok {
if ip, _ := netip.ParseAddr(aaaa.AAAA.String()); ip.Is6() {
addrs = append(addrs, ip)
}
}
}
}
return addrs, nil
return &res, nil
}
func resolveDoT(host, dotUpstream string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
shared.LogDebug("DNS", "resolveDoT: %s with upstream=%s (bypass=%t)", host, dotUpstream, bypass)
var v4, v6 []netip.Addr
var v4Err, v6Err error
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
v4, v4Err = doDoTQuery(dotUpstream, host, dns.TypeA, bypass)
}()
go func() {
defer wg.Done()
v6, v6Err = doDoTQuery(dotUpstream, host, dns.TypeAAAA, bypass)
}()
wg.Wait()
if v4Err != nil && v6Err != nil {
shared.LogError("DNS", "resolveDoT failed for %s: both A and AAAA queries failed (bypass=%t)", host, bypass)
return nil, nil, errors.Join(v4Err, v6Err)
}
shared.LogDebug("DNS", "resolveDoT success for %s: %d v4, %d v6 (bypass=%t)", host, len(v4), len(v6), bypass)
return v4, v6, nil
}
func doDoTQuery(dotUpstream, host string, qtype uint16, bypass bool) ([]netip.Addr, error) {
// Normalize upstream to host:port
sni, port, err := net.SplitHostPort(dotUpstream)
if err != nil {
sni = dotUpstream
port = "853"
dotUpstream = net.JoinHostPort(sni, port)
}
// Resolve hostname using bypass resolver
serverAddr, err := resolveServerAddr(context.Background(), dotUpstream, bypass)
if err != nil {
return nil, err
}
req := &dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.SetEdns0(4096, true)
req.SetQuestion(dns.Fqdn(host), qtype)
client := &dns.Client{
Net: "tcp-tls",
Dialer: GetDialer(bypass),
Timeout: 5 * time.Second,
TLSConfig: &tls.Config{
ServerName: sni,
InsecureSkipVerify: false,
},
}
res, _, err := client.Exchange(req, serverAddr)
if err != nil {
return nil, err
}
if res.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("DoT query failed with Rcode: %d", res.Rcode)
}
var addrs []netip.Addr
for _, ans := range res.Answer {
func parseDNSAnswers(msg *dns.Msg, qtype uint16) []netip.Addr {
var out []netip.Addr
for _, ans := range msg.Answer {
switch qtype {
case dns.TypeA:
if a, ok := ans.(*dns.A); ok {
if ip, _ := netip.ParseAddr(a.A.String()); ip.Is4() {
addrs = append(addrs, ip)
if ip, err := netip.ParseAddr(a.A.String()); err == nil {
out = append(out, ip)
}
}
case dns.TypeAAAA:
if aaaa, ok := ans.(*dns.AAAA); ok {
if ip, _ := netip.ParseAddr(aaaa.AAAA.String()); ip.Is6() {
addrs = append(addrs, ip)
if ip, err := netip.ParseAddr(aaaa.AAAA.String()); err == nil {
out = append(out, ip)
}
}
}
}
return addrs, nil
return out
}
// Resolve runs the correct protocol based on the global config
func Resolve(host string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
cfg := getConfig()
shared.LogDebug("DNS", "Resolve(%s, bypass=%t) protocol=%s upstream=%s", host, bypass, cfg.Protocol, cfg.Upstream)
var v4, v6 []netip.Addr
var err error
switch cfg.Protocol {
case "doh":
v4, v6, err = resolveDoH(host, cfg.Upstream, bypass)
case "dot":
v4, v6, err = resolveDoT(host, cfg.Upstream, bypass)
default:
v4, v6, err = resolvePlain(host, cfg.Upstream, bypass)
}
func Resolve(
ctx context.Context,
host, protocol, upstream string,
bypass bool,
underlying string,
) ([]netip.Addr, []netip.Addr, error) {
t, err := buildTransport(ctx, protocol, upstream, bypass, underlying)
if err != nil {
shared.LogError("DNS", "Final Resolve failed for %s: %v", host, err)
} else {
shared.LogDebug("DNS", "Resolve success for %s: %d v4, %d v6", host, len(v4), len(v6))
return nil, nil, err
}
return v4, v6, err
return resolveHost(ctx, t, host)
}
func CustomResolver(bypass bool) *net.Resolver {
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := GetDialer(bypass)
return d.DialContext(ctx, network, address)
},
func buildTransport(
ctx context.Context,
protocol, upstream string,
bypass bool,
underlying string,
) (Transport, error) {
switch protocol {
case "doh":
u, err := url.Parse(upstream)
if err != nil {
return nil, err
}
hostname := u.Hostname()
port := u.Port()
if port == "" {
port = "443"
}
u.Host = net.JoinHostPort(hostname, port)
// Pre-resolve with IPv4-first ordering + bypass
servers, _, err := resolveServerAddrs(ctx, u.Host, bypass, "443", underlying)
if err != nil {
return nil, err
}
if len(servers) == 0 {
return nil, fmt.Errorf("no addresses resolved for DoH server")
}
// Custom dialer that tries servers in order (IPv4 → IPv6)
dialer := GetDialer(bypass)
transport := &http.Transport{
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
for _, addr := range servers {
conn, err := dialer.DialContext(ctx, network, addr)
if err == nil {
return conn, nil
}
}
return nil, fmt.Errorf("all DoH addresses failed")
},
TLSClientConfig: &tls.Config{
ServerName: hostname,
},
}
return DoHTransport{
Client: &http.Client{Timeout: 5 * time.Second, Transport: transport},
URL: u.String(),
Servers: servers,
Hostname: hostname,
}, nil
case "dot":
servers, sni, err := resolveServerAddrs(ctx, upstream, bypass, "853", underlying)
if err != nil {
return nil, err
}
if len(servers) == 0 {
return nil, fmt.Errorf("no addresses resolved for DoT server")
}
client := &dns.Client{
Net: "tcp-tls",
Dialer: GetDialer(bypass),
Timeout: 5 * time.Second,
TLSConfig: &tls.Config{
ServerName: sni,
},
}
return DoTTransport{
Client: client,
Servers: servers,
}, nil
default: // plain DNS
_, addr, err := parseUpstream(upstream)
if err != nil {
return nil, err
}
servers, _, err := resolveServerAddrs(ctx, addr, bypass, "53", underlying)
if err != nil {
return nil, err
}
client := &dns.Client{
Net: "udp",
Dialer: GetDialer(bypass),
Timeout: 5 * time.Second,
}
return PlainTransport{
Client: client,
Servers: servers,
}, nil
}
}
func GetDialer(bypass bool) *net.Dialer {
if !bypass {
return &net.Dialer{
LocalAddr: nil,
}
return &net.Dialer{LocalAddr: nil}
}
return &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
var opErr error
+28 -31
View File
@@ -3,54 +3,51 @@
struct go_string { const char *str; long n; };
extern void SetDNSConfig(struct go_string handle);
extern char* ResolveBootstrap(const char* host, int bypass);
JNIEXPORT void JNICALL Java_com_zaneschepke_tunnel_DnsConfigManager_setDNSConfig(
JNIEnv* env, jclass clazz, jstring json)
{
if (json == NULL) {
return;
}
const char* cjson = (*env)->GetStringUTFChars(env, json, 0);
if (cjson != NULL) {
size_t len = (*env)->GetStringUTFLength(env, json);
SetDNSConfig((struct go_string){
.str = cjson,
.n = (long)len
});
(*env)->ReleaseStringUTFChars(env, json, cjson);
}
}
extern char* ResolveBootstrap(
const char* host,
const char* protocol,
const char* upstream,
const char* underlyingDnsServers,
int bypass);
JNIEXPORT jstring JNICALL
Java_com_zaneschepke_tunnel_DnsConfigManager_resolveBootstrap(
JNIEnv* env,
jclass clazz,
jstring host,
jboolean bypass)
jstring protocol,
jstring upstream,
jstring underlyingDnsServers,
jint bypass)
{
if (host == NULL) {
return (*env)->NewStringUTF(env, "{\"error\":\"invalid host\"}");
if (host == NULL || protocol == NULL || upstream == NULL || underlyingDnsServers == NULL) {
return (*env)->NewStringUTF(env, "ERR|invalid arguments");
}
const char* chost = (*env)->GetStringUTFChars(env, host, NULL);
if (chost == NULL) {
return (*env)->NewStringUTF(env, "{\"error\":\"out of memory\"}");
const char* chost = (*env)->GetStringUTFChars(env, host, NULL);
const char* cprotocol = (*env)->GetStringUTFChars(env, protocol, NULL);
const char* cupstream = (*env)->GetStringUTFChars(env, upstream, NULL);
const char* cunderlying = (*env)->GetStringUTFChars(env, underlyingDnsServers, NULL);
if (chost == NULL || cprotocol == NULL || cupstream == NULL || cunderlying == NULL) {
return (*env)->NewStringUTF(env, "ERR|out of memory");
}
char* resultC = ResolveBootstrap(
(char*)chost,
bypass ? 1 : 0
chost,
cprotocol,
cupstream,
cunderlying,
bypass ? 1 : 0
);
(*env)->ReleaseStringUTFChars(env, host, chost);
(*env)->ReleaseStringUTFChars(env, protocol, cprotocol);
(*env)->ReleaseStringUTFChars(env, upstream, cupstream);
(*env)->ReleaseStringUTFChars(env, underlyingDnsServers, cunderlying);
if (resultC == NULL) {
return (*env)->NewStringUTF(env, "{\"error\":\"null response\"}");
return (*env)->NewStringUTF(env, "ERR|null response");
}
jstring jresult = (*env)->NewStringUTF(env, resultC);