mirror of
https://github.com/wgtunnel/android.git
synced 2026-06-02 00:29:08 +02:00
+6
-1
@@ -32,6 +32,7 @@ import kotlinx.coroutines.sync.withLock
|
||||
class TunnelCoordinator(
|
||||
private val tunnelProvider: TunnelProvider,
|
||||
private val serviceManager: ServiceManager,
|
||||
private val bootstrapCoordinator: AppBoostrapCoordinator,
|
||||
settingsRepository: GeneralSettingRepository,
|
||||
private val tunnelRepository: TunnelRepository,
|
||||
dnsSettingsRepository: RoomDnsSettingsRepository,
|
||||
@@ -86,7 +87,11 @@ class TunnelCoordinator(
|
||||
suspend fun startTunnel(
|
||||
config: TunnelConfig,
|
||||
source: TunnelActionSource = TunnelActionSource.USER,
|
||||
) = tunnelMutex.withLock { startTunnelInternal(config, source) }
|
||||
) = tunnelMutex.withLock {
|
||||
// wait for app to be bootstrapped
|
||||
bootstrapCoordinator.isReady.first { it }
|
||||
startTunnelInternal(config, source)
|
||||
}
|
||||
|
||||
suspend fun stopTunnel(id: Int, source: TunnelActionSource = TunnelActionSource.USER) =
|
||||
tunnelMutex.withLock {
|
||||
|
||||
@@ -24,6 +24,7 @@ val coordinatorModule = module {
|
||||
get(),
|
||||
get(),
|
||||
get(),
|
||||
get(),
|
||||
get(named(Scope.APPLICATION)),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -3,25 +3,27 @@ package com.zaneschepke.tunnel
|
||||
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.json.JSONObject
|
||||
|
||||
internal object DnsConfigManager {
|
||||
private external fun setDNSConfig(configJson: String)
|
||||
|
||||
fun update(protocol: String, upstream: String) {
|
||||
val config =
|
||||
JSONObject().apply {
|
||||
put("protocol", protocol)
|
||||
put("upstream", upstream)
|
||||
}
|
||||
setDNSConfig(config.toString())
|
||||
}
|
||||
private external fun resolveBootstrap(
|
||||
host: String,
|
||||
protocol: String,
|
||||
upstream: String,
|
||||
underlyingDnsServers: String,
|
||||
bypass: Int,
|
||||
): String
|
||||
|
||||
private external fun resolveBootstrap(host: String, bypass: Int): String
|
||||
|
||||
suspend fun resolveHostBootstrap(host: String, bypass: Boolean): DnsBootstrapResult =
|
||||
suspend fun resolveHostBootstrap(
|
||||
host: String,
|
||||
protocol: String,
|
||||
upstream: String,
|
||||
underlyingDnsServers: String,
|
||||
bypass: Boolean,
|
||||
): DnsBootstrapResult =
|
||||
withContext(Dispatchers.IO) {
|
||||
val raw = resolveBootstrap(host, if (bypass) 1 else 0)
|
||||
val bypassOption = if (bypass) 1 else 0
|
||||
val raw = resolveBootstrap(host, protocol, upstream, underlyingDnsServers, bypassOption)
|
||||
|
||||
if (raw.startsWith("ERR|")) {
|
||||
throw RuntimeException(raw.removePrefix("ERR|"))
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.zaneschepke.tunnel.DnsConfigManager
|
||||
import com.zaneschepke.tunnel.Tunnel
|
||||
import com.zaneschepke.tunnel.event.ActorEvent
|
||||
import com.zaneschepke.tunnel.event.ActorEvent.ActiveConfigUpdated
|
||||
import com.zaneschepke.tunnel.event.ActorEvent.BootstrapConfigUpdated
|
||||
import com.zaneschepke.tunnel.event.ActorEvent.BootstrapStateChanged
|
||||
import com.zaneschepke.tunnel.event.ActorEvent.EngineStatus
|
||||
import com.zaneschepke.tunnel.event.ActorEvent.KillSwitchStateChanged
|
||||
@@ -12,7 +13,9 @@ import com.zaneschepke.tunnel.event.ActorEvent.PeersUpdated
|
||||
import com.zaneschepke.tunnel.event.ActorEvent.ResolvedPeersApplied
|
||||
import com.zaneschepke.tunnel.event.ActorEvent.TunnelStarted
|
||||
import com.zaneschepke.tunnel.event.ActorEvent.TunnelStopped
|
||||
import com.zaneschepke.tunnel.event.ActorEvent.UnderlyingDnsServersUpdated
|
||||
import com.zaneschepke.tunnel.event.TunnelEvent
|
||||
import com.zaneschepke.tunnel.event.TunnelEvent.NoRootShellAccess
|
||||
import com.zaneschepke.tunnel.model.BackendMode
|
||||
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
||||
import com.zaneschepke.tunnel.model.PublicKey
|
||||
@@ -122,6 +125,10 @@ internal class TunnelActor(
|
||||
engine.stop(runtime.running.handle, runtime.running.mode)
|
||||
}
|
||||
|
||||
is TunnelCommand.SetBootstrapConfig -> {
|
||||
apply(BootstrapConfigUpdated(cmd.config))
|
||||
}
|
||||
|
||||
is TunnelCommand.UpdatePeers -> {
|
||||
val runtime = _state.value.byTunnelId[cmd.tunnelId] ?: continue
|
||||
val running = runtime.running
|
||||
@@ -209,9 +216,7 @@ internal class TunnelActor(
|
||||
} catch (t: Throwable) {
|
||||
Timber.w(t, "Root shell commands failed")
|
||||
if (t is RootShellException.NoRootAccess) {
|
||||
_events.emit(
|
||||
TunnelEvent.NoRootShellAccess(tunnelId = cmd.tunnelId)
|
||||
)
|
||||
_events.emit(NoRootShellAccess(tunnelId = cmd.tunnelId))
|
||||
}
|
||||
} finally {
|
||||
if (isPostDown) {
|
||||
@@ -219,6 +224,10 @@ internal class TunnelActor(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
is TunnelCommand.UpdateUnderlyingDnsServers -> {
|
||||
apply(UnderlyingDnsServersUpdated(cmd.servers))
|
||||
}
|
||||
}
|
||||
} catch (t: Throwable) {
|
||||
Timber.e(t, "Tunnel command failed: $cmd")
|
||||
@@ -621,6 +630,14 @@ internal class TunnelActor(
|
||||
is KillSwitchStateChanged -> {
|
||||
state.copy(killSwitchEnabled = event.enabled)
|
||||
}
|
||||
|
||||
is BootstrapConfigUpdated -> {
|
||||
state.copy(dnsConfig = event.config)
|
||||
}
|
||||
|
||||
is UnderlyingDnsServersUpdated -> {
|
||||
state.copy(dnsConfig = state.dnsConfig.copy(underlyingDnsServers = event.servers))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -667,9 +684,17 @@ internal class TunnelActor(
|
||||
val endpoint = peer.endpoint ?: continue
|
||||
val host = endpoint.substringBeforeLast(":")
|
||||
|
||||
val dnsConfig = state.value.dnsConfig
|
||||
|
||||
val dnsResult =
|
||||
try {
|
||||
DnsConfigManager.resolveHostBootstrap(host = host, bypass = bypassNeeded)
|
||||
DnsConfigManager.resolveHostBootstrap(
|
||||
host = host,
|
||||
dnsConfig.protocol,
|
||||
dnsConfig.upstream,
|
||||
dnsConfig.underlyingDnsServers,
|
||||
bypass = bypassNeeded,
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
Timber.w(e, "DNS failed for $host")
|
||||
continue
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package com.zaneschepke.tunnel.backend
|
||||
|
||||
import com.zaneschepke.networkmonitor.ActiveNetwork
|
||||
import com.zaneschepke.networkmonitor.DnsInfo
|
||||
import com.zaneschepke.networkmonitor.NetworkMonitor
|
||||
import com.zaneschepke.networkmonitor.PrivateDnsMode
|
||||
import com.zaneschepke.tunnel.DnsConfigManager
|
||||
import com.zaneschepke.tunnel.NotificationProvider
|
||||
import com.zaneschepke.tunnel.Tunnel
|
||||
import com.zaneschepke.tunnel.event.TunnelEvent
|
||||
@@ -14,6 +15,7 @@ import com.zaneschepke.tunnel.model.TunnelCommand
|
||||
import com.zaneschepke.tunnel.service.VpnService
|
||||
import com.zaneschepke.tunnel.state.BackendStatus
|
||||
import com.zaneschepke.tunnel.state.KillSwitchState
|
||||
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
|
||||
import com.zaneschepke.tunnel.util.BackendException
|
||||
import java.lang.ref.WeakReference
|
||||
import kotlin.reflect.KClass
|
||||
@@ -27,8 +29,10 @@ import kotlinx.coroutines.flow.asSharedFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.distinctUntilChangedBy
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.flow.firstOrNull
|
||||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withTimeoutOrNull
|
||||
import org.koin.java.KoinJavaComponent.inject
|
||||
import timber.log.Timber
|
||||
|
||||
@@ -191,24 +195,90 @@ class TunnelBackend(
|
||||
override suspend fun setBootstrapDnsMode(mode: DnsBoostrapMode) {
|
||||
_status.update { it.copy(dnsMode = mode) }
|
||||
|
||||
dnsConfigJob?.cancel()
|
||||
dnsConfigJob = null
|
||||
|
||||
when (mode) {
|
||||
is DnsBoostrapMode.Custom -> {
|
||||
Timber.d("DNS Boostrap mode set to custom, disabling system dns monitoring")
|
||||
dnsConfigJob?.cancel()
|
||||
dnsConfigJob = null
|
||||
|
||||
DnsConfigManager.update(
|
||||
mode.config.protocol,
|
||||
mode.config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
|
||||
Timber.d(
|
||||
"DNS Bootstrap mode set to custom: ${mode.config.protocol} -> ${mode.config.upstream}"
|
||||
)
|
||||
|
||||
emitInitialDnsConfig(mode.config.protocol, mode.config.upstream)
|
||||
startUnderlyingDnsMonitoring()
|
||||
}
|
||||
|
||||
DnsBoostrapMode.System -> {
|
||||
Timber.d("DNS Bootstrap mode set to System")
|
||||
emitInitialDnsConfig()
|
||||
startSystemDnsMonitoring()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun emitInitialDnsConfig(protocol: String? = null, upstream: String? = null) {
|
||||
val state =
|
||||
withTimeoutOrNull(2_500L) {
|
||||
networkMonitor.connectivityStateFlow.first { connectivityState ->
|
||||
val dns = connectivityState.underlyingDnsInfo
|
||||
dns.servers.isNotEmpty() ||
|
||||
connectivityState.activeNetwork is ActiveNetwork.Disconnected
|
||||
}
|
||||
} ?: networkMonitor.connectivityStateFlow.firstOrNull()
|
||||
|
||||
val dns = state?.underlyingDnsInfo ?: DnsInfo()
|
||||
|
||||
val finalProtocol =
|
||||
protocol
|
||||
?: when (dns.privateDnsMode) {
|
||||
PrivateDnsMode.HOSTNAME -> "dot"
|
||||
else -> "plain"
|
||||
}
|
||||
|
||||
val finalUpstream =
|
||||
upstream
|
||||
?: when (dns.privateDnsMode) {
|
||||
PrivateDnsMode.HOSTNAME ->
|
||||
dns.privateDnsHostname?.takeIf { it.isNotBlank() }
|
||||
?: DnsBoostrapConfig.DEFAULT_UPSTREAM
|
||||
else -> dns.servers.firstOrNull() ?: DnsBoostrapConfig.DEFAULT_UPSTREAM
|
||||
}
|
||||
|
||||
val underlying =
|
||||
dns.servers.joinToString(",").ifBlank { DnsBoostrapConfig.DEFAULT_UPSTREAM }
|
||||
|
||||
Timber.d(
|
||||
"DNS initial emission: protocol=$finalProtocol upstream=$finalUpstream underlying=$underlying"
|
||||
)
|
||||
|
||||
actor.send(
|
||||
TunnelCommand.SetBootstrapConfig(
|
||||
RuntimeDnsConfig(
|
||||
protocol = finalProtocol,
|
||||
upstream = finalUpstream,
|
||||
underlyingDnsServers = underlying,
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private fun startUnderlyingDnsMonitoring() {
|
||||
if (dnsConfigJob?.isActive == true) return
|
||||
|
||||
dnsConfigJob = scope.launch {
|
||||
networkMonitor.connectivityStateFlow
|
||||
.distinctUntilChangedBy { it.underlyingDnsInfo.servers }
|
||||
.collect { state ->
|
||||
val dns = state.underlyingDnsInfo
|
||||
val underlying = dns.servers.joinToString(",")
|
||||
|
||||
Timber.d("Underlying DNS servers changed: $underlying")
|
||||
|
||||
actor.send(TunnelCommand.UpdateUnderlyingDnsServers(underlying))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun emergencyStopAllOfTypeSync(modeClass: KClass<out BackendMode>) {
|
||||
actor.emergencyStopAllOfType(modeClass)
|
||||
}
|
||||
@@ -236,13 +306,12 @@ class TunnelBackend(
|
||||
val config =
|
||||
when (dns.privateDnsMode) {
|
||||
PrivateDnsMode.OFF,
|
||||
PrivateDnsMode.AUTOMATIC -> {
|
||||
PrivateDnsMode.AUTOMATIC ->
|
||||
DnsBoostrapConfig.Plain(
|
||||
dns.servers.firstOrNull() ?: DnsBoostrapConfig.DEFAULT_UPSTREAM
|
||||
)
|
||||
}
|
||||
|
||||
PrivateDnsMode.HOSTNAME -> {
|
||||
PrivateDnsMode.HOSTNAME ->
|
||||
dns.privateDnsHostname
|
||||
?.takeIf { it.isNotBlank() }
|
||||
?.let { DnsBoostrapConfig.DoT(it) }
|
||||
@@ -250,12 +319,16 @@ class TunnelBackend(
|
||||
dns.servers.firstOrNull()
|
||||
?: DnsBoostrapConfig.DEFAULT_UPSTREAM
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
DnsConfigManager.update(
|
||||
config.protocol,
|
||||
config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
|
||||
actor.send(
|
||||
TunnelCommand.SetBootstrapConfig(
|
||||
RuntimeDnsConfig(
|
||||
protocol = config.protocol,
|
||||
upstream = config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
|
||||
underlyingDnsServers = dns.servers.joinToString(","),
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import com.zaneschepke.tunnel.model.TunnelCommand
|
||||
import com.zaneschepke.tunnel.state.BootstrapState
|
||||
import com.zaneschepke.tunnel.state.EngineStartResult
|
||||
import com.zaneschepke.tunnel.state.NativeTunnelStatus
|
||||
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
|
||||
import com.zaneschepke.wireguardautotunnel.parser.ActiveConfig
|
||||
import com.zaneschepke.wireguardautotunnel.parser.PeerSection
|
||||
|
||||
@@ -23,6 +24,10 @@ sealed class ActorEvent {
|
||||
val preferIpv6: Boolean,
|
||||
) : ActorEvent()
|
||||
|
||||
data class BootstrapConfigUpdated(val config: RuntimeDnsConfig) : ActorEvent()
|
||||
|
||||
data class UnderlyingDnsServersUpdated(val servers: String) : ActorEvent()
|
||||
|
||||
data class ResolvedPeersApplied(
|
||||
val tunnelId: Int,
|
||||
val cache: Map<PublicKey, DnsBootstrapResult>,
|
||||
|
||||
@@ -17,7 +17,7 @@ sealed class DnsBoostrapConfig(open val upstream: String?) {
|
||||
|
||||
data class DoH(override val upstream: String?) : DnsBoostrapConfig(upstream) {
|
||||
override val protocol: String
|
||||
get() = "dot"
|
||||
get() = "doh"
|
||||
}
|
||||
|
||||
data class DoT(override val upstream: String?) : DnsBoostrapConfig(upstream) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.zaneschepke.tunnel.model
|
||||
|
||||
import com.zaneschepke.tunnel.Tunnel
|
||||
import com.zaneschepke.tunnel.state.BootstrapState
|
||||
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
|
||||
import com.zaneschepke.wireguardautotunnel.parser.PeerSection
|
||||
|
||||
sealed class TunnelCommand {
|
||||
@@ -24,6 +25,10 @@ sealed class TunnelCommand {
|
||||
|
||||
data class SetBootstrapState(val tunnelId: Int, val state: BootstrapState) : TunnelCommand()
|
||||
|
||||
data class SetBootstrapConfig(val config: RuntimeDnsConfig) : TunnelCommand()
|
||||
|
||||
data class UpdateUnderlyingDnsServers(val servers: String) : TunnelCommand()
|
||||
|
||||
data class RunHook(val tunnelId: Int, val phase: Phase, val cmds: List<String>?) :
|
||||
TunnelCommand() {
|
||||
enum class Phase {
|
||||
|
||||
@@ -4,4 +4,5 @@ data class ActorState(
|
||||
val byTunnelId: Map<Int, TunnelRuntimeState>,
|
||||
val byHandle: Map<Int, Int>,
|
||||
val killSwitchEnabled: Boolean = false,
|
||||
val dnsConfig: RuntimeDnsConfig = RuntimeDnsConfig(),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.zaneschepke.tunnel.state
|
||||
|
||||
import com.zaneschepke.tunnel.model.DnsBoostrapConfig
|
||||
|
||||
data class RuntimeDnsConfig(
|
||||
val protocol: String = "plain",
|
||||
val upstream: String = DnsBoostrapConfig.DEFAULT_UPSTREAM,
|
||||
val underlyingDnsServers: String = DnsBoostrapConfig.DEFAULT_UPSTREAM,
|
||||
)
|
||||
+286
-322
@@ -9,7 +9,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -18,7 +17,6 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -27,21 +25,6 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const defaultPlain = "udp://1.1.1.1:53"
|
||||
|
||||
var (
|
||||
currentConfig DNSConfig = DNSConfig{
|
||||
"plain",
|
||||
"1.1.1.1:53",
|
||||
}
|
||||
configMu sync.RWMutex
|
||||
)
|
||||
|
||||
type DNSConfig struct {
|
||||
Protocol string `json:"protocol"` // plain, doh, or dot
|
||||
Upstream string `json:"upstream"`
|
||||
}
|
||||
|
||||
type Resolved struct {
|
||||
V4 []netip.Addr
|
||||
V6 []netip.Addr
|
||||
@@ -52,36 +35,34 @@ type ResolverOptions struct {
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
func DefaultOptions() ResolverOptions {
|
||||
return ResolverOptions{
|
||||
UpstreamURL: defaultPlain,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
//export SetDNSConfig
|
||||
func SetDNSConfig(config string) {
|
||||
var cfg DNSConfig
|
||||
if err := json.Unmarshal([]byte(config), &cfg); err != nil {
|
||||
shared.LogError("DNS", "Failed to parse DNSConfig: %v", err)
|
||||
return
|
||||
}
|
||||
if cfg.Protocol != "plain" && cfg.Protocol != "doh" && cfg.Protocol != "dot" {
|
||||
cfg.Protocol = "plain"
|
||||
}
|
||||
configMu.Lock()
|
||||
currentConfig = cfg
|
||||
configMu.Unlock()
|
||||
shared.LogDebug("DNS", "DNS config updated: %s %s", cfg.Protocol, cfg.Upstream)
|
||||
type Transport interface {
|
||||
Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
|
||||
}
|
||||
|
||||
//export ResolveBootstrap
|
||||
func ResolveBootstrap(host *C.char, bypass C.int) *C.char {
|
||||
func ResolveBootstrap(
|
||||
host *C.char,
|
||||
protocol *C.char,
|
||||
upstream *C.char,
|
||||
underlyingDnsServers *C.char,
|
||||
bypass C.int,
|
||||
) *C.char {
|
||||
h := C.GoString(host)
|
||||
p := C.GoString(protocol)
|
||||
u := C.GoString(upstream)
|
||||
underlying := C.GoString(underlyingDnsServers)
|
||||
bp := bypass == 1
|
||||
shared.LogDebug("DNS", "ResolveBootstrap called for host=%s (bypass=%t)", h, bp)
|
||||
|
||||
v4, v6, err := Resolve(h, bp)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
shared.LogDebug(
|
||||
"DNS",
|
||||
"ResolveBootstrap called host=%s protocol=%s upstream=%s bypass=%t",
|
||||
h, p, u, bp,
|
||||
)
|
||||
|
||||
v4, v6, err := Resolve(ctx, h, p, u, bp, underlying)
|
||||
if err != nil {
|
||||
shared.LogError("DNS", "ResolveBootstrap failed for %s: %v", h, err)
|
||||
return C.CString("ERR|" + err.Error())
|
||||
@@ -96,15 +77,76 @@ func ResolveBootstrap(host *C.char, bypass C.int) *C.char {
|
||||
v6Str[i] = ip.String()
|
||||
}
|
||||
|
||||
result := "v4=" + strings.Join(v4Str, ",") + ";v6=" + strings.Join(v6Str, ",")
|
||||
result := "v4=" + strings.Join(v4Str, ",") +
|
||||
";v6=" + strings.Join(v6Str, ",")
|
||||
|
||||
shared.LogDebug("DNS", "ResolveBootstrap success for %s: %s", h, result)
|
||||
return C.CString(result)
|
||||
}
|
||||
|
||||
func getConfig() DNSConfig {
|
||||
configMu.RLock()
|
||||
defer configMu.RUnlock()
|
||||
return currentConfig
|
||||
type DoTTransport struct {
|
||||
Client *dns.Client
|
||||
Servers []string
|
||||
}
|
||||
|
||||
type DoHTransport struct {
|
||||
Client *http.Client
|
||||
URL string
|
||||
Servers []string // IPv4 first, IPv6 fallback
|
||||
Hostname string // for SNI and Host header
|
||||
}
|
||||
|
||||
type PlainTransport struct {
|
||||
Client *dns.Client
|
||||
Servers []string
|
||||
}
|
||||
|
||||
func resolveHost(
|
||||
ctx context.Context,
|
||||
t Transport,
|
||||
host string,
|
||||
) (v4, v6 []netip.Addr, err error) {
|
||||
a4, e4 := resolveQ(ctx, t, host, dns.TypeA)
|
||||
if e4 == nil {
|
||||
v4 = a4
|
||||
}
|
||||
a6, e6 := resolveQ(ctx, t, host, dns.TypeAAAA)
|
||||
if e6 == nil {
|
||||
v6 = a6
|
||||
}
|
||||
|
||||
if len(v4) > 0 || len(v6) > 0 {
|
||||
return v4, v6, nil
|
||||
}
|
||||
return nil, nil, errors.Join(e4, e6)
|
||||
}
|
||||
|
||||
func resolveQ(
|
||||
ctx context.Context,
|
||||
t Transport,
|
||||
host string,
|
||||
qtype uint16,
|
||||
) ([]netip.Addr, error) {
|
||||
req := &dns.Msg{}
|
||||
req.SetQuestion(dns.Fqdn(host), qtype)
|
||||
req.SetEdns0(4096, true)
|
||||
|
||||
res, err := t.Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res == nil {
|
||||
return nil, fmt.Errorf("nil DNS response")
|
||||
}
|
||||
if res.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("rcode %d", res.Rcode)
|
||||
}
|
||||
|
||||
addrs := parseDNSAnswers(res, qtype)
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("no answers for qtype %d", qtype)
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func parseUpstream(upstreamURL string) (network, address string, err error) {
|
||||
@@ -113,7 +155,6 @@ func parseUpstream(upstreamURL string) (network, address string, err error) {
|
||||
if !strings.Contains(u, "://") {
|
||||
u = "udp://" + u
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
shared.LogError("DNS", "parseUpstream failed for %q: %v", upstreamURL, err)
|
||||
@@ -141,189 +182,127 @@ func parseUpstream(upstreamURL string) (network, address string, err error) {
|
||||
return network, address, nil
|
||||
}
|
||||
|
||||
func resolveServerAddr(ctx context.Context, address string, bypass bool) (string, error) {
|
||||
func newUnderlyingResolver(bypass bool, underlying string) *net.Resolver {
|
||||
if !bypass {
|
||||
return &net.Resolver{PreferGo: false}
|
||||
}
|
||||
|
||||
rawServers := strings.Split(underlying, ",")
|
||||
var servers []string
|
||||
|
||||
for _, s := range rawServers {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.Contains(s, ":") {
|
||||
s = net.JoinHostPort(s, "53")
|
||||
}
|
||||
servers = append(servers, s)
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
servers = []string{"1.1.1.1:53"}
|
||||
}
|
||||
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||
for _, server := range servers {
|
||||
conn, err := GetDialer(true).DialContext(ctx, network, server)
|
||||
if err == nil {
|
||||
shared.LogDebug("DNS", "Using underlying bootstrap resolver: %s", server)
|
||||
return conn, nil
|
||||
}
|
||||
shared.LogDebug("DNS", "Bootstrap resolver failed for %s: %v", server, err)
|
||||
}
|
||||
return nil, fmt.Errorf("all underlying DNS servers failed")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func resolveServerAddrs(
|
||||
ctx context.Context,
|
||||
address string,
|
||||
bypass bool,
|
||||
defaultPort string,
|
||||
underlying string,
|
||||
) ([]string, string, error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
shared.LogError("DNS", "resolveServerAddr: invalid address %q: %v", address, err)
|
||||
return "", err
|
||||
host = address
|
||||
port = defaultPort
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
return address, nil
|
||||
}
|
||||
|
||||
shared.LogDebug("DNS", "resolveServerAddr: bootstrapping upstream hostname %s (bypass=%t)", host, bypass)
|
||||
|
||||
bootstrapDialer := GetDialer(bypass)
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||
return bootstrapDialer.DialContext(ctx, network, "1.1.1.1:53")
|
||||
},
|
||||
return []string{net.JoinHostPort(host, port)}, host, nil
|
||||
}
|
||||
|
||||
resolver := newUnderlyingResolver(bypass, underlying)
|
||||
ips, err := resolver.LookupIP(ctx, "ip", host)
|
||||
if err != nil {
|
||||
shared.LogError("DNS", "Failed to resolve upstream hostname %s (bypass=%t): %v", host, bypass, err)
|
||||
return "", fmt.Errorf("failed to resolve upstream hostname %s: %w", host, err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
err = errors.New("no IPs found for upstream hostname")
|
||||
shared.LogError("DNS", "%v for %s", err, host)
|
||||
return "", err
|
||||
shared.LogError("DNS", "Failed to resolve upstream %s (bypass=%t): %v", host, bypass, err)
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
addr := net.JoinHostPort(ips[0].String(), port)
|
||||
shared.LogDebug("DNS", "Resolved upstream %s -> %s", host, addr)
|
||||
return addr, nil
|
||||
}
|
||||
func resolveInner(host string, ipType uint16, network, serverAddr string, bypass bool) ([]netip.Addr, error) {
|
||||
req := &dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.SetQuestion(dns.Fqdn(host), ipType)
|
||||
req.SetEdns0(4096, true)
|
||||
|
||||
client := &dns.Client{
|
||||
Net: network,
|
||||
Dialer: GetDialer(bypass),
|
||||
Timeout: 5 * time.Second,
|
||||
UDPSize: 4096,
|
||||
}
|
||||
|
||||
res, _, err := client.Exchange(req, serverAddr)
|
||||
if err != nil {
|
||||
shared.LogError("DNS", "resolveInner: DNS exchange failed for %s (type=%d, server=%s, bypass=%t): %v", host, ipType, serverAddr, bypass, err)
|
||||
return nil, err
|
||||
}
|
||||
if res.Rcode != dns.RcodeSuccess {
|
||||
shared.LogError("DNS", "resolveInner: DNS query failed with Rcode %d for %s", res.Rcode, host)
|
||||
return nil, fmt.Errorf("DNS query failed with Rcode: %d", res.Rcode)
|
||||
}
|
||||
|
||||
var addr []netip.Addr
|
||||
for _, ans := range res.Answer {
|
||||
switch ipType {
|
||||
case dns.TypeA:
|
||||
if a, ok := ans.(*dns.A); ok {
|
||||
if ip, err := netip.ParseAddr(a.A.String()); err == nil {
|
||||
addr = append(addr, ip)
|
||||
}
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if aaaa, ok := ans.(*dns.AAAA); ok {
|
||||
if ip, err := netip.ParseAddr(aaaa.AAAA.String()); err == nil {
|
||||
addr = append(addr, ip)
|
||||
}
|
||||
}
|
||||
var v4, v6 []string
|
||||
for _, ip := range ips {
|
||||
addr := net.JoinHostPort(ip.String(), port)
|
||||
if ip.To4() != nil {
|
||||
v4 = append(v4, addr)
|
||||
} else {
|
||||
v6 = append(v6, addr)
|
||||
}
|
||||
}
|
||||
return addr, nil
|
||||
|
||||
return append(v4, v6...), host, nil
|
||||
}
|
||||
|
||||
func resolvePlain(host, upstreamURL string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
|
||||
shared.LogDebug("DNS", "resolvePlain: %s with upstream=%s (bypass=%t)", host, upstreamURL, bypass)
|
||||
network, addr, err := parseUpstream(upstreamURL)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
func (t PlainTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
for _, server := range t.Servers {
|
||||
m, _, err := t.Client.Exchange(msg, server)
|
||||
if err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
serverAddr, err := resolveServerAddr(context.Background(), addr, bypass)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var v4, v6 []netip.Addr
|
||||
var v4Err, v6Err error
|
||||
wg.Add(2)
|
||||
go func() { v4, v4Err = resolveInner(host, dns.TypeA, network, serverAddr, bypass); wg.Done() }()
|
||||
go func() { v6, v6Err = resolveInner(host, dns.TypeAAAA, network, serverAddr, bypass); wg.Done() }()
|
||||
wg.Wait()
|
||||
|
||||
if v4Err != nil && v6Err != nil {
|
||||
shared.LogError("DNS", "resolvePlain failed for %s: both A and AAAA failed", host)
|
||||
return nil, nil, errors.Join(v4Err, v6Err)
|
||||
}
|
||||
if len(v4) == 0 && len(v6) == 0 {
|
||||
err = errors.New("no IP addresses found")
|
||||
shared.LogError("DNS", "%v for %s", err, host)
|
||||
return nil, nil, err
|
||||
}
|
||||
return v4, v6, nil
|
||||
return nil, fmt.Errorf("all DNS servers failed")
|
||||
}
|
||||
|
||||
func resolveDoH(host, dohURL string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
|
||||
shared.LogDebug("DNS", "Resolving DOH: %s with %s", host, dohURL)
|
||||
var v4, v6 []netip.Addr
|
||||
var v4Err, v6Err error
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
v4, v4Err = doDoHQuery(dohURL, host, dns.TypeA, bypass)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
v6, v6Err = doDoHQuery(dohURL, host, dns.TypeAAAA, bypass)
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
if v4Err != nil && v6Err != nil {
|
||||
shared.LogError("DNS", "resolveDoH failed for %s: both queries failed", host)
|
||||
return nil, nil, errors.Join(v4Err, v6Err)
|
||||
func (t DoTTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
for _, server := range t.Servers {
|
||||
m, _, err := t.Client.Exchange(msg, server)
|
||||
if err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
return v4, v6, nil
|
||||
return nil, fmt.Errorf("all DoT servers failed")
|
||||
}
|
||||
|
||||
func doDoHQuery(dohURL, host string, qtype uint16, bypass bool) ([]netip.Addr, error) {
|
||||
req := &dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.SetEdns0(4096, true)
|
||||
req.SetQuestion(dns.Fqdn(host), qtype)
|
||||
|
||||
transport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
h, port, _ := net.SplitHostPort(addr)
|
||||
if net.ParseIP(h) == nil {
|
||||
ips, err := CustomResolver(bypass).LookupIP(ctx, "ip", h)
|
||||
if err == nil && len(ips) > 0 {
|
||||
h = ips[0].String()
|
||||
}
|
||||
}
|
||||
return GetDialer(bypass).DialContext(ctx, network, net.JoinHostPort(h, port))
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
wire, err := req.Pack()
|
||||
func (t DoHTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
wire, err := msg.Pack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(context.Background(), "POST", dohURL, bytes.NewReader(wire))
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx, "POST", t.URL, bytes.NewReader(wire),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/dns-message")
|
||||
httpReq.Header.Set("Accept", "application/dns-message")
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
req.Header.Set("Accept", "application/dns-message")
|
||||
req.Host = t.Hostname // important for virtual hosting and cert validation
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
resp, err := t.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
shared.LogError("DNS", "doDoHQuery: DoH server returned HTTP %d for %s", resp.StatusCode, host)
|
||||
return nil, fmt.Errorf("DoH HTTP %d", resp.StatusCode)
|
||||
return nil, fmt.Errorf("doh status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
@@ -335,158 +314,143 @@ func doDoHQuery(dohURL, host string, qtype uint16, bypass bool) ([]netip.Addr, e
|
||||
if err := res.Unpack(body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.Rcode != dns.RcodeSuccess {
|
||||
shared.LogError("DNS", "doDoHQuery: DoH Rcode %d for %s", res.Rcode, host)
|
||||
return nil, fmt.Errorf("DoH Rcode %d", res.Rcode)
|
||||
}
|
||||
|
||||
var addrs []netip.Addr
|
||||
for _, ans := range res.Answer {
|
||||
if qtype == dns.TypeA {
|
||||
if a, ok := ans.(*dns.A); ok {
|
||||
if ip, _ := netip.ParseAddr(a.A.String()); ip.Is4() {
|
||||
addrs = append(addrs, ip)
|
||||
}
|
||||
}
|
||||
} else if qtype == dns.TypeAAAA {
|
||||
if aaaa, ok := ans.(*dns.AAAA); ok {
|
||||
if ip, _ := netip.ParseAddr(aaaa.AAAA.String()); ip.Is6() {
|
||||
addrs = append(addrs, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return addrs, nil
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func resolveDoT(host, dotUpstream string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
|
||||
shared.LogDebug("DNS", "resolveDoT: %s with upstream=%s (bypass=%t)", host, dotUpstream, bypass)
|
||||
|
||||
var v4, v6 []netip.Addr
|
||||
var v4Err, v6Err error
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
v4, v4Err = doDoTQuery(dotUpstream, host, dns.TypeA, bypass)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
v6, v6Err = doDoTQuery(dotUpstream, host, dns.TypeAAAA, bypass)
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
if v4Err != nil && v6Err != nil {
|
||||
shared.LogError("DNS", "resolveDoT failed for %s: both A and AAAA queries failed (bypass=%t)", host, bypass)
|
||||
return nil, nil, errors.Join(v4Err, v6Err)
|
||||
}
|
||||
|
||||
shared.LogDebug("DNS", "resolveDoT success for %s: %d v4, %d v6 (bypass=%t)", host, len(v4), len(v6), bypass)
|
||||
return v4, v6, nil
|
||||
}
|
||||
|
||||
func doDoTQuery(dotUpstream, host string, qtype uint16, bypass bool) ([]netip.Addr, error) {
|
||||
// Normalize upstream to host:port
|
||||
sni, port, err := net.SplitHostPort(dotUpstream)
|
||||
if err != nil {
|
||||
sni = dotUpstream
|
||||
port = "853"
|
||||
dotUpstream = net.JoinHostPort(sni, port)
|
||||
}
|
||||
|
||||
// Resolve hostname using bypass resolver
|
||||
serverAddr, err := resolveServerAddr(context.Background(), dotUpstream, bypass)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.SetEdns0(4096, true)
|
||||
req.SetQuestion(dns.Fqdn(host), qtype)
|
||||
|
||||
client := &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
Dialer: GetDialer(bypass),
|
||||
Timeout: 5 * time.Second,
|
||||
TLSConfig: &tls.Config{
|
||||
ServerName: sni,
|
||||
InsecureSkipVerify: false,
|
||||
},
|
||||
}
|
||||
|
||||
res, _, err := client.Exchange(req, serverAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("DoT query failed with Rcode: %d", res.Rcode)
|
||||
}
|
||||
|
||||
var addrs []netip.Addr
|
||||
for _, ans := range res.Answer {
|
||||
func parseDNSAnswers(msg *dns.Msg, qtype uint16) []netip.Addr {
|
||||
var out []netip.Addr
|
||||
for _, ans := range msg.Answer {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
if a, ok := ans.(*dns.A); ok {
|
||||
if ip, _ := netip.ParseAddr(a.A.String()); ip.Is4() {
|
||||
addrs = append(addrs, ip)
|
||||
if ip, err := netip.ParseAddr(a.A.String()); err == nil {
|
||||
out = append(out, ip)
|
||||
}
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if aaaa, ok := ans.(*dns.AAAA); ok {
|
||||
if ip, _ := netip.ParseAddr(aaaa.AAAA.String()); ip.Is6() {
|
||||
addrs = append(addrs, ip)
|
||||
if ip, err := netip.ParseAddr(aaaa.AAAA.String()); err == nil {
|
||||
out = append(out, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return addrs, nil
|
||||
return out
|
||||
}
|
||||
|
||||
// Resolve runs the correct protocol based on the global config
|
||||
func Resolve(host string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
|
||||
cfg := getConfig()
|
||||
shared.LogDebug("DNS", "Resolve(%s, bypass=%t) protocol=%s upstream=%s", host, bypass, cfg.Protocol, cfg.Upstream)
|
||||
|
||||
var v4, v6 []netip.Addr
|
||||
var err error
|
||||
switch cfg.Protocol {
|
||||
case "doh":
|
||||
v4, v6, err = resolveDoH(host, cfg.Upstream, bypass)
|
||||
case "dot":
|
||||
v4, v6, err = resolveDoT(host, cfg.Upstream, bypass)
|
||||
default:
|
||||
v4, v6, err = resolvePlain(host, cfg.Upstream, bypass)
|
||||
}
|
||||
|
||||
func Resolve(
|
||||
ctx context.Context,
|
||||
host, protocol, upstream string,
|
||||
bypass bool,
|
||||
underlying string,
|
||||
) ([]netip.Addr, []netip.Addr, error) {
|
||||
t, err := buildTransport(ctx, protocol, upstream, bypass, underlying)
|
||||
if err != nil {
|
||||
shared.LogError("DNS", "Final Resolve failed for %s: %v", host, err)
|
||||
} else {
|
||||
shared.LogDebug("DNS", "Resolve success for %s: %d v4, %d v6", host, len(v4), len(v6))
|
||||
return nil, nil, err
|
||||
}
|
||||
return v4, v6, err
|
||||
return resolveHost(ctx, t, host)
|
||||
}
|
||||
|
||||
func CustomResolver(bypass bool) *net.Resolver {
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := GetDialer(bypass)
|
||||
return d.DialContext(ctx, network, address)
|
||||
},
|
||||
func buildTransport(
|
||||
ctx context.Context,
|
||||
protocol, upstream string,
|
||||
bypass bool,
|
||||
underlying string,
|
||||
) (Transport, error) {
|
||||
switch protocol {
|
||||
case "doh":
|
||||
u, err := url.Parse(upstream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hostname := u.Hostname()
|
||||
port := u.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
u.Host = net.JoinHostPort(hostname, port)
|
||||
|
||||
// Pre-resolve with IPv4-first ordering + bypass
|
||||
servers, _, err := resolveServerAddrs(ctx, u.Host, bypass, "443", underlying)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(servers) == 0 {
|
||||
return nil, fmt.Errorf("no addresses resolved for DoH server")
|
||||
}
|
||||
|
||||
// Custom dialer that tries servers in order (IPv4 → IPv6)
|
||||
dialer := GetDialer(bypass)
|
||||
transport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||
for _, addr := range servers {
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("all DoH addresses failed")
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
ServerName: hostname,
|
||||
},
|
||||
}
|
||||
|
||||
return DoHTransport{
|
||||
Client: &http.Client{Timeout: 5 * time.Second, Transport: transport},
|
||||
URL: u.String(),
|
||||
Servers: servers,
|
||||
Hostname: hostname,
|
||||
}, nil
|
||||
|
||||
case "dot":
|
||||
servers, sni, err := resolveServerAddrs(ctx, upstream, bypass, "853", underlying)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(servers) == 0 {
|
||||
return nil, fmt.Errorf("no addresses resolved for DoT server")
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
Dialer: GetDialer(bypass),
|
||||
Timeout: 5 * time.Second,
|
||||
TLSConfig: &tls.Config{
|
||||
ServerName: sni,
|
||||
},
|
||||
}
|
||||
return DoTTransport{
|
||||
Client: client,
|
||||
Servers: servers,
|
||||
}, nil
|
||||
|
||||
default: // plain DNS
|
||||
_, addr, err := parseUpstream(upstream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
servers, _, err := resolveServerAddrs(ctx, addr, bypass, "53", underlying)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Net: "udp",
|
||||
Dialer: GetDialer(bypass),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
return PlainTransport{
|
||||
Client: client,
|
||||
Servers: servers,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetDialer(bypass bool) *net.Dialer {
|
||||
if !bypass {
|
||||
return &net.Dialer{
|
||||
LocalAddr: nil,
|
||||
}
|
||||
return &net.Dialer{LocalAddr: nil}
|
||||
}
|
||||
|
||||
return &net.Dialer{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var opErr error
|
||||
|
||||
@@ -3,54 +3,51 @@
|
||||
|
||||
struct go_string { const char *str; long n; };
|
||||
|
||||
extern void SetDNSConfig(struct go_string handle);
|
||||
extern char* ResolveBootstrap(const char* host, int bypass);
|
||||
|
||||
JNIEXPORT void JNICALL Java_com_zaneschepke_tunnel_DnsConfigManager_setDNSConfig(
|
||||
JNIEnv* env, jclass clazz, jstring json)
|
||||
{
|
||||
if (json == NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
const char* cjson = (*env)->GetStringUTFChars(env, json, 0);
|
||||
if (cjson != NULL) {
|
||||
size_t len = (*env)->GetStringUTFLength(env, json);
|
||||
|
||||
SetDNSConfig((struct go_string){
|
||||
.str = cjson,
|
||||
.n = (long)len
|
||||
});
|
||||
|
||||
(*env)->ReleaseStringUTFChars(env, json, cjson);
|
||||
}
|
||||
}
|
||||
extern char* ResolveBootstrap(
|
||||
const char* host,
|
||||
const char* protocol,
|
||||
const char* upstream,
|
||||
const char* underlyingDnsServers,
|
||||
int bypass);
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_zaneschepke_tunnel_DnsConfigManager_resolveBootstrap(
|
||||
JNIEnv* env,
|
||||
jclass clazz,
|
||||
jstring host,
|
||||
jboolean bypass)
|
||||
jstring protocol,
|
||||
jstring upstream,
|
||||
jstring underlyingDnsServers,
|
||||
jint bypass)
|
||||
{
|
||||
if (host == NULL) {
|
||||
return (*env)->NewStringUTF(env, "{\"error\":\"invalid host\"}");
|
||||
if (host == NULL || protocol == NULL || upstream == NULL || underlyingDnsServers == NULL) {
|
||||
return (*env)->NewStringUTF(env, "ERR|invalid arguments");
|
||||
}
|
||||
|
||||
const char* chost = (*env)->GetStringUTFChars(env, host, NULL);
|
||||
if (chost == NULL) {
|
||||
return (*env)->NewStringUTF(env, "{\"error\":\"out of memory\"}");
|
||||
const char* chost = (*env)->GetStringUTFChars(env, host, NULL);
|
||||
const char* cprotocol = (*env)->GetStringUTFChars(env, protocol, NULL);
|
||||
const char* cupstream = (*env)->GetStringUTFChars(env, upstream, NULL);
|
||||
const char* cunderlying = (*env)->GetStringUTFChars(env, underlyingDnsServers, NULL);
|
||||
|
||||
if (chost == NULL || cprotocol == NULL || cupstream == NULL || cunderlying == NULL) {
|
||||
return (*env)->NewStringUTF(env, "ERR|out of memory");
|
||||
}
|
||||
|
||||
char* resultC = ResolveBootstrap(
|
||||
(char*)chost,
|
||||
bypass ? 1 : 0
|
||||
chost,
|
||||
cprotocol,
|
||||
cupstream,
|
||||
cunderlying,
|
||||
bypass ? 1 : 0
|
||||
);
|
||||
|
||||
(*env)->ReleaseStringUTFChars(env, host, chost);
|
||||
(*env)->ReleaseStringUTFChars(env, protocol, cprotocol);
|
||||
(*env)->ReleaseStringUTFChars(env, upstream, cupstream);
|
||||
(*env)->ReleaseStringUTFChars(env, underlyingDnsServers, cunderlying);
|
||||
|
||||
if (resultC == NULL) {
|
||||
return (*env)->NewStringUTF(env, "{\"error\":\"null response\"}");
|
||||
return (*env)->NewStringUTF(env, "ERR|null response");
|
||||
}
|
||||
|
||||
jstring jresult = (*env)->NewStringUTF(env, resultC);
|
||||
|
||||
Reference in New Issue
Block a user