feat: add kill switch and tunnel restore on boot

This commit is contained in:
zaneschepke
2026-02-15 14:44:03 -05:00
parent 49b4af2b54
commit 0a288b8ad1
62 changed files with 885 additions and 689 deletions
@@ -2,7 +2,7 @@ package com.zaneschepke.wireguardautotunnel.cli.commands.killswitch
import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils
import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils.renderAnsi import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils.renderAnsi
import com.zaneschepke.wireguardautotunnel.client.service.BackendCommandService import com.zaneschepke.wireguardautotunnel.client.service.BackendService
import java.util.concurrent.Callable import java.util.concurrent.Callable
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
@@ -14,7 +14,7 @@ import picocli.CommandLine.*
mixinStandardHelpOptions = true, mixinStandardHelpOptions = true,
) )
class KillSwitchCommand : Callable<Int> { class KillSwitchCommand : Callable<Int> {
private val backendService: BackendCommandService by inject(BackendCommandService::class.java) private val backendService: BackendService by inject(BackendService::class.java)
@Parameters(index = "0", description = ["The desired state: 'on' or 'off' (or true/false)."]) @Parameters(index = "0", description = ["The desired state: 'on' or 'off' (or true/false)."])
lateinit var state: String lateinit var state: String
@@ -2,7 +2,7 @@ package com.zaneschepke.wireguardautotunnel.cli.commands.status
import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils
import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils.renderAnsi import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils.renderAnsi
import com.zaneschepke.wireguardautotunnel.client.service.BackendCommandService import com.zaneschepke.wireguardautotunnel.client.service.BackendService
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus
import java.time.format.DateTimeFormatter import java.time.format.DateTimeFormatter
@@ -17,7 +17,7 @@ import picocli.CommandLine.*
mixinStandardHelpOptions = true, mixinStandardHelpOptions = true,
) )
class StatusCommand : Callable<Int> { class StatusCommand : Callable<Int> {
private val backendService: BackendCommandService by inject(BackendCommandService::class.java) private val backendService: BackendService by inject(BackendService::class.java)
override fun call(): Int = runBlocking { fetchSnapshot() } override fun call(): Int = runBlocking { fetchSnapshot() }
@@ -39,7 +39,7 @@ class TunnelDeleteCommand : Callable<Int> {
return@runBlocking 0 return@runBlocking 0
} }
CliUtils.withSpinner("Deleting '$tunnelName'...") { tunnelRepository.delete(tunnel) } CliUtils.withSpinner("Deleting '$tunnelName'...") { tunnelRepository.delete(tunnel.id) }
CliUtils.printSuccess("Tunnel '$tunnelName' deleted successfully.") CliUtils.printSuccess("Tunnel '$tunnelName' deleted successfully.")
0 0
@@ -3,7 +3,7 @@ package com.zaneschepke.wireguardautotunnel.cli.commands.tunnel
import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils
import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException
import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository
import com.zaneschepke.wireguardautotunnel.client.service.TunnelCommandService import com.zaneschepke.wireguardautotunnel.client.service.TunnelService
import java.util.concurrent.Callable import java.util.concurrent.Callable
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
@@ -12,7 +12,7 @@ import picocli.CommandLine.Parameters
@Command(name = "down", description = ["Bring a tunnel down."]) @Command(name = "down", description = ["Bring a tunnel down."])
class TunnelDownCommand : Callable<Int> { class TunnelDownCommand : Callable<Int> {
private val tunnelService: TunnelCommandService by inject(TunnelCommandService::class.java) private val tunnelService: TunnelService by inject(TunnelService::class.java)
private val tunnelRepository: TunnelRepository by inject(TunnelRepository::class.java) private val tunnelRepository: TunnelRepository by inject(TunnelRepository::class.java)
@Parameters( @Parameters(
@@ -3,7 +3,7 @@ package com.zaneschepke.wireguardautotunnel.cli.commands.tunnel
import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils
import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException
import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository
import com.zaneschepke.wireguardautotunnel.client.service.TunnelCommandService import com.zaneschepke.wireguardautotunnel.client.service.TunnelService
import java.util.concurrent.Callable import java.util.concurrent.Callable
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
@@ -12,7 +12,7 @@ import picocli.CommandLine.Parameters
@Command(name = "up", description = ["Bring a tunnel up."]) @Command(name = "up", description = ["Bring a tunnel up."])
class TunnelUpCommand : Callable<Int> { class TunnelUpCommand : Callable<Int> {
private val tunnelService: TunnelCommandService by inject(TunnelCommandService::class.java) private val tunnelService: TunnelService by inject(TunnelService::class.java)
private val tunnelRepository: TunnelRepository by inject(TunnelRepository::class.java) private val tunnelRepository: TunnelRepository by inject(TunnelRepository::class.java)
@Parameters( @Parameters(
@@ -1,14 +1,14 @@
package com.zaneschepke.wireguardautotunnel.cli.strategy package com.zaneschepke.wireguardautotunnel.cli.strategy
import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils import com.zaneschepke.wireguardautotunnel.cli.util.CliUtils
import com.zaneschepke.wireguardautotunnel.client.service.DaemonHealthService import com.zaneschepke.wireguardautotunnel.client.service.DaemonService
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
import picocli.CommandLine.* import picocli.CommandLine.*
class CliExecutionStrategy(private val defaultStrategy: IExecutionStrategy) : IExecutionStrategy { class CliExecutionStrategy(private val defaultStrategy: IExecutionStrategy) : IExecutionStrategy {
private val daemonHealthService: DaemonHealthService by inject(DaemonHealthService::class.java) private val daemonService: DaemonService by inject(DaemonService::class.java)
override fun execute(parseResult: ParseResult): Int = runBlocking { override fun execute(parseResult: ParseResult): Int = runBlocking {
// skip help and version // skip help and version
@@ -18,7 +18,7 @@ class CliExecutionStrategy(private val defaultStrategy: IExecutionStrategy) : IE
val isAlive = val isAlive =
try { try {
daemonHealthService.alive() daemonService.alive()
} catch (e: Exception) { } catch (e: Exception) {
false false
} }
@@ -2,7 +2,7 @@
"formatVersion": 1, "formatVersion": 1,
"database": { "database": {
"version": 1, "version": 1,
"identityHash": "57c05ed7dd6615f7812c499f8486ae43", "identityHash": "b662918cb8aa7ed02d3d3050c1fba46e",
"entities": [ "entities": [
{ {
"tableName": "tunnel_config", "tableName": "tunnel_config",
@@ -106,7 +106,7 @@
}, },
{ {
"tableName": "general_settings", "tableName": "general_settings",
"createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `theme` TEXT NOT NULL DEFAULT 'DARK', `locale` TEXT, `already_donated` INTEGER NOT NULL DEFAULT 0)", "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `theme` TEXT NOT NULL DEFAULT 'DARK', `locale` TEXT, `already_donated` INTEGER NOT NULL DEFAULT 0, `restore_tunnel_on_boot` INTEGER NOT NULL DEFAULT 0)",
"fields": [ "fields": [
{ {
"fieldPath": "id", "fieldPath": "id",
@@ -132,6 +132,13 @@
"affinity": "INTEGER", "affinity": "INTEGER",
"notNull": true, "notNull": true,
"defaultValue": "0" "defaultValue": "0"
},
{
"fieldPath": "restoreTunnelOnBoot",
"columnName": "restore_tunnel_on_boot",
"affinity": "INTEGER",
"notNull": true,
"defaultValue": "0"
} }
], ],
"primaryKey": { "primaryKey": {
@@ -144,7 +151,7 @@
], ],
"setupQueries": [ "setupQueries": [
"CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)",
"INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '57c05ed7dd6615f7812c499f8486ae43')" "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, 'b662918cb8aa7ed02d3d3050c1fba46e')"
] ]
} }
} }
@@ -20,6 +20,9 @@ interface GeneralSettingsDao {
@Query("UPDATE general_settings SET locale = :locale WHERE id = 1") @Query("UPDATE general_settings SET locale = :locale WHERE id = 1")
suspend fun updateLocale(locale: String) suspend fun updateLocale(locale: String)
@Query("UPDATE general_settings SET restore_tunnel_on_boot = :enabled WHERE id = 1")
suspend fun updateRestoreTunnelOnBoot(enabled: Boolean)
@Query("UPDATE general_settings SET already_donated = :donated WHERE id = 1") @Query("UPDATE general_settings SET already_donated = :donated WHERE id = 1")
suspend fun updateAlreadyDonated(donated: Boolean) suspend fun updateAlreadyDonated(donated: Boolean)
} }
@@ -10,4 +10,6 @@ data class GeneralSettings(
@ColumnInfo(name = "theme", defaultValue = "DARK") val theme: String = "DARK", @ColumnInfo(name = "theme", defaultValue = "DARK") val theme: String = "DARK",
@ColumnInfo(name = "locale") val locale: String? = null, @ColumnInfo(name = "locale") val locale: String? = null,
@ColumnInfo(name = "already_donated", defaultValue = "0") val alreadyDonated: Boolean = false, @ColumnInfo(name = "already_donated", defaultValue = "0") val alreadyDonated: Boolean = false,
@ColumnInfo(name = "restore_tunnel_on_boot", defaultValue = "0")
val restoreTunnelOnBoot: Boolean = false,
) )
@@ -10,7 +10,14 @@ fun Entity.toDomain(): Domain =
theme = Theme.valueOf(theme.uppercase()), theme = Theme.valueOf(theme.uppercase()),
locale = locale, locale = locale,
alreadyDonated = alreadyDonated, alreadyDonated = alreadyDonated,
restoreTunnelOnBoot = restoreTunnelOnBoot,
) )
fun Domain.toEntity(): Entity = fun Domain.toEntity(): Entity =
Entity(id = id, theme = theme.name, locale = locale, alreadyDonated = alreadyDonated) Entity(
id = id,
theme = theme.name,
locale = locale,
alreadyDonated = alreadyDonated,
restoreTunnelOnBoot = restoreTunnelOnBoot,
)
@@ -33,4 +33,8 @@ class RoomSettingsRepository(private val settingsDao: GeneralSettingsDao) :
override suspend fun updateAlreadyDonated(donated: Boolean) { override suspend fun updateAlreadyDonated(donated: Boolean) {
settingsDao.updateAlreadyDonated(donated) settingsDao.updateAlreadyDonated(donated)
} }
override suspend fun updateRestoreTunnelOnBoot(enabled: Boolean) {
settingsDao.updateRestoreTunnelOnBoot(enabled)
}
} }
@@ -1,14 +1,14 @@
package com.zaneschepke.wireguardautotunnel.client.data.service package com.zaneschepke.wireguardautotunnel.client.data.service
import co.touchlab.kermit.Logger import co.touchlab.kermit.Logger
import com.zaneschepke.wireguardautotunnel.client.data.service.UdsDaemonHealthService.Companion.DAEMON_WS_RECONNECT_DELAY_MILLIS import com.zaneschepke.wireguardautotunnel.client.data.service.UdsDaemonService.Companion.DAEMON_WS_RECONNECT_DELAY_MILLIS
import com.zaneschepke.wireguardautotunnel.client.domain.repository.LockdownSettingsRepository import com.zaneschepke.wireguardautotunnel.client.domain.repository.LockdownSettingsRepository
import com.zaneschepke.wireguardautotunnel.client.service.BackendCommandService import com.zaneschepke.wireguardautotunnel.client.service.BackendService
import com.zaneschepke.wireguardautotunnel.core.ipc.Routes import com.zaneschepke.wireguardautotunnel.core.ipc.Routes
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.TunnelState import com.zaneschepke.wireguardautotunnel.core.ipc.dto.TunnelState
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.request.KillSwitchRequest import com.zaneschepke.wireguardautotunnel.core.ipc.dto.request.FlagRequest
import com.zaneschepke.wireguardautotunnel.parser.ActiveConfig import com.zaneschepke.wireguardautotunnel.parser.ActiveConfig
import io.ktor.client.* import io.ktor.client.*
import io.ktor.client.call.* import io.ktor.client.call.*
@@ -16,40 +16,45 @@ import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.http.* import io.ktor.http.*
import io.ktor.utils.io.* import io.ktor.utils.io.*
import io.ktor.websocket.Frame import io.ktor.websocket.*
import io.ktor.websocket.readText
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.*
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.flow.flatMapLatest
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.isActive import kotlinx.coroutines.isActive
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
class UdsBackendCommandService( class UdsBackendService(
private val client: HttpClient, private val client: HttpClient,
private val json: Json, private val json: Json,
private val lockdownSettingsRepository: LockdownSettingsRepository, private val lockdownSettingsRepository: LockdownSettingsRepository,
) : BackendCommandService { ) : BackendService {
override suspend fun setMode(mode: BackendMode): Result<Unit> = safeDaemonCall { override suspend fun setMode(mode: BackendMode): Result<Unit> = safeDaemonCall {
client.post(Routes.BACKEND_MODE) { setBody(mode) } client.put(Routes.BACKEND_MODE) { setBody(mode) }
} }
override suspend fun setKillSwitch(enabled: Boolean): Result<Unit> { override suspend fun setKillSwitch(enabled: Boolean): Result<Unit> {
lockdownSettingsRepository.updateEnabled(enabled) lockdownSettingsRepository.updateEnabled(enabled)
return safeDaemonCall { return safeDaemonCall {
val request = KillSwitchRequest(enabled) val request = FlagRequest(enabled)
client.post(Routes.BACKEND_KILL_SWITCH) { setBody(request) } client.put(Routes.BACKEND_KILL_SWITCH) { setBody(request) }
Unit Unit
} }
.onFailure { lockdownSettingsRepository.updateEnabled(!enabled) } .onFailure { lockdownSettingsRepository.updateEnabled(!enabled) }
} }
override suspend fun setKillSwitchLanBypass(enabled: Boolean): Result<Unit> {
lockdownSettingsRepository.updateBypassLan(enabled)
return safeDaemonCall {
val request = FlagRequest(enabled)
client.put(Routes.BACKEND_KILL_SWITCH_BYPASS) { setBody(request) }
Unit
}
.onFailure { lockdownSettingsRepository.updateBypassLan(!enabled) }
}
override suspend fun getStatus(): Result<BackendStatus> = runCatching { override suspend fun getStatus(): Result<BackendStatus> = runCatching {
val response = client.get(Routes.BACKEND_STATUS) val response = client.get(Routes.BACKEND_STATUS)
response.body<BackendStatus>() response.body<BackendStatus>()
@@ -1,8 +1,11 @@
package com.zaneschepke.wireguardautotunnel.client.data.service package com.zaneschepke.wireguardautotunnel.client.data.service
import co.touchlab.kermit.Logger import co.touchlab.kermit.Logger
import com.zaneschepke.wireguardautotunnel.client.service.DaemonHealthService import com.zaneschepke.wireguardautotunnel.client.domain.repository.GeneralSettingRepository
import com.zaneschepke.wireguardautotunnel.client.domain.repository.LockdownSettingsRepository
import com.zaneschepke.wireguardautotunnel.client.service.DaemonService
import com.zaneschepke.wireguardautotunnel.core.ipc.Routes import com.zaneschepke.wireguardautotunnel.core.ipc.Routes
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.request.FlagRequest
import io.ktor.client.* import io.ktor.client.*
import io.ktor.client.plugins.websocket.* import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.* import io.ktor.client.request.*
@@ -16,7 +19,11 @@ import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.flowOn import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.isActive import kotlinx.coroutines.isActive
class UdsDaemonHealthService(private val client: HttpClient) : DaemonHealthService { class UdsDaemonService(
private val client: HttpClient,
private val lockdownSettingsRepository: LockdownSettingsRepository,
private val generalSettingsRepository: GeneralSettingRepository,
) : DaemonService {
override suspend fun alive(): Boolean { override suspend fun alive(): Boolean {
return try { return try {
@@ -27,6 +34,26 @@ class UdsDaemonHealthService(private val client: HttpClient) : DaemonHealthServi
} }
} }
override suspend fun setRestoreKillSwitch(enabled: Boolean): Result<Unit> {
lockdownSettingsRepository.updateRestoreOnBoot(enabled)
return safeDaemonCall {
val request = FlagRequest(enabled)
client.put(Routes.DAEMON_RESTORE_KILL_SWITCH) { setBody(request) }
Unit
}
.onFailure { lockdownSettingsRepository.updateRestoreOnBoot(!enabled) }
}
override suspend fun setRestoreTunnel(enabled: Boolean): Result<Unit> {
generalSettingsRepository.updateRestoreTunnelOnBoot(enabled)
return safeDaemonCall {
val request = FlagRequest(enabled)
client.put(Routes.DAEMON_RESTORE_TUNNEL) { setBody(request) }
Unit
}
.onFailure { generalSettingsRepository.updateRestoreTunnelOnBoot(!enabled) }
}
override val alive: Flow<Boolean> = override val alive: Flow<Boolean> =
callbackFlow { callbackFlow {
while (isActive) { while (isActive) {
@@ -2,7 +2,7 @@ package com.zaneschepke.wireguardautotunnel.client.data.service
import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException
import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository
import com.zaneschepke.wireguardautotunnel.client.service.TunnelCommandService import com.zaneschepke.wireguardautotunnel.client.service.TunnelService
import com.zaneschepke.wireguardautotunnel.core.ipc.Routes import com.zaneschepke.wireguardautotunnel.core.ipc.Routes
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.request.StartTunnelRequest import com.zaneschepke.wireguardautotunnel.core.ipc.dto.request.StartTunnelRequest
import io.ktor.client.* import io.ktor.client.*
@@ -11,10 +11,10 @@ import io.ktor.http.*
import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.sync.withLock
class UdsTunnelCommandService( class UdsTunnelService(
private val client: HttpClient, private val client: HttpClient,
private val tunnelRepository: TunnelRepository, private val tunnelRepository: TunnelRepository,
) : TunnelCommandService { ) : TunnelService {
val mutex = Mutex() val mutex = Mutex()
@@ -2,14 +2,14 @@ package com.zaneschepke.wireguardautotunnel.client.di
import co.touchlab.kermit.Logger import co.touchlab.kermit.Logger
import com.zaneschepke.wireguardautotunnel.client.data.service.DefaultTunnelImportService import com.zaneschepke.wireguardautotunnel.client.data.service.DefaultTunnelImportService
import com.zaneschepke.wireguardautotunnel.client.data.service.UdsBackendCommandService import com.zaneschepke.wireguardautotunnel.client.data.service.UdsBackendService
import com.zaneschepke.wireguardautotunnel.client.data.service.UdsDaemonHealthService import com.zaneschepke.wireguardautotunnel.client.data.service.UdsDaemonService
import com.zaneschepke.wireguardautotunnel.client.data.service.UdsTunnelCommandService import com.zaneschepke.wireguardautotunnel.client.data.service.UdsTunnelService
import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException
import com.zaneschepke.wireguardautotunnel.client.service.BackendCommandService import com.zaneschepke.wireguardautotunnel.client.service.BackendService
import com.zaneschepke.wireguardautotunnel.client.service.DaemonHealthService import com.zaneschepke.wireguardautotunnel.client.service.DaemonService
import com.zaneschepke.wireguardautotunnel.client.service.TunnelCommandService
import com.zaneschepke.wireguardautotunnel.client.service.TunnelImportService import com.zaneschepke.wireguardautotunnel.client.service.TunnelImportService
import com.zaneschepke.wireguardautotunnel.client.service.TunnelService
import com.zaneschepke.wireguardautotunnel.core.crypto.HmacProtector import com.zaneschepke.wireguardautotunnel.core.crypto.HmacProtector
import com.zaneschepke.wireguardautotunnel.core.ipc.Headers import com.zaneschepke.wireguardautotunnel.core.ipc.Headers
import com.zaneschepke.wireguardautotunnel.core.ipc.IPC import com.zaneschepke.wireguardautotunnel.core.ipc.IPC
@@ -68,6 +68,8 @@ val serviceModule = module {
HttpStatusCode.InternalServerError -> HttpStatusCode.InternalServerError ->
ClientException.BadRequestException(bodyText) ClientException.BadRequestException(bodyText)
HttpStatusCode.Conflict -> ClientException.ConflictException(bodyText) HttpStatusCode.Conflict -> ClientException.ConflictException(bodyText)
HttpStatusCode.Unauthorized ->
ClientException.UnauthorizedException(bodyText)
else -> ClientException.UnknownError(bodyText) else -> ClientException.UnknownError(bodyText)
} }
} }
@@ -81,7 +83,7 @@ val serviceModule = module {
install("HmacSigner") { install("HmacSigner") {
requestPipeline.intercept(HttpRequestPipeline.Render) { payload -> requestPipeline.intercept(HttpRequestPipeline.Render) { payload ->
val path = context.url.encodedPath val path = context.url.encodedPath
if (path.startsWith(Routes.DAEMON_BASE)) return@intercept if (path == Routes.DAEMON_BASE) return@intercept
val secret = IPC.getIPCSecret() val secret = IPC.getIPCSecret()
val user = System.getProperty("user.name") val user = System.getProperty("user.name")
@@ -106,9 +108,9 @@ val serviceModule = module {
} }
} }
} }
single<DaemonHealthService> { UdsDaemonHealthService(get()) } single<DaemonService> { UdsDaemonService(get(), get(), get()) }
single<TunnelCommandService> { UdsTunnelCommandService(get(), tunnelRepository = get()) } single<TunnelService> { UdsTunnelService(get(), tunnelRepository = get()) }
single<BackendCommandService> { UdsBackendCommandService(get(), get(), get()) } single<BackendService> { UdsBackendService(get(), get(), get()) }
single<TunnelImportService> { DefaultTunnelImportService(get()) } single<TunnelImportService> { DefaultTunnelImportService(get()) }
} }
@@ -9,5 +9,7 @@ sealed class ClientException : Exception() {
class UnknownError(override val message: String) : ClientException() class UnknownError(override val message: String) : ClientException()
class UnauthorizedException(override val message: String) : ClientException()
class DaemonCommsException : ClientException() class DaemonCommsException : ClientException()
} }
@@ -9,4 +9,5 @@ data class GeneralSettings(
val theme: Theme = Theme.DARK, val theme: Theme = Theme.DARK,
val locale: String? = null, val locale: String? = null,
val alreadyDonated: Boolean = false, val alreadyDonated: Boolean = false,
val restoreTunnelOnBoot: Boolean = false,
) )
@@ -16,4 +16,6 @@ interface GeneralSettingRepository {
suspend fun updateLocale(locale: String) suspend fun updateLocale(locale: String)
suspend fun updateAlreadyDonated(donated: Boolean) suspend fun updateAlreadyDonated(donated: Boolean)
suspend fun updateRestoreTunnelOnBoot(enabled: Boolean)
} }
@@ -4,11 +4,13 @@ import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
interface BackendCommandService { interface BackendService {
suspend fun setMode(mode: BackendMode): Result<Unit> suspend fun setMode(mode: BackendMode): Result<Unit>
suspend fun setKillSwitch(enabled: Boolean): Result<Unit> suspend fun setKillSwitch(enabled: Boolean): Result<Unit>
suspend fun setKillSwitchLanBypass(enabled: Boolean): Result<Unit>
suspend fun getStatus(): Result<BackendStatus> suspend fun getStatus(): Result<BackendStatus>
fun statusFlow(): Flow<BackendStatus> fun statusFlow(): Flow<BackendStatus>
@@ -2,8 +2,12 @@ package com.zaneschepke.wireguardautotunnel.client.service
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
interface DaemonHealthService { interface DaemonService {
suspend fun alive(): Boolean suspend fun alive(): Boolean
suspend fun setRestoreKillSwitch(enabled: Boolean): Result<Unit>
suspend fun setRestoreTunnel(enabled: Boolean): Result<Unit>
val alive: Flow<Boolean> val alive: Flow<Boolean>
} }
@@ -1,6 +1,6 @@
package com.zaneschepke.wireguardautotunnel.client.service package com.zaneschepke.wireguardautotunnel.client.service
interface TunnelCommandService { interface TunnelService {
suspend fun startTunnel(id: Long): Result<Unit> suspend fun startTunnel(id: Long): Result<Unit>
suspend fun stopTunnel(id: Long): Result<Unit> suspend fun stopTunnel(id: Long): Result<Unit>
@@ -891,6 +891,32 @@
"Apache-2.0" "Apache-2.0"
] ]
}, },
{
"uniqueId": "com.ibm.icu:icu4j",
"funding": [
],
"developers": [
{
"name": "Markus Scherer"
},
{
"name": "Richard Gillam"
}
],
"artifactVersion": "77.1",
"description": "International Components for Unicode for Java (ICU4J) is a mature, widely used Java library\n providing Unicode and Globalization support",
"scm": {
"connection": "scm:git:git://github.com/unicode-org/icu.git",
"url": "https://github.com/unicode-org/icu",
"developerConnection": "scm:git:git@github.com:unicode-org/icu.git"
},
"name": "ICU4J",
"website": "https://icu.unicode.org/",
"licenses": [
"a953261c1ab0a39c092979a5ef28d565"
]
},
{ {
"uniqueId": "com.materialkolor:material-kolor", "uniqueId": "com.materialkolor:material-kolor",
"funding": [ "funding": [
@@ -1296,6 +1322,29 @@
"MIT" "MIT"
] ]
}, },
{
"uniqueId": "io.github.skeptick.libres:libres",
"funding": [
],
"developers": [
{
"name": "Danil Yudov"
}
],
"artifactVersion": "1.2.4",
"description": "Resources generation in Kotlin Multiplatform.",
"scm": {
"connection": "scm:git:ssh://git@github.com/skeptick/libres.git",
"url": "https://github.com/skeptick/libres",
"developerConnection": "scm:git:ssh://git@github.com/skeptick/libres.git"
},
"name": "Libres Core",
"website": "https://github.com/skeptick/libres",
"licenses": [
"Apache-2.0"
]
},
{ {
"uniqueId": "io.github.skolson:kmp-io", "uniqueId": "io.github.skolson:kmp-io",
"funding": [ "funding": [
@@ -1915,6 +1964,27 @@
"Apache-2.0" "Apache-2.0"
] ]
}, },
{
"uniqueId": "nl.jacobras:Human-Readable",
"funding": [
],
"developers": [
{
"name": "Jacob Ras"
}
],
"artifactVersion": "1.12.3",
"description": "A small set of data formatting utilities for Kotlin Multiplatform (KMP)",
"scm": {
"url": "https://github.com/jacobras/human-readable"
},
"name": "Human Readable",
"website": "https://github.com/jacobras/human-readable",
"licenses": [
"MIT"
]
},
{ {
"uniqueId": "org.apache.commons:commons-lang3", "uniqueId": "org.apache.commons:commons-lang3",
"funding": [ "funding": [
@@ -4602,6 +4672,11 @@
"spdxId": "MIT", "spdxId": "MIT",
"name": "MIT License" "name": "MIT License"
}, },
"a953261c1ab0a39c092979a5ef28d565": {
"hash": "a953261c1ab0a39c092979a5ef28d565",
"url": "https://raw.githubusercontent.com/unicode-org/icu/maint/maint-77/LICENSE",
"name": "Unicode-3.0"
},
"dd74d358d0b6f8b5099019a55929b63f": { "dd74d358d0b6f8b5099019a55929b63f": {
"hash": "dd74d358d0b6f8b5099019a55929b63f", "hash": "dd74d358d0b6f8b5099019a55929b63f",
"url": "https://www.gnu.org/licenses/lgpl-2.1.html", "url": "https://www.gnu.org/licenses/lgpl-2.1.html",
@@ -27,7 +27,7 @@ fun SwitchWithDivider(
color = MaterialTheme.colorScheme.outline, color = MaterialTheme.colorScheme.outline,
) )
Box(modifier = Modifier.pointerInput(Unit) { detectTapGestures {} }) { Box(modifier = Modifier.pointerInput(Unit) { detectTapGestures {} }) {
com.zaneschepke.wireguardautotunnel.ui.common.button.ThemedSwitch( ThemedSwitch(
checked = checked, checked = checked,
onClick = onClick, onClick = onClick,
enabled = enabled, enabled = enabled,
@@ -17,6 +17,7 @@ import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.dokar.sonner.Toast
import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.Res import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.Res
import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.appearance import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.appearance
import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.general import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.general
@@ -24,23 +25,35 @@ import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.lockdo
import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.settings import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.settings
import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.tunnel import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.tunnel
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalNavController import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalNavController
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalToaster
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.SurfaceRow import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.SurfaceRow
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.ThemedSwitch
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.label.GroupLabel
import com.zaneschepke.wireguardautotunnel.desktop.ui.navigation.Route import com.zaneschepke.wireguardautotunnel.desktop.ui.navigation.Route
import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.settings.appearance.LockdownIntent import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.settings.appearance.LockdownIntent
import com.zaneschepke.wireguardautotunnel.desktop.ui.sideeffects.AppSideEffect
import com.zaneschepke.wireguardautotunnel.desktop.viewmodel.SettingsViewModel import com.zaneschepke.wireguardautotunnel.desktop.viewmodel.SettingsViewModel
import com.zaneschepke.wireguardautotunnel.ui.common.button.ThemedSwitch
import com.zaneschepke.wireguardautotunnel.ui.common.label.GroupLabel
import org.jetbrains.compose.resources.stringResource import org.jetbrains.compose.resources.stringResource
import org.koin.compose.viewmodel.koinViewModel import org.koin.compose.viewmodel.koinViewModel
import org.orbitmvi.orbit.compose.collectAsState import org.orbitmvi.orbit.compose.collectAsState
import org.orbitmvi.orbit.compose.collectSideEffect
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun SettingsScreen(viewModel: SettingsViewModel = koinViewModel()) { fun SettingsScreen(viewModel: SettingsViewModel = koinViewModel()) {
val navController = LocalNavController.current val navController = LocalNavController.current
val toaster = LocalToaster.current
val uiState by viewModel.collectAsState() val uiState by viewModel.collectAsState()
viewModel.collectSideEffect { sideEffect ->
when (sideEffect) {
is AppSideEffect.Toast -> {
toaster.show(Toast(sideEffect.message, sideEffect.type))
}
}
}
if (!uiState.isLoaded) return if (!uiState.isLoaded) return
Scaffold(topBar = { TopAppBar(title = { Text(stringResource(Res.string.settings)) }) }) { Scaffold(topBar = { TopAppBar(title = { Text(stringResource(Res.string.settings)) }) }) {
@@ -94,14 +107,14 @@ fun SettingsScreen(viewModel: SettingsViewModel = koinViewModel()) {
title = "Allow local network access", title = "Allow local network access",
onClick = { onClick = {
viewModel.onLockdownAction( viewModel.onLockdownAction(
LockdownIntent.TogglePersist(!uiState.lockdownRestoreOnBootEnabled) LockdownIntent.ToggleBypassLan(!uiState.lockdownRestoreOnBootEnabled)
) )
}, },
trailing = { trailing = {
ThemedSwitch( ThemedSwitch(
checked = uiState.lockdownBypassEnabled, checked = uiState.lockdownBypassEnabled,
onClick = { onClick = {
viewModel.onLockdownAction(LockdownIntent.TogglePersist(it)) viewModel.onLockdownAction(LockdownIntent.ToggleBypassLan(it))
}, },
) )
}, },
@@ -116,16 +129,12 @@ fun SettingsScreen(viewModel: SettingsViewModel = koinViewModel()) {
leading = { Icon(Icons.Default.RestartAlt, contentDescription = null) }, leading = { Icon(Icons.Default.RestartAlt, contentDescription = null) },
title = "Restore tunnel on system startup", title = "Restore tunnel on system startup",
onClick = { onClick = {
viewModel.onLockdownAction( viewModel.onRestoreTunnelOnBoot(!uiState.tunnelRestoreOnBootEnabled)
LockdownIntent.TogglePersist(!uiState.lockdownRestoreOnBootEnabled)
)
}, },
trailing = { trailing = {
ThemedSwitch( ThemedSwitch(
checked = uiState.lockdownRestoreOnBootEnabled, checked = uiState.tunnelRestoreOnBootEnabled,
onClick = { onClick = { viewModel.onRestoreTunnelOnBoot(it) },
viewModel.onLockdownAction(LockdownIntent.TogglePersist(it))
},
) )
}, },
) )
@@ -63,11 +63,11 @@ import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.websit
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalNavController import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalNavController
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalToaster import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalToaster
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.SurfaceRow import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.SurfaceRow
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.label.GroupLabel
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.text.DescriptionText
import com.zaneschepke.wireguardautotunnel.desktop.ui.navigation.Route import com.zaneschepke.wireguardautotunnel.desktop.ui.navigation.Route
import com.zaneschepke.wireguardautotunnel.desktop.util.DesktopUtils import com.zaneschepke.wireguardautotunnel.desktop.util.DesktopUtils
import com.zaneschepke.wireguardautotunnel.desktop.util.toClipEntry import com.zaneschepke.wireguardautotunnel.desktop.util.toClipEntry
import com.zaneschepke.wireguardautotunnel.ui.common.label.GroupLabel
import com.zaneschepke.wireguardautotunnel.ui.common.text.DescriptionText
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import org.jetbrains.compose.resources.stringResource import org.jetbrains.compose.resources.stringResource
import org.jetbrains.compose.resources.vectorResource import org.jetbrains.compose.resources.vectorResource
@@ -40,12 +40,12 @@ import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.libera
import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.options import com.zaneschepke.wireguardautotunnel.composeapp.generated.resources.options
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalNavController import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalNavController
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.SurfaceRow import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.SurfaceRow
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.ThemedSwitch
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.label.GroupLabel
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.text.DescriptionText
import com.zaneschepke.wireguardautotunnel.desktop.ui.navigation.Route import com.zaneschepke.wireguardautotunnel.desktop.ui.navigation.Route
import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.support.donate.components.DonationHeroSection
import com.zaneschepke.wireguardautotunnel.desktop.viewmodel.AppViewModel import com.zaneschepke.wireguardautotunnel.desktop.viewmodel.AppViewModel
import com.zaneschepke.wireguardautotunnel.ui.common.button.ThemedSwitch
import com.zaneschepke.wireguardautotunnel.ui.common.label.GroupLabel
import com.zaneschepke.wireguardautotunnel.ui.common.text.DescriptionText
import com.zaneschepke.wireguardautotunnel.ui.screens.support.donate.components.DonationHeroSection
import org.jetbrains.compose.resources.stringResource import org.jetbrains.compose.resources.stringResource
import org.jetbrains.compose.resources.vectorResource import org.jetbrains.compose.resources.vectorResource
import org.orbitmvi.orbit.compose.collectAsState import org.orbitmvi.orbit.compose.collectAsState
@@ -1,4 +1,4 @@
package com.zaneschepke.wireguardautotunnel.ui.screens.support.donate.components package com.zaneschepke.wireguardautotunnel.desktop.ui.screens.support.donate.components
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
@@ -168,18 +168,19 @@ fun TunnelsScreen(viewModel: TunnelsViewModel = koinViewModel()) {
detectTapGestures(onPress = { viewModel.onClearSelectionMode() }) detectTapGestures(onPress = { viewModel.onClearSelectionMode() })
} }
} }
) ) {
TunnelList( TunnelList(
uiState = uiState, uiState = uiState,
startTunnel = viewModel::onStartTunnel, startTunnel = viewModel::onStartTunnel,
stopTunnel = viewModel::onStopTunnel, stopTunnel = viewModel::onStopTunnel,
viewModel::onItemsReordered, viewModel::onItemsReordered,
viewModel::onPersistReorder, viewModel::onPersistReorder,
viewModel::onSelectTunnel, viewModel::onSelectTunnel,
viewModel::onDeselectTunnel, viewModel::onDeselectTunnel,
viewModel::onClearSelectionMode, viewModel::onClearSelectionMode,
{ intent -> pendingDeleteIntent = intent }, { intent -> pendingDeleteIntent = intent },
viewModel::onExportIntent, viewModel::onExportIntent,
) )
}
} }
} }
@@ -16,3 +16,14 @@ fun TunnelState.asColor(): Color {
TunnelState.RESOLVING_DNS -> Straw TunnelState.RESOLVING_DNS -> Straw
} }
} }
fun TunnelState.asTooltipMessage(): String? {
return when (this) {
TunnelState.DOWN,
TunnelState.STARTING,
TunnelState.UNKNOWN -> null
TunnelState.HEALTHY -> "Healthy"
TunnelState.HANDSHAKE_FAILURE -> "Handshake failure"
TunnelState.RESOLVING_DNS -> "Resolving DNS"
}
}
@@ -8,15 +8,21 @@ import androidx.compose.foundation.background
import androidx.compose.foundation.gestures.awaitEachGesture import androidx.compose.foundation.gestures.awaitEachGesture
import androidx.compose.foundation.gestures.awaitFirstDown import androidx.compose.foundation.gestures.awaitFirstDown
import androidx.compose.foundation.gestures.waitForUpOrCancellation import androidx.compose.foundation.gestures.waitForUpOrCancellation
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.Circle import androidx.compose.material.icons.rounded.Circle
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.* import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.ExperimentalComposeUiApi import androidx.compose.ui.ExperimentalComposeUiApi
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.shadow import androidx.compose.ui.draw.shadow
@@ -30,15 +36,20 @@ import com.zaneschepke.wireguardautotunnel.client.domain.model.TunnelConfig
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.TunnelState import com.zaneschepke.wireguardautotunnel.core.ipc.dto.TunnelState
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalNavController import com.zaneschepke.wireguardautotunnel.desktop.ui.common.LocalNavController
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.SurfaceRow import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.SurfaceRow
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.button.SwitchWithDivider
import com.zaneschepke.wireguardautotunnel.desktop.ui.common.tooltip.CustomTooltip
import com.zaneschepke.wireguardautotunnel.desktop.ui.navigation.Route import com.zaneschepke.wireguardautotunnel.desktop.ui.navigation.Route
import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.tunnels.DeleteIntent import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.tunnels.DeleteIntent
import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.tunnels.ExportIntent import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.tunnels.ExportIntent
import com.zaneschepke.wireguardautotunnel.desktop.ui.state.TunnelsUiState import com.zaneschepke.wireguardautotunnel.desktop.ui.state.TunnelsUiState
import com.zaneschepke.wireguardautotunnel.ui.common.button.SwitchWithDivider
import sh.calvin.reorderable.ReorderableItem import sh.calvin.reorderable.ReorderableItem
import sh.calvin.reorderable.rememberReorderableLazyListState import sh.calvin.reorderable.rememberReorderableLazyListState
@OptIn(ExperimentalFoundationApi::class, ExperimentalComposeUiApi::class) @OptIn(
ExperimentalFoundationApi::class,
ExperimentalComposeUiApi::class,
ExperimentalMaterial3Api::class,
)
@Composable @Composable
fun TunnelList( fun TunnelList(
uiState: TunnelsUiState, uiState: TunnelsUiState,
@@ -61,14 +72,14 @@ fun TunnelList(
onReorder(from.index, to.index) onReorder(from.index, to.index)
} }
val tunnelIndicatorColors by val tunnelIndicators by
remember(uiState.tunnelStates, uiState.tunnels) { remember(uiState.tunnelStates, uiState.tunnels) {
derivedStateOf { derivedStateOf {
uiState.tunnels.associate { tunnel -> uiState.tunnels.associate { tunnel ->
val state = val state =
uiState.tunnelStates.firstOrNull { it.id == tunnel.id }?.state uiState.tunnelStates.firstOrNull { it.id == tunnel.id }?.state
?: TunnelState.UNKNOWN ?: TunnelState.UNKNOWN
tunnel.id to state.asColor() tunnel.id to (state.asColor() to state.asTooltipMessage())
} }
} }
} }
@@ -82,17 +93,32 @@ fun TunnelList(
LazyColumn( LazyColumn(
state = lazyListState, state = lazyListState,
modifier = modifier =
modifier.background(MaterialTheme.colorScheme.background).onKeyEvent { modifier
if (it.key == Key.Escape && uiState.isSelectionMode) { .background(MaterialTheme.colorScheme.background)
onExitSelectionMode() .onKeyEvent {
true if (it.key == Key.Escape && uiState.isSelectionMode) {
} else false onExitSelectionMode()
}, true
} else false
}
.fillMaxSize(),
) { ) {
if (uiState.tunnels.isEmpty()) {
item {
Box(
modifier = Modifier.fillMaxSize().padding(top = 80.dp),
contentAlignment = Alignment.Center,
) {
Text("No tunnels added yet! Click the + symbol to add a tunnel.")
}
}
return@LazyColumn
}
items(uiState.tunnels, key = { it.id }) { tunnel -> items(uiState.tunnels, key = { it.id }) { tunnel ->
val isSelected = uiState.selectedTunnels.contains(tunnel)
ReorderableItem(reorderableState, key = tunnel.id) { isDragging -> ReorderableItem(reorderableState, key = tunnel.id) { isDragging ->
val elevation by animateDpAsState(if (isDragging) 8.dp else 0.dp) val elevation by animateDpAsState(if (isDragging) 8.dp else 0.dp)
val isSelected = uiState.selectedTunnels.contains(tunnel)
ContextMenuArea( ContextMenuArea(
items = { items = {
@@ -144,22 +170,25 @@ fun TunnelList(
.pointerInput(tunnel.id, uiState.isSelectionMode, isSelected) { .pointerInput(tunnel.id, uiState.isSelectionMode, isSelected) {
awaitEachGesture { awaitEachGesture {
val down = awaitFirstDown() val down = awaitFirstDown()
down.consume()
val up = waitForUpOrCancellation() val up = waitForUpOrCancellation()
if (up != null) {
up.consume()
if ((up.position - down.position).getDistance() < 5f) {
val modifiers = currentEvent.keyboardModifiers
val isMultiSelectModifier =
modifiers.isCtrlPressed ||
modifiers.isMetaPressed
if ( if (
up != null && isMultiSelectModifier || uiState.isSelectionMode
(up.position - down.position).getDistance() < 5f ) {
) { if (isSelected) onDeselected(tunnel)
else onSelected(tunnel)
val modifiers = currentEvent.keyboardModifiers } else {
val isMultiSelectModifier = navController.push(Route.Tunnel(tunnel.id))
modifiers.isCtrlPressed || modifiers.isMetaPressed }
if (isMultiSelectModifier || uiState.isSelectionMode) {
if (isSelected) onDeselected(tunnel)
else onSelected(tunnel)
} else {
navController.push(Route.Tunnel(tunnel.id))
} }
} }
} }
@@ -167,14 +196,22 @@ fun TunnelList(
.pointerHoverIcon(PointerIcon.Hand) .pointerHoverIcon(PointerIcon.Hand)
.then(if (isDragging) Modifier.zIndex(1f) else Modifier), .then(if (isDragging) Modifier.zIndex(1f) else Modifier),
leading = { leading = {
Icon( val tooltip = tunnelIndicators[tunnel.id]?.second
Icons.Rounded.Circle, val indicatorColor = tunnelIndicators[tunnel.id]?.first
contentDescription = null, @Composable
tint = fun icon() {
tunnelIndicatorColors[tunnel.id] Icon(
?: TunnelState.UNKNOWN.asColor(), Icons.Rounded.Circle,
modifier = Modifier.size(14.dp), contentDescription = null,
) tint = indicatorColor ?: TunnelState.UNKNOWN.asColor(),
modifier = Modifier.size(14.dp),
)
}
if (tooltip != null) {
CustomTooltip(text = tooltip) { icon() }
} else {
icon()
}
}, },
selected = isSelected, selected = isSelected,
trailing = { trailing = {
@@ -5,4 +5,5 @@ data class SettingsUiState(
val lockdownEnabled: Boolean = false, val lockdownEnabled: Boolean = false,
val lockdownRestoreOnBootEnabled: Boolean = false, val lockdownRestoreOnBootEnabled: Boolean = false,
val lockdownBypassEnabled: Boolean = false, val lockdownBypassEnabled: Boolean = false,
val tunnelRestoreOnBootEnabled: Boolean = false,
) )
@@ -18,6 +18,7 @@ fun ClientException?.asUserMessage(): String {
is ClientException.DaemonCommsException -> is ClientException.DaemonCommsException ->
"Daemon communication error, please check the daemon status." "Daemon communication error, please check the daemon status."
is ClientException.InternalServerError -> "An internal error occurred, please try again." is ClientException.InternalServerError -> "An internal error occurred, please try again."
is ClientException.UnauthorizedException -> "Unauthorized, please try again."
is ClientException.UnknownError, is ClientException.UnknownError,
null -> "An unknown error occurred, please try again." null -> "An unknown error occurred, please try again."
} }
@@ -3,8 +3,8 @@ package com.zaneschepke.wireguardautotunnel.desktop.viewmodel
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import com.zaneschepke.wireguardautotunnel.client.data.model.Theme import com.zaneschepke.wireguardautotunnel.client.data.model.Theme
import com.zaneschepke.wireguardautotunnel.client.domain.repository.GeneralSettingRepository import com.zaneschepke.wireguardautotunnel.client.domain.repository.GeneralSettingRepository
import com.zaneschepke.wireguardautotunnel.client.service.BackendCommandService import com.zaneschepke.wireguardautotunnel.client.service.BackendService
import com.zaneschepke.wireguardautotunnel.client.service.DaemonHealthService import com.zaneschepke.wireguardautotunnel.client.service.DaemonService
import com.zaneschepke.wireguardautotunnel.desktop.ui.sideeffects.AppSideEffect import com.zaneschepke.wireguardautotunnel.desktop.ui.sideeffects.AppSideEffect
import com.zaneschepke.wireguardautotunnel.desktop.ui.state.AppUiState import com.zaneschepke.wireguardautotunnel.desktop.ui.state.AppUiState
import io.github.sudarshanmhasrup.localina.api.LocaleUpdater import io.github.sudarshanmhasrup.localina.api.LocaleUpdater
@@ -13,8 +13,8 @@ import org.orbitmvi.orbit.viewmodel.container
class AppViewModel( class AppViewModel(
private val settingsRepository: GeneralSettingRepository, private val settingsRepository: GeneralSettingRepository,
private val daemonHealthService: DaemonHealthService, private val daemonService: DaemonService,
private val backendCommandService: BackendCommandService, private val backendService: BackendService,
) : ContainerHost<AppUiState, AppSideEffect>, ViewModel() { ) : ContainerHost<AppUiState, AppSideEffect>, ViewModel() {
override val container = override val container =
@@ -37,11 +37,9 @@ class AppViewModel(
} }
} }
} }
intent { daemonService.alive.collect { reduce { state.copy(daemonConnected = it) } } }
intent { intent {
daemonHealthService.alive.collect { reduce { state.copy(daemonConnected = it) } } backendService.statusFlow().collect {
}
intent {
backendCommandService.statusFlow().collect {
reduce { state.copy(lockdownActive = it.killSwitchEnabled) } reduce { state.copy(lockdownActive = it.killSwitchEnabled) }
} }
} }
@@ -1,12 +1,16 @@
package com.zaneschepke.wireguardautotunnel.desktop.viewmodel package com.zaneschepke.wireguardautotunnel.desktop.viewmodel
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import com.dokar.sonner.ToastType
import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException
import com.zaneschepke.wireguardautotunnel.client.domain.repository.GeneralSettingRepository import com.zaneschepke.wireguardautotunnel.client.domain.repository.GeneralSettingRepository
import com.zaneschepke.wireguardautotunnel.client.domain.repository.LockdownSettingsRepository import com.zaneschepke.wireguardautotunnel.client.domain.repository.LockdownSettingsRepository
import com.zaneschepke.wireguardautotunnel.client.service.BackendCommandService import com.zaneschepke.wireguardautotunnel.client.service.BackendService
import com.zaneschepke.wireguardautotunnel.client.service.DaemonService
import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.settings.appearance.LockdownIntent import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.settings.appearance.LockdownIntent
import com.zaneschepke.wireguardautotunnel.desktop.ui.sideeffects.AppSideEffect import com.zaneschepke.wireguardautotunnel.desktop.ui.sideeffects.AppSideEffect
import com.zaneschepke.wireguardautotunnel.desktop.ui.state.SettingsUiState import com.zaneschepke.wireguardautotunnel.desktop.ui.state.SettingsUiState
import com.zaneschepke.wireguardautotunnel.desktop.util.asUserMessage
import kotlinx.coroutines.flow.combine import kotlinx.coroutines.flow.combine
import org.orbitmvi.orbit.ContainerHost import org.orbitmvi.orbit.ContainerHost
import org.orbitmvi.orbit.viewmodel.container import org.orbitmvi.orbit.viewmodel.container
@@ -14,7 +18,8 @@ import org.orbitmvi.orbit.viewmodel.container
class SettingsViewModel( class SettingsViewModel(
private val settingsRepository: GeneralSettingRepository, private val settingsRepository: GeneralSettingRepository,
private val lockdownRepository: LockdownSettingsRepository, private val lockdownRepository: LockdownSettingsRepository,
private val backendCommandService: BackendCommandService, private val backendService: BackendService,
private val daemonService: DaemonService,
) : ContainerHost<SettingsUiState, AppSideEffect>, ViewModel() { ) : ContainerHost<SettingsUiState, AppSideEffect>, ViewModel() {
override val container = override val container =
@@ -33,19 +38,34 @@ class SettingsViewModel(
lockdownEnabled = lockdown.enabled, lockdownEnabled = lockdown.enabled,
lockdownRestoreOnBootEnabled = lockdown.restoreOnBoot, lockdownRestoreOnBootEnabled = lockdown.restoreOnBoot,
lockdownBypassEnabled = lockdown.bypassLan, lockdownBypassEnabled = lockdown.bypassLan,
tunnelRestoreOnBootEnabled = settings.restoreTunnelOnBoot,
) )
} }
} }
} }
} }
fun onRestoreTunnelOnBoot(enabled: Boolean) = intent {
daemonService.setRestoreTunnel(enabled).onFailure {
val message = (it as? ClientException).asUserMessage()
postSideEffect(AppSideEffect.Toast(message, ToastType.Error))
}
}
fun onLockdownAction(intent: LockdownIntent) = intent { fun onLockdownAction(intent: LockdownIntent) = intent {
when (intent) { when (intent) {
is LockdownIntent.ToggleBypassLan -> {} is LockdownIntent.ToggleBypassLan -> {
is LockdownIntent.ToggleMaster -> { backendService.setKillSwitchLanBypass(intent.enabled)
backendCommandService.setKillSwitch(intent.enabled)
} }
is LockdownIntent.TogglePersist -> {} is LockdownIntent.ToggleMaster -> {
backendService.setKillSwitch(intent.enabled)
}
is LockdownIntent.TogglePersist -> {
daemonService.setRestoreKillSwitch(intent.enabled)
}
}.onFailure {
val message = (it as? ClientException).asUserMessage()
postSideEffect(AppSideEffect.Toast(message, ToastType.Error))
} }
} }
} }
@@ -3,17 +3,16 @@ package com.zaneschepke.wireguardautotunnel.desktop.viewmodel
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import com.dokar.sonner.ToastType import com.dokar.sonner.ToastType
import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository
import com.zaneschepke.wireguardautotunnel.client.service.BackendCommandService import com.zaneschepke.wireguardautotunnel.client.service.BackendService
import com.zaneschepke.wireguardautotunnel.desktop.ui.sideeffects.AppSideEffect import com.zaneschepke.wireguardautotunnel.desktop.ui.sideeffects.AppSideEffect
import com.zaneschepke.wireguardautotunnel.desktop.ui.state.TunnelUiState import com.zaneschepke.wireguardautotunnel.desktop.ui.state.TunnelUiState
import com.zaneschepke.wireguardautotunnel.parser.Config import com.zaneschepke.wireguardautotunnel.parser.Config
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.map
import org.orbitmvi.orbit.ContainerHost import org.orbitmvi.orbit.ContainerHost
import org.orbitmvi.orbit.viewmodel.container import org.orbitmvi.orbit.viewmodel.container
class TunnelViewModel( class TunnelViewModel(
private val backendCommandService: BackendCommandService, private val backendService: BackendService,
private val tunnelRepository: TunnelRepository, private val tunnelRepository: TunnelRepository,
val tunnelId: Long, val tunnelId: Long,
) : ContainerHost<TunnelUiState, AppSideEffect>, ViewModel() { ) : ContainerHost<TunnelUiState, AppSideEffect>, ViewModel() {
@@ -23,44 +22,44 @@ class TunnelViewModel(
TunnelUiState(), TunnelUiState(),
buildSettings = { repeatOnSubscribedStopTimeout = 5000L }, buildSettings = { repeatOnSubscribedStopTimeout = 5000L },
) { ) {
combine( intent {
tunnelRepository.flow.map { it.firstOrNull { tun -> tun.id == tunnelId } }, tunnelRepository.flow
backendCommandService.statusFlow().map { status -> .map { it.firstOrNull { tun -> tun.id == tunnelId } }
status.activeTunnels.firstOrNull { tunnel -> tunnel.id == tunnelId } .collect { tunnel ->
}, reduce {
) { tunnel, activeTunnel -> state.copy(
tunnel to activeTunnel isLoaded = true,
} originalConfig = tunnel ?: state.originalConfig,
.collect { (tunnel, status) -> editedConfig = tunnel ?: state.editedConfig,
reduce { )
state.copy( }
isLoaded = true,
originalConfig = tunnel ?: state.originalConfig,
editedConfig = tunnel ?: state.editedConfig,
tunnelState = status?.state,
activeConfig = status?.activeConfig ?: state.activeConfig,
)
} }
} }
intent {
backendService
.statusFlow()
.map { status ->
status.activeTunnels.firstOrNull { tunnel -> tunnel.id == tunnelId }
}
.collect {
reduce {
state.copy(
tunnelState = it?.state ?: state.tunnelState,
activeConfig = it?.activeConfig ?: state.activeConfig,
)
}
}
}
} }
fun onConfigUpdate(newText: String) = intent { fun onConfigUpdate(newText: String) = intent {
reduce { val newEdited = state.editedConfig.copy(quickConfig = newText)
state.copy( reduce { state.copy(editedConfig = newEdited, isDirty = state.originalConfig != newEdited) }
editedConfig = state.editedConfig.copy(quickConfig = newText),
isDirty = state.originalConfig != state.editedConfig,
)
}
} }
fun onNameUpdated(name: String) = intent { fun onNameUpdated(name: String) = intent {
reduce { val newEdited = state.editedConfig.copy(name = name)
val updatedConfig = state.editedConfig.copy(name = name) reduce { state.copy(editedConfig = newEdited, isDirty = state.originalConfig != newEdited) }
state.copy(
editedConfig = updatedConfig,
isDirty = state.originalConfig != state.editedConfig,
)
}
} }
fun saveChanges() = intent { fun saveChanges() = intent {
@@ -5,9 +5,9 @@ import com.dokar.sonner.ToastType
import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException import com.zaneschepke.wireguardautotunnel.client.domain.error.ClientException
import com.zaneschepke.wireguardautotunnel.client.domain.model.TunnelConfig import com.zaneschepke.wireguardautotunnel.client.domain.model.TunnelConfig
import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository
import com.zaneschepke.wireguardautotunnel.client.service.BackendCommandService import com.zaneschepke.wireguardautotunnel.client.service.BackendService
import com.zaneschepke.wireguardautotunnel.client.service.TunnelCommandService
import com.zaneschepke.wireguardautotunnel.client.service.TunnelImportService import com.zaneschepke.wireguardautotunnel.client.service.TunnelImportService
import com.zaneschepke.wireguardautotunnel.client.service.TunnelService
import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.tunnels.DeleteIntent import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.tunnels.DeleteIntent
import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.tunnels.ExportIntent import com.zaneschepke.wireguardautotunnel.desktop.ui.screens.tunnels.ExportIntent
import com.zaneschepke.wireguardautotunnel.desktop.ui.sideeffects.AppSideEffect import com.zaneschepke.wireguardautotunnel.desktop.ui.sideeffects.AppSideEffect
@@ -18,7 +18,7 @@ import io.github.vinceglb.filekit.FileKit
import io.github.vinceglb.filekit.dialogs.openFileSaver import io.github.vinceglb.filekit.dialogs.openFileSaver
import io.github.vinceglb.filekit.name import io.github.vinceglb.filekit.name
import io.github.vinceglb.filekit.write import io.github.vinceglb.filekit.write
import kotlinx.coroutines.flow.combine import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.distinctUntilChanged import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.map
import org.orbitmvi.orbit.ContainerHost import org.orbitmvi.orbit.ContainerHost
@@ -26,8 +26,8 @@ import org.orbitmvi.orbit.viewmodel.container
class TunnelsViewModel( class TunnelsViewModel(
private val tunnelRepository: TunnelRepository, private val tunnelRepository: TunnelRepository,
private val backendCommandService: BackendCommandService, private val backendService: BackendService,
private val tunnelCommandService: TunnelCommandService, private val tunnelService: TunnelService,
private val tunnelImportService: TunnelImportService, private val tunnelImportService: TunnelImportService,
) : ContainerHost<TunnelsUiState, AppSideEffect>, ViewModel() { ) : ContainerHost<TunnelsUiState, AppSideEffect>, ViewModel() {
@@ -37,24 +37,16 @@ class TunnelsViewModel(
buildSettings = { repeatOnSubscribedStopTimeout = 5_000L }, buildSettings = { repeatOnSubscribedStopTimeout = 5_000L },
) { ) {
intent { intent {
combine( tunnelRepository.flow.collect { tunnels ->
tunnelRepository.flow, reduce { state.copy(tunnels = tunnels, isLoaded = true) }
backendCommandService }
.statusFlow() }
.map { it.activeTunnels } intent {
.distinctUntilChanged(), backendService
) { tunnels, tunnelStates -> .statusFlow()
Pair(tunnels.sortedBy { it.position }, tunnelStates) .map { it.activeTunnels }
} .distinctUntilChanged()
.collect { (tunnels, tunnelStates) -> .collect { reduce { state.copy(tunnelStates = it) } }
reduce {
state.copy(
tunnels = tunnels,
tunnelStates = tunnelStates,
isLoaded = true,
)
}
}
} }
} }
@@ -73,14 +65,14 @@ class TunnelsViewModel(
} }
fun onStartTunnel(id: Long) = intent { fun onStartTunnel(id: Long) = intent {
tunnelCommandService.startTunnel(id).onFailure { tunnelService.startTunnel(id).onFailure {
val message = (it as? ClientException).asUserMessage() val message = (it as? ClientException).asUserMessage()
postSideEffect(AppSideEffect.Toast(message, ToastType.Error)) postSideEffect(AppSideEffect.Toast(message, ToastType.Error))
} }
} }
fun onStopTunnel(id: Long) = intent { fun onStopTunnel(id: Long) = intent {
tunnelCommandService.stopTunnel(id).onFailure { tunnelService.stopTunnel(id).onFailure {
val message = (it as? ClientException).asUserMessage() val message = (it as? ClientException).asUserMessage()
postSideEffect(AppSideEffect.Toast(message, ToastType.Error)) postSideEffect(AppSideEffect.Toast(message, ToastType.Error))
} }
+12 -3
View File
@@ -17,6 +17,13 @@ app {
options += "-XX:+UseG1GC" options += "-XX:+UseG1GC"
options += "-XX:+UseStringDeduplication" options += "-XX:+UseStringDeduplication"
jlink-flags = [
"--compress=zip-9",
"--strip-debug"
]
# for high-res displays # for high-res displays
system-properties { system-properties {
"sun.java2d.uiScale" = "1.0" "sun.java2d.uiScale" = "1.0"
@@ -45,7 +52,7 @@ app {
inputs += "daemon/build/install/daemon/lib/*.jar" inputs += "daemon/build/install/daemon/lib/*.jar"
inputs += "cli/build/install/cli/lib/*.jar" inputs += "cli/build/install/cli/lib/*.jar"
// Target platforms # Target platforms
machines = [ machines = [
linux.amd64.glibc, linux.amd64.glibc,
windows.amd64, windows.amd64,
@@ -77,11 +84,13 @@ app {
file-name = "wgtunnel-daemon.service" file-name = "wgtunnel-daemon.service"
# start early to avoid leaks
Unit { Unit {
Description = "WG Tunnel Daemon" Description = "WG Tunnel Daemon"
Documentation = "https://wgtunnel.com" Documentation = "https://wgtunnel.com"
Before = "network-online.target" Before= network.target network-pre.target
After = "NetworkManager.service systemd-resolved.service" Wants= network.target
After= local-fs.target
StartLimitBurst = 5 StartLimitBurst = 5
StartLimitIntervalSec = 20 StartLimitIntervalSec = 20
} }
@@ -54,7 +54,7 @@ object PermissionsHelper {
path, path,
PosixFilePermissions.fromString(OWNER_FULL_CONTROL_SYMBOLIC), PosixFilePermissions.fromString(OWNER_FULL_CONTROL_SYMBOLIC),
) )
Logger.i { "Successfully set directory permissions to " } Logger.i { "Successfully set daemon data directory permission" }
} catch (e: Exception) { } catch (e: Exception) {
Logger.e { "POSIX native permissions failed: ${e.message} → falling back to chmod" } Logger.e { "POSIX native permissions failed: ${e.message} → falling back to chmod" }
try { try {
@@ -77,6 +77,58 @@ object PermissionsHelper {
} }
} }
fun secureDaemonDataDirectory(path: Path) {
val pathString = path.toString()
try {
if (SystemUtils.IS_OS_WINDOWS) {
val process =
ProcessBuilder(
ICACLS,
pathString,
WIN_INHERIT_REPLACE,
WIN_GRANT_REPLACE,
"$SID_SYSTEM$WIN_FULL_CONTROL_INHERIT",
WIN_GRANT_REPLACE,
"$SID_ADMINISTRATORS$WIN_FULL_CONTROL_INHERIT",
)
.start()
val exitCode = process.waitFor()
if (exitCode == 0) {
Logger.i { "Successfully secured Windows directory: $pathString" }
logWindowsACLs(pathString)
} else {
val error = process.errorStream.bufferedReader().use { it.readText() }
Logger.e { "Failed to secure Windows directory: $error" }
}
} else {
try {
Files.setPosixFilePermissions(
path,
PosixFilePermissions.fromString(OWNER_ONLY_PRIVATE_DIR),
)
Logger.i { "Successfully set POSIX permissions for directory: $pathString" }
} catch (e: Exception) {
Logger.e {
"POSIX native permissions failed: ${e.message} → falling back to chmod"
}
val exitCode = ProcessBuilder("chmod", "700", pathString).start().waitFor()
if (exitCode == 0) {
Logger.i {
"Successfully set directory permissions using chmod: $pathString"
}
} else {
Logger.e { "chmod failed with exit code $exitCode for: $pathString" }
}
}
val finalPerms = Files.getPosixFilePermissions(path)
Logger.i { "Final directory permissions: $finalPerms for $pathString" }
}
} catch (e: Exception) {
Logger.e(e) { "Error securing directory: $pathString" }
}
}
fun setupDirectoryPermissionsWindows(runtimeDirPath: String) { fun setupDirectoryPermissionsWindows(runtimeDirPath: String) {
try { try {
val process = val process =
@@ -5,12 +5,17 @@ object Routes {
const val DAEMON_STATUS = "$DAEMON_BASE/status" const val DAEMON_STATUS = "$DAEMON_BASE/status"
const val DAEMON_STATUS_WS = "$DAEMON_BASE/status/ws" const val DAEMON_STATUS_WS = "$DAEMON_BASE/status/ws"
const val DAEMON_RESTORE_TUNNEL = "$DAEMON_BASE/restore/tunnel"
const val DAEMON_RESTORE_KILL_SWITCH = "$DAEMON_BASE/restore/kill-switch"
const val BACKEND_BASE = "/backend" const val BACKEND_BASE = "/backend"
const val BACKEND_STATUS = "$BACKEND_BASE/status" const val BACKEND_STATUS = "$BACKEND_BASE/status"
const val BACKEND_ACTIVE_CONFIG = "$BACKEND_BASE/config/{id}/active" const val BACKEND_ACTIVE_CONFIG = "$BACKEND_BASE/config/{id}/active"
const val BACKEND_STATUS_WS = "$BACKEND_BASE/status/ws" const val BACKEND_STATUS_WS = "$BACKEND_BASE/status/ws"
const val BACKEND_KILL_SWITCH = "$BACKEND_BASE/kill-switch" const val BACKEND_KILL_SWITCH = "$BACKEND_BASE/kill-switch"
const val BACKEND_KILL_SWITCH_BYPASS = "$BACKEND_BASE/kill-switch/bypass-lan"
const val BACKEND_MODE = "$BACKEND_BASE/mode" const val BACKEND_MODE = "$BACKEND_BASE/mode"
object Tunnels { object Tunnels {
@@ -2,4 +2,4 @@ package com.zaneschepke.wireguardautotunnel.core.ipc.dto.request
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
@Serializable data class KillSwitchRequest(val enable: Boolean, val bypassLan: Boolean = false) @Serializable data class FlagRequest(val value: Boolean)
+2 -3
View File
@@ -22,9 +22,8 @@ dependencies {
testImplementation(kotlin("test")) testImplementation(kotlin("test"))
// secure caching // caching
implementation(libs.kstore) implementation(libs.multiplatform.settings)
implementation(libs.kstore.file)
implementation(libs.kotlinx.serialization) implementation(libs.kotlinx.serialization)
@@ -7,6 +7,7 @@ import com.zaneschepke.wireguardautotunnel.daemon.plugin.hmacShieldPlugin
import com.zaneschepke.wireguardautotunnel.daemon.routes.backendRoutes import com.zaneschepke.wireguardautotunnel.daemon.routes.backendRoutes
import com.zaneschepke.wireguardautotunnel.daemon.routes.daemonRoutes import com.zaneschepke.wireguardautotunnel.daemon.routes.daemonRoutes
import com.zaneschepke.wireguardautotunnel.daemon.routes.tunnelRoutes import com.zaneschepke.wireguardautotunnel.daemon.routes.tunnelRoutes
import com.zaneschepke.wireguardautotunnel.daemon.tunnel.RunningTunnel
import com.zaneschepke.wireguardautotunnel.tunnel.Backend import com.zaneschepke.wireguardautotunnel.tunnel.Backend
import io.ktor.http.HttpStatusCode import io.ktor.http.HttpStatusCode
import io.ktor.serialization.kotlinx.* import io.ktor.serialization.kotlinx.*
@@ -94,8 +95,8 @@ class TunnelDaemon(
} }
install(hmacShieldPlugin) install(hmacShieldPlugin)
routing { routing {
daemonRoutes() daemonRoutes(cacheRepository)
tunnelRoutes(backend) tunnelRoutes(backend, cacheRepository)
backendRoutes(backend) backendRoutes(backend)
} }
monitor.subscribe(ApplicationStarted) { monitor.subscribe(ApplicationStarted) {
@@ -114,11 +115,23 @@ class TunnelDaemon(
} }
scope.launch { scope.launch {
// TODO handle startup with cached settings val restoreTun = cacheRepository.getRestoreTunnelOnBoot()
val settings = cacheRepository.getKillSwitchSettings() val restoreKillSwitch = cacheRepository.getKillSwitchRestore()
val startConfigs = cacheRepository.getStartConfigs() if (restoreKillSwitch) {
Logger.d { "Got kill switch settings $settings" } Logger.i { "Attempting to restore kill switch" }
Logger.d { "Got start configs of size ${startConfigs.size}" } backend
.setKillSwitch(true)
.onFailure { Logger.e(it) { "Failed to restore kill switch" } }
.onSuccess { Logger.i { "Kill switch successfully restored" } }
}
if (restoreTun) {
Logger.i { "Attempting to restore previous tunnel" }
val config = cacheRepository.getLastActiveTunnelConfig() ?: return@launch
val name = cacheRepository.getLastActiveTunnelName() ?: return@launch
val id = cacheRepository.getLastActiveTunnelId() ?: return@launch
val tunnel = RunningTunnel(id, name)
backend.start(tunnel, config)
}
} }
} }
@@ -1,13 +1,31 @@
package com.zaneschepke.wireguardautotunnel.daemon.data package com.zaneschepke.wireguardautotunnel.daemon.data
import com.zaneschepke.wireguardautotunnel.daemon.data.model.KillSwitchSettings
interface DaemonCacheRepository { interface DaemonCacheRepository {
suspend fun getKillSwitchSettings(): KillSwitchSettings suspend fun updateKillSwitchEnabled(enabled: Boolean)
suspend fun setKillSwitchSettings(settings: KillSwitchSettings) suspend fun updateKillSwitchBypassLan(enabled: Boolean)
suspend fun getStartConfigs(): Set<String> suspend fun updateKillSwitchRestore(enabled: Boolean)
suspend fun setStartConfigs(configs: Set<String>) suspend fun getKillSwitchEnabled(): Boolean
suspend fun getKillSwitchBypassLan(): Boolean
suspend fun getKillSwitchRestore(): Boolean
suspend fun updateLastActiveTunnelConfig(quick: String)
suspend fun getLastActiveTunnelConfig(): String?
suspend fun updateLastActiveTunnelId(tunnelId: Long)
suspend fun getLastActiveTunnelId(): Long?
suspend fun updateLastActiveTunnelName(tunnelName: String)
suspend fun getLastActiveTunnelName(): String?
suspend fun setRestoreTunnelOnBoot(enabled: Boolean)
suspend fun getRestoreTunnelOnBoot(): Boolean
} }
@@ -1,109 +0,0 @@
package com.zaneschepke.wireguardautotunnel.daemon.data
import co.touchlab.kermit.Logger
import com.zaneschepke.wireguardautotunnel.daemon.data.model.DaemonCacheData
import com.zaneschepke.wireguardautotunnel.daemon.data.model.KillSwitchSettings
import io.github.xxfast.kstore.KStore
import io.github.xxfast.kstore.file.storeOf
import java.nio.file.Files
import java.nio.file.Paths
import java.nio.file.attribute.PosixFilePermissions
import kotlinx.io.files.Path
import kotlinx.serialization.json.Json
import org.apache.commons.lang3.SystemUtils
class KStoreDaemonCacheRepository(
private val baseCacheDir: java.nio.file.Path = getCacheBaseDir()
) : DaemonCacheRepository {
companion object {
const val CACHE_FILE_NAME = "cache.json"
private fun getCacheBaseDir(): java.nio.file.Path {
return when {
SystemUtils.IS_OS_MAC_OSX -> Paths.get("/Library/Application Support/wgtunnel")
SystemUtils.IS_OS_WINDOWS -> Paths.get(System.getenv("PROGRAMDATA") + "\\wgtunnel")
else -> Paths.get("/var/lib/wgtunnel")
}
}
}
init {
if (Files.notExists(baseCacheDir)) {
Files.createDirectories(baseCacheDir)
setSecurePermissions(baseCacheDir)
}
}
private fun getStore(): KStore<DaemonCacheData> {
val storePathNio = baseCacheDir.resolve(CACHE_FILE_NAME)
val storeKPath = Path(storePathNio.toString())
if (!Files.exists(storePathNio)) {
Files.createFile(storePathNio)
}
if (Files.size(storePathNio) == 0L) {
val defaultData = DaemonCacheData()
val defaultJson = Json.encodeToString(defaultData)
Files.writeString(storePathNio, defaultJson)
}
setSecurePermissions(storePathNio)
return storeOf(file = storeKPath, default = DaemonCacheData())
}
private fun setSecurePermissions(path: java.nio.file.Path) {
val os = System.getProperty("os.name").lowercase()
try {
if (!os.contains("win")) {
val isDirectory = Files.isDirectory(path)
val permsString =
if (isDirectory) "rwx------" else "rw-------" // 700 for dirs, 600 for files
val perms = PosixFilePermissions.fromString(permsString)
Files.setPosixFilePermissions(path, perms)
} else {
val process =
ProcessBuilder(
"icacls",
path.toString(),
"/inheritance:r", // remove inherited permissions
"/grant:r",
"SYSTEM:(F)", // full control to system
"/grant:r",
"Administrators:(F)", // full control to admin
)
.start()
val exitCode = process.waitFor()
if (exitCode != 0) {
Logger.e { "icacls failed with code $exitCode" }
}
}
} catch (e: Exception) {
Logger.e(e) { "Failed to set permissions" }
}
}
override suspend fun getKillSwitchSettings(): KillSwitchSettings {
return getStore().get()?.killSwitch ?: KillSwitchSettings(false, false)
}
override suspend fun setKillSwitchSettings(settings: KillSwitchSettings) {
val store = getStore()
store.update { current ->
current?.copy(killSwitch = settings) ?: DaemonCacheData(killSwitch = settings)
}
}
override suspend fun getStartConfigs(): Set<String> {
return getStore().get()?.startConfigs ?: emptySet()
}
override suspend fun setStartConfigs(configs: Set<String>) {
val store = getStore()
store.update { current ->
current?.copy(startConfigs = configs) ?: DaemonCacheData(startConfigs = configs)
}
}
}
@@ -0,0 +1,107 @@
package com.zaneschepke.wireguardautotunnel.daemon.data
import co.touchlab.kermit.Logger
import com.russhwolf.settings.PropertiesSettings
import com.russhwolf.settings.Settings
import com.russhwolf.settings.set
import com.zaneschepke.wireguardautotunnel.core.helper.PermissionsHelper
import java.io.FileOutputStream
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.util.Properties
import org.apache.commons.lang3.SystemUtils
class SettingsDaemonCacheRepository(private val baseCacheDir: Path = getCacheBaseDir()) :
DaemonCacheRepository {
private val settings: Settings by lazy {
val storePathNio = baseCacheDir.resolve(CACHE_FILE_NAME)
// create cache dir
if (Files.notExists(baseCacheDir)) {
Files.createDirectories(baseCacheDir)
}
// secure the cache dir for admin/root only
PermissionsHelper.secureDaemonDataDirectory(baseCacheDir)
// load data
val props = Properties()
if (Files.exists(storePathNio) && Files.size(storePathNio) > 0) {
Files.newInputStream(storePathNio).use { props.load(it) }
}
// save
PropertiesSettings(props) {
try {
FileOutputStream(storePathNio.toFile()).use { output ->
props.store(output, "WireGuard AutoTunnel Daemon Cache")
}
} catch (e: Exception) {
Logger.e(e) { "Failed to save settings to disk" }
}
}
}
companion object {
const val CACHE_FILE_NAME = "cache.properties"
private const val KEY_KS_ENABLED = "killswitch_enabled"
private const val KEY_KS_BYPASS_LAN = "killswitch_bypass_lan"
private const val KEY_KS_RESTORE = "killswitch_restore"
private const val KEY_LAST_TUNNEL = "last_active_tunnel"
private const val KEY_LAST_TUNNEL_ID = "last_active_tunnel_id"
private const val KEY_LAST_TUNNEL_NAME = "last_active_tunnel_name"
private const val KEY_RESTORE_ON_BOOT = "restore_on_boot"
private fun getCacheBaseDir(): Path {
return when {
SystemUtils.IS_OS_MAC_OSX -> Paths.get("/Library/Application Support/wgtunnel")
SystemUtils.IS_OS_WINDOWS -> Paths.get(System.getenv("PROGRAMDATA") + "\\wgtunnel")
else -> Paths.get("/var/lib/wgtunnel")
}
}
}
override suspend fun updateKillSwitchEnabled(enabled: Boolean) =
settings.set(KEY_KS_ENABLED, enabled)
override suspend fun getKillSwitchEnabled(): Boolean =
settings.getBoolean(KEY_KS_ENABLED, false)
override suspend fun updateKillSwitchBypassLan(enabled: Boolean) =
settings.set(KEY_KS_BYPASS_LAN, enabled)
override suspend fun getKillSwitchBypassLan(): Boolean =
settings.getBoolean(KEY_KS_BYPASS_LAN, false)
override suspend fun updateKillSwitchRestore(enabled: Boolean) =
settings.set(KEY_KS_RESTORE, enabled)
override suspend fun getKillSwitchRestore(): Boolean =
settings.getBoolean(KEY_KS_RESTORE, false)
override suspend fun updateLastActiveTunnelConfig(quick: String) =
settings.set(KEY_LAST_TUNNEL, quick)
override suspend fun getLastActiveTunnelConfig(): String? =
settings.getStringOrNull(KEY_LAST_TUNNEL)
override suspend fun updateLastActiveTunnelId(tunnelId: Long) =
settings.set(KEY_LAST_TUNNEL_ID, tunnelId)
override suspend fun getLastActiveTunnelId(): Long? = settings.getLongOrNull(KEY_LAST_TUNNEL_ID)
override suspend fun updateLastActiveTunnelName(tunnelName: String) =
settings.set(KEY_LAST_TUNNEL_NAME, tunnelName)
override suspend fun getLastActiveTunnelName(): String? =
settings.getStringOrNull(KEY_LAST_TUNNEL_NAME)
override suspend fun setRestoreTunnelOnBoot(enabled: Boolean) =
settings.set(KEY_RESTORE_ON_BOOT, enabled)
override suspend fun getRestoreTunnelOnBoot(): Boolean =
settings.getBoolean(KEY_RESTORE_ON_BOOT, false)
}
@@ -1,9 +0,0 @@
package com.zaneschepke.wireguardautotunnel.daemon.data.model
import kotlinx.serialization.Serializable
@Serializable
data class DaemonCacheData(
val killSwitch: KillSwitchSettings = KillSwitchSettings(false, false),
val startConfigs: Set<String> = emptySet(),
)
@@ -1,5 +0,0 @@
package com.zaneschepke.wireguardautotunnel.daemon.data.model
import kotlinx.serialization.Serializable
@Serializable data class KillSwitchSettings(val enabled: Boolean, val bypassLan: Boolean)
@@ -3,7 +3,7 @@ package com.zaneschepke.wireguardautotunnel.daemon.di
import com.zaneschepke.wireguardautotunnel.core.ipc.IPC import com.zaneschepke.wireguardautotunnel.core.ipc.IPC
import com.zaneschepke.wireguardautotunnel.daemon.TunnelDaemon import com.zaneschepke.wireguardautotunnel.daemon.TunnelDaemon
import com.zaneschepke.wireguardautotunnel.daemon.data.DaemonCacheRepository import com.zaneschepke.wireguardautotunnel.daemon.data.DaemonCacheRepository
import com.zaneschepke.wireguardautotunnel.daemon.data.KStoreDaemonCacheRepository import com.zaneschepke.wireguardautotunnel.daemon.data.SettingsDaemonCacheRepository
import com.zaneschepke.wireguardautotunnel.tunnel.AmneziaBackend import com.zaneschepke.wireguardautotunnel.tunnel.AmneziaBackend
import com.zaneschepke.wireguardautotunnel.tunnel.Backend import com.zaneschepke.wireguardautotunnel.tunnel.Backend
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
@@ -17,6 +17,6 @@ val daemonModule = module {
} }
} }
single<Backend> { AmneziaBackend() } single<Backend> { AmneziaBackend() }
single<DaemonCacheRepository> { KStoreDaemonCacheRepository() } single<DaemonCacheRepository> { SettingsDaemonCacheRepository() }
single { TunnelDaemon(get(), get(), get(), IPC.getDaemonSocketPath()) } single { TunnelDaemon(get(), get(), get(), IPC.getDaemonSocketPath()) }
} }
@@ -14,8 +14,8 @@ val hmacShieldPlugin =
createApplicationPlugin("HmacShield") { createApplicationPlugin("HmacShield") {
onCall { call -> onCall { call ->
// ignore daemon routes // ignore daemon health calls
if (call.request.path().contains(Routes.DAEMON_BASE)) { if (call.request.path() == Routes.DAEMON_BASE) {
return@onCall return@onCall
} }
@@ -5,7 +5,7 @@ import com.zaneschepke.wireguardautotunnel.core.ipc.Routes
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.TunnelStatus import com.zaneschepke.wireguardautotunnel.core.ipc.dto.TunnelStatus
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.request.KillSwitchRequest import com.zaneschepke.wireguardautotunnel.core.ipc.dto.request.FlagRequest
import com.zaneschepke.wireguardautotunnel.daemon.dto.toDto import com.zaneschepke.wireguardautotunnel.daemon.dto.toDto
import com.zaneschepke.wireguardautotunnel.daemon.dto.toInternal import com.zaneschepke.wireguardautotunnel.daemon.dto.toInternal
import com.zaneschepke.wireguardautotunnel.parser.ActiveConfig import com.zaneschepke.wireguardautotunnel.parser.ActiveConfig
@@ -17,19 +17,17 @@ import io.ktor.server.request.*
import io.ktor.server.response.* import io.ktor.server.response.*
import io.ktor.server.routing.* import io.ktor.server.routing.*
import io.ktor.server.websocket.* import io.ktor.server.websocket.*
import io.ktor.websocket.*
import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.distinctUntilChanged import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flatMapLatest import kotlinx.coroutines.flow.flatMapLatest
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
@OptIn(ExperimentalCoroutinesApi::class) @OptIn(ExperimentalCoroutinesApi::class)
fun Route.backendRoutes(backend: Backend) { fun Route.backendRoutes(backend: Backend) {
post(Routes.BACKEND_MODE) { put(Routes.BACKEND_MODE) {
val mode = call.receive<BackendMode>() val mode = call.receive<BackendMode>()
Logger.i { "Setting backend mode to $mode" } Logger.i { "Setting backend mode to $mode" }
@@ -37,16 +35,28 @@ fun Route.backendRoutes(backend: Backend) {
call.respond(HttpStatusCode.OK, "Backend mode set to $mode") call.respond(HttpStatusCode.OK, "Backend mode set to $mode")
} }
post(Routes.BACKEND_KILL_SWITCH) { put(Routes.BACKEND_KILL_SWITCH_BYPASS) {
val request = call.receive<KillSwitchRequest>() val request = call.receive<FlagRequest>()
Logger.i { "Setting backend bypass lan to $request" }
Logger.i {
"Setting kill switch to enabled: ${request.enable} and bypassLan: ${request.bypassLan}"
}
backend backend
.setKillSwitch(request.enable) .setKillSwitchLanBypass(request.value)
.onSuccess { .onSuccess {
call.respond(HttpStatusCode.OK, "Kill switch set to ${request.enable} successfully") call.respond(
HttpStatusCode.OK,
"Bypass LAN for kill switch set to ${request.value} successfully",
)
}
.onFailure { call.respond(HttpStatusCode.BadRequest, it.message ?: "Unknown error") }
}
put(Routes.BACKEND_KILL_SWITCH) {
val request = call.receive<FlagRequest>()
Logger.i { "Setting kill switch to enabled: ${request.value}" }
backend
.setKillSwitch(request.value)
.onSuccess {
call.respond(HttpStatusCode.OK, "Kill switch set to ${request.value} successfully")
} }
.onFailure { .onFailure {
if (it is BackendException.StateConflict) if (it is BackendException.StateConflict)
@@ -1,16 +1,37 @@
package com.zaneschepke.wireguardautotunnel.daemon.routes package com.zaneschepke.wireguardautotunnel.daemon.routes
import co.touchlab.kermit.Logger
import com.zaneschepke.wireguardautotunnel.core.ipc.Routes import com.zaneschepke.wireguardautotunnel.core.ipc.Routes
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.request.FlagRequest
import com.zaneschepke.wireguardautotunnel.daemon.data.DaemonCacheRepository
import io.ktor.http.* import io.ktor.http.*
import io.ktor.server.request.receive
import io.ktor.server.response.respond
import io.ktor.server.routing.* import io.ktor.server.routing.*
import io.ktor.server.websocket.* import io.ktor.server.websocket.*
import kotlinx.coroutines.awaitCancellation import kotlinx.coroutines.awaitCancellation
fun Route.daemonRoutes() { fun Route.daemonRoutes(daemonCacheRepository: DaemonCacheRepository) {
get(Routes.DAEMON_STATUS) { call.response.status(HttpStatusCode.OK) } get(Routes.DAEMON_STATUS) { call.response.status(HttpStatusCode.OK) }
webSocket(Routes.DAEMON_STATUS_WS) { webSocket(Routes.DAEMON_STATUS_WS) {
try { try {
awaitCancellation() awaitCancellation()
} finally {} } finally {}
} }
put(Routes.DAEMON_RESTORE_TUNNEL) {
val request = call.receive<FlagRequest>()
Logger.d { "Updating restore tunnel to ${request.value}" }
daemonCacheRepository.setRestoreTunnelOnBoot(request.value)
Logger.d { "Successfully updated restore tunnel to ${request.value}" }
call.respond(HttpStatusCode.OK, "Tunnel restore updated to ${request.value}")
}
put(Routes.DAEMON_RESTORE_KILL_SWITCH) {
val request = call.receive<FlagRequest>()
Logger.d { "Updating restore kill switch to ${request.value}" }
daemonCacheRepository.updateKillSwitchRestore(request.value)
Logger.d { "Successfully updated restore kill switch to ${request.value}" }
call.respond(HttpStatusCode.OK, "Kill switch restore updated to ${request.value}")
}
} }
@@ -3,6 +3,7 @@ package com.zaneschepke.wireguardautotunnel.daemon.routes
import co.touchlab.kermit.Logger import co.touchlab.kermit.Logger
import com.zaneschepke.wireguardautotunnel.core.ipc.Routes import com.zaneschepke.wireguardautotunnel.core.ipc.Routes
import com.zaneschepke.wireguardautotunnel.core.ipc.dto.request.StartTunnelRequest import com.zaneschepke.wireguardautotunnel.core.ipc.dto.request.StartTunnelRequest
import com.zaneschepke.wireguardautotunnel.daemon.data.DaemonCacheRepository
import com.zaneschepke.wireguardautotunnel.daemon.tunnel.RunningTunnel import com.zaneschepke.wireguardautotunnel.daemon.tunnel.RunningTunnel
import com.zaneschepke.wireguardautotunnel.tunnel.Backend import com.zaneschepke.wireguardautotunnel.tunnel.Backend
import com.zaneschepke.wireguardautotunnel.tunnel.util.BackendException import com.zaneschepke.wireguardautotunnel.tunnel.util.BackendException
@@ -11,7 +12,7 @@ import io.ktor.server.request.*
import io.ktor.server.response.* import io.ktor.server.response.*
import io.ktor.server.routing.* import io.ktor.server.routing.*
fun Route.tunnelRoutes(backend: Backend) { fun Route.tunnelRoutes(backend: Backend, daemonCacheRepository: DaemonCacheRepository) {
post(Routes.Tunnels.START_TEMPLATE) { post(Routes.Tunnels.START_TEMPLATE) {
val id = val id =
@@ -23,6 +24,11 @@ fun Route.tunnelRoutes(backend: Backend) {
val tunnel = RunningTunnel(id, request.name) val tunnel = RunningTunnel(id, request.name)
Logger.d { "Updating daemon cache" }
daemonCacheRepository.updateLastActiveTunnelConfig(request.quickConfig)
daemonCacheRepository.updateLastActiveTunnelId(id)
daemonCacheRepository.updateLastActiveTunnelName(request.name)
backend backend
.start(tunnel, request.quickConfig) .start(tunnel, request.quickConfig)
.onSuccess { call.respond(HttpStatusCode.OK, "Tunnel ${request.name} started") } .onSuccess { call.respond(HttpStatusCode.OK, "Tunnel ${request.name} started") }
+3 -3
View File
@@ -27,6 +27,7 @@ orbitCompose = "11.0.0"
sonner = "0.3.9" sonner = "0.3.9"
materialKolor = "4.1.1" materialKolor = "4.1.1"
nativeTray = "1.1.0" nativeTray = "1.1.0"
mps = "1.3.0"
# Files # Files
kmpIo = "0.3.0" kmpIo = "0.3.0"
@@ -46,7 +47,6 @@ picocli = "4.7.7"
androidx-room = "2.8.4" androidx-room = "2.8.4"
androidx-sqlite = "2.6.2" androidx-sqlite = "2.6.2"
kstore = "1.0.0"
lang3 = "3.20.0" lang3 = "3.20.0"
@@ -128,8 +128,6 @@ jna-platform = { module = "net.java.dev.jna:jna-platform", version.ref = "jnaPla
androidx-room-runtime = { module = "androidx.room:room-runtime", version.ref = "androidx-room" } androidx-room-runtime = { module = "androidx.room:room-runtime", version.ref = "androidx-room" }
androidx-sqlite-bundled = { module = "androidx.sqlite:sqlite-bundled", version.ref = "androidx-sqlite" } androidx-sqlite-bundled = { module = "androidx.sqlite:sqlite-bundled", version.ref = "androidx-sqlite" }
androidx-room-compiler = { module = "androidx.room:room-compiler", version.ref = "androidx-room" } androidx-room-compiler = { module = "androidx.room:room-compiler", version.ref = "androidx-room" }
kstore = { module = "io.github.xxfast:kstore", version.ref = "kstore" }
kstore-file = { module = "io.github.xxfast:kstore-file", version.ref = "kstore" }
# Util # Util
apache-commons-lang3 = { module = "org.apache.commons:commons-lang3", version.ref = "lang3" } apache-commons-lang3 = { module = "org.apache.commons:commons-lang3", version.ref = "lang3" }
@@ -173,6 +171,8 @@ filekit-core = { module = "io.github.vinceglb:filekit-core", version.ref = "file
filekit-dialogs = { module = "io.github.vinceglb:filekit-dialogs", version.ref = "filekit" } filekit-dialogs = { module = "io.github.vinceglb:filekit-dialogs", version.ref = "filekit" }
filekit-dialogs-compose = { module = "io.github.vinceglb:filekit-dialogs-compose", version.ref = "filekit" } filekit-dialogs-compose = { module = "io.github.vinceglb:filekit-dialogs-compose", version.ref = "filekit" }
multiplatform-settings = { module = "com.russhwolf:multiplatform-settings", version.ref = "mps" }
[plugins] [plugins]
composeHotReload = { id = "org.jetbrains.compose.hot-reload", version.ref = "composeHotReload" } composeHotReload = { id = "org.jetbrains.compose.hot-reload", version.ref = "composeHotReload" }
@@ -11,13 +11,26 @@ import java.util.concurrent.ConcurrentHashMap
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.* import kotlinx.coroutines.flow.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
class AmneziaBackend : Backend { class AmneziaBackend : Backend {
private val tun = AwgTunnel.INSTANCE private val tun = AwgTunnel.INSTANCE
private val tunnelMutex = Mutex()
private val killSwitchMutex = Mutex()
private var currentMode: Backend.Mode = Backend.Mode.Userspace private var currentMode: Backend.Mode = Backend.Mode.Userspace
private val _status = MutableStateFlow(Backend.Status(false, currentMode, emptyMap())) private val _status =
MutableStateFlow(
Backend.Status(
killSwitchEnabled = false,
killSwitchLanBypassEnabled = false,
mode = currentMode,
activeTunnels = emptyMap(),
)
)
override val status: Flow<Backend.Status> = _status.asStateFlow() override val status: Flow<Backend.Status> = _status.asStateFlow()
@@ -27,17 +40,27 @@ class AmneziaBackend : Backend {
private val backendScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) private val backendScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
init { init {
initKillSwitchStatus() backendScope.launch { initKillSwitchStatus() }
} }
private fun initKillSwitchStatus() { private suspend fun initKillSwitchStatus() =
val status = tun.getKillSwitchStatus() killSwitchMutex.withLock {
val enabled = status == 1 val killSwitchStatus = tun.getKillSwitchStatus()
_status.update { it.copy(killSwitchEnabled = enabled) } val killSwitchEnabled = killSwitchStatus == 1
} val bypassEnabled =
if (killSwitchEnabled) {
val bypassStatus = tun.getKillSwitchLanBypassStatus()
bypassStatus == 1
} else false
_status.update {
it.copy(
killSwitchEnabled = killSwitchEnabled,
killSwitchLanBypassEnabled = bypassEnabled,
)
}
}
@Synchronized override suspend fun start(tunnel: Tunnel, config: String): Result<Unit> = runCatching {
override fun start(tunnel: Tunnel, config: String): Result<Unit> = runCatching {
if (_status.value.activeTunnels.any { it.key.id == tunnel.id }) { if (_status.value.activeTunnels.any { it.key.id == tunnel.id }) {
throw BackendException.StateConflict("Tunnel ${tunnel.id} is already in use") throw BackendException.StateConflict("Tunnel ${tunnel.id} is already in use")
} }
@@ -58,11 +81,12 @@ class AmneziaBackend : Backend {
trySend(statusCode) trySend(statusCode)
} }
} }
val handle = val handle =
when (currentMode) { tunnelMutex.withLock {
Backend.Mode.Proxy -> tun.awgProxyTurnOn(config, statusCallback) when (currentMode) {
Backend.Mode.Userspace -> tun.awgTurnOn(config, statusCallback) Backend.Mode.Proxy -> tun.awgProxyTurnOn(config, statusCallback)
Backend.Mode.Userspace -> tun.awgTurnOn(config, statusCallback)
}
} }
if (handle < 0) { if (handle < 0) {
@@ -115,8 +139,7 @@ class AmneziaBackend : Backend {
} }
} }
@Synchronized override suspend fun stop(id: Long): Result<Unit> = runCatching {
override fun stop(id: Long): Result<Unit> = runCatching {
val tunnel = val tunnel =
tunnelHandles.keys.find { it.id == id } tunnelHandles.keys.find { it.id == id }
?: return Result.failure( ?: return Result.failure(
@@ -129,9 +152,11 @@ class AmneziaBackend : Backend {
BackendException.StateConflict("Tunnel with $id is not active.") BackendException.StateConflict("Tunnel with $id is not active.")
) )
when (currentMode) { tunnelMutex.withLock {
Backend.Mode.Proxy -> tun.awgProxyTurnOff(handle) when (currentMode) {
Backend.Mode.Userspace -> tun.awgTurnOff(handle) Backend.Mode.Proxy -> tun.awgProxyTurnOff(handle)
Backend.Mode.Userspace -> tun.awgTurnOff(handle)
}
} }
tunnelJobs.remove(tunnel)?.cancel() tunnelJobs.remove(tunnel)?.cancel()
@@ -141,7 +166,7 @@ class AmneziaBackend : Backend {
_status.update { it.copy(activeTunnels = it.activeTunnels - key) } _status.update { it.copy(activeTunnels = it.activeTunnels - key) }
} }
override fun setMode(mode: Backend.Mode) { override suspend fun setMode(mode: Backend.Mode) {
if (mode == currentMode) return if (mode == currentMode) return
shutdown() shutdown()
currentMode = mode currentMode = mode
@@ -160,7 +185,7 @@ class AmneziaBackend : Backend {
_status.update { it.copy(activeTunnels = emptyMap()) } _status.update { it.copy(activeTunnels = emptyMap()) }
} }
override fun getActiveConfig(id: Long): Result<String?> { override suspend fun getActiveConfig(id: Long): Result<String?> {
val handle = val handle =
tunnelHandles.keys.find { it.id == id }?.let { tunnelHandles[it] } tunnelHandles.keys.find { it.id == id }?.let { tunnelHandles[it] }
?: return Result.failure( ?: return Result.failure(
@@ -179,31 +204,44 @@ class AmneziaBackend : Backend {
} }
} }
override fun setKillSwitch(enabled: Boolean): Result<Unit> { override suspend fun setKillSwitch(enabled: Boolean): Result<Unit> {
if (_status.value.killSwitchEnabled == enabled) if (_status.value.killSwitchEnabled == enabled)
return Result.failure( return Result.failure(
BackendException.StateConflict("Kill switch enable: $enabled is already set.") BackendException.StateConflict("Kill switch enable: $enabled is already set.")
) )
val setValue = if (enabled) 1 else 0 val killSwitchEnabled =
val status = tun.setKillSwitch(setValue) killSwitchMutex.withLock {
if (status == -1) val setValue = if (enabled) 1 else 0
return Result.failure( val status = tun.setKillSwitch(setValue)
BackendException.InternalError( if (status == -1)
"Kill switch failed to start with error code: $status" return Result.failure(
) BackendException.InternalError(
) "Kill switch failed to start with error code: $status"
val killSwitchEnabled = status == 1 )
)
status == 1
}
_status.update { it.copy(killSwitchEnabled = killSwitchEnabled) } _status.update { it.copy(killSwitchEnabled = killSwitchEnabled) }
return Result.success(Unit) return Result.success(Unit)
} }
override suspend fun setKillSwitchLanBypass(enabled: Boolean): Result<Unit> {
if (!_status.value.killSwitchEnabled)
return Result.failure(BackendException.StateConflict("Kill switch is not active."))
killSwitchMutex.withLock {
val setValue = if (enabled) 1 else 0
tun.setKillSwitchLanBypass(setValue)
}
return Result.success(Unit)
}
private fun mapStatusCodeToState(statusCode: Int): Tunnel.State { private fun mapStatusCodeToState(statusCode: Int): Tunnel.State {
return when (statusCode) { return when (statusCode) {
0 -> Tunnel.State.Up.Healthy 0 -> Tunnel.State.Up.Healthy
1 -> Tunnel.State.Up.HandshakeFailure 1 -> Tunnel.State.Up.HandshakeFailure
2 -> Tunnel.State.Up.ResolvingDns 2 -> Tunnel.State.Up.ResolvingDns
3 -> Tunnel.State.Up.Unknown 3 -> Tunnel.State.Up.Unknown
else -> Tunnel.State.Down // unknown or negative error code consider down else -> Tunnel.State.Down
} }
} }
} }
@@ -4,17 +4,19 @@ import com.zaneschepke.wireguardautotunnel.tunnel.model.TunnelKey
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
interface Backend { interface Backend {
fun start(tunnel: Tunnel, config: String): Result<Unit> suspend fun start(tunnel: Tunnel, config: String): Result<Unit>
fun stop(id: Long): Result<Unit> suspend fun stop(id: Long): Result<Unit>
fun setMode(mode: Mode) suspend fun setMode(mode: Mode)
fun setKillSwitch(enabled: Boolean): Result<Unit> suspend fun setKillSwitch(enabled: Boolean): Result<Unit>
suspend fun setKillSwitchLanBypass(enabled: Boolean): Result<Unit>
fun shutdown() fun shutdown()
fun getActiveConfig(id: Long): Result<String?> suspend fun getActiveConfig(id: Long): Result<String?>
val status: Flow<Status> val status: Flow<Status>
@@ -26,6 +28,7 @@ interface Backend {
data class Status( data class Status(
val killSwitchEnabled: Boolean, val killSwitchEnabled: Boolean,
val killSwitchLanBypassEnabled: Boolean,
val mode: Mode, val mode: Mode,
val activeTunnels: Map<TunnelKey, Tunnel.State>, val activeTunnels: Map<TunnelKey, Tunnel.State>,
) )
@@ -26,6 +26,10 @@ interface AwgTunnel : Library {
fun setKillSwitch(value: Int): Int // 1 for enable, 0 for disable, return 1 or -1 for error fun setKillSwitch(value: Int): Int // 1 for enable, 0 for disable, return 1 or -1 for error
fun setKillSwitchLanBypass(value: Int): Int
fun getKillSwitchLanBypassStatus(): Int
fun getKillSwitchStatus(): Int // 1 for enabled, 0 for disabled fun getKillSwitchStatus(): Int // 1 for enabled, 0 for disabled
companion object { companion object {
@@ -4,6 +4,8 @@ package killswitch
import "C" import "C"
import ( import (
"net/netip"
"github.com/wgtunnel/desktop/tunnel/shared" "github.com/wgtunnel/desktop/tunnel/shared"
"github.com/wgtunnel/desktop/tunnel/vpn/firewall/osfirewall/firewallmgr" "github.com/wgtunnel/desktop/tunnel/vpn/firewall/osfirewall/firewallmgr"
) )
@@ -50,3 +52,70 @@ func getKillSwitchStatus() C.int {
} }
return C.int(0) return C.int(0)
} }
//export setKillSwitchLanBypass
func setKillSwitchLanBypass(enabled C.int) C.int {
fw, err := firewallmgr.Get()
if err != nil {
logger.Errorf("Failed to get firewall: %v", err)
return C.int(-1)
}
if !fw.IsEnabled() {
logger.Errorf("Firewall is not active")
return C.int(-1)
}
if enabled == 1 {
localPrefixes := []netip.Prefix{
// IPv4 Private Ranges (RFC 1918)
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/12"),
netip.MustParsePrefix("192.168.0.0/16"),
// IPv4 Link-Local (APIPA)
netip.MustParsePrefix("169.254.0.0/16"),
// IPv4 Loopback
netip.MustParsePrefix("127.0.0.0/8"),
// IPv4 Multicast (for local discovery, e.g., mDNS)
netip.MustParsePrefix("224.0.0.0/4"),
// IPv6 Unique Local Addresses (ULA, RFC 4193)
netip.MustParsePrefix("fc00::/7"),
// IPv6 Link-Local (RFC 4291)
netip.MustParsePrefix("fe80::/10"),
// IPv6 Loopback
netip.MustParsePrefix("::1/128"),
// IPv6 Multicast (for local discovery)
netip.MustParsePrefix("ff00::/8"),
}
err := fw.AllowLocalNetworks(localPrefixes)
if err != nil {
logger.Errorf("Failed to enable kill switch: %v", err)
return C.int(-1)
}
logger.Verbosef("Kill switch enabled")
} else {
fw.RemoveLocalNetworks()
}
return enabled
}
//export getKillSwitchLanBypassStatus
func getKillSwitchLanBypassStatus() C.int {
fw, err := firewallmgr.Get()
if err != nil {
logger.Errorf("Failed to get firewall: %v", err)
return C.int(0)
}
if fw.IsAllowLocalNetworksEnabled() {
return C.int(1)
}
return C.int(0)
}
+1 -1
View File
@@ -51,7 +51,7 @@ func newConn() (*Conn, error) {
} }
return &Conn{ return &Conn{
conn: conn, conn: conn,
obj: conn.Object(dbusDest, dbus.ObjectPath(dbusPath)), obj: conn.Object(dbusDest, dbusPath),
}, nil }, nil
} }
@@ -24,4 +24,9 @@ type Firewall interface {
// AllowLocalNetworks adds bypass rules for the specified local network prefixes. Requires kill switch enabled and // AllowLocalNetworks adds bypass rules for the specified local network prefixes. Requires kill switch enabled and
// operates independently of tunnel/router bypasses. // operates independently of tunnel/router bypasses.
AllowLocalNetworks([]netip.Prefix) error AllowLocalNetworks([]netip.Prefix) error
// RemoveLocalNetworks removes any rules set by AllowLocalNetworks
RemoveLocalNetworks() error
IsAllowLocalNetworksEnabled() bool
} }
@@ -247,10 +247,7 @@ func (f *LinuxFirewall) AllowLocalNetworks(prefixes []netip.Prefix) error {
} }
// remove any old rules // remove any old rules
for _, rule := range f.localAddrRules { f.RemoveLocalNetworks()
f.conn.DelRule(rule)
}
f.localAddrRules = nil
// add bypass rules for each prefix // add bypass rules for each prefix
for _, table := range f.getTables() { for _, table := range f.getTables() {
@@ -258,6 +255,18 @@ func (f *LinuxFirewall) AllowLocalNetworks(prefixes []netip.Prefix) error {
if err != nil { if err != nil {
return fmt.Errorf("get output chain: %w", err) return fmt.Errorf("get output chain: %w", err)
} }
// temp remove drop rules
dropTemplate := createDropRule(table.Filter, outputChain)
existingDrop, err := findRule(f.conn, dropTemplate)
if err != nil {
return fmt.Errorf("find drop rule: %w", err)
}
if existingDrop != nil {
f.conn.DelRule(existingDrop)
}
// add the local bypass rules
for _, prefix := range prefixes { for _, prefix := range prefixes {
if prefix.Addr().Is6() && !f.v6Available { if prefix.Addr().Is6() && !f.v6Available {
continue continue
@@ -272,14 +281,33 @@ func (f *LinuxFirewall) AllowLocalNetworks(prefixes []netip.Prefix) error {
f.localAddrRules = append(f.localAddrRules, rule) f.localAddrRules = append(f.localAddrRules, rule)
} }
} }
// add drop rule back
dropRule := createDropRule(table.Filter, outputChain)
f.conn.AddRule(dropRule)
} }
if err := f.conn.Flush(); err != nil { if err := f.conn.Flush(); err != nil {
return fmt.Errorf("flush after bypassing local addrs: %w", err) return fmt.Errorf("flush after bypassing local addrs: %w", err)
} }
f.logger.Verbosef("Bypassed local addrs: %v", prefixes) f.logger.Verbosef("Bypassed local addrs: %v", prefixes)
return nil return nil
} }
func (f *LinuxFirewall) RemoveLocalNetworks() error {
for _, rule := range f.localAddrRules {
f.conn.DelRule(rule)
}
f.localAddrRules = nil
return nil
}
func (f *LinuxFirewall) IsAllowLocalNetworksEnabled() bool {
return f.localAddrRules != nil
}
func (f *LinuxFirewall) IsEnabled() bool { func (f *LinuxFirewall) IsEnabled() bool {
return f.killSwitchEnabled.Load() return f.killSwitchEnabled.Load()
} }
@@ -1,290 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright © 2026 WG Tunnel.
// Adapted from Tailscale
//go:build darwin
package osfirewall
import (
"errors"
"fmt"
"net/netip"
"os"
"os/exec"
"path/filepath"
"strings"
"unsafe"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/wgtunnel/desktop/tunnel/vpn/firewall"
"golang.org/x/net/bpf"
"golang.org/x/sys/unix"
)
const (
anchorName = "wgtunnel"
pfConfPath = "/etc/pf.conf" // System PF config; we'll append our anchor
)
// macFirewall implements the firewall.Firewall interface for macOS using PF (Packet Filter).
type macFirewall struct {
tunnelPort uint16 // WireGuard listen port for inbound punch
killSwitchEnabled bool // Track if kill switch is active (not atomic, as PF is stateful)
v6Available bool // Whether the host supports IPv6
logger *device.Logger
}
func New(logger *device.Logger) (firewall.Firewall, error) {
v6err := CheckIPv6(logger)
supportsV6 := v6err == nil
logger.Verbosef("PF mode, v6 support: %v", supportsV6)
return &macFirewall{
v6Available: supportsV6,
logger: logger,
}, nil
}
func (f *macFirewall) HasV6Available() bool {
return f.v6Available
}
func (f *macFirewall) Active() bool {
return f.killSwitchEnabled
}
// Enable initializes the firewall (e.g., ensures PF is enabled and our anchor is referenced).
func (f *macFirewall) Up() error {
// Ensure PF is enabled (macOS default is off; enable if needed)
if err := execSudoCommand("pfctl", "-e"); err != nil && !strings.Contains(err.Error(), "already enabled") {
return fmt.Errorf("enable PF: %w", err)
}
// Add our anchor to /etc/pf.conf if not present (append if needed)
conf, err := os.ReadFile(pfConfPath)
if err != nil {
return fmt.Errorf("read pf.conf: %w", err)
}
if !strings.Contains(string(conf), fmt.Sprintf(`anchor "%s"`, anchorName)) {
if err := os.AppendFile(pfConfPath, []byte(fmt.Sprintf(`\nanchor "%s"\nload anchor "%s" from "/etc/pf.anchors/%s"\n`, anchorName, anchorName, anchorName)), 0644); err != nil {
return fmt.Errorf("append to pf.conf: %w", err)
}
if err := execSudoCommand("pfctl", "-f", pfConfPath); err != nil {
return fmt.Errorf("reload pf.conf: %w", err)
}
}
f.logger.Verbosef("PF initialized")
return nil
}
// SetTunnelPort sets the UDP port for the WireGuard tunnel and adds punch rules.
func (f *macFirewall) SetTunnelPort(port uint16) error {
rule := fmt.Sprintf("pass in quick proto udp to any port %d keep state", port)
if err := f.addRuleToAnchor(rule); err != nil {
return fmt.Errorf("add port punch rule: %w", err)
}
f.tunnelPort = port
f.logger.Verbosef("Added tunnel port punch for UDP port %d", port)
return nil
}
// ToggleKillSwitch enables/disables the kill switch.
func (f *macFirewall) ToggleKillSwitch(enable bool) error {
if enable == f.killSwitchEnabled {
return nil
}
if enable {
if err := f.addKillSwitchRules(); err != nil {
return fmt.Errorf("add kill switch rules: %w", err)
}
} else {
if err := f.delKillSwitchRules(); err != nil {
return fmt.Errorf("del kill switch rules: %w", err)
}
}
f.killSwitchEnabled = enable
f.logger.Verbosef("Kill switch toggled: %v", enable)
return nil
}
// addKillSwitchRules adds PF rules for kill switch (block non-tunnel outbound, with exemptions).
func (f *macFirewall) addKillSwitchRules() error {
rules := []string{
"block out all", // Default block outbound
"pass out quick on utun0 all", // Allow on tunnel (adjust 'utun0' dynamically if needed)
"pass out quick to <bypass_ips> all", // Placeholder for SetBypassRoutes
// Add loopback allowance
"pass out quick on lo0 all",
"pass in quick on lo0 all",
// Established/related (PF handles keep state)
"pass out all keep state",
"pass in all keep state",
}
if err := f.writeRulesToAnchor(rules); err != nil {
return err
}
f.logger.Verbosef("Kill switch rules added")
return nil
}
// delKillSwitchRules removes kill switch rules by clearing the anchor.
func (f *macFirewall) delKillSwitchRules() error {
if err := f.clearAnchor(); err != nil {
return err
}
f.logger.Verbosef("Kill switch rules removed")
return nil
}
// SetBypassRoutes adds exemptions for bootstrap routes.
func (f *macFirewall) SetBypassRoutes(bypassRoutes []netip.Prefix) error {
var rules []string
for _, route := range bypassRoutes {
rules = append(rules, fmt.Sprintf("pass out quick to %s all", route.String()))
}
if err := f.addRulesToAnchor(rules); err != nil {
return fmt.Errorf("add bypass routes: %w", err)
}
f.logger.Verbosef("Added bypass routes: %v", bypassRoutes)
return nil
}
// TemporaryBypassSocket uses BPF to attach a filter to the socket for bypass.
func (f *macFirewall) TemporaryBypassSocket(fd int) (func() error, error) {
// Compile a simple BPF program to allow specific traffic (e.g., UDP to VPN ports)
// Example: Allow UDP dport == your tunnel port (adjust as needed)
// This is a basic UDP check; extend for port/IP as needed
instructions := []bpf.Instruction{
bpf.LoadAbsolute{Off: 9, Size: 1}, // Load IP protocol (offset 9 in IP header)
bpf.JumpIf{Cond: bpf.JumpEqual, Val: 17, SkipFalse: 2}, // Check if UDP (17), jump to reject if not
// Add port check here if needed, e.g.:
// bpf.LoadAbsolute{Off: 22, Size: 2}, // Load UDP dport (network byte order, offset 20 src +2 dst in UDP)
// bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(f.tunnelPort), SkipFalse: 1},
bpf.RetConstant{Val: 65535}, // Accept (return max packet length)
bpf.RetConstant{Val: 0}, // Reject
}
prog, err := bpf.Assemble(instructions) // Compile to machine code
if err != nil {
return nil, fmt.Errorf("assemble BPF: %w", err)
}
// Prepare SockFprog struct
sockFprog := unix.SockFprog{
Len: uint16(len(prog)),
Filter: (*unix.Sockfilter)(unsafe.Pointer(&prog[0])),
}
// Attach to socket using exported SetsockoptSockFprog
if err := unix.SetsockoptSockFprog(fd, unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, &sockFprog); err != nil {
return nil, fmt.Errorf("attach BPF to fd %d: %w", fd, err)
}
f.logger.Verbosef("BPF bypass attached to fd %d", fd)
return func() error {
// Detach BPF
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DETACH_FILTER, 0); err != nil {
f.logger.Errorf("Failed to detach BPF on fd %d: %v", fd, err)
return err
}
f.logger.Verbosef("BPF detached from fd %d", fd)
return nil
}, nil
}
// Helper: writeRulesToAnchor writes rules to anchor file and reloads.
func (f *macFirewall) writeRulesToAnchor(rules []string) error {
anchorPath := filepath.Join("/etc/pf.anchors", anchorName)
content := strings.Join(rules, "\n")
if err := os.WriteFile(anchorPath, []byte(content), 0644); err != nil {
return fmt.Errorf("write anchor file: %w", err)
}
return f.reloadAnchor()
}
// addRuleToAnchor appends a single rule and reloads.
func (f *macFirewall) addRuleToAnchor(rule string) error {
anchorPath := filepath.Join("/etc/pf.anchors", anchorName)
file, err := os.OpenFile(anchorPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return fmt.Errorf("open anchor file: %w", err)
}
defer file.Close()
if _, err := file.WriteString(rule + "\n"); err != nil {
return fmt.Errorf("append rule: %w", err)
}
return f.reloadAnchor()
}
// addRulesToAnchor appends multiple rules and reloads.
func (f *macFirewall) addRulesToAnchor(rules []string) error {
anchorPath := filepath.Join("/etc/pf.anchors", anchorName)
file, err := os.OpenFile(anchorPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return fmt.Errorf("open anchor file: %w", err)
}
defer file.Close()
for _, rule := range rules {
if _, err := file.WriteString(rule + "\n"); err != nil {
return fmt.Errorf("append rule: %w", err)
}
}
return f.reloadAnchor()
}
// clearAnchor clears the anchor file and reloads.
func (f *macFirewall) clearAnchor() error {
anchorPath := filepath.Join("/etc/pf.anchors", anchorName)
if err := os.WriteFile(anchorPath, []byte{}, 0644); err != nil {
return fmt.Errorf("clear anchor file: %w", err)
}
return f.reloadAnchor()
}
// reloadAnchor reloads the PF anchor.
func (f *macFirewall) reloadAnchor() error {
if err := execSudoCommand("pfctl", "-a", anchorName, "-F", "all"); err != nil {
return fmt.Errorf("flush anchor: %w", err)
}
if err := execSudoCommand("pfctl", "-a", anchorName, "-f", filepath.Join("/etc/pf.anchors", anchorName)); err != nil {
return fmt.Errorf("load anchor: %w", err)
}
return nil
}
// execSudoCommand runs a command with sudo (assumes sudo is available; handle prompts in app if needed).
func execSudoCommand(name string, args ...string) error {
cmd := exec.Command("sudo", append([]string{name}, args...)...)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("%s: %w\nOutput: %s", name, err, output)
}
return nil
}
func CheckIPv6(logger *device.Logger) error {
// Similar to Linux: Check sysctl or interfaces for IPv6
interfaces, err := net.Interfaces()
if err != nil {
return err
}
for _, iface := range interfaces {
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ip, _, err := net.ParseCIDR(addr.String())
if err == nil && ip.To16() != nil && ip.To4() == nil {
logger.Verbosef("IPv6 detected on interface %s", iface.Name)
return nil
}
}
}
return errors.New("no IPv6 interfaces found")
}
@@ -134,12 +134,7 @@ func New(logger *device.Logger) (firewall.Firewall, error) {
func (f *WindowsFirewall) AllowLocalNetworks(addrs []netip.Prefix) error { func (f *WindowsFirewall) AllowLocalNetworks(addrs []netip.Prefix) error {
// cleanup old local addr rules // cleanup old local addr rules
if err := f.removeRules(f.localAddrRules); err != nil { f.RemoveLocalNetworks()
f.logger.Errorf("Failed to remove old local addr rules: %v", err)
}
f.mu.Lock()
f.localAddrRules = nil
f.mu.Unlock()
// add new rules // add new rules
addedByPrefix, err := f.addPermissiveRulesForPrefixes(addrs, "bypass for local addr ") addedByPrefix, err := f.addPermissiveRulesForPrefixes(addrs, "bypass for local addr ")
@@ -156,6 +151,21 @@ func (f *WindowsFirewall) AllowLocalNetworks(addrs []netip.Prefix) error {
return nil return nil
} }
func (f *WindowsFirewall) RemoveLocalNetworks() error {
if err := f.removeRules(f.localAddrRules); err != nil {
f.logger.Errorf("Failed to remove old local addr rules: %v", err)
}
f.mu.Lock()
f.localAddrRules = nil
f.mu.Unlock()
return nil
}
func (f *WindowsFirewall) IsAllowLocalNetworksEnabled() bool {
return f.localAddrRules != nil
}
func (f *WindowsFirewall) UpdatePermittedRoutes(newRoutes []netip.Prefix) error { func (f *WindowsFirewall) UpdatePermittedRoutes(newRoutes []netip.Prefix) error {
f.mu.Lock() f.mu.Lock()
// routes to remove // routes to remove