fix: tunnel lock (#765)

fix: start up logger bug
refactor: switch to bound services
refactor: expose resolved peer endpoint
This commit is contained in:
Zane Schepke
2025-04-28 15:06:43 -04:00
committed by GitHub
parent 0c90b33813
commit 25fd31e252
21 changed files with 246 additions and 187 deletions
+1 -4
View File
@@ -166,15 +166,12 @@
<receiver
android:name=".core.broadcast.RestartReceiver"
android:enabled="true"
android:exported="false"
android:directBootAware="true">
android:exported="false">
<intent-filter>
<action android:name="android.intent.action.SCREEN_ON" />
<action android:name="android.intent.action.USER_PRESENT" />
<action android:name="android.intent.action.BOOT_COMPLETED" />
<action android:name="android.intent.action.QUICKBOOT_POWERON" />
<action android:name="com.htc.intent.action.QUICKBOOT_POWERON" />
<action android:name="android.intent.action.LOCKED_BOOT_COMPLETED" />
<action android:name="android.intent.action.MY_PACKAGE_REPLACED" />
</intent-filter>
</receiver>
@@ -39,7 +39,6 @@ import androidx.navigation.compose.rememberNavController
import androidx.navigation.toRoute
import com.zaneschepke.networkmonitor.NetworkMonitor
import com.zaneschepke.wireguardautotunnel.core.tunnel.TunnelManager
import com.zaneschepke.wireguardautotunnel.domain.enums.TunnelStatus
import com.zaneschepke.wireguardautotunnel.domain.repository.AppStateRepository
import com.zaneschepke.wireguardautotunnel.ui.Route
import com.zaneschepke.wireguardautotunnel.ui.common.dialog.VpnDeniedDialog
@@ -110,6 +109,7 @@ class MainActivity : AppCompatActivity() {
val isTv = isRunningOnTv()
val appUiState by viewModel.uiState.collectAsStateWithLifecycle()
val appViewState by viewModel.appViewState.collectAsStateWithLifecycle()
val tunnelError by viewModel.tunnelManager.errorEvents.collectAsStateWithLifecycle(null)
val navController = rememberNavController()
val backStackEntry by navController.currentBackStackEntryAsState()
@@ -151,6 +151,15 @@ class MainActivity : AppCompatActivity() {
viewModel.handleEvent(AppEvent.SetBatteryOptimizeDisableShown)
}
LaunchedEffect(tunnelError) {
if (tunnelError == null) return@LaunchedEffect
val message = tunnelError!!.second.toStringRes()
val context = this@MainActivity
snackbar.showSnackbar(
context.getString(R.string.tunnel_error_template, context.getString(message))
)
}
with(appViewState) {
LaunchedEffect(isConfigChanged) {
if (isConfigChanged) {
@@ -166,21 +175,6 @@ class MainActivity : AppCompatActivity() {
viewModel.handleEvent(AppEvent.MessageShown)
}
}
LaunchedEffect(appUiState.activeTunnels) {
appUiState.activeTunnels.mapNotNull { (tunnelConf, tunnelState) ->
(tunnelState.status as? TunnelStatus.Error)?.let { error ->
val message = error.error.toStringRes()
val context = this@MainActivity
snackbar.showSnackbar(
context.getString(
R.string.tunnel_error_template,
context.getString(message),
)
)
viewModel.handleEvent(AppEvent.ClearTunnelError(tunnelConf))
}
}
}
LaunchedEffect(popBackStack) {
if (popBackStack) {
navController.popBackStack()
@@ -32,14 +32,15 @@ class RestartReceiver : BroadcastReceiver() {
Timber.d("RestartReceiver triggered with action: ${intent.action}")
// screen on for Android TV only to help with sleep shutdowns
val isTv = context.isRunningOnTv()
if (intent.action == Intent.ACTION_SCREEN_ON && !isTv) return
if (intent.action == Intent.ACTION_USER_PRESENT && !isTv) return
serviceManager.updateTunnelTile()
serviceManager.updateAutoTunnelTile()
applicationScope.launch(ioDispatcher) {
val settings = appDataRepository.settings.get()
if (settings.isRestoreOnBootEnabled) {
if (settings.isAutoTunnelEnabled && !serviceManager.autoTunnelActive.value) {
if (
settings.isAutoTunnelEnabled && serviceManager.autoTunnelService.value == null
) {
Timber.d("Starting auto-tunnel on boot/update")
serviceManager.startAutoTunnel()
} else {
@@ -1,10 +1,11 @@
package com.zaneschepke.wireguardautotunnel.core.service
import android.app.Service
import android.content.ComponentName
import android.content.Context
import android.content.Intent
import android.content.ServiceConnection
import android.net.VpnService
import com.zaneschepke.wireguardautotunnel.WireGuardAutoTunnel
import android.os.IBinder
import com.zaneschepke.wireguardautotunnel.core.service.autotunnel.AutoTunnelService
import com.zaneschepke.wireguardautotunnel.di.ApplicationScope
import com.zaneschepke.wireguardautotunnel.di.IoDispatcher
@@ -13,7 +14,6 @@ import com.zaneschepke.wireguardautotunnel.domain.repository.AppDataRepository
import com.zaneschepke.wireguardautotunnel.util.extensions.requestAutoTunnelTileServiceUpdate
import com.zaneschepke.wireguardautotunnel.util.extensions.requestTunnelTileServiceStateUpdate
import jakarta.inject.Inject
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.MutableStateFlow
@@ -37,23 +37,37 @@ constructor(
private val autoTunnelMutex = Mutex()
private val _autoTunnelActive = MutableStateFlow(false)
val autoTunnelActive = _autoTunnelActive.asStateFlow()
private val _tunnelService = MutableStateFlow<TunnelForegroundService?>(null)
private val _autoTunnelService = MutableStateFlow<AutoTunnelService?>(null)
val autoTunnelService = _autoTunnelService.asStateFlow()
var autoTunnelService = CompletableDeferred<AutoTunnelService>()
var backgroundService = CompletableDeferred<TunnelForegroundService>()
private fun <T : Service> startService(cls: Class<T>, background: Boolean) {
runCatching {
val intent = Intent(context, cls)
if (background) {
context.startForegroundService(intent)
} else {
context.startService(intent)
}
private val tunnelServiceConnection =
object : ServiceConnection {
override fun onServiceConnected(name: ComponentName, service: IBinder) {
val binder = service as? TunnelForegroundService.LocalBinder
_tunnelService.value = binder?.service
Timber.d("TunnelForegroundService connected")
}
.onFailure { Timber.e(it) }
}
override fun onServiceDisconnected(name: ComponentName) {
_tunnelService.value = null
Timber.d("TunnelForegroundService disconnected")
}
}
private val autoTunnelServiceConnection =
object : ServiceConnection {
override fun onServiceConnected(name: ComponentName, service: IBinder) {
val binder = service as? AutoTunnelService.LocalBinder
_autoTunnelService.value = binder?.service
Timber.d("AutoTunnelService connected")
}
override fun onServiceDisconnected(name: ComponentName) {
_autoTunnelService.value = null
Timber.d("AutoTunnelService disconnected")
}
}
fun hasVpnPermission(): Boolean {
return VpnService.prepare(context) == null
@@ -63,20 +77,13 @@ constructor(
autoTunnelMutex.withLock {
val settings = appDataRepository.settings.get()
appDataRepository.settings.save(settings.copy(isAutoTunnelEnabled = true))
if (autoTunnelService.isCompleted) {
_autoTunnelActive.update { true }
return
if (_autoTunnelService.value != null) return
withContext(ioDispatcher) {
val intent = Intent(context, AutoTunnelService::class.java)
context.startForegroundService(intent)
context.bindService(intent, autoTunnelServiceConnection, Context.BIND_AUTO_CREATE)
withContext(mainDispatcher) { updateAutoTunnelTile() }
}
runCatching {
autoTunnelService = CompletableDeferred()
startService(AutoTunnelService::class.java, !WireGuardAutoTunnel.isForeground())
_autoTunnelActive.update { true }
}
.onFailure {
Timber.e(it)
_autoTunnelActive.update { false }
}
withContext(mainDispatcher) { updateAutoTunnelTile() }
}
}
@@ -84,43 +91,44 @@ constructor(
autoTunnelMutex.withLock {
val settings = appDataRepository.settings.get()
appDataRepository.settings.save(settings.copy(isAutoTunnelEnabled = false))
if (!autoTunnelService.isCompleted) return
runCatching {
val service = autoTunnelService.await()
service.stop()
_autoTunnelActive.update { false }
autoTunnelService = CompletableDeferred()
if (_autoTunnelService.value == null) return
_autoTunnelService.value?.let { service ->
service.stop()
try {
context.unbindService(autoTunnelServiceConnection)
} finally {
_tunnelService.value = null
}
.onFailure { Timber.e(it) }
}
withContext(mainDispatcher) { updateAutoTunnelTile() }
}
}
fun startTunnelForegroundService() {
if (backgroundService.isCompleted) return
runCatching {
backgroundService = CompletableDeferred()
startService(
TunnelForegroundService::class.java,
!WireGuardAutoTunnel.isForeground(),
)
suspend fun startTunnelForegroundService() {
if (_tunnelService.value != null) return
withContext(ioDispatcher) {
applicationScope.launch(ioDispatcher) {
val intent = Intent(context, TunnelForegroundService::class.java)
context.startForegroundService(intent)
context.bindService(intent, tunnelServiceConnection, Context.BIND_AUTO_CREATE)
}
.onFailure { Timber.e(it) }
}
}
suspend fun stopTunnelForegroundService() {
if (!backgroundService.isCompleted) return
runCatching {
val service = backgroundService.await()
service.stop()
backgroundService = CompletableDeferred()
fun stopTunnelForegroundService() {
_tunnelService.value?.let { service ->
service.stop()
try {
context.unbindService(tunnelServiceConnection)
} finally {
_tunnelService.value = null
}
.onFailure { Timber.e(it) }
}
}
fun toggleAutoTunnel() {
applicationScope.launch(ioDispatcher) {
if (_autoTunnelActive.value) stopAutoTunnel() else startAutoTunnel()
if (_autoTunnelService.value != null) stopAutoTunnel() else startAutoTunnel()
}
}
@@ -131,4 +139,12 @@ constructor(
fun updateTunnelTile() {
context.requestTunnelTileServiceStateUpdate()
}
fun handleTunnelServiceDestroy() {
_tunnelService.update { null }
}
fun handleAutoTunnelServiceDestroy() {
_autoTunnelService.update { null }
}
}
@@ -2,6 +2,7 @@ package com.zaneschepke.wireguardautotunnel.core.service
import android.app.Notification
import android.content.Intent
import android.os.Binder
import android.os.IBinder
import androidx.core.app.ServiceCompat
import androidx.lifecycle.LifecycleService
@@ -23,7 +24,6 @@ import com.zaneschepke.wireguardautotunnel.util.extensions.distinctByKeys
import dagger.hilt.android.AndroidEntryPoint
import java.util.concurrent.ConcurrentHashMap
import javax.inject.Inject
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Job
import kotlinx.coroutines.NonCancellable
@@ -64,9 +64,12 @@ class TunnelForegroundService : LifecycleService() {
private val jobsMutex = Mutex()
class LocalBinder(val service: TunnelForegroundService) : Binder()
private val binder = LocalBinder(this)
override fun onCreate() {
super.onCreate()
serviceManager.backgroundService.complete(this)
ServiceCompat.startForeground(
this@TunnelForegroundService,
NotificationManager.VPN_NOTIFICATION_ID,
@@ -75,14 +78,13 @@ class TunnelForegroundService : LifecycleService() {
)
}
override fun onBind(intent: Intent): IBinder? {
override fun onBind(intent: Intent): IBinder {
super.onBind(intent)
return null
return binder
}
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
super.onStartCommand(intent, flags, startId)
serviceManager.backgroundService.complete(this)
ServiceCompat.startForeground(
this@TunnelForegroundService,
NotificationManager.VPN_NOTIFICATION_ID,
@@ -273,7 +275,7 @@ class TunnelForegroundService : LifecycleService() {
}
override fun onDestroy() {
serviceManager.backgroundService = CompletableDeferred()
serviceManager.handleTunnelServiceDestroy()
ServiceCompat.stopForeground(this, ServiceCompat.STOP_FOREGROUND_REMOVE)
super.onDestroy()
}
@@ -1,6 +1,7 @@
package com.zaneschepke.wireguardautotunnel.core.service.autotunnel
import android.content.Intent
import android.os.Binder
import android.os.IBinder
import android.os.PowerManager
import androidx.core.app.ServiceCompat
@@ -28,7 +29,6 @@ import com.zaneschepke.wireguardautotunnel.util.extensions.Tunnels
import dagger.hilt.android.AndroidEntryPoint
import javax.inject.Inject
import javax.inject.Provider
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.FlowPreview
@@ -68,21 +68,23 @@ class AutoTunnelService : LifecycleService() {
private var killSwitchJob: Job? = null
class LocalBinder(val service: AutoTunnelService) : Binder()
private val binder = LocalBinder(this)
override fun onCreate() {
super.onCreate()
serviceManager.autoTunnelService.complete(this)
launchWatcherNotification()
}
override fun onBind(intent: Intent): IBinder? {
override fun onBind(intent: Intent): IBinder {
super.onBind(intent)
return null
return binder
}
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
super.onStartCommand(intent, flags, startId)
Timber.d("onStartCommand executed with startId: $startId")
serviceManager.autoTunnelService.complete(this)
start()
return START_STICKY
}
@@ -105,7 +107,7 @@ class AutoTunnelService : LifecycleService() {
}
override fun onDestroy() {
serviceManager.autoTunnelService = CompletableDeferred()
serviceManager.handleAutoTunnelServiceDestroy()
restoreVpnKillSwitch()
super.onDestroy()
}
@@ -38,8 +38,8 @@ class AutoTunnelControlTile : TileService(), LifecycleOwner {
lifecycleRegistry.handleLifecycleEvent(Lifecycle.Event.ON_START)
Timber.d("Start listening called for auto tunnel tile")
lifecycleScope.launch {
serviceManager.autoTunnelActive.collect {
if (it) return@collect setActive()
serviceManager.autoTunnelService.collect {
if (it != null) return@collect setActive()
setInactive()
}
}
@@ -56,7 +56,7 @@ class AutoTunnelControlTile : TileService(), LifecycleOwner {
super.onClick()
unlockAndRun {
lifecycleScope.launch {
if (serviceManager.autoTunnelActive.value) {
if (serviceManager.autoTunnelService.value != null) {
serviceManager.stopAutoTunnel()
setInactive()
} else {
@@ -1,6 +1,5 @@
package com.zaneschepke.wireguardautotunnel.core.tunnel
import com.wireguard.android.backend.BackendException
import com.wireguard.android.backend.Tunnel
import com.zaneschepke.wireguardautotunnel.core.service.ServiceManager
import com.zaneschepke.wireguardautotunnel.di.ApplicationScope
@@ -11,14 +10,11 @@ import com.zaneschepke.wireguardautotunnel.domain.repository.AppDataRepository
import com.zaneschepke.wireguardautotunnel.domain.state.TunnelState
import com.zaneschepke.wireguardautotunnel.domain.state.TunnelStatistics
import com.zaneschepke.wireguardautotunnel.util.extensions.asTunnelState
import com.zaneschepke.wireguardautotunnel.util.extensions.toBackendError
import java.util.concurrent.ConcurrentHashMap
import kotlin.concurrent.thread
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex
@@ -31,6 +27,10 @@ abstract class BaseTunnel(
private val serviceManager: ServiceManager,
) : TunnelProvider {
private val _errorEvents =
MutableSharedFlow<Pair<TunnelConf, BackendError>>(replay = 0, extraBufferCapacity = 1)
override val errorEvents = _errorEvents.asSharedFlow()
private val activeTuns = MutableStateFlow<Map<TunnelConf, TunnelState>>(emptyMap())
private val tunThreads = ConcurrentHashMap<Int, Thread>()
override val activeTunnels = activeTuns.asStateFlow()
@@ -45,37 +45,34 @@ abstract class BaseTunnel(
abstract fun stopBackend(tunnel: TunnelConf)
override suspend fun clearError(tunnelConf: TunnelConf) =
updateTunnelStatus(tunnelConf, TunnelStatus.Down)
override fun hasVpnPermission(): Boolean {
return serviceManager.hasVpnPermission()
}
protected suspend fun updateTunnelStatus(
tunnelConf: TunnelConf,
state: TunnelStatus? = null,
status: TunnelStatus? = null,
stats: TunnelStatistics? = null,
) {
tunStatusMutex.withLock {
activeTuns.update { current ->
val originalConf = current.getKeyById(tunnelConf.id) ?: tunnelConf
val existingState = current.getValueById(tunnelConf.id) ?: TunnelState()
val newState = state ?: existingState.status
activeTuns.update { currentTuns ->
val originalConf = currentTuns.getKeyById(tunnelConf.id) ?: tunnelConf
val existingState = currentTuns.getValueById(tunnelConf.id) ?: TunnelState()
val newState = status ?: existingState.status
if (newState == TunnelStatus.Down) {
Timber.d("Removing tunnel ${tunnelConf.id} from activeTunnels as state is DOWN")
cleanUpTunThread(tunnelConf)
current - originalConf
currentTuns - originalConf
} else if (existingState.status == newState && stats == null) {
Timber.d("Skipping redundant state update for ${tunnelConf.id}: $newState")
current
currentTuns
} else {
val updated =
existingState.copy(
status = newState,
statistics = stats ?: existingState.statistics,
)
current + (originalConf to updated)
currentTuns + (originalConf to updated)
}
}
}
@@ -117,23 +114,17 @@ abstract class BaseTunnel(
if (this@BaseTunnel is UserspaceTunnel) stopActiveTunnels()
tunMutex.withLock {
tunThreads[tunnelConf.id] = thread {
runCatching {
runBlocking {
try {
Timber.d("Starting tunnel ${tunnelConf.id}...")
startTunnelInner(tunnelConf)
Timber.d("Started complete for tunnel ${tunnelConf.name}...")
} catch (e: BackendError) {
Timber.e(e, "Failed to start tunnel ${tunnelConf.name} userspace")
updateTunnelStatus(tunnelConf, TunnelStatus.Error(e))
} catch (e: InterruptedException) {
Timber.w(
"Tunnel start has been interrupted as ${tunnelConf.name} failed to start"
)
}
}
runBlocking {
try {
Timber.d("Starting tunnel ${tunnelConf.id}...")
startTunnelInner(tunnelConf)
Timber.d("Started complete for tunnel ${tunnelConf.name}...")
} catch (e: InterruptedException) {
Timber.w(
"Tunnel start has been interrupted as ${tunnelConf.name} failed to start"
)
}
.onFailure { Timber.w("Tunnel start has been interrupted") }
}
}
}
}
@@ -147,11 +138,10 @@ abstract class BaseTunnel(
Timber.d("Started for tun ${tunnelConf.id}...")
saveTunnelActiveState(tunnelConf, true)
serviceManager.startTunnelForegroundService()
} catch (e: BackendException) {
} catch (e: BackendError) {
Timber.e(e, "Failed to start backend for ${tunnelConf.name}")
val backendError = e.toBackendError()
updateTunnelStatus(tunnelConf, TunnelStatus.Error(backendError))
throw backendError
_errorEvents.emit(tunnelConf to e)
updateTunnelStatus(tunnelConf, TunnelStatus.Down)
}
}
@@ -163,26 +153,27 @@ abstract class BaseTunnel(
override suspend fun stopTunnel(tunnelConf: TunnelConf?, reason: TunnelStatus.StopReason) {
if (tunnelConf == null) return stopActiveTunnels()
tunMutex.withLock {
try {
if (activeTuns.isStarting(tunnelConf.id))
return handleStuckStartingTunnelShutdown(tunnelConf)
updateTunnelStatus(tunnelConf, TunnelStatus.Stopping(reason))
stopTunnelInner(tunnelConf)
} catch (e: BackendError) {
Timber.e(e, "Failed to stop tunnel ${tunnelConf.id}")
updateTunnelStatus(tunnelConf, TunnelStatus.Error(e))
}
if (activeTuns.isStarting(tunnelConf.id))
return handleStuckStartingTunnelShutdown(tunnelConf)
updateTunnelStatus(tunnelConf, TunnelStatus.Stopping(reason))
stopTunnelInner(tunnelConf)
}
}
private suspend fun stopTunnelInner(tunnelConf: TunnelConf) {
val tunnel = activeTuns.findTunnel(tunnelConf.id) ?: return
stopBackend(tunnel)
saveTunnelActiveState(tunnelConf, false)
removeActiveTunnel(tunnel)
try {
val tunnel = activeTuns.findTunnel(tunnelConf.id) ?: return
stopBackend(tunnel)
saveTunnelActiveState(tunnelConf, false)
removeActiveTunnel(tunnel)
} catch (e: BackendError) {
Timber.e(e, "Failed to stop tunnel ${tunnelConf.id}")
_errorEvents.emit(tunnelConf to e)
updateTunnelStatus(tunnelConf, TunnelStatus.Down)
}
}
private suspend fun handleServiceStateOnChange() {
private fun handleServiceStateOnChange() {
if (activeTuns.value.isEmpty() && bouncingTunnelIds.isEmpty())
serviceManager.stopTunnelForegroundService()
}
@@ -193,15 +184,15 @@ abstract class BaseTunnel(
tunThreads[tunnel.id]?.let {
if (it.state != Thread.State.TERMINATED) {
it.interrupt()
updateTunnelStatus(tunnel, TunnelStatus.Down)
} else {
Timber.d("Thread already terminated")
}
}
} catch (e: Exception) {
Timber.e(e, "Failed to stop tunnel thread for ${tunnel.name}")
} finally {
updateTunnelStatus(tunnel, TunnelStatus.Down)
}
cleanUpTunThread(tunnel)
}
private fun cleanUpTunThread(tunnel: TunnelConf) {
@@ -221,7 +212,7 @@ abstract class BaseTunnel(
bouncingTunnelIds[tunnelConf.id] = reason
try {
stopTunnel(tunnelConf, reason)
delay(300L)
delay(BOUNCE_DELAY)
startTunnel(tunnelConf)
} finally {
bouncingTunnelIds.remove(tunnelConf.id)
@@ -235,4 +226,8 @@ abstract class BaseTunnel(
override suspend fun runningTunnelNames(): Set<String> =
activeTuns.value.keys.map { it.tunName }.toSet()
companion object {
const val BOUNCE_DELAY = 300L
}
}
@@ -5,6 +5,7 @@ import com.zaneschepke.wireguardautotunnel.di.IoDispatcher
import com.zaneschepke.wireguardautotunnel.di.Kernel
import com.zaneschepke.wireguardautotunnel.di.Userspace
import com.zaneschepke.wireguardautotunnel.domain.entity.TunnelConf
import com.zaneschepke.wireguardautotunnel.domain.enums.BackendError
import com.zaneschepke.wireguardautotunnel.domain.enums.BackendState
import com.zaneschepke.wireguardautotunnel.domain.enums.TunnelStatus
import com.zaneschepke.wireguardautotunnel.domain.repository.AppDataRepository
@@ -15,6 +16,7 @@ import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.flatMapLatest
@@ -62,6 +64,9 @@ constructor(
initialValue = emptyMap(),
)
override val errorEvents: SharedFlow<Pair<TunnelConf, BackendError>>
get() = tunnelProviderFlow.value.errorEvents
override val bouncingTunnelIds: ConcurrentHashMap<Int, TunnelStatus.StopReason> =
tunnelProviderFlow.value.bouncingTunnelIds
@@ -69,10 +74,6 @@ constructor(
return userspaceTunnel.hasVpnPermission()
}
override suspend fun clearError(tunnelConf: TunnelConf) {
tunnelProviderFlow.value.clearError(tunnelConf)
}
override suspend fun updateTunnelStatistics(tunnel: TunnelConf) {
tunnelProviderFlow.value.updateTunnelStatistics(tunnel)
}
@@ -1,11 +1,13 @@
package com.zaneschepke.wireguardautotunnel.core.tunnel
import com.zaneschepke.wireguardautotunnel.domain.entity.TunnelConf
import com.zaneschepke.wireguardautotunnel.domain.enums.BackendError
import com.zaneschepke.wireguardautotunnel.domain.enums.BackendState
import com.zaneschepke.wireguardautotunnel.domain.enums.TunnelStatus
import com.zaneschepke.wireguardautotunnel.domain.state.TunnelState
import com.zaneschepke.wireguardautotunnel.domain.state.TunnelStatistics
import java.util.concurrent.ConcurrentHashMap
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.StateFlow
interface TunnelProvider {
@@ -46,11 +48,11 @@ interface TunnelProvider {
val activeTunnels: StateFlow<Map<TunnelConf, TunnelState>>
val errorEvents: SharedFlow<Pair<TunnelConf, BackendError>>
val bouncingTunnelIds: ConcurrentHashMap<Int, TunnelStatus.StopReason>
fun hasVpnPermission(): Boolean
suspend fun clearError(tunnelConf: TunnelConf)
suspend fun updateTunnelStatistics(tunnel: TunnelConf)
}
@@ -50,8 +50,9 @@ constructor(
} catch (e: BackendException) {
Timber.e(e, "Failed to stop tunnel ${tunnel.id}")
throw e.toBackendError()
} finally {
handlePreviouslyEnabledVpnKillSwitch()
}
handlePreviouslyEnabledVpnKillSwitch()
}
// stop vpn kill switch if we need to resolve DNS for peer endpoints
@@ -69,7 +70,7 @@ constructor(
// restore vpn kill switch if needed
private fun handlePreviouslyEnabledVpnKillSwitch() {
// let auto tunnel handle this if it is active
if (!serviceManager.autoTunnelActive.value) {
if (serviceManager.autoTunnelService.value == null) {
previousBackendState?.let { (state, lanEnabled) ->
Timber.d("Restoring kill switch configuration")
val lan = if (lanEnabled) TunnelConf.LAN_BYPASS_ALLOWED_IPS else emptyList()
@@ -57,7 +57,7 @@ constructor(
withContext(ioDispatcher) {
Timber.i("Service worker started")
with(appDataRepository.settings.get()) {
if (isAutoTunnelEnabled && !serviceManager.autoTunnelActive.value)
if (isAutoTunnelEnabled && serviceManager.autoTunnelService.value == null)
return@with serviceManager.startAutoTunnel()
if (tunnelManager.activeTunnels.value.isEmpty())
tunnelManager.restorePreviousState()
@@ -1,7 +1,6 @@
package com.zaneschepke.wireguardautotunnel.domain.enums
sealed class TunnelStatus {
data class Error(val error: BackendError) : TunnelStatus()
data object Up : TunnelStatus()
@@ -12,6 +12,7 @@ class AmneziaStatistics(private val statistics: Statistics) : TunnelStatistics()
rxBytes = stats.rxBytes,
txBytes = stats.txBytes,
latestHandshakeEpochMillis = stats.latestHandshakeEpochMillis,
resolvedEndpoint = stats.resolvedEndpoint,
)
}
}
@@ -8,6 +8,7 @@ abstract class TunnelStatistics {
val rxBytes: Long,
val txBytes: Long,
val latestHandshakeEpochMillis: Long,
val resolvedEndpoint: String,
)
abstract fun peerStats(peer: Key): PeerStats?
@@ -12,6 +12,7 @@ class WireGuardStatistics(private val statistics: Statistics) : TunnelStatistics
txBytes = peerStats.txBytes,
rxBytes = peerStats.rxBytes,
latestHandshakeEpochMillis = peerStats.latestHandshakeEpochMillis,
resolvedEndpoint = peerStats.resolvedEndpoint,
)
}
}
@@ -8,6 +8,9 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.res.stringResource
@@ -21,49 +24,90 @@ import com.zaneschepke.wireguardautotunnel.util.extensions.toThreeDecimalPlaceSt
@Composable
fun TunnelStatisticsRow(statistics: TunnelStatistics?, tunnelConf: TunnelConf) {
val config = TunnelConf.configFromAmQuick(tunnelConf.wgQuick)
config.peers.forEach { peer ->
Row(
modifier = Modifier.fillMaxWidth().padding(end = 10.dp, bottom = 10.dp, start = 45.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(5.dp, Alignment.Start),
) {
val peerId = peer.publicKey.toBase64().subSequence(0, 3).toString() + "***"
val peerRx = statistics?.peerStats(peer.publicKey)?.rxBytes ?: 0
val peerTx = statistics?.peerStats(peer.publicKey)?.txBytes ?: 0
val peerTxMB = NumberUtils.bytesToMB(peerTx).toThreeDecimalPlaceString()
val peerRxMB = NumberUtils.bytesToMB(peerRx).toThreeDecimalPlaceString()
val handshake =
statistics?.peerStats(peer.publicKey)?.latestHandshakeEpochMillis?.let {
if (it == 0L) {
stringResource(R.string.never)
} else {
"${NumberUtils.getSecondsBetweenTimestampAndNow(it)} ${stringResource(R.string.sec)}"
Column(
modifier = Modifier.fillMaxWidth().padding(start = 45.dp, bottom = 10.dp, end = 10.dp),
verticalArrangement = Arrangement.spacedBy(10.dp, Alignment.CenterVertically),
horizontalAlignment = Alignment.Start,
) {
config.peers.forEach { peer ->
val peerId = remember { peer.publicKey.toBase64().subSequence(0, 3).toString() + "***" }
val endpoint =
remember(statistics) { statistics?.peerStats(peer.publicKey)?.resolvedEndpoint }
val peerRxMB by
remember(statistics) {
derivedStateOf {
statistics
?.peerStats(peer.publicKey)
?.rxBytes
?.let { NumberUtils.bytesToMB(it) }
?.toThreeDecimalPlaceString()
}
} ?: stringResource(R.string.never)
Column(verticalArrangement = Arrangement.spacedBy(10.dp)) {
}
val peerTxMB by
remember(statistics) {
derivedStateOf {
statistics
?.peerStats(peer.publicKey)
?.txBytes
?.let { NumberUtils.bytesToMB(it) }
?.toThreeDecimalPlaceString()
}
}
val handshake by
remember(statistics) {
derivedStateOf {
statistics?.peerStats(peer.publicKey)?.latestHandshakeEpochMillis?.let {
if (it == 0L) {
null
} else {
"${NumberUtils.getSecondsBetweenTimestampAndNow(it)}"
}
}
}
}
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(16.dp, Alignment.Start),
) {
Text(
stringResource(R.string.peer).lowercase() + ": $peerId",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.outline,
)
Text(
"tx: $peerTxMB MB",
stringResource(R.string.handshake) +
": ${if(handshake == null) stringResource(R.string.never) else handshake + " " + stringResource(R.string.sec)}",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.outline,
)
}
Column(verticalArrangement = Arrangement.spacedBy(10.dp)) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(16.dp, Alignment.Start),
) {
Text(
stringResource(R.string.handshake) + ": $handshake",
"rx: ${peerRxMB ?: 0.00} MB",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.outline,
)
Text(
"rx: $peerRxMB MB",
"tx: ${peerTxMB ?: 0.00} MB",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.outline,
)
}
if (endpoint != null) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(16.dp, Alignment.Start),
) {
Text(
"endpoint: $endpoint",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.outline,
)
}
}
}
}
}
@@ -39,6 +39,8 @@ import java.time.Instant
import java.util.*
import javax.inject.Inject
import javax.inject.Provider
import kotlin.collections.component1
import kotlin.collections.component2
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.sync.Mutex
@@ -56,7 +58,7 @@ constructor(
@IoDispatcher private val ioDispatcher: CoroutineDispatcher,
@MainDispatcher private val mainDispatcher: CoroutineDispatcher,
@AppShell private val rootShell: Provider<RootShell>,
private val tunnelManager: TunnelManager,
val tunnelManager: TunnelManager,
private val serviceManager: ServiceManager,
private val logReader: LogReader,
private val fileUtils: FileUtils,
@@ -86,7 +88,7 @@ constructor(
appDataRepository.tunnels.flow,
appDataRepository.appState.flow,
tunnelManager.activeTunnels,
serviceManager.autoTunnelActive,
serviceManager.autoTunnelService.map { it != null },
networkMonitor.networkStatusFlow,
) { array ->
val settings = array[0] as AppSettings
@@ -206,7 +208,6 @@ constructor(
is AppEvent.ShowMessage -> handleShowMessage(event.message)
is AppEvent.PopBackStack ->
_appViewState.update { it.copy(popBackStack = event.pop) }
is AppEvent.ClearTunnelError -> tunnelManager.clearError(event.tunnel)
AppEvent.ToggleRemoteControl -> handleToggleRemoteControl(state.appState)
AppEvent.ClearSelectedTunnels -> clearSelectedTunnels()
is AppEvent.SetShowModal ->
@@ -265,6 +266,9 @@ constructor(
}
}
private fun handleTunnelErrors() =
viewModelScope.launch { tunnelManager.errorEvents.collect { errorEvent -> } }
private suspend fun handleAppReadyCheck(tunnels: List<TunnelConf>) {
if (tunnels.size == appDataRepository.tunnels.count()) {
_appViewState.update { it.copy(isAppReady = true) }
@@ -106,8 +106,6 @@ sealed class AppEvent {
data class ShowMessage(val message: StringValue) : AppEvent()
data class ClearTunnelError(val tunnel: TunnelConf) : AppEvent()
data class PopBackStack(val pop: Boolean) : AppEvent()
data class SetBottomSheet(val showSheet: AppViewState.BottomSheet) : AppEvent()