From c98fa04f731de0dd146dc7b156e24d80d5a42d09 Mon Sep 17 00:00:00 2001 From: Zane Schepke Date: Sun, 16 Mar 2025 20:10:44 -0400 Subject: [PATCH] fix: auto tunnel and tunnel regressions --- .../core/service/tile/TunnelControlTile.kt | 2 +- .../core/tunnel/BaseTunnel.kt | 143 ++++++++---------- .../core/tunnel/Extensions.kt | 25 +++ .../core/tunnel/KernelTunnel.kt | 65 +++----- .../core/tunnel/TunnelManager.kt | 2 +- .../core/tunnel/TunnelProvider.kt | 2 +- .../core/tunnel/UserspaceTunnel.kt | 83 ++++------ .../domain/entity/TunnelConf.kt | 2 - .../domain/state/AutoTunnelState.kt | 27 ++-- .../ui/screens/main/MainScreen.kt | 3 +- .../ui/screens/main/TunnelOptionsScreen.kt | 3 +- .../ui/state/AppUiState.kt | 2 +- .../viewmodel/AppViewModel.kt | 10 +- .../networkmonitor/AndroidNetworkMonitor.kt | 19 +-- 14 files changed, 165 insertions(+), 223 deletions(-) create mode 100644 app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/Extensions.kt diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/service/tile/TunnelControlTile.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/service/tile/TunnelControlTile.kt index d006e69f..38614d21 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/service/tile/TunnelControlTile.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/service/tile/TunnelControlTile.kt @@ -56,7 +56,7 @@ class TunnelControlTile : TileService() { if (tunnels.isEmpty()) return@launch setUnavailable() with(tunnelManager.activeTunnels.value) { if (isNotEmpty()) if (size == 1) { - tunnels.firstOrNull { it.id == keys.first() }?.let { return@launch updateTile(it.tunName, true) } + tunnels.firstOrNull { it.id == keys.first().id }?.let { return@launch updateTile(it.tunName, true) } } else { return@launch updateTile(getString(R.string.multiple), true) } diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/BaseTunnel.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/BaseTunnel.kt index 72ca7e9d..6e7e802f 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/BaseTunnel.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/BaseTunnel.kt @@ -21,11 +21,11 @@ import com.zaneschepke.wireguardautotunnel.ui.common.snackbar.SnackbarController import com.zaneschepke.wireguardautotunnel.util.Constants import com.zaneschepke.wireguardautotunnel.util.StringValue import com.zaneschepke.wireguardautotunnel.util.extensions.asTunnelState +import com.zaneschepke.wireguardautotunnel.util.extensions.cancelWithMessage import com.zaneschepke.wireguardautotunnel.util.extensions.toBackendError import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job -import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.delay import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.asStateFlow @@ -51,49 +51,22 @@ abstract class BaseTunnel( const val CHECK_INTERVAL = 1000L } - internal val tunnels = MutableStateFlow>(emptyList()) - private val _activeTunnels = MutableStateFlow>(emptyMap()) - override val activeTunnels = _activeTunnels.asStateFlow() + protected val activeTuns = MutableStateFlow>(emptyMap()) + override val activeTunnels = activeTuns.asStateFlow() - protected val tunnelJobs = ConcurrentHashMap>() - private val mutex = Mutex() + private val tunnelJobs = ConcurrentHashMap>() + + protected val mutex = Mutex() private val isNetworkConnected = MutableStateFlow(true) init { applicationScope.launch(ioDispatcher) { launch { monitorNetworkStatus() } launch { monitorTunnelConfigChanges() } - launch { monitorTunnels() } } } - private suspend fun monitorTunnels() { - tunnels.collectLatest { newTunnels -> - mutex.withLock { - val previousIds = tunnelJobs.keys - val currentIds = newTunnels.map { it.id }.toSet() - val added = newTunnels.filter { it.id !in previousIds && it.isActive } - val removed = previousIds - currentIds - - added.forEach { tunnel -> - Timber.d("Starting jobs for tunnel ${tunnel.id}: ${tunnel.tunName}") - if (tunnelJobs[tunnel.id] == null) { - tunnelJobs[tunnel.id] = mutableListOf(startTunnelJobs(tunnel)) - } - } - - removed.forEach { id -> - Timber.d("Stopping jobs for tunnel $id") - tunnelJobs[id]?.forEach { it.cancelAndJoin() } - tunnelJobs.remove(id) - _activeTunnels.update { it - id } - serviceManager.updateTunnelTile() - } - } - } - } - - protected fun startTunnelJobs(tunnel: TunnelConf): Job { + private fun startTunnelJobs(tunnel: TunnelConf): Job { return applicationScope.launch(ioDispatcher) { val jobs = mutableListOf() jobs += launch { updateTunnelStatistics(tunnel) } @@ -106,7 +79,7 @@ abstract class BaseTunnel( while (true) { runCatching { val stats = getStatistics(tunnel) - updateTunnelState(tunnel.id, stats = stats) + updateTunnelState(tunnel, stats = stats) }.onFailure { e -> Timber.e(e, "Failed to update stats for ${tunnel.tunName}") } @@ -154,21 +127,26 @@ abstract class BaseTunnel( } } - protected fun updateTunnelState(tunnelId: Int, state: TunnelStatus? = null, stats: TunnelStatistics? = null) { + protected fun updateTunnelState(tunnelConf: TunnelConf, state: TunnelStatus? = null, stats: TunnelStatistics? = null) { applicationScope.launch(ioDispatcher) { mutex.withLock { - _activeTunnels.update { current -> - val existing = current[tunnelId] ?: TunnelState() - val newState = state ?: existing.state - if (existing.state == newState && stats == null) { - Timber.d("Skipping redundant state update for $tunnelId: $newState") + activeTuns.update { current -> + val originalConf = current.getKeyById(tunnelConf.id) ?: tunnelConf + val existingState = current.getValueById(tunnelConf.id) ?: TunnelState() + val newState = state ?: existingState.state + if (newState == TunnelStatus.DOWN) { + // Remove tunnel from activeTunnels when it goes DOWN + Timber.d("Removing tunnel ${tunnelConf.id} from activeTunnels as state is DOWN") + current - originalConf + } else if (existingState.state == newState && stats == null) { + Timber.d("Skipping redundant state update for ${tunnelConf.id}: $newState") current } else { - val updated = existing.copy( + val updated = existingState.copy( state = newState, - statistics = stats ?: existing.statistics, + statistics = stats ?: existingState.statistics, ) - current + (tunnelId to updated) + current + (originalConf to updated) } } } @@ -176,67 +154,60 @@ abstract class BaseTunnel( } protected suspend fun configureTunnel(tunnelConf: TunnelConf) { + // setup state change callback tunnelConf.setStateChangeCallback { state -> Timber.d("State change callback triggered for tunnel ${tunnelConf.id}: ${tunnelConf.tunName} with state $state at ${System.currentTimeMillis()}") when (state) { - is Tunnel.State -> updateTunnelState(tunnelConf.id, state.asTunnelState()) - is org.amnezia.awg.backend.Tunnel.State -> updateTunnelState(tunnelConf.id, state.asTunnelState()) + is Tunnel.State -> updateTunnelState(tunnelConf, state.asTunnelState()) + is org.amnezia.awg.backend.Tunnel.State -> updateTunnelState(tunnelConf, state.asTunnelState()) } applicationScope.launch(ioDispatcher) { serviceManager.updateTunnelTile() } } + + activeTuns.update { current -> + current.filter { it.key != tunnelConf } + (tunnelConf to TunnelState()) + } + } + + protected suspend fun onStartSuccess(tunnelConf: TunnelConf) { + val tunnelCopy = tunnelConf.copyWithCallback(isActive = true) + + // start service + if (activeTuns.value.isEmpty()) { + serviceManager.startTunnelForegroundService(tunnelCopy) + } else { + serviceManager.updateTunnelForegroundServiceNotification(tunnelCopy) + } + // save active + appDataRepository.tunnels.save(tunnelCopy) + // start tunnel jobs + tunnelJobs[tunnelCopy.id] = mutableListOf(startTunnelJobs(tunnelConf)) } override fun startTunnel(tunnelConf: TunnelConf) { - applicationScope.launch(ioDispatcher) { - mutex.withLock { - val currentState = _activeTunnels.value[tunnelConf.id]?.state - if (currentState == TunnelStatus.UP || currentState == TunnelStatus.STARTING) { - Timber.w("Tunnel ${tunnelConf.id} is already $currentState, skipping start (possible duplicate call)") - return@launch - } - - val existingNames = tunnels.value.map { it.tunName }.toSet() - if (tunnelConf.tunName in existingNames && tunnels.value.any { it.id != tunnelConf.id && it.tunName == tunnelConf.tunName }) { - Timber.w("Duplicate tunName ${tunnelConf.tunName} detected for tunnel ${tunnelConf.id}") - } - - val lockedConf = tunnelConf.copyWithCallback(isActive = true) - Timber.d("Starting tunnel with TunnelConf: $lockedConf") - if (tunnels.value.isEmpty()) { - Timber.d("No active tunnels, starting background service for ${lockedConf.id}") - serviceManager.startTunnelForegroundService(lockedConf) - } else { - Timber.d("Tunnels already active, updating service notification for ${lockedConf.id}") - serviceManager.updateTunnelForegroundServiceNotification(lockedConf) - } - appDataRepository.tunnels.save(lockedConf) - tunnels.update { current -> - current.filter { it.id != lockedConf.id } + lockedConf - } - } - } + throw NotImplementedError("Must be implemented by subclass") } override fun stopTunnel(tunnelConf: TunnelConf?) { tunnelConf?.let { applicationScope.launch(ioDispatcher) { mutex.withLock { + removeActiveTunnel(tunnelConf) + tunnelJobs[tunnelConf.id]?.forEach { it.cancelWithMessage("Cancel tunnel job") } + tunnelJobs.remove(tunnelConf.id) val lockedConf = it.copyWithCallback(isActive = false) - Timber.d("Stopping tunnel with TunnelConf: $lockedConf") - tunnels.update { tunnels -> tunnels.filter { t -> t.id != lockedConf.id } } appDataRepository.tunnels.save(lockedConf) - if (tunnels.value.isEmpty()) { + + // TODO improve to handle multiple tunnels + if (activeTuns.value.isEmpty()) { Timber.d("No tunnels active, stopping background service") serviceManager.stopTunnelForegroundService() } else { Timber.d("Other tunnels still active, updating service notification") - val nextActive = tunnels.value.firstOrNull { it.isActive } + val nextActive = activeTuns.value.keys.firstOrNull() if (nextActive != null) { Timber.d("Next active tunnel: ${nextActive.id}") serviceManager.updateTunnelForegroundServiceNotification(nextActive) - } else { - Timber.w("No active tunnels found in _tunnels, forcing service stop") - serviceManager.stopTunnelForegroundService() } } } @@ -244,6 +215,12 @@ abstract class BaseTunnel( } } + private fun removeActiveTunnel(tunnelConf: TunnelConf) { + activeTuns.update { current -> + current.toMutableMap().apply { remove(tunnelConf) } + } + } + override suspend fun bounceTunnel(tunnelConf: TunnelConf) { stopTunnel(tunnelConf) delay(1000) @@ -264,7 +241,7 @@ abstract class BaseTunnel( appDataRepository.tunnels.flow.collectLatest { storedTunnels -> mutex.withLock { storedTunnels.forEach { stored -> - val current = tunnels.value.find { it.id == stored.id } + val current = activeTuns.value.keys.find { it.id == stored.id } if (current != null && !current.isQuickConfigMatching(stored)) { Timber.d("Config changed for ${stored.id}, bouncing") bounceTunnel(stored) @@ -278,5 +255,5 @@ abstract class BaseTunnel( throw NotImplementedError("Must be implemented by subclass") } - override suspend fun runningTunnelNames(): Set = tunnels.value.map { it.tunName }.toSet() + override suspend fun runningTunnelNames(): Set = activeTuns.value.keys.map { it.tunName }.toSet() } diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/Extensions.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/Extensions.kt new file mode 100644 index 00000000..8dd7fbc4 --- /dev/null +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/Extensions.kt @@ -0,0 +1,25 @@ +package com.zaneschepke.wireguardautotunnel.core.tunnel + +import com.zaneschepke.wireguardautotunnel.domain.entity.TunnelConf +import com.zaneschepke.wireguardautotunnel.domain.state.TunnelState + +fun Map.allDown(): Boolean { + return this.all { it.value.state.isDown() } +} + +fun Map.hasActive(): Boolean { + return this.any { it.value.state.isUp() } +} + +fun Map.getValueById(id: Int): TunnelState? { + val key = this.keys.find { it.id == id } + return key?.let { this@getValueById[it] } +} + +fun Map.getKeyById(id: Int): TunnelConf? { + return this.keys.find { it.id == id } +} + +fun Map.isUp(tunnelConf: TunnelConf): Boolean { + return this.getValueById(tunnelConf.id)?.state?.isUp() ?: false +} diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/KernelTunnel.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/KernelTunnel.kt index 0c8b8f2c..884e4c97 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/KernelTunnel.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/KernelTunnel.kt @@ -13,12 +13,12 @@ import com.zaneschepke.wireguardautotunnel.domain.enums.TunnelStatus import com.zaneschepke.wireguardautotunnel.domain.repository.AppDataRepository import com.zaneschepke.wireguardautotunnel.domain.state.TunnelStatistics import com.zaneschepke.wireguardautotunnel.domain.state.WireGuardStatistics +import com.zaneschepke.wireguardautotunnel.util.extensions.asTunnelState import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.sync.withLock import timber.log.Timber -import java.util.concurrent.ConcurrentHashMap import javax.inject.Inject class KernelTunnel @Inject constructor( @@ -31,34 +31,24 @@ class KernelTunnel @Inject constructor( networkMonitor: NetworkMonitor, ) : BaseTunnel(ioDispatcher, applicationScope, networkMonitor, appDataRepository, serviceManager, notificationManager) { - private val startedTunnels = ConcurrentHashMap() - override fun startTunnel(tunnelConf: TunnelConf) { Timber.i("Starting tunnel ${tunnelConf.id} kernel") applicationScope.launch(ioDispatcher) { runCatching { - updateTunnelState(tunnelConf.id, TunnelStatus.STARTING) - Timber.d("Set STARTING state for tunnel ${tunnelConf.id} at ${System.currentTimeMillis()}") + // tunnel already active + if (activeTuns.value.any { it.key.id == tunnelConf.id }) return@launch - runBlocking { configureTunnel(tunnelConf) } - Timber.d("Callback set for tunnel ${tunnelConf.id} at ${System.currentTimeMillis()}") + mutex.withLock { + updateTunnelState(tunnelConf, TunnelStatus.STARTING) - super.startTunnel(tunnelConf) - Timber.d("Calling backend.setState UP for tunnel ${tunnelConf.id}") - backend.setState(tunnelConf, Tunnel.State.UP, tunnelConf.toWgConfig()) - startedTunnels[tunnelConf.id] = tunnelConf + // configure state callback and add to tunnels + configureTunnel(tunnelConf) - val backendState = backend.getState(tunnelConf) - if (backendState == Tunnel.State.UP) { - updateTunnelState(tunnelConf.id, TunnelStatus.UP) - Timber.d("Confirmed UP state for tunnel ${tunnelConf.id} at ${System.currentTimeMillis()}") - } else { - Timber.w("Tunnel ${tunnelConf.id} not UP after setState, state: $backendState") + updateTunnelState(tunnelConf, backend.setState(tunnelConf, Tunnel.State.UP, tunnelConf.toWgConfig()).asTunnelState()) + + // run some actions after start success + onStartSuccess(tunnelConf) } - - // Start stats jobs only after UP is confirmed - tunnelJobs[tunnelConf.id] = mutableListOf(startTunnelJobs(tunnelConf)) - Timber.d("Started stats jobs for tunnel ${tunnelConf.id} at ${System.currentTimeMillis()}") }.onFailure { exception -> Timber.e(exception, "Failed to start tunnel ${tunnelConf.id} kernel") stopTunnel(tunnelConf) @@ -81,34 +71,17 @@ class KernelTunnel @Inject constructor( override fun stopTunnel(tunnelConf: TunnelConf?) { applicationScope.launch(ioDispatcher) { runCatching { - val originalTunnel = tunnelConf?.let { startedTunnels.getOrDefault(it.id, null) } + val originalTunnel = activeTuns.value.keys.find { it.id == tunnelConf?.id } if (originalTunnel != null) { - Timber.i( - "Stopping tunnel ${originalTunnel.id} kernel", - ) -// updateTunnelState(tunnelConf.id, TunnelStatus.STOPPING) - backend.setState(originalTunnel, Tunnel.State.DOWN, originalTunnel.toWgConfig()) - super.stopTunnel(originalTunnel) - startedTunnels.remove(originalTunnel.id) - tunnelJobs[originalTunnel.id]?.forEach { it.cancel() } - tunnelJobs.remove(originalTunnel.id) - if (backend.getState(originalTunnel) == Tunnel.State.DOWN) { - updateTunnelState(originalTunnel.id, TunnelStatus.DOWN) - Timber.d("Confirmed DOWN state for tunnel ${originalTunnel.id}") + Timber.i("Stopping tunnel ${originalTunnel.id} kernel") + mutex.withLock { + updateTunnelState(originalTunnel, backend.setState(originalTunnel, Tunnel.State.DOWN, originalTunnel.toWgConfig()).asTunnelState()) + super.stopTunnel(originalTunnel) } } else { Timber.w("Tunnel not found in startedTunnels, stopping all tunnels") - startedTunnels.forEach { (_, config) -> -// updateTunnelState(config.id, TunnelStatus.STOPPING) - val state = backend.setState(config, Tunnel.State.DOWN, config.toWgConfig()) - super.stopTunnel(tunnelConf) - if (state == Tunnel.State.DOWN) { - startedTunnels.remove(config.id) - tunnelJobs[config.id]?.forEach { it.cancel() } - tunnelJobs.remove(config.id) - updateTunnelState(config.id, TunnelStatus.DOWN) - Timber.d("Confirmed DOWN state for tunnel ${config.id} after fallback") - } + activeTuns.value.keys.forEach { config -> + stopTunnel(config) } } }.onFailure { e -> diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/TunnelManager.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/TunnelManager.kt index 0339e1fd..00304bd7 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/TunnelManager.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/TunnelManager.kt @@ -84,7 +84,7 @@ class TunnelManager @Inject constructor( val settings = appDataRepository.settings.get() if (settings.isRestoreOnBootEnabled) { val previouslyActiveTuns = appDataRepository.tunnels.getActive() - val tunsToStart = previouslyActiveTuns.filterNot { tun -> activeTunnels.value.any { tun.id == it.key } } + val tunsToStart = previouslyActiveTuns.filterNot { tun -> activeTunnels.value.any { tun.id == it.key.id } } if (settings.isKernelEnabled) { return@launch tunsToStart.forEach { startTunnel(it) diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/TunnelProvider.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/TunnelProvider.kt index 82bff610..7fd0be9f 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/TunnelProvider.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/TunnelProvider.kt @@ -13,5 +13,5 @@ interface TunnelProvider { suspend fun setBackendState(backendState: BackendState, allowedIps: Collection) suspend fun runningTunnelNames(): Set fun getStatistics(tunnelConf: TunnelConf): TunnelStatistics? - val activeTunnels: StateFlow> + val activeTunnels: StateFlow> } diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/UserspaceTunnel.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/UserspaceTunnel.kt index 731539e4..6f2cf7d3 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/UserspaceTunnel.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/core/tunnel/UserspaceTunnel.kt @@ -12,14 +12,15 @@ import com.zaneschepke.wireguardautotunnel.domain.repository.AppDataRepository import com.zaneschepke.wireguardautotunnel.domain.state.AmneziaStatistics import com.zaneschepke.wireguardautotunnel.domain.state.TunnelStatistics import com.zaneschepke.wireguardautotunnel.util.extensions.asAmBackendState +import com.zaneschepke.wireguardautotunnel.util.extensions.asTunnelState import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.delay import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.sync.withLock import org.amnezia.awg.backend.Backend +import org.amnezia.awg.backend.Tunnel import timber.log.Timber -import java.util.concurrent.ConcurrentHashMap import javax.inject.Inject class UserspaceTunnel @Inject constructor( @@ -32,35 +33,27 @@ class UserspaceTunnel @Inject constructor( networkMonitor: NetworkMonitor, ) : BaseTunnel(ioDispatcher, applicationScope, networkMonitor, appDataRepository, serviceManager, notificationManager) { - private val startedTunnels = ConcurrentHashMap() - override fun startTunnel(tunnelConf: TunnelConf) { Timber.i("Starting tunnel ${tunnelConf.id} userspace") applicationScope.launch(ioDispatcher) { runCatching { - stopActiveTunnels(tunnelConf) - updateTunnelState(tunnelConf.id, TunnelStatus.STARTING) - Timber.d("Set STARTING state for tunnel ${tunnelConf.id} at ${System.currentTimeMillis()}") + // tunnel already active + if (activeTuns.value.any { it.key.id == tunnelConf.id }) return@launch - runBlocking { configureTunnel(tunnelConf) } - Timber.d("Callback set for tunnel ${tunnelConf.id} at ${System.currentTimeMillis()}") + // stop any active tunnels that aren't this one, userspace only + stopActiveTunnels() - super.startTunnel(tunnelConf) - Timber.d("Calling backend.setState UP for tunnel ${tunnelConf.id}") - backend.setState(tunnelConf, org.amnezia.awg.backend.Tunnel.State.UP, tunnelConf.toAmConfig()) - startedTunnels[tunnelConf.id] = tunnelConf + mutex.withLock { + updateTunnelState(tunnelConf, TunnelStatus.STARTING) - val backendState = backend.getState(tunnelConf) - if (backendState == org.amnezia.awg.backend.Tunnel.State.UP) { - updateTunnelState(tunnelConf.id, TunnelStatus.UP) - Timber.d("Confirmed UP state for tunnel ${tunnelConf.id} at ${System.currentTimeMillis()}") - } else { - Timber.w("Tunnel ${tunnelConf.id} not UP after setState, state: $backendState") + // configure state callback and add to tunnels + configureTunnel(tunnelConf) + + updateTunnelState(tunnelConf, backend.setState(tunnelConf, Tunnel.State.UP, tunnelConf.toAmConfig()).asTunnelState()) + + // run some actions after start success + onStartSuccess(tunnelConf) } - - // Start stats jobs only after UP is confirmed - tunnelJobs[tunnelConf.id] = mutableListOf(startTunnelJobs(tunnelConf)) - Timber.d("Started stats jobs for tunnel ${tunnelConf.id} at ${System.currentTimeMillis()}") }.onFailure { exception -> Timber.e(exception, "Failed to start tunnel ${tunnelConf.id} userspace") stopTunnel(tunnelConf) @@ -71,15 +64,10 @@ class UserspaceTunnel @Inject constructor( } } - private suspend fun stopActiveTunnels(tunnelConf: TunnelConf) { - val runningTunnels = activeTunnels.value.filter { (id, state) -> - id != tunnelConf.id && state.state == TunnelStatus.UP - } - runningTunnels.forEach { (id, _) -> - val runningTunnel = startedTunnels[id] - if (runningTunnel != null) { - Timber.i("Stopping running tunnel ${runningTunnel.id} before starting ${tunnelConf.id}") - stopTunnel(runningTunnel) + private suspend fun stopActiveTunnels() { + activeTunnels.value.forEach { (config, state) -> + if (state.state.isUp()) { + stopTunnel(config) delay(300) } } @@ -88,34 +76,17 @@ class UserspaceTunnel @Inject constructor( override fun stopTunnel(tunnelConf: TunnelConf?) { applicationScope.launch(ioDispatcher) { runCatching { - val originalTunnel = tunnelConf?.let { startedTunnels.getOrDefault(it.id, null) } + val originalTunnel = activeTuns.value.keys.find { it.id == tunnelConf?.id } if (originalTunnel != null) { - Timber.i( - "Stopping tunnel ${originalTunnel.id} userspace", - ) -// updateTunnelState(tunnelConf.id, TunnelStatus.STOPPING) - backend.setState(originalTunnel, org.amnezia.awg.backend.Tunnel.State.DOWN, originalTunnel.toAmConfig()) - super.stopTunnel(originalTunnel) - startedTunnels.remove(originalTunnel.id) - tunnelJobs[originalTunnel.id]?.forEach { it.cancel() } - tunnelJobs.remove(originalTunnel.id) - if (backend.getState(originalTunnel) == org.amnezia.awg.backend.Tunnel.State.DOWN) { - updateTunnelState(originalTunnel.id, TunnelStatus.DOWN) - Timber.d("Confirmed DOWN state for tunnel ${originalTunnel.id}") + Timber.i("Stopping tunnel ${originalTunnel.id} userspace") + mutex.withLock { + updateTunnelState(originalTunnel, backend.setState(originalTunnel, Tunnel.State.DOWN, originalTunnel.toAmConfig()).asTunnelState()) + super.stopTunnel(originalTunnel) } } else { Timber.w("Tunnel not found in startedTunnels, stopping all tunnels") - startedTunnels.forEach { (_, config) -> -// updateTunnelState(config.id, TunnelStatus.STOPPING) - val state = backend.setState(config, org.amnezia.awg.backend.Tunnel.State.DOWN, config.toAmConfig()) - super.stopTunnel(tunnelConf) - if (state == org.amnezia.awg.backend.Tunnel.State.DOWN) { - startedTunnels.remove(config.id) - tunnelJobs[config.id]?.forEach { it.cancel() } - tunnelJobs.remove(config.id) - updateTunnelState(config.id, TunnelStatus.DOWN) - Timber.d("Confirmed DOWN state for tunnel ${config.id} after fallback") - } + activeTuns.value.keys.forEach { config -> + stopTunnel(config) } } }.onFailure { e -> diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/domain/entity/TunnelConf.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/domain/entity/TunnelConf.kt index efe628b2..35a4be12 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/domain/entity/TunnelConf.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/domain/entity/TunnelConf.kt @@ -34,7 +34,6 @@ data class TunnelConf( private var stateChangeCallback: ((Any) -> Unit)? = null, ) : Tunnel, org.amnezia.awg.backend.Tunnel { - // Mutex to protect stateChangeCallback access private val callbackMutex = Mutex() suspend fun setStateChangeCallback(callback: (Any) -> Unit) { @@ -43,7 +42,6 @@ data class TunnelConf( } } - // Ensure callback is copied over fun copyWithCallback( id: Int = this.id, tunName: String = this.tunName, diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/domain/state/AutoTunnelState.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/domain/state/AutoTunnelState.kt index ca414e6e..dd68881c 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/domain/state/AutoTunnelState.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/domain/state/AutoTunnelState.kt @@ -1,5 +1,8 @@ package com.zaneschepke.wireguardautotunnel.domain.state +import com.zaneschepke.wireguardautotunnel.core.tunnel.allDown +import com.zaneschepke.wireguardautotunnel.core.tunnel.hasActive +import com.zaneschepke.wireguardautotunnel.core.tunnel.isUp import com.zaneschepke.wireguardautotunnel.domain.events.KillSwitchEvent import com.zaneschepke.wireguardautotunnel.domain.entity.AppSettings import com.zaneschepke.wireguardautotunnel.domain.entity.TunnelConf @@ -7,7 +10,7 @@ import com.zaneschepke.wireguardautotunnel.domain.events.AutoTunnelEvent import com.zaneschepke.wireguardautotunnel.util.extensions.isMatchingToWildcardList data class AutoTunnelState( - val activeTunnels: Map = emptyMap(), + val activeTunnels: Map = emptyMap(), val networkState: NetworkState = NetworkState(), val settings: AppSettings = AppSettings(), val tunnels: List = emptyList(), @@ -20,12 +23,12 @@ data class AutoTunnelState( private fun isMobileTunnelDataChangeNeeded(): Boolean { val preferredTunnel = preferredMobileDataTunnel() return preferredTunnel != null && - activeTunnels.isNotEmpty() && !activeTunnels.any { it.key == preferredTunnel.id } + activeTunnels.isNotEmpty() && !activeTunnels.isUp(preferredTunnel) } private fun isEthernetTunnelChangeNeeded(): Boolean { val preferredTunnel = preferredEthernetTunnel() - return preferredTunnel != null && activeTunnels.isNotEmpty() && !activeTunnels.any { it.key == preferredTunnel.id } + return preferredTunnel != null && activeTunnels.isNotEmpty() && !activeTunnels.isUp(preferredTunnel) } private fun preferredMobileDataTunnel(): TunnelConf? { @@ -45,11 +48,11 @@ data class AutoTunnelState( } private fun startOnEthernet(): Boolean { - return networkState.isEthernetConnected && settings.isTunnelOnEthernetEnabled && activeTunnels.isEmpty() + return networkState.isEthernetConnected && settings.isTunnelOnEthernetEnabled && activeTunnels.allDown() } private fun stopOnEthernet(): Boolean { - return networkState.isEthernetConnected && !settings.isTunnelOnEthernetEnabled && activeTunnels.isNotEmpty() + return networkState.isEthernetConnected && !settings.isTunnelOnEthernetEnabled && activeTunnels.hasActive() } // TODO test removed kill switch state check @@ -67,11 +70,11 @@ data class AutoTunnelState( } private fun stopOnMobileData(): Boolean { - return isMobileDataActive() && !settings.isTunnelOnMobileDataEnabled && activeTunnels.isNotEmpty() + return isMobileDataActive() && !settings.isTunnelOnMobileDataEnabled && activeTunnels.hasActive() } private fun startOnMobileData(): Boolean { - return isMobileDataActive() && settings.isTunnelOnMobileDataEnabled && activeTunnels.isEmpty() + return isMobileDataActive() && settings.isTunnelOnMobileDataEnabled && activeTunnels.allDown() } private fun changeOnMobileData(): Boolean { @@ -83,24 +86,24 @@ data class AutoTunnelState( } private fun stopOnWifi(): Boolean { - return isWifiActive() && !settings.isTunnelOnWifiEnabled && activeTunnels.isNotEmpty() + return isWifiActive() && !settings.isTunnelOnWifiEnabled && activeTunnels.hasActive() } private fun stopOnTrustedWifi(): Boolean { - return isWifiActive() && settings.isTunnelOnWifiEnabled && activeTunnels.isNotEmpty() && isCurrentSSIDTrusted() + return isWifiActive() && settings.isTunnelOnWifiEnabled && activeTunnels.hasActive() && isCurrentSSIDTrusted() } private fun startOnUntrustedWifi(): Boolean { - return isWifiActive() && settings.isTunnelOnWifiEnabled && activeTunnels.isEmpty() && !isCurrentSSIDTrusted() + return isWifiActive() && settings.isTunnelOnWifiEnabled && activeTunnels.allDown() && !isCurrentSSIDTrusted() } private fun changeOnUntrustedWifi(): Boolean { - return isWifiActive() && settings.isTunnelOnWifiEnabled && activeTunnels.isNotEmpty() && !isCurrentSSIDTrusted() && !isWifiTunnelPreferred() + return isWifiActive() && settings.isTunnelOnWifiEnabled && activeTunnels.hasActive() && !isCurrentSSIDTrusted() && !isWifiTunnelPreferred() } private fun isWifiTunnelPreferred(): Boolean { val preferred = preferredWifiTunnel() - return activeTunnels.any { it.key == preferred?.id } + return preferred?.let { activeTunnels.isUp(it) } ?: true } fun asAutoTunnelEvent(): AutoTunnelEvent { diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/screens/main/MainScreen.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/screens/main/MainScreen.kt index fc76e19f..c4fd9b4a 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/screens/main/MainScreen.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/screens/main/MainScreen.kt @@ -38,6 +38,7 @@ import androidx.compose.ui.unit.dp import androidx.hilt.navigation.compose.hiltViewModel import androidx.lifecycle.compose.collectAsStateWithLifecycle import com.zaneschepke.wireguardautotunnel.R +import com.zaneschepke.wireguardautotunnel.core.tunnel.getValueById import com.zaneschepke.wireguardautotunnel.domain.entity.TunnelConf import com.zaneschepke.wireguardautotunnel.domain.state.TunnelState import com.zaneschepke.wireguardautotunnel.ui.state.AppUiState @@ -227,7 +228,7 @@ fun MainScreen(viewModel: MainViewModel = hiltViewModel(), uiState: AppUiState) key = { tunnel -> tunnel.id }, ) { tunnel -> val expanded = uiState.generalState.isTunnelStatsExpanded - val tunnelState = activeTunnels.getOrDefault(tunnel.id, TunnelState()) + val tunnelState = activeTunnels.getValueById(tunnel.id) ?: TunnelState() TunnelRowItem( tunnelState.state.isUp(), expanded, diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/screens/main/TunnelOptionsScreen.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/screens/main/TunnelOptionsScreen.kt index 9b0e3870..dafa2bf5 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/screens/main/TunnelOptionsScreen.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/screens/main/TunnelOptionsScreen.kt @@ -32,6 +32,7 @@ import androidx.compose.ui.text.input.KeyboardType import androidx.compose.ui.unit.dp import androidx.hilt.navigation.compose.hiltViewModel import com.zaneschepke.wireguardautotunnel.R +import com.zaneschepke.wireguardautotunnel.core.tunnel.isUp import com.zaneschepke.wireguardautotunnel.domain.entity.TunnelConf import com.zaneschepke.wireguardautotunnel.ui.Route import com.zaneschepke.wireguardautotunnel.ui.common.button.ScaledSwitch @@ -195,7 +196,7 @@ fun OptionsScreen(tunnelConf: TunnelConf, appUiState: AppUiState, viewModel: Tun trailing = { ScaledSwitch( checked = tunnelConf.isPingEnabled, - enabled = !appUiState.activeTunnels.containsKey(tunnelConf.id), + enabled = !appUiState.activeTunnels.isUp(tunnelConf), onClick = { onPingToggle() }, ) }, diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/state/AppUiState.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/state/AppUiState.kt index ea982988..f78fbc7a 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/state/AppUiState.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/ui/state/AppUiState.kt @@ -8,7 +8,7 @@ import com.zaneschepke.wireguardautotunnel.domain.state.TunnelState data class AppUiState( val appSettings: AppSettings = AppSettings(), val tunnels: List = emptyList(), - val activeTunnels: Map = emptyMap(), + val activeTunnels: Map = emptyMap(), val generalState: GeneralState = GeneralState(), val autoTunnelActive: Boolean = false, ) diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/viewmodel/AppViewModel.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/viewmodel/AppViewModel.kt index 023a5a1c..52301b16 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/viewmodel/AppViewModel.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/viewmodel/AppViewModel.kt @@ -147,6 +147,8 @@ constructor( appDataRepository.appState.setLocalLogsEnabled(toggledOn) if (!toggledOn) { logReader.stop() + } else { + logReader.start() } } } @@ -275,13 +277,13 @@ constructor( } } - suspend fun requestRoot(): Result { + private suspend fun requestRoot(): Result { return withContext(ioDispatcher) { runCatching { rootShell.get().start() - SnackbarController.Companion.showMessage(StringValue.StringResource(R.string.root_accepted)) + SnackbarController.showMessage(StringValue.StringResource(R.string.root_accepted)) }.onFailure { - SnackbarController.Companion.showMessage(StringValue.StringResource(R.string.error_root_denied)) + SnackbarController.showMessage(StringValue.StringResource(R.string.error_root_denied)) } } } @@ -351,7 +353,7 @@ constructor( runCatching { val amConfig = tunnelConfig.toAmConfig() val wgConfig = tunnelConfig.toWgConfig() - val proxy = InterfaceProxy.Companion.from(amConfig.`interface`) + val proxy = InterfaceProxy.from(amConfig.`interface`) if (proxy.includedApplications.isEmpty() && proxy.excludedApplications.isEmpty()) return@launch if (proxy.includedApplications.retainAll(packages.toSet()) || proxy.excludedApplications.retainAll(packages.toSet())) { updateTunnelConfig(tunnelConfig, amConfig = amConfig, wgConfig = wgConfig, `interface` = proxy) diff --git a/networkmonitor/src/main/java/com/zaneschepke/networkmonitor/AndroidNetworkMonitor.kt b/networkmonitor/src/main/java/com/zaneschepke/networkmonitor/AndroidNetworkMonitor.kt index 2b01bcbb..42485658 100644 --- a/networkmonitor/src/main/java/com/zaneschepke/networkmonitor/AndroidNetworkMonitor.kt +++ b/networkmonitor/src/main/java/com/zaneschepke/networkmonitor/AndroidNetworkMonitor.kt @@ -30,7 +30,8 @@ class AndroidNetworkMonitor( data class TransportState(val connected: Boolean = false) private val wifiFlow: Flow = callbackFlow { - var currentSsid: String? = null + + var currentSsid: String? @Suppress("DEPRECATION") fun getWifiSsid(): String? { @@ -47,13 +48,8 @@ class AndroidNetworkMonitor( val callback = object : ConnectivityManager.NetworkCallback() { override fun onAvailable(network: Network) { Timber.d("Wi-Fi onAvailable: network=$network") - val capabilities = connectivityManager.getNetworkCapabilities(network) - val connected = capabilities?.hasTransport(NetworkCapabilities.TRANSPORT_WIFI) == true - if (connected) { - val ssid = getWifiSsid() - currentSsid = ssid - trySend(WifiState(connected = true, ssid = ssid)) - } + currentSsid = getWifiSsid() + trySend(WifiState(connected = true, ssid = currentSsid)) } override fun onLost(network: Network) { @@ -63,12 +59,7 @@ class AndroidNetworkMonitor( } override fun onCapabilitiesChanged(network: Network, networkCapabilities: NetworkCapabilities) { - val connected = networkCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_WIFI) - val ssid = if (connected) getWifiSsid() else null - if (ssid != currentSsid) { - currentSsid = ssid - trySend(WifiState(connected = connected, ssid = ssid)) - } + Timber.d("Wi-Fi onCapabilitiesChanged: network=$network, networkCapabilities=$networkCapabilities") } }