mirror of
https://github.com/wgtunnel/android.git
synced 2026-06-02 00:29:08 +02:00
+6
-1
@@ -32,6 +32,7 @@ import kotlinx.coroutines.sync.withLock
|
|||||||
class TunnelCoordinator(
|
class TunnelCoordinator(
|
||||||
private val tunnelProvider: TunnelProvider,
|
private val tunnelProvider: TunnelProvider,
|
||||||
private val serviceManager: ServiceManager,
|
private val serviceManager: ServiceManager,
|
||||||
|
private val bootstrapCoordinator: AppBoostrapCoordinator,
|
||||||
settingsRepository: GeneralSettingRepository,
|
settingsRepository: GeneralSettingRepository,
|
||||||
private val tunnelRepository: TunnelRepository,
|
private val tunnelRepository: TunnelRepository,
|
||||||
dnsSettingsRepository: RoomDnsSettingsRepository,
|
dnsSettingsRepository: RoomDnsSettingsRepository,
|
||||||
@@ -86,7 +87,11 @@ class TunnelCoordinator(
|
|||||||
suspend fun startTunnel(
|
suspend fun startTunnel(
|
||||||
config: TunnelConfig,
|
config: TunnelConfig,
|
||||||
source: TunnelActionSource = TunnelActionSource.USER,
|
source: TunnelActionSource = TunnelActionSource.USER,
|
||||||
) = tunnelMutex.withLock { startTunnelInternal(config, source) }
|
) = tunnelMutex.withLock {
|
||||||
|
// wait for app to be bootstrapped
|
||||||
|
bootstrapCoordinator.isReady.first { it }
|
||||||
|
startTunnelInternal(config, source)
|
||||||
|
}
|
||||||
|
|
||||||
suspend fun stopTunnel(id: Int, source: TunnelActionSource = TunnelActionSource.USER) =
|
suspend fun stopTunnel(id: Int, source: TunnelActionSource = TunnelActionSource.USER) =
|
||||||
tunnelMutex.withLock {
|
tunnelMutex.withLock {
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ val coordinatorModule = module {
|
|||||||
get(),
|
get(),
|
||||||
get(),
|
get(),
|
||||||
get(),
|
get(),
|
||||||
|
get(),
|
||||||
get(named(Scope.APPLICATION)),
|
get(named(Scope.APPLICATION)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,25 +3,27 @@ package com.zaneschepke.tunnel
|
|||||||
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import org.json.JSONObject
|
|
||||||
|
|
||||||
internal object DnsConfigManager {
|
internal object DnsConfigManager {
|
||||||
private external fun setDNSConfig(configJson: String)
|
|
||||||
|
|
||||||
fun update(protocol: String, upstream: String) {
|
private external fun resolveBootstrap(
|
||||||
val config =
|
host: String,
|
||||||
JSONObject().apply {
|
protocol: String,
|
||||||
put("protocol", protocol)
|
upstream: String,
|
||||||
put("upstream", upstream)
|
underlyingDnsServers: String,
|
||||||
}
|
bypass: Int,
|
||||||
setDNSConfig(config.toString())
|
): String
|
||||||
}
|
|
||||||
|
|
||||||
private external fun resolveBootstrap(host: String, bypass: Int): String
|
suspend fun resolveHostBootstrap(
|
||||||
|
host: String,
|
||||||
suspend fun resolveHostBootstrap(host: String, bypass: Boolean): DnsBootstrapResult =
|
protocol: String,
|
||||||
|
upstream: String,
|
||||||
|
underlyingDnsServers: String,
|
||||||
|
bypass: Boolean,
|
||||||
|
): DnsBootstrapResult =
|
||||||
withContext(Dispatchers.IO) {
|
withContext(Dispatchers.IO) {
|
||||||
val raw = resolveBootstrap(host, if (bypass) 1 else 0)
|
val bypassOption = if (bypass) 1 else 0
|
||||||
|
val raw = resolveBootstrap(host, protocol, upstream, underlyingDnsServers, bypassOption)
|
||||||
|
|
||||||
if (raw.startsWith("ERR|")) {
|
if (raw.startsWith("ERR|")) {
|
||||||
throw RuntimeException(raw.removePrefix("ERR|"))
|
throw RuntimeException(raw.removePrefix("ERR|"))
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import com.zaneschepke.tunnel.DnsConfigManager
|
|||||||
import com.zaneschepke.tunnel.Tunnel
|
import com.zaneschepke.tunnel.Tunnel
|
||||||
import com.zaneschepke.tunnel.event.ActorEvent
|
import com.zaneschepke.tunnel.event.ActorEvent
|
||||||
import com.zaneschepke.tunnel.event.ActorEvent.ActiveConfigUpdated
|
import com.zaneschepke.tunnel.event.ActorEvent.ActiveConfigUpdated
|
||||||
|
import com.zaneschepke.tunnel.event.ActorEvent.BootstrapConfigUpdated
|
||||||
import com.zaneschepke.tunnel.event.ActorEvent.BootstrapStateChanged
|
import com.zaneschepke.tunnel.event.ActorEvent.BootstrapStateChanged
|
||||||
import com.zaneschepke.tunnel.event.ActorEvent.EngineStatus
|
import com.zaneschepke.tunnel.event.ActorEvent.EngineStatus
|
||||||
import com.zaneschepke.tunnel.event.ActorEvent.KillSwitchStateChanged
|
import com.zaneschepke.tunnel.event.ActorEvent.KillSwitchStateChanged
|
||||||
@@ -12,7 +13,9 @@ import com.zaneschepke.tunnel.event.ActorEvent.PeersUpdated
|
|||||||
import com.zaneschepke.tunnel.event.ActorEvent.ResolvedPeersApplied
|
import com.zaneschepke.tunnel.event.ActorEvent.ResolvedPeersApplied
|
||||||
import com.zaneschepke.tunnel.event.ActorEvent.TunnelStarted
|
import com.zaneschepke.tunnel.event.ActorEvent.TunnelStarted
|
||||||
import com.zaneschepke.tunnel.event.ActorEvent.TunnelStopped
|
import com.zaneschepke.tunnel.event.ActorEvent.TunnelStopped
|
||||||
|
import com.zaneschepke.tunnel.event.ActorEvent.UnderlyingDnsServersUpdated
|
||||||
import com.zaneschepke.tunnel.event.TunnelEvent
|
import com.zaneschepke.tunnel.event.TunnelEvent
|
||||||
|
import com.zaneschepke.tunnel.event.TunnelEvent.NoRootShellAccess
|
||||||
import com.zaneschepke.tunnel.model.BackendMode
|
import com.zaneschepke.tunnel.model.BackendMode
|
||||||
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
||||||
import com.zaneschepke.tunnel.model.PublicKey
|
import com.zaneschepke.tunnel.model.PublicKey
|
||||||
@@ -122,6 +125,10 @@ internal class TunnelActor(
|
|||||||
engine.stop(runtime.running.handle, runtime.running.mode)
|
engine.stop(runtime.running.handle, runtime.running.mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
is TunnelCommand.SetBootstrapConfig -> {
|
||||||
|
apply(BootstrapConfigUpdated(cmd.config))
|
||||||
|
}
|
||||||
|
|
||||||
is TunnelCommand.UpdatePeers -> {
|
is TunnelCommand.UpdatePeers -> {
|
||||||
val runtime = _state.value.byTunnelId[cmd.tunnelId] ?: continue
|
val runtime = _state.value.byTunnelId[cmd.tunnelId] ?: continue
|
||||||
val running = runtime.running
|
val running = runtime.running
|
||||||
@@ -209,9 +216,7 @@ internal class TunnelActor(
|
|||||||
} catch (t: Throwable) {
|
} catch (t: Throwable) {
|
||||||
Timber.w(t, "Root shell commands failed")
|
Timber.w(t, "Root shell commands failed")
|
||||||
if (t is RootShellException.NoRootAccess) {
|
if (t is RootShellException.NoRootAccess) {
|
||||||
_events.emit(
|
_events.emit(NoRootShellAccess(tunnelId = cmd.tunnelId))
|
||||||
TunnelEvent.NoRootShellAccess(tunnelId = cmd.tunnelId)
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (isPostDown) {
|
if (isPostDown) {
|
||||||
@@ -219,6 +224,10 @@ internal class TunnelActor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
is TunnelCommand.UpdateUnderlyingDnsServers -> {
|
||||||
|
apply(UnderlyingDnsServersUpdated(cmd.servers))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch (t: Throwable) {
|
} catch (t: Throwable) {
|
||||||
Timber.e(t, "Tunnel command failed: $cmd")
|
Timber.e(t, "Tunnel command failed: $cmd")
|
||||||
@@ -621,6 +630,14 @@ internal class TunnelActor(
|
|||||||
is KillSwitchStateChanged -> {
|
is KillSwitchStateChanged -> {
|
||||||
state.copy(killSwitchEnabled = event.enabled)
|
state.copy(killSwitchEnabled = event.enabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
is BootstrapConfigUpdated -> {
|
||||||
|
state.copy(dnsConfig = event.config)
|
||||||
|
}
|
||||||
|
|
||||||
|
is UnderlyingDnsServersUpdated -> {
|
||||||
|
state.copy(dnsConfig = state.dnsConfig.copy(underlyingDnsServers = event.servers))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -667,9 +684,17 @@ internal class TunnelActor(
|
|||||||
val endpoint = peer.endpoint ?: continue
|
val endpoint = peer.endpoint ?: continue
|
||||||
val host = endpoint.substringBeforeLast(":")
|
val host = endpoint.substringBeforeLast(":")
|
||||||
|
|
||||||
|
val dnsConfig = state.value.dnsConfig
|
||||||
|
|
||||||
val dnsResult =
|
val dnsResult =
|
||||||
try {
|
try {
|
||||||
DnsConfigManager.resolveHostBootstrap(host = host, bypass = bypassNeeded)
|
DnsConfigManager.resolveHostBootstrap(
|
||||||
|
host = host,
|
||||||
|
dnsConfig.protocol,
|
||||||
|
dnsConfig.upstream,
|
||||||
|
dnsConfig.underlyingDnsServers,
|
||||||
|
bypass = bypassNeeded,
|
||||||
|
)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Timber.w(e, "DNS failed for $host")
|
Timber.w(e, "DNS failed for $host")
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
package com.zaneschepke.tunnel.backend
|
package com.zaneschepke.tunnel.backend
|
||||||
|
|
||||||
|
import com.zaneschepke.networkmonitor.ActiveNetwork
|
||||||
|
import com.zaneschepke.networkmonitor.DnsInfo
|
||||||
import com.zaneschepke.networkmonitor.NetworkMonitor
|
import com.zaneschepke.networkmonitor.NetworkMonitor
|
||||||
import com.zaneschepke.networkmonitor.PrivateDnsMode
|
import com.zaneschepke.networkmonitor.PrivateDnsMode
|
||||||
import com.zaneschepke.tunnel.DnsConfigManager
|
|
||||||
import com.zaneschepke.tunnel.NotificationProvider
|
import com.zaneschepke.tunnel.NotificationProvider
|
||||||
import com.zaneschepke.tunnel.Tunnel
|
import com.zaneschepke.tunnel.Tunnel
|
||||||
import com.zaneschepke.tunnel.event.TunnelEvent
|
import com.zaneschepke.tunnel.event.TunnelEvent
|
||||||
@@ -14,6 +15,7 @@ import com.zaneschepke.tunnel.model.TunnelCommand
|
|||||||
import com.zaneschepke.tunnel.service.VpnService
|
import com.zaneschepke.tunnel.service.VpnService
|
||||||
import com.zaneschepke.tunnel.state.BackendStatus
|
import com.zaneschepke.tunnel.state.BackendStatus
|
||||||
import com.zaneschepke.tunnel.state.KillSwitchState
|
import com.zaneschepke.tunnel.state.KillSwitchState
|
||||||
|
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
|
||||||
import com.zaneschepke.tunnel.util.BackendException
|
import com.zaneschepke.tunnel.util.BackendException
|
||||||
import java.lang.ref.WeakReference
|
import java.lang.ref.WeakReference
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
@@ -27,8 +29,10 @@ import kotlinx.coroutines.flow.asSharedFlow
|
|||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.flow.distinctUntilChangedBy
|
import kotlinx.coroutines.flow.distinctUntilChangedBy
|
||||||
import kotlinx.coroutines.flow.first
|
import kotlinx.coroutines.flow.first
|
||||||
|
import kotlinx.coroutines.flow.firstOrNull
|
||||||
import kotlinx.coroutines.flow.update
|
import kotlinx.coroutines.flow.update
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
import kotlinx.coroutines.withTimeoutOrNull
|
||||||
import org.koin.java.KoinJavaComponent.inject
|
import org.koin.java.KoinJavaComponent.inject
|
||||||
import timber.log.Timber
|
import timber.log.Timber
|
||||||
|
|
||||||
@@ -191,24 +195,90 @@ class TunnelBackend(
|
|||||||
override suspend fun setBootstrapDnsMode(mode: DnsBoostrapMode) {
|
override suspend fun setBootstrapDnsMode(mode: DnsBoostrapMode) {
|
||||||
_status.update { it.copy(dnsMode = mode) }
|
_status.update { it.copy(dnsMode = mode) }
|
||||||
|
|
||||||
|
dnsConfigJob?.cancel()
|
||||||
|
dnsConfigJob = null
|
||||||
|
|
||||||
when (mode) {
|
when (mode) {
|
||||||
is DnsBoostrapMode.Custom -> {
|
is DnsBoostrapMode.Custom -> {
|
||||||
Timber.d("DNS Boostrap mode set to custom, disabling system dns monitoring")
|
Timber.d(
|
||||||
dnsConfigJob?.cancel()
|
"DNS Bootstrap mode set to custom: ${mode.config.protocol} -> ${mode.config.upstream}"
|
||||||
dnsConfigJob = null
|
|
||||||
|
|
||||||
DnsConfigManager.update(
|
|
||||||
mode.config.protocol,
|
|
||||||
mode.config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
emitInitialDnsConfig(mode.config.protocol, mode.config.upstream)
|
||||||
|
startUnderlyingDnsMonitoring()
|
||||||
}
|
}
|
||||||
|
|
||||||
DnsBoostrapMode.System -> {
|
DnsBoostrapMode.System -> {
|
||||||
|
Timber.d("DNS Bootstrap mode set to System")
|
||||||
|
emitInitialDnsConfig()
|
||||||
startSystemDnsMonitoring()
|
startSystemDnsMonitoring()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private suspend fun emitInitialDnsConfig(protocol: String? = null, upstream: String? = null) {
|
||||||
|
val state =
|
||||||
|
withTimeoutOrNull(2_500L) {
|
||||||
|
networkMonitor.connectivityStateFlow.first { connectivityState ->
|
||||||
|
val dns = connectivityState.underlyingDnsInfo
|
||||||
|
dns.servers.isNotEmpty() ||
|
||||||
|
connectivityState.activeNetwork is ActiveNetwork.Disconnected
|
||||||
|
}
|
||||||
|
} ?: networkMonitor.connectivityStateFlow.firstOrNull()
|
||||||
|
|
||||||
|
val dns = state?.underlyingDnsInfo ?: DnsInfo()
|
||||||
|
|
||||||
|
val finalProtocol =
|
||||||
|
protocol
|
||||||
|
?: when (dns.privateDnsMode) {
|
||||||
|
PrivateDnsMode.HOSTNAME -> "dot"
|
||||||
|
else -> "plain"
|
||||||
|
}
|
||||||
|
|
||||||
|
val finalUpstream =
|
||||||
|
upstream
|
||||||
|
?: when (dns.privateDnsMode) {
|
||||||
|
PrivateDnsMode.HOSTNAME ->
|
||||||
|
dns.privateDnsHostname?.takeIf { it.isNotBlank() }
|
||||||
|
?: DnsBoostrapConfig.DEFAULT_UPSTREAM
|
||||||
|
else -> dns.servers.firstOrNull() ?: DnsBoostrapConfig.DEFAULT_UPSTREAM
|
||||||
|
}
|
||||||
|
|
||||||
|
val underlying =
|
||||||
|
dns.servers.joinToString(",").ifBlank { DnsBoostrapConfig.DEFAULT_UPSTREAM }
|
||||||
|
|
||||||
|
Timber.d(
|
||||||
|
"DNS initial emission: protocol=$finalProtocol upstream=$finalUpstream underlying=$underlying"
|
||||||
|
)
|
||||||
|
|
||||||
|
actor.send(
|
||||||
|
TunnelCommand.SetBootstrapConfig(
|
||||||
|
RuntimeDnsConfig(
|
||||||
|
protocol = finalProtocol,
|
||||||
|
upstream = finalUpstream,
|
||||||
|
underlyingDnsServers = underlying,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun startUnderlyingDnsMonitoring() {
|
||||||
|
if (dnsConfigJob?.isActive == true) return
|
||||||
|
|
||||||
|
dnsConfigJob = scope.launch {
|
||||||
|
networkMonitor.connectivityStateFlow
|
||||||
|
.distinctUntilChangedBy { it.underlyingDnsInfo.servers }
|
||||||
|
.collect { state ->
|
||||||
|
val dns = state.underlyingDnsInfo
|
||||||
|
val underlying = dns.servers.joinToString(",")
|
||||||
|
|
||||||
|
Timber.d("Underlying DNS servers changed: $underlying")
|
||||||
|
|
||||||
|
actor.send(TunnelCommand.UpdateUnderlyingDnsServers(underlying))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override fun emergencyStopAllOfTypeSync(modeClass: KClass<out BackendMode>) {
|
override fun emergencyStopAllOfTypeSync(modeClass: KClass<out BackendMode>) {
|
||||||
actor.emergencyStopAllOfType(modeClass)
|
actor.emergencyStopAllOfType(modeClass)
|
||||||
}
|
}
|
||||||
@@ -236,13 +306,12 @@ class TunnelBackend(
|
|||||||
val config =
|
val config =
|
||||||
when (dns.privateDnsMode) {
|
when (dns.privateDnsMode) {
|
||||||
PrivateDnsMode.OFF,
|
PrivateDnsMode.OFF,
|
||||||
PrivateDnsMode.AUTOMATIC -> {
|
PrivateDnsMode.AUTOMATIC ->
|
||||||
DnsBoostrapConfig.Plain(
|
DnsBoostrapConfig.Plain(
|
||||||
dns.servers.firstOrNull() ?: DnsBoostrapConfig.DEFAULT_UPSTREAM
|
dns.servers.firstOrNull() ?: DnsBoostrapConfig.DEFAULT_UPSTREAM
|
||||||
)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
PrivateDnsMode.HOSTNAME -> {
|
PrivateDnsMode.HOSTNAME ->
|
||||||
dns.privateDnsHostname
|
dns.privateDnsHostname
|
||||||
?.takeIf { it.isNotBlank() }
|
?.takeIf { it.isNotBlank() }
|
||||||
?.let { DnsBoostrapConfig.DoT(it) }
|
?.let { DnsBoostrapConfig.DoT(it) }
|
||||||
@@ -250,12 +319,16 @@ class TunnelBackend(
|
|||||||
dns.servers.firstOrNull()
|
dns.servers.firstOrNull()
|
||||||
?: DnsBoostrapConfig.DEFAULT_UPSTREAM
|
?: DnsBoostrapConfig.DEFAULT_UPSTREAM
|
||||||
)
|
)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DnsConfigManager.update(
|
actor.send(
|
||||||
config.protocol,
|
TunnelCommand.SetBootstrapConfig(
|
||||||
config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
|
RuntimeDnsConfig(
|
||||||
|
protocol = config.protocol,
|
||||||
|
upstream = config.upstream ?: DnsBoostrapConfig.DEFAULT_UPSTREAM,
|
||||||
|
underlyingDnsServers = dns.servers.joinToString(","),
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import com.zaneschepke.tunnel.model.TunnelCommand
|
|||||||
import com.zaneschepke.tunnel.state.BootstrapState
|
import com.zaneschepke.tunnel.state.BootstrapState
|
||||||
import com.zaneschepke.tunnel.state.EngineStartResult
|
import com.zaneschepke.tunnel.state.EngineStartResult
|
||||||
import com.zaneschepke.tunnel.state.NativeTunnelStatus
|
import com.zaneschepke.tunnel.state.NativeTunnelStatus
|
||||||
|
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
|
||||||
import com.zaneschepke.wireguardautotunnel.parser.ActiveConfig
|
import com.zaneschepke.wireguardautotunnel.parser.ActiveConfig
|
||||||
import com.zaneschepke.wireguardautotunnel.parser.PeerSection
|
import com.zaneschepke.wireguardautotunnel.parser.PeerSection
|
||||||
|
|
||||||
@@ -23,6 +24,10 @@ sealed class ActorEvent {
|
|||||||
val preferIpv6: Boolean,
|
val preferIpv6: Boolean,
|
||||||
) : ActorEvent()
|
) : ActorEvent()
|
||||||
|
|
||||||
|
data class BootstrapConfigUpdated(val config: RuntimeDnsConfig) : ActorEvent()
|
||||||
|
|
||||||
|
data class UnderlyingDnsServersUpdated(val servers: String) : ActorEvent()
|
||||||
|
|
||||||
data class ResolvedPeersApplied(
|
data class ResolvedPeersApplied(
|
||||||
val tunnelId: Int,
|
val tunnelId: Int,
|
||||||
val cache: Map<PublicKey, DnsBootstrapResult>,
|
val cache: Map<PublicKey, DnsBootstrapResult>,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ sealed class DnsBoostrapConfig(open val upstream: String?) {
|
|||||||
|
|
||||||
data class DoH(override val upstream: String?) : DnsBoostrapConfig(upstream) {
|
data class DoH(override val upstream: String?) : DnsBoostrapConfig(upstream) {
|
||||||
override val protocol: String
|
override val protocol: String
|
||||||
get() = "dot"
|
get() = "doh"
|
||||||
}
|
}
|
||||||
|
|
||||||
data class DoT(override val upstream: String?) : DnsBoostrapConfig(upstream) {
|
data class DoT(override val upstream: String?) : DnsBoostrapConfig(upstream) {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.zaneschepke.tunnel.model
|
|||||||
|
|
||||||
import com.zaneschepke.tunnel.Tunnel
|
import com.zaneschepke.tunnel.Tunnel
|
||||||
import com.zaneschepke.tunnel.state.BootstrapState
|
import com.zaneschepke.tunnel.state.BootstrapState
|
||||||
|
import com.zaneschepke.tunnel.state.RuntimeDnsConfig
|
||||||
import com.zaneschepke.wireguardautotunnel.parser.PeerSection
|
import com.zaneschepke.wireguardautotunnel.parser.PeerSection
|
||||||
|
|
||||||
sealed class TunnelCommand {
|
sealed class TunnelCommand {
|
||||||
@@ -24,6 +25,10 @@ sealed class TunnelCommand {
|
|||||||
|
|
||||||
data class SetBootstrapState(val tunnelId: Int, val state: BootstrapState) : TunnelCommand()
|
data class SetBootstrapState(val tunnelId: Int, val state: BootstrapState) : TunnelCommand()
|
||||||
|
|
||||||
|
data class SetBootstrapConfig(val config: RuntimeDnsConfig) : TunnelCommand()
|
||||||
|
|
||||||
|
data class UpdateUnderlyingDnsServers(val servers: String) : TunnelCommand()
|
||||||
|
|
||||||
data class RunHook(val tunnelId: Int, val phase: Phase, val cmds: List<String>?) :
|
data class RunHook(val tunnelId: Int, val phase: Phase, val cmds: List<String>?) :
|
||||||
TunnelCommand() {
|
TunnelCommand() {
|
||||||
enum class Phase {
|
enum class Phase {
|
||||||
|
|||||||
@@ -4,4 +4,5 @@ data class ActorState(
|
|||||||
val byTunnelId: Map<Int, TunnelRuntimeState>,
|
val byTunnelId: Map<Int, TunnelRuntimeState>,
|
||||||
val byHandle: Map<Int, Int>,
|
val byHandle: Map<Int, Int>,
|
||||||
val killSwitchEnabled: Boolean = false,
|
val killSwitchEnabled: Boolean = false,
|
||||||
|
val dnsConfig: RuntimeDnsConfig = RuntimeDnsConfig(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package com.zaneschepke.tunnel.state
|
||||||
|
|
||||||
|
import com.zaneschepke.tunnel.model.DnsBoostrapConfig
|
||||||
|
|
||||||
|
data class RuntimeDnsConfig(
|
||||||
|
val protocol: String = "plain",
|
||||||
|
val upstream: String = DnsBoostrapConfig.DEFAULT_UPSTREAM,
|
||||||
|
val underlyingDnsServers: String = DnsBoostrapConfig.DEFAULT_UPSTREAM,
|
||||||
|
)
|
||||||
+286
-322
@@ -9,7 +9,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -18,7 +17,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -27,21 +25,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPlain = "udp://1.1.1.1:53"
|
|
||||||
|
|
||||||
var (
|
|
||||||
currentConfig DNSConfig = DNSConfig{
|
|
||||||
"plain",
|
|
||||||
"1.1.1.1:53",
|
|
||||||
}
|
|
||||||
configMu sync.RWMutex
|
|
||||||
)
|
|
||||||
|
|
||||||
type DNSConfig struct {
|
|
||||||
Protocol string `json:"protocol"` // plain, doh, or dot
|
|
||||||
Upstream string `json:"upstream"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Resolved struct {
|
type Resolved struct {
|
||||||
V4 []netip.Addr
|
V4 []netip.Addr
|
||||||
V6 []netip.Addr
|
V6 []netip.Addr
|
||||||
@@ -52,36 +35,34 @@ type ResolverOptions struct {
|
|||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultOptions() ResolverOptions {
|
type Transport interface {
|
||||||
return ResolverOptions{
|
Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
|
||||||
UpstreamURL: defaultPlain,
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//export SetDNSConfig
|
|
||||||
func SetDNSConfig(config string) {
|
|
||||||
var cfg DNSConfig
|
|
||||||
if err := json.Unmarshal([]byte(config), &cfg); err != nil {
|
|
||||||
shared.LogError("DNS", "Failed to parse DNSConfig: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if cfg.Protocol != "plain" && cfg.Protocol != "doh" && cfg.Protocol != "dot" {
|
|
||||||
cfg.Protocol = "plain"
|
|
||||||
}
|
|
||||||
configMu.Lock()
|
|
||||||
currentConfig = cfg
|
|
||||||
configMu.Unlock()
|
|
||||||
shared.LogDebug("DNS", "DNS config updated: %s %s", cfg.Protocol, cfg.Upstream)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//export ResolveBootstrap
|
//export ResolveBootstrap
|
||||||
func ResolveBootstrap(host *C.char, bypass C.int) *C.char {
|
func ResolveBootstrap(
|
||||||
|
host *C.char,
|
||||||
|
protocol *C.char,
|
||||||
|
upstream *C.char,
|
||||||
|
underlyingDnsServers *C.char,
|
||||||
|
bypass C.int,
|
||||||
|
) *C.char {
|
||||||
h := C.GoString(host)
|
h := C.GoString(host)
|
||||||
|
p := C.GoString(protocol)
|
||||||
|
u := C.GoString(upstream)
|
||||||
|
underlying := C.GoString(underlyingDnsServers)
|
||||||
bp := bypass == 1
|
bp := bypass == 1
|
||||||
shared.LogDebug("DNS", "ResolveBootstrap called for host=%s (bypass=%t)", h, bp)
|
|
||||||
|
|
||||||
v4, v6, err := Resolve(h, bp)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
shared.LogDebug(
|
||||||
|
"DNS",
|
||||||
|
"ResolveBootstrap called host=%s protocol=%s upstream=%s bypass=%t",
|
||||||
|
h, p, u, bp,
|
||||||
|
)
|
||||||
|
|
||||||
|
v4, v6, err := Resolve(ctx, h, p, u, bp, underlying)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
shared.LogError("DNS", "ResolveBootstrap failed for %s: %v", h, err)
|
shared.LogError("DNS", "ResolveBootstrap failed for %s: %v", h, err)
|
||||||
return C.CString("ERR|" + err.Error())
|
return C.CString("ERR|" + err.Error())
|
||||||
@@ -96,15 +77,76 @@ func ResolveBootstrap(host *C.char, bypass C.int) *C.char {
|
|||||||
v6Str[i] = ip.String()
|
v6Str[i] = ip.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
result := "v4=" + strings.Join(v4Str, ",") + ";v6=" + strings.Join(v6Str, ",")
|
result := "v4=" + strings.Join(v4Str, ",") +
|
||||||
|
";v6=" + strings.Join(v6Str, ",")
|
||||||
|
|
||||||
shared.LogDebug("DNS", "ResolveBootstrap success for %s: %s", h, result)
|
shared.LogDebug("DNS", "ResolveBootstrap success for %s: %s", h, result)
|
||||||
return C.CString(result)
|
return C.CString(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getConfig() DNSConfig {
|
type DoTTransport struct {
|
||||||
configMu.RLock()
|
Client *dns.Client
|
||||||
defer configMu.RUnlock()
|
Servers []string
|
||||||
return currentConfig
|
}
|
||||||
|
|
||||||
|
type DoHTransport struct {
|
||||||
|
Client *http.Client
|
||||||
|
URL string
|
||||||
|
Servers []string // IPv4 first, IPv6 fallback
|
||||||
|
Hostname string // for SNI and Host header
|
||||||
|
}
|
||||||
|
|
||||||
|
type PlainTransport struct {
|
||||||
|
Client *dns.Client
|
||||||
|
Servers []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveHost(
|
||||||
|
ctx context.Context,
|
||||||
|
t Transport,
|
||||||
|
host string,
|
||||||
|
) (v4, v6 []netip.Addr, err error) {
|
||||||
|
a4, e4 := resolveQ(ctx, t, host, dns.TypeA)
|
||||||
|
if e4 == nil {
|
||||||
|
v4 = a4
|
||||||
|
}
|
||||||
|
a6, e6 := resolveQ(ctx, t, host, dns.TypeAAAA)
|
||||||
|
if e6 == nil {
|
||||||
|
v6 = a6
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(v4) > 0 || len(v6) > 0 {
|
||||||
|
return v4, v6, nil
|
||||||
|
}
|
||||||
|
return nil, nil, errors.Join(e4, e6)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveQ(
|
||||||
|
ctx context.Context,
|
||||||
|
t Transport,
|
||||||
|
host string,
|
||||||
|
qtype uint16,
|
||||||
|
) ([]netip.Addr, error) {
|
||||||
|
req := &dns.Msg{}
|
||||||
|
req.SetQuestion(dns.Fqdn(host), qtype)
|
||||||
|
req.SetEdns0(4096, true)
|
||||||
|
|
||||||
|
res, err := t.Query(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if res == nil {
|
||||||
|
return nil, fmt.Errorf("nil DNS response")
|
||||||
|
}
|
||||||
|
if res.Rcode != dns.RcodeSuccess {
|
||||||
|
return nil, fmt.Errorf("rcode %d", res.Rcode)
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs := parseDNSAnswers(res, qtype)
|
||||||
|
if len(addrs) == 0 {
|
||||||
|
return nil, fmt.Errorf("no answers for qtype %d", qtype)
|
||||||
|
}
|
||||||
|
return addrs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseUpstream(upstreamURL string) (network, address string, err error) {
|
func parseUpstream(upstreamURL string) (network, address string, err error) {
|
||||||
@@ -113,7 +155,6 @@ func parseUpstream(upstreamURL string) (network, address string, err error) {
|
|||||||
if !strings.Contains(u, "://") {
|
if !strings.Contains(u, "://") {
|
||||||
u = "udp://" + u
|
u = "udp://" + u
|
||||||
}
|
}
|
||||||
|
|
||||||
parsed, err := url.Parse(u)
|
parsed, err := url.Parse(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
shared.LogError("DNS", "parseUpstream failed for %q: %v", upstreamURL, err)
|
shared.LogError("DNS", "parseUpstream failed for %q: %v", upstreamURL, err)
|
||||||
@@ -141,189 +182,127 @@ func parseUpstream(upstreamURL string) (network, address string, err error) {
|
|||||||
return network, address, nil
|
return network, address, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveServerAddr(ctx context.Context, address string, bypass bool) (string, error) {
|
func newUnderlyingResolver(bypass bool, underlying string) *net.Resolver {
|
||||||
|
if !bypass {
|
||||||
|
return &net.Resolver{PreferGo: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
rawServers := strings.Split(underlying, ",")
|
||||||
|
var servers []string
|
||||||
|
|
||||||
|
for _, s := range rawServers {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(s, ":") {
|
||||||
|
s = net.JoinHostPort(s, "53")
|
||||||
|
}
|
||||||
|
servers = append(servers, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(servers) == 0 {
|
||||||
|
servers = []string{"1.1.1.1:53"}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||||
|
for _, server := range servers {
|
||||||
|
conn, err := GetDialer(true).DialContext(ctx, network, server)
|
||||||
|
if err == nil {
|
||||||
|
shared.LogDebug("DNS", "Using underlying bootstrap resolver: %s", server)
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
shared.LogDebug("DNS", "Bootstrap resolver failed for %s: %v", server, err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("all underlying DNS servers failed")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveServerAddrs(
|
||||||
|
ctx context.Context,
|
||||||
|
address string,
|
||||||
|
bypass bool,
|
||||||
|
defaultPort string,
|
||||||
|
underlying string,
|
||||||
|
) ([]string, string, error) {
|
||||||
host, port, err := net.SplitHostPort(address)
|
host, port, err := net.SplitHostPort(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
shared.LogError("DNS", "resolveServerAddr: invalid address %q: %v", address, err)
|
host = address
|
||||||
return "", err
|
port = defaultPort
|
||||||
}
|
}
|
||||||
|
|
||||||
if net.ParseIP(host) != nil {
|
if net.ParseIP(host) != nil {
|
||||||
return address, nil
|
return []string{net.JoinHostPort(host, port)}, host, nil
|
||||||
}
|
|
||||||
|
|
||||||
shared.LogDebug("DNS", "resolveServerAddr: bootstrapping upstream hostname %s (bypass=%t)", host, bypass)
|
|
||||||
|
|
||||||
bootstrapDialer := GetDialer(bypass)
|
|
||||||
resolver := &net.Resolver{
|
|
||||||
PreferGo: true,
|
|
||||||
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
|
||||||
return bootstrapDialer.DialContext(ctx, network, "1.1.1.1:53")
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resolver := newUnderlyingResolver(bypass, underlying)
|
||||||
ips, err := resolver.LookupIP(ctx, "ip", host)
|
ips, err := resolver.LookupIP(ctx, "ip", host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
shared.LogError("DNS", "Failed to resolve upstream hostname %s (bypass=%t): %v", host, bypass, err)
|
shared.LogError("DNS", "Failed to resolve upstream %s (bypass=%t): %v", host, bypass, err)
|
||||||
return "", fmt.Errorf("failed to resolve upstream hostname %s: %w", host, err)
|
return nil, "", err
|
||||||
}
|
|
||||||
if len(ips) == 0 {
|
|
||||||
err = errors.New("no IPs found for upstream hostname")
|
|
||||||
shared.LogError("DNS", "%v for %s", err, host)
|
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := net.JoinHostPort(ips[0].String(), port)
|
var v4, v6 []string
|
||||||
shared.LogDebug("DNS", "Resolved upstream %s -> %s", host, addr)
|
for _, ip := range ips {
|
||||||
return addr, nil
|
addr := net.JoinHostPort(ip.String(), port)
|
||||||
}
|
if ip.To4() != nil {
|
||||||
func resolveInner(host string, ipType uint16, network, serverAddr string, bypass bool) ([]netip.Addr, error) {
|
v4 = append(v4, addr)
|
||||||
req := &dns.Msg{}
|
} else {
|
||||||
req.Id = dns.Id()
|
v6 = append(v6, addr)
|
||||||
req.RecursionDesired = true
|
|
||||||
req.SetQuestion(dns.Fqdn(host), ipType)
|
|
||||||
req.SetEdns0(4096, true)
|
|
||||||
|
|
||||||
client := &dns.Client{
|
|
||||||
Net: network,
|
|
||||||
Dialer: GetDialer(bypass),
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
UDPSize: 4096,
|
|
||||||
}
|
|
||||||
|
|
||||||
res, _, err := client.Exchange(req, serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
shared.LogError("DNS", "resolveInner: DNS exchange failed for %s (type=%d, server=%s, bypass=%t): %v", host, ipType, serverAddr, bypass, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if res.Rcode != dns.RcodeSuccess {
|
|
||||||
shared.LogError("DNS", "resolveInner: DNS query failed with Rcode %d for %s", res.Rcode, host)
|
|
||||||
return nil, fmt.Errorf("DNS query failed with Rcode: %d", res.Rcode)
|
|
||||||
}
|
|
||||||
|
|
||||||
var addr []netip.Addr
|
|
||||||
for _, ans := range res.Answer {
|
|
||||||
switch ipType {
|
|
||||||
case dns.TypeA:
|
|
||||||
if a, ok := ans.(*dns.A); ok {
|
|
||||||
if ip, err := netip.ParseAddr(a.A.String()); err == nil {
|
|
||||||
addr = append(addr, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case dns.TypeAAAA:
|
|
||||||
if aaaa, ok := ans.(*dns.AAAA); ok {
|
|
||||||
if ip, err := netip.ParseAddr(aaaa.AAAA.String()); err == nil {
|
|
||||||
addr = append(addr, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return addr, nil
|
|
||||||
|
return append(v4, v6...), host, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolvePlain(host, upstreamURL string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
|
func (t PlainTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||||
shared.LogDebug("DNS", "resolvePlain: %s with upstream=%s (bypass=%t)", host, upstreamURL, bypass)
|
for _, server := range t.Servers {
|
||||||
network, addr, err := parseUpstream(upstreamURL)
|
m, _, err := t.Client.Exchange(msg, server)
|
||||||
if err != nil {
|
if err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
|
||||||
return nil, nil, err
|
return m, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return nil, fmt.Errorf("all DNS servers failed")
|
||||||
serverAddr, err := resolveServerAddr(context.Background(), addr, bypass)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
var v4, v6 []netip.Addr
|
|
||||||
var v4Err, v6Err error
|
|
||||||
wg.Add(2)
|
|
||||||
go func() { v4, v4Err = resolveInner(host, dns.TypeA, network, serverAddr, bypass); wg.Done() }()
|
|
||||||
go func() { v6, v6Err = resolveInner(host, dns.TypeAAAA, network, serverAddr, bypass); wg.Done() }()
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if v4Err != nil && v6Err != nil {
|
|
||||||
shared.LogError("DNS", "resolvePlain failed for %s: both A and AAAA failed", host)
|
|
||||||
return nil, nil, errors.Join(v4Err, v6Err)
|
|
||||||
}
|
|
||||||
if len(v4) == 0 && len(v6) == 0 {
|
|
||||||
err = errors.New("no IP addresses found")
|
|
||||||
shared.LogError("DNS", "%v for %s", err, host)
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return v4, v6, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveDoH(host, dohURL string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
|
func (t DoTTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||||
shared.LogDebug("DNS", "Resolving DOH: %s with %s", host, dohURL)
|
for _, server := range t.Servers {
|
||||||
var v4, v6 []netip.Addr
|
m, _, err := t.Client.Exchange(msg, server)
|
||||||
var v4Err, v6Err error
|
if err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
|
||||||
|
return m, nil
|
||||||
var wg sync.WaitGroup
|
}
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
v4, v4Err = doDoHQuery(dohURL, host, dns.TypeA, bypass)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
v6, v6Err = doDoHQuery(dohURL, host, dns.TypeAAAA, bypass)
|
|
||||||
}()
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if v4Err != nil && v6Err != nil {
|
|
||||||
shared.LogError("DNS", "resolveDoH failed for %s: both queries failed", host)
|
|
||||||
return nil, nil, errors.Join(v4Err, v6Err)
|
|
||||||
}
|
}
|
||||||
return v4, v6, nil
|
return nil, fmt.Errorf("all DoT servers failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func doDoHQuery(dohURL, host string, qtype uint16, bypass bool) ([]netip.Addr, error) {
|
func (t DoHTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||||
req := &dns.Msg{}
|
wire, err := msg.Pack()
|
||||||
req.Id = dns.Id()
|
|
||||||
req.RecursionDesired = true
|
|
||||||
req.SetEdns0(4096, true)
|
|
||||||
req.SetQuestion(dns.Fqdn(host), qtype)
|
|
||||||
|
|
||||||
transport := &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
h, port, _ := net.SplitHostPort(addr)
|
|
||||||
if net.ParseIP(h) == nil {
|
|
||||||
ips, err := CustomResolver(bypass).LookupIP(ctx, "ip", h)
|
|
||||||
if err == nil && len(ips) > 0 {
|
|
||||||
h = ips[0].String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return GetDialer(bypass).DialContext(ctx, network, net.JoinHostPort(h, port))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: transport,
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
wire, err := req.Pack()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, err := http.NewRequestWithContext(context.Background(), "POST", dohURL, bytes.NewReader(wire))
|
req, err := http.NewRequestWithContext(
|
||||||
|
ctx, "POST", t.URL, bytes.NewReader(wire),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpReq.Header.Set("Content-Type", "application/dns-message")
|
req.Header.Set("Content-Type", "application/dns-message")
|
||||||
httpReq.Header.Set("Accept", "application/dns-message")
|
req.Header.Set("Accept", "application/dns-message")
|
||||||
|
req.Host = t.Hostname // important for virtual hosting and cert validation
|
||||||
|
|
||||||
resp, err := client.Do(httpReq)
|
resp, err := t.Client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
shared.LogError("DNS", "doDoHQuery: DoH server returned HTTP %d for %s", resp.StatusCode, host)
|
return nil, fmt.Errorf("doh status %d", resp.StatusCode)
|
||||||
return nil, fmt.Errorf("DoH HTTP %d", resp.StatusCode)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
@@ -335,158 +314,143 @@ func doDoHQuery(dohURL, host string, qtype uint16, bypass bool) ([]netip.Addr, e
|
|||||||
if err := res.Unpack(body); err != nil {
|
if err := res.Unpack(body); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Rcode != dns.RcodeSuccess {
|
return &res, nil
|
||||||
shared.LogError("DNS", "doDoHQuery: DoH Rcode %d for %s", res.Rcode, host)
|
|
||||||
return nil, fmt.Errorf("DoH Rcode %d", res.Rcode)
|
|
||||||
}
|
|
||||||
|
|
||||||
var addrs []netip.Addr
|
|
||||||
for _, ans := range res.Answer {
|
|
||||||
if qtype == dns.TypeA {
|
|
||||||
if a, ok := ans.(*dns.A); ok {
|
|
||||||
if ip, _ := netip.ParseAddr(a.A.String()); ip.Is4() {
|
|
||||||
addrs = append(addrs, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if qtype == dns.TypeAAAA {
|
|
||||||
if aaaa, ok := ans.(*dns.AAAA); ok {
|
|
||||||
if ip, _ := netip.ParseAddr(aaaa.AAAA.String()); ip.Is6() {
|
|
||||||
addrs = append(addrs, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return addrs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveDoT(host, dotUpstream string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
|
func parseDNSAnswers(msg *dns.Msg, qtype uint16) []netip.Addr {
|
||||||
shared.LogDebug("DNS", "resolveDoT: %s with upstream=%s (bypass=%t)", host, dotUpstream, bypass)
|
var out []netip.Addr
|
||||||
|
for _, ans := range msg.Answer {
|
||||||
var v4, v6 []netip.Addr
|
|
||||||
var v4Err, v6Err error
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
v4, v4Err = doDoTQuery(dotUpstream, host, dns.TypeA, bypass)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
v6, v6Err = doDoTQuery(dotUpstream, host, dns.TypeAAAA, bypass)
|
|
||||||
}()
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if v4Err != nil && v6Err != nil {
|
|
||||||
shared.LogError("DNS", "resolveDoT failed for %s: both A and AAAA queries failed (bypass=%t)", host, bypass)
|
|
||||||
return nil, nil, errors.Join(v4Err, v6Err)
|
|
||||||
}
|
|
||||||
|
|
||||||
shared.LogDebug("DNS", "resolveDoT success for %s: %d v4, %d v6 (bypass=%t)", host, len(v4), len(v6), bypass)
|
|
||||||
return v4, v6, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func doDoTQuery(dotUpstream, host string, qtype uint16, bypass bool) ([]netip.Addr, error) {
|
|
||||||
// Normalize upstream to host:port
|
|
||||||
sni, port, err := net.SplitHostPort(dotUpstream)
|
|
||||||
if err != nil {
|
|
||||||
sni = dotUpstream
|
|
||||||
port = "853"
|
|
||||||
dotUpstream = net.JoinHostPort(sni, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolve hostname using bypass resolver
|
|
||||||
serverAddr, err := resolveServerAddr(context.Background(), dotUpstream, bypass)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req := &dns.Msg{}
|
|
||||||
req.Id = dns.Id()
|
|
||||||
req.RecursionDesired = true
|
|
||||||
req.SetEdns0(4096, true)
|
|
||||||
req.SetQuestion(dns.Fqdn(host), qtype)
|
|
||||||
|
|
||||||
client := &dns.Client{
|
|
||||||
Net: "tcp-tls",
|
|
||||||
Dialer: GetDialer(bypass),
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
ServerName: sni,
|
|
||||||
InsecureSkipVerify: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
res, _, err := client.Exchange(req, serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if res.Rcode != dns.RcodeSuccess {
|
|
||||||
return nil, fmt.Errorf("DoT query failed with Rcode: %d", res.Rcode)
|
|
||||||
}
|
|
||||||
|
|
||||||
var addrs []netip.Addr
|
|
||||||
for _, ans := range res.Answer {
|
|
||||||
switch qtype {
|
switch qtype {
|
||||||
case dns.TypeA:
|
case dns.TypeA:
|
||||||
if a, ok := ans.(*dns.A); ok {
|
if a, ok := ans.(*dns.A); ok {
|
||||||
if ip, _ := netip.ParseAddr(a.A.String()); ip.Is4() {
|
if ip, err := netip.ParseAddr(a.A.String()); err == nil {
|
||||||
addrs = append(addrs, ip)
|
out = append(out, ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case dns.TypeAAAA:
|
case dns.TypeAAAA:
|
||||||
if aaaa, ok := ans.(*dns.AAAA); ok {
|
if aaaa, ok := ans.(*dns.AAAA); ok {
|
||||||
if ip, _ := netip.ParseAddr(aaaa.AAAA.String()); ip.Is6() {
|
if ip, err := netip.ParseAddr(aaaa.AAAA.String()); err == nil {
|
||||||
addrs = append(addrs, ip)
|
out = append(out, ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return addrs, nil
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve runs the correct protocol based on the global config
|
func Resolve(
|
||||||
func Resolve(host string, bypass bool) ([]netip.Addr, []netip.Addr, error) {
|
ctx context.Context,
|
||||||
cfg := getConfig()
|
host, protocol, upstream string,
|
||||||
shared.LogDebug("DNS", "Resolve(%s, bypass=%t) protocol=%s upstream=%s", host, bypass, cfg.Protocol, cfg.Upstream)
|
bypass bool,
|
||||||
|
underlying string,
|
||||||
var v4, v6 []netip.Addr
|
) ([]netip.Addr, []netip.Addr, error) {
|
||||||
var err error
|
t, err := buildTransport(ctx, protocol, upstream, bypass, underlying)
|
||||||
switch cfg.Protocol {
|
|
||||||
case "doh":
|
|
||||||
v4, v6, err = resolveDoH(host, cfg.Upstream, bypass)
|
|
||||||
case "dot":
|
|
||||||
v4, v6, err = resolveDoT(host, cfg.Upstream, bypass)
|
|
||||||
default:
|
|
||||||
v4, v6, err = resolvePlain(host, cfg.Upstream, bypass)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
shared.LogError("DNS", "Final Resolve failed for %s: %v", host, err)
|
return nil, nil, err
|
||||||
} else {
|
|
||||||
shared.LogDebug("DNS", "Resolve success for %s: %d v4, %d v6", host, len(v4), len(v6))
|
|
||||||
}
|
}
|
||||||
return v4, v6, err
|
return resolveHost(ctx, t, host)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CustomResolver(bypass bool) *net.Resolver {
|
func buildTransport(
|
||||||
return &net.Resolver{
|
ctx context.Context,
|
||||||
PreferGo: true,
|
protocol, upstream string,
|
||||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
bypass bool,
|
||||||
d := GetDialer(bypass)
|
underlying string,
|
||||||
return d.DialContext(ctx, network, address)
|
) (Transport, error) {
|
||||||
},
|
switch protocol {
|
||||||
|
case "doh":
|
||||||
|
u, err := url.Parse(upstream)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hostname := u.Hostname()
|
||||||
|
port := u.Port()
|
||||||
|
if port == "" {
|
||||||
|
port = "443"
|
||||||
|
}
|
||||||
|
u.Host = net.JoinHostPort(hostname, port)
|
||||||
|
|
||||||
|
// Pre-resolve with IPv4-first ordering + bypass
|
||||||
|
servers, _, err := resolveServerAddrs(ctx, u.Host, bypass, "443", underlying)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(servers) == 0 {
|
||||||
|
return nil, fmt.Errorf("no addresses resolved for DoH server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom dialer that tries servers in order (IPv4 → IPv6)
|
||||||
|
dialer := GetDialer(bypass)
|
||||||
|
transport := &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||||
|
for _, addr := range servers {
|
||||||
|
conn, err := dialer.DialContext(ctx, network, addr)
|
||||||
|
if err == nil {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("all DoH addresses failed")
|
||||||
|
},
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
ServerName: hostname,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return DoHTransport{
|
||||||
|
Client: &http.Client{Timeout: 5 * time.Second, Transport: transport},
|
||||||
|
URL: u.String(),
|
||||||
|
Servers: servers,
|
||||||
|
Hostname: hostname,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
case "dot":
|
||||||
|
servers, sni, err := resolveServerAddrs(ctx, upstream, bypass, "853", underlying)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(servers) == 0 {
|
||||||
|
return nil, fmt.Errorf("no addresses resolved for DoT server")
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &dns.Client{
|
||||||
|
Net: "tcp-tls",
|
||||||
|
Dialer: GetDialer(bypass),
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
ServerName: sni,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return DoTTransport{
|
||||||
|
Client: client,
|
||||||
|
Servers: servers,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
default: // plain DNS
|
||||||
|
_, addr, err := parseUpstream(upstream)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
servers, _, err := resolveServerAddrs(ctx, addr, bypass, "53", underlying)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &dns.Client{
|
||||||
|
Net: "udp",
|
||||||
|
Dialer: GetDialer(bypass),
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
return PlainTransport{
|
||||||
|
Client: client,
|
||||||
|
Servers: servers,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetDialer(bypass bool) *net.Dialer {
|
func GetDialer(bypass bool) *net.Dialer {
|
||||||
if !bypass {
|
if !bypass {
|
||||||
return &net.Dialer{
|
return &net.Dialer{LocalAddr: nil}
|
||||||
LocalAddr: nil,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &net.Dialer{
|
return &net.Dialer{
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
var opErr error
|
var opErr error
|
||||||
|
|||||||
@@ -3,54 +3,51 @@
|
|||||||
|
|
||||||
struct go_string { const char *str; long n; };
|
struct go_string { const char *str; long n; };
|
||||||
|
|
||||||
extern void SetDNSConfig(struct go_string handle);
|
extern char* ResolveBootstrap(
|
||||||
extern char* ResolveBootstrap(const char* host, int bypass);
|
const char* host,
|
||||||
|
const char* protocol,
|
||||||
JNIEXPORT void JNICALL Java_com_zaneschepke_tunnel_DnsConfigManager_setDNSConfig(
|
const char* upstream,
|
||||||
JNIEnv* env, jclass clazz, jstring json)
|
const char* underlyingDnsServers,
|
||||||
{
|
int bypass);
|
||||||
if (json == NULL) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char* cjson = (*env)->GetStringUTFChars(env, json, 0);
|
|
||||||
if (cjson != NULL) {
|
|
||||||
size_t len = (*env)->GetStringUTFLength(env, json);
|
|
||||||
|
|
||||||
SetDNSConfig((struct go_string){
|
|
||||||
.str = cjson,
|
|
||||||
.n = (long)len
|
|
||||||
});
|
|
||||||
|
|
||||||
(*env)->ReleaseStringUTFChars(env, json, cjson);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
JNIEXPORT jstring JNICALL
|
JNIEXPORT jstring JNICALL
|
||||||
Java_com_zaneschepke_tunnel_DnsConfigManager_resolveBootstrap(
|
Java_com_zaneschepke_tunnel_DnsConfigManager_resolveBootstrap(
|
||||||
JNIEnv* env,
|
JNIEnv* env,
|
||||||
jclass clazz,
|
jclass clazz,
|
||||||
jstring host,
|
jstring host,
|
||||||
jboolean bypass)
|
jstring protocol,
|
||||||
|
jstring upstream,
|
||||||
|
jstring underlyingDnsServers,
|
||||||
|
jint bypass)
|
||||||
{
|
{
|
||||||
if (host == NULL) {
|
if (host == NULL || protocol == NULL || upstream == NULL || underlyingDnsServers == NULL) {
|
||||||
return (*env)->NewStringUTF(env, "{\"error\":\"invalid host\"}");
|
return (*env)->NewStringUTF(env, "ERR|invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* chost = (*env)->GetStringUTFChars(env, host, NULL);
|
const char* chost = (*env)->GetStringUTFChars(env, host, NULL);
|
||||||
if (chost == NULL) {
|
const char* cprotocol = (*env)->GetStringUTFChars(env, protocol, NULL);
|
||||||
return (*env)->NewStringUTF(env, "{\"error\":\"out of memory\"}");
|
const char* cupstream = (*env)->GetStringUTFChars(env, upstream, NULL);
|
||||||
|
const char* cunderlying = (*env)->GetStringUTFChars(env, underlyingDnsServers, NULL);
|
||||||
|
|
||||||
|
if (chost == NULL || cprotocol == NULL || cupstream == NULL || cunderlying == NULL) {
|
||||||
|
return (*env)->NewStringUTF(env, "ERR|out of memory");
|
||||||
}
|
}
|
||||||
|
|
||||||
char* resultC = ResolveBootstrap(
|
char* resultC = ResolveBootstrap(
|
||||||
(char*)chost,
|
chost,
|
||||||
bypass ? 1 : 0
|
cprotocol,
|
||||||
|
cupstream,
|
||||||
|
cunderlying,
|
||||||
|
bypass ? 1 : 0
|
||||||
);
|
);
|
||||||
|
|
||||||
(*env)->ReleaseStringUTFChars(env, host, chost);
|
(*env)->ReleaseStringUTFChars(env, host, chost);
|
||||||
|
(*env)->ReleaseStringUTFChars(env, protocol, cprotocol);
|
||||||
|
(*env)->ReleaseStringUTFChars(env, upstream, cupstream);
|
||||||
|
(*env)->ReleaseStringUTFChars(env, underlyingDnsServers, cunderlying);
|
||||||
|
|
||||||
if (resultC == NULL) {
|
if (resultC == NULL) {
|
||||||
return (*env)->NewStringUTF(env, "{\"error\":\"null response\"}");
|
return (*env)->NewStringUTF(env, "ERR|null response");
|
||||||
}
|
}
|
||||||
|
|
||||||
jstring jresult = (*env)->NewStringUTF(env, resultC);
|
jstring jresult = (*env)->NewStringUTF(env, resultC);
|
||||||
|
|||||||
Reference in New Issue
Block a user