fix: dns server boostrapping with underlying dns servers

#1241
This commit is contained in:
zaneschepke
2026-05-24 01:15:35 -04:00
parent 68dc57422c
commit bf432cca0d
12 changed files with 475 additions and 388 deletions
@@ -32,6 +32,7 @@ import kotlinx.coroutines.sync.withLock
class TunnelCoordinator( class TunnelCoordinator(
private val tunnelProvider: TunnelProvider, private val tunnelProvider: TunnelProvider,
private val serviceManager: ServiceManager, private val serviceManager: ServiceManager,
private val bootstrapCoordinator: AppBoostrapCoordinator,
settingsRepository: GeneralSettingRepository, settingsRepository: GeneralSettingRepository,
private val tunnelRepository: TunnelRepository, private val tunnelRepository: TunnelRepository,
dnsSettingsRepository: RoomDnsSettingsRepository, dnsSettingsRepository: RoomDnsSettingsRepository,
@@ -86,7 +87,11 @@ class TunnelCoordinator(
suspend fun startTunnel( suspend fun startTunnel(
config: TunnelConfig, config: TunnelConfig,
source: TunnelActionSource = TunnelActionSource.USER, source: TunnelActionSource = TunnelActionSource.USER,
) = tunnelMutex.withLock { startTunnelInternal(config, source) } ) = tunnelMutex.withLock {
// wait for app to be bootstrapped
bootstrapCoordinator.isReady.first { it }
startTunnelInternal(config, source)
}
suspend fun stopTunnel(id: Int, source: TunnelActionSource = TunnelActionSource.USER) = suspend fun stopTunnel(id: Int, source: TunnelActionSource = TunnelActionSource.USER) =
tunnelMutex.withLock { tunnelMutex.withLock {
@@ -24,6 +24,7 @@ val coordinatorModule = module {
get(), get(),
get(), get(),
get(), get(),
get(),
get(named(Scope.APPLICATION)), get(named(Scope.APPLICATION)),
) )
} }
@@ -3,25 +3,27 @@ package com.zaneschepke.tunnel
import com.zaneschepke.tunnel.model.DnsBootstrapResult import com.zaneschepke.tunnel.model.DnsBootstrapResult
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import org.json.JSONObject
internal object DnsConfigManager { internal object DnsConfigManager {
private external fun setDNSConfig(configJson: String)
fun update(protocol: String, upstream: String) { private external fun resolveBootstrap(
val config = host: String,
JSONObject().apply { protocol: String,
put("protocol", protocol) upstream: String,
put("upstream", upstream) underlyingDnsServers: String,
} bypass: Int,
setDNSConfig(config.toString()) ): String
}
private external fun resolveBootstrap(host: String, bypass: Int): String suspend fun resolveHostBootstrap(
host: String,
suspend fun resolveHostBootstrap(host: String, bypass: Boolean): DnsBootstrapResult = protocol: String,
upstream: String,
underlyingDnsServers: String,
bypass: Boolean,
): DnsBootstrapResult =
withContext(Dispatchers.IO) { withContext(Dispatchers.IO) {
val raw = resolveBootstrap(host, if (bypass) 1 else 0) val bypassOption = if (bypass) 1 else 0
val raw = resolveBootstrap(host, protocol, upstream, underlyingDnsServers, bypassOption)
if (raw.startsWith("ERR|")) { if (raw.startsWith("ERR|")) {
throw RuntimeException(raw.removePrefix("ERR|")) throw RuntimeException(raw.removePrefix("ERR|"))
@@ -5,6 +5,7 @@ import com.zaneschepke.tunnel.DnsConfigManager
import com.zaneschepke.tunnel.Tunnel import com.zaneschepke.tunnel.Tunnel
import com.zaneschepke.tunnel.event.ActorEvent import com.zaneschepke.tunnel.event.ActorEvent
import com.zaneschepke.tunnel.event.ActorEvent.ActiveConfigUpdated import com.zaneschepke.tunnel.event.ActorEvent.ActiveConfigUpdated
import com.zaneschepke.tunnel.event.ActorEvent.BootstrapConfigUpdated
import com.zaneschepke.tunnel.event.ActorEvent.BootstrapStateChanged import com.zaneschepke.tunnel.event.ActorEvent.BootstrapStateChanged
import com.zaneschepke.tunnel.event.ActorEvent.EngineStatus import com.zaneschepke.tunnel.event.ActorEvent.EngineStatus
import com.zaneschepke.tunnel.event.ActorEvent.KillSwitchStateChanged import com.zaneschepke.tunnel.event.ActorEvent.KillSwitchStateChanged
@@ -12,7 +13,9 @@ import com.zaneschepke.tunnel.event.ActorEvent.PeersUpdated
import com.zaneschepke.tunnel.event.ActorEvent.ResolvedPeersApplied import com.zaneschepke.tunnel.event.ActorEvent.ResolvedPeersApplied
import com.zaneschepke.tunnel.event.ActorEvent.TunnelStarted import com.zaneschepke.tunnel.event.ActorEvent.TunnelStarted
import com.zaneschepke.tunnel.event.ActorEvent.TunnelStopped import com.zaneschepke.tunnel.event.ActorEvent.TunnelStopped
import com.zaneschepke.tunnel.event.ActorEvent.UnderlyingDnsServersUpdated
import com.zaneschepke.tunnel.event.TunnelEvent import com.zaneschepke.tunnel.event.TunnelEvent
import com.zaneschepke.tunnel.event.TunnelEvent.NoRootShellAccess
import com.zaneschepke.tunnel.model.BackendMode import com.zaneschepke.tunnel.model.BackendMode
import com.zaneschepke.tunnel.model.DnsBootstrapResult import com.zaneschepke.tunnel.model.DnsBootstrapResult
import com.zaneschepke.tunnel.model.PublicKey import com.zaneschepke.tunnel.model.PublicKey
@@ -122,6 +125,10 @@ internal class TunnelActor(
engine.stop(runtime.running.handle, runtime.running.mode) engine.stop(runtime.running.handle, runtime.running.mode)
} }
is TunnelCommand.SetBootstrapConfig -> {
apply(BootstrapConfigUpdated(cmd.config))
}
is TunnelCommand.UpdatePeers -> { is TunnelCommand.UpdatePeers -> {
val runtime = _state.value.byTunnelId[cmd.tunnelId] ?: continue val runtime = _state.value.byTunnelId[cmd.tunnelId] ?: continue
val running = runtime.running val running = runtime.running
@@ -209,9 +216,7 @@ internal class TunnelActor(
} catch (t: Throwable) { } catch (t: Throwable) {
Timber.w(t, "Root shell commands failed") Timber.w(t, "Root shell commands failed")
if (t is RootShellException.NoRootAccess) { if (t is RootShellException.NoRootAccess) {
_events.emit( _events.emit(NoRootShellAccess(tunnelId = cmd.tunnelId))
TunnelEvent.NoRootShellAccess(tunnelId = cmd.tunnelId)
)
} }
} finally { } finally {
if (isPostDown) { if (isPostDown) {
@@ -219,6 +224,10 @@ internal class TunnelActor(
} }
} }
} }
is TunnelCommand.UpdateUnderlyingDnsServers -> {
apply(UnderlyingDnsServersUpdated(cmd.servers))
}
} }
} catch (t: Throwable) { } catch (t: Throwable) {
Timber.e(t, "Tunnel command failed: $cmd") Timber.e(t, "Tunnel command failed: $cmd")
@@ -621,6 +630,14 @@ internal class TunnelActor(
is KillSwitchStateChanged -> { is KillSwitchStateChanged -> {
state.copy(killSwitchEnabled = event.enabled) state.copy(killSwitchEnabled = event.enabled)
} }
is BootstrapConfigUpdated -> {
state.copy(dnsConfig = event.config)
}
is UnderlyingDnsServersUpdated -> {
state.copy(dnsConfig = state.dnsConfig.copy(underlyingDnsServers = event.servers))
}
} }
} }
@@ -667,9 +684,17 @@ internal class TunnelActor(
val endpoint = peer.endpoint ?: continue val endpoint = peer.endpoint ?: continue
val host = endpoint.substringBeforeLast(":") val host = endpoint.substringBeforeLast(":")
val dnsConfig = state.value.dnsConfig
val dnsResult = val dnsResult =
try { try {
DnsConfigManager.resolveHostBootstrap(host = host, bypass = bypassNeeded) DnsConfigManager.resolveHostBootstrap(
host = host,
dnsConfig.protocol,
dnsConfig.upstream,
dnsConfig.underlyingDnsServers,
bypass = bypassNeeded,
)
} catch (e: Exception) { } catch (e: Exception) {
Timber.w(e, "DNS failed for $host") Timber.w(e, "DNS failed for $host")
continue continue
@@ -1,8 +1,9 @@
package com.zaneschepke.tunnel.backend package com.zaneschepke.tunnel.backend
import com.zaneschepke.networkmonitor.ActiveNetwork
import com.zaneschepke.networkmonitor.DnsInfo
import com.zaneschepke.networkmonitor.NetworkMonitor import com.zaneschepke.networkmonitor.NetworkMonitor
import com.zaneschepke.networkmonitor.PrivateDnsMode import com.zaneschepke.networkmonitor.PrivateDnsMode
import com.zaneschepke.tunnel.DnsConfigManager
import com.zaneschepke.tunnel.NotificationProvider import com.zaneschepke.tunnel.NotificationProvider
import com.zaneschepke.tunnel.Tunnel import com.zaneschepke.tunnel.Tunnel
import com.zaneschepke.tunnel.event.TunnelEvent import com.zaneschepke.tunnel.event.TunnelEvent
@@ -14,6 +15,7 @@ import com.zaneschepke.tunnel.model.TunnelCommand
import com.zaneschepke.tunnel.service.VpnService import com.zaneschepke.tunnel.service.VpnService
import com.zaneschepke.tunnel.state.BackendStatus import com.zaneschepke.tunnel.state.BackendStatus
import com.zaneschepke.tunnel.state.KillSwitchState import com.zaneschepke.tunnel.state.KillSwitchState
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
import com.zaneschepke.tunnel.util.BackendException import com.zaneschepke.tunnel.util.BackendException
import java.lang.ref.WeakReference import java.lang.ref.WeakReference
import kotlin.reflect.KClass import kotlin.reflect.KClass
@@ -27,8 +29,10 @@ import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.distinctUntilChangedBy import kotlinx.coroutines.flow.distinctUntilChangedBy
import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeoutOrNull
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
import timber.log.Timber import timber.log.Timber
@@ -191,24 +195,90 @@ class TunnelBackend(
override suspend fun setBootstrapDnsMode(mode: DnsBoostrapMode) { override suspend fun setBootstrapDnsMode(mode: DnsBoostrapMode) {
_status.update { it.copy(dnsMode = mode) } _status.update { it.copy(dnsMode = mode) }
dnsConfigJob?.cancel()
dnsConfigJob = null
when (mode) { when (mode) {
is DnsBoostrapMode.Custom -> { is DnsBoostrapMode.Custom -> {
Timber.d("DNS Boostrap mode set to custom, disabling system dns monitoring") Timber.d(
dnsConfigJob?.cancel() "DNS Bootstrap mode set to custom: ${mode.config.protocol} -> ${mode.config.upstream}"
dnsConfigJob = null
DnsConfigManager.update(
mode.config.protocol,
mode.config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
) )
emitInitialDnsConfig(mode.config.protocol, mode.config.upstream)
startUnderlyingDnsMonitoring()
} }
DnsBoostrapMode.System -> { DnsBoostrapMode.System -> {
Timber.d("DNS Bootstrap mode set to System")
emitInitialDnsConfig()
startSystemDnsMonitoring() startSystemDnsMonitoring()
} }
} }
} }
private suspend fun emitInitialDnsConfig(protocol: String? = null, upstream: String? = null) {
val state =
withTimeoutOrNull(2_500L) {
networkMonitor.connectivityStateFlow.first { connectivityState ->
val dns = connectivityState.underlyingDnsInfo
dns.servers.isNotEmpty() ||
connectivityState.activeNetwork is ActiveNetwork.Disconnected
}
} ?: networkMonitor.connectivityStateFlow.firstOrNull()
val dns = state?.underlyingDnsInfo ?: DnsInfo()
val finalProtocol =
protocol
?: when (dns.privateDnsMode) {
PrivateDnsMode.HOSTNAME -> "dot"
else -> "plain"
}
val finalUpstream =
upstream
?: when (dns.privateDnsMode) {
PrivateDnsMode.HOSTNAME ->
dns.privateDnsHostname?.takeIf { it.isNotBlank() }
?: DnsBoostrapConfig.DEFAULT_UPSTREAM
else -> dns.servers.firstOrNull() ?: DnsBoostrapConfig.DEFAULT_UPSTREAM
}
val underlying =
dns.servers.joinToString(",").ifBlank { DnsBoostrapConfig.DEFAULT_UPSTREAM }
Timber.d(
"DNS initial emission: protocol=$finalProtocol upstream=$finalUpstream underlying=$underlying"
)
actor.send(
TunnelCommand.SetBootstrapConfig(
RuntimeDnsConfig(
protocol = finalProtocol,
upstream = finalUpstream,
underlyingDnsServers = underlying,
)
)
)
}
private fun startUnderlyingDnsMonitoring() {
if (dnsConfigJob?.isActive == true) return
dnsConfigJob = scope.launch {
networkMonitor.connectivityStateFlow
.distinctUntilChangedBy { it.underlyingDnsInfo.servers }
.collect { state ->
val dns = state.underlyingDnsInfo
val underlying = dns.servers.joinToString(",")
Timber.d("Underlying DNS servers changed: $underlying")
actor.send(TunnelCommand.UpdateUnderlyingDnsServers(underlying))
}
}
}
override fun emergencyStopAllOfTypeSync(modeClass: KClass<out BackendMode>) { override fun emergencyStopAllOfTypeSync(modeClass: KClass<out BackendMode>) {
actor.emergencyStopAllOfType(modeClass) actor.emergencyStopAllOfType(modeClass)
} }
@@ -236,13 +306,12 @@ class TunnelBackend(
val config = val config =
when (dns.privateDnsMode) { when (dns.privateDnsMode) {
PrivateDnsMode.OFF, PrivateDnsMode.OFF,
PrivateDnsMode.AUTOMATIC -> { PrivateDnsMode.AUTOMATIC ->
DnsBoostrapConfig.Plain( DnsBoostrapConfig.Plain(
dns.servers.firstOrNull() ?: DnsBoostrapConfig.DEFAULT_UPSTREAM dns.servers.firstOrNull() ?: DnsBoostrapConfig.DEFAULT_UPSTREAM
) )
}
PrivateDnsMode.HOSTNAME -> { PrivateDnsMode.HOSTNAME ->
dns.privateDnsHostname dns.privateDnsHostname
?.takeIf { it.isNotBlank() } ?.takeIf { it.isNotBlank() }
?.let { DnsBoostrapConfig.DoT(it) } ?.let { DnsBoostrapConfig.DoT(it) }
@@ -250,12 +319,16 @@ class TunnelBackend(
dns.servers.firstOrNull() dns.servers.firstOrNull()
?: DnsBoostrapConfig.DEFAULT_UPSTREAM ?: DnsBoostrapConfig.DEFAULT_UPSTREAM
) )
}
} }
DnsConfigManager.update( actor.send(
config.protocol, TunnelCommand.SetBootstrapConfig(
config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM, RuntimeDnsConfig(
protocol = config.protocol,
upstream = config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
underlyingDnsServers = dns.servers.joinToString(","),
)
)
) )
} }
} }
@@ -6,6 +6,7 @@ import com.zaneschepke.tunnel.model.TunnelCommand
import com.zaneschepke.tunnel.state.BootstrapState import com.zaneschepke.tunnel.state.BootstrapState
import com.zaneschepke.tunnel.state.EngineStartResult import com.zaneschepke.tunnel.state.EngineStartResult
import com.zaneschepke.tunnel.state.NativeTunnelStatus import com.zaneschepke.tunnel.state.NativeTunnelStatus
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
import com.zaneschepke.wireguardautotunnel.parser.ActiveConfig import com.zaneschepke.wireguardautotunnel.parser.ActiveConfig
import com.zaneschepke.wireguardautotunnel.parser.PeerSection import com.zaneschepke.wireguardautotunnel.parser.PeerSection
@@ -23,6 +24,10 @@ sealed class ActorEvent {
val preferIpv6: Boolean, val preferIpv6: Boolean,
) : ActorEvent() ) : ActorEvent()
data class BootstrapConfigUpdated(val config: RuntimeDnsConfig) : ActorEvent()
data class UnderlyingDnsServersUpdated(val servers: String) : ActorEvent()
data class ResolvedPeersApplied( data class ResolvedPeersApplied(
val tunnelId: Int, val tunnelId: Int,
val cache: Map<PublicKey, DnsBootstrapResult>, val cache: Map<PublicKey, DnsBootstrapResult>,
@@ -17,7 +17,7 @@ sealed class DnsBoostrapConfig(open val upstream: String?) {
data class DoH(override val upstream: String?) : DnsBoostrapConfig(upstream) { data class DoH(override val upstream: String?) : DnsBoostrapConfig(upstream) {
override val protocol: String override val protocol: String
get() = "dot" get() = "doh"
} }
data class DoT(override val upstream: String?) : DnsBoostrapConfig(upstream) { data class DoT(override val upstream: String?) : DnsBoostrapConfig(upstream) {
@@ -2,6 +2,7 @@ package com.zaneschepke.tunnel.model
import com.zaneschepke.tunnel.Tunnel import com.zaneschepke.tunnel.Tunnel
import com.zaneschepke.tunnel.state.BootstrapState import com.zaneschepke.tunnel.state.BootstrapState
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
import com.zaneschepke.wireguardautotunnel.parser.PeerSection import com.zaneschepke.wireguardautotunnel.parser.PeerSection
sealed class TunnelCommand { sealed class TunnelCommand {
@@ -24,6 +25,10 @@ sealed class TunnelCommand {
data class SetBootstrapState(val tunnelId: Int, val state: BootstrapState) : TunnelCommand() data class SetBootstrapState(val tunnelId: Int, val state: BootstrapState) : TunnelCommand()
data class SetBootstrapConfig(val config: RuntimeDnsConfig) : TunnelCommand()
data class UpdateUnderlyingDnsServers(val servers: String) : TunnelCommand()
data class RunHook(val tunnelId: Int, val phase: Phase, val cmds: List<String>?) : data class RunHook(val tunnelId: Int, val phase: Phase, val cmds: List<String>?) :
TunnelCommand() { TunnelCommand() {
enum class Phase { enum class Phase {
@@ -4,4 +4,5 @@ data class ActorState(
val byTunnelId: Map<Int, TunnelRuntimeState>, val byTunnelId: Map<Int, TunnelRuntimeState>,
val byHandle: Map<Int, Int>, val byHandle: Map<Int, Int>,
val killSwitchEnabled: Boolean = false, val killSwitchEnabled: Boolean = false,
val dnsConfig: RuntimeDnsConfig = RuntimeDnsConfig(),
) )
@@ -0,0 +1,9 @@
package com.zaneschepke.tunnel.state
import com.zaneschepke.tunnel.model.DnsBoostrapConfig
data class RuntimeDnsConfig(
val protocol: String = "plain",
val upstream: String = DnsBoostrapConfig.DEFAULT_UPSTREAM,
val underlyingDnsServers: String = DnsBoostrapConfig.DEFAULT_UPSTREAM,
)
+286 -322
View File
@@ -9,7 +9,6 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -18,7 +17,6 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
@@ -27,21 +25,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
const defaultPlain = "udp://1.1.1.1:53"
var (
currentConfig DNSConfig = DNSConfig{
"plain",
"1.1.1.1:53",
}
configMu sync.RWMutex
)
type DNSConfig struct {
Protocol string `json:"protocol"` // plain, doh, or dot
Upstream string `json:"upstream"`
}
type Resolved struct { type Resolved struct {
V4 []netip.Addr V4 []netip.Addr
V6 []netip.Addr V6 []netip.Addr
@@ -52,36 +35,34 @@ type ResolverOptions struct {
Timeout time.Duration Timeout time.Duration
} }
func DefaultOptions() ResolverOptions { type Transport interface {
return ResolverOptions{ Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
UpstreamURL: defaultPlain,
Timeout: 5 * time.Second,
}
}
//export SetDNSConfig
func SetDNSConfig(config string) {
var cfg DNSConfig
if err := json.Unmarshal([]byte(config), &cfg); err != nil {
shared.LogError("DNS", "Failed to parse DNSConfig: %v", err)
return
}
if cfg.Protocol != "plain" && cfg.Protocol != "doh" && cfg.Protocol != "dot" {
cfg.Protocol = "plain"
}
configMu.Lock()
currentConfig = cfg
configMu.Unlock()
shared.LogDebug("DNS", "DNS config updated: %s %s", cfg.Protocol, cfg.Upstream)
} }
//export ResolveBootstrap //export ResolveBootstrap
func ResolveBootstrap(host *C.char, bypass C.int) *C.char { func ResolveBootstrap(
host *C.char,
protocol *C.char,
upstream *C.char,
underlyingDnsServers *C.char,
bypass C.int,
) *C.char {
h := C.GoString(host) h := C.GoString(host)
p := C.GoString(protocol)
u := C.GoString(upstream)
underlying := C.GoString(underlyingDnsServers)
bp := bypass == 1 bp := bypass == 1
shared.LogDebug("DNS", "ResolveBootstrap called for host=%s (bypass=%t)", h, bp)
v4, v6, err := Resolve(h, bp) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
shared.LogDebug(
"DNS",
"ResolveBootstrap called host=%s protocol=%s upstream=%s bypass=%t",
h, p, u, bp,
)
v4, v6, err := Resolve(ctx, h, p, u, bp, underlying)
if err != nil { if err != nil {
shared.LogError("DNS", "ResolveBootstrap failed for %s: %v", h, err) shared.LogError("DNS", "ResolveBootstrap failed for %s: %v", h, err)
return C.CString("ERR|" + err.Error()) return C.CString("ERR|" + err.Error())
@@ -96,15 +77,76 @@ func ResolveBootstrap(host *C.char, bypass C.int) *C.char {
v6Str[i] = ip.String() v6Str[i] = ip.String()
} }
result := "v4=" + strings.Join(v4Str, ",") + ";v6=" + strings.Join(v6Str, ",") result := "v4=" + strings.Join(v4Str, ",") +
";v6=" + strings.Join(v6Str, ",")
shared.LogDebug("DNS", "ResolveBootstrap success for %s: %s", h, result) shared.LogDebug("DNS", "ResolveBootstrap success for %s: %s", h, result)
return C.CString(result) return C.CString(result)
} }
func getConfig() DNSConfig { type DoTTransport struct {
configMu.RLock() Client *dns.Client
defer configMu.RUnlock() Servers []string
return currentConfig }
type DoHTransport struct {
Client *http.Client
URL string
Servers []string // IPv4 first, IPv6 fallback
Hostname string // for SNI and Host header
}
type PlainTransport struct {
Client *dns.Client
Servers []string
}
func resolveHost(
ctx context.Context,
t Transport,
host string,
) (v4, v6 []netip.Addr, err error) {
a4, e4 := resolveQ(ctx, t, host, dns.TypeA)
if e4 == nil {
v4 = a4
}
a6, e6 := resolveQ(ctx, t, host, dns.TypeAAAA)
if e6 == nil {
v6 = a6
}
if len(v4) > 0 || len(v6) > 0 {
return v4, v6, nil
}
return nil, nil, errors.Join(e4, e6)
}
func resolveQ(
ctx context.Context,
t Transport,
host string,
qtype uint16,
) ([]netip.Addr, error) {
req := &dns.Msg{}
req.SetQuestion(dns.Fqdn(host), qtype)
req.SetEdns0(4096, true)
res, err := t.Query(ctx, req)
if err != nil {
return nil, err
}
if res == nil {
return nil, fmt.Errorf("nil DNS response")
}
if res.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("rcode %d", res.Rcode)
}
addrs := parseDNSAnswers(res, qtype)
if len(addrs) == 0 {
return nil, fmt.Errorf("no answers for qtype %d", qtype)
}
return addrs, nil
} }
func parseUpstream(upstreamURL string) (network, address string, err error) { func parseUpstream(upstreamURL string) (network, address string, err error) {
@@ -113,7 +155,6 @@ func parseUpstream(upstreamURL string) (network, address string, err error) {
if !strings.Contains(u, "://") { if !strings.Contains(u, "://") {
u = "udp://" + u u = "udp://" + u
} }
parsed, err := url.Parse(u) parsed, err := url.Parse(u)
if err != nil { if err != nil {
shared.LogError("DNS", "parseUpstream failed for %q: %v", upstreamURL, err) shared.LogError("DNS", "parseUpstream failed for %q: %v", upstreamURL, err)
@@ -141,189 +182,127 @@ func parseUpstream(upstreamURL string) (network, address string, err error) {
return network, address, nil return network, address, nil
} }
func resolveServerAddr(ctx context.Context, address string, bypass bool) (string, error) { func newUnderlyingResolver(bypass bool, underlying string) *net.Resolver {
if !bypass {
return &net.Resolver{PreferGo: false}
}
rawServers := strings.Split(underlying, ",")
var servers []string
for _, s := range rawServers {
s = strings.TrimSpace(s)
if s == "" {
continue
}
if !strings.Contains(s, ":") {
s = net.JoinHostPort(s, "53")
}
servers = append(servers, s)
}
if len(servers) == 0 {
servers = []string{"1.1.1.1:53"}
}
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
for _, server := range servers {
conn, err := GetDialer(true).DialContext(ctx, network, server)
if err == nil {
shared.LogDebug("DNS", "Using underlying bootstrap resolver: %s", server)
return conn, nil
}
shared.LogDebug("DNS", "Bootstrap resolver failed for %s: %v", server, err)
}
return nil, fmt.Errorf("all underlying DNS servers failed")
},
}
}
func resolveServerAddrs(
ctx context.Context,
address string,
bypass bool,
defaultPort string,
underlying string,
) ([]string, string, error) {
host, port, err := net.SplitHostPort(address) host, port, err := net.SplitHostPort(address)
if err != nil { if err != nil {
shared.LogError("DNS", "resolveServerAddr: invalid address %q: %v", address, err) host = address
return "", err port = defaultPort
} }
if net.ParseIP(host) != nil { if net.ParseIP(host) != nil {
return address, nil return []string{net.JoinHostPort(host, port)}, host, nil
}
shared.LogDebug("DNS", "resolveServerAddr: bootstrapping upstream hostname %s (bypass=%t)", host, bypass)
bootstrapDialer := GetDialer(bypass)
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
return bootstrapDialer.DialContext(ctx, network, "1.1.1.1:53")
},
} }
resolver := newUnderlyingResolver(bypass, underlying)
ips, err := resolver.LookupIP(ctx, "ip", host) ips, err := resolver.LookupIP(ctx, "ip", host)
if err != nil { if err != nil {
shared.LogError("DNS", "Failed to resolve upstream hostname %s (bypass=%t): %v", host, bypass, err) shared.LogError("DNS", "Failed to resolve upstream %s (bypass=%t): %v", host, bypass, err)
return "", fmt.Errorf("failed to resolve upstream hostname %s: %w", host, err) return nil, "", err
}
if len(ips) == 0 {
err = errors.New("no IPs found for upstream hostname")
shared.LogError("DNS", "%v for %s", err, host)
return "", err
} }
addr := net.JoinHostPort(ips[0].String(), port) var v4, v6 []string
shared.LogDebug("DNS", "Resolved upstream %s -> %s", host, addr) for _, ip := range ips {
return addr, nil addr := net.JoinHostPort(ip.String(), port)
} if ip.To4() != nil {
func resolveInner(host string, ipType uint16, network, serverAddr string, bypass bool) ([]netip.Addr, error) { v4 = append(v4, addr)
req := &dns.Msg{} } else {
req.Id = dns.Id() v6 = append(v6, addr)
req.RecursionDesired = true
req.SetQuestion(dns.Fqdn(host), ipType)
req.SetEdns0(4096, true)
client := &dns.Client{
Net: network,
Dialer: GetDialer(bypass),
Timeout: 5 * time.Second,
UDPSize: 4096,
}
res, _, err := client.Exchange(req, serverAddr)
if err != nil {
shared.LogError("DNS", "resolveInner: DNS exchange failed for %s (type=%d, server=%s, bypass=%t): %v", host, ipType, serverAddr, bypass, err)
return nil, err
}
if res.Rcode != dns.RcodeSuccess {
shared.LogError("DNS", "resolveInner: DNS query failed with Rcode %d for %s", res.Rcode, host)
return nil, fmt.Errorf("DNS query failed with Rcode: %d", res.Rcode)
}
var addr []netip.Addr
for _, ans := range res.Answer {
switch ipType {
case dns.TypeA:
if a, ok := ans.(*dns.A); ok {
if ip, err := netip.ParseAddr(a.A.String()); err == nil {
addr = append(addr, ip)
}
}
case dns.TypeAAAA:
if aaaa, ok := ans.(*dns.AAAA); ok {
if ip, err := netip.ParseAddr(aaaa.AAAA.String()); err == nil {
addr = append(addr, ip)
}
}
} }
} }
return addr, nil
return append(v4, v6...), host, nil
} }
func resolvePlain(host, upstreamURL string, bypass bool) ([]netip.Addr, []netip.Addr, error) { func (t PlainTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
shared.LogDebug("DNS", "resolvePlain: %s with upstream=%s (bypass=%t)", host, upstreamURL, bypass) for _, server := range t.Servers {
network, addr, err := parseUpstream(upstreamURL) m, _, err := t.Client.Exchange(msg, server)
if err != nil { if err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
return nil, nil, err return m, nil
}
} }
return nil, fmt.Errorf("all DNS servers failed")
serverAddr, err := resolveServerAddr(context.Background(), addr, bypass)
if err != nil {
return nil, nil, err
}
var wg sync.WaitGroup
var v4, v6 []netip.Addr
var v4Err, v6Err error
wg.Add(2)
go func() { v4, v4Err = resolveInner(host, dns.TypeA, network, serverAddr, bypass); wg.Done() }()
go func() { v6, v6Err = resolveInner(host, dns.TypeAAAA, network, serverAddr, bypass); wg.Done() }()
wg.Wait()
if v4Err != nil && v6Err != nil {
shared.LogError("DNS", "resolvePlain failed for %s: both A and AAAA failed", host)
return nil, nil, errors.Join(v4Err, v6Err)
}
if len(v4) == 0 && len(v6) == 0 {
err = errors.New("no IP addresses found")
shared.LogError("DNS", "%v for %s", err, host)
return nil, nil, err
}
return v4, v6, nil
} }
func resolveDoH(host, dohURL string, bypass bool) ([]netip.Addr, []netip.Addr, error) { func (t DoTTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
shared.LogDebug("DNS", "Resolving DOH: %s with %s", host, dohURL) for _, server := range t.Servers {
var v4, v6 []netip.Addr m, _, err := t.Client.Exchange(msg, server)
var v4Err, v6Err error if err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
return m, nil
var wg sync.WaitGroup }
wg.Add(2)
go func() {
defer wg.Done()
v4, v4Err = doDoHQuery(dohURL, host, dns.TypeA, bypass)
}()
go func() {
defer wg.Done()
v6, v6Err = doDoHQuery(dohURL, host, dns.TypeAAAA, bypass)
}()
wg.Wait()
if v4Err != nil && v6Err != nil {
shared.LogError("DNS", "resolveDoH failed for %s: both queries failed", host)
return nil, nil, errors.Join(v4Err, v6Err)
} }
return v4, v6, nil return nil, fmt.Errorf("all DoT servers failed")
} }
func doDoHQuery(dohURL, host string, qtype uint16, bypass bool) ([]netip.Addr, error) { func (t DoHTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
req := &dns.Msg{} wire, err := msg.Pack()
req.Id = dns.Id()
req.RecursionDesired = true
req.SetEdns0(4096, true)
req.SetQuestion(dns.Fqdn(host), qtype)
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
h, port, _ := net.SplitHostPort(addr)
if net.ParseIP(h) == nil {
ips, err := CustomResolver(bypass).LookupIP(ctx, "ip", h)
if err == nil && len(ips) > 0 {
h = ips[0].String()
}
}
return GetDialer(bypass).DialContext(ctx, network, net.JoinHostPort(h, port))
},
}
client := &http.Client{
Transport: transport,
Timeout: 5 * time.Second,
}
wire, err := req.Pack()
if err != nil { if err != nil {
return nil, err return nil, err
} }
httpReq, err := http.NewRequestWithContext(context.Background(), "POST", dohURL, bytes.NewReader(wire)) req, err := http.NewRequestWithContext(
ctx, "POST", t.URL, bytes.NewReader(wire),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
httpReq.Header.Set("Content-Type", "application/dns-message") req.Header.Set("Content-Type", "application/dns-message")
httpReq.Header.Set("Accept", "application/dns-message") req.Header.Set("Accept", "application/dns-message")
req.Host = t.Hostname // important for virtual hosting and cert validation
resp, err := client.Do(httpReq) resp, err := t.Client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
shared.LogError("DNS", "doDoHQuery: DoH server returned HTTP %d for %s", resp.StatusCode, host) return nil, fmt.Errorf("doh status %d", resp.StatusCode)
return nil, fmt.Errorf("DoH HTTP %d", resp.StatusCode)
} }
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
@@ -335,158 +314,143 @@ func doDoHQuery(dohURL, host string, qtype uint16, bypass bool) ([]netip.Addr, e
if err := res.Unpack(body); err != nil { if err := res.Unpack(body); err != nil {
return nil, err return nil, err
} }
if res.Rcode != dns.RcodeSuccess { return &res, nil
shared.LogError("DNS", "doDoHQuery: DoH Rcode %d for %s", res.Rcode, host)
return nil, fmt.Errorf("DoH Rcode %d", res.Rcode)
}
var addrs []netip.Addr
for _, ans := range res.Answer {
if qtype == dns.TypeA {
if a, ok := ans.(*dns.A); ok {
if ip, _ := netip.ParseAddr(a.A.String()); ip.Is4() {
addrs = append(addrs, ip)
}
}
} else if qtype == dns.TypeAAAA {
if aaaa, ok := ans.(*dns.AAAA); ok {
if ip, _ := netip.ParseAddr(aaaa.AAAA.String()); ip.Is6() {
addrs = append(addrs, ip)
}
}
}
}
return addrs, nil
} }
func resolveDoT(host, dotUpstream string, bypass bool) ([]netip.Addr, []netip.Addr, error) { func parseDNSAnswers(msg *dns.Msg, qtype uint16) []netip.Addr {
shared.LogDebug("DNS", "resolveDoT: %s with upstream=%s (bypass=%t)", host, dotUpstream, bypass) var out []netip.Addr
for _, ans := range msg.Answer {
var v4, v6 []netip.Addr
var v4Err, v6Err error
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
v4, v4Err = doDoTQuery(dotUpstream, host, dns.TypeA, bypass)
}()
go func() {
defer wg.Done()
v6, v6Err = doDoTQuery(dotUpstream, host, dns.TypeAAAA, bypass)
}()
wg.Wait()
if v4Err != nil && v6Err != nil {
shared.LogError("DNS", "resolveDoT failed for %s: both A and AAAA queries failed (bypass=%t)", host, bypass)
return nil, nil, errors.Join(v4Err, v6Err)
}
shared.LogDebug("DNS", "resolveDoT success for %s: %d v4, %d v6 (bypass=%t)", host, len(v4), len(v6), bypass)
return v4, v6, nil
}
func doDoTQuery(dotUpstream, host string, qtype uint16, bypass bool) ([]netip.Addr, error) {
// Normalize upstream to host:port
sni, port, err := net.SplitHostPort(dotUpstream)
if err != nil {
sni = dotUpstream
port = "853"
dotUpstream = net.JoinHostPort(sni, port)
}
// Resolve hostname using bypass resolver
serverAddr, err := resolveServerAddr(context.Background(), dotUpstream, bypass)
if err != nil {
return nil, err
}
req := &dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.SetEdns0(4096, true)
req.SetQuestion(dns.Fqdn(host), qtype)
client := &dns.Client{
Net: "tcp-tls",
Dialer: GetDialer(bypass),
Timeout: 5 * time.Second,
TLSConfig: &tls.Config{
ServerName: sni,
InsecureSkipVerify: false,
},
}
res, _, err := client.Exchange(req, serverAddr)
if err != nil {
return nil, err
}
if res.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("DoT query failed with Rcode: %d", res.Rcode)
}
var addrs []netip.Addr
for _, ans := range res.Answer {
switch qtype { switch qtype {
case dns.TypeA: case dns.TypeA:
if a, ok := ans.(*dns.A); ok { if a, ok := ans.(*dns.A); ok {
if ip, _ := netip.ParseAddr(a.A.String()); ip.Is4() { if ip, err := netip.ParseAddr(a.A.String()); err == nil {
addrs = append(addrs, ip) out = append(out, ip)
} }
} }
case dns.TypeAAAA: case dns.TypeAAAA:
if aaaa, ok := ans.(*dns.AAAA); ok { if aaaa, ok := ans.(*dns.AAAA); ok {
if ip, _ := netip.ParseAddr(aaaa.AAAA.String()); ip.Is6() { if ip, err := netip.ParseAddr(aaaa.AAAA.String()); err == nil {
addrs = append(addrs, ip) out = append(out, ip)
} }
} }
} }
} }
return addrs, nil return out
} }
// Resolve runs the correct protocol based on the global config func Resolve(
func Resolve(host string, bypass bool) ([]netip.Addr, []netip.Addr, error) { ctx context.Context,
cfg := getConfig() host, protocol, upstream string,
shared.LogDebug("DNS", "Resolve(%s, bypass=%t) protocol=%s upstream=%s", host, bypass, cfg.Protocol, cfg.Upstream) bypass bool,
underlying string,
var v4, v6 []netip.Addr ) ([]netip.Addr, []netip.Addr, error) {
var err error t, err := buildTransport(ctx, protocol, upstream, bypass, underlying)
switch cfg.Protocol {
case "doh":
v4, v6, err = resolveDoH(host, cfg.Upstream, bypass)
case "dot":
v4, v6, err = resolveDoT(host, cfg.Upstream, bypass)
default:
v4, v6, err = resolvePlain(host, cfg.Upstream, bypass)
}
if err != nil { if err != nil {
shared.LogError("DNS", "Final Resolve failed for %s: %v", host, err) return nil, nil, err
} else {
shared.LogDebug("DNS", "Resolve success for %s: %d v4, %d v6", host, len(v4), len(v6))
} }
return v4, v6, err return resolveHost(ctx, t, host)
} }
func CustomResolver(bypass bool) *net.Resolver { func buildTransport(
return &net.Resolver{ ctx context.Context,
PreferGo: true, protocol, upstream string,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { bypass bool,
d := GetDialer(bypass) underlying string,
return d.DialContext(ctx, network, address) ) (Transport, error) {
}, switch protocol {
case "doh":
u, err := url.Parse(upstream)
if err != nil {
return nil, err
}
hostname := u.Hostname()
port := u.Port()
if port == "" {
port = "443"
}
u.Host = net.JoinHostPort(hostname, port)
// Pre-resolve with IPv4-first ordering + bypass
servers, _, err := resolveServerAddrs(ctx, u.Host, bypass, "443", underlying)
if err != nil {
return nil, err
}
if len(servers) == 0 {
return nil, fmt.Errorf("no addresses resolved for DoH server")
}
// Custom dialer that tries servers in order (IPv4 → IPv6)
dialer := GetDialer(bypass)
transport := &http.Transport{
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
for _, addr := range servers {
conn, err := dialer.DialContext(ctx, network, addr)
if err == nil {
return conn, nil
}
}
return nil, fmt.Errorf("all DoH addresses failed")
},
TLSClientConfig: &tls.Config{
ServerName: hostname,
},
}
return DoHTransport{
Client: &http.Client{Timeout: 5 * time.Second, Transport: transport},
URL: u.String(),
Servers: servers,
Hostname: hostname,
}, nil
case "dot":
servers, sni, err := resolveServerAddrs(ctx, upstream, bypass, "853", underlying)
if err != nil {
return nil, err
}
if len(servers) == 0 {
return nil, fmt.Errorf("no addresses resolved for DoT server")
}
client := &dns.Client{
Net: "tcp-tls",
Dialer: GetDialer(bypass),
Timeout: 5 * time.Second,
TLSConfig: &tls.Config{
ServerName: sni,
},
}
return DoTTransport{
Client: client,
Servers: servers,
}, nil
default: // plain DNS
_, addr, err := parseUpstream(upstream)
if err != nil {
return nil, err
}
servers, _, err := resolveServerAddrs(ctx, addr, bypass, "53", underlying)
if err != nil {
return nil, err
}
client := &dns.Client{
Net: "udp",
Dialer: GetDialer(bypass),
Timeout: 5 * time.Second,
}
return PlainTransport{
Client: client,
Servers: servers,
}, nil
} }
} }
func GetDialer(bypass bool) *net.Dialer { func GetDialer(bypass bool) *net.Dialer {
if !bypass { if !bypass {
return &net.Dialer{ return &net.Dialer{LocalAddr: nil}
LocalAddr: nil,
}
} }
return &net.Dialer{ return &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
var opErr error var opErr error
+28 -31
View File
@@ -3,54 +3,51 @@
struct go_string { const char *str; long n; }; struct go_string { const char *str; long n; };
extern void SetDNSConfig(struct go_string handle); extern char* ResolveBootstrap(
extern char* ResolveBootstrap(const char* host, int bypass); const char* host,
const char* protocol,
JNIEXPORT void JNICALL Java_com_zaneschepke_tunnel_DnsConfigManager_setDNSConfig( const char* upstream,
JNIEnv* env, jclass clazz, jstring json) const char* underlyingDnsServers,
{ int bypass);
if (json == NULL) {
return;
}
const char* cjson = (*env)->GetStringUTFChars(env, json, 0);
if (cjson != NULL) {
size_t len = (*env)->GetStringUTFLength(env, json);
SetDNSConfig((struct go_string){
.str = cjson,
.n = (long)len
});
(*env)->ReleaseStringUTFChars(env, json, cjson);
}
}
JNIEXPORT jstring JNICALL JNIEXPORT jstring JNICALL
Java_com_zaneschepke_tunnel_DnsConfigManager_resolveBootstrap( Java_com_zaneschepke_tunnel_DnsConfigManager_resolveBootstrap(
JNIEnv* env, JNIEnv* env,
jclass clazz, jclass clazz,
jstring host, jstring host,
jboolean bypass) jstring protocol,
jstring upstream,
jstring underlyingDnsServers,
jint bypass)
{ {
if (host == NULL) { if (host == NULL || protocol == NULL || upstream == NULL || underlyingDnsServers == NULL) {
return (*env)->NewStringUTF(env, "{\"error\":\"invalid host\"}"); return (*env)->NewStringUTF(env, "ERR|invalid arguments");
} }
const char* chost = (*env)->GetStringUTFChars(env, host, NULL); const char* chost = (*env)->GetStringUTFChars(env, host, NULL);
if (chost == NULL) { const char* cprotocol = (*env)->GetStringUTFChars(env, protocol, NULL);
return (*env)->NewStringUTF(env, "{\"error\":\"out of memory\"}"); const char* cupstream = (*env)->GetStringUTFChars(env, upstream, NULL);
const char* cunderlying = (*env)->GetStringUTFChars(env, underlyingDnsServers, NULL);
if (chost == NULL || cprotocol == NULL || cupstream == NULL || cunderlying == NULL) {
return (*env)->NewStringUTF(env, "ERR|out of memory");
} }
char* resultC = ResolveBootstrap( char* resultC = ResolveBootstrap(
(char*)chost, chost,
bypass ? 1 : 0 cprotocol,
cupstream,
cunderlying,
bypass ? 1 : 0
); );
(*env)->ReleaseStringUTFChars(env, host, chost); (*env)->ReleaseStringUTFChars(env, host, chost);
(*env)->ReleaseStringUTFChars(env, protocol, cprotocol);
(*env)->ReleaseStringUTFChars(env, upstream, cupstream);
(*env)->ReleaseStringUTFChars(env, underlyingDnsServers, cunderlying);
if (resultC == NULL) { if (resultC == NULL) {
return (*env)->NewStringUTF(env, "{\"error\":\"null response\"}"); return (*env)->NewStringUTF(env, "ERR|null response");
} }
jstring jresult = (*env)->NewStringUTF(env, resultC); jstring jresult = (*env)->NewStringUTF(env, resultC);