commit 478aef6952f18524a9527526b0a1db2d4a3cdb1e Author: zaneschepke Date: Tue Feb 3 06:40:04 2026 -0500 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..367c7a6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +*.iml +.kotlin +.gradle +**/build/ +xcuserdata +!src/**/build/ +local.properties +output +.idea +.DS_Store +captures +.externalNativeBuild +.cxx +*.xcodeproj/* +!*.xcodeproj/project.pbxproj +!*.xcodeproj/xcshareddata/ +!*.xcodeproj/project.xcworkspace/ +!*.xcworkspace/contents.xcworkspacedata +**/xcshareddata/WorkspaceSettings.xcsettings +node_modules/ +composeApp/generated.conveyor.conf diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..604db4d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "daemon/winsw"] + path = daemon/winsw + url = https://github.com/wgtunnel/winsw +[submodule "tunnel/tools/amneziawg-tools"] + path = tunnel/tools/amneziawg-tools + url = https://github.com/amnezia-vpn/amneziawg-tools diff --git a/README.md b/README.md new file mode 100644 index 0000000..6525db6 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +# WG Tunnel - Desktop + +A WIP project for WG Tunnel desktop. + +## Supported Platforms +- macOS (Future) +- Windows +- Linux \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts new file mode 100644 index 0000000..06c7951 --- /dev/null +++ b/build.gradle.kts @@ -0,0 +1,52 @@ +// build.gradle.kts +plugins { + alias(libs.plugins.composeHotReload) apply false + alias(libs.plugins.jetbrainsCompose) apply false + alias(libs.plugins.composeCompiler) apply false + alias(libs.plugins.kotlinMultiplatform) apply false + alias(libs.plugins.conveyor) apply false + alias(libs.plugins.moko) apply false + alias(libs.plugins.buildconfig) apply false +} + +val jvmVersion = libs.versions.jvm.get().toInt() +version = libs.versions.app.get() + +allprojects { + group = "com.zaneschepke.wireguardautotunnel" + version = version + plugins.withId("org.jetbrains.kotlin.jvm") { + extensions.configure { + jvmToolchain(jvmVersion) + } + } + + plugins.withId("org.jetbrains.kotlin.multiplatform") { + extensions.configure { + jvmToolchain(jvmVersion) + } + } +} + +registerConveyorTask( + taskName = "buildLinuxDeb", + packageType = "debian-package", + subDir = "deb", +) + +registerConveyorTask( + taskName = "buildWindowsMsix", + packageType = "windows-msix", + subDir = "windows", +) + +registerConveyorTask( + taskName = "buildConveyorSite", + packageType = "site", + subDir = "site" +) + + +tasks.register("clean") { + delete(layout.buildDirectory) +} \ No newline at end of file diff --git a/buildSrc/build.gradle.kts b/buildSrc/build.gradle.kts new file mode 100644 index 0000000..e16cc6d --- /dev/null +++ b/buildSrc/build.gradle.kts @@ -0,0 +1,13 @@ +plugins { + `kotlin-dsl` // enable the Kotlin-DSL +} + +repositories { + gradlePluginPortal() + mavenCentral() + google() +} + +dependencies { + implementation("org.apache.commons:commons-lang3:3.20.0") +} diff --git a/buildSrc/settings.gradle.kts b/buildSrc/settings.gradle.kts new file mode 100644 index 0000000..3f87d39 --- /dev/null +++ b/buildSrc/settings.gradle.kts @@ -0,0 +1 @@ +rootProject.name = "buildSrc" \ No newline at end of file diff --git a/buildSrc/src/main/kotlin/LocalProperties.kt b/buildSrc/src/main/kotlin/LocalProperties.kt new file mode 100644 index 0000000..7e4b8e5 --- /dev/null +++ b/buildSrc/src/main/kotlin/LocalProperties.kt @@ -0,0 +1,19 @@ +import java.io.File +import java.io.FileInputStream +import java.util.* + +object LocalProperties { + + private val properties by lazy { + val props = Properties() + val file = File("local.properties") + if (file.exists()) { + FileInputStream(file).use { props.load(it) } + } + props + } + + fun get(key: String): String? = properties.getProperty(key) + + fun getOrDefault(key: String, default: String): String = properties.getProperty(key, default) +} \ No newline at end of file diff --git a/buildSrc/src/main/kotlin/System.kt b/buildSrc/src/main/kotlin/System.kt new file mode 100644 index 0000000..fc831ef --- /dev/null +++ b/buildSrc/src/main/kotlin/System.kt @@ -0,0 +1,5 @@ +object SystemVar { + fun fromEnvironment(envVar : String) : String? { + return System.getenv(envVar) + } +} \ No newline at end of file diff --git a/buildSrc/src/main/kotlin/Tasks.kt b/buildSrc/src/main/kotlin/Tasks.kt new file mode 100644 index 0000000..8e5a92b --- /dev/null +++ b/buildSrc/src/main/kotlin/Tasks.kt @@ -0,0 +1,38 @@ +import org.gradle.api.Project +import org.gradle.api.tasks.Exec +import org.gradle.kotlin.dsl.register + + +fun Project.registerConveyorTask( + taskName: String, + packageType: String, + subDir: String, +) { + tasks.register(taskName) { + group = "distribution" + val outputDir = layout.buildDirectory.dir("conveyor/$subDir") + outputs.dir(outputDir) + + environment( + "CONVEYOR_PASSPHRASE", + SystemVar.fromEnvironment("CONVEYOR_PASSPHRASE") ?: LocalProperties.get("conveyor.passphrase") ?: "" + ) + + val args = mutableListOf( + "conveyor", + "--passphrase=env:CONVEYOR_PASSPHRASE", + "make", + "--output-dir", outputDir.get().asFile.absolutePath, + packageType + ) + + commandLine(args) + + dependsOn( + ":composeApp:createDistributable", + ":cli:installDist", + ":daemon:installDist", + ":composeApp:writeConveyorConfig" + ) + } +} \ No newline at end of file diff --git a/cli/.gitignore b/cli/.gitignore new file mode 100644 index 0000000..42afabf --- /dev/null +++ b/cli/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/cli/build.gradle.kts b/cli/build.gradle.kts new file mode 100644 index 0000000..d2ff5d4 --- /dev/null +++ b/cli/build.gradle.kts @@ -0,0 +1,49 @@ +plugins { + application + alias(libs.plugins.serialization) + kotlin("jvm") + kotlin("kapt") +} + +dependencies { + implementation(project(":client")) + // CLI + implementation(libs.picocli) + kapt(libs.picocli.codegen) + + // DI + implementation(libs.koin.core) + + // Logging + implementation(libs.kermit) + implementation(libs.logback.classic) + + implementation(libs.kotlinx.serialization) + implementation(libs.kotlinx.coroutines.core) +} + +kapt { + arguments { + arg("project", "${project.group}/${project.name}") + } +} + +tasks.named("installDist") { + duplicatesStrategy = DuplicatesStrategy.EXCLUDE +} + +application { + mainClass.set("com.zaneschepke.wireguardautotunnel.cli.MainKt") +} + +tasks.withType { + manifest { + attributes( + "Implementation-Title" to project.name, + "Implementation-Version" to libs.versions.app.get(), + "Main-Class" to application.mainClass.get() + ) + } +} + + diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/CliRoot.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/CliRoot.kt new file mode 100644 index 0000000..69e9fcf --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/CliRoot.kt @@ -0,0 +1,32 @@ +package com.zaneschepke.wireguardautotunnel.cli + +import com.zaneschepke.wireguardautotunnel.cli.CliRoot.Companion.BANNER +import com.zaneschepke.wireguardautotunnel.cli.commands.tunnel.TunnelCommand +import com.zaneschepke.wireguardautotunnel.cli.provider.ManifestVersionProvider +import picocli.CommandLine.Command + +@Command( + name = "wgtunnel", + description = ["CLI client for WG Tunnel."], + mixinStandardHelpOptions = true, + versionProvider = ManifestVersionProvider::class, + header = [BANNER], + subcommands = [ + TunnelCommand::class + ] +) +class CliRoot : Runnable { + override fun run() { + + } + companion object { + const val BANNER: String = ("" + + "██╗ ██╗ ██████╗ ████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ \n" + + "██║ ██║██╔════╝ ╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ \n" + + "██║ █╗ ██║██║ ███╗ ██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ \n" + + "██║███╗██║██║ ██║ ██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ \n" + + "╚███╔███╔╝╚██████╔╝ ██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗\n" + + " ╚══╝╚══╝ ╚═════╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝\n") + + } +} \ No newline at end of file diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/Main.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/Main.kt new file mode 100644 index 0000000..7309e3b --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/Main.kt @@ -0,0 +1,24 @@ +package com.zaneschepke.wireguardautotunnel.cli + +import co.touchlab.kermit.Logger +import co.touchlab.kermit.Severity +import co.touchlab.kermit.platformLogWriter +import com.zaneschepke.wireguardautotunnel.cli.commands.handler.CliExceptionHandler +import com.zaneschepke.wireguardautotunnel.cli.strategy.CliExecutionStrategy +import com.zaneschepke.wireguardautotunnel.client.di.databaseModule +import com.zaneschepke.wireguardautotunnel.client.di.serviceModule +import org.koin.core.context.startKoin +import picocli.CommandLine + +fun main(args: Array) { + Logger.setLogWriters(platformLogWriter()) + Logger.setMinSeverity(Severity.Debug) + Logger.setTag("CLI") + startKoin { + modules(databaseModule, serviceModule) + } + val commandLine = CommandLine(CliRoot::class.java) + commandLine.executionStrategy = CliExecutionStrategy(commandLine.executionStrategy) + commandLine.executionExceptionHandler = CliExceptionHandler() + commandLine.execute(*args) +} \ No newline at end of file diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/handler/CliExceptionHandler.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/handler/CliExceptionHandler.kt new file mode 100644 index 0000000..c7a038b --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/handler/CliExceptionHandler.kt @@ -0,0 +1,18 @@ +package com.zaneschepke.wireguardautotunnel.cli.commands.handler + +import picocli.CommandLine +import picocli.CommandLine.IExecutionExceptionHandler +import picocli.CommandLine.ParseResult + +class CliExceptionHandler : IExecutionExceptionHandler { + override fun handleExecutionException( + ex: Exception, + commandLine: CommandLine, + parseResult: ParseResult + ): Int { + commandLine.err.println( + commandLine.colorScheme.errorText("Error completing command: ${ex.message}") + ) + return CommandLine.ExitCode.SOFTWARE + } +} diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelCommand.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelCommand.kt new file mode 100644 index 0000000..a67e003 --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelCommand.kt @@ -0,0 +1,20 @@ +package com.zaneschepke.wireguardautotunnel.cli.commands.tunnel + +import picocli.CommandLine.Command + +@Command( + name = "tunnel", + mixinStandardHelpOptions = true, + subcommands = [ + TunnelUpCommand::class, + TunnelDownCommand::class, + TunnelImportCommand::class, + TunnelListCommand::class, + TunnelDeleteCommand::class, + ] +) +class TunnelCommand : Runnable { + override fun run() { + println("Please specify a subcommand: start, stop, list, etc..") + } +} \ No newline at end of file diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelDeleteCommand.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelDeleteCommand.kt new file mode 100644 index 0000000..eba4e2d --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelDeleteCommand.kt @@ -0,0 +1,38 @@ +package com.zaneschepke.wireguardautotunnel.cli.commands.tunnel + +import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository +import kotlinx.coroutines.runBlocking +import org.koin.java.KoinJavaComponent.inject +import picocli.CommandLine.* +import java.util.concurrent.Callable + +@Command( + name = "delete", + description = ["Delete a tunnel."], +) +class TunnelDeleteCommand : Callable { + + private val tunnelRepository: TunnelRepository by inject(TunnelRepository::class.java) + + @Option(names = ["-y", "--yes"], description = ["Delete without additional prompts."]) + var yes: Boolean? = null + + @Parameters(index = "0", paramLabel = "", description = ["The name of the tunnel to bring up."]) + lateinit var tunnelName: String + + override fun call(): Int = runBlocking { + if(yes == null) { + print("Are you sure you want to delete $tunnelName? [y/N]: ") + val userInput = readlnOrNull()?.trim()?.lowercase() + if (userInput != "y" && userInput != "yes") return@runBlocking 0 + } + try { + tunnelRepository.deleteByName(tunnelName) + } catch (_: Exception) { + System.err.println("Failed to delete $tunnelName! Check that the service is running.") + return@runBlocking 1 + } + + return@runBlocking 0 + } +} \ No newline at end of file diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelDownCommand.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelDownCommand.kt new file mode 100644 index 0000000..2050446 --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelDownCommand.kt @@ -0,0 +1,30 @@ +package com.zaneschepke.wireguardautotunnel.cli.commands.tunnel + +import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository +import com.zaneschepke.wireguardautotunnel.client.service.TunnelCommandService +import kotlinx.coroutines.runBlocking +import org.koin.java.KoinJavaComponent.inject +import picocli.CommandLine.Command +import picocli.CommandLine.Parameters + +@Command(name = "down", description = ["Bring a tunnel down."]) +class TunnelDownCommand : Runnable { + private val tunnelService: TunnelCommandService by inject(TunnelCommandService::class.java) + + private val tunnelRepository: TunnelRepository by inject(TunnelRepository::class.java) + + @Parameters(index = "0", paramLabel = "", description = ["The name of the tunnel to bring down."]) + lateinit var tunnelName: String + + override fun run() { + runBlocking { + val tunnel = tunnelRepository.getTunnelByName(tunnelName) ?: return@runBlocking println("Tunnel $tunnelName not found") + val result = tunnelService.stopTunnel(tunnel.id) + if (result.isSuccess) { + println("Tunnel stopped successfully.") + } else { + println("Failed to stop tunnel: ${result.exceptionOrNull()?.message ?: "Unknown error"}") + } + } + } +} \ No newline at end of file diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelImportCommand.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelImportCommand.kt new file mode 100644 index 0000000..9a12231 --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelImportCommand.kt @@ -0,0 +1,67 @@ +package com.zaneschepke.wireguardautotunnel.cli.commands.tunnel + +import com.zaneschepke.wireguardautotunnel.client.service.TunnelImportService +import kotlinx.coroutines.runBlocking +import org.koin.java.KoinJavaComponent.inject +import picocli.CommandLine.* +import java.io.File +import java.util.concurrent.Callable + +@Command( + name = "import", + description = ["Import configuration from a file, string, or stdin."] +) +class TunnelImportCommand : Callable { + + private val tunnelImportService: TunnelImportService by inject(TunnelImportService::class.java) + + @ArgGroup(exclusive = true, multiplicity = "1") + lateinit var input: Input + + class Input { + @Option(names = ["--file"], description = ["Import config from file"]) + var file: File? = null + + @Option(names = ["--string"], description = ["Import config from string literal"]) + var string: String? = null + } + + @Option(names = ["--name"], description = ["Specify a tunnel name"]) + var name: String? = null + + override fun call(): Int = runBlocking { + val config : String = try { + when { + input.file != null -> { + val f = input.file!! + if (!f.exists()) { + System.err.println("Error: File does not exist: ${f.absolutePath}") + return@runBlocking 1 + } + if (!f.isFile) { + System.err.println("Error: Not a file: ${f.absolutePath}") + return@runBlocking 1 + } + f.readText() + } + + input.string != null -> input.string!! + + else -> { + System.err.println("Error: No input source provided. Use --file, --string, or - for stdin.") + return@runBlocking 1 + } + } + } catch (e: Exception) { + System.err.println("Error reading input: ${e.message}") + return@runBlocking 1 + } + + val name = name ?: input.file?.nameWithoutExtension + + tunnelImportService.import(config , name) + + return@runBlocking 0 + } + +} diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelListCommand.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelListCommand.kt new file mode 100644 index 0000000..9e6d791 --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelListCommand.kt @@ -0,0 +1,48 @@ +package com.zaneschepke.wireguardautotunnel.cli.commands.tunnel + +import co.touchlab.kermit.Logger +import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.Json +import org.koin.java.KoinJavaComponent.inject +import picocli.CommandLine.Command +import picocli.CommandLine.Option +import java.util.concurrent.Callable + +@Command( + name = "list", + description = ["List configured WG Tunnel tunnels."] +) +class TunnelListCommand : Callable { + + private val tunnelRepository: TunnelRepository by inject(TunnelRepository::class.java) + + @Option(names = ["--json"], description = ["Output in JSON format for scripting."]) + var json: Boolean = false + + override fun call(): Int = runBlocking { + val tunnels = try { + tunnelRepository.getAll().sortedBy { it.position } + } catch (e: Exception) { + Logger.e("failed to load tunnels", e) + System.err.println("Error: Failed to retrieve tunnels. ${e.message}") + return@runBlocking 1 + } + + if (tunnels.isEmpty()) { + println("No tunnels found") + return@runBlocking 0 + } + + if (json) { + val names = tunnels.map { it.name } + println(Json.encodeToString(names)) + } else { + // TODO better strategy for large number of tunnels + println("Configured Tunnels:") + tunnels.forEach { println(it.name) } + } + + return@runBlocking 0 + } +} diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelUpCommand.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelUpCommand.kt new file mode 100644 index 0000000..e704f46 --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/commands/tunnel/TunnelUpCommand.kt @@ -0,0 +1,29 @@ +package com.zaneschepke.wireguardautotunnel.cli.commands.tunnel + +import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository +import com.zaneschepke.wireguardautotunnel.client.service.TunnelCommandService +import kotlinx.coroutines.runBlocking +import org.koin.java.KoinJavaComponent.inject +import picocli.CommandLine.Command +import picocli.CommandLine.Parameters + +@Command(name = "up", description = ["Bring a tunnel up."]) +class TunnelUpCommand : Runnable { + private val tunnelService: TunnelCommandService by inject(TunnelCommandService::class.java) + private val tunnelRepository: TunnelRepository by inject(TunnelRepository::class.java) + + @Parameters(index = "0", paramLabel = "", description = ["The name of the tunnel to bring up."]) + lateinit var tunnelName: String + + override fun run() { + runBlocking { + val tunnel = tunnelRepository.getTunnelByName(tunnelName) ?: return@runBlocking println("Failed to find the $tunnelName") + val result = tunnelService.startTunnel(tunnel.id) + if (result.isSuccess) { + println("Tunnel start triggered successfully.") + } else { + println("Failed to start tunnel: ${result.exceptionOrNull()?.message ?: "Unknown error"}") + } + } + } +} \ No newline at end of file diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/provider/ManifestVersionProvider.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/provider/ManifestVersionProvider.kt new file mode 100644 index 0000000..a6177d6 --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/provider/ManifestVersionProvider.kt @@ -0,0 +1,10 @@ +package com.zaneschepke.wireguardautotunnel.cli.provider + +import picocli.CommandLine + +class ManifestVersionProvider : CommandLine.IVersionProvider { + override fun getVersion(): Array { + val version = ManifestVersionProvider::class.java.getPackage().implementationVersion + return if (version != null) arrayOf(version) else arrayOf("Unknown version") + } +} \ No newline at end of file diff --git a/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/strategy/CliExecutionStrategy.kt b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/strategy/CliExecutionStrategy.kt new file mode 100644 index 0000000..1a903cb --- /dev/null +++ b/cli/src/main/java/com/zaneschepke/wireguardautotunnel/cli/strategy/CliExecutionStrategy.kt @@ -0,0 +1,33 @@ +package com.zaneschepke.wireguardautotunnel.cli.strategy + +import com.zaneschepke.wireguardautotunnel.client.service.DaemonHealthService +import kotlinx.coroutines.runBlocking +import org.koin.java.KoinJavaComponent.inject +import picocli.CommandLine.* + +class CliExecutionStrategy(private val defaultStrategy: IExecutionStrategy) : IExecutionStrategy { + + val daemonHealthService : DaemonHealthService by inject(DaemonHealthService::class.java) + + override fun execute(parseResult: ParseResult): Int = runBlocking { + // Drill down to the deepest subcommand + var current = parseResult + while (current.hasSubcommand()) { + current = current.subcommand() + } + val commandSpec = current.commandSpec() + + val skipCheck = parseResult.isUsageHelpRequested || parseResult.isVersionHelpRequested + +// if (!skipCheck && !daemonHealthService.alive()) { +// throw ExecutionException( +// commandSpec.commandLine(), +// "The WG Tunnel service must be installed and started to execute this command. " + +// "Install and start it with 'wgtunnel service install -y' or, if already installed, " + +// "start the service with 'wgtunnel service start'." +// ) +// } + + return@runBlocking defaultStrategy.execute(parseResult) + } +} \ No newline at end of file diff --git a/client/.gitignore b/client/.gitignore new file mode 100644 index 0000000..42afabf --- /dev/null +++ b/client/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/client/build.gradle.kts b/client/build.gradle.kts new file mode 100644 index 0000000..5f1557f --- /dev/null +++ b/client/build.gradle.kts @@ -0,0 +1,56 @@ +import dev.icerock.gradle.MRVisibility + +plugins { + alias(libs.plugins.kotlinMultiplatform) + alias(libs.plugins.ksp) + alias(libs.plugins.room) + alias(libs.plugins.serialization) + alias(libs.plugins.moko) +} + +kotlin { + jvm() + + sourceSets { + val commonMain by getting { + dependencies { + implementation(project(":parser")) + implementation(project(":keyring")) + implementation(project(":core")) + implementation(libs.androidx.room.runtime) + implementation(libs.androidx.sqlite.bundled) + + implementation(libs.kermit) + implementation(libs.logback.classic) + + implementation(libs.kotlinx.serialization) + + api(libs.moko.core) + api(libs.moko.compose) + + // DI + implementation(libs.koin.core) + + implementation(libs.bundles.ktor.client.jvm) + + // Util + implementation(libs.apache.commons.lang3) + } + } + } +} + +dependencies { + "kspJvm"(libs.androidx.room.compiler) +} + +room { schemaDirectory("$projectDir/schemas") } + +multiplatformResources { + resourcesPackage.set("com.zaneschepke.wireguardautotunnel") + resourcesClassName.set("SharedRes") + resourcesVisibility.set(MRVisibility.Public) +} + + + diff --git a/client/schemas/com.zaneschepke.wireguardautotunnel.client.data.AppDatabase/1.json b/client/schemas/com.zaneschepke.wireguardautotunnel.client.data.AppDatabase/1.json new file mode 100644 index 0000000..23dd2f2 --- /dev/null +++ b/client/schemas/com.zaneschepke.wireguardautotunnel.client.data.AppDatabase/1.json @@ -0,0 +1,341 @@ +{ + "formatVersion": 1, + "database": { + "version": 1, + "identityHash": "d66aaaa9eeab5a2e84406838017246b1", + "entities": [ + { + "tableName": "tunnel_config", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NOT NULL, `quick_config` TEXT NOT NULL, `tunnel_networks` TEXT NOT NULL DEFAULT '', `is_primary_tunnel` INTEGER NOT NULL DEFAULT false, `active` INTEGER NOT NULL DEFAULT false, `ping_target` TEXT DEFAULT null, `is_ethernet_tunnel` INTEGER NOT NULL DEFAULT false, `is_ipv4_preferred` INTEGER NOT NULL DEFAULT true, `position` INTEGER NOT NULL DEFAULT 0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "quickConfig", + "columnName": "quick_config", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "tunnelNetworks", + "columnName": "tunnel_networks", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "''" + }, + { + "fieldPath": "isPrimaryTunnel", + "columnName": "is_primary_tunnel", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "false" + }, + { + "fieldPath": "active", + "columnName": "active", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "false" + }, + { + "fieldPath": "pingTarget", + "columnName": "ping_target", + "affinity": "TEXT", + "defaultValue": "null" + }, + { + "fieldPath": "isEthernetTunnel", + "columnName": "is_ethernet_tunnel", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "false" + }, + { + "fieldPath": "isIpv4Preferred", + "columnName": "is_ipv4_preferred", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "true" + }, + { + "fieldPath": "position", + "columnName": "position", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + }, + "indices": [ + { + "name": "index_tunnel_config_name", + "unique": true, + "columnNames": [ + "name" + ], + "orders": [], + "createSql": "CREATE UNIQUE INDEX IF NOT EXISTS `index_tunnel_config_name` ON `${TABLE_NAME}` (`name`)" + } + ] + }, + { + "tableName": "proxy_settings", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `socks5_proxy_enabled` INTEGER NOT NULL DEFAULT 0, `socks5_proxy_bind_address` TEXT, `http_proxy_enable` INTEGER NOT NULL DEFAULT 0, `http_proxy_bind_address` TEXT, `proxy_username` TEXT, `proxy_password` TEXT)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "socks5ProxyEnabled", + "columnName": "socks5_proxy_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "socks5ProxyBindAddress", + "columnName": "socks5_proxy_bind_address", + "affinity": "TEXT" + }, + { + "fieldPath": "httpProxyEnabled", + "columnName": "http_proxy_enable", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "httpProxyBindAddress", + "columnName": "http_proxy_bind_address", + "affinity": "TEXT" + }, + { + "fieldPath": "proxyUsername", + "columnName": "proxy_username", + "affinity": "TEXT" + }, + { + "fieldPath": "proxyPassword", + "columnName": "proxy_password", + "affinity": "TEXT" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + } + }, + { + "tableName": "lockdown_settings", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `bypass_lan` INTEGER NOT NULL DEFAULT 0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "bypassLan", + "columnName": "bypass_lan", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + } + }, + { + "tableName": "general_settings", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `is_restore_on_boot_enabled` INTEGER NOT NULL DEFAULT 0, `app_mode` INTEGER NOT NULL DEFAULT 0, `theme` TEXT NOT NULL DEFAULT 'AUTOMATIC', `locale` TEXT, `already_donated` INTEGER NOT NULL DEFAULT 0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "isRestoreOnBootEnabled", + "columnName": "is_restore_on_boot_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "appMode", + "columnName": "app_mode", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "theme", + "columnName": "theme", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "'AUTOMATIC'" + }, + { + "fieldPath": "locale", + "columnName": "locale", + "affinity": "TEXT" + }, + { + "fieldPath": "alreadyDonated", + "columnName": "already_donated", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + } + }, + { + "tableName": "dns_settings", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `dns_protocol` INTEGER NOT NULL DEFAULT 0, `dns_endpoint` TEXT, `global_tunnel_dns_enabled` INTEGER NOT NULL DEFAULT 0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "dnsProtocol", + "columnName": "dns_protocol", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "dnsEndpoint", + "columnName": "dns_endpoint", + "affinity": "TEXT" + }, + { + "fieldPath": "isGlobalTunnelDnsEnabled", + "columnName": "global_tunnel_dns_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + } + }, + { + "tableName": "auto_tunnel_settings", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `is_tunnel_enabled` INTEGER NOT NULL DEFAULT 0, `trusted_network_ssids` TEXT NOT NULL DEFAULT '', `is_tunnel_on_ethernet_enabled` INTEGER NOT NULL DEFAULT 0, `is_tunnel_on_wifi_enabled` INTEGER NOT NULL DEFAULT 0, `is_wildcards_enabled` INTEGER NOT NULL DEFAULT 0, `is_stop_on_no_internet_enabled` INTEGER NOT NULL DEFAULT 0, `is_tunnel_on_unsecure_enabled` INTEGER NOT NULL DEFAULT 0, `start_on_boot` INTEGER NOT NULL DEFAULT 0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "isAutoTunnelEnabled", + "columnName": "is_tunnel_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "trustedNetworkSSIDs", + "columnName": "trusted_network_ssids", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "''" + }, + { + "fieldPath": "isTunnelOnEthernetEnabled", + "columnName": "is_tunnel_on_ethernet_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "isTunnelOnWifiEnabled", + "columnName": "is_tunnel_on_wifi_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "isWildcardsEnabled", + "columnName": "is_wildcards_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "isStopOnNoInternetEnabled", + "columnName": "is_stop_on_no_internet_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "isTunnelOnUnsecureEnabled", + "columnName": "is_tunnel_on_unsecure_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "startOnBoot", + "columnName": "start_on_boot", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + } + } + ], + "setupQueries": [ + "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", + "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, 'd66aaaa9eeab5a2e84406838017246b1')" + ] + } +} \ No newline at end of file diff --git a/client/schemas/com.zaneschepke.wireguardautotunnel.shared.data.AppDatabase/1.json b/client/schemas/com.zaneschepke.wireguardautotunnel.shared.data.AppDatabase/1.json new file mode 100644 index 0000000..23dd2f2 --- /dev/null +++ b/client/schemas/com.zaneschepke.wireguardautotunnel.shared.data.AppDatabase/1.json @@ -0,0 +1,341 @@ +{ + "formatVersion": 1, + "database": { + "version": 1, + "identityHash": "d66aaaa9eeab5a2e84406838017246b1", + "entities": [ + { + "tableName": "tunnel_config", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NOT NULL, `quick_config` TEXT NOT NULL, `tunnel_networks` TEXT NOT NULL DEFAULT '', `is_primary_tunnel` INTEGER NOT NULL DEFAULT false, `active` INTEGER NOT NULL DEFAULT false, `ping_target` TEXT DEFAULT null, `is_ethernet_tunnel` INTEGER NOT NULL DEFAULT false, `is_ipv4_preferred` INTEGER NOT NULL DEFAULT true, `position` INTEGER NOT NULL DEFAULT 0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "quickConfig", + "columnName": "quick_config", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "tunnelNetworks", + "columnName": "tunnel_networks", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "''" + }, + { + "fieldPath": "isPrimaryTunnel", + "columnName": "is_primary_tunnel", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "false" + }, + { + "fieldPath": "active", + "columnName": "active", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "false" + }, + { + "fieldPath": "pingTarget", + "columnName": "ping_target", + "affinity": "TEXT", + "defaultValue": "null" + }, + { + "fieldPath": "isEthernetTunnel", + "columnName": "is_ethernet_tunnel", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "false" + }, + { + "fieldPath": "isIpv4Preferred", + "columnName": "is_ipv4_preferred", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "true" + }, + { + "fieldPath": "position", + "columnName": "position", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + }, + "indices": [ + { + "name": "index_tunnel_config_name", + "unique": true, + "columnNames": [ + "name" + ], + "orders": [], + "createSql": "CREATE UNIQUE INDEX IF NOT EXISTS `index_tunnel_config_name` ON `${TABLE_NAME}` (`name`)" + } + ] + }, + { + "tableName": "proxy_settings", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `socks5_proxy_enabled` INTEGER NOT NULL DEFAULT 0, `socks5_proxy_bind_address` TEXT, `http_proxy_enable` INTEGER NOT NULL DEFAULT 0, `http_proxy_bind_address` TEXT, `proxy_username` TEXT, `proxy_password` TEXT)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "socks5ProxyEnabled", + "columnName": "socks5_proxy_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "socks5ProxyBindAddress", + "columnName": "socks5_proxy_bind_address", + "affinity": "TEXT" + }, + { + "fieldPath": "httpProxyEnabled", + "columnName": "http_proxy_enable", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "httpProxyBindAddress", + "columnName": "http_proxy_bind_address", + "affinity": "TEXT" + }, + { + "fieldPath": "proxyUsername", + "columnName": "proxy_username", + "affinity": "TEXT" + }, + { + "fieldPath": "proxyPassword", + "columnName": "proxy_password", + "affinity": "TEXT" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + } + }, + { + "tableName": "lockdown_settings", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `bypass_lan` INTEGER NOT NULL DEFAULT 0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "bypassLan", + "columnName": "bypass_lan", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + } + }, + { + "tableName": "general_settings", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `is_restore_on_boot_enabled` INTEGER NOT NULL DEFAULT 0, `app_mode` INTEGER NOT NULL DEFAULT 0, `theme` TEXT NOT NULL DEFAULT 'AUTOMATIC', `locale` TEXT, `already_donated` INTEGER NOT NULL DEFAULT 0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "isRestoreOnBootEnabled", + "columnName": "is_restore_on_boot_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "appMode", + "columnName": "app_mode", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "theme", + "columnName": "theme", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "'AUTOMATIC'" + }, + { + "fieldPath": "locale", + "columnName": "locale", + "affinity": "TEXT" + }, + { + "fieldPath": "alreadyDonated", + "columnName": "already_donated", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + } + }, + { + "tableName": "dns_settings", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `dns_protocol` INTEGER NOT NULL DEFAULT 0, `dns_endpoint` TEXT, `global_tunnel_dns_enabled` INTEGER NOT NULL DEFAULT 0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "dnsProtocol", + "columnName": "dns_protocol", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "dnsEndpoint", + "columnName": "dns_endpoint", + "affinity": "TEXT" + }, + { + "fieldPath": "isGlobalTunnelDnsEnabled", + "columnName": "global_tunnel_dns_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + } + }, + { + "tableName": "auto_tunnel_settings", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `is_tunnel_enabled` INTEGER NOT NULL DEFAULT 0, `trusted_network_ssids` TEXT NOT NULL DEFAULT '', `is_tunnel_on_ethernet_enabled` INTEGER NOT NULL DEFAULT 0, `is_tunnel_on_wifi_enabled` INTEGER NOT NULL DEFAULT 0, `is_wildcards_enabled` INTEGER NOT NULL DEFAULT 0, `is_stop_on_no_internet_enabled` INTEGER NOT NULL DEFAULT 0, `is_tunnel_on_unsecure_enabled` INTEGER NOT NULL DEFAULT 0, `start_on_boot` INTEGER NOT NULL DEFAULT 0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "isAutoTunnelEnabled", + "columnName": "is_tunnel_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "trustedNetworkSSIDs", + "columnName": "trusted_network_ssids", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "''" + }, + { + "fieldPath": "isTunnelOnEthernetEnabled", + "columnName": "is_tunnel_on_ethernet_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "isTunnelOnWifiEnabled", + "columnName": "is_tunnel_on_wifi_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "isWildcardsEnabled", + "columnName": "is_wildcards_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "isStopOnNoInternetEnabled", + "columnName": "is_stop_on_no_internet_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "isTunnelOnUnsecureEnabled", + "columnName": "is_tunnel_on_unsecure_enabled", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + }, + { + "fieldPath": "startOnBoot", + "columnName": "start_on_boot", + "affinity": "INTEGER", + "notNull": true, + "defaultValue": "0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + } + } + ], + "setupQueries": [ + "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", + "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, 'd66aaaa9eeab5a2e84406838017246b1')" + ] + } +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/AppKeyringConverter.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/AppKeyringConverter.kt new file mode 100644 index 0000000..fc471d7 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/AppKeyringConverter.kt @@ -0,0 +1,24 @@ +package com.zaneschepke.wireguardautotunnel.client.data + +import androidx.room.ProvidedTypeConverter +import androidx.room.TypeConverter +import com.zaneschepke.wireguardautotunnel.client.data.model.EncryptedField +import com.zaneschepke.wireguardautotunnel.core.crypto.Crypto +import org.koin.java.KoinJavaComponent.inject +import javax.crypto.SecretKey + +@ProvidedTypeConverter +class AppKeyringConverter { + + private val secretKey: SecretKey by inject(SecretKey::class.java) + + @TypeConverter + fun decryptQuick(encryptedQuick: String): EncryptedField { + return EncryptedField(Crypto.decryptWithMasterKey(encryptedQuick, secretKey)) + } + + @TypeConverter + fun encryptQuick(quick: EncryptedField): String { + return Crypto.encryptWithMasterKey(quick.value, secretKey) + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/Database.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/Database.kt new file mode 100644 index 0000000..0008ed2 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/Database.kt @@ -0,0 +1,69 @@ +package com.zaneschepke.wireguardautotunnel.client.data + +import androidx.room.ConstructedBy +import androidx.room.Database +import androidx.room.RoomDatabase +import androidx.room.RoomDatabaseConstructor +import androidx.room.TypeConverters +import com.zaneschepke.wireguardautotunnel.client.data.dao.AutoTunnelSettingsDao +import com.zaneschepke.wireguardautotunnel.client.data.dao.DnsSettingsDao +import com.zaneschepke.wireguardautotunnel.client.data.dao.GeneralSettingsDao +import com.zaneschepke.wireguardautotunnel.client.data.dao.LockdownSettingsDao +import com.zaneschepke.wireguardautotunnel.client.data.dao.ProxySettingsDao +import com.zaneschepke.wireguardautotunnel.client.data.dao.TunnelConfigDao +import com.zaneschepke.wireguardautotunnel.client.data.entity.AutoTunnelSettings +import com.zaneschepke.wireguardautotunnel.client.data.entity.DnsSettings +import com.zaneschepke.wireguardautotunnel.client.data.entity.GeneralSettings +import com.zaneschepke.wireguardautotunnel.client.data.entity.LockdownSettings +import com.zaneschepke.wireguardautotunnel.client.data.entity.ProxySettings +import com.zaneschepke.wireguardautotunnel.client.data.entity.TunnelConfig +import com.zaneschepke.wireguardautotunnel.keyring.Keyring +import org.apache.commons.lang3.SystemUtils +import java.io.File + +@Database(entities = [TunnelConfig::class, ProxySettings::class, LockdownSettings::class, + GeneralSettings::class, DnsSettings::class, AutoTunnelSettings::class], version = 1, exportSchema = true) +@TypeConverters(DatabaseConverters::class, AppKeyringConverter::class) +@ConstructedBy(AppDatabaseConstructor::class) +abstract class AppDatabase : RoomDatabase() { + abstract fun tunnelConfigDao(): TunnelConfigDao + + abstract fun proxySettingsDao(): ProxySettingsDao + + abstract fun generalSettingsDao(): GeneralSettingsDao + + abstract fun autoTunnelSettingsDao(): AutoTunnelSettingsDao + + abstract fun lockdownSettingsDao(): LockdownSettingsDao + + abstract fun dnsSettingsDao(): DnsSettingsDao + + companion object { + const val DB_SECRET_KEY = "db_secret" + const val DB_KEYRING = "wg_tunnel" + const val DB_FILE_NAME = "wg_tunnel.db" + const val APP_NAME = "WGTunnel" // macos convention + + fun getDatabaseDir() : File { + val home = System.getProperty("user.home") + return when { + SystemUtils.IS_OS_WINDOWS -> { + val appData = System.getenv("APPDATA") ?: "${System.getProperty("user.home")}\\AppData\\Roaming" + File("$appData\\$APP_NAME") + } + SystemUtils.IS_OS_MAC -> { + File("$home/Library/Application Support/$APP_NAME") + } + else -> { + val xdgDataHome = System.getenv("XDG_DATA_HOME") ?: "$home/.local/share" + File("$xdgDataHome/${APP_NAME.lowercase()}") // linux lowercase convention + } + } + } + } +} + +@Suppress("NO_ACTUAL_FOR_EXPECT") +expect object AppDatabaseConstructor : RoomDatabaseConstructor { + override fun initialize(): AppDatabase +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/DatabaseCallback.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/DatabaseCallback.kt new file mode 100644 index 0000000..9d405ca --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/DatabaseCallback.kt @@ -0,0 +1,15 @@ +package com.zaneschepke.wireguardautotunnel.client.data + +import androidx.room.RoomDatabase +import androidx.sqlite.SQLiteConnection +import androidx.sqlite.execSQL + +class DatabaseCallback(private val databaseProvider: Lazy) : RoomDatabase.Callback() { + override fun onCreate(connection: SQLiteConnection) { + super.onCreate(connection) + connection.execSQL("INSERT INTO proxy_settings DEFAULT VALUES") + connection.execSQL("INSERT INTO general_settings DEFAULT VALUES") + connection.execSQL("INSERT INTO auto_tunnel_settings DEFAULT VALUES") + connection.execSQL("INSERT INTO dns_settings DEFAULT VALUES") + } +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/DatabaseConverters.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/DatabaseConverters.kt new file mode 100644 index 0000000..985b36c --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/DatabaseConverters.kt @@ -0,0 +1,45 @@ +package com.zaneschepke.wireguardautotunnel.client.data + +import androidx.room.ProvidedTypeConverter +import androidx.room.TypeConverter +import com.zaneschepke.wireguardautotunnel.client.data.model.AppMode +import com.zaneschepke.wireguardautotunnel.client.data.model.DnsProtocol +import kotlinx.serialization.json.Json + +@ProvidedTypeConverter +class DatabaseConverters { + @TypeConverter + fun listToString(value: List): String { + return Json.encodeToString(value) + } + + @TypeConverter + fun stringToList(value: String): List { + if (value.isBlank() || value.isEmpty()) return mutableListOf() + return try { + Json.decodeFromString>(value) + } catch (e: Exception) { + val list = value.split(",").toMutableList() + val json = listToString(list) + Json.decodeFromString>(json) + } + } + + @TypeConverter + fun setToString(value: Set): String { + return listToString(value.toList()) + } + + @TypeConverter + fun stringToSet(value: String): Set { + return stringToList(value).toSet() + } + + @TypeConverter fun toMode(value: Int): AppMode = AppMode.fromValue(value) + + @TypeConverter fun fromMode(mode: AppMode): Int = mode.value + + @TypeConverter fun toDnsProtocol(value: Int): DnsProtocol = DnsProtocol.fromValue(value) + + @TypeConverter fun fromDnsProtocol(mode: DnsProtocol): Int = mode.value +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/AutoTunnelSettingsDao.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/AutoTunnelSettingsDao.kt new file mode 100644 index 0000000..6de6ad0 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/AutoTunnelSettingsDao.kt @@ -0,0 +1,21 @@ +package com.zaneschepke.wireguardautotunnel.client.data.dao + +import androidx.room.Dao +import androidx.room.Query +import androidx.room.Upsert +import com.zaneschepke.wireguardautotunnel.client.data.entity.AutoTunnelSettings +import kotlinx.coroutines.flow.Flow + +@Dao +interface AutoTunnelSettingsDao { + @Query("SELECT * FROM auto_tunnel_settings LIMIT 1") + suspend fun getAutoTunnelSettings(): AutoTunnelSettings? + + @Upsert suspend fun upsert(autoTunnelSettings: AutoTunnelSettings) + + @Query("SELECT * FROM auto_tunnel_settings LIMIT 1") + fun getAutoTunnelSettingsFlow(): Flow + + @Query("UPDATE auto_tunnel_settings SET is_tunnel_enabled = :enabled") + suspend fun updateAutoTunnelEnabled(enabled: Boolean) +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/DnsSettingsDao.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/DnsSettingsDao.kt new file mode 100644 index 0000000..5457cec --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/DnsSettingsDao.kt @@ -0,0 +1,16 @@ +package com.zaneschepke.wireguardautotunnel.client.data.dao + +import androidx.room.Dao +import androidx.room.Query +import androidx.room.Upsert +import com.zaneschepke.wireguardautotunnel.client.data.entity.DnsSettings +import kotlinx.coroutines.flow.Flow + +@Dao +interface DnsSettingsDao { + @Query("SELECT * FROM dns_settings LIMIT 1") suspend fun getDnsSettings(): DnsSettings? + + @Upsert suspend fun upsert(dnsSettings: DnsSettings) + + @Query("SELECT * FROM dns_settings LIMIT 1") fun getDnsSettingsFlow(): Flow +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/GeneralSettingsDao.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/GeneralSettingsDao.kt new file mode 100644 index 0000000..d51d577 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/GeneralSettingsDao.kt @@ -0,0 +1,28 @@ +package com.zaneschepke.wireguardautotunnel.client.data.dao + +import androidx.room.Dao +import androidx.room.Query +import androidx.room.Upsert +import com.zaneschepke.wireguardautotunnel.client.data.entity.GeneralSettings +import com.zaneschepke.wireguardautotunnel.client.data.model.AppMode +import kotlinx.coroutines.flow.Flow + +@Dao +interface GeneralSettingsDao { + @Query("SELECT * FROM general_settings LIMIT 1") + suspend fun getGeneralSettings(): GeneralSettings? + + @Upsert suspend fun upsert(generalSettings: GeneralSettings) + + @Query("SELECT * FROM general_settings LIMIT 1") + fun getGeneralSettingsFlow(): Flow + + @Query("UPDATE general_settings SET theme = :theme WHERE id = 1") + suspend fun updateTheme(theme: String) + + @Query("UPDATE general_settings SET locale = :locale WHERE id = 1") + suspend fun updateLocale(locale: String) + + @Query("UPDATE general_settings SET app_mode = :appMode WHERE id = 1") + suspend fun updateAppMode(appMode: AppMode) +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/LockdownSettingsDao.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/LockdownSettingsDao.kt new file mode 100644 index 0000000..96d2c7b --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/LockdownSettingsDao.kt @@ -0,0 +1,18 @@ +package com.zaneschepke.wireguardautotunnel.client.data.dao + +import androidx.room.Dao +import androidx.room.Query +import androidx.room.Upsert +import com.zaneschepke.wireguardautotunnel.client.data.entity.LockdownSettings +import kotlinx.coroutines.flow.Flow + +@Dao +interface LockdownSettingsDao { + @Query("SELECT * FROM lockdown_settings LIMIT 1") + suspend fun getLockdownSettings(): LockdownSettings? + + @Upsert suspend fun upsert(lockdownSettings: LockdownSettings) + + @Query("SELECT * FROM lockdown_settings LIMIT 1") + fun getLockdownSettingsFlow(): Flow +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/ProxySettingsDao.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/ProxySettingsDao.kt new file mode 100644 index 0000000..52dc8f4 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/ProxySettingsDao.kt @@ -0,0 +1,16 @@ +package com.zaneschepke.wireguardautotunnel.client.data.dao + +import androidx.room.Dao +import androidx.room.Query +import androidx.room.Upsert +import com.zaneschepke.wireguardautotunnel.client.data.entity.ProxySettings +import kotlinx.coroutines.flow.Flow + +@Dao +interface ProxySettingsDao { + @Upsert suspend fun upsert(proxySettings: ProxySettings) + + @Query("SELECT * FROM proxy_settings LIMIT 1") suspend fun getProxySettings(): ProxySettings? + + @Query("SELECT * FROM proxy_settings LIMIT 1") fun getProxySettingsFlow(): Flow +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/TunnelConfigDao.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/TunnelConfigDao.kt new file mode 100644 index 0000000..abfc505 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/dao/TunnelConfigDao.kt @@ -0,0 +1,84 @@ +package com.zaneschepke.wireguardautotunnel.client.data.dao + +import androidx.room.* +import com.zaneschepke.wireguardautotunnel.client.data.entity.TunnelConfig +import kotlinx.coroutines.flow.Flow + +@Dao +interface TunnelConfigDao { + + @Upsert suspend fun upsert(t: TunnelConfig) + + @Insert(onConflict = OnConflictStrategy.REPLACE) suspend fun saveAll(t: List) + + @Query("SELECT * FROM tunnel_config WHERE id=:id") suspend fun getById(id: Long): TunnelConfig? + + @Query("UPDATE tunnel_config SET active = 0 WHERE active = 1") + suspend fun resetActiveTunnels() + + @Query("SELECT * FROM tunnel_config WHERE name=:name") + suspend fun getByName(name: String): TunnelConfig? + + @Query("SELECT * FROM tunnel_config WHERE active=1") + suspend fun getActive(): List + + @Query("SELECT * FROM tunnel_config") suspend fun getAll(): List + + @Delete suspend fun delete(t: TunnelConfig) + + @Delete suspend fun delete(t: List) + + @Query("DELETE FROM tunnel_config WHERE name = :name") + suspend fun deleteByName(name: String) + + @Query("SELECT COUNT('id') FROM tunnel_config") suspend fun count(): Long + + @Query("SELECT * FROM tunnel_config WHERE tunnel_networks LIKE '%' || :name || '%'") + suspend fun findByTunnelNetworkName(name: String): List + + @Query("UPDATE tunnel_config SET is_primary_tunnel = 0 WHERE is_primary_tunnel =1") + suspend fun resetPrimaryTunnel() + + @Query("UPDATE tunnel_config SET is_ethernet_tunnel = 0 WHERE is_ethernet_tunnel =1") + suspend fun resetEthernetTunnel() + + @Query("SELECT * FROM tunnel_config WHERE is_primary_tunnel=1") + suspend fun findByPrimary(): List + + @Query( + """ + SELECT * FROM tunnel_config + WHERE name != '${TunnelConfig.GLOBAL_CONFIG_NAME}' + ORDER BY + CASE WHEN is_primary_tunnel = 1 THEN 0 ELSE 1 END, + position ASC + LIMIT 1 + """ + ) + suspend fun getDefaultTunnel(): TunnelConfig? + + @Query( + """ + SELECT * FROM tunnel_config + WHERE name != '${TunnelConfig.GLOBAL_CONFIG_NAME}' + ORDER BY + CASE WHEN active = 1 THEN 0 + WHEN is_primary_tunnel = 1 THEN 1 + ELSE 2 END, + position ASC + LIMIT 1 + """ + ) + suspend fun getStartTunnel(): TunnelConfig? + + @Query("SELECT * FROM tunnel_config ORDER BY position") + fun getAllFlow(): Flow> + + @Query("SELECT * FROM tunnel_config WHERE name != :globalName ORDER BY position") + fun getAllTunnelsExceptGlobal( + globalName: String = TunnelConfig.GLOBAL_CONFIG_NAME + ): Flow> + + @Query("SELECT * FROM tunnel_config WHERE name = :globalName LIMIT 1") + fun getGlobalTunnel(globalName: String = TunnelConfig.GLOBAL_CONFIG_NAME): Flow +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/AutoTunnelSettings.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/AutoTunnelSettings.kt new file mode 100644 index 0000000..efe4c09 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/AutoTunnelSettings.kt @@ -0,0 +1,25 @@ +package com.zaneschepke.wireguardautotunnel.client.data.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.PrimaryKey + +@Entity(tableName = "auto_tunnel_settings") +data class AutoTunnelSettings( + @PrimaryKey(autoGenerate = true) val id: Int = 0, + @ColumnInfo(name = "is_tunnel_enabled", defaultValue = "0") + val isAutoTunnelEnabled: Boolean = false, + @ColumnInfo(name = "trusted_network_ssids", defaultValue = "") + val trustedNetworkSSIDs: Set = emptySet(), + @ColumnInfo(name = "is_tunnel_on_ethernet_enabled", defaultValue = "0") + val isTunnelOnEthernetEnabled: Boolean = false, + @ColumnInfo(name = "is_tunnel_on_wifi_enabled", defaultValue = "0") + val isTunnelOnWifiEnabled: Boolean = false, + @ColumnInfo(name = "is_wildcards_enabled", defaultValue = "0") + val isWildcardsEnabled: Boolean = false, + @ColumnInfo(name = "is_stop_on_no_internet_enabled", defaultValue = "0") + val isStopOnNoInternetEnabled: Boolean = false, + @ColumnInfo(name = "is_tunnel_on_unsecure_enabled", defaultValue = "0") + val isTunnelOnUnsecureEnabled: Boolean = false, + @ColumnInfo(name = "start_on_boot", defaultValue = "0") val startOnBoot: Boolean = false, +) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/DnsSettings.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/DnsSettings.kt new file mode 100644 index 0000000..a453198 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/DnsSettings.kt @@ -0,0 +1,16 @@ +package com.zaneschepke.wireguardautotunnel.client.data.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.PrimaryKey +import com.zaneschepke.wireguardautotunnel.client.data.model.DnsProtocol + +@Entity(tableName = "dns_settings") +data class DnsSettings( + @PrimaryKey(autoGenerate = true) val id: Int = 0, + @ColumnInfo(name = "dns_protocol", defaultValue = "0") + val dnsProtocol: DnsProtocol = DnsProtocol.fromValue(0), + @ColumnInfo(name = "dns_endpoint") val dnsEndpoint: String? = null, + @ColumnInfo(name = "global_tunnel_dns_enabled", defaultValue = "0") + val isGlobalTunnelDnsEnabled: Boolean = false, +) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/GeneralSettings.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/GeneralSettings.kt new file mode 100644 index 0000000..a0ee85b --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/GeneralSettings.kt @@ -0,0 +1,17 @@ +package com.zaneschepke.wireguardautotunnel.client.data.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.PrimaryKey +import com.zaneschepke.wireguardautotunnel.client.data.model.AppMode + +@Entity(tableName = "general_settings") +data class GeneralSettings( + @PrimaryKey(autoGenerate = true) val id: Int = 0, + @ColumnInfo(name = "is_restore_on_boot_enabled", defaultValue = "0") + val isRestoreOnBootEnabled: Boolean = false, + @ColumnInfo(name = "app_mode", defaultValue = "0") val appMode: AppMode = AppMode.fromValue(0), + @ColumnInfo(name = "theme", defaultValue = "AUTOMATIC") val theme: String = "AUTOMATIC", + @ColumnInfo(name = "locale") val locale: String? = null, + @ColumnInfo(name = "already_donated", defaultValue = "0") val alreadyDonated: Boolean = false, +) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/LockdownSettings.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/LockdownSettings.kt new file mode 100644 index 0000000..2e6216d --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/LockdownSettings.kt @@ -0,0 +1,11 @@ +package com.zaneschepke.wireguardautotunnel.client.data.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.PrimaryKey + +@Entity(tableName = "lockdown_settings") +data class LockdownSettings( + @PrimaryKey(autoGenerate = true) val id: Long = 0, + @ColumnInfo(name = "bypass_lan", defaultValue = "0") val bypassLan: Boolean = false +) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/ProxySettings.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/ProxySettings.kt new file mode 100644 index 0000000..60bc5f0 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/ProxySettings.kt @@ -0,0 +1,18 @@ +package com.zaneschepke.wireguardautotunnel.client.data.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.PrimaryKey + +@Entity(tableName = "proxy_settings") +data class ProxySettings( + @PrimaryKey(autoGenerate = true) val id: Long = 0, + @ColumnInfo(name = "socks5_proxy_enabled", defaultValue = "0") + val socks5ProxyEnabled: Boolean = false, + @ColumnInfo(name = "socks5_proxy_bind_address") val socks5ProxyBindAddress: String? = null, + @ColumnInfo(name = "http_proxy_enable", defaultValue = "0") + val httpProxyEnabled: Boolean = false, + @ColumnInfo(name = "http_proxy_bind_address") val httpProxyBindAddress: String? = null, + @ColumnInfo(name = "proxy_username") val proxyUsername: String? = null, + @ColumnInfo(name = "proxy_password") val proxyPassword: String? = null, +) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/TunnelConfig.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/TunnelConfig.kt new file mode 100644 index 0000000..902d5ec --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/entity/TunnelConfig.kt @@ -0,0 +1,32 @@ +package com.zaneschepke.wireguardautotunnel.client.data.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.Index +import androidx.room.PrimaryKey +import androidx.room.TypeConverters +import com.zaneschepke.wireguardautotunnel.client.data.AppKeyringConverter +import com.zaneschepke.wireguardautotunnel.client.data.model.EncryptedField + +@Entity(tableName = "tunnel_config", indices = [Index(value = ["name"], unique = true)]) +data class TunnelConfig( + @PrimaryKey(autoGenerate = true) val id: Int = 0, + @ColumnInfo(name = "name") val name: String, + @field:TypeConverters(AppKeyringConverter::class) + @ColumnInfo(name = "quick_config") val quickConfig: EncryptedField, + @ColumnInfo(name = "tunnel_networks", defaultValue = "") + val tunnelNetworks: Set = setOf(), + @ColumnInfo(name = "is_primary_tunnel", defaultValue = "false") + val isPrimaryTunnel: Boolean = false, + @ColumnInfo(name = "active", defaultValue = "false") val active: Boolean = false, + @ColumnInfo(name = "ping_target", defaultValue = "null") var pingTarget: String? = null, + @ColumnInfo(name = "is_ethernet_tunnel", defaultValue = "false") + val isEthernetTunnel: Boolean = false, + @ColumnInfo(name = "is_ipv4_preferred", defaultValue = "true") + val isIpv4Preferred: Boolean = true, + @ColumnInfo(name = "position", defaultValue = "0") val position: Int = 0, +) { + companion object { + const val GLOBAL_CONFIG_NAME = "4675ab06-903a-438b-8485-6ea4187a9512" + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/AutoTunnelSettingsMapper.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/AutoTunnelSettingsMapper.kt new file mode 100644 index 0000000..b5105f1 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/AutoTunnelSettingsMapper.kt @@ -0,0 +1,30 @@ +package com.zaneschepke.wireguardautotunnel.client.data.mapper + +import com.zaneschepke.wireguardautotunnel.client.data.entity.AutoTunnelSettings as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.model.AutoTunnelSettings as Domain + +fun Entity.toDomain(): Domain = + Domain( + id = id, + isAutoTunnelEnabled = isAutoTunnelEnabled, + trustedNetworkSSIDs = trustedNetworkSSIDs, + isTunnelOnEthernetEnabled = isTunnelOnEthernetEnabled, + isTunnelOnWifiEnabled = isTunnelOnWifiEnabled, + isWildcardsEnabled = isWildcardsEnabled, + isStopOnNoInternetEnabled = isStopOnNoInternetEnabled, + isTunnelOnUnsecureEnabled = isTunnelOnUnsecureEnabled, + startOnBoot = startOnBoot, + ) + +fun Domain.toEntity(): Entity = + Entity( + id = id, + isAutoTunnelEnabled = isAutoTunnelEnabled, + trustedNetworkSSIDs = trustedNetworkSSIDs, + isTunnelOnEthernetEnabled = isTunnelOnEthernetEnabled, + isTunnelOnWifiEnabled = isTunnelOnWifiEnabled, + isWildcardsEnabled = isWildcardsEnabled, + isStopOnNoInternetEnabled = isStopOnNoInternetEnabled, + isTunnelOnUnsecureEnabled = isTunnelOnUnsecureEnabled, + startOnBoot = startOnBoot, + ) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/DnsSettingsMapper.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/DnsSettingsMapper.kt new file mode 100644 index 0000000..0d8925e --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/DnsSettingsMapper.kt @@ -0,0 +1,21 @@ +package com.zaneschepke.wireguardautotunnel.client.data.mapper + +import com.zaneschepke.wireguardautotunnel.client.data.model.DnsProtocol +import com.zaneschepke.wireguardautotunnel.client.data.entity.DnsSettings as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.model.DnsSettings as Domain + +fun Entity.toDomain(): Domain = + Domain( + id = id, + dnsProtocol = dnsProtocol.value, + dnsEndpoint = dnsEndpoint, + isGlobalTunnelDnsEnabled = isGlobalTunnelDnsEnabled, + ) + +fun Domain.toEntity(): Entity = + Entity( + id = id, + dnsProtocol = DnsProtocol.fromValue(dnsProtocol), + dnsEndpoint = dnsEndpoint, + isGlobalTunnelDnsEnabled = isGlobalTunnelDnsEnabled, + ) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/LockdownMapper.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/LockdownMapper.kt new file mode 100644 index 0000000..68003fa --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/LockdownMapper.kt @@ -0,0 +1,13 @@ +package com.zaneschepke.wireguardautotunnel.client.data.mapper + +import com.zaneschepke.wireguardautotunnel.client.data.entity.LockdownSettings as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.model.LockdownSettings as Domain + +fun Entity.toDomain(): Domain = + Domain(id = id, bypassLan = bypassLan) + +fun Domain.toEntity(): Entity = + Entity( + id = id, + bypassLan = bypassLan + ) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/ProxySettingsMapper.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/ProxySettingsMapper.kt new file mode 100644 index 0000000..f837af3 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/ProxySettingsMapper.kt @@ -0,0 +1,26 @@ +package com.zaneschepke.wireguardautotunnel.client.data.mapper + +import com.zaneschepke.wireguardautotunnel.client.data.entity.ProxySettings as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.model.ProxySettings as Domain + +fun Entity.toDomain(): Domain = + Domain( + id = id, + socks5ProxyEnabled = socks5ProxyEnabled, + socks5ProxyBindAddress = socks5ProxyBindAddress, + httpProxyEnabled = httpProxyEnabled, + httpProxyBindAddress = httpProxyBindAddress, + proxyUsername = proxyUsername, + proxyPassword = proxyPassword, + ) + +fun Domain.toEntity(): Entity = + Entity( + id = id, + socks5ProxyEnabled = socks5ProxyEnabled, + socks5ProxyBindAddress = socks5ProxyBindAddress, + httpProxyEnabled = httpProxyEnabled, + httpProxyBindAddress = httpProxyBindAddress, + proxyUsername = proxyUsername, + proxyPassword = proxyPassword, + ) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/SettingsMapper.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/SettingsMapper.kt new file mode 100644 index 0000000..f430af8 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/SettingsMapper.kt @@ -0,0 +1,25 @@ +package com.zaneschepke.wireguardautotunnel.client.data.mapper + +import com.zaneschepke.wireguardautotunnel.client.data.model.Theme +import com.zaneschepke.wireguardautotunnel.client.data.entity.GeneralSettings as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.model.GeneralSettings as Domain + +fun Entity.toDomain(): Domain = + Domain( + id = id, + isRestoreOnBootEnabled = isRestoreOnBootEnabled, + appMode = appMode, + theme = Theme.valueOf(theme.uppercase()), + locale = locale, + alreadyDonated = alreadyDonated, + ) + +fun Domain.toEntity(): Entity = + Entity( + id = id, + isRestoreOnBootEnabled = isRestoreOnBootEnabled, + appMode = appMode, + theme = theme.name, + locale = locale, + alreadyDonated = alreadyDonated, + ) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/TunnelConfigMapper.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/TunnelConfigMapper.kt new file mode 100644 index 0000000..2cf7804 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/mapper/TunnelConfigMapper.kt @@ -0,0 +1,33 @@ +package com.zaneschepke.wireguardautotunnel.client.data.mapper + +import com.zaneschepke.wireguardautotunnel.client.data.model.EncryptedField +import com.zaneschepke.wireguardautotunnel.client.data.entity.TunnelConfig as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.model.TunnelConfig as Domain + +fun Entity.toDomain(): Domain = + Domain( + id = id, + name = name, + quickConfig = quickConfig.value, + tunnelNetworks = tunnelNetworks, + isPrimaryTunnel = isPrimaryTunnel, + active = active, + pingTarget = pingTarget, + isEthernetTunnel = isEthernetTunnel, + isIpv4Preferred = isIpv4Preferred, + position = position, + ) + +fun Domain.toEntity(): Entity = + Entity( + id = id, + name = name, + quickConfig = EncryptedField(quickConfig), + tunnelNetworks = tunnelNetworks, + isPrimaryTunnel = isPrimaryTunnel, + active = active, + pingTarget = pingTarget, + isEthernetTunnel = isEthernetTunnel, + isIpv4Preferred = isIpv4Preferred, + position = position, + ) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/AppMode.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/AppMode.kt new file mode 100644 index 0000000..1f249fd --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/AppMode.kt @@ -0,0 +1,12 @@ +package com.zaneschepke.wireguardautotunnel.client.data.model + +enum class AppMode(val value: Int) { + VPN(0), + PROXY(1), + LOCK_DOWN(2), + KERNEL(3); + + companion object { + fun fromValue(value: Int): com.zaneschepke.wireguardautotunnel.client.data.model.AppMode = entries.find { it.value == value } ?: VPN + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/Dns.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/Dns.kt new file mode 100644 index 0000000..91561ae --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/Dns.kt @@ -0,0 +1,30 @@ +package com.zaneschepke.wireguardautotunnel.client.data.model + +enum class DnsProtocol(val value: Int) { + SYSTEM(0), + DOH(1); + + companion object { + fun fromValue(value: Int): com.zaneschepke.wireguardautotunnel.client.data.model.DnsProtocol = + _root_ide_package_.com.zaneschepke.wireguardautotunnel.client.data.model.DnsProtocol.entries.find { it.value == value } ?: SYSTEM + } +} + +enum class DnsProvider(private val systemAddress: String, private val dohAddress: String) { + CLOUDFLARE("1.1.1.1", "https://1.1.1.1/dns-query"), + ADGUARD("94.140.14.14", "https://94.140.14.14/dns-query"); + + fun asAddress(protocol: com.zaneschepke.wireguardautotunnel.client.data.model.DnsProtocol): String { + return when (protocol) { + _root_ide_package_.com.zaneschepke.wireguardautotunnel.client.data.model.DnsProtocol.SYSTEM -> systemAddress + _root_ide_package_.com.zaneschepke.wireguardautotunnel.client.data.model.DnsProtocol.DOH -> dohAddress + } + } + + companion object { + fun fromAddress(address: String): com.zaneschepke.wireguardautotunnel.client.data.model.DnsProvider { + return entries.find { it.systemAddress == address || it.dohAddress == address } + ?: CLOUDFLARE + } + } +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/EncryptedField.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/EncryptedField.kt new file mode 100644 index 0000000..d9ff6b6 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/EncryptedField.kt @@ -0,0 +1,4 @@ +package com.zaneschepke.wireguardautotunnel.client.data.model + +@JvmInline +value class EncryptedField(val value: String) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/Theme.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/Theme.kt new file mode 100644 index 0000000..6c6ce2b --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/model/Theme.kt @@ -0,0 +1,10 @@ +package com.zaneschepke.wireguardautotunnel.client.data.model + +enum class Theme { + AUTOMATIC, + LIGHT, + DARK, + DARKER, + AMOLED, + DYNAMIC, +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomAutoTunnelSettingsRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomAutoTunnelSettingsRepository.kt new file mode 100644 index 0000000..5baab72 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomAutoTunnelSettingsRepository.kt @@ -0,0 +1,29 @@ +package com.zaneschepke.wireguardautotunnel.client.data.repository + +import com.zaneschepke.wireguardautotunnel.client.data.dao.AutoTunnelSettingsDao +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toDomain +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toEntity +import com.zaneschepke.wireguardautotunnel.client.domain.repository.AutoTunnelSettingsRepository +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import com.zaneschepke.wireguardautotunnel.client.data.entity.AutoTunnelSettings as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.model.AutoTunnelSettings as Domain + +class RoomAutoTunnelSettingsRepository(private val autoTunnelSettingsDao: AutoTunnelSettingsDao) : + AutoTunnelSettingsRepository { + override suspend fun upsert(autoTunnelSettings: Domain) { + autoTunnelSettingsDao.upsert(autoTunnelSettings.toEntity()) + } + + override val flow: Flow + get() = + autoTunnelSettingsDao.getAutoTunnelSettingsFlow().map { (it ?: Entity()).toDomain() } + + override suspend fun getAutoTunnelSettings(): Domain { + return (autoTunnelSettingsDao.getAutoTunnelSettings() ?: Entity()).toDomain() + } + + override suspend fun updateAutoTunnelEnabled(enabled: Boolean) { + autoTunnelSettingsDao.updateAutoTunnelEnabled(enabled) + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomDnsSettingsRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomDnsSettingsRepository.kt new file mode 100644 index 0000000..e218996 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomDnsSettingsRepository.kt @@ -0,0 +1,24 @@ +package com.zaneschepke.wireguardautotunnel.client.data.repository + +import com.zaneschepke.wireguardautotunnel.client.data.dao.DnsSettingsDao +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toDomain +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toEntity +import com.zaneschepke.wireguardautotunnel.client.domain.repository.DnsSettingsRepository +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import com.zaneschepke.wireguardautotunnel.client.data.entity.DnsSettings as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.model.DnsSettings as Domain + +class RoomDnsSettingsRepository(private val dnsSettingsDao: DnsSettingsDao) : + DnsSettingsRepository { + override suspend fun upsert(dnsSettings: Domain) { + dnsSettingsDao.upsert(dnsSettings.toEntity()) + } + + override val flow: Flow + get() = dnsSettingsDao.getDnsSettingsFlow().map { (it ?: Entity()).toDomain() } + + override suspend fun getDnsSettings(): Domain { + return (dnsSettingsDao.getDnsSettings() ?: Entity()).toDomain() + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomLockdownSettingsRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomLockdownSettingsRepository.kt new file mode 100644 index 0000000..31e9afc --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomLockdownSettingsRepository.kt @@ -0,0 +1,23 @@ +package com.zaneschepke.wireguardautotunnel.client.data.repository + +import com.zaneschepke.wireguardautotunnel.client.data.dao.LockdownSettingsDao +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toDomain +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toEntity +import com.zaneschepke.wireguardautotunnel.client.domain.repository.LockdownSettingsRepository +import kotlinx.coroutines.flow.map +import com.zaneschepke.wireguardautotunnel.client.data.entity.LockdownSettings as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.model.LockdownSettings as Domain + +class RoomLockdownSettingsRepository(private val lockdownSettingsDao: LockdownSettingsDao) : + LockdownSettingsRepository { + override suspend fun upsert(lockdownSettings: Domain) { + lockdownSettingsDao.upsert(lockdownSettings.toEntity()) + } + + override val flow = + lockdownSettingsDao.getLockdownSettingsFlow().map { (it ?: Entity()).toDomain() } + + override suspend fun getLockdownSettings(): Domain { + return (lockdownSettingsDao.getLockdownSettings() ?: Entity()).toDomain() + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomProxySettingsRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomProxySettingsRepository.kt new file mode 100644 index 0000000..2fdeeb3 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomProxySettingsRepository.kt @@ -0,0 +1,23 @@ +package com.zaneschepke.wireguardautotunnel.client.data.repository + +import com.zaneschepke.wireguardautotunnel.client.data.dao.ProxySettingsDao +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toDomain +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toEntity +import com.zaneschepke.wireguardautotunnel.client.domain.repository.ProxySettingsRepository +import kotlinx.coroutines.flow.map +import com.zaneschepke.wireguardautotunnel.client.data.entity.ProxySettings as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.model.ProxySettings as Domain + +class RoomProxySettingsRepository(private val proxySettingsDao: ProxySettingsDao) : + ProxySettingsRepository { + + override suspend fun upsert(proxySettings: Domain) { + proxySettingsDao.upsert(proxySettings.toEntity()) + } + + override val flow = proxySettingsDao.getProxySettingsFlow().map { (it ?: Entity()).toDomain() } + + override suspend fun getProxySettings(): Domain { + return (proxySettingsDao.getProxySettings() ?: Entity()).toDomain() + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomSettingsRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomSettingsRepository.kt new file mode 100644 index 0000000..1fb168d --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomSettingsRepository.kt @@ -0,0 +1,37 @@ +package com.zaneschepke.wireguardautotunnel.client.data.repository + +import com.zaneschepke.wireguardautotunnel.client.data.dao.GeneralSettingsDao +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toDomain +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toEntity +import com.zaneschepke.wireguardautotunnel.client.data.model.AppMode +import com.zaneschepke.wireguardautotunnel.client.data.model.Theme +import com.zaneschepke.wireguardautotunnel.client.domain.model.GeneralSettings as Domain +import com.zaneschepke.wireguardautotunnel.client.data.entity.GeneralSettings as Entity +import com.zaneschepke.wireguardautotunnel.client.domain.repository.GeneralSettingRepository +import kotlinx.coroutines.flow.map + +class RoomSettingsRepository(private val settingsDao: GeneralSettingsDao) : + GeneralSettingRepository { + + override suspend fun upsert(generalSettings: Domain) { + settingsDao.upsert(generalSettings.toEntity()) + } + + override val flow = settingsDao.getGeneralSettingsFlow().map { (it ?: Entity()).toDomain() } + + override suspend fun getGeneralSettings(): Domain { + return (settingsDao.getGeneralSettings() ?: Entity()).toDomain() + } + + override suspend fun updateTheme(theme: Theme) { + settingsDao.updateTheme(theme.name) + } + + override suspend fun updateLocale(locale: String) { + settingsDao.updateLocale(locale) + } + + override suspend fun updateAppMode(appMode: AppMode) { + settingsDao.updateAppMode(appMode) + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomTunnelRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomTunnelRepository.kt new file mode 100644 index 0000000..4d1fcc6 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/repository/RoomTunnelRepository.kt @@ -0,0 +1,97 @@ +package com.zaneschepke.wireguardautotunnel.client.data.repository + +import com.zaneschepke.wireguardautotunnel.client.data.dao.TunnelConfigDao +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toDomain +import com.zaneschepke.wireguardautotunnel.client.data.mapper.toEntity +import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import com.zaneschepke.wireguardautotunnel.client.domain.model.TunnelConfig as Domain + +class RoomTunnelRepository(private val tunnelConfigDao: TunnelConfigDao) : TunnelRepository { + + override val flow = + tunnelConfigDao.getAllFlow().map { it.map { tunnelConfig -> tunnelConfig.toDomain() } } + + override val userTunnelsFlow = + tunnelConfigDao.getAllTunnelsExceptGlobal().map { + it.map { tunnelConfig -> tunnelConfig.toDomain() } + } + + override val globalTunnelFlow: Flow = + tunnelConfigDao.getGlobalTunnel().map { it?.toDomain() } + + override suspend fun getAll(): List { + return tunnelConfigDao.getAll().map { it.toDomain() } + } + + override suspend fun save(tunnelConfig: Domain) { + tunnelConfigDao.upsert(tunnelConfig.toEntity()) + } + + override suspend fun saveAll(tunnelConfigList: List) { + tunnelConfigDao.saveAll(tunnelConfigList.map { tunnelConfig -> tunnelConfig.toEntity() }) + } + + override suspend fun updatePrimaryTunnel(tunnelConfig: Domain?) { + tunnelConfigDao.resetPrimaryTunnel() + tunnelConfig?.let { save(it.copy(isPrimaryTunnel = true)) } + } + + override suspend fun resetActiveTunnels() { + tunnelConfigDao.resetActiveTunnels() + } + + override suspend fun updateEthernetTunnel(tunnelConfig: Domain?) { + tunnelConfigDao.resetEthernetTunnel() + tunnelConfig?.let { save(it.copy(isEthernetTunnel = true)) } + } + + override suspend fun delete(tunnelConfig: Domain) { + tunnelConfigDao.delete(tunnelConfig.toEntity()) + } + + override suspend fun deleteByName(name: String) { + tunnelConfigDao.deleteByName(name) + } + + override suspend fun getById(id: Int): Domain? { + return tunnelConfigDao.getById(id.toLong())?.toDomain() + } + + override suspend fun getActive(): List { + return tunnelConfigDao.getActive().map { it.toDomain() } + } + + override suspend fun getDefaultTunnel(): Domain? { + return tunnelConfigDao.getDefaultTunnel()?.toDomain() + } + + override suspend fun getStartTunnel(): Domain? { + return tunnelConfigDao.getStartTunnel()?.toDomain() + } + + override suspend fun getTunnelByName(name: String): Domain? { + return tunnelConfigDao.getByName(name)?.toDomain() + } + + override suspend fun count(): Int { + return tunnelConfigDao.count().toInt() + } + + override suspend fun findByTunnelName(name: String): Domain? { + return tunnelConfigDao.getByName(name)?.toDomain() + } + + override suspend fun findByTunnelNetworksName(name: String): List { + return tunnelConfigDao.findByTunnelNetworkName(name).map { it.toDomain() } + } + + override suspend fun findPrimary(): List { + return tunnelConfigDao.findByPrimary().map { it.toDomain() } + } + + override suspend fun delete(tunnels: List) { + tunnelConfigDao.delete(tunnels.map { it.toEntity() }) + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/service/DefaultTunnelImportService.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/service/DefaultTunnelImportService.kt new file mode 100644 index 0000000..ea1b731 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/service/DefaultTunnelImportService.kt @@ -0,0 +1,25 @@ +package com.zaneschepke.wireguardautotunnel.client.data.service + +import com.zaneschepke.wireguardautotunnel.client.domain.model.TunnelConfig +import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository +import com.zaneschepke.wireguardautotunnel.client.domain.repository.extensions.saveTunnelsUniquely +import com.zaneschepke.wireguardautotunnel.client.service.QuickConfigMap +import com.zaneschepke.wireguardautotunnel.client.service.QuickString +import com.zaneschepke.wireguardautotunnel.client.service.TunnelImportService +import com.zaneschepke.wireguardautotunnel.client.service.TunnelName + +class DefaultTunnelImportService( + private val tunnelRepository: TunnelRepository, +) : TunnelImportService { + + override suspend fun import(config: QuickString, name: TunnelName?) { + import(mapOf(config to name)) + } + + override suspend fun import(configs: QuickConfigMap) { + val tunnelConfigs = + configs.map { (config, name) -> TunnelConfig.fromQuickString(config, name) } + val existingNames = tunnelRepository.getAll().map { it.name } + tunnelRepository.saveTunnelsUniquely(tunnelConfigs, existingNames) + } +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/service/UdsDaemonHealthService.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/service/UdsDaemonHealthService.kt new file mode 100644 index 0000000..d0bf95c --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/service/UdsDaemonHealthService.kt @@ -0,0 +1,19 @@ +package com.zaneschepke.wireguardautotunnel.client.data.service + +import com.zaneschepke.wireguardautotunnel.client.service.DaemonHealthService +import io.ktor.client.* +import io.ktor.client.request.* +import io.ktor.http.* + +class UdsDaemonHealthService( + private val client : HttpClient +) : DaemonHealthService { + override suspend fun alive(): Boolean { + return try { + client.get("/status") { + }.status.isSuccess() + } catch (_ : Exception) { + false + } + } +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/service/UdsTunnelCommandService.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/service/UdsTunnelCommandService.kt new file mode 100644 index 0000000..2cae94a --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/data/service/UdsTunnelCommandService.kt @@ -0,0 +1,107 @@ +package com.zaneschepke.wireguardautotunnel.client.data.service + +import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository +import com.zaneschepke.wireguardautotunnel.client.service.TunnelCommandService +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.StartTunnelRequest +import io.ktor.client.* +import io.ktor.client.call.* +import io.ktor.client.plugins.websocket.* +import io.ktor.client.request.* +import io.ktor.client.statement.* +import io.ktor.http.* +import io.ktor.websocket.* +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.channels.awaitClose +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.callbackFlow +import kotlinx.coroutines.flow.flowOn +import kotlinx.serialization.json.Json +import okio.IOException + +class UdsTunnelCommandService( + private val client: HttpClient, + private val tunnelRepository: TunnelRepository +) : TunnelCommandService { + + private val json = Json { ignoreUnknownKeys = true } + + override suspend fun startTunnel(id: Int): Result = runCatching { + val tunnelConfig = tunnelRepository.getById(id) + ?: throw IOException("Tunnel $id not found") + + val request = StartTunnelRequest( + id = id, + name = tunnelConfig.name, + quickConfig = tunnelConfig.quickConfig + ) + + val response = client.post("/tunnel/start") { + setBody(json.encodeToString(request)) + contentType(ContentType.Application.Json) + } + + if (!response.status.isSuccess()) { + throw IOException("Failed to start tunnel $id: ${response.status.value} - ${response.bodyAsText()}") + } + } + + override suspend fun stopTunnel(id: Int): Result = runCatching { + val response = client.post("/tunnel/stop/$id") + + if (!response.status.isSuccess()) { + throw IOException("Failed to stop tunnel $id: ${response.status.value} - ${response.bodyAsText()}") + } + } + + override suspend fun setMode(mode: BackendMode): Result = runCatching { + val response = client.post("/tunnel/mode") { + setBody(json.encodeToString(mode)) + contentType(ContentType.Text.Plain) + } + + if (!response.status.isSuccess()) { + throw IOException("Failed to set mode: ${response.bodyAsText()}") + } + } + + override suspend fun setKillSwitch(enabled: Boolean): Result = runCatching { + val response = client.post("/tunnel/kill-switch") { + setBody(enabled.toString()) + contentType(ContentType.Text.Plain) + } + + if (!response.status.isSuccess()) { + throw IOException("Failed to set kill switch: ${response.bodyAsText()}") + } + } + + override suspend fun getStatus(): Result = runCatching { + val response = client.get("/tunnel/status") + + if (!response.status.isSuccess()) { + throw IOException("Failed to get status: ${response.status.value} - ${response.bodyAsText()}") + } + + response.body() + } + + override fun statusFlow(): Flow = callbackFlow { + val session = client.webSocketSession("/tunnel/status/stream") + + try { + for (frame in session.incoming) { + if (frame is Frame.Text) { + val dto = json.decodeFromString(frame.readText()) + trySend(dto) + } + } + } catch (e: Exception) { + close(e) + } finally { + session.close() + awaitClose() + } + }.flowOn(Dispatchers.IO) +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/di/Qualifiers.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/di/Qualifiers.kt new file mode 100644 index 0000000..b979692 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/di/Qualifiers.kt @@ -0,0 +1,5 @@ +package com.zaneschepke.wireguardautotunnel.client.di + +enum class Secret { + IPC +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/di/databaseModule.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/di/databaseModule.kt new file mode 100644 index 0000000..16431b1 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/di/databaseModule.kt @@ -0,0 +1,59 @@ +package com.zaneschepke.wireguardautotunnel.client.di + +import androidx.room.Room +import androidx.room.RoomDatabase +import androidx.sqlite.driver.bundled.BundledSQLiteDriver +import com.zaneschepke.wireguardautotunnel.client.data.AppDatabase +import com.zaneschepke.wireguardautotunnel.client.data.AppKeyringConverter +import com.zaneschepke.wireguardautotunnel.client.data.DatabaseCallback +import com.zaneschepke.wireguardautotunnel.client.data.DatabaseConverters +import com.zaneschepke.wireguardautotunnel.client.data.dao.* +import com.zaneschepke.wireguardautotunnel.client.data.repository.* +import com.zaneschepke.wireguardautotunnel.client.domain.repository.* +import com.zaneschepke.wireguardautotunnel.core.crypto.Crypto +import com.zaneschepke.wireguardautotunnel.keyring.Keyring +import kotlinx.coroutines.Dispatchers +import org.koin.dsl.module +import java.io.File +import javax.crypto.SecretKey + +val databaseModule = module { + single { DatabaseCallback(lazy { get() }) } + single { + val dbKey = AppDatabase.DB_SECRET_KEY + val keyring = Keyring(AppDatabase.DB_KEYRING) + val encodedSecret = keyring.get(dbKey) ?: run { + val secret = Crypto.generateRandomBase64EncodedAesKey() + keyring.put(dbKey, secret) + secret + } + Crypto.decodeKey(encodedSecret) + } + single { + val dbFileName = AppDatabase.DB_FILE_NAME + val dbDir = AppDatabase.getDatabaseDir() + dbDir.mkdirs() + val dbFile = File(dbDir, dbFileName) + Room.databaseBuilder(dbFile.absolutePath) + .setDriver(BundledSQLiteDriver()) + .fallbackToDestructiveMigration(true) + .addCallback(get()) + .addTypeConverter(DatabaseConverters()) + .addTypeConverter(AppKeyringConverter()) + .setQueryCoroutineContext(Dispatchers.IO) + .build() + } + + single { get().tunnelConfigDao() } + single { get().autoTunnelSettingsDao() } + single { get().dnsSettingsDao() } + single { get().lockdownSettingsDao() } + single { get().proxySettingsDao() } + single { get().generalSettingsDao() } + + single() { RoomTunnelRepository(get()) } + single() { RoomAutoTunnelSettingsRepository(get()) } + single() { RoomDnsSettingsRepository(get()) } + single() { RoomLockdownSettingsRepository(get()) } + single() { RoomProxySettingsRepository(get()) } +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/di/serviceModule.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/di/serviceModule.kt new file mode 100644 index 0000000..651f4b1 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/di/serviceModule.kt @@ -0,0 +1,73 @@ +package com.zaneschepke.wireguardautotunnel.client.di + +import com.zaneschepke.wireguardautotunnel.client.data.service.DefaultTunnelImportService +import com.zaneschepke.wireguardautotunnel.client.data.service.UdsDaemonHealthService +import com.zaneschepke.wireguardautotunnel.client.data.service.UdsTunnelCommandService +import com.zaneschepke.wireguardautotunnel.client.service.DaemonHealthService +import com.zaneschepke.wireguardautotunnel.client.service.TunnelCommandService +import com.zaneschepke.wireguardautotunnel.client.service.TunnelImportService +import com.zaneschepke.wireguardautotunnel.core.crypto.HmacProtector +import com.zaneschepke.wireguardautotunnel.core.ipc.IPC +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.SecureCommand +import io.ktor.client.* +import io.ktor.client.engine.cio.* +import io.ktor.client.plugins.* +import io.ktor.client.plugins.contentnegotiation.* +import io.ktor.client.plugins.websocket.* +import io.ktor.client.request.* +import io.ktor.http.* +import io.ktor.http.content.* +import io.ktor.serialization.kotlinx.json.* +import kotlinx.serialization.json.Json +import org.koin.dsl.module + +val serviceModule = module { + single { + // so daemon knows where to look for secret + val user = System.getProperty("user.name") + HttpClient(CIO) { + defaultRequest { + unixSocket(IPC.getDaemonSocketPath()) + } + install(ContentNegotiation) { + json(Json { + ignoreUnknownKeys = true + encodeDefaults = true + }) + } + install(WebSockets) + install("HmacSigner") { + requestPipeline.intercept(HttpRequestPipeline.Before) { + + if (subject is SecureCommand) { + return@intercept + } + + val payload = when (val body = subject) { + is String -> body + is TextContent -> body.text + else -> "" + } + + val timestamp = System.currentTimeMillis() / 1000 + val signature = HmacProtector.generateSignature( + IPC.getIPCSecret(), + timestamp, + payload + ) + + val secureCommand = SecureCommand(timestamp, signature, user, payload) + context.contentType(ContentType.Application.Json) + context.setBody(secureCommand) + + proceedWith(secureCommand) + } + } + } + } + single { UdsDaemonHealthService(get()) } + single { UdsTunnelCommandService(get(), tunnelRepository = get()) } + single { + DefaultTunnelImportService(get()) + } +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/AutoTunnelSettings.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/AutoTunnelSettings.kt new file mode 100644 index 0000000..862c91d --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/AutoTunnelSettings.kt @@ -0,0 +1,19 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.model + +import kotlinx.serialization.Serializable + +@Serializable +data class AutoTunnelSettings( + val id: Int = 0, + val isAutoTunnelEnabled: Boolean = false, + val isTunnelOnMobileDataEnabled: Boolean = false, + val trustedNetworkSSIDs: Set = emptySet(), + val isTunnelOnEthernetEnabled: Boolean = false, + val isTunnelOnWifiEnabled: Boolean = false, + val isWildcardsEnabled: Boolean = false, + val isStopOnNoInternetEnabled: Boolean = false, + val debounceDelaySeconds: Int = 3, + val isTunnelOnUnsecureEnabled: Boolean = false, + val wifiDetectionMethod: Int = 0, + val startOnBoot: Boolean = false, +) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/DnsSettings.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/DnsSettings.kt new file mode 100644 index 0000000..3ad5ee3 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/DnsSettings.kt @@ -0,0 +1,11 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.model + +import kotlinx.serialization.Serializable + +@Serializable +data class DnsSettings( + val id: Int = 0, + val dnsProtocol: Int = 0, + val dnsEndpoint: String? = null, + val isGlobalTunnelDnsEnabled: Boolean = false, +) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/GeneralSettings.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/GeneralSettings.kt new file mode 100644 index 0000000..a6fd91c --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/GeneralSettings.kt @@ -0,0 +1,15 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.model + +import com.zaneschepke.wireguardautotunnel.client.data.model.AppMode +import com.zaneschepke.wireguardautotunnel.client.data.model.Theme +import kotlinx.serialization.Serializable + +@Serializable +data class GeneralSettings( + val id: Int = 0, + val isRestoreOnBootEnabled: Boolean = false, + val appMode: AppMode = AppMode.fromValue(0), + val theme: Theme = Theme.AUTOMATIC, + val locale: String? = null, + val alreadyDonated: Boolean = false, +) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/LockdownSettings.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/LockdownSettings.kt new file mode 100644 index 0000000..0b150ec --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/LockdownSettings.kt @@ -0,0 +1,11 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.model + +import kotlinx.serialization.Serializable + +@Serializable +data class LockdownSettings( + val id: Long = 0L, + val bypassLan: Boolean = false, + val metered: Boolean = false, + val dualStack: Boolean = false, +) diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/ProxySettings.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/ProxySettings.kt new file mode 100644 index 0000000..338f6d0 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/ProxySettings.kt @@ -0,0 +1,19 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.model + +import kotlinx.serialization.Serializable + +@Serializable +data class ProxySettings( + val id: Long = 0, + val socks5ProxyEnabled: Boolean = false, + val socks5ProxyBindAddress: String? = null, + val httpProxyEnabled: Boolean = false, + val httpProxyBindAddress: String? = null, + val proxyUsername: String? = null, + val proxyPassword: String? = null, +) { + companion object { + const val DEFAULT_SOCKS_BIND_ADDRESS = "127.0.0.1:25344" + const val DEFAULT_HTTP_BIND_ADDRESS = "127.0.0.1:25345" + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/TunnelConfig.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/TunnelConfig.kt new file mode 100644 index 0000000..91e2b6f --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/model/TunnelConfig.kt @@ -0,0 +1,109 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.model + +import com.zaneschepke.wireguardautotunnel.parser.Config +import kotlinx.serialization.Serializable +import kotlin.collections.get + +@Serializable +data class TunnelConfig( + val id: Int = 0, + val name: String, + val quickConfig: String, + val tunnelNetworks: Set = setOf(), + val isPrimaryTunnel: Boolean = false, + val active: Boolean = false, + val pingTarget: String? = null, + val isEthernetTunnel: Boolean = false, + val isIpv4Preferred: Boolean = true, + val position: Int = 0, +) { + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is TunnelConfig) return false + return id == other.id && + name == other.name && + quickConfig == other.quickConfig && + isPrimaryTunnel == other.isPrimaryTunnel && + isEthernetTunnel == other.isEthernetTunnel && + pingTarget == other.pingTarget && + tunnelNetworks == other.tunnelNetworks && + isIpv4Preferred == other.isIpv4Preferred + } + + override fun hashCode(): Int { + var result = id + result = 31 * result + name.hashCode() + result = 31 * result + quickConfig.hashCode() + return result + } + + fun asConfig(): Config { + return Config.parseQuickString(quickConfig) + } + + companion object { + + fun generateRandom8Digits(): String { + val digits = ('0'..'9').toList() + return (1..8).map { digits.random() }.joinToString("") + } + + private fun generateDefaultTunnelName(config: Config? = null): String { + return config?.peers[0]?.host ?: generateRandom8Digits() + } + + fun configFromQuick(quick: String): Config { + return Config.parseQuickString(quick) + } + + fun fromQuickString(quick: String, name: String? = null): TunnelConfig { + val config = configFromQuick(quick) + return tunnelConfFromConfig(config, name) + } + + private fun tunnelConfFromConfig(config: Config, name: String? = null): TunnelConfig { + return TunnelConfig( + name = name ?: generateDefaultTunnelName(config), + quickConfig = config.asQuickString(), + ) + } + private const val IPV6_ALL_NETWORKS = "::/0" + private const val IPV4_ALL_NETWORKS = "0.0.0.0/0" + val ALL_IPS = listOf(IPV4_ALL_NETWORKS, IPV6_ALL_NETWORKS) + val IPV4_PUBLIC_NETWORKS = + setOf( + "0.0.0.0/5", + "8.0.0.0/7", + "11.0.0.0/8", + "12.0.0.0/6", + "16.0.0.0/4", + "32.0.0.0/3", + "64.0.0.0/2", + "128.0.0.0/3", + "160.0.0.0/5", + "168.0.0.0/6", + "172.0.0.0/12", + "172.32.0.0/11", + "172.64.0.0/10", + "172.128.0.0/9", + "173.0.0.0/8", + "174.0.0.0/7", + "176.0.0.0/4", + "192.0.0.0/9", + "192.128.0.0/11", + "192.160.0.0/13", + "192.169.0.0/16", + "192.170.0.0/15", + "192.172.0.0/14", + "192.176.0.0/12", + "192.192.0.0/10", + "193.0.0.0/8", + "194.0.0.0/7", + "196.0.0.0/6", + "200.0.0.0/5", + "208.0.0.0/4", + ) + val LAN_BYPASS_ALLOWED_IPS = setOf(IPV6_ALL_NETWORKS) + IPV4_PUBLIC_NETWORKS + } +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/AutoTunnelSettingsRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/AutoTunnelSettingsRepository.kt new file mode 100644 index 0000000..9efe8af --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/AutoTunnelSettingsRepository.kt @@ -0,0 +1,14 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.repository + +import com.zaneschepke.wireguardautotunnel.client.domain.model.AutoTunnelSettings +import kotlinx.coroutines.flow.Flow + +interface AutoTunnelSettingsRepository { + suspend fun upsert(autoTunnelSettings: AutoTunnelSettings) + + val flow: Flow + + suspend fun getAutoTunnelSettings(): AutoTunnelSettings + + suspend fun updateAutoTunnelEnabled(enabled: Boolean) +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/DnsSettingsRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/DnsSettingsRepository.kt new file mode 100644 index 0000000..e460768 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/DnsSettingsRepository.kt @@ -0,0 +1,12 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.repository + +import com.zaneschepke.wireguardautotunnel.client.domain.model.DnsSettings +import kotlinx.coroutines.flow.Flow + +interface DnsSettingsRepository { + suspend fun upsert(dnsSettings: DnsSettings) + + val flow: Flow + + suspend fun getDnsSettings(): DnsSettings +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/GeneralSettingRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/GeneralSettingRepository.kt new file mode 100644 index 0000000..28df0db --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/GeneralSettingRepository.kt @@ -0,0 +1,20 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.repository + +import com.zaneschepke.wireguardautotunnel.client.data.model.AppMode +import com.zaneschepke.wireguardautotunnel.client.data.model.Theme +import com.zaneschepke.wireguardautotunnel.client.domain.model.GeneralSettings +import kotlinx.coroutines.flow.Flow + +interface GeneralSettingRepository { + suspend fun upsert(generalSettings: GeneralSettings) + + val flow: Flow + + suspend fun getGeneralSettings(): GeneralSettings + + suspend fun updateTheme(theme: Theme) + + suspend fun updateLocale(locale: String) + + suspend fun updateAppMode(appMode: AppMode) +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/LockdownSettingsRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/LockdownSettingsRepository.kt new file mode 100644 index 0000000..8d21055 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/LockdownSettingsRepository.kt @@ -0,0 +1,12 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.repository + +import com.zaneschepke.wireguardautotunnel.client.domain.model.LockdownSettings +import kotlinx.coroutines.flow.Flow + +interface LockdownSettingsRepository { + suspend fun upsert(lockdownSettings: LockdownSettings) + + val flow: Flow + + suspend fun getLockdownSettings(): LockdownSettings +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/ProxySettingsRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/ProxySettingsRepository.kt new file mode 100644 index 0000000..d9fdb75 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/ProxySettingsRepository.kt @@ -0,0 +1,12 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.repository + +import com.zaneschepke.wireguardautotunnel.client.domain.model.ProxySettings +import kotlinx.coroutines.flow.Flow + +interface ProxySettingsRepository { + suspend fun upsert(proxySettings: ProxySettings) + + val flow: Flow + + suspend fun getProxySettings(): ProxySettings +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/TunnelRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/TunnelRepository.kt new file mode 100644 index 0000000..ee5baf9 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/TunnelRepository.kt @@ -0,0 +1,48 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.repository + +import com.zaneschepke.wireguardautotunnel.client.domain.model.TunnelConfig +import kotlinx.coroutines.flow.Flow + +interface TunnelRepository { + val flow: Flow> + + val userTunnelsFlow: Flow> + + val globalTunnelFlow: Flow + + suspend fun getAll(): List + + suspend fun save(tunnelConfig: TunnelConfig) + + suspend fun saveAll(tunnelConfigList: List) + + suspend fun updatePrimaryTunnel(tunnelConfig: TunnelConfig?) + + suspend fun resetActiveTunnels() + + suspend fun updateEthernetTunnel(tunnelConfig: TunnelConfig?) + + suspend fun delete(tunnelConfig: TunnelConfig) + + suspend fun deleteByName(name: String) + + suspend fun getById(id: Int): TunnelConfig? + + suspend fun getActive(): List + + suspend fun getDefaultTunnel(): TunnelConfig? + + suspend fun getStartTunnel(): TunnelConfig? + + suspend fun getTunnelByName(name: String): TunnelConfig? + + suspend fun count(): Int + + suspend fun findByTunnelName(name: String): TunnelConfig? + + suspend fun findByTunnelNetworksName(name: String): List + + suspend fun findPrimary(): List + + suspend fun delete(tunnels: List) +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/extensions/TunnelRepository.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/extensions/TunnelRepository.kt new file mode 100644 index 0000000..471f3a8 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/domain/repository/extensions/TunnelRepository.kt @@ -0,0 +1,47 @@ +package com.zaneschepke.wireguardautotunnel.client.domain.repository.extensions + +import com.zaneschepke.wireguardautotunnel.client.domain.model.TunnelConfig +import com.zaneschepke.wireguardautotunnel.client.domain.repository.TunnelRepository + +suspend fun TunnelRepository.saveTunnelsUniquely( + tunnels: List, + existingNames: List, +) { + val uniqueTunnels = + generateUniquelyNamedConfigs( + tunnels, + existingNames + ) + saveAll(uniqueTunnels) +} + +private fun generateUniquelyNamedConfigs( + incoming: List, + existingNames: List, +): List { + val usedNames = existingNames.toMutableSet() + val result = mutableListOf() + val regex = Regex("(.+)\\s*\\((\\d+)\\)$") + + for (tun in incoming) { + var baseName = tun.name + var uniqueName = tun.name + var counter = 1 + + val matchResult = regex.find(baseName) + if (matchResult != null) { + baseName = matchResult.groupValues[1].trimEnd() + counter = matchResult.groupValues[2].toIntOrNull()?.plus(1) ?: 1 + uniqueName = "$baseName ($counter)" + } + + while (uniqueName in usedNames) { + uniqueName = "$baseName ($counter)" + counter++ + } + + usedNames.add(uniqueName) + result.add(tun.copy(name = uniqueName)) + } + return result +} diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/service/DaemonHealthService.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/service/DaemonHealthService.kt new file mode 100644 index 0000000..e695a1a --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/service/DaemonHealthService.kt @@ -0,0 +1,5 @@ +package com.zaneschepke.wireguardautotunnel.client.service + +interface DaemonHealthService { + suspend fun alive(): Boolean +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/service/TunnelCommandService.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/service/TunnelCommandService.kt new file mode 100644 index 0000000..cfeed16 --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/service/TunnelCommandService.kt @@ -0,0 +1,14 @@ +package com.zaneschepke.wireguardautotunnel.client.service + +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus +import kotlinx.coroutines.flow.Flow + +interface TunnelCommandService { + suspend fun startTunnel(id: Int): Result + suspend fun stopTunnel(id: Int): Result + suspend fun setMode(mode: BackendMode): Result + suspend fun setKillSwitch(enabled: Boolean): Result + suspend fun getStatus(): Result + fun statusFlow(): Flow +} \ No newline at end of file diff --git a/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/service/TunnelImportService.kt b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/service/TunnelImportService.kt new file mode 100644 index 0000000..fc44cbb --- /dev/null +++ b/client/src/commonMain/kotlin/com/zaneschepke/wireguardautotunnel/client/service/TunnelImportService.kt @@ -0,0 +1,11 @@ +package com.zaneschepke.wireguardautotunnel.client.service + + +typealias QuickString = String +typealias TunnelName = String +typealias QuickConfigMap = Map + +interface TunnelImportService { + suspend fun import(config: QuickString, name: TunnelName? = null) + suspend fun import(configs: QuickConfigMap) +} \ No newline at end of file diff --git a/client/src/commonMain/moko-resources/base/strings.xml b/client/src/commonMain/moko-resources/base/strings.xml new file mode 100644 index 0000000..b96ca50 --- /dev/null +++ b/client/src/commonMain/moko-resources/base/strings.xml @@ -0,0 +1,4 @@ + + + WG Tunnel + \ No newline at end of file diff --git a/composeApp/build.gradle.kts b/composeApp/build.gradle.kts new file mode 100644 index 0000000..90bff1b --- /dev/null +++ b/composeApp/build.gradle.kts @@ -0,0 +1,63 @@ +import org.jetbrains.compose.desktop.application.dsl.TargetFormat + +plugins { + alias(libs.plugins.kotlinMultiplatform) + alias(libs.plugins.jetbrainsCompose) + alias(libs.plugins.composeCompiler) + alias(libs.plugins.composeHotReload) + alias(libs.plugins.conveyor) +} + +group = "com.zaneschepke.wireguardautotunnel" +version = libs.versions.app.get() + +kotlin { + jvm() + + sourceSets { + commonMain.dependencies { + implementation(project(":client")) + implementation(libs.compose.runtime) + implementation(libs.compose.foundation) + implementation(libs.compose.material3) + implementation(libs.compose.ui) + implementation(libs.compose.components.resources) + implementation(libs.compose.uiToolingPreview) + implementation(libs.androidx.lifecycle.viewmodelCompose) + implementation(libs.androidx.lifecycle.runtimeCompose) + } + commonTest.dependencies { + implementation(libs.kotlin.test) + } + jvmMain.dependencies { + implementation(compose.desktop.currentOs) + implementation(libs.kotlinx.coroutinesSwing) + } + } +} + + +compose.desktop { + application { + mainClass = "com.zaneschepke.wireguardautotunnel.desktop.MainKt" + + nativeDistributions { + targetFormats(TargetFormat.Dmg, TargetFormat.Msi, TargetFormat.Deb, TargetFormat.Rpm, TargetFormat.AppImage) + packageName = "com.zaneschepke.wireguardautotunnel.desktop" + packageVersion = libs.versions.app.get() + } + } +} + +// Conveyor +dependencies { + linuxAmd64(libs.desktop.jvm.linux.x64) + macAmd64(libs.desktop.jvm.macos.x64) + macAarch64(libs.desktop.jvm.macos.arm64) + windowsAmd64(libs.desktop.jvm.windows.x64) + windowsAarch64(libs.desktop.jvm.windows.arm64) +} + +tasks.named("clean") { + delete(file("generated.conveyor.conf")) +} diff --git a/composeApp/src/jvmMain/composeResources/drawable/compose-multiplatform.xml b/composeApp/src/jvmMain/composeResources/drawable/compose-multiplatform.xml new file mode 100644 index 0000000..eba956a --- /dev/null +++ b/composeApp/src/jvmMain/composeResources/drawable/compose-multiplatform.xml @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/composeApp/src/jvmMain/kotlin/com/zaneschepke/wireguardautotunnel/desktop/App.kt b/composeApp/src/jvmMain/kotlin/com/zaneschepke/wireguardautotunnel/desktop/App.kt new file mode 100644 index 0000000..da577c1 --- /dev/null +++ b/composeApp/src/jvmMain/kotlin/com/zaneschepke/wireguardautotunnel/desktop/App.kt @@ -0,0 +1,27 @@ +package com.zaneschepke.wireguardautotunnel.desktop + +import androidx.compose.foundation.background +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.safeContentPadding +import androidx.compose.material3.MaterialTheme +import androidx.compose.runtime.Composable +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.tooling.preview.Preview + +@Composable +@Preview +fun App() { + MaterialTheme { + Column( + modifier = Modifier + .background(MaterialTheme.colorScheme.primaryContainer) + .safeContentPadding() + .fillMaxSize(), + horizontalAlignment = Alignment.CenterHorizontally, + ) { + + } + } +} \ No newline at end of file diff --git a/composeApp/src/jvmMain/kotlin/com/zaneschepke/wireguardautotunnel/desktop/Main.kt b/composeApp/src/jvmMain/kotlin/com/zaneschepke/wireguardautotunnel/desktop/Main.kt new file mode 100644 index 0000000..91427dd --- /dev/null +++ b/composeApp/src/jvmMain/kotlin/com/zaneschepke/wireguardautotunnel/desktop/Main.kt @@ -0,0 +1,15 @@ +package com.zaneschepke.wireguardautotunnel.desktop + +import androidx.compose.ui.window.Window +import androidx.compose.ui.window.application +import dev.icerock.moko.resources.compose.stringResource +import com.zaneschepke.wireguardautotunnel.SharedRes + +fun main() = application { + Window( + onCloseRequest = ::exitApplication, + title = stringResource(SharedRes.strings.app_name) + ) { + App() + } +} \ No newline at end of file diff --git a/composeApp/src/jvmMain/kotlin/com/zaneschepke/wireguardautotunnel/desktop/Platform.kt b/composeApp/src/jvmMain/kotlin/com/zaneschepke/wireguardautotunnel/desktop/Platform.kt new file mode 100644 index 0000000..5fe7b01 --- /dev/null +++ b/composeApp/src/jvmMain/kotlin/com/zaneschepke/wireguardautotunnel/desktop/Platform.kt @@ -0,0 +1,7 @@ +package com.zaneschepke.wireguardautotunnel.desktop + +class JVMPlatform { + val name: String = "Java ${System.getProperty("java.version")}" +} + +fun getPlatform() = JVMPlatform() \ No newline at end of file diff --git a/composeApp/src/jvmTest/kotlin/com/zaneschepke/wireguardautotunnel/desktop/ComposeAppDesktopTest.kt b/composeApp/src/jvmTest/kotlin/com/zaneschepke/wireguardautotunnel/desktop/ComposeAppDesktopTest.kt new file mode 100644 index 0000000..da7f94d --- /dev/null +++ b/composeApp/src/jvmTest/kotlin/com/zaneschepke/wireguardautotunnel/desktop/ComposeAppDesktopTest.kt @@ -0,0 +1,12 @@ +package com.zaneschepke.wireguardautotunnel.desktop + +import kotlin.test.Test +import kotlin.test.assertEquals + +class ComposeAppDesktopTest { + + @Test + fun example() { + assertEquals(3, 1 + 2) + } +} \ No newline at end of file diff --git a/conveyor.conf b/conveyor.conf new file mode 100644 index 0000000..a681ea8 --- /dev/null +++ b/conveyor.conf @@ -0,0 +1,204 @@ +include required("https://raw.githubusercontent.com/hydraulic-software/conveyor/master/configs/jvm/extract-native-libraries.conf") +include required("composeApp/generated.conveyor.conf") + +app { + fsname = wgtunnel + display-name = "WG Tunnel" + description = "WG Tunnel: WireGuard and AmneziaWG VPN client with auto-tunneling, lockdown and proxying." + license = MIT + homepage = "https://wgtunnel.com" + + site.base-url = "http://localhost" + + icons = ["icon.png"] + + jvm { + # for performance + options += "-XX:+UseG1GC" + options += "-XX:+UseStringDeduplication" + + # for high-res displays + system-properties { + "sun.java2d.uiScale" = "1.0" + "apple.laf.useScreenMenuBar" = "true" + } + + modules = [ detect ] + + gui { + main-class = com.zaneschepke.wireguardautotunnel.desktop.MainKt + } + + cli { + wgtctl { + main-class = com.zaneschepke.wireguardautotunnel.cli.MainKt + exe-name = wgtctl + } + daemon { + main-class = com.zaneschepke.wireguardautotunnel.daemon.MainKt + console = false + } + } + } + + inputs += "composeApp/build/libs/*.jar" + inputs += "daemon/build/install/daemon/lib/*.jar" + inputs += "cli/build/install/cli/lib/*.jar" + + // Target platforms + machines = [ + linux.amd64.glibc, + windows.amd64, +// windows.aarch64, + mac.amd64, + mac.aarch64 + ] + + linux { + deb.depends = ["systemd"] + rpm.requires = ["systemd"] + + desktop-file { + "Desktop Entry" { + Categories = "Network;Security;Settings;Utility;" + } + } + + # for CLI + symlinks = [ + /usr/bin/wgtunnel -> ${app.linux.install-path}/bin/wgtunnel, + /usr/bin/wgtctl -> ${app.linux.install-path}/bin/wgtctl, + /usr/bin/wgt -> ${app.linux.install-path}/bin/wgtctl + ] + + services { + daemon { + include "/stdlib/linux/service.conf" + + file-name = "wgtunnel-daemon.service" + + Unit { + Description = "WG Tunnel Daemon" + Documentation = "https://wgtunnel.com" + Before = "network-online.target" + After = "NetworkManager.service systemd-resolved.service" + StartLimitBurst = 5 + StartLimitIntervalSec = 20 + } + + Service { + Restart = always + RestartSec = 1s + ExecStart = ${app.linux.install-path}/bin/daemon + Type = exec + + StandardOutput = journal + StandardError = journal + + Environment = [ + "WG_TUNNEL_SERVICE=1", + "HOME=%S/wgtunnel" + ] + + WorkingDirectory = ${app.linux.install-path} + + # Allow socket access + UMask = 0000 + + ProtectSystem = full + + StateDirectory = "wgtunnel" + LogsDirectory = "wgtunnel" + ConfigurationDirectory = "wgtunnel" + RuntimeDirectory = "wgtunnel" + RuntimeDirectoryMode = 0755 + RuntimeDirectoryPreserve = "restart" + + # Added CAP_DAC_OVERRIDE for per user IPC key read + CapabilityBoundingSet = "CAP_NET_ADMIN CAP_NET_BIND_SERVICE CAP_NET_RAW CAP_DAC_OVERRIDE" + AmbientCapabilities = "CAP_NET_ADMIN CAP_NET_BIND_SERVICE CAP_NET_RAW CAP_DAC_OVERRIDE" + + RestrictAddressFamilies = "AF_INET AF_INET6 AF_NETLINK AF_UNIX" + + KillSignal = SIGTERM + TimeoutStopSec = 30 + + ReadWritePaths = [ + "/run/wgtunnel", + "/etc/resolv.conf", + "/var/lib/wgtunnel", + "/home", # Need home to be able to read user's IPC key + "/etc/resolv.conf", + "/run/systemd/resolve", + "/run/systemd/resolve/stub-resolv.conf", + "/run/systemd/resolve/resolv.conf" + ] + } + + Install { + WantedBy = "multi-user.target" + } + } + } + } + + mac { + + entitlements-plist = { + "com.apple.security.network.client" = true + "com.apple.security.network.server" = true + } + } + + windows { + + inputs += daemon/winsw/artifacts/publish/WinSW-x64.exe -> service-wrapper.exe + + aarch64 { + inputs += tunnel/tools/wintun/arm64/wintun.dll -> wintun.dll + } + amd64 { + inputs += tunnel/tools/wintun/amd64/wintun.dll -> wintun.dll + } + + manifests { + + exe { + requested-execution-level = asInvoker + } + msix { + display-name = "WG Tunnel" + description = "WireGuard and AmneziaWG VPN client with auto-tunneling, lockdown and proxying." + + min-version = "10.0.19041.0" + capabilities += "rescap:allowElevation" + capabilities += "rescap:localSystemServices" + capabilities += "rescap:packagedServices" + + namespaces { + desktop6 = "http://schemas.microsoft.com/appx/manifest/desktop/windows10/6" + uap3 = "http://schemas.microsoft.com/appx/manifest/uap/windows10/3" + } + + ignorable-namespaces += "desktop6" + ignorable-namespaces += "uap3" + + extensions-xml = """ + + + + """ + + virtualization { + excluded-directories += "LocalAppData/Temp" + excluded-directories += "CommonAppData/wgtunnel" + excluded-directories += "CommonAppData/wgtunnel/logs" + } + } + } + + start-on-login = false + updates = background + } +} +conveyor.compatibility-level = 21 \ No newline at end of file diff --git a/core/.gitignore b/core/.gitignore new file mode 100644 index 0000000..42afabf --- /dev/null +++ b/core/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/core/build.gradle.kts b/core/build.gradle.kts new file mode 100644 index 0000000..b374523 --- /dev/null +++ b/core/build.gradle.kts @@ -0,0 +1,18 @@ +plugins { + kotlin("jvm") + alias(libs.plugins.serialization) +} + +dependencies { + implementation(libs.kotlinx.serialization) + implementation(libs.apache.commons.lang3) + + // Logging + implementation(libs.kermit) + implementation(libs.logback.classic) + + implementation(libs.kotlinx.coroutines.core) + + // Backoff + implementation(libs.kotlin.retry) +} \ No newline at end of file diff --git a/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/crypto/Crypto.kt b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/crypto/Crypto.kt new file mode 100644 index 0000000..d75ebb1 --- /dev/null +++ b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/crypto/Crypto.kt @@ -0,0 +1,60 @@ +package com.zaneschepke.wireguardautotunnel.core.crypto + +import java.security.SecureRandom +import javax.crypto.Cipher +import javax.crypto.SecretKey +import javax.crypto.spec.GCMParameterSpec +import javax.crypto.spec.SecretKeySpec +import kotlin.io.encoding.Base64 + +object Crypto { + + const val KEY_ALGORITHM = "AES" + const val CYPHER = "AES/GCM/NoPadding" + + private val random = SecureRandom() + + fun generateRandomBase64(byteLength: Int = 32): String { + val bytes = ByteArray(byteLength) + random.nextBytes(bytes) + return Base64.encode(bytes) + } + + fun generateRandomAESKey() : SecretKey { + val keyBytes = ByteArray(32) + random.nextBytes(keyBytes) + return SecretKeySpec(keyBytes, KEY_ALGORITHM) + } + + fun generateRandomBase64EncodedAesKey() : String { + return Base64.encode(generateRandomAESKey().encoded) + } + + fun decodeKey(key: String): SecretKey { + return SecretKeySpec(Base64.decode(key), KEY_ALGORITHM) + } + + fun encryptWithMasterKey(plainText: String, key: SecretKey): String { + val cipher = Cipher.getInstance(CYPHER) + val iv = ByteArray(12) // 96-bit IV for GCM + random.nextBytes(iv) + val spec = GCMParameterSpec(128, iv) + cipher.init(Cipher.ENCRYPT_MODE, key, spec) + val cipherText = cipher.doFinal(plainText.toByteArray(Charsets.UTF_8)) + + // store IV + ciphertext together, base64-encoded + val combined = iv + cipherText + return Base64.encode(combined) + } + + fun decryptWithMasterKey(encrypted: String, key: SecretKey): String { + val combined = Base64.decode(encrypted) + val iv = combined.copyOfRange(0, 12) + val cipherText = combined.copyOfRange(12, combined.size) + val cipher = Cipher.getInstance(CYPHER) + val spec = GCMParameterSpec(128, iv) + cipher.init(Cipher.DECRYPT_MODE, key, spec) + val decrypted = cipher.doFinal(cipherText) + return String(decrypted, Charsets.UTF_8) + } +} \ No newline at end of file diff --git a/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/crypto/HmacProtector.kt b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/crypto/HmacProtector.kt new file mode 100644 index 0000000..be79185 --- /dev/null +++ b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/crypto/HmacProtector.kt @@ -0,0 +1,27 @@ +package com.zaneschepke.wireguardautotunnel.core.crypto + +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.SecureCommand +import javax.crypto.Mac +import javax.crypto.spec.SecretKeySpec +import kotlin.io.encoding.Base64 +import kotlin.math.abs + +object HmacProtector { + private const val ALGORITHM = "HmacSHA256" + + fun generateSignature(key: String, timestamp: Long, payload: String?): String { + val mac = Mac.getInstance(ALGORITHM) + mac.init(SecretKeySpec(key.toByteArray(), ALGORITHM)) + val dataToSign = "$timestamp${payload ?: ""}" + return Base64.encode(mac.doFinal(dataToSign.toByteArray())) + } + + fun verify(key: String, command: SecureCommand): Boolean { + val now = System.currentTimeMillis() / 1000 + // 30 seconds window to prevent replay attacks + if (abs(now - command.timestamp) > 30) return false + + val expected = generateSignature(key, command.timestamp, command.payload) + return expected == command.signature + } +} \ No newline at end of file diff --git a/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/helper/PermissionsHelper.kt b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/helper/PermissionsHelper.kt new file mode 100644 index 0000000..15e8203 --- /dev/null +++ b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/helper/PermissionsHelper.kt @@ -0,0 +1,198 @@ +package com.zaneschepke.wireguardautotunnel.core.helper + +import co.touchlab.kermit.Logger +import com.github.michaelbull.retry.policy.binaryExponentialBackoff +import com.github.michaelbull.retry.policy.plus +import com.github.michaelbull.retry.policy.stopAtAttempts +import com.github.michaelbull.retry.retry +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import org.apache.commons.lang3.SystemUtils +import java.io.File +import java.io.FileNotFoundException +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths +import java.nio.file.attribute.PosixFilePermissions + +object PermissionsHelper { + + val socketRetryPolicy = binaryExponentialBackoff(min = 10L, max = 250L) + stopAtAttempts(25) + + // unix + const val WORLD_WRITABLE_OCTAL = "666" + const val WORLD_READWRITE_SYMBOLIC = "rw-rw-rw-" + const val OWNER_FULL_CONTROL_OCTAL = "755" + const val OWNER_FULL_CONTROL_SYMBOLIC = "rwxr-xr-x" + const val OWNER_ONLY_PRIVATE_FILE = "rw-------" + const val OWNER_ONLY_PRIVATE_DIR = "rwx------" + + // windows universal SIDs + private const val SID_SYSTEM = "*S-1-5-18" + private const val SID_ADMINISTRATORS = "*S-1-5-32-544" + private const val SID_USERS = "*S-1-5-32-545" + private const val SID_CREATOR_OWNER = "*S-1-3-0" + + // windows permission flags + private const val WIN_DIR_MODIFY_INHERIT = ":(OI)(CI)(M)" + private const val WIN_FULL_CONTROL_INHERIT = ":(OI)(CI)(F)" + + fun setupDirectoryPermissionsUnix(runtimeDirPath: String) { + val path = Paths.get(runtimeDirPath) + + if (Files.exists(path)) { + try { + Files.setPosixFilePermissions(path, PosixFilePermissions.fromString(OWNER_FULL_CONTROL_SYMBOLIC)) + Logger.i { "Successfully set directory permissions to " } + } catch (e: Exception) { + Logger.e { "POSIX native permissions failed: ${e.message} → falling back to chmod" } + try { + val exitCode = ProcessBuilder("chmod", OWNER_FULL_CONTROL_OCTAL, runtimeDirPath) + .start() + .waitFor() + + if (exitCode == 0) { + Logger.i { "Successfully set directory permissions using chmod" } + } else { + Logger.e { "chmod failed with exit code $exitCode" } + } + } catch (chmodEx: Exception) { + Logger.e { "Failed to execute chmod: ${chmodEx.message}" } + } + } + } else { + Logger.w { "Runtime directory $runtimeDirPath not found" } + } + } + + fun setupDirectoryPermissionsWindows(runtimeDirPath: String) { + try { + val process = ProcessBuilder( + "icacls", runtimeDirPath, + "/grant", "$SID_USERS$WIN_DIR_MODIFY_INHERIT", + "/grant", "$SID_SYSTEM$WIN_FULL_CONTROL_INHERIT", + "/grant", "$SID_ADMINISTRATORS$WIN_FULL_CONTROL_INHERIT" + ).start() + + if (process.waitFor() != 0) { + val error = process.errorStream.bufferedReader().use { it.readText() } + Logger.e { "icacls directory setup failed: $error" } + } + } catch (e: Exception) { + Logger.e(e) { "Failed to set Windows directory ACLs" } + } + } + + suspend fun setupSocketPermissionsWithPollUnix(socketPath: String) = withContext(Dispatchers.IO) { + val socketFile = File(socketPath) + + runCatching { + retry(socketRetryPolicy) { + if (!socketFile.exists()) { + throw FileNotFoundException("Socket $socketPath not found yet") + } + setupSocketPermissionsUnix(socketPath) + } + + val socketPerms = Files.getPosixFilePermissions(Paths.get(socketPath)) + Logger.i { "Final socket permissions: $socketPerms" } + + }.onFailure { + Logger.e { "Socket $socketPath failed to appear. Daemon likely failed to start: ${it.message}" } + } + } + + suspend fun setupSocketPermissionsWithPollWindows(socketPath: String) = withContext(Dispatchers.IO) { + val socketFile = File(socketPath) + runCatching { + retry(socketRetryPolicy) { + if (!socketFile.exists()) throw FileNotFoundException("Socket not found yet") + setupDirectoryPermissionsWindows(socketPath) + } + logWindowsACLs(socketPath) + }.onFailure { + Logger.e { "Socket $socketPath failed to appear on Windows: ${it.message}" } + } + } + + + + fun setupSocketPermissionsUnix(socketPath: String) { + val path = Paths.get(socketPath) + try { + Files.setPosixFilePermissions(path, PosixFilePermissions.fromString(WORLD_READWRITE_SYMBOLIC)) + Logger.i { "Successfully set socket permissions to 0666" } + } catch (e: Exception) { + Logger.e { "POSIX native permissions failed: ${e.message} → falling back to chmod" } + + try { + val exitCode = ProcessBuilder("chmod", WORLD_WRITABLE_OCTAL, socketPath) + .start() + .waitFor() + + if (exitCode == 0) { + Logger.i { "Successfully set socket permissions using chmod" } + } else { + Logger.e { "chmod failed with exit code $exitCode" } + throw IllegalStateException("chmod exited with non-zero status") + } + } catch (chmodEx: Exception) { + Logger.e { "All POSIX methods failed: ${chmodEx.message} → using JVM fallback" } + + // try file API + val socketFile = path.toFile() + val readOk = socketFile.setReadable(true, false) + val writeOk = socketFile.setWritable(true, false) + + if (readOk && writeOk) { + Logger.w { "Applied weak Java fallback permissions (readable/writable for all)" } + } else { + Logger.e { "Failed to set any permissions on socket $socketPath" } + } + } + } + } + + fun setOwnerOnly(path: Path) { + try { + if (SystemUtils.IS_OS_WINDOWS) { + applyWindowsOwnerOnlyPermissions(path) + } else { + applyPosixOwnerOnlyPermissions(path) + } + } catch (e: Exception) { + Logger.e(e) { "Failed to set permissions for: $path" } + } + } + + private fun applyPosixOwnerOnlyPermissions(path: Path) { + val isDir = Files.isDirectory(path) + val permsString = if (isDir) OWNER_ONLY_PRIVATE_DIR else OWNER_ONLY_PRIVATE_FILE + Files.setPosixFilePermissions(path, PosixFilePermissions.fromString(permsString)) + } + + private fun applyWindowsOwnerOnlyPermissions(path: Path) { + try { + val process = ProcessBuilder( + "icacls", path.toString(), + "/inheritance:r", // remove inherited perms + "/grant:r", "$SID_SYSTEM$WIN_FULL_CONTROL_INHERIT", + "/grant:r", "$SID_CREATOR_OWNER$WIN_FULL_CONTROL_INHERIT", + "/grant:r", "$SID_ADMINISTRATORS$WIN_FULL_CONTROL_INHERIT" + ).start() + + if (process.waitFor() != 0) { + Logger.e { "icacls owner-only failed for $path" } + } + } catch (e: Exception) { + Logger.e(e) { "Error applying owner-only Windows perms" } + } + } + private fun logWindowsACLs(path: String) { + runCatching { + val output = ProcessBuilder("icacls", path).start().inputStream.bufferedReader().readText() + Logger.i { "Final ACLs for $path: $output" } + } + } + +} \ No newline at end of file diff --git a/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/IPC.kt b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/IPC.kt new file mode 100644 index 0000000..79a849c --- /dev/null +++ b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/IPC.kt @@ -0,0 +1,76 @@ +package com.zaneschepke.wireguardautotunnel.core.ipc + +import co.touchlab.kermit.Logger +import com.zaneschepke.wireguardautotunnel.core.crypto.Crypto +import com.zaneschepke.wireguardautotunnel.core.helper.PermissionsHelper +import org.apache.commons.lang3.SystemUtils +import java.io.File +import java.nio.file.Files +import java.nio.file.Paths + +object IPC { + const val KEY_FILE = "ipc.key" + const val USER_FOLDER = ".wgtunnel" + const val SOCKET_FILE_NAME = "daemon.sock" + + fun resolveKeyForUser(user: String): String? { + if (!user.matches(Regex("^[a-zA-Z0-9._-]+$"))) { + Logger.w { "Invalid username format: $user" } + return null + } + + return try { + val userHome = getUserHome(user) + val keyPath = Paths.get(userHome, USER_FOLDER, KEY_FILE) + + if (Files.exists(keyPath)) { + keyPath.toFile().readText().trim().takeIf { it.isNotBlank() } + } else { + Logger.w { "IPC key not found for user: $user → $keyPath" } + null + } + } catch (e: Exception) { + Logger.Companion.e(e) { "Failed to resolve IPC key for user: $user" } + null + } + } + + // should be called by client ONLY + fun getIPCSecret() : String { + val ipcFile = File(System.getProperty("user.home"), "${IPC.USER_FOLDER}/${IPC.KEY_FILE}") + if (!ipcFile.parentFile.exists()) ipcFile.parentFile.mkdirs() + + return if (!ipcFile.exists()) { + val secret = Crypto.generateRandomBase64(32) + ipcFile.writeText(secret) + // Set 600 permissions immediately + PermissionsHelper.setOwnerOnly(ipcFile.toPath()) + secret + } else { + ipcFile.readText() + } + } + + private fun getUserHome(user: String): String { + return when { + SystemUtils.IS_OS_WINDOWS -> "C:\\Users\\$user" + SystemUtils.IS_OS_MAC_OSX -> "/Users/$user" + else -> "/home/$user" + } + } + + fun getDaemonSocketPath(): String { + return when { + SystemUtils.IS_OS_WINDOWS -> { + val baseDir = System.getenv("PROGRAMDATA") + "\\wgtunnel" + "$baseDir\\$SOCKET_FILE_NAME" + } + SystemUtils.IS_OS_MAC_OSX -> { + "/tmp/wgtunnel/$SOCKET_FILE_NAME" + } + else -> { + "/run/wgtunnel/$SOCKET_FILE_NAME" + } + } + } +} \ No newline at end of file diff --git a/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/BackendMode.kt b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/BackendMode.kt new file mode 100644 index 0000000..e239cff --- /dev/null +++ b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/BackendMode.kt @@ -0,0 +1,8 @@ +package com.zaneschepke.wireguardautotunnel.core.ipc.dto + +import kotlinx.serialization.Serializable + +@Serializable +enum class BackendMode { + KERNEL, USERSPACE, PROXY +} \ No newline at end of file diff --git a/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/BackendStatus.kt b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/BackendStatus.kt new file mode 100644 index 0000000..6aec62f --- /dev/null +++ b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/BackendStatus.kt @@ -0,0 +1,10 @@ +package com.zaneschepke.wireguardautotunnel.core.ipc.dto + +import kotlinx.serialization.Serializable + +@Serializable +data class BackendStatus( + val killSwitchEnabled: Boolean, + val mode: BackendMode, + val activeTunnels: List +) \ No newline at end of file diff --git a/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/SecureCommand.kt b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/SecureCommand.kt new file mode 100644 index 0000000..366cd85 --- /dev/null +++ b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/SecureCommand.kt @@ -0,0 +1,11 @@ +package com.zaneschepke.wireguardautotunnel.core.ipc.dto + +import kotlinx.serialization.Serializable + +@Serializable +data class SecureCommand( + val timestamp: Long, + val signature: String, + val userHint: String, + val payload: String? = null +) \ No newline at end of file diff --git a/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/StartTunnelRequest.kt b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/StartTunnelRequest.kt new file mode 100644 index 0000000..23dac73 --- /dev/null +++ b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/StartTunnelRequest.kt @@ -0,0 +1,10 @@ +package com.zaneschepke.wireguardautotunnel.core.ipc.dto + +import kotlinx.serialization.Serializable + +@Serializable +data class StartTunnelRequest( + val id: Int, + val name: String, + val quickConfig: String +) \ No newline at end of file diff --git a/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/TunnelState.kt b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/TunnelState.kt new file mode 100644 index 0000000..d298cc9 --- /dev/null +++ b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/TunnelState.kt @@ -0,0 +1,8 @@ +package com.zaneschepke.wireguardautotunnel.core.ipc.dto + +import kotlinx.serialization.Serializable + +@Serializable +enum class TunnelState { + DOWN, STARTING, HEALTHY, HANDSHAKE_FAILURE, RESOLVING_DNS, UNKNOWN +} \ No newline at end of file diff --git a/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/TunnelStatus.kt b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/TunnelStatus.kt new file mode 100644 index 0000000..8e837b3 --- /dev/null +++ b/core/src/main/java/com/zaneschepke/wireguardautotunnel/core/ipc/dto/TunnelStatus.kt @@ -0,0 +1,10 @@ +package com.zaneschepke.wireguardautotunnel.core.ipc.dto + +import kotlinx.serialization.Serializable + +@Serializable +data class TunnelStatus( + val id: Int, + val name: String, + val state: TunnelState +) \ No newline at end of file diff --git a/daemon/.gitignore b/daemon/.gitignore new file mode 100644 index 0000000..4029495 --- /dev/null +++ b/daemon/.gitignore @@ -0,0 +1 @@ +/output \ No newline at end of file diff --git a/daemon/build.gradle.kts b/daemon/build.gradle.kts new file mode 100644 index 0000000..c974931 --- /dev/null +++ b/daemon/build.gradle.kts @@ -0,0 +1,87 @@ +plugins { + kotlin("jvm") + application + alias(libs.plugins.serialization) +} + +dependencies { + implementation(project(":tunnel")) + implementation(project(":parser")) + implementation(project(":core")) + + // DI + implementation(libs.koin.core) + + implementation(libs.bundles.ktor.server.jvm) + + implementation(libs.kotlinx.coroutines.core) + + // Logging + implementation(libs.kermit) + implementation(libs.logback.classic) + + testImplementation(kotlin("test")) + + // secure caching + implementation(libs.kstore) + implementation(libs.kstore.file) + + implementation(libs.kotlinx.serialization) + + // Util + implementation(libs.apache.commons.lang3) +} + +application { + mainClass.set("com.zaneschepke.wireguardautotunnel.daemon.MainKt") +} + +tasks.test { + useJUnitPlatform() +} + +val cleanDotNet = tasks.register("cleanDotNet") { + group = "build" + workingDir = file("winsw/src") + commandLine("dotnet", "clean", "-c", "Release") +} + +tasks.named("clean") { + dependsOn(cleanDotNet) + + delete(file("output")) + // Clean up WinSW specific artifacts + delete(file("winsw/src/WinSW/bin")) + delete(file("winsw/src/WinSW/obj")) + delete(file("winsw/artifacts")) +} + +tasks.named("installDist") { + dependsOn("buildWinSW") +} + +tasks.register("buildWinSW") { + val winSwDir = "winsw/src/WinSW" + group = "build" + description = "Build Windows service wrapper." + workingDir = file(winSwDir) + + inputs.files( + fileTree(winSwDir) { + include("**/*.cs", "**/*.csproj", "**/appsettings.json") + exclude("bin/**", "obj/**") + } + ).withPropertyName("winSwSourceFiles") + .withPathSensitivity(PathSensitivity.RELATIVE) + + outputs.dir(file("$winSwDir/bin/Release/net7.0-windows/win-x64/publish")) + .withPropertyName("winSwPublishDir") + + commandLine("dotnet", "publish", "WinSW.csproj", + "-f", "net7.0-windows", + "-c", "Release", + "-r", "win-x64", + "--self-contained", "true", + "-p:PublishSingleFile=true" + ) +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/Main.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/Main.kt new file mode 100644 index 0000000..916dbbf --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/Main.kt @@ -0,0 +1,32 @@ +package com.zaneschepke.wireguardautotunnel.daemon + +import co.touchlab.kermit.Logger +import com.zaneschepke.wireguardautotunnel.daemon.di.daemonModule +import com.zaneschepke.wireguardautotunnel.daemon.util.initLogger +import org.koin.core.context.startKoin +import org.koin.java.KoinJavaComponent.inject +import kotlin.system.exitProcess + +fun main() { + initLogger() + startKoin { + modules(daemonModule) + } + + val daemon : TunnelDaemon by inject(TunnelDaemon::class.java) + + try { + Runtime.getRuntime().addShutdownHook( + Thread { + Logger.i { "Stopping daemon..." } + daemon.stop() + } + ) + + Logger.i { "Starting daemon..." } + daemon.run() + } catch (e: Exception) { + Logger.e(e) { "Shutting down..." } + exitProcess(1) + } +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/TunnelDaemon.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/TunnelDaemon.kt new file mode 100644 index 0000000..c6e1797 --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/TunnelDaemon.kt @@ -0,0 +1,105 @@ +package com.zaneschepke.wireguardautotunnel.daemon + +import co.touchlab.kermit.Logger +import com.zaneschepke.wireguardautotunnel.core.helper.PermissionsHelper +import com.zaneschepke.wireguardautotunnel.daemon.data.DaemonCacheRepository +import com.zaneschepke.wireguardautotunnel.daemon.plugin.UDSPlugins +import com.zaneschepke.wireguardautotunnel.daemon.routes.tunnelCommandRoutes +import com.zaneschepke.wireguardautotunnel.tunnel.Backend +import io.ktor.http.* +import io.ktor.serialization.kotlinx.json.* +import io.ktor.server.application.* +import io.ktor.server.cio.* +import io.ktor.server.engine.* +import io.ktor.server.plugins.contentnegotiation.* +import io.ktor.server.routing.* +import io.ktor.server.websocket.* +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.serialization.json.Json +import org.apache.commons.lang3.SystemUtils +import java.io.File +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicBoolean + +class TunnelDaemon(private val json: Json, + private val backend: Backend, + private val cacheRepository: DaemonCacheRepository, + private val socketPath: String +) { + private var server: EmbeddedServer<*, *>? = null + private val running = AtomicBoolean(false) + private val shutdownLatch = CountDownLatch(1) + private val scope = CoroutineScope(Dispatchers.IO) + + // run the daemon + internal fun run() { + startUdsServer() + shutdownLatch.await() // block main thread until stop() + } + + + fun startUdsServer() { + if (!running.compareAndSet(false, true)) return + + Logger.i { "Starting IPC server" } + + val socketFile = File(socketPath) + val runtimeDir = socketFile.parentFile + runtimeDir.mkdirs() + + when { + SystemUtils.IS_OS_WINDOWS -> PermissionsHelper.setupDirectoryPermissionsWindows(runtimeDir.absolutePath) + SystemUtils.IS_OS_UNIX -> PermissionsHelper.setupDirectoryPermissionsUnix(runtimeDir.absolutePath) + } + + socketFile.delete() // delete old socket if exists + + + server = embeddedServer(CIO, configure = { + unixConnector(socketPath) + }) { + install(ContentNegotiation) { + json(json) + } + install(WebSockets) + routing { + get("/status") { call.response.status(HttpStatusCode.OK) } + route("/tunnel") { + install(UDSPlugins.hmacShieldPlugin) + tunnelCommandRoutes(json, backend) + } + } + monitor.subscribe(ApplicationStarted) { + Logger.i { "IPC server started successfully" } + } + }.start(wait = false) + + scope.launch { + when { + SystemUtils.IS_OS_UNIX -> PermissionsHelper.setupSocketPermissionsWithPollUnix(socketPath) + SystemUtils.IS_OS_WINDOWS -> PermissionsHelper.setupSocketPermissionsWithPollWindows(socketPath) + } + } + + scope.launch { + // TODO handle startup with cached settings + val settings = cacheRepository.getKillSwitchSettings() + val startConfigs = cacheRepository.getStartConfigs() + Logger.d { "Got kill switch settings $settings" } + Logger.d { "Got start configs of size ${startConfigs.size}" } + } + } + + + fun stop() { + if (!running.compareAndSet(true, false)) return + Logger.i { "Daemon stop initiated - closing all tunnels" } + backend.shutdown() + Logger.i { "All tunnels closed - stopping server" } + server?.stop(gracePeriodMillis = 1_000, timeoutMillis = 2_000) + shutdownLatch.countDown() + Logger.i { "UDS server fully stopped" } + } +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/DaemonCacheRepository.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/DaemonCacheRepository.kt new file mode 100644 index 0000000..09e22a2 --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/DaemonCacheRepository.kt @@ -0,0 +1,10 @@ +package com.zaneschepke.wireguardautotunnel.daemon.data + +import com.zaneschepke.wireguardautotunnel.daemon.data.model.KillSwitchSettings + +interface DaemonCacheRepository { + suspend fun getKillSwitchSettings(): KillSwitchSettings + suspend fun setKillSwitchSettings(settings: KillSwitchSettings) + suspend fun getStartConfigs(): Set + suspend fun setStartConfigs(configs: Set) +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/KStoreDaemonCacheRepository.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/KStoreDaemonCacheRepository.kt new file mode 100644 index 0000000..3a73c2f --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/KStoreDaemonCacheRepository.kt @@ -0,0 +1,105 @@ +package com.zaneschepke.wireguardautotunnel.daemon.data + +import co.touchlab.kermit.Logger +import com.zaneschepke.wireguardautotunnel.daemon.data.model.DaemonCacheData +import com.zaneschepke.wireguardautotunnel.daemon.data.model.KillSwitchSettings +import io.github.xxfast.kstore.KStore +import io.github.xxfast.kstore.file.storeOf +import kotlinx.io.files.Path +import kotlinx.serialization.json.Json +import org.apache.commons.lang3.SystemUtils +import java.nio.file.Files +import java.nio.file.Paths +import java.nio.file.attribute.PosixFilePermissions + +class KStoreDaemonCacheRepository( + private val baseCacheDir: java.nio.file.Path = getCacheBaseDir() +) : DaemonCacheRepository { + + companion object { + const val CACHE_FILE_NAME = "cache.json" + + private fun getCacheBaseDir(): java.nio.file.Path { + return when { + SystemUtils.IS_OS_MAC_OSX -> Paths.get("/Library/Application Support/wgtunnel") + SystemUtils.IS_OS_WINDOWS -> Paths.get(System.getenv("PROGRAMDATA") + "\\wgtunnel") + else -> Paths.get("/var/lib/wgtunnel") + } + } + } + + init { + Files.createDirectories(baseCacheDir) + setSecurePermissions(baseCacheDir) + } + + private fun getStore(): KStore { + val storePathNio = baseCacheDir.resolve(CACHE_FILE_NAME) + val storeKPath = Path(storePathNio.toString()) + + if (!Files.exists(storePathNio)) { + Files.createFile(storePathNio) + } + + if (Files.size(storePathNio) == 0L) { + val defaultData = DaemonCacheData() + val defaultJson = Json.encodeToString(defaultData) + Files.writeString(storePathNio, defaultJson) + } + + setSecurePermissions(storePathNio) + + return storeOf( + file = storeKPath, + default = DaemonCacheData() + ) + } + + + private fun setSecurePermissions(path: java.nio.file.Path) { + val os = System.getProperty("os.name").lowercase() + try { + if (!os.contains("win")) { + val isDirectory = Files.isDirectory(path) + val permsString = if (isDirectory) "rwx------" else "rw-------" // 700 for dirs, 600 for files + val perms = PosixFilePermissions.fromString(permsString) + Files.setPosixFilePermissions(path, perms) + } else { + val process = ProcessBuilder( + "icacls", path.toString(), + "/inheritance:r", // remove inherited permissions + "/grant:r", "SYSTEM:(F)", // full control to system + "/grant:r", "Administrators:(F)" // full control to admin + ).start() + val exitCode = process.waitFor() + if (exitCode != 0) { + Logger.e { "icacls failed with code $exitCode"} + } + } + } catch (e: Exception) { + Logger.e(e) { "Failed to set permissions"} + } + } + + override suspend fun getKillSwitchSettings(): KillSwitchSettings { + return getStore().get()?.killSwitch ?: KillSwitchSettings(false, false) + } + + override suspend fun setKillSwitchSettings(settings: KillSwitchSettings) { + val store = getStore() + store.update { current -> + current?.copy(killSwitch = settings) ?: DaemonCacheData(killSwitch = settings) + } + } + + override suspend fun getStartConfigs(): Set { + return getStore().get()?.startConfigs ?: emptySet() + } + + override suspend fun setStartConfigs(configs: Set) { + val store = getStore() + store.update { current -> + current?.copy(startConfigs = configs) ?: DaemonCacheData(startConfigs = configs) + } + } +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/model/DaemonCacheData.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/model/DaemonCacheData.kt new file mode 100644 index 0000000..9a71922 --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/model/DaemonCacheData.kt @@ -0,0 +1,9 @@ +package com.zaneschepke.wireguardautotunnel.daemon.data.model + +import kotlinx.serialization.Serializable + +@Serializable +data class DaemonCacheData( + val killSwitch: KillSwitchSettings = KillSwitchSettings(false, false), + val startConfigs: Set = emptySet() +) diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/model/KillSwitchSettings.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/model/KillSwitchSettings.kt new file mode 100644 index 0000000..c95ea31 --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/data/model/KillSwitchSettings.kt @@ -0,0 +1,6 @@ +package com.zaneschepke.wireguardautotunnel.daemon.data.model + +import kotlinx.serialization.Serializable + +@Serializable +data class KillSwitchSettings(val enabled: Boolean, val bypassLan: Boolean) diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/di/daemonModule.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/di/daemonModule.kt new file mode 100644 index 0000000..21d4a6b --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/di/daemonModule.kt @@ -0,0 +1,23 @@ +package com.zaneschepke.wireguardautotunnel.daemon.di + +import com.zaneschepke.wireguardautotunnel.core.ipc.IPC +import com.zaneschepke.wireguardautotunnel.daemon.TunnelDaemon +import com.zaneschepke.wireguardautotunnel.daemon.data.DaemonCacheRepository +import com.zaneschepke.wireguardautotunnel.daemon.data.KStoreDaemonCacheRepository +import com.zaneschepke.wireguardautotunnel.tunnel.AmneziaBackend +import com.zaneschepke.wireguardautotunnel.tunnel.Backend +import kotlinx.serialization.json.Json +import org.koin.core.module.dsl.singleOf +import org.koin.dsl.module + +val daemonModule = module { + single { + Json { + ignoreUnknownKeys = true + encodeDefaults = true + } + } + single { AmneziaBackend() } + single { KStoreDaemonCacheRepository() } + single { TunnelDaemon(get(), get(), get(), IPC.getDaemonSocketPath()) } +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/dto/Extensions.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/dto/Extensions.kt new file mode 100644 index 0000000..ad82b25 --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/dto/Extensions.kt @@ -0,0 +1,37 @@ +package com.zaneschepke.wireguardautotunnel.daemon.dto + +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendStatus +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.TunnelState +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.TunnelStatus +import com.zaneschepke.wireguardautotunnel.tunnel.Backend +import com.zaneschepke.wireguardautotunnel.tunnel.Tunnel + +fun Tunnel.State.toDto(): TunnelState = when (this) { + Tunnel.State.Down -> TunnelState.DOWN + Tunnel.State.Starting -> TunnelState.STARTING + is Tunnel.State.Up.Healthy -> TunnelState.HEALTHY + is Tunnel.State.Up.HandshakeFailure -> TunnelState.HANDSHAKE_FAILURE + is Tunnel.State.Up.ResolvingDns -> TunnelState.RESOLVING_DNS + is Tunnel.State.Up.Unknown -> TunnelState.UNKNOWN +} + +fun Backend.Mode.toDto(): BackendMode = when (this) { + Backend.Mode.Userspace -> BackendMode.USERSPACE + Backend.Mode.Proxy -> BackendMode.PROXY +} + +fun Backend.Status.toDto(): BackendStatus { + val activeList = activeTunnels.map { (tunnel, state) -> + TunnelStatus( + id = tunnel.id, + name = tunnel.name, + state = state.toDto() + ) + } + return BackendStatus( + killSwitchEnabled = killSwitchEnabled, + mode = mode.toDto(), + activeTunnels = activeList + ) +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/plugin/UDSPlugins.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/plugin/UDSPlugins.kt new file mode 100644 index 0000000..6b41bcd --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/plugin/UDSPlugins.kt @@ -0,0 +1,44 @@ +package com.zaneschepke.wireguardautotunnel.daemon.plugin + +import co.touchlab.kermit.Logger +import com.zaneschepke.wireguardautotunnel.core.ipc.IPC +import com.zaneschepke.wireguardautotunnel.core.crypto.HmacProtector +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.SecureCommand +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.createRouteScopedPlugin +import io.ktor.server.request.receive +import io.ktor.server.response.respond +import io.ktor.util.AttributeKey + +object UDSPlugins { + + const val VERIFIED_PAYLOAD_KEY = "verifiedPayload" + + val hmacShieldPlugin = createRouteScopedPlugin("HmacShield") { + val payloadKey = AttributeKey(VERIFIED_PAYLOAD_KEY) + + onCall { call -> + try { + Logger.d { "Verifying request..." } + val command = call.receive() + + Logger.d { "Resolving users secret..." } + val secret = IPC.resolveKeyForUser(command.userHint) + ?: return@onCall call.respond(HttpStatusCode.Unauthorized, "Unable to resolve key for user") + + Logger.d { "Verifying users secret..." } + if (!HmacProtector.verify(secret, command)) { + Logger.e { "Invalid user secret. Unauthorized request." } + call.respond(HttpStatusCode.Unauthorized, "Invalid HMAC Handshake") + return@onCall + } + Logger.d { "Constructing verified payload..." } + command.payload?.let { call.attributes.put(payloadKey, it) } + + } catch (e: Exception) { + Logger.e(e){ "Invalid user secret. Unauthorized request." } + call.respond(HttpStatusCode.BadRequest, "Secure envelope missing or malformed: ${e.message}") + } + } + } +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/routes/tunnelCommandRoutes.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/routes/tunnelCommandRoutes.kt new file mode 100644 index 0000000..2892407 --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/routes/tunnelCommandRoutes.kt @@ -0,0 +1,109 @@ +// 3. Updated tunnelCommandRoutes.kt (no TunnelRepository dependency) +package com.zaneschepke.wireguardautotunnel.daemon.routes + +import co.touchlab.kermit.Logger +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.BackendMode +import com.zaneschepke.wireguardautotunnel.core.ipc.dto.StartTunnelRequest +import com.zaneschepke.wireguardautotunnel.daemon.dto.toDto +import com.zaneschepke.wireguardautotunnel.daemon.tunnel.RunningTunnel +import com.zaneschepke.wireguardautotunnel.daemon.util.unwrapVerifiedPayload +import com.zaneschepke.wireguardautotunnel.daemon.util.verifiedPayload +import com.zaneschepke.wireguardautotunnel.tunnel.Backend +import io.ktor.http.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.server.websocket.* +import io.ktor.websocket.* +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.withContext +import kotlinx.serialization.json.Json +import org.koin.java.KoinJavaComponent.inject + +fun Route.tunnelCommandRoutes(json: Json, backend: Backend) { + + val logger = Logger.withTag("TunnelCommands") + + // START TUNNEL + post("/start") { + val request = call.unwrapVerifiedPayload( + json = json, + logger = logger, + logMessage = "Failed to parse start request" + ) ?: return@post + + logger.i { "Starting tunnel ${request.id} (${request.name})" } + + val tunnel = RunningTunnel(request.id, request.name) + + val result = backend.start(tunnel, request.quickConfig) + + if (result.isFailure) { + logger.e(result.exceptionOrNull()) { "Failed to start tunnel ${request.id}" } + call.respond(HttpStatusCode.InternalServerError, "Failed to start tunnel") + } else { + call.respond(HttpStatusCode.OK, "Tunnel ${request.id} started") + } + } + + // STOP TUNNEL + post("/stop/{id}") { + val id = call.parameters["id"]?.toIntOrNull() + ?: return@post call.respond(HttpStatusCode.BadRequest, "Missing or invalid id") + + logger.i { "Stopping tunnel $id" } + + backend.stop(id) + + call.respond(HttpStatusCode.OK, "Tunnel $id stopped") + } + +// post("/mode") { +// val mode = call.unwrapVerifiedPayload( +// json = json, +// logger = logger, +// logMessage = "Failed to parse mode request" +// ) ?: return@post +// +// logger.i { "Setting backend mode to $mode" } +// backend.setMode(mode) +// call.respond(HttpStatusCode.OK, "Mode set to $mode") +// } +// +// post("/kill-switch") { +// val enabledStr = call.verifiedPayload()?.trim() +// ?: return@post call.respond(HttpStatusCode.BadRequest, "Missing enabled value") +// +// val enabled = enabledStr.equals("true", ignoreCase = true) +// +// logger.i { "Setting kill switch to $enabled" } +// val result = backend.setKillSwitch(enabled) +// if (result.isFailure) { +// call.respond(HttpStatusCode.InternalServerError, "Failed to toggle kill switch") +// } else { +// call.respond(HttpStatusCode.OK, "Kill switch set to $enabled") +// } +// } +// +// get("/status") { +// val status = backend.status.first() +// call.respond(HttpStatusCode.OK, status.toDto()) +// } +// +// webSocket("/status/stream") { +// logger.i { "Client connected to /tunnel/status/stream" } +// try { +// backend.status +// .map { it.toDto() } +// .collect { dto -> +// val text = json.encodeToString(dto) +// send(Frame.Text(text)) +// } +// } catch (e: Exception) { +// logger.e(e) { "Error streaming status" } +// } finally { +// logger.i { "Client disconnected from status stream" } +// } +// } +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/tunnel/RunningTunnel.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/tunnel/RunningTunnel.kt new file mode 100644 index 0000000..9eb2697 --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/tunnel/RunningTunnel.kt @@ -0,0 +1,15 @@ +package com.zaneschepke.wireguardautotunnel.daemon.tunnel + +import co.touchlab.kermit.Logger +import com.zaneschepke.wireguardautotunnel.tunnel.Tunnel + +class RunningTunnel( + override val id: Int, + override val name: String, + override val features: Set = emptySet() +) : Tunnel { + + override fun updateState(state: Tunnel.State) { + Logger.i { "Tunnel $id ($name) state changed → $state" } + } +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/util/Logger.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/util/Logger.kt new file mode 100644 index 0000000..9977cff --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/util/Logger.kt @@ -0,0 +1,11 @@ +package com.zaneschepke.wireguardautotunnel.daemon.util + +import co.touchlab.kermit.Logger +import co.touchlab.kermit.Severity +import co.touchlab.kermit.platformLogWriter + +fun initLogger() { + Logger.setLogWriters(platformLogWriter()) + Logger.setMinSeverity(Severity.Debug) + Logger.setTag("Daemon") +} \ No newline at end of file diff --git a/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/util/UdsExtensions.kt b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/util/UdsExtensions.kt new file mode 100644 index 0000000..ce99b02 --- /dev/null +++ b/daemon/src/main/kotlin/com/zaneschepke/wireguardautotunnel/daemon/util/UdsExtensions.kt @@ -0,0 +1,35 @@ +package com.zaneschepke.wireguardautotunnel.daemon.util + +import co.touchlab.kermit.Logger +import com.zaneschepke.wireguardautotunnel.daemon.plugin.UDSPlugins +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.ApplicationCall +import io.ktor.server.response.respond +import io.ktor.util.AttributeKey +import kotlinx.serialization.json.Json + +inline fun ApplicationCall.verifiedPayload(json: Json): Result { + val payloadStr = attributes.getOrNull(AttributeKey(UDSPlugins.VERIFIED_PAYLOAD_KEY)) + ?: return Result.failure(IllegalArgumentException("Missing payload")) + + return try { + Result.success(json.decodeFromString(payloadStr)) + } catch (e: Exception) { + Result.failure(e) + } +} + +suspend inline fun ApplicationCall.unwrapVerifiedPayload( + json: Json, + logger: Logger, + logMessage: String = "Failed to parse payload", + errorResponseMessage: String = "Invalid JSON payload" +): T? { + val result = verifiedPayload(json) + if (result.isFailure) { + logger.e(result.exceptionOrNull()) { logMessage } + respond(HttpStatusCode.BadRequest, errorResponseMessage) + return null + } + return result.getOrThrow() +} \ No newline at end of file diff --git a/daemon/src/main/resources/macos/cli.entitlements b/daemon/src/main/resources/macos/cli.entitlements new file mode 100644 index 0000000..7266d08 --- /dev/null +++ b/daemon/src/main/resources/macos/cli.entitlements @@ -0,0 +1,10 @@ + + + + + com.apple.security.network.client + + com.apple.security.files.user-selected.read-write + + + \ No newline at end of file diff --git a/daemon/src/main/resources/macos/daemon.entitlements b/daemon/src/main/resources/macos/daemon.entitlements new file mode 100644 index 0000000..4568f32 --- /dev/null +++ b/daemon/src/main/resources/macos/daemon.entitlements @@ -0,0 +1,13 @@ + + + + + com.apple.security.network.client + + com.apple.security.network.server + + com.apple.security.files.user-selected.read-write + + + + \ No newline at end of file diff --git a/daemon/src/main/resources/macos/wgtunnel-daemon.plist b/daemon/src/main/resources/macos/wgtunnel-daemon.plist new file mode 100644 index 0000000..6e286cc --- /dev/null +++ b/daemon/src/main/resources/macos/wgtunnel-daemon.plist @@ -0,0 +1,34 @@ + + + + + Label + com.zaneschepke.wgtunnel.daemon + ProgramArguments + + /Applications/WireGuard AutoTunnel.app/Contents/MacOS/wgtunnel + _run + + RunAtLoad + + KeepAlive + + StandardOutPath + /Library/Logs/wgtunnel/stdout.log + StandardErrorPath + /Library/Logs/wgtunnel/stderr.log + EnvironmentVariables + + WG_TUNNEL_SERVICE + 1 + HOME + /Library/Application Support/wgtunnel + + WorkingDirectory + /Library/Application Support/wgtunnel + ThrottleInterval + 1 + ExitTimeOut + 30 + + \ No newline at end of file diff --git a/daemon/winsw b/daemon/winsw new file mode 160000 index 0000000..8b5db6e --- /dev/null +++ b/daemon/winsw @@ -0,0 +1 @@ +Subproject commit 8b5db6eaa050d31cda12afc6d870855f23180d37 diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 0000000..5dbc22c --- /dev/null +++ b/gradle.properties @@ -0,0 +1,7 @@ +#Kotlin +kotlin.code.style=official +kotlin.daemon.jvmargs=-Xmx3072M +#Gradle +org.gradle.jvmargs=-Xmx3072M -Dfile.encoding=UTF-8 +org.gradle.configuration-cache=true +org.gradle.caching=true \ No newline at end of file diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml new file mode 100644 index 0000000..d5ed3c8 --- /dev/null +++ b/gradle/libs.versions.toml @@ -0,0 +1,127 @@ +[versions] +app = "1.0.0" +jvm = "21" +androidx-lifecycle = "2.9.6" +composeHotReload = "1.0.0" +compose-plugin = "1.11.0-alpha01" +junit = "4.13.2" +kotlin = "2.3.0" +kotlinx-coroutines = "1.10.2" +material3 = "1.10.0-alpha05" +serialization = "1.9.0" +cryptoRand = "0.6.0" +curve25519Kotlin = "0.0.8" +koin = "4.2.0-beta2" +ktor = "3.3.3" +conveyor = "1.13" +ksp = "2.3.0" +desktopJvm = "1.11.0-alpha01" +jnaPlatform = "5.18.1" +buildconfig = "6.0.7" +retry = "2.0.2" + +moko = "0.25.2" + +# Logging +logbackClassic = "1.5.24" +kermit = "2.0.8" + +picocli = "4.7.7" + +androidx-room = "2.8.4" +androidx-sqlite = "2.6.2" +kstore = "1.0.0" + +lang3 = "3.20.0" + +[bundles] +ktor-client-jvm = ["ktor-client-core-jvm", "ktor-client-cio-jvm", "ktor-client-content-negotiation-jvm", "ktor-serialization-json-jvm", "ktor-client-okhttp", "ktor-client-websockets-jvm"] +ktor-server-jvm = ["ktor-server-cio-jvm", "ktor-serialization-json-jvm", "ktor-server-content-negotiation-jvm", "ktor-server-core-jvm", "ktor-server-websockets-jvm"] + +[libraries] +kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" } +kotlin-testJunit = { module = "org.jetbrains.kotlin:kotlin-test-junit", version.ref = "kotlin" } +junit = { module = "junit:junit", version.ref = "junit" } +androidx-lifecycle-viewmodelCompose = { module = "org.jetbrains.androidx.lifecycle:lifecycle-viewmodel-compose", version.ref = "androidx-lifecycle" } +androidx-lifecycle-runtimeCompose = { module = "org.jetbrains.androidx.lifecycle:lifecycle-runtime-compose", version.ref = "androidx-lifecycle" } +compose-runtime = { module = "org.jetbrains.compose.runtime:runtime", version.ref = "compose-plugin" } +compose-foundation = { module = "org.jetbrains.compose.foundation:foundation", version.ref = "compose-plugin" } +compose-material3 = { module = "org.jetbrains.compose.material3:material3", version.ref = "material3" } +compose-ui = { module = "org.jetbrains.compose.ui:ui", version.ref = "compose-plugin" } +compose-components-resources = { module = "org.jetbrains.compose.components:components-resources", version.ref = "compose-plugin" } +compose-uiToolingPreview = { module = "org.jetbrains.compose.ui:ui-tooling-preview", version.ref = "compose-plugin" } +kotlinx-coroutinesSwing = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-swing", version.ref = "kotlinx-coroutines" } + +# Serialization +kotlinx-serialization = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-json", version.ref = "serialization" } +kotlinx-serialization-core = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-core", version.ref = "serialization" } + +# cryto +crypto-rand = { module = "org.kotlincrypto.random:crypto-rand", version.ref = "cryptoRand" } +curve25519-kotlin = { module = "io.github.andreypfau:curve25519-kotlin", version.ref = "curve25519Kotlin" } + +# DI +koin-core = { module = "io.insert-koin:koin-core", version.ref = "koin" } + +# Targets +desktop-jvm-linux-x64 = { module = "org.jetbrains.compose.desktop:desktop-jvm-linux-x64", version.ref = "desktopJvm" } +desktop-jvm-macos-arm64 = { module = "org.jetbrains.compose.desktop:desktop-jvm-macos-arm64", version.ref = "desktopJvm" } +desktop-jvm-macos-x64 = { module = "org.jetbrains.compose.desktop:desktop-jvm-macos-x64", version.ref = "desktopJvm" } +desktop-jvm-windows-x64 = { module = "org.jetbrains.compose.desktop:desktop-jvm-windows-x64", version.ref = "desktopJvm" } +desktop-jvm-windows-arm64 = { module = "org.jetbrains.compose.desktop:desktop-jvm-windows-arm64", version.ref = "desktopJvm" } + +# Ktor Server +ktor-server-core-jvm = { module = "io.ktor:ktor-server-core-jvm", version.ref = "ktor" } +ktor-server-cio-jvm = { module = "io.ktor:ktor-server-cio-jvm", version.ref = "ktor" } +ktor-server-content-negotiation-jvm = { module = "io.ktor:ktor-server-content-negotiation-jvm", version.ref = "ktor" } +ktor-server-websockets-jvm = { module = "io.ktor:ktor-server-websockets-jvm", version.ref = "ktor" } + +# Ktor Client +ktor-client-core-jvm = { module = "io.ktor:ktor-client-core-jvm", version.ref = "ktor" } +ktor-client-cio-jvm = { module = "io.ktor:ktor-client-cio-jvm", version.ref = "ktor" } +ktor-client-content-negotiation-jvm = { module = "io.ktor:ktor-client-content-negotiation-jvm", version.ref = "ktor" } +ktor-serialization-json-jvm = { module = "io.ktor:ktor-serialization-kotlinx-json-jvm", version.ref = "ktor" } +ktor-client-okhttp = { module = "io.ktor:ktor-client-okhttp", version.ref = "ktor" } +ktor-client-darwin = { module = "io.ktor:ktor-client-darwin", version.ref = "ktor" } +ktor-client-websockets-jvm = { module = "io.ktor:ktor-client-websockets-jvm", version.ref = "ktor" } + +# Coroutines +kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" } + +# Logging +kermit = { module = "co.touchlab:kermit", version.ref = "kermit" } +logback-classic = { module = "ch.qos.logback:logback-classic", version.ref = "logbackClassic" } + +# CLI +picocli = { module = "info.picocli:picocli", version.ref = "picocli" } +picocli-codegen = { module = "info.picocli:picocli-codegen", version.ref = "picocli" } + +jna-platform = { module = "net.java.dev.jna:jna-platform", version.ref = "jnaPlatform" } + +# Storage +androidx-room-runtime = { module = "androidx.room:room-runtime", version.ref = "androidx-room" } +androidx-sqlite-bundled = { module = "androidx.sqlite:sqlite-bundled", version.ref = "androidx-sqlite" } +androidx-room-compiler = { module = "androidx.room:room-compiler", version.ref = "androidx-room" } +kstore = { module = "io.github.xxfast:kstore", version.ref = "kstore" } +kstore-file = { module = "io.github.xxfast:kstore-file", version.ref = "kstore" } + +# Util +apache-commons-lang3 = { module = "org.apache.commons:commons-lang3", version.ref = "lang3" } + +# Resources +moko-core = { module = "dev.icerock.moko:resources", version.ref = "moko" } +moko-compose = { module = "dev.icerock.moko:resources-compose", version.ref = "moko" } + +kotlin-retry = { module = "com.michael-bull.kotlin-retry:kotlin-retry", version.ref = "retry" } + +[plugins] +composeHotReload = { id = "org.jetbrains.compose.hot-reload", version.ref = "composeHotReload" } +jetbrainsCompose = { id = "org.jetbrains.compose", version.ref = "compose-plugin" } +composeCompiler = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } +kotlinMultiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" } +serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "ksp" } +conveyor = { id = "dev.hydraulic.conveyor", version.ref = "conveyor" } +room = { id = "androidx.room", version.ref = "androidx-room" } +ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" } +moko = { id = "dev.icerock.mobile.multiplatform-resources", version.ref = "moko" } +buildconfig = { id = "com.github.gmazzo.buildconfig", version.ref = "buildconfig"} \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000..1b33c55 Binary files /dev/null and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000..d4081da --- /dev/null +++ b/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew new file mode 100755 index 0000000..23d15a9 --- /dev/null +++ b/gradlew @@ -0,0 +1,251 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s\n' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH="\\\"\\\"" + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + -jar "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat new file mode 100644 index 0000000..db3a6ac --- /dev/null +++ b/gradlew.bat @@ -0,0 +1,94 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH= + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/icon.png b/icon.png new file mode 100644 index 0000000..ca3b402 Binary files /dev/null and b/icon.png differ diff --git a/keyring/.gitignore b/keyring/.gitignore new file mode 100644 index 0000000..91be53f --- /dev/null +++ b/keyring/.gitignore @@ -0,0 +1,3 @@ +/build +tools/keyring-go/out +src/main/resources/* \ No newline at end of file diff --git a/keyring/build.gradle.kts b/keyring/build.gradle.kts new file mode 100644 index 0000000..98bd9d6 --- /dev/null +++ b/keyring/build.gradle.kts @@ -0,0 +1,38 @@ +plugins { + kotlin("jvm") +} + +dependencies { + implementation(libs.jna.platform) +} + +tasks.register("buildGoLibs") { + val libDir = "tools/keyring-go" + group = "build" + description = "Builds Go shared libs using Makefile" + workingDir = file(libDir) + + inputs.dir(file(libDir)) + .withPropertyName("goSourceDir") + .withPathSensitivity(PathSensitivity.RELATIVE) + + outputs.dir(file("src/main/resources")) + .withPropertyName("outputResourcesDir") + + commandLine("make", "all") +} + +tasks.named("processResources") { + dependsOn("buildGoLibs") +} + +val cleanGoLibs = tasks.register("cleanGoLibs") { + workingDir = file("tools/keyring-go") + commandLine("make", "clean") +} + +tasks.named("clean") { + dependsOn(cleanGoLibs) + delete(file("tools/keyring-go/out")) + delete(file("src/main/resources")) +} diff --git a/keyring/src/main/kotlin/com/zaneschepke/wireguardautotunnel/keyring/Keyring.kt b/keyring/src/main/kotlin/com/zaneschepke/wireguardautotunnel/keyring/Keyring.kt new file mode 100644 index 0000000..972d3fd --- /dev/null +++ b/keyring/src/main/kotlin/com/zaneschepke/wireguardautotunnel/keyring/Keyring.kt @@ -0,0 +1,29 @@ +package com.zaneschepke.wireguardautotunnel.keyring + +import com.sun.jna.Native +import com.sun.jna.Pointer + +class Keyring(private val service: String) { + + private val native = NativeKeyring.INSTANCE + + fun put(name: String, value: String) { + val result = native.storeSecret(service, name, value) + check(result == 1) { + "Failed to store secret: $name" + } + } + + fun get(name: String): String? { + val ptr: Pointer = native.getSecret(service, name) ?: return null + return try { + ptr.getString(0) + } finally { + Native.free(Pointer.nativeValue(ptr)) + } + } + + fun delete(name: String) { + native.deleteSecret(service, name) + } +} diff --git a/keyring/src/main/kotlin/com/zaneschepke/wireguardautotunnel/keyring/NativeKeyring.kt b/keyring/src/main/kotlin/com/zaneschepke/wireguardautotunnel/keyring/NativeKeyring.kt new file mode 100644 index 0000000..0956541 --- /dev/null +++ b/keyring/src/main/kotlin/com/zaneschepke/wireguardautotunnel/keyring/NativeKeyring.kt @@ -0,0 +1,29 @@ +package com.zaneschepke.wireguardautotunnel.keyring + +import com.sun.jna.Library +import com.sun.jna.Native +import com.sun.jna.Pointer + +interface NativeKeyring : Library { + + fun storeSecret( + service: String, + name: String, + value: String + ): Int + + fun getSecret( + service: String, + name: String + ): Pointer? + + fun deleteSecret( + service: String, + name: String + ): Int + + companion object { + val INSTANCE: NativeKeyring = + Native.load("keyring", NativeKeyring::class.java) + } +} diff --git a/keyring/tools/keyring-go/Makefile b/keyring/tools/keyring-go/Makefile new file mode 100644 index 0000000..04e7a59 --- /dev/null +++ b/keyring/tools/keyring-go/Makefile @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 + +DESTDIR ?= $(CURDIR)/out +RESOURCEDIR ?= $(CURDIR)/../../src/main/resources +HOST_OS := $(shell uname -s | tr '[:upper:]' '[:lower:]') + +ifeq ($(HOST_OS),darwin) + PLATFORMS := darwin-amd64 darwin-arm64 linux-amd64 windows-amd64 windows-arm64 +else + PLATFORMS := linux-amd64 windows-amd64 +endif + +export CGO_ENABLED := 1 + +# Standard Go build command with -a (force rebuild) +# Added -v so you can see the progress in Gradle logs +GOBUILD := go build -a -v -buildmode=c-shared -trimpath -ldflags="-buildid=" + +default: all + +# Macro to keep the build targets clean +# Usage: $(call go-build,GOOS,GOARCH,CC,OUTPUT) +define go-build + @mkdir -p "$(DESTDIR)" + GOOS=$(1) GOARCH=$(2) CC=$(3) $(GOBUILD) -o "$(4)" . +endef + +$(DESTDIR)/libkeyring-linux-amd64.so: $(wildcard *.go) + $(call go-build,linux,amd64,gcc,$@) + +$(DESTDIR)/libkeyring-windows-amd64.dll: $(wildcard *.go) + $(call go-build,windows,amd64,x86_64-w64-mingw32-gcc,$@) + +$(DESTDIR)/libkeyring-windows-arm64.dll: $(wildcard *.go) + $(call go-build,windows,arm64,aarch64-w64-mingw32-gcc,$@) + +$(DESTDIR)/libkeyring-darwin-amd64.dylib: $(wildcard *.go) + $(call go-build,darwin,amd64,clang,$@) + +$(DESTDIR)/libkeyring-darwin-arm64.dylib: $(wildcard *.go) + $(call go-build,darwin,arm64,clang,$@) + +build_all: $(foreach plat,$(PLATFORMS),$(DESTDIR)/libkeyring-$(plat).$(if $(findstring windows,$(plat)),dll,$(if $(findstring darwin,$(plat)),dylib,so))) + +copy_to_resources: build_all + @for plat in $(PLATFORMS); do \ + os=`echo "$$plat" | cut -d- -f1`; \ + arch=`echo "$$plat" | cut -d- -f2`; \ + jna_dir=; \ + if [ "$$os" = "linux" ] && [ "$$arch" = "amd64" ]; then jna_dir=linux-x86-64; fi; \ + if [ "$$os" = "windows" ] && [ "$$arch" = "amd64" ]; then jna_dir=win32-x86-64; fi; \ + if [ "$$os" = "windows" ] && [ "$$arch" = "arm64" ]; then jna_dir=win32-aarch64; fi; \ + if [ "$$os" = "darwin" ] && [ "$$arch" = "amd64" ]; then jna_dir=darwin-x86-64; fi; \ + if [ "$$os" = "darwin" ] && [ "$$arch" = "arm64" ]; then jna_dir=darwin-aarch64; fi; \ + libext=so; \ + if [ "$$os" = "windows" ]; then libext=dll; fi; \ + if [ "$$os" = "darwin" ]; then libext=dylib; fi; \ + dest_dir="$(RESOURCEDIR)/$$jna_dir"; \ + mkdir -p "$$dest_dir"; \ + cp "$(DESTDIR)/libkeyring-$$plat.$$libext" "$$dest_dir/libkeyring.$$libext"; \ + echo "Copied $$plat -> $$dest_dir/libkeyring.$$libext"; \ + done + +all: copy_to_resources + +clean: + rm -rf "$(DESTDIR)" + +.PHONY: default all build_all copy_to_resources clean +.DELETE_ON_ERROR: \ No newline at end of file diff --git a/keyring/tools/keyring-go/go.mod b/keyring/tools/keyring-go/go.mod new file mode 100644 index 0000000..4cfa222 --- /dev/null +++ b/keyring/tools/keyring-go/go.mod @@ -0,0 +1,12 @@ +module github.com/wgtunnel/desktop/keyring + +go 1.25.5 + +require github.com/zalando/go-keyring v0.2.6 + +require ( + al.essio.dev/pkg/shellescape v1.5.1 // indirect + github.com/danieljoos/wincred v1.2.2 // indirect + github.com/godbus/dbus/v5 v5.1.0 // indirect + golang.org/x/sys v0.26.0 // indirect +) diff --git a/keyring/tools/keyring-go/go.sum b/keyring/tools/keyring-go/go.sum new file mode 100644 index 0000000..8c1e0e0 --- /dev/null +++ b/keyring/tools/keyring-go/go.sum @@ -0,0 +1,22 @@ +al.essio.dev/pkg/shellescape v1.5.1 h1:86HrALUujYS/h+GtqoB26SBEdkWfmMI6FubjXlsXyho= +al.essio.dev/pkg/shellescape v1.5.1/go.mod h1:6sIqp7X2P6mThCQ7twERpZTuigpr6KbZWtls1U8I890= +github.com/danieljoos/wincred v1.2.2 h1:774zMFJrqaeYCK2W57BgAem/MLi6mtSE47MB6BOJ0i0= +github.com/danieljoos/wincred v1.2.2/go.mod h1:w7w4Utbrz8lqeMbDAK0lkNJUv5sAOkFi7nd/ogr0Uh8= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= +github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s= +github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/keyring/tools/keyring-go/keyring.go b/keyring/tools/keyring-go/keyring.go new file mode 100644 index 0000000..02d6dd7 --- /dev/null +++ b/keyring/tools/keyring-go/keyring.go @@ -0,0 +1,75 @@ +package main + +/* +#include +*/ +import "C" + +import ( + "errors" + + "github.com/zalando/go-keyring" +) + +//export storeSecret +func storeSecret(service *C.char, name *C.char, value *C.char) C.int { + if service == nil || name == nil || value == nil { + return C.int(-1) + } + + err := keyring.Set( + C.GoString(service), + C.GoString(name), + C.GoString(value), + ) + + if err != nil { + return C.int(-1) + } + + return C.int(1) +} + +//export getSecret +func getSecret(service *C.char, name *C.char) *C.char { + if service == nil || name == nil { + return nil + } + + value, err := keyring.Get( + C.GoString(service), + C.GoString(name), + ) + + if err != nil { + if errors.Is(err, keyring.ErrNotFound) { + return nil + } + return nil + } + + return C.CString(value) +} + +//export deleteSecret +func deleteSecret(service *C.char, name *C.char) C.int { + if service == nil || name == nil { + return C.int(-1) + } + + err := keyring.Delete( + C.GoString(service), + C.GoString(name), + ) + + if err != nil { + if errors.Is(err, keyring.ErrNotFound) { + return C.int(-1) + } + return C.int(-1) + } + + return C.int(1) +} + +func main() {} diff --git a/parser/build.gradle.kts b/parser/build.gradle.kts new file mode 100644 index 0000000..f4d3fbf --- /dev/null +++ b/parser/build.gradle.kts @@ -0,0 +1,17 @@ +plugins { + kotlin("jvm") + alias(libs.plugins.serialization) +} + +dependencies { + testImplementation(kotlin("test")) + + implementation(libs.kotlinx.serialization.core) + + implementation(libs.crypto.rand) + implementation(libs.curve25519.kotlin) +} + +tasks.test { + useJUnitPlatform() +} \ No newline at end of file diff --git a/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/Config.kt b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/Config.kt new file mode 100644 index 0000000..b4916d4 --- /dev/null +++ b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/Config.kt @@ -0,0 +1,157 @@ +package com.zaneschepke.wireguardautotunnel.parser + +import com.zaneschepke.wireguardautotunnel.parser.crypto.Key +import com.zaneschepke.wireguardautotunnel.parser.util.getBool +import com.zaneschepke.wireguardautotunnel.parser.util.getInt +import com.zaneschepke.wireguardautotunnel.parser.util.getList +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class Config( + @SerialName("Interface") val `interface`: InterfaceSection, + @SerialName("Peer") val peers: List = emptyList() +) { + + @Throws(ConfigParseException::class) + fun validate() { + `interface`.validate() + peers.forEachIndexed { index, peer -> peer.validate(index) } + } + + fun asQuickString(): String = buildString { + appendLine("[Interface]") + appendLine("PrivateKey = ${`interface`.privateKey}") + `interface`.address?.let { appendLine("Address = $it") } + `interface`.dns?.let { appendLine("DNS = $it") } + `interface`.listenPort?.let { appendLine("ListenPort = $it") } + `interface`.mtu?.let { appendLine("MTU = $it") } + `interface`.fwMark?.let { appendLine("FwMark = $it") } + `interface`.table?.let { appendLine("Table = $it") } + `interface`.saveConfig?.let { appendLine("SaveConfig = $it") } + + // AmneziaWG + `interface`.jC?.let { appendLine("Jc = $it") } + `interface`.jMin?.let { appendLine("Jmin = $it") } + `interface`.jMax?.let { appendLine("Jmax = $it") } + `interface`.s1?.let { appendLine("S1 = $it") } + `interface`.s2?.let { appendLine("S2 = $it") } + `interface`.s3?.let { appendLine("S3 = $it") } + `interface`.s4?.let { appendLine("S4 = $it") } + `interface`.h1?.let { appendLine("H1 = $it") } + `interface`.h2?.let { appendLine("H2 = $it") } + `interface`.h3?.let { appendLine("H3 = $it") } + `interface`.h4?.let { appendLine("H4 = $it") } + `interface`.i1?.let { appendLine("I1 = $it") } + `interface`.i2?.let { appendLine("I2 = $it") } + `interface`.i3?.let { appendLine("I3 = $it") } + `interface`.i4?.let { appendLine("I4 = $it") } + `interface`.i5?.let { appendLine("I5 = $it") } + + `interface`.includedApplications?.let { appendLine("IncludedApplications = ${it.joinToString(",")}") } + `interface`.excludedApplications?.let { appendLine("ExcludedApplications = ${it.joinToString(",")}") } + + peers.forEach { peer -> + append("\n[Peer]\n") + appendLine("PublicKey = ${peer.publicKey}") + peer.endpoint?.let { appendLine("Endpoint = $it") } + peer.allowedIPs?.let { appendLine("AllowedIPs = $it") } + peer.presharedKey?.let { appendLine("PresharedKey = $it") } + peer.persistentKeepalive?.let { appendLine("PersistentKeepalive = $it") } + } + }.trim() + + fun rotateInterfaceKey(): Config { + val privateKey = Key.generatePrivateKey() + val newInterface = `interface`.copy(privateKey = privateKey.toBase64()) + return copy(`interface` = newInterface) + } + + companion object { + @Throws(ConfigParseException::class) + fun parseQuickString(configString: String): Config { + val interfaceMap = mutableMapOf() + val peerMaps = mutableListOf>() + var currentSection: MutableMap? = null + + configString.lines().forEach { line -> + val trimmed = line.split("#", ";")[0].trim() + if (trimmed.isEmpty()) return@forEach + + if (trimmed.startsWith("[") && trimmed.endsWith("]")) { + currentSection = when (trimmed.substring(1, trimmed.length - 1).lowercase()) { + "interface" -> interfaceMap + "peer" -> mutableMapOf().also { peerMaps.add(it) } + else -> null // ignore unknown + } + return@forEach + } + + val parts = trimmed.split("=", limit = 2) + if (parts.size == 2) { + currentSection?.put(parts[0].trim(), parts[1].trim()) + } + } + + if (interfaceMap.isEmpty()) throw ConfigParseException(ErrorType.MISSING_REQUIRED_FIELD, "Interface") + + return Config( + `interface` = buildInterface(interfaceMap), + peers = peerMaps.map { buildPeer(it) } + ).also { it.validate() } + } + + private fun buildInterface(m: Map) = InterfaceSection( + privateKey = m["PrivateKey"] ?: "", + address = m["Address"], + dns = m["DNS"], + listenPort = m.getInt("ListenPort", "Interface"), + mtu = m.getInt("MTU", "Interface"), + fwMark = m.getInt("FwMark", "Interface"), + table = m["Table"], + saveConfig = m.getBool("SaveConfig", "Interface"), + jC = m.getInt("Jc", "Interface"), + jMin = m.getInt("Jmin", "Interface"), + jMax = m.getInt("Jmax", "Interface"), + s1 = m.getInt("S1", "Interface"), + s2 = m.getInt("S2", "Interface"), + s3 = m.getInt("S3", "Interface"), + s4 = m.getInt("S4", "Interface"), + h1 = m["H1"], h2 = m["H2"], h3 = m["H3"], h4 = m["H4"], + i1 = m["I1"], i2 = m["I2"], i3 = m["I3"], i4 = m["I4"], i5 = m["I5"], + includedApplications = m.getList("IncludedApplications"), + excludedApplications = m.getList("ExcludedApplications") + ) + + private fun buildPeer(m: Map) = PeerSection( + publicKey = m["PublicKey"] ?: "", + allowedIPs = m["AllowedIPs"], + endpoint = m["Endpoint"], + presharedKey = m["PresharedKey"], + persistentKeepalive = m.getInt("PersistentKeepalive", "Peer") + ) + + fun parseEndpoint(endpoint: String): Pair { + var host: String + var portStr: String? + if (endpoint.startsWith("[")) { + val endBracket = endpoint.lastIndexOf("]") + if (endBracket == -1 || !endpoint.substring(endBracket + 1).startsWith(":")) return null to null + host = endpoint.take(endBracket + 1) + portStr = endpoint.substring(endBracket + 2) + } else { + val parts = endpoint.split(":", limit = 2) + if (parts.size != 2) return null to null + host = parts[0] + portStr = parts[1] + } + return host to portStr + } + + internal fun generatePublicKeyFromPrivate(privateBase64: String): String { + val privateKey = Key.fromBase64(privateBase64) + val publicKey = Key.generatePublicKey(privateKey) + return publicKey.toBase64() + } + } +} \ No newline at end of file diff --git a/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/ConfigParseException.kt b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/ConfigParseException.kt new file mode 100644 index 0000000..f78a64c --- /dev/null +++ b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/ConfigParseException.kt @@ -0,0 +1,9 @@ +package com.zaneschepke.wireguardautotunnel.parser + +class ConfigParseException( + val errorType: ErrorType, + val field: String, + val value: Any? = null, + val extra: String? = null, + message: String = "$field: $errorType${value?.let { " (value: $it)" } ?: ""}${extra?.let { " ($it)" } ?: ""}" +) : Exception(message) \ No newline at end of file diff --git a/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/ErrorType.kt b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/ErrorType.kt new file mode 100644 index 0000000..68482a7 --- /dev/null +++ b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/ErrorType.kt @@ -0,0 +1,22 @@ +package com.zaneschepke.wireguardautotunnel.parser + +enum class ErrorType { + MISSING_REQUIRED_FIELD, + INVALID_BASE64_KEY, + INVALID_PORT_RANGE, + INVALID_MTU_RANGE, + INVALID_FWMARK, + INVALID_JC_RANGE, + INVALID_JMIN_JMAX_ORDER, + INVALID_JMAX_MTU, + INVALID_PADDING_NEGATIVE, + INVALID_HEADER_FORMAT, + INVALID_SIGNATURE_FORMAT, + INVALID_ENDPOINT_FORMAT, + INVALID_KEEPALIVE_NEGATIVE, + INVALID_CIDR, + INVALID_IP, + INVALID_HOSTNAME, + INVALID_DNS_ENTRY, + INVALID_VALUE_FORMAT +} \ No newline at end of file diff --git a/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/InterfaceSection.kt b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/InterfaceSection.kt new file mode 100644 index 0000000..d99edb9 --- /dev/null +++ b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/InterfaceSection.kt @@ -0,0 +1,84 @@ +package com.zaneschepke.wireguardautotunnel.parser + +import com.zaneschepke.wireguardautotunnel.parser.util.NetworkUtils +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class InterfaceSection( + @SerialName("PrivateKey") val privateKey: String, + @SerialName("Address") val address: String? = null, + @SerialName("ListenPort") val listenPort: Int? = null, + @SerialName("DNS") val dns: String? = null, + @SerialName("MTU") val mtu: Int? = null, + // Linux + @SerialName("FwMark") val fwMark: Int? = null, + @SerialName("Table") val table: String? = null, + @SerialName("SaveConfig") val saveConfig: Boolean? = null, + // Desktop or Rooted Android + @SerialName("PreUp") val preUp: String? = null, + @SerialName("PostUp") val postUp: String? = null, + @SerialName("PreDown") val preDown: String? = null, + @SerialName("PostDown") val postDown: String? = null, + // Android + @SerialName("IncludedApplications") val includedApplications: List? = null, + @SerialName("ExcludedApplications") val excludedApplications: List? = null, + // Amnezia + @SerialName("Jc") val jC: Int? = null, + @SerialName("Jmin") val jMin: Int? = null, + @SerialName("Jmax") val jMax: Int? = null, + @SerialName("S1") val s1: Int? = null, + @SerialName("S2") val s2: Int? = null, + @SerialName("S3") val s3: Int? = null, + @SerialName("S4") val s4: Int? = null, + @SerialName("H1") val h1: String? = null, + @SerialName("H2") val h2: String? = null, + @SerialName("H3") val h3: String? = null, + @SerialName("H4") val h4: String? = null, + @SerialName("I1") val i1: String? = null, + @SerialName("I2") val i2: String? = null, + @SerialName("I3") val i3: String? = null, + @SerialName("I4") val i4: String? = null, + @SerialName("I5") val i5: String? = null, +) { + val publicKey: String = Config.generatePublicKeyFromPrivate(privateKey) + + @Throws(ConfigParseException::class) + fun validate() { + if (privateKey.isBlank()) throw ConfigParseException(ErrorType.MISSING_REQUIRED_FIELD, "Interface.PrivateKey") + if (!NetworkUtils.isValidBase64(privateKey)) throw ConfigParseException(ErrorType.INVALID_BASE64_KEY, "Interface.PrivateKey", privateKey) + + listenPort?.let { if (it !in 0..65535) throw ConfigParseException(ErrorType.INVALID_PORT_RANGE, "Interface.ListenPort", it) } + mtu?.let { if (it !in 576..9000) throw ConfigParseException(ErrorType.INVALID_MTU_RANGE, "Interface.MTU", it) } + fwMark?.let { if (it < 0) throw ConfigParseException(ErrorType.INVALID_FWMARK, "Interface.FwMark", it) } + + jC?.let { if (it !in 4..12) throw ConfigParseException(ErrorType.INVALID_JC_RANGE, "Interface.Jc", it) } + if (jMin != null && jMax != null) { + if (jMin > jMax) throw ConfigParseException(ErrorType.INVALID_JMIN_JMAX_ORDER, "Interface.Jmin/Jmax") + if (jMax >= (mtu ?: 1500)) throw ConfigParseException(ErrorType.INVALID_JMAX_MTU, "Interface.Jmax", jMax) + } + + listOf(s1, s2, s3, s4).forEachIndexed { i, s -> + if (s != null && s < 0) throw ConfigParseException(ErrorType.INVALID_PADDING_NEGATIVE, "Interface.S${i + 1}", s) + } + + listOf(h1, h2, h3, h4).forEachIndexed { i, h -> + if (h != null && !NetworkUtils.isValidAmneziaHeader(h)) { + throw ConfigParseException(ErrorType.INVALID_HEADER_FORMAT, "Interface.H${i + 1}", h) + } + } + + listOf(i1, i2, i3, i4, i5).forEachIndexed { i, sig -> + if (sig != null && !NetworkUtils.isValidHexSignature(sig)) { + throw ConfigParseException(ErrorType.INVALID_SIGNATURE_FORMAT, "Interface.I${i + 1}", sig) + } + } + + address?.split(",")?.map { it.trim() }?.forEach { + if (it.isNotBlank() && !NetworkUtils.isValidCidr(it)) throw ConfigParseException(ErrorType.INVALID_CIDR, "Interface.Address", it) + } + dns?.split(",")?.map { it.trim() }?.forEach { + if (it.isNotBlank() && !NetworkUtils.isValidDnsEntry(it)) throw ConfigParseException(ErrorType.INVALID_DNS_ENTRY, "Interface.DNS", it) + } + } +} \ No newline at end of file diff --git a/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/PeerSection.kt b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/PeerSection.kt new file mode 100644 index 0000000..f82a327 --- /dev/null +++ b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/PeerSection.kt @@ -0,0 +1,53 @@ +package com.zaneschepke.wireguardautotunnel.parser + +import com.zaneschepke.wireguardautotunnel.parser.util.NetworkUtils +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class PeerSection( + @SerialName("PublicKey") val publicKey: String, + @SerialName("AllowedIPs") val allowedIPs: String? = null, + @SerialName("Endpoint") val endpoint: String? = null, + @SerialName("PresharedKey") val presharedKey: String? = null, + @SerialName("PersistentKeepalive") val persistentKeepalive: Int? = null +) { + + @Throws(ConfigParseException::class) + fun validate(index: Int) { + val prefix = "Peer[$index]" + if (publicKey.isBlank()) throw ConfigParseException(ErrorType.MISSING_REQUIRED_FIELD, "$prefix.PublicKey") + if (!NetworkUtils.isValidBase64(publicKey)) throw ConfigParseException(ErrorType.INVALID_BASE64_KEY, "$prefix.PublicKey", publicKey) + + persistentKeepalive?.let { if (it !in 0..65535) throw ConfigParseException(ErrorType.INVALID_KEEPALIVE_NEGATIVE, "$prefix.PersistentKeepalive", it) } + + endpoint?.let { + val (host, portStr) = Config.parseEndpoint(it) + val port = portStr?.toIntOrNull() + if (host == null || port == null || port !in 0..65535) { + throw ConfigParseException(ErrorType.INVALID_ENDPOINT_FORMAT, "$prefix.Endpoint", it) + } + if (!NetworkUtils.isValidDnsEntry(host)) { + throw ConfigParseException(ErrorType.INVALID_HOSTNAME, "$prefix.Endpoint host", host) + } + } + + allowedIPs?.split(",")?.map { it.trim() }?.forEach { + if (it.isNotBlank() && !NetworkUtils.isValidCidr(it)) { + throw ConfigParseException(ErrorType.INVALID_CIDR, "$prefix.AllowedIPs", it) + } + } + } + + val host: String? get() { + val (h, _) = endpoint?.let { Config.parseEndpoint(it) } ?: return null + return h + } + + val port: Int? get() { + val (_, p) = endpoint?.let { Config.parseEndpoint(it) } ?: return null + return p?.toIntOrNull() + } + val isStaticallyConfigured: Boolean + get() = host?.let { NetworkUtils.isValidIp(it) } ?: false +} \ No newline at end of file diff --git a/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/crypto/Key.kt b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/crypto/Key.kt new file mode 100644 index 0000000..8e96ead --- /dev/null +++ b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/crypto/Key.kt @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright © 2026 WG Tunnel. +// Adapted from WireGuard LLC. + +package com.zaneschepke.wireguardautotunnel.parser.crypto + +import io.github.andreypfau.curve25519.x25519.X25519 +import org.kotlincrypto.random.CryptoRand +import kotlin.experimental.and +import kotlin.experimental.or + +class KeyFormatException : Exception { + constructor(format: Key.Format, type: Key.Type) : super("Invalid key format: $format, type: $type") +} + +class Key private constructor(private val key: ByteArray) { + + fun getBytes(): ByteArray = key.copyOf() + + fun toBase64(): String { + val output = CharArray(Format.BASE64.length) + var i = 0 + while (i < key.size / 3) { + encodeBase64(key, i * 3, output, i * 4) + i++ + } + val endSegment = byteArrayOf(key[i * 3], key[i * 3 + 1], 0) + encodeBase64(endSegment, 0, output, i * 4) + output[Format.BASE64.length - 1] = '=' + return output.concatToString() + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is Key) return false + return key.contentEquals(other.key) + } + + override fun hashCode(): Int { + var ret = 0 + var i = 0 + while (i < key.size / 4) { + ret = ret xor ((key[i * 4 + 0].toInt() shr 0) + (key[i * 4 + 1].toInt() shr 8) + + (key[i * 4 + 2].toInt() shr 16) + (key[i * 4 + 3].toInt() shr 24)) + i++ + } + return ret + } + + companion object { + fun fromBase64(str: String): Key { + val input = str.toCharArray() + if (input.size != Format.BASE64.length || input[Format.BASE64.length - 1] != '=') { + throw KeyFormatException(Format.BASE64, Type.LENGTH) + } + val key = ByteArray(Format.BINARY.length) + var ret = 0 + var i = 0 + while (i < key.size / 3) { + val value = decodeBase64(input, i * 4) + ret = ret or (value ushr 31) + key[i * 3] = ((value ushr 16) and 0xff).toByte() + key[i * 3 + 1] = ((value ushr 8) and 0xff).toByte() + key[i * 3 + 2] = (value and 0xff).toByte() + i++ + } + val endSegment = charArrayOf(input[i * 4], input[i * 4 + 1], input[i * 4 + 2], 'A') + val value = decodeBase64(endSegment, 0) + ret = ret or ((value ushr 31) or (value and 0xff)) + key[i * 3] = ((value ushr 16) and 0xff).toByte() + key[i * 3 + 1] = ((value ushr 8) and 0xff).toByte() + + if (ret != 0) { + throw KeyFormatException(Format.BASE64, Type.CONTENTS) + } + return Key(key) + } + + fun fromBytes(bytes: ByteArray): Key { + if (bytes.size != Format.BINARY.length) { + throw KeyFormatException(Format.BINARY, Type.LENGTH) + } + return Key(bytes) + } + + fun generatePrivateKey(): Key { + val privateKey = ByteArray(Format.BINARY.length) + CryptoRand.nextBytes(privateKey) + privateKey[0] = privateKey[0] and 248.toByte() + privateKey[31] = privateKey[31] and 127.toByte() + privateKey[31] = privateKey[31] or 64.toByte() + return Key(privateKey) + } + + fun generatePublicKey(privateKey: Key): Key { + val publicKey = ByteArray(Format.BINARY.length) + X25519.x25519(privateKey.getBytes(), output = publicKey) + return Key(publicKey) + } + + private fun decodeBase64(src: CharArray, srcOffset: Int): Int { + var value = 0 + for (i in 0 until 4) { + val c = src[i + srcOffset].code + value = value or (-1 + + ((((('A'.code - 1) - c) and (c - ('Z'.code + 1))) ushr 8) and (c - 64)) + + ((((('a'.code - 1) - c) and (c - ('z'.code + 1))) ushr 8) and (c - 70)) + + ((((('0'.code - 1) - c) and (c - ('9'.code + 1))) ushr 8) and (c + 5)) + + (((('+'.code - 1) - c) and (c - ('+'.code + 1))) ushr 8 and 63) + + (((('/'.code - 1) - c) and (c - ('/'.code + 1))) ushr 8 and 64) + ) shl (18 - 6 * i) + } + return value + } + + private fun encodeBase64(src: ByteArray, srcOffset: Int, dest: CharArray, destOffset: Int) { + val input = byteArrayOf( + (src[srcOffset].toInt() shr 2 and 63).toByte(), + ((src[srcOffset].toInt() shl 4 or (src[1 + srcOffset].toInt() and 0xff ushr 4)) and 63).toByte(), + ((src[1 + srcOffset].toInt() shl 2 or (src[2 + srcOffset].toInt() and 0xff ushr 6)) and 63).toByte(), + (src[2 + srcOffset].toInt() and 63).toByte() + ) + for (i in 0 until 4) { + dest[i + destOffset] = (input[i].toInt() + 'A'.code + + (((25 - input[i].toInt()) ushr 8) and 6) - + (((51 - input[i].toInt()) ushr 8) and 75) - + (((61 - input[i].toInt()) ushr 8) and 15) + + (((62 - input[i].toInt()) ushr 8) and 3)).toChar() + } + } + } + + enum class Format(val length: Int) { + BASE64(44), + BINARY(32), + HEX(64) + } + + enum class Type { + LENGTH, + CONTENTS + } +} \ No newline at end of file diff --git a/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/util/Extensions.kt b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/util/Extensions.kt new file mode 100644 index 0000000..efa4620 --- /dev/null +++ b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/util/Extensions.kt @@ -0,0 +1,22 @@ +package com.zaneschepke.wireguardautotunnel.parser.util + +import com.zaneschepke.wireguardautotunnel.parser.ConfigParseException +import com.zaneschepke.wireguardautotunnel.parser.ErrorType + +fun Map.getInt(key: String, section: String): Int? { + val value = this[key] ?: return null + return value.toIntOrNull() ?: throw ConfigParseException(ErrorType.INVALID_VALUE_FORMAT, "$section.$key", value) +} + +fun Map.getBool(key: String, section: String): Boolean? { + val value = this[key] ?: return null + return when (value.lowercase()) { + "true", "yes", "on" -> true + "false", "no", "off" -> false + else -> throw ConfigParseException(ErrorType.INVALID_VALUE_FORMAT, "$section.$key", value) + } +} + +fun Map.getList(key: String): List? { + return this[key]?.split(",")?.map { it.trim() }?.filter { it.isNotEmpty() } +} \ No newline at end of file diff --git a/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/util/NetworkUtils.kt b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/util/NetworkUtils.kt new file mode 100644 index 0000000..58c0015 --- /dev/null +++ b/parser/src/main/kotlin/com/zaneschepke/wireguardautotunnel/parser/util/NetworkUtils.kt @@ -0,0 +1,82 @@ +package com.zaneschepke.wireguardautotunnel.parser.util + +import java.net.InetAddress + +object NetworkUtils { + private val hostnameRegex = Regex("^(?:[a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])(\\.[a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])*$") + + + fun isValidIp(ip: String): Boolean { + val sanitized = ip.removeSurrounding("[", "]") + if (sanitized.any { it.lowercaseChar() in 'g'..'z' }) return false + + return try { + InetAddress.getAllByName(sanitized).isNotEmpty() + } catch (e: Exception) { + false + } + } + + fun isValidCidr(cidr: String): Boolean { + val parts = cidr.split("/", limit = 2) + val ip = parts[0] + + if (parts.size == 1) { + return isValidIp(ip) + } + + val prefix = parts[1].toIntOrNull() ?: return false + if (!isValidIp(ip)) return false + + return try { + val addr = InetAddress.getByName(ip.removeSurrounding("[", "]")) + val maxPrefix = if (addr is java.net.Inet4Address) 32 else 128 + prefix in 0..maxPrefix + } catch (e: Exception) { + false + } + } + + fun isValidDnsEntry(entry: String): Boolean { + if (entry.isBlank()) return false + // Safe: isValidIp is offline, isValidHostname is regex. + return isValidIp(entry) || isValidHostname(entry) + } + + + fun isValidHostname(host: String): Boolean { + val cleaned = host.removeSurrounding("[", "]") + return hostnameRegex.matches(cleaned) && cleaned.length <= 253 + } + + fun isValidBase64(str: String): Boolean { + // WireGuard keys are always 44 chars (32 bytes encoded) + if (str.length != 44 || !str.endsWith("=")) return false + val base64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=" + return str.all { it in base64Chars } + } + + fun isValidAmneziaHeader(header: String): Boolean { + val maxUInt32 = 4294967295L + return try { + if (header.contains("-")) { + val parts = header.split("-") + if (parts.size != 2) return false + val start = parts[0].trim().toLong() + val end = parts[1].trim().toLong() + start in 0..maxUInt32 && end in 0..maxUInt32 && start <= end + } else { + header.trim().toLong() in 0..maxUInt32 + } + } catch (_: Exception) { + false + } + } + + fun isValidHexSignature(signature: String): Boolean { + val hex = signature.removePrefix("0x").trim() + if (hex.isEmpty() || hex.length % 2 != 0) return false + val hexChars = "0123456789abcdefABCDEF" + return hex.all { it in hexChars } + } +} \ No newline at end of file diff --git a/settings.gradle.kts b/settings.gradle.kts new file mode 100644 index 0000000..2ba0702 --- /dev/null +++ b/settings.gradle.kts @@ -0,0 +1,24 @@ +rootProject.name = "wgtunnel" +enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS") + +pluginManagement { + repositories { + google() + mavenCentral() + gradlePluginPortal() + maven("https://maven.hq.hydraulic.software") + } +} + +dependencyResolutionManagement { + repositories { + google() + mavenCentral() + } +} + +plugins { + id("org.gradle.toolchains.foojay-resolver-convention") version "1.0.0" +} + +include(":composeApp", ":parser", ":daemon", ":tunnel", ":cli", ":client", ":keyring", ":core") diff --git a/tunnel/.gitignore b/tunnel/.gitignore new file mode 100644 index 0000000..e5d0f92 --- /dev/null +++ b/tunnel/.gitignore @@ -0,0 +1,3 @@ +src/main/resources/* +/tools/libwg-go/build +/tools/libwg-go/out \ No newline at end of file diff --git a/tunnel/build.gradle.kts b/tunnel/build.gradle.kts new file mode 100644 index 0000000..d8034c2 --- /dev/null +++ b/tunnel/build.gradle.kts @@ -0,0 +1,53 @@ +plugins { + kotlin("jvm") +} + +dependencies { + testImplementation(kotlin("test")) + + implementation(libs.kotlinx.coroutines.core) + implementation(libs.jna.platform) +} + +tasks.test { + useJUnitPlatform() +} + + +tasks.register("buildGoLibs") { + val goDir = "tools/libwg-go" + group = "build" + description = "Builds Go shared libs using Makefile" + workingDir = file(goDir) + + // Track only source files + inputs.files( + fileTree(goDir) { + include("**/*.go", "**/go.mod", "**/go.sum", "Makefile") + exclude("out/**", "build/**", ".gocache/**") + } + ).withPropertyName("goSourceFiles") + .withPathSensitivity(PathSensitivity.RELATIVE) + + outputs.dir(file("src/main/resources")) + .withPropertyName("outputResourcesDir") + + commandLine("make", "all") +} + +tasks.named("processResources") { + dependsOn("buildGoLibs") +} + +val cleanGoLibs = tasks.register("cleanGoLibs") { + workingDir = file("tools/libwg-go") + commandLine("make", "clean") +} + +// 3. Update the main clean task +tasks.named("clean") { + dependsOn(cleanGoLibs) + delete(file("tools/libwg-go/build")) + delete(file("tools/libwg-go/out")) + delete(file("src/main/resources")) +} diff --git a/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/AmneziaBackend.kt b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/AmneziaBackend.kt new file mode 100644 index 0000000..4ec45e5 --- /dev/null +++ b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/AmneziaBackend.kt @@ -0,0 +1,141 @@ +package com.zaneschepke.wireguardautotunnel.tunnel + +import com.zaneschepke.wireguardautotunnel.tunnel.native.AwgTunnel +import com.zaneschepke.wireguardautotunnel.tunnel.native.StatusCodeCallback +import com.zaneschepke.wireguardautotunnel.tunnel.util.BackendException +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.* +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import java.util.concurrent.ConcurrentHashMap +import kotlin.concurrent.withLock + +class AmneziaBackend : Backend { + private val tun = AwgTunnel.INSTANCE + + private var currentMode: Backend.Mode = Backend.Mode.Userspace + + private val _status = MutableStateFlow(Backend.Status(false, currentMode, emptyMap())) + + override val status: Flow = _status.asStateFlow() + + private val tunnelHandles = ConcurrentHashMap() + private val tunnelJobs = ConcurrentHashMap() + + private val backendScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + + init { + initKillSwitchStatus() + } + + private fun initKillSwitchStatus() { + val status = tun.getKillSwitchStatus() + val enabled = status == 1 + _status.update { it.copy(killSwitchEnabled = enabled) } + } + + @Synchronized + override fun start(tunnel: Tunnel, config: String): Result = runCatching { + if (_status.value.activeTunnels.any { it.key.id == tunnel.id }) { + return Result.success(Unit) + } + + tunnel.updateState(Tunnel.State.Starting) + _status.update { it.copy(activeTunnels = it.activeTunnels + (tunnel to Tunnel.State.Starting)) } + + val statusFlow = callbackFlow { + val statusCallback = object : StatusCodeCallback { + override fun onTunnelStatusCode(handle: Int, statusCode: Int) { + trySend(statusCode) + } + } + + val handle = when(currentMode) { + Backend.Mode.Proxy -> tun.awgProxyTurnOn(config, statusCallback) + Backend.Mode.Userspace -> tun.awgTurnOn(config, statusCallback) + } + + if (handle < 0) { + close(BackendException.BackendFailure(IllegalStateException("Tunnel failed to start with handle: $handle"))) + tunnel.updateState(Tunnel.State.Down) + _status.update { it.copy(activeTunnels = it.activeTunnels - tunnel) } + } else { + tunnelHandles[tunnel] = handle + _status.update { it.copy(activeTunnels = it.activeTunnels + (tunnel to Tunnel.State.Up.Unknown)) } + } + awaitCancellation() + }.buffer(Channel.BUFFERED) + + tunnelJobs[tunnel] = backendScope.launch { + statusFlow.collect { statusCode -> + val tunnelState = mapStatusCodeToState(statusCode) + _status.update { it.copy(activeTunnels = it.activeTunnels + (tunnel to tunnelState)) } + tunnel.updateState(tunnelState) + } + }.apply { invokeOnCompletion { + tunnelJobs.remove(tunnel) + } } + }.onFailure { throwable -> + tunnel.updateState(Tunnel.State.Down) + tunnelJobs.remove(tunnel)?.cancel() + _status.update { it.copy(activeTunnels = it.activeTunnels - tunnel) } + return Result.failure(BackendException.BackendFailure(throwable)) + } + + @Synchronized + override fun stop(id: Int) { + val tunnel = tunnelHandles.keys.firstOrNull { t -> t.id == id } ?: return + val handle = tunnelHandles.remove(tunnel) ?: return + + when(currentMode) { + Backend.Mode.Proxy -> tun.awgProxyTurnOff(handle) + Backend.Mode.Userspace -> tun.awgTurnOff(handle) + } + + tunnelJobs.remove(tunnel)?.cancel() + + tunnel.updateState(Tunnel.State.Down) + _status.update { it.copy(activeTunnels = it.activeTunnels - tunnel) } + } + + override fun setMode(mode: Backend.Mode) { + if (mode == currentMode) return + shutdown() + currentMode = mode + } + + override fun shutdown() { + + when(currentMode) { + Backend.Mode.Proxy -> tun.awgProxyTurnOffAll() + Backend.Mode.Userspace -> tun.awgTurnOffAll() + } + + tunnelJobs.values.forEach { it.cancel() } + tunnelJobs.clear() + tunnelHandles.clear() + _status.update { it.copy(activeTunnels = emptyMap()) } + } + + override fun setKillSwitch(enabled: Boolean): Result { + if (_status.value.killSwitchEnabled == enabled) return Result.success(Unit) + val setValue = if (enabled) 1 else 0 + val status = tun.setKillSwitch(setValue) + if (status == -1) return Result.failure(BackendException.KillSwitchSetFailed("")) + val killSwitchEnabled = status == 1 + _status.update { it.copy(killSwitchEnabled = killSwitchEnabled) } + return Result.success(Unit) + } + + private fun mapStatusCodeToState(statusCode: Int): Tunnel.State { + // Matching native status codes + return when (statusCode) { + 0 -> Tunnel.State.Up.Healthy + 1 -> Tunnel.State.Up.HandshakeFailure + 2 -> Tunnel.State.Up.ResolvingDns + 3 -> Tunnel.State.Up.Unknown + else -> Tunnel.State.Down // unknow or negative error code consider down + } + } +} \ No newline at end of file diff --git a/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/Backend.kt b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/Backend.kt new file mode 100644 index 0000000..21af197 --- /dev/null +++ b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/Backend.kt @@ -0,0 +1,26 @@ +package com.zaneschepke.wireguardautotunnel.tunnel + +import kotlinx.coroutines.flow.Flow + +interface Backend { + fun start(tunnel: Tunnel, config : String) : Result + fun stop(id : Int) + fun setMode(mode: Mode) + + fun setKillSwitch(enabled: Boolean) : Result + + fun shutdown() + + val status : Flow + + sealed interface Mode { + data object Userspace: Mode + data object Proxy : Mode + } + + data class Status( + val killSwitchEnabled: Boolean, + val mode: Mode, + val activeTunnels: Map + ) +} \ No newline at end of file diff --git a/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/Tunnel.kt b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/Tunnel.kt new file mode 100644 index 0000000..a7f946e --- /dev/null +++ b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/Tunnel.kt @@ -0,0 +1,32 @@ +package com.zaneschepke.wireguardautotunnel.tunnel + +import kotlinx.coroutines.flow.StateFlow + +interface Tunnel { + val id: Int + val name: String + val features : Set + + fun updateState(state: State) + + sealed interface State { + sealed class Up : State { + data object Healthy : Up() + data object ResolvingDns : Up() + data object HandshakeFailure : Up() + data object Unknown : Up() + } + data object Down : State + data object Starting : State + } + + sealed interface Feature { + data object DynamicDNS : Feature + data class PingMonitor( + val intervalSeconds: Int = 30, + val attempts: Int = 3, + val timeoutSeconds: Int? = null, + val target: String? = null, + ) : Feature + } +} \ No newline at end of file diff --git a/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/native/AwgTunnel.kt b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/native/AwgTunnel.kt new file mode 100644 index 0000000..81db0d8 --- /dev/null +++ b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/native/AwgTunnel.kt @@ -0,0 +1,35 @@ +package com.zaneschepke.wireguardautotunnel.tunnel.native + +import com.sun.jna.Library +import com.sun.jna.Native +import com.sun.jna.Pointer + +interface AwgTunnel : Library { + + // Normal tunnel methods + fun awgTurnOn( + cfg: String?, + callback: StatusCodeCallback? + ): Int + fun awgTurnOff(handle: Int) + fun awgGetConfig(handle: Int): Pointer? + fun awgTurnOffAll() + + // Proxy tunnel methods + fun awgProxyTurnOn( + cfg: String?, + callback: StatusCodeCallback? + ): Int + fun awgProxyGetConfig(handle: Int): Pointer? + fun awgProxyTurnOffAll() + fun awgProxyTurnOff(handle: Int) + + + fun setKillSwitch(value: Int) : Int // 1 for enable, 0 for disable, return 1 or -1 for error + + fun getKillSwitchStatus() : Int // 1 for enabled, 0 for disabled + + companion object { + val INSTANCE: AwgTunnel = Native.load("wg", AwgTunnel::class.java) + } +} \ No newline at end of file diff --git a/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/native/StatusCodeCallback.kt b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/native/StatusCodeCallback.kt new file mode 100644 index 0000000..e1fb9ef --- /dev/null +++ b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/native/StatusCodeCallback.kt @@ -0,0 +1,7 @@ +package com.zaneschepke.wireguardautotunnel.tunnel.native + +import com.sun.jna.Callback + +interface StatusCodeCallback : Callback{ + fun onTunnelStatusCode(handle: Int, statusCode: Int) +} \ No newline at end of file diff --git a/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/util/BackendException.kt b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/util/BackendException.kt new file mode 100644 index 0000000..c9dd4f9 --- /dev/null +++ b/tunnel/src/main/kotlin/com/zaneschepke/wireguardautotunnel/tunnel/util/BackendException.kt @@ -0,0 +1,8 @@ +package com.zaneschepke.wireguardautotunnel.tunnel.util + +sealed class BackendException : Exception() { + data class InvalidConfig(val reason: String) : BackendException() + data class PermissionDenied(val requiredPermission: String) : BackendException() + data class KillSwitchSetFailed(val reason : String) : BackendException() + data class BackendFailure(override val cause: Throwable) : BackendException() +} \ No newline at end of file diff --git a/tunnel/tools/amneziawg-tools b/tunnel/tools/amneziawg-tools new file mode 160000 index 0000000..5c6ffd6 --- /dev/null +++ b/tunnel/tools/amneziawg-tools @@ -0,0 +1 @@ +Subproject commit 5c6ffd6168f7c69199200a91803fa02e1b8c4152 diff --git a/tunnel/tools/libwg-go/Makefile b/tunnel/tools/libwg-go/Makefile new file mode 100644 index 0000000..415d4b5 --- /dev/null +++ b/tunnel/tools/libwg-go/Makefile @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright © 2026 WG Tunnel. +# Adapted from WireGuard LLC. + +BUILDDIR ?= $(CURDIR)/build +DESTDIR ?= $(CURDIR)/out +RESOURCEDIR ?= $(CURDIR)/../../src/main/resources + +NDK_GO_ARCH_MAP_x86_64 := amd64 +NDK_GO_ARCH_MAP_aarch64 := arm64 + +GO_VERSION := 1.25.5 +GO_DIR := $(BUILDDIR)/go-$(GO_VERSION) +export GOCACHE := $(BUILDDIR)/go-cache + +GO_PLATFORM := $(shell uname -s | tr '[:upper:]' '[:lower:]')-$(NDK_GO_ARCH_MAP_$(shell uname -m)) + +GO_TARBALL := go$(GO_VERSION).$(GO_PLATFORM).tar.gz +GO_HASH_darwin-amd64 := b69d51bce599e5381a94ce15263ae644ec84667a5ce23d58dc2e63e2c12a9f56 +GO_HASH_darwin-arm64 := bed8ebe824e3d3b27e8471d1307f803fc6ab8e1d0eb7a4ae196979bd9b801dd3 +GO_HASH_linux-amd64 := 9e9b755d63b36acf30c12a9a3fc379243714c1c6d3dd72861da637f336ebb35b + +export GOROOT := $(GO_DIR) +export PATH := $(GO_DIR)/bin:$(PATH) +export CGO_ENABLED := 1 + +HOST_OS := $(shell uname -s | tr '[:upper:]' '[:lower:]') +ifeq ($(HOST_OS),darwin) +PLATFORMS := darwin-amd64 darwin-arm64 linux-amd64 windows-amd64 windows-arm64 +else +PLATFORMS := linux-amd64 windows-amd64 +endif + +default: all + +$(BUILDDIR)/$(GO_TARBALL): + mkdir -p "$(dir $@)" + curl -o "$@.tmp" "https://dl.google.com/go/$(GO_TARBALL)" + echo "$(GO_HASH_$(GO_PLATFORM)) $@.tmp" | sha256sum -c + mv "$@.tmp" "$@" + +$(GO_DIR)/.prepared: $(BUILDDIR)/$(GO_TARBALL) + mkdir -p "$(dir $@)" + tar -C "$(dir $@)" --strip-components=1 -xzf "$^" + cd "$(dir $@)/src/runtime" && sed -i 's/CLOCK_MONOTONIC/BOOTTIME/g' sys_linux_*.s + cd "$(dir $@)/src/runtime" && sed -i 's/ $1/ $7/g' sys_linux_*.s + cd "$(dir $@)/src/runtime" && sed -i '/libc_mach_absolute_time/a \//go:cgo_import_dynamic libc_mach_continuous_time mach_continuous_time "/usr/lib/libSystem.B.dylib"' sys_darwin.go + cd "$(dir $@)/src/runtime" && sed -i 's/mach_absolute_time/mach_continuous_time/g' sys_darwin_amd64.s sys_darwin_arm64.s + touch "$@" + +$(DESTDIR)/libwg-linux-amd64.so: $(GO_DIR)/.prepared go.mod go.sum $(wildcard *.go) $(wildcard **/*.go) + @mkdir -p "$(DESTDIR)" + @if [ "$(HOST_OS)" = "darwin" ]; then \ + GOOS=linux GOARCH=amd64 CC="zig cc -target x86_64-linux-gnu-musl" CXX="zig c++ -target x86_64-linux-gnu-musl" go build -ldflags="-buildid=" -v -trimpath -buildvcs=false -o "$@" -buildmode=c-shared .; \ + else \ + GOOS=linux GOARCH=amd64 go build -ldflags="-buildid=" -v -trimpath -buildvcs=false -o "$@" -buildmode=c-shared .; \ + fi + +$(DESTDIR)/libwg-windows-amd64.dll: $(GO_DIR)/.prepared go.mod go.sum $(wildcard *.go) $(wildcard **/*.go) + @mkdir -p "$(DESTDIR)" + GOOS=windows GOARCH=amd64 CC=x86_64-w64-mingw32-gcc CXX=x86_64-w64-mingw32-g++ \ + go build -ldflags="-buildid=" -v -trimpath -buildvcs=false -o "$@" -buildmode=c-shared . + +$(DESTDIR)/libwg-windows-arm64.dll: $(GO_DIR)/.prepared go.mod go.sum $(wildcard *.go) $(wildcard **/*.go) + @mkdir -p "$(DESTDIR)" + GOOS=windows GOARCH=arm64 CC=aarch64-w64-mingw32-gcc CXX=aarch64-w64-mingw32-g++ \ + go build -ldflags="-buildid=" -v -trimpath -buildvcs=false -o "$@" -buildmode=c-shared . + +$(DESTDIR)/libwg-darwin-amd64.dylib: $(GO_DIR)/.prepared go.mod go.sum $(wildcard *.go) $(wildcard **/*.go) + @mkdir -p "$(DESTDIR)" + GOOS=darwin GOARCH=amd64 go build -ldflags="-buildid=" -v -trimpath -buildvcs=false -o "$@" -buildmode=c-shared . + +$(DESTDIR)/libwg-darwin-arm64.dylib: $(GO_DIR)/.prepared go.mod go.sum $(wildcard *.go) $(wildcard **/*.go) + @mkdir -p "$(DESTDIR)" + GOOS=darwin GOARCH=arm64 go build -ldflags="-buildid=" -v -trimpath -buildvcs=false -o "$@" -buildmode=c-shared . + +go-clean-cache: $(GO_DIR)/.prepared + @mkdir -p $(GOCACHE) + -@go clean -cache > /dev/null 2>&1 + +build_all: go-clean-cache \ + $(foreach plat,$(PLATFORMS),$(DESTDIR)/libwg-$(plat).$(if $(findstring windows,$(plat)),dll,$(if $(findstring darwin,$(plat)),dylib,so))) + +copy_to_resources: build_all + @for plat in $(PLATFORMS); do \ + os=`echo "$$plat" | cut -d- -f1`; \ + arch=`echo "$$plat" | cut -d- -f2`; \ + jna_dir=; \ + if [ "$$os" = "linux" ] && [ "$$arch" = "amd64" ]; then jna_dir=linux-x86-64; fi; \ + if [ "$$os" = "windows" ] && [ "$$arch" = "amd64" ]; then jna_dir=win32-x86-64; fi; \ + if [ "$$os" = "windows" ] && [ "$$arch" = "arm64" ]; then jna_dir=win32-aarch64; fi; \ + if [ "$$os" = "darwin" ] && [ "$$arch" = "amd64" ]; then jna_dir=darwin-x86-64; fi; \ + if [ "$$os" = "darwin" ] && [ "$$arch" = "arm64" ]; then jna_dir=darwin-aarch64; fi; \ + libext=so; \ + if [ "$$os" = "windows" ]; then libext=dll; fi; \ + if [ "$$os" = "darwin" ]; then libext=dylib; fi; \ + dest_dir="$(RESOURCEDIR)/$$jna_dir"; \ + mkdir -p "$$dest_dir"; \ + cp "$(DESTDIR)/libwg-$$plat.$$libext" "$$dest_dir/libwg.$$libext"; \ + echo "Copied $$plat -> $$dest_dir/libwg.$$libext"; \ + done + + +all: copy_to_resources + +clean: + rm -rf $(BUILDDIR) $(DESTDIR) + +.PHONY: default all build_all copy_to_resources clean +.DELETE_ON_ERROR: \ No newline at end of file diff --git a/tunnel/tools/libwg-go/constants/constants.go b/tunnel/tools/libwg-go/constants/constants.go new file mode 100644 index 0000000..3651fa5 --- /dev/null +++ b/tunnel/tools/libwg-go/constants/constants.go @@ -0,0 +1,7 @@ +package constants + +const ( + IfacePrefix = "wgtun" + IfaceName = IfacePrefix + "%d" + DummyAddress = "100.64.0.1" +) diff --git a/tunnel/tools/libwg-go/dns/dns.go b/tunnel/tools/libwg-go/dns/dns.go new file mode 100644 index 0000000..ed58602 --- /dev/null +++ b/tunnel/tools/libwg-go/dns/dns.go @@ -0,0 +1,165 @@ +// Package dnsresolver provides modular DNS resolution with backoff retries using AdguardTeam/dnsproxy. +// It supports various protocols (plain DNS, DoT, DoH, DoQ, DNSCrypt) via upstream URLs. +// Example upstream formats: +// - Plain UDP: "udp://1.1.1.1:53" +// - Plain TCP: "tcp://1.1.1.1:53" +// - DoT: "tls://1.1.1.1:853" +// - DoH: "https://cloudflare-dns.com/dns-query" +// - DoQ: "quic://dns.adguard-dns.com:853" +// - DNSCrypt: "sdns://AQIAAAAAAAAAFDEuZTAuMC4xOjg0NDMg04wIk9UdC5pYol3Wg92WwgQzOKk8J6SxvE-rO4jDW56HAgBgML0pB4" + +package dns + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sync" + "time" + + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/cenkalti/backoff/v5" + "github.com/miekg/dns" +) + +// ResolverOptions configures the DNS resolver. +type ResolverOptions struct { + UpstreamURL string + Timeout time.Duration +} + +// DefaultOptions returns default resolver options with 1.1.1.1 over UDP. +func DefaultOptions() ResolverOptions { + return ResolverOptions{ + UpstreamURL: "udp://1.1.1.1:53", + Timeout: 5 * time.Second, + } +} + +type Resolved struct { + V4 []netip.Addr + V6 []netip.Addr +} + +func resolveInner(host string, ipType uint16, u upstream.Upstream, dialer *net.Dialer, wg *sync.WaitGroup) ([]netip.Addr, error) { + var addr []netip.Addr + defer wg.Done() + + req := &dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.SetQuestion(dns.Fqdn(host), ipType) + + req.SetEdns0(4096, true) + + // Since upstream.Options doesn't take a dialer, we use the miekg/dns client + // directly with our custom dialer to ensure the SO_MARK/Binding is applied. + client := &dns.Client{ + Net: "udp", + Dialer: dialer, + Timeout: 5 * time.Second, + UDPSize: 4096, + } + + // We use the Address from the upstream (e.g., "1.1.1.1:53") + res, _, err := client.Exchange(req, u.Address()) + if err != nil { + return nil, err + } + + if res.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("DNS query failed with Rcode: %d", res.Rcode) + } + + for _, ans := range res.Answer { + switch ipType { + case dns.TypeA: + if a, ok := ans.(*dns.A); ok { + if ip, err := netip.ParseAddr(a.A.String()); err == nil { + addr = append(addr, ip) + } + } + case dns.TypeAAAA: + if aaaa, ok := ans.(*dns.AAAA); ok { + if ip, err := netip.ParseAddr(aaaa.AAAA.String()); err == nil { + addr = append(addr, ip) + } + } + } + } + return addr, nil +} + +func Resolve(host string, opts ResolverOptions, preferIpv6 bool) ([]netip.Addr, []netip.Addr, error) { + dialer, err := GetBypassDialer(preferIpv6) + if err != nil { + return nil, nil, fmt.Errorf("bypass dialer failed: %w", err) + } + + // 2. Setup the library just to handle URL parsing and certificates + // We pass the CustomResolver (which uses our bypass dialer) for bootstrapping + u, err := upstream.AddressToUpstream(opts.UpstreamURL, &upstream.Options{ + Bootstrap: CustomResolver(preferIpv6), + Timeout: opts.Timeout, + PreferIPv6: preferIpv6, + }) + if err != nil { + return nil, nil, err + } + defer u.Close() + + var wg sync.WaitGroup + var v4, v6 []netip.Addr + var v4Err, v6Err error + + wg.Add(2) + // 3. We use the 'dialer' directly in resolveInner + go func() { v4, v4Err = resolveInner(host, dns.TypeA, u, dialer, &wg) }() + go func() { v6, v6Err = resolveInner(host, dns.TypeAAAA, u, dialer, &wg) }() + wg.Wait() + + if v4Err != nil && v6Err != nil { + return nil, nil, errors.Join(v4Err, v6Err) + } + + if len(v4) == 0 && len(v6) == 0 { + if v4Err != nil { + return nil, nil, v4Err + } + if v6Err != nil { + return nil, nil, v6Err + } + return nil, nil, errors.New("no IP addresses found") + } + + return v4, v6, nil +} + +// ResolveWithBackoff retries resolution with exponential backoff until success +func ResolveWithBackoff(ctx context.Context, host string, opts ResolverOptions, preferIpv6 bool, logger *device.Logger) (Resolved, error) { + logger.Verbosef("Starting DNS resolution...") + operation := func() (Resolved, error) { + if err := ctx.Err(); err != nil { + return Resolved{}, backoff.Permanent(err) + } + v4, v6, err := Resolve(host, opts, preferIpv6) + if err != nil { + logger.Errorf("Error resolving host %s: %v, retrying...", host, err) + return Resolved{}, err + } + if len(v4) == 0 && len(v6) == 0 { + logger.Errorf("No IPs resolved for host %s, retrying...", host) + return Resolved{}, errors.New("no IPs resolved") + } + logger.Verbosef("Host successfully resolved.") + return Resolved{V4: v4, V6: v6}, nil + } + + return backoff.Retry(ctx, operation, + backoff.WithBackOff(backoff.NewExponentialBackOff()), + backoff.WithMaxElapsedTime(0), // retry forever + ) +} diff --git a/tunnel/tools/libwg-go/dns/resolver_unix.go b/tunnel/tools/libwg-go/dns/resolver_unix.go new file mode 100644 index 0000000..e882b98 --- /dev/null +++ b/tunnel/tools/libwg-go/dns/resolver_unix.go @@ -0,0 +1,41 @@ +//go:build unix + +package dns + +import ( + "context" + "net" + "syscall" + + "github.com/wgtunnel/desktop/tunnel/vpn/firewall/mark" +) + +// GetBypassDialer returns a dialer that bypasses the VPN via SO_MARK +func GetBypassDialer(preferIpv6 bool) (*net.Dialer, error) { + return &net.Dialer{ + Control: func(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + opErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark.LinuxBootstrapMarkNum) + }) + if err != nil { + return err + } + return opErr + }, + }, nil +} + +// CustomResolver is still needed for the dnsproxy Bootstrap field +func CustomResolver(preferIpv6 bool) *net.Resolver { + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d, err := GetBypassDialer(preferIpv6) + if err != nil { + return nil, err + } + return d.DialContext(ctx, network, address) + }, + } +} diff --git a/tunnel/tools/libwg-go/dns/resolver_windows.go b/tunnel/tools/libwg-go/dns/resolver_windows.go new file mode 100644 index 0000000..d2eb766 --- /dev/null +++ b/tunnel/tools/libwg-go/dns/resolver_windows.go @@ -0,0 +1,26 @@ +//go:build windows + +package dns + +import ( + "context" + "net" +) + +// GetBypassDialer returns a standard dialer for Windows. +// Since the process is already bypassed in the Windows Firewall, +// no special socket marking or binding is required. +func GetBypassDialer(preferIpv6 bool) (*net.Dialer, error) { + return &net.Dialer{}, nil +} + +// CustomResolver returns a standard net.Resolver for Windows. +func CustomResolver(preferIpv6 bool) *net.Resolver { + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d, _ := GetBypassDialer(preferIpv6) + return d.DialContext(ctx, network, address) + }, + } +} diff --git a/tunnel/tools/libwg-go/go.mod b/tunnel/tools/libwg-go/go.mod new file mode 100755 index 0000000..a80768d --- /dev/null +++ b/tunnel/tools/libwg-go/go.mod @@ -0,0 +1,61 @@ +module github.com/wgtunnel/desktop/tunnel + +go 1.25.5 + +require ( + github.com/AdguardTeam/dnsproxy v0.78.2 + github.com/amnezia-vpn/amneziawg-go v0.2.16 + github.com/artem-russkikh/wireproxy-awg v1.0.12 + github.com/cenkalti/backoff/v5 v5.0.3 + github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 + github.com/google/nftables v0.3.0 + github.com/vishvananda/netlink v1.3.1 + go4.org/netipx v0.0.0-20231129151722-fdeea329fbba + golang.zx2c4.com/wireguard/windows v0.5.3 + inet.af/wf v0.0.0-20221017222439-36129f591884 + tailscale.com v1.94.1 +) + +require ( + github.com/AdguardTeam/golibs v0.35.7 // indirect + github.com/BurntSushi/toml v1.6.0 // indirect + github.com/ameshkov/dnscrypt/v2 v2.4.0 // indirect + github.com/ameshkov/dnsstamps v1.0.3 // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/mdlayher/netlink v1.8.0 // indirect + github.com/mdlayher/socket v0.5.1 // indirect + github.com/quic-go/qpack v0.6.0 // indirect + github.com/quic-go/quic-go v0.59.0 // indirect + github.com/vishvananda/netns v0.0.5 // indirect + golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect + golang.org/x/exp/typeparams v0.0.0-20260112195511-716be5621a96 // indirect + golang.org/x/mod v0.32.0 // indirect + golang.org/x/text v0.33.0 // indirect + golang.org/x/tools v0.41.0 // indirect + honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0 // indirect +) + +require ( + github.com/MakeNowJust/heredoc/v2 v2.0.1 // indirect + github.com/go-ini/ini v1.67.0 // indirect + github.com/google/btree v1.1.3 // indirect + github.com/miekg/dns v1.1.72 + github.com/things-go/go-socks5 v0.1.0 // indirect + golang.org/x/crypto v0.47.0 // indirect + golang.org/x/net v0.49.0 + golang.org/x/sync v0.19.0 // indirect + golang.org/x/sys v0.40.0 + golang.org/x/time v0.14.0 // indirect + golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect + gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 // indirect; ind +) + +//replace github.com/amnezia-vpn/amneziawg-go => github.com/wgtunnel/amneziawg-go v0.0.0-20251225080458-6a08ea62878d + +//replace github.com/artem-russkikh/wireproxy-awg => github.com/wgtunnel/wireproxy-awg v0.0.0-20251215030122-ffaf05dda47f + +// local dev +replace github.com/amnezia-vpn/amneziawg-go => ../../../../amneziawg-go + +// +replace github.com/artem-russkikh/wireproxy-awg => ../../../../wireproxy-awg diff --git a/tunnel/tools/libwg-go/go.sum b/tunnel/tools/libwg-go/go.sum new file mode 100644 index 0000000..c591e86 --- /dev/null +++ b/tunnel/tools/libwg-go/go.sum @@ -0,0 +1,91 @@ +github.com/AdguardTeam/dnsproxy v0.78.2 h1:g+ba4vh72hAv9zIE+OPSEnu77utSKxIF6u2jNhYAR7g= +github.com/AdguardTeam/dnsproxy v0.78.2/go.mod h1:gwr+7Dc0e7QddQLC9JLGjL5NSKcqw0ESsNMRI5Q67Ps= +github.com/AdguardTeam/golibs v0.35.7 h1:pTQpixUos7mALr3jqb0pigfrkiqPAX1hiYUi/yeBWiA= +github.com/AdguardTeam/golibs v0.35.7/go.mod h1:meFdRqMtG/PLW6LD20MYAlcRbwAVowlbunHgE17xz9s= +github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= +github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/MakeNowJust/heredoc/v2 v2.0.1 h1:rlCHh70XXXv7toz95ajQWOWQnN4WNLt0TdpZYIR/J6A= +github.com/MakeNowJust/heredoc/v2 v2.0.1/go.mod h1:6/2Abh5s+hc3g9nbWLe9ObDIOhaRrqsyY9MWy+4JdRM= +github.com/ameshkov/dnscrypt/v2 v2.4.0 h1:if6ZG2cuQmcP2TwSY+D0+8+xbPfoatufGlOQTMNkI9o= +github.com/ameshkov/dnscrypt/v2 v2.4.0/go.mod h1:WpEFV2uhebXb8Jhes/5/fSdpmhGV8TL22RDaeWwV6hI= +github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo= +github.com/ameshkov/dnsstamps v1.0.3/go.mod h1:Ii3eUu73dx4Vw5O4wjzmT5+lkCwovjzaEZZ4gKyIH5A= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= +github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= +github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= +github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= +github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg= +github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM= +github.com/mdlayher/netlink v1.8.0 h1:e7XNIYJKD7hUct3Px04RuIGJbBxy1/c4nX7D5YyvvlM= +github.com/mdlayher/netlink v1.8.0/go.mod h1:UhgKXUlDQhzb09DrCl2GuRNEglHmhYoWAHid9HK3594= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= +github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= +github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= +github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= +github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/things-go/go-socks5 v0.1.0 h1:4f5dz0iMQ6cA4wseFmyLmCHmg3SWJTW92ndrKS6oERg= +github.com/things-go/go-socks5 v0.1.0/go.mod h1:Riabiyu52kLsla0YmJqunt1c1JEl6iXSr4bRd7swFEA= +github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= +github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= +github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= +github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= +go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= +golang.org/x/exp/typeparams v0.0.0-20260112195511-716be5621a96 h1:RMc8anw0hCPcg5CZYN2PEQ8nMwosk461R6vFwPrCFVg= +golang.org/x/exp/typeparams v0.0.0-20260112195511-716be5621a96/go.mod h1:4Mzdyp/6jzw9auFDJ3OMF5qksa7UvPnzKqTVGcb04ms= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= +golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= +golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 h1:2gap+Kh/3F47cO6hAu3idFvsJ0ue6TRcEi2IUkv/F8k= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= +honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0 h1:5SXjd4ET5dYijLaf0O3aOenC0Z4ZafIWSpjUzsQaNho= +honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0/go.mod h1:EPDDhEZqVHhWuPI5zPAsjU0U7v9xNIWjoOVyZ5ZcniQ= +inet.af/wf v0.0.0-20221017222439-36129f591884 h1:zg9snq3Cpy50lWuVqDYM7AIRVTtU50y5WXETMFohW/Q= +inet.af/wf v0.0.0-20221017222439-36129f591884/go.mod h1:bSAQ38BYbY68uwpasXOTZo22dKGy9SNvI6PZFeKomZE= +tailscale.com v1.94.1 h1:0dAst/ozTuFkgmxZULc3oNwR9+qPIt5ucvzH7kaM0Jw= +tailscale.com v1.94.1/go.mod h1:gLnVrEOP32GWvroaAHHGhjSGMPJ1i4DvqNwEg+Yuov4= diff --git a/tunnel/tools/libwg-go/ipc/ipc_bsd.go b/tunnel/tools/libwg-go/ipc/ipc_bsd.go new file mode 100644 index 0000000..c7b1dfb --- /dev/null +++ b/tunnel/tools/libwg-go/ipc/ipc_bsd.go @@ -0,0 +1,29 @@ +//go:build darwin || freebsd || openbsd + +package ipc + +import ( + "net" + + "github.com/amnezia-vpn/amneziawg-go/ipc" + "github.com/wgtunnel/desktop/tunnel/shared" +) + +func SetupIPC(name string) (net.Listener, error) { + var socketDirectory = "/run/wgtunnel" + + uapiFile, err := ipc.UAPIOpen(socketDirectory, name) + if err != nil { + shared.LogError("IPC", "UAPIOpen: %v", err) + return nil, err + } + + uapi, err := ipc.UAPIListen(socketDirectory, name, uapiFile) + if err != nil { + uapiFile.Close() + shared.LogError("IPC", "UAPIListen: %v", err) + return nil, err + } + + return uapi, nil +} diff --git a/tunnel/tools/libwg-go/ipc/ipc_linux.go b/tunnel/tools/libwg-go/ipc/ipc_linux.go new file mode 100644 index 0000000..c4667eb --- /dev/null +++ b/tunnel/tools/libwg-go/ipc/ipc_linux.go @@ -0,0 +1,29 @@ +//go:build linux + +package ipc + +import ( + "net" + + "github.com/amnezia-vpn/amneziawg-go/ipc" + "github.com/wgtunnel/desktop/tunnel/shared" +) + +func SetupIPC(name string) (net.Listener, error) { + var socketDirectory = "/run/wgtunnel" + + uapiFile, err := ipc.UAPIOpen(socketDirectory, name) + if err != nil { + shared.LogError("IPC", "UAPIOpen: %v", err) + return nil, err + } + + uapi, err := ipc.UAPIListen(socketDirectory, name, uapiFile) + if err != nil { + uapiFile.Close() + shared.LogError("IPC", "UAPIListen: %v", err) + return nil, err + } + + return uapi, nil +} diff --git a/tunnel/tools/libwg-go/ipc/ipc_windows.go b/tunnel/tools/libwg-go/ipc/ipc_windows.go new file mode 100644 index 0000000..902abb7 --- /dev/null +++ b/tunnel/tools/libwg-go/ipc/ipc_windows.go @@ -0,0 +1,20 @@ +//go:build windows + +package ipc + +import ( + "net" + + "github.com/amnezia-vpn/amneziawg-go/ipc" + "github.com/wgtunnel/desktop/tunnel/shared" +) + +func SetupIPC(name string) (net.Listener, error) { + uapi, err := ipc.UAPIListen(name) + if err != nil { + shared.LogError("IPC", "UAPIListen: %v", err) + return nil, err + } + + return uapi, nil +} diff --git a/tunnel/tools/libwg-go/killswitch/killswitch.go b/tunnel/tools/libwg-go/killswitch/killswitch.go new file mode 100644 index 0000000..49ad9e9 --- /dev/null +++ b/tunnel/tools/libwg-go/killswitch/killswitch.go @@ -0,0 +1,51 @@ +//go:build !android + +package killswitch + +import "C" +import ( + "github.com/wgtunnel/desktop/tunnel/shared" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall/osfirewall/firewallmgr" +) + +var logger = shared.NewLogger("KillSwitch") + +//export setKillSwitch +func setKillSwitch(enabled C.int) C.int { + fw, err := firewallmgr.Get() + if err != nil { + logger.Errorf("Failed to get firewall: %v", err) + return C.int(-1) + } + + if enabled == 1 { + err := fw.Enable() + if err != nil { + logger.Errorf("Failed to enable kill switch: %v", err) + return C.int(-1) + } + logger.Verbosef("Kill switch enabled") + } else { + err := fw.Disable() + if err != nil { + logger.Errorf("Failed to disable kill switch: %v", err) + return C.int(-1) + } + logger.Verbosef("Kill switch disabled") + } + return enabled +} + +//export getKillSwitchStatus +func getKillSwitchStatus() C.int { + fw, err := firewallmgr.Get() + if err != nil { + logger.Errorf("Failed to get firewall: %v", err) + return C.int(0) + } + + if fw.IsEnabled() { + return C.int(1) + } + return C.int(0) +} diff --git a/tunnel/tools/libwg-go/main.go b/tunnel/tools/libwg-go/main.go new file mode 100644 index 0000000..5f78def --- /dev/null +++ b/tunnel/tools/libwg-go/main.go @@ -0,0 +1,9 @@ +package main + +import ( + _ "github.com/wgtunnel/desktop/tunnel/killswitch" + _ "github.com/wgtunnel/desktop/tunnel/proxy" + _ "github.com/wgtunnel/desktop/tunnel/vpn" +) + +func main() {} diff --git a/tunnel/tools/libwg-go/proxy/proxy.go b/tunnel/tools/libwg-go/proxy/proxy.go new file mode 100755 index 0000000..6cbd8ab --- /dev/null +++ b/tunnel/tools/libwg-go/proxy/proxy.go @@ -0,0 +1,231 @@ +//go:build !android + +package proxy + +/* +#include +typedef void (*StatusCodeCallback)(int32_t handle, int32_t status); +*/ +import "C" +import ( + "context" + "sync" + "syscall" + + "os" + "os/signal" + + "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/amnezia-vpn/amneziawg-go/tun/netstack" + wireproxyawg "github.com/artem-russkikh/wireproxy-awg" + ipc "github.com/wgtunnel/desktop/tunnel/ipc" + "github.com/wgtunnel/desktop/tunnel/shared" + "github.com/wgtunnel/desktop/tunnel/util" +) + +var ( + tag = "AwgProxy" + virtualTunnelHandles = make(map[int32]*wireproxyawg.VirtualTun) + ctx context.Context + cancelFunc context.CancelFunc +) + +func init() { + // Handle signals for clean shutdown + go handleSignals() +} + +func handleSignals() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + <-sigs + awgProxyTurnOffAll() + os.Exit(0) +} + +//export awgProxyTurnOn +func awgProxyTurnOn(config *C.char, callback C.StatusCodeCallback) C.int { + handle, err2 := util.GenerateHandle(virtualTunnelHandles) + if err2 != nil { + shared.LogError(tag, "Unable to find empty handle", err2) + return C.int(-1) + } + + shared.StoreTunnelCallback(handle, shared.StatusCodeCallback(callback)) + + goConfig := C.GoString(config) + + conf, err := wireproxyawg.ParseConfigString(goConfig) + if err != nil { + shared.LogError(tag, "Invalid config file", err) + return C.int(-1) + } + + setting, err := wireproxyawg.CreateIPCRequest(conf.Device, false) + if err != nil { + shared.LogError(tag, "Create IPC request failed", err) + return C.int(-1) + } + + tun, tnet, err := netstack.CreateNetTUN(setting.DeviceAddr, setting.DNS, setting.MTU) + if err != nil { + shared.LogError(tag, "Create TUN failed", err) + return C.int(-1) + } + + name, err := tun.Name() + if err != nil { + shared.LogError(tag, "Get TUN name failed", err) + return C.int(-1) + } + + bind := conn.NewDefaultBind() + + statusCB := func(code device.StatusCode) { + // use goroutine to avoid any blocking from JNA + go shared.NotifyStatusCode(handle, int32(code)) + } + + dev := device.NewDevice(tun, bind, shared.NewLogger("Tun/"+name), false, statusCB) + + err = dev.IpcSet(setting.IpcRequest) + if err != nil { + shared.LogError(tag, "Ipc setting failed", err) + return C.int(-1) + } + + uapi, _ := ipc.SetupIPC(name) + + go func() { + for { + connection, err := uapi.Accept() + if err != nil { + return + } + go dev.IpcHandle(connection) + } + }() + + err = dev.Up() + if err != nil { + shared.LogError(tag, "Failed to bring up device", err) + uapi.Close() + dev.Close() + return C.int(-1) + } + + virtualTun := &wireproxyawg.VirtualTun{ + Tnet: tnet, + Dev: dev, + Logger: shared.NewLogger("Proxy"), + Uapi: uapi, + Conf: conf.Device, + PingRecord: make(map[string]uint64), + PingRecordLock: new(sync.Mutex), + } + + virtualTunnelHandles[handle] = virtualTun + + // Create cancellable context + ctx, cancelFunc = context.WithCancel(context.Background()) + + // Spawn all routines with context + for _, spawner := range conf.Routines { + shared.LogDebug(tag, "Spawning routine..") + go func(s wireproxyawg.RoutineSpawner) { + if err := s.SpawnRoutine(ctx, virtualTun); err != nil { + shared.LogError(tag, "Routine failed: %v", err) + } + }(spawner) + } + + shared.LogDebug(tag, "Done starting proxy and tunnel") + return C.int(handle) +} + +func awgUpdateProxyTunnelPeers(tunnelHandle int32, settings string) int32 { + handle, ok := virtualTunnelHandles[tunnelHandle] + if !ok { + shared.LogError(tag, "Tunnel is not up") + return -1 + } + + conf, err := wireproxyawg.ParseConfigString(settings) + if err != nil { + shared.LogError(tag, "Invalid config file", err) + return -1 + } + + ipcRequest, err := wireproxyawg.CreatePeerIPCRequest(conf.Device) + if err != nil { + shared.LogError(tag, "CreateIPCRequest: %v", err) + return -1 + } + + err = handle.Dev.IpcSet(ipcRequest.IpcRequest) + if err != nil { + shared.LogError(tag, "IpcSet: %v", err) + return -1 + } + + shared.LogDebug(tag, "Configuration updated successfully") + return 0 +} + +//export awgProxyGetConfig +func awgProxyGetConfig(tunnelHandle C.int) *C.char { + goTunnelHandle := int32(tunnelHandle) + handle, ok := virtualTunnelHandles[goTunnelHandle] + if !ok { + shared.LogError(tag, "Tunnel is not up") + return nil + } + settings, err := handle.Dev.IpcGet() + if err != nil { + shared.LogError(tag, "Failed to get device config: %v", err) + return nil + } + return C.CString(settings) +} + +//export awgProxyTurnOffAll +func awgProxyTurnOffAll() { + if cancelFunc != nil { + shared.LogDebug(tag, "Stopping proxy routines..") + cancelFunc() + cancelFunc = nil + } + handles := make([]int32, 0, len(virtualTunnelHandles)) + for h := range virtualTunnelHandles { + handles = append(handles, h) + } + for _, handle := range handles { + awgProxyTurnOff(C.int(handle)) + } + virtualTunnelHandles = make(map[int32]*wireproxyawg.VirtualTun) + shared.LogDebug(tag, "Proxy fully reset: %d handles closed", len(handles)) +} + +//export awgProxyTurnOff +func awgProxyTurnOff(virtualTunnelHandle C.int) { + goVirtualTunnelHandle := int32(virtualTunnelHandle) + virtualTun, ok := virtualTunnelHandles[goVirtualTunnelHandle] + if !ok { + shared.LogError(tag, "Tunnel handle %d not found", goVirtualTunnelHandle) + return + } + shared.LogDebug(tag, "Tearing down tunnel %d", goVirtualTunnelHandle) + + // Disable UAPI listener and underlying file + if virtualTun.Uapi != nil { + virtualTun.Uapi.Close() + } + + if virtualTun.Dev != nil { + virtualTun.Dev.Close() + } + + delete(virtualTunnelHandles, goVirtualTunnelHandle) + shared.LogDebug(tag, "Tunnel %d fully closed (UAPI/Dev/Bind purged)", goVirtualTunnelHandle) +} diff --git a/tunnel/tools/libwg-go/shared/shared.go b/tunnel/tools/libwg-go/shared/shared.go new file mode 100755 index 0000000..529f439 --- /dev/null +++ b/tunnel/tools/libwg-go/shared/shared.go @@ -0,0 +1,63 @@ +package shared + +/* +#include +typedef void (*StatusCodeCallback)(int32_t handle, int32_t status); + +void callStatusCallback(StatusCodeCallback cb, int32_t handle, int32_t status) { + if (cb) cb(handle, status); +} +*/ +import "C" +import ( + "log" + + "github.com/amnezia-vpn/amneziawg-go/device" +) + +var tag = "AmneziaWG" + +func LogDebug(format string, args ...interface{}) { + log.Printf("[DEBUG] %s: "+format+"\n", append([]interface{}{tag}, args...)...) +} + +func LogWarn(format string, args ...interface{}) { + log.Printf("[WARN] %s: "+format+"\n", append([]interface{}{tag}, args...)...) +} + +func LogError(format string, args ...interface{}) { + log.Printf("[ERROR] %s: "+format+"\n", append([]interface{}{tag}, args...)...) +} + +func NewLogger(prefix string) *device.Logger { + return &device.Logger{ + Verbosef: func(format string, args ...any) { + LogDebug(prefix+": "+format, args...) + }, + Errorf: func(format string, args ...any) { + LogError(prefix+": "+format, args...) + }, + } +} + +type StatusCodeCallback C.StatusCodeCallback + +var tunnelCallbacks = make(map[int32]StatusCodeCallback) + +func StoreTunnelCallback(handle int32, cb StatusCodeCallback) { + if cb != nil { + tunnelCallbacks[handle] = cb + } +} + +func NotifyStatusCode(handle int32, status int32) { + if cb, ok := tunnelCallbacks[handle]; ok && cb != nil { + C.callStatusCallback(cb, C.int32_t(handle), C.int32_t(status)) + } +} + +const ( + StatusHealthy = iota + StatusHandshakeFailure + StatusResolvingDNS +) diff --git a/tunnel/tools/libwg-go/util/util.go b/tunnel/tools/libwg-go/util/util.go new file mode 100755 index 0000000..ebd493d --- /dev/null +++ b/tunnel/tools/libwg-go/util/util.go @@ -0,0 +1,16 @@ +package util + +import ( + "fmt" + "math" +) + +// GenerateHandle generates a unique int32 handle for a given map. +func GenerateHandle[K int32, V any](handles map[K]V) (int32, error) { + for i := int32(0); i < math.MaxInt32; i++ { + if _, exists := handles[K(i)]; !exists { + return i, nil + } + } + return -1, fmt.Errorf("unable to find handle") +} diff --git a/tunnel/tools/libwg-go/vpn/bind/bind_darwin.go b/tunnel/tools/libwg-go/vpn/bind/bind_darwin.go new file mode 100644 index 0000000..c21f6d1 --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/bind/bind_darwin.go @@ -0,0 +1,13 @@ +//go:build darwin + +package bind + +import ( + "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device" +) + +func SetupBind(logger *device.Logger, bind conn.Bind) error { + + return nil // No fwmark on non-Linux; no-op +} diff --git a/tunnel/tools/libwg-go/vpn/bind/linux_bind.go b/tunnel/tools/libwg-go/vpn/bind/linux_bind.go new file mode 100644 index 0000000..a164126 --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/bind/linux_bind.go @@ -0,0 +1,38 @@ +//go:build linux && !android + +package bind + +import ( + "fmt" + "syscall" + + "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall/mark" + "golang.org/x/sys/unix" +) + +func SetupBind(logger *device.Logger, bind conn.Bind) error { + stdBind, ok := bind.(*conn.StdNetBind) + if !ok { + return fmt.Errorf("failed to cast to StdNetBind") + } + stdBind.SetControl(func(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + logger.Verbosef("Control called on socket FD %d - setting fwmark...", fd) + if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark.LinuxBypassMarkNum); err != nil { + opErr = err + logger.Errorf("Failed to set fwmark on FD %d: %v", fd, err) + } else { + logger.Verbosef("Fwmark %d set on FD %d", mark.LinuxBypassMarkNum, fd) + } + }) + if err != nil { + return err + } + return opErr + }) + logger.Verbosef("Set control func on bind to apply fwmark on socket ops") + return nil +} diff --git a/tunnel/tools/libwg-go/vpn/bind/windows_bind.go b/tunnel/tools/libwg-go/vpn/bind/windows_bind.go new file mode 100644 index 0000000..0edde61 --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/bind/windows_bind.go @@ -0,0 +1,12 @@ +//go:build windows + +package bind + +import ( + "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device" +) + +func SetupBind(logger *device.Logger, bind conn.Bind) error { + return nil +} diff --git a/tunnel/tools/libwg-go/vpn/dns/dns_linux.go b/tunnel/tools/libwg-go/vpn/dns/dns_linux.go new file mode 100644 index 0000000..06ab24f --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/dns/dns_linux.go @@ -0,0 +1,294 @@ +//go:build linux + +package dns + +import ( + "context" + "fmt" + "net/netip" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/godbus/dbus/v5" + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" +) + +const ( + dbusDest = "org.freedesktop.resolve1" + dbusInterface = "org.freedesktop.resolve1.Manager" + dbusPath = "/org/freedesktop/resolve1" + + resolvConfPath = "/etc/resolv.conf" + resolvConfBak = "/etc/resolv.conf.bak.wgt" +) + +// Conn represents a systemd-resolved dbus connection. +type Conn struct { + conn *dbus.Conn + obj dbus.BusObject +} + +func newConn() (*Conn, error) { + conn, err := dbus.SystemBusPrivate() + if err != nil { + return nil, fmt.Errorf("failed to init private conn to system bus: %w", err) + } + methods := []dbus.Auth{dbus.AuthExternal(strconv.Itoa(os.Getuid()))} + err = conn.Auth(methods) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to auth with external method: %w", err) + } + err = conn.Hello() + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to make hello call: %w", err) + } + return &Conn{ + conn: conn, + obj: conn.Object(dbusDest, dbus.ObjectPath(dbusPath)), + }, nil +} + +// Call wraps obj.CallWithContext by using 0 as flags and formats the method with the dbus manager interface. +func (c *Conn) Call(ctx context.Context, method string, args ...interface{}) *dbus.Call { + return c.obj.CallWithContext(ctx, fmt.Sprintf("%s.%s", dbusInterface, method), 0, args...) +} + +// Close closes the current dbus connection. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// SetDns configures DNS servers and search domains, using systemd-resolved if available (per-interface), +// falling back to overwriting /etc/resolv.conf otherwise. +func SetDns(iface string, dns []netip.Addr, searchDomains []string, fullTunnel bool, logger *device.Logger) error { + index, err := getInterfaceIndex(iface) + if isSystemdResolvedActive() { + if err != nil { + logger.Errorf("Failed to get interface name, falling back to resolv.conf: %w", err) + return setDnsFile(dns, searchDomains, fullTunnel) + } + logger.Verbosef("Configuring systemd-resolver...") + return setDnsSystemd(index, dns, searchDomains, fullTunnel) + } + logger.Verbosef("Systemd-resolver not detected, falling back to resolv.conf...") + return setDnsFile(dns, searchDomains, fullTunnel) +} + +func getInterfaceIndex(ifName string) (int, error) { + link, err := netlink.LinkByName(ifName) + if err != nil { + return 0, fmt.Errorf("failed to get link for %s: %w", ifName, err) + } + return link.Attrs().Index, nil +} + +// RevertDns reverts DNS configuration, using systemd-resolved if available, or restoring the resolv.conf backup otherwise. +func RevertDns(iface string, logger *device.Logger) error { + index, err := getInterfaceIndex(iface) + if isSystemdResolvedActive() { + if err != nil { + logger.Errorf("Failed to get interface name, attempting to revert resolv.conf from backup...") + return revertDnsFile() + } + logger.Verbosef("Reverting systemd-resolver...") + return revertDnsSystemd(index) + } + logger.Verbosef("Systemd-resolver not detected, attempting to revert dns from backup...") + return revertDnsFile() +} + +// isSystemdResolvedActive checks if systemd-resolved is available and responsive via DBus. +func isSystemdResolvedActive() bool { + conn, err := newConn() + if err != nil { + return false + } + defer conn.Close() + + // Test with a simple local resolve (flags=0) + var addresses []struct { + IfIndex int + Family int + Address []byte + } + var canonical string + var outflags uint64 + call := conn.Call(context.Background(), "ResolveHostname", 0, "localhost", unix.AF_UNSPEC, uint64(0)) + if call.Err != nil { + return false + } + err = call.Store(&addresses, &canonical, &outflags) + return err == nil +} + +// setDnsSystemd configures DNS via systemd-resolved DBus (per-interface). +func setDnsSystemd(ifIndex int, dns []netip.Addr, searchDomains []string, fullTunnel bool) error { + conn, err := newConn() + if err != nil { + return fmt.Errorf("dbus connect: %w", err) + } + defer conn.Close() + + type dnsEntry struct { + Family int32 + Address []byte + } + + var linkDNS []dnsEntry + for _, ip := range dns { + fam := int32(unix.AF_INET) + if ip.Is6() { + fam = int32(unix.AF_INET6) + } + linkDNS = append(linkDNS, dnsEntry{ + Family: fam, + Address: ip.AsSlice(), + }) + } + call := conn.Call(context.Background(), "SetLinkDNS", ifIndex, linkDNS) + if call.Err != nil { + return fmt.Errorf("set link DNS: %w", call.Err) + } + + type domainEntry struct { + Domain string + Routing bool + } + + var linkDomains []domainEntry + for _, domain := range searchDomains { + linkDomains = append(linkDomains, domainEntry{ + Domain: domain, + Routing: false, + }) + } + // full tunnel, add "~." as routing domain to capture all queries + if fullTunnel && len(dns) > 0 { + linkDomains = append(linkDomains, domainEntry{ + Domain: "~.", + Routing: true, + }) + } + call = conn.Call(context.Background(), "SetLinkDomains", ifIndex, linkDomains) + if call.Err != nil { + return fmt.Errorf("set link domains: %w", call.Err) + } + + // set the link as the default DNS route for full tunnel + if fullTunnel { + call = conn.Call(context.Background(), "SetLinkDefaultRoute", ifIndex, true) + if call.Err != nil { + return fmt.Errorf("set link default route: %w", call.Err) + } + } + + return nil +} + +// revertDnsSystemd reverts DNS configuration via systemd-resolved DBus. +func revertDnsSystemd(ifIndex int) error { + conn, err := newConn() + if err != nil { + return fmt.Errorf("dbus connect: %w", err) + } + defer conn.Close() + + // revert default route + call := conn.Call(context.Background(), "SetLinkDefaultRoute", ifIndex, false) + if call.Err != nil { + return fmt.Errorf("revert link default route: %w", call.Err) + } + + // revert all settings for the link + call = conn.Call(context.Background(), "RevertLink", ifIndex) + if call.Err != nil { + return fmt.Errorf("revert link: %w", call.Err) + } + + return nil +} + +// setDnsFile is the fallback: overwrites /etc/resolv.conf and locks if fullTunnel. +func setDnsFile(dns []netip.Addr, searchDomains []string, fullTunnel bool) error { + if err := backupResolvConf(); err != nil { + return err + } + + // Write new conf + f, err := os.Create(resolvConfPath) + if err != nil { + return err + } + defer f.Close() + + for _, d := range dns { + fmt.Fprintf(f, "nameserver %s\n", d.String()) + } + if len(searchDomains) > 0 { + fmt.Fprintf(f, "search %s\n", strings.Join(searchDomains, " ")) + } + + // attempt lock if full tunnel + if fullTunnel { + if err := lockResolvConf(true); err != nil { + } + } + + return nil +} + +// revertDnsFile is the fallback: restores backup and unlocks. +func revertDnsFile() error { + lockResolvConf(false) + + if _, err := os.Stat(resolvConfBak); os.IsNotExist(err) { + return nil + } + + src, err := os.ReadFile(resolvConfBak) + if err != nil { + return err + } + if err := os.WriteFile(resolvConfPath, src, 0644); err != nil { + return err + } + os.Remove(resolvConfBak) + return nil +} + +// backupResolvConf backs up resolv.conf if not already done. +func backupResolvConf() error { + if _, err := os.Stat(resolvConfBak); err == nil { + return nil + } + src, err := os.ReadFile(resolvConfPath) + if err != nil { + return err + } + return os.WriteFile(resolvConfBak, src, 0644) +} + +// lockResolvConf locks/unlocks with chattr (immutable). +func lockResolvConf(lock bool) error { + arg := "-i" + if lock { + arg = "+i" + } + // use filepath.Abs to handle symlinks properly + absPath, err := filepath.Abs(resolvConfPath) + if err != nil { + return err + } + cmd := exec.Command("chattr", arg, absPath) + if err := cmd.Run(); err != nil { + return fmt.Errorf("chattr %s: %w", arg, err) + } + return nil +} diff --git a/tunnel/tools/libwg-go/vpn/dns/dns_windows.go b/tunnel/tools/libwg-go/vpn/dns/dns_windows.go new file mode 100644 index 0000000..33defd4 --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/dns/dns_windows.go @@ -0,0 +1,79 @@ +//go:build windows + +package dns + +import ( + "fmt" + "net/netip" + "os/exec" + "strings" + + "github.com/amnezia-vpn/amneziawg-go/device" + "golang.org/x/net/nettest" + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func SetDNS(luid winipcfg.LUID, dns []netip.Addr, searchDomains []string, fullTunnel bool, logger *device.Logger) error { + if fullTunnel { + // se global search domains + if len(searchDomains) > 0 { + pscmd := "Set-DnsClientGlobalSetting -SuffixSearchList @('" + strings.Join(searchDomains, "','") + "')" + cmd := exec.Command("powershell", "-Command", pscmd) + if err := cmd.Run(); err != nil { + logger.Errorf("set global search: %v", err) + } + } + } + + // set DNS on interface + v4dns, v6dns := []netip.Addr{}, []netip.Addr{} + for _, d := range dns { + if d.Is4() { + v4dns = append(v4dns, d) + } else if d.Is6() && nettest.SupportsIPv6() { + v6dns = append(v6dns, d) + } + } + + // v4 + if len(v4dns) > 0 || len(searchDomains) > 0 { + err := luid.SetDNS(windows.AF_INET, v4dns, searchDomains) + if err != nil { + return fmt.Errorf("set v4 dns: %w", err) + } + } + + // v6 + if len(v6dns) > 0 || len(searchDomains) > 0 { + err := luid.SetDNS(windows.AF_INET6, v6dns, searchDomains) + if err != nil { + return fmt.Errorf("set v6 dns: %w", err) + } + } + + return nil +} + +func RevertDNS(luid winipcfg.LUID, fullTunnel bool, originalSearchDomains []string, logger *device.Logger) error { + if fullTunnel && originalSearchDomains != nil { + // restore original global search + pscmd := "Set-DnsClientGlobalSetting -SuffixSearchList @('" + strings.Join(originalSearchDomains, "','") + "')" + cmd := exec.Command("powershell", "-Command", pscmd) + if err := cmd.Run(); err != nil { + logger.Errorf("restore global search: %v", err) + } + originalSearchDomains = nil + } else if fullTunnel { + // clear if no original + pscmd := "Set-DnsClientGlobalSetting -SuffixSearchList @()" + cmd := exec.Command("powershell", "-Command", pscmd) + cmd.Run() + } + + // clear DNS interface + luid.FlushDNS(windows.AF_INET) + luid.FlushDNS(windows.AF_INET6) + + return nil +} diff --git a/tunnel/tools/libwg-go/vpn/firewall/firewall.go b/tunnel/tools/libwg-go/vpn/firewall/firewall.go new file mode 100644 index 0000000..1eb969d --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/firewall/firewall.go @@ -0,0 +1,21 @@ +package firewall + +import "net/netip" + +// Firewall is responsible for managing the system's firewall rules, especially the kill switch. It operates independently of the router. +type Firewall interface { + + // Enable activates the kill switch, blocking all outbound traffic except + // explicitly allowed bypasses. + Enable() error + + // IsEnabled reports whether the kill switch is currently active. + IsEnabled() bool + + // Disable deactivates the kill switch and cleans up all rules. + Disable() error + + // AllowLocalNetworks adds bypass rules for the specified local network prefixes. Requires kill switch enabled and + // operates independently of tunnel/router bypasses. + AllowLocalNetworks([]netip.Prefix) error +} diff --git a/tunnel/tools/libwg-go/vpn/firewall/mark/mark.go b/tunnel/tools/libwg-go/vpn/firewall/mark/mark.go new file mode 100644 index 0000000..c6eb995 --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/firewall/mark/mark.go @@ -0,0 +1,16 @@ +package mark + +const ( + // LinuxFwmarkMaskNum Used to isolate bits 16-23 which are the safe range for custom marks + LinuxFwmarkMaskNum = 0xff0000 + // Our mark num + LinuxBypassMarkNum = 0x100000 + // LinuxBootstrapMarkNum is specifically for the DNS Resolver + LinuxBootstrapMarkNum = 0x200000 +) + +var ( + // LinuxBootstrapMarkBytes is the Little Endian representation for nftables + // 0x200000 -> [00, 00, 20, 00] + LinuxBootstrapMarkBytes = []byte{0x00, 0x00, 0x20, 0x00} +) diff --git a/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewall_linux.go b/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewall_linux.go new file mode 100644 index 0000000..dec66ff --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewall_linux.go @@ -0,0 +1,1085 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright © 2026 WG Tunnel. +// Adapted from Tailscale + +//go:build linux && !android + +package osfirewall + +import ( + "encoding/binary" + "errors" + "fmt" + "net/netip" + "reflect" + "sync/atomic" + + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/google/nftables" + "github.com/google/nftables/expr" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall/mark" + "golang.org/x/net/nettest" + "golang.org/x/sys/unix" +) + +const ( + baseChainInput = "INPUT" + baseChainOutput = "OUTPUT" + baseChainPostrouting = "POSTROUTING" + baseChainForward = "FORWARD" + chainNameForward = "wgtunnel-forward" + chainNameInput = "wgtunnel-input" + chainNamePostrouting = "wgtunnel-postrouting" + chainNameOutput = "wgtunnel-output" + chainTypeRegular = "" +) + +type LinuxFirewall struct { + conn *nftables.Conn + nft4 *nftable // IPv4 tables, never nil + nft6 *nftable // IPv6 tables or nil if no IPv6 support + + v6Available bool + + tunnelPort uint16 + + killSwitchEnabled atomic.Bool + logger *device.Logger + + localAddrRules []*nftables.Rule // For tracking AllowedLocalNetworks rules + tunnelRules map[string][]*nftables.Rule // For tracking iface tunnel bypass rules +} + +func New(logger *device.Logger) (firewall.Firewall, error) { + conn, err := nftables.New() + if err != nil { + return nil, fmt.Errorf("nftables connection: %w", err) + } + + nft4 := &nftable{Proto: nftables.TableFamilyIPv4} + + supportsV6 := nettest.SupportsIPv6() + + var nft6 *nftable + if supportsV6 { + nft6 = &nftable{Proto: nftables.TableFamilyIPv6} + } + logger.Verbosef("nftables mode, v6 support: %v", supportsV6) + + f := &LinuxFirewall{ + conn: conn, + nft4: nft4, + nft6: nft6, + v6Available: supportsV6, + logger: logger, + tunnelRules: make(map[string][]*nftables.Rule), + } + return f, nil +} + +func (f *LinuxFirewall) AddTunnelBypasses(iface string) error { + if !f.IsEnabled() { + return errors.New("kill switch must be enabled to add tunnel bypasses") + } + + // remove old rules + _ = f.RemoveTunnelBypasses(iface) + + var newRules []*nftables.Rule + + for _, table := range f.getTables() { + outputChain, err := getChainFromTable(f.conn, table.Filter, chainNameOutput) + inputChain, _ := getChainFromTable(f.conn, table.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get output chain: %w", err) + } + + // apply tunnel mark + bootstrapRule := createFwmarkRule(table.Filter, outputChain, mark.LinuxBootstrapMarkNum) + f.conn.InsertRule(bootstrapRule) + newRules = append(newRules, bootstrapRule) + + // allow input for DNS boostrap + stateRule := &nftables.Rule{ + Table: table.Filter, + Chain: inputChain, + Exprs: []expr.Any{ + &expr.Ct{Key: expr.CtKeySTATE, Register: 1}, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: []byte{0x06, 0x00, 0x00, 0x00}, // ESTABLISHED (2) | RELATED (4) + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + } + f.conn.InsertRule(stateRule) + newRules = append(newRules, stateRule) + + // add tunnel interface bypass rule + tunnelBypassRule := &nftables.Rule{ + Table: table.Filter, + Chain: outputChain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(iface + "\x00"), + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + } + existing, _ := findRule(f.conn, tunnelBypassRule) + if existing == nil { + f.conn.InsertRule(tunnelBypassRule) + newRules = append(newRules, tunnelBypassRule) + } + + // Add prefix bypass rules + //for _, prefix := range prefixes { + // if prefix.Addr().Is6() && !f.v6Available { + // continue + // } + // rule, err := createRangeRule(table.Filter, outputChain, prefix, expr.VerdictAccept) + // if err != nil { + // return fmt.Errorf("create bypass rule for %v: %w", prefix, err) + // } + // existing, _ = findRule(f.conn, rule) + // if existing == nil { + // f.conn.InsertRule(rule) + // newRules = append(newRules, rule) + // } + //} + } + + if err := f.conn.Flush(); err != nil { + return fmt.Errorf("flush after adding tunnel bypasses: %w", err) + } + + if f.tunnelRules == nil { + f.tunnelRules = make(map[string][]*nftables.Rule) + } + f.tunnelRules[iface] = newRules + + f.logger.Verbosef("Added/Updated tunnel bypasses for iface %s", iface) + return nil +} + +func (f *LinuxFirewall) RemoveTunnelBypasses(iface string) error { + if !f.IsEnabled() { + f.logger.Verbosef("Firewall is not enabled, skipping") + return nil + } + + rules, ok := f.tunnelRules[iface] + if !ok { + return nil + } + + for _, rule := range rules { + f.conn.DelRule(rule) + } + + if err := f.conn.Flush(); err != nil { + return fmt.Errorf("flush after removing tunnel bypasses: %w", err) + } + + delete(f.tunnelRules, iface) + + f.logger.Verbosef("Removed tunnel bypasses for iface %s", iface) + return nil +} + +func (f *LinuxFirewall) Disable() error { + if !f.IsEnabled() { + f.logger.Verbosef("Firewall is not enabled, skipping") + return nil + } + + // remove hooks + if err := f.deleteCustomHooks(); err != nil { + f.logger.Errorf("del hooks: %v", err) + } + + // flush base rules + if err := f.flushCustomChains(); err != nil { + f.logger.Errorf("del base: %v", err) + } + + // delete chains + if err := f.deleteCustomChains(); err != nil { + f.logger.Errorf("del chains: %v", err) + } + + // delete tables + for _, family := range []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyIPv6} { + if err := deleteTableIfExists(f.conn, family, "filter"); err != nil { + f.logger.Errorf("delete filter table (%v): %v", family, err) + } + if err := deleteTableIfExists(f.conn, family, "nat"); err != nil { + f.logger.Errorf("delete nat table (%v): %v", family, err) + } + } + + if err := f.conn.Flush(); err != nil { + return fmt.Errorf("final flush: %w", err) + } + + f.localAddrRules = nil + f.tunnelRules = make(map[string][]*nftables.Rule) + + f.killSwitchEnabled.Store(false) + + f.logger.Verbosef("Firewall cleaned up and kill switch disabled") + return nil +} + +func (f *LinuxFirewall) AllowLocalNetworks(prefixes []netip.Prefix) error { + if !f.IsEnabled() { + return errors.New("kill switch must be enabled to allow local networks") + } + + // remove any old rules + for _, rule := range f.localAddrRules { + f.conn.DelRule(rule) + } + f.localAddrRules = nil + + // add bypass rules for each prefix + for _, table := range f.getTables() { + outputChain, err := getChainFromTable(f.conn, table.Filter, chainNameOutput) + if err != nil { + return fmt.Errorf("get output chain: %w", err) + } + for _, prefix := range prefixes { + if prefix.Addr().Is6() && !f.v6Available { + continue + } + rule, err := createRangeRule(table.Filter, outputChain, prefix, expr.VerdictAccept) + if err != nil { + return fmt.Errorf("create bypass rule for %v: %w", prefix, err) + } + existing, _ := findRule(f.conn, rule) + if existing == nil { + f.conn.AddRule(rule) + f.localAddrRules = append(f.localAddrRules, rule) + } + } + } + if err := f.conn.Flush(); err != nil { + return fmt.Errorf("flush after bypassing local addrs: %w", err) + } + f.logger.Verbosef("Bypassed local addrs: %v", prefixes) + return nil +} + +func (f *LinuxFirewall) IsEnabled() bool { + return f.killSwitchEnabled.Load() +} + +type nftable struct { + Proto nftables.TableFamily + Filter *nftables.Table + Nat *nftables.Table +} + +type chainInfo struct { + table *nftables.Table + name string + chainType nftables.ChainType + chainHook *nftables.ChainHook + chainPriority *nftables.ChainPriority + chainPolicy *nftables.ChainPolicy +} + +var ErrChainNotFound = errors.New("chain not found") + +type errorChainNotFound struct { + chainName string + tableName string +} + +func (e errorChainNotFound) Error() string { + return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName) +} + +func (e errorChainNotFound) Is(target error) bool { + return target == ErrChainNotFound +} + +// SetTunnelPort adds punch rules for inbound UDP on the port. +func (f *LinuxFirewall) SetTunnelPort(port uint16) error { + for _, table := range f.getTables() { + inputChain, err := getChainFromTable(f.conn, table.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain: %w", err) + } + if err := addAcceptOnPortRule(f.conn, table.Filter, inputChain, port); err != nil { + return fmt.Errorf("add accept on port rule: %w", err) + } + } + if err := f.conn.Flush(); err != nil { + return fmt.Errorf("flush after adding port punch: %w", err) + } + f.tunnelPort = port + f.logger.Verbosef("Added tunnel port punch for UDP port %d", port) + return nil +} + +// addAcceptOnPortRule adds the rule if not exist +func addAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error { + rule := createAcceptOnPortRule(table, chain, port) + existing, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("find rule: %w", err) + } + if existing != nil { + return nil // Already exists + } + conn.InsertRule(rule) + return nil // Flush called outside +} + +// createAcceptOnPortRule creates ACCEPT rule for UDP dport. +func createAcceptOnPortRule(table *nftables.Table, chain *nftables.Chain, port uint16) *nftables.Rule { + portBytes := make([]byte, 2) + // for network byte order + binary.BigEndian.PutUint16(portBytes, port) + return &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + // load layer 4 protocol (for UDP/TCP) in register 1 temp storage + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + // check if loaded register 1 storage is UDP + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_UDP}, + }, + // load the destination port from register 1 + newLoadDportExpr(1), + // check if the port matches our wg listener port + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: portBytes, + }, + &expr.Counter{}, + // allow it on the firewall + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + } +} + +// newLoadDportExpr loads dport to register +func newLoadDportExpr(destReg uint32) expr.Any { + return &expr.Payload{ + DestRegister: destReg, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + } +} + +// deleteTableIfExists deletes a nftables table if it exists. +func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error { + t, err := getTableIfExists(c, family, name) + if err != nil { + return fmt.Errorf("get table: %w", err) + } + if t == nil { + return nil // Not exist + } + c.DelTable(t) + if err := c.Flush(); err != nil { + return fmt.Errorf("del table: %w", err) + } + return nil +} + +// getTableIfExists returns the table if it exists. +func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { + tables, err := c.ListTables() + if err != nil { + return nil, fmt.Errorf("get tables: %w", err) + } + for _, table := range tables { + if table.Name == name && table.Family == family { + return table, nil + } + } + return nil, nil +} + +// createTableIfNotExist creates a nftables table if not exist. +func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { + if t, err := getTableIfExists(c, family, name); err != nil { + return nil, fmt.Errorf("get table: %w", err) + } else if t != nil { + return t, nil + } + t := c.AddTable(&nftables.Table{ + Family: family, + Name: name, + }) + if err := c.Flush(); err != nil { + return nil, fmt.Errorf("add table: %w", err) + } + return t, nil +} + +// getChainFromTable returns the chain if it exists. +func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) { + chains, err := c.ListChainsOfTableFamily(table.Family) + if err != nil { + return nil, fmt.Errorf("list chains: %w", err) + } + for _, chain := range chains { + if chain.Table.Name == table.Name && chain.Name == name { + return chain, nil + } + } + return nil, errorChainNotFound{chainName: name, tableName: table.Name} +} + +// createChainIfNotExist creates a chain if not exist. +func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { + _, err := getOrCreateChain(c, cinfo) + return err +} + +func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error) { + chain, err := getChainFromTable(c, cinfo.table, cinfo.name) + if err != nil && !errors.Is(err, ErrChainNotFound) { + return nil, fmt.Errorf("get chain: %w", err) + } else if err == nil { + // Existing chain; check compatibility if needed + return chain, nil + } + + chain = c.AddChain(&nftables.Chain{ + Name: cinfo.name, + Table: cinfo.table, + Type: cinfo.chainType, + Hooknum: cinfo.chainHook, + Priority: cinfo.chainPriority, + Policy: cinfo.chainPolicy, + }) + + if err := c.Flush(); err != nil { + return nil, fmt.Errorf("add chain: %w", err) + } + + return chain, nil +} + +// deleteChainIfExists deletes a chain if it exists. +func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error { + chain, err := getChainFromTable(c, table, name) + if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) { + return fmt.Errorf("get chain: %w", err) + } else if err != nil { + return nil // Not exist + } + + c.FlushChain(chain) + c.DelChain(chain) + + if err := c.Flush(); err != nil { + return fmt.Errorf("flush and delete chain: %w", err) + } + + return nil +} + +// getTables returns v4/v6 tables based on system support. +func (f *LinuxFirewall) getTables() []*nftable { + if f.v6Available { + return []*nftable{f.nft4, f.nft6} + } + return []*nftable{f.nft4} +} + +// getNFTByAddr selects v4/v6 table by addr family. +func (f *LinuxFirewall) getNFTByAddr(addr netip.Addr) (*nftable, error) { + if addr.Is6() && !f.v6Available { + return nil, fmt.Errorf("nftables for IPv6 not available") + } + if addr.Is6() { + return f.nft6, nil + } + return f.nft4, nil +} + +// findRule finds a rule by matching expressions. +func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) { + rules, err := conn.GetRules(rule.Table, rule.Chain) + if err != nil { + return nil, fmt.Errorf("get rules: %w", err) + } + for _, r := range rules { + if len(r.Exprs) != len(rule.Exprs) { + continue + } + match := true + for i, e := range r.Exprs { + if _, ok := e.(*expr.Counter); ok { + continue // Skip counters + } + if !reflect.DeepEqual(e, rule.Exprs[i]) { + match = false + break + } + } + if match { + return r, nil + } + } + return nil, nil +} + +func (f *LinuxFirewall) Enable() error { + if f.IsEnabled() { + f.logger.Verbosef("Kill switch already active, skipping activation") + return nil + } + + polAccept := nftables.ChainPolicyAccept + for _, table := range f.getTables() { + // Create filter table + filter, err := createTableIfNotExist(f.conn, table.Proto, "filter") + if err != nil { + return fmt.Errorf("create filter table: %w", err) + } + table.Filter = filter + + _, err = getOrCreateChain(f.conn, chainInfo{filter, baseChainForward, nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}) + if err != nil { + return fmt.Errorf("create FORWARD chain: %w", err) + } + _, err = getOrCreateChain(f.conn, chainInfo{filter, baseChainInput, nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}) + if err != nil { + return fmt.Errorf("create INPUT chain: %w", err) + } + _, err = getOrCreateChain(f.conn, chainInfo{filter, baseChainOutput, nftables.ChainTypeFilter, nftables.ChainHookOutput, nftables.ChainPriorityFilter, &polAccept}) + if err != nil { + return fmt.Errorf("create OUTPUT chain: %w", err) + } + + // Custom chains (regular, jumped to from conventional) + if err = createChainIfNotExist(f.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil { + return fmt.Errorf("create wgtunnel-forward chain: %w", err) + } + if err = createChainIfNotExist(f.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { + return fmt.Errorf("create wgtunnel-input chain: %w", err) + } + if err = createChainIfNotExist(f.conn, chainInfo{filter, chainNameOutput, chainTypeRegular, nil, nil, nil}); err != nil { + return fmt.Errorf("create wgtunnel-output chain: %w", err) + } + + nat, err := createTableIfNotExist(f.conn, table.Proto, "nat") + if err != nil { + return fmt.Errorf("create nat table: %w", err) + } + table.Nat = nat + + _, err = getOrCreateChain(f.conn, chainInfo{nat, baseChainPostrouting, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}) + if err != nil { + return fmt.Errorf("create POSTROUTING chain: %w", err) + } + if err = createChainIfNotExist(f.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil { + return fmt.Errorf("create wgtunnel-postrouting chain: %w", err) + } + } + + if err := f.conn.Flush(); err != nil { + return fmt.Errorf("flush after chain creation: %w", err) + } + + if err := f.addHooks(); err != nil { + return fmt.Errorf("add hooks: %w", err) + } + + if err := f.addKillSwitchRules(); err != nil { + return fmt.Errorf("add kill switch rules: %w", err) + } + + f.killSwitchEnabled.Store(true) + return nil +} + +// addHooks adds jump rules from conventional chains to custom ones. +func (f *LinuxFirewall) addHooks() error { + conn := f.conn + + for _, table := range f.getTables() { + inputChain, err := getChainFromTable(conn, table.Filter, baseChainInput) + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + if err = addHookRule(conn, table.Filter, inputChain, chainNameInput); err != nil { + return fmt.Errorf("add INPUT hook: %w", err) + } + + forwardChain, err := getChainFromTable(conn, table.Filter, baseChainForward) + if err != nil { + return fmt.Errorf("get FORWARD chain: %w", err) + } + if err = addHookRule(conn, table.Filter, forwardChain, chainNameForward); err != nil { + return fmt.Errorf("add FORWARD hook: %w", err) + } + + outputChain, err := getChainFromTable(conn, table.Filter, baseChainOutput) + if err != nil { + return fmt.Errorf("get OUTPUT chain: %w", err) + } + if err = addHookRule(conn, table.Filter, outputChain, chainNameOutput); err != nil { + return fmt.Errorf("add OUTPUT hook: %w", err) + } + + postroutingChain, err := getChainFromTable(conn, table.Nat, baseChainPostrouting) + if err != nil { + return fmt.Errorf("get POSTROUTING chain: %w", err) + } + if err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting); err != nil { + return fmt.Errorf("add POSTROUTING hook: %w", err) + } + } + return nil +} + +// createHookRule creates a jump rule. +func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule { + return &nftables.Rule{ + Table: table, + Chain: fromChain, + Exprs: []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: toChainName, + }, + }, + } +} + +// addHookRule inserts a jump rule at the top. +func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { + rule := createHookRule(table, fromChain, toChainName) + conn.InsertRule(rule) + return conn.Flush() +} + +// addKillSwitchRules adds bypass for fwmark and DROP at end (private helper). +func (f *LinuxFirewall) addKillSwitchRules() error { + f.logger.Verbosef("Adding kill switch rules...") + + for _, table := range f.getTables() { + + inputChain, err := getChainFromTable(f.conn, table.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain: %w", err) + } + + // allow loopback + if err := f.addLoopbackRule(table.Filter, inputChain); err != nil { + return err + } + + // allow Established/Related traffic for reply + if err := f.addEstablishedRule(table.Filter, inputChain); err != nil { + return err + } + + // drop everything else + dropRule := createDropRule(table.Filter, inputChain) + f.conn.AddRule(dropRule) + + outputChain, err := getChainFromTable(f.conn, table.Filter, chainNameOutput) + if err != nil { + return fmt.Errorf("get output chain: %w", err) + } + + // allow loopback on output + if err := f.addLoopbackRule(table.Filter, outputChain); err != nil { + return err + } + + // allow the marked tunnel traffic + bypassRule := createFwmarkRule(table.Filter, outputChain, mark.LinuxBypassMarkNum) + f.conn.InsertRule(bypassRule) + + // drop everything else + dropRule = createDropRule(table.Filter, outputChain) + f.conn.AddRule(dropRule) + + forwardChain, err := getChainFromTable(f.conn, table.Filter, chainNameForward) + if err != nil { + return fmt.Errorf("get forward chain: %w", err) + } + + // drop all forwarded traffic + dropRule = createDropRule(table.Filter, forwardChain) + f.conn.AddRule(dropRule) + } + + if err := f.conn.Flush(); err != nil { + return fmt.Errorf("flush after adding kill switch: %w", err) + } + f.logger.Verbosef("Kill switch rules added.") + return nil +} + +// addTunnelInterfaceRule adds a rule to let our tun interface escape firewall +func (f *LinuxFirewall) addTunnelInterfaceRule(iface string, table *nftables.Table, chain *nftables.Chain) error { + tunnelBypassRule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(iface + "\x00"), + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + } + existing, _ := findRule(f.conn, tunnelBypassRule) + if existing == nil { + f.conn.InsertRule(tunnelBypassRule) + } + return nil +} + +func (f *LinuxFirewall) addLoopbackRule(table *nftables.Table, chain *nftables.Chain) error { + loRule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: getIfKeyForChain(chain), + Register: 1, + }, + // Compare Register 1 to "lo" + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte("lo\x00"), // Null-terminated string + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + } + f.conn.InsertRule(loRule) + return nil +} + +// Helper to determine if we should look at Input or Output interface +func getIfKeyForChain(chain *nftables.Chain) expr.MetaKey { + if chain.Name == chainNameInput { + return expr.MetaKeyIIFNAME + } + return expr.MetaKeyOIFNAME +} + +func (f *LinuxFirewall) addEstablishedRule(table *nftables.Table, chain *nftables.Chain) error { + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + // Load Connection Tracking State + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + // Bitwise check for Established (0x02) | Related (0x04) = 0x06 + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: []byte{0x06, 0x00, 0x00, 0x00}, // Bits 1 and 2 (Est/Rel) + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, // If result is NOT 0, it matched one of the bits + Register: 1, + Data: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + } + + f.conn.InsertRule(rule) + return nil +} + +// delKillSwitchRules removes kill switch by flushing chains +func (f *LinuxFirewall) delKillSwitchRules() error { + f.logger.Verbosef("Removing kill switch rules...") + + for _, table := range f.getTables() { + if outputChain, err := getChainFromTable(f.conn, table.Filter, chainNameOutput); err == nil { + f.conn.FlushChain(outputChain) + } + + if inputChain, err := getChainFromTable(f.conn, table.Filter, chainNameInput); err == nil { + f.conn.FlushChain(inputChain) + } + + if forwardChain, err := getChainFromTable(f.conn, table.Filter, chainNameForward); err == nil { + f.conn.FlushChain(forwardChain) + } + } + + if err := f.conn.Flush(); err != nil { + return fmt.Errorf("flush after deleting kill switch: %w", err) + } + + f.logger.Verbosef("Kill switch rules removed.") + + return nil +} + +// createDropRule creates a simple DROP rule with counter +func createDropRule(table *nftables.Table, chain *nftables.Chain) *nftables.Rule { + return &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictDrop}, + }, + } +} + +// createRangeRule creates ACCEPT for dst IP in prefix/range (adapted for daddr). +func createRangeRule( + table *nftables.Table, + chain *nftables.Chain, + rng netip.Prefix, + decision expr.VerdictKind, +) (*nftables.Rule, error) { + var loadExpr expr.Any + var maskLen uint32 + var mask []byte + var xor []byte + var err error + + if rng.Addr().Is4() { + loadExpr, err = newLoadDaddrExpr(nftables.TableFamilyIPv4, 1) + if err != nil { + return nil, fmt.Errorf("newLoadDaddrExpr: %w", err) + } + maskLen = 4 + mask = maskOf(rng) + xor = []byte{0x00, 0x00, 0x00, 0x00} + } else { + loadExpr, err = newLoadDaddrExpr(nftables.TableFamilyIPv6, 1) + if err != nil { + return nil, fmt.Errorf("newLoadDaddrExpr: %w", err) + } + maskLen = 16 + bits := rng.Bits() + mask = make([]byte, 16) + for i := 0; i < bits/8; i++ { + mask[i] = 0xff + } + if bits%8 != 0 { + mask[bits/8] = 0xff << (8 - uint(bits%8)) + } + xor = make([]byte, 16) + } + + netip := rng.Addr().AsSlice() + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + loadExpr, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: maskLen, + Mask: mask, + Xor: xor, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: netip, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: decision, + }, + }, + } + return rule, nil +} + +// newLoadDaddrExpr loads destination addr into register. +func newLoadDaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) { + switch proto { + case nftables.TableFamilyIPv4: + return &expr.Payload{ + DestRegister: destReg, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, // IPv4 offset + Len: 4, + }, nil + case nftables.TableFamilyIPv6: + return &expr.Payload{ + DestRegister: destReg, + Base: expr.PayloadBaseNetworkHeader, + Offset: 24, // IPv6 offset + Len: 16, + }, nil + default: + return nil, fmt.Errorf("unsupported family %v", proto) + } +} + +// createFwmarkRule generates a rule for a specific mark within our mask +func createFwmarkRule(table *nftables.Table, chain *nftables.Chain, markVal uint32) *nftables.Rule { + maskBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(maskBytes, mark.LinuxFwmarkMaskNum) + + markBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(markBytes, markVal) + + return &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: maskBytes, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: markBytes, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + } +} + +// maskOf returns CIDR mask bytes +func maskOf(pfx netip.Prefix) []byte { + mask := make([]byte, 4) + binary.BigEndian.PutUint32(mask, ^(uint32(0xffffffff) >> pfx.Bits())) + return mask +} + +// deleteCustomHooks removes jump rules from base to custom chains +func (f *LinuxFirewall) deleteCustomHooks() error { + conn := f.conn + for _, table := range f.getTables() { + if table == nil || table.Filter == nil { + continue // skip if table or filter not initialized + } + inputChain, err := getChainFromTable(conn, table.Filter, baseChainInput) + if err == nil && inputChain != nil { + deleteHookRule(conn, table.Filter, inputChain, chainNameInput) + } + + forwardChain, err := getChainFromTable(conn, table.Filter, baseChainForward) + if err == nil && forwardChain != nil { + deleteHookRule(conn, table.Filter, forwardChain, chainNameForward) + } + + outputChain, err := getChainFromTable(conn, table.Filter, baseChainOutput) + if err == nil && outputChain != nil { + deleteHookRule(conn, table.Filter, outputChain, chainNameOutput) + } + + if table.Nat == nil { + continue + } + postroutingChain, err := getChainFromTable(conn, table.Nat, baseChainPostrouting) + if err == nil && postroutingChain != nil { + deleteHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) + } + } + return conn.Flush() +} + +// deleteHookRule deletes a specific jump rule if it exists +func deleteHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { + rule := createHookRule(table, fromChain, toChainName) + existing, err := findRule(conn, rule) + if err != nil || existing == nil { + return err // Or nil if not found + } + conn.DelRule(existing) + return nil +} + +// deleteCustomChains deletes custom chains +func (f *LinuxFirewall) deleteCustomChains() error { + for _, table := range f.getTables() { + deleteChainIfExists(f.conn, table.Filter, chainNameForward) + deleteChainIfExists(f.conn, table.Filter, chainNameInput) + deleteChainIfExists(f.conn, table.Filter, chainNameOutput) + deleteChainIfExists(f.conn, table.Nat, chainNamePostrouting) + } + return f.conn.Flush() +} + +// flushCustomChains flushes rules from custom chains +func (f *LinuxFirewall) flushCustomChains() error { + for _, table := range f.getTables() { + inputChain, err := getChainFromTable(f.conn, table.Filter, chainNameInput) + if err == nil { + f.conn.FlushChain(inputChain) + } + + forwardChain, err := getChainFromTable(f.conn, table.Filter, chainNameForward) + if err == nil { + f.conn.FlushChain(forwardChain) + } + + outputChain, err := getChainFromTable(f.conn, table.Filter, chainNameOutput) + if err == nil { + f.conn.FlushChain(outputChain) + } + + postrouteChain, err := getChainFromTable(f.conn, table.Nat, chainNamePostrouting) + if err == nil { + f.conn.FlushChain(postrouteChain) + } + } + return f.conn.Flush() +} diff --git a/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewall_macos.go b/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewall_macos.go new file mode 100644 index 0000000..b067bfe --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewall_macos.go @@ -0,0 +1,290 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright © 2026 WG Tunnel. +// Adapted from Tailscale + +//go:build darwin + +package osfirewall + +import ( + "errors" + "fmt" + "net/netip" + "os" + "os/exec" + "path/filepath" + "strings" + "unsafe" + + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall" + "golang.org/x/net/bpf" + "golang.org/x/sys/unix" +) + +const ( + anchorName = "wgtunnel" + pfConfPath = "/etc/pf.conf" // System PF config; we'll append our anchor +) + +// macFirewall implements the firewall.Firewall interface for macOS using PF (Packet Filter). +type macFirewall struct { + tunnelPort uint16 // WireGuard listen port for inbound punch + killSwitchEnabled bool // Track if kill switch is active (not atomic, as PF is stateful) + v6Available bool // Whether the host supports IPv6 + logger *device.Logger +} + +func New(logger *device.Logger) (firewall.Firewall, error) { + v6err := CheckIPv6(logger) + supportsV6 := v6err == nil + logger.Verbosef("PF mode, v6 support: %v", supportsV6) + + return &macFirewall{ + v6Available: supportsV6, + logger: logger, + }, nil +} + +func (f *macFirewall) HasV6Available() bool { + return f.v6Available +} + +func (f *macFirewall) Active() bool { + return f.killSwitchEnabled +} + +// Enable initializes the firewall (e.g., ensures PF is enabled and our anchor is referenced). +func (f *macFirewall) Up() error { + // Ensure PF is enabled (macOS default is off; enable if needed) + if err := execSudoCommand("pfctl", "-e"); err != nil && !strings.Contains(err.Error(), "already enabled") { + return fmt.Errorf("enable PF: %w", err) + } + + // Add our anchor to /etc/pf.conf if not present (append if needed) + conf, err := os.ReadFile(pfConfPath) + if err != nil { + return fmt.Errorf("read pf.conf: %w", err) + } + if !strings.Contains(string(conf), fmt.Sprintf(`anchor "%s"`, anchorName)) { + if err := os.AppendFile(pfConfPath, []byte(fmt.Sprintf(`\nanchor "%s"\nload anchor "%s" from "/etc/pf.anchors/%s"\n`, anchorName, anchorName, anchorName)), 0644); err != nil { + return fmt.Errorf("append to pf.conf: %w", err) + } + if err := execSudoCommand("pfctl", "-f", pfConfPath); err != nil { + return fmt.Errorf("reload pf.conf: %w", err) + } + } + + f.logger.Verbosef("PF initialized") + return nil +} + +// SetTunnelPort sets the UDP port for the WireGuard tunnel and adds punch rules. +func (f *macFirewall) SetTunnelPort(port uint16) error { + rule := fmt.Sprintf("pass in quick proto udp to any port %d keep state", port) + if err := f.addRuleToAnchor(rule); err != nil { + return fmt.Errorf("add port punch rule: %w", err) + } + f.tunnelPort = port + f.logger.Verbosef("Added tunnel port punch for UDP port %d", port) + return nil +} + +// ToggleKillSwitch enables/disables the kill switch. +func (f *macFirewall) ToggleKillSwitch(enable bool) error { + if enable == f.killSwitchEnabled { + return nil + } + + if enable { + if err := f.addKillSwitchRules(); err != nil { + return fmt.Errorf("add kill switch rules: %w", err) + } + } else { + if err := f.delKillSwitchRules(); err != nil { + return fmt.Errorf("del kill switch rules: %w", err) + } + } + + f.killSwitchEnabled = enable + f.logger.Verbosef("Kill switch toggled: %v", enable) + return nil +} + +// addKillSwitchRules adds PF rules for kill switch (block non-tunnel outbound, with exemptions). +func (f *macFirewall) addKillSwitchRules() error { + rules := []string{ + "block out all", // Default block outbound + "pass out quick on utun0 all", // Allow on tunnel (adjust 'utun0' dynamically if needed) + "pass out quick to all", // Placeholder for SetBypassRoutes + // Add loopback allowance + "pass out quick on lo0 all", + "pass in quick on lo0 all", + // Established/related (PF handles keep state) + "pass out all keep state", + "pass in all keep state", + } + + if err := f.writeRulesToAnchor(rules); err != nil { + return err + } + f.logger.Verbosef("Kill switch rules added") + return nil +} + +// delKillSwitchRules removes kill switch rules by clearing the anchor. +func (f *macFirewall) delKillSwitchRules() error { + if err := f.clearAnchor(); err != nil { + return err + } + f.logger.Verbosef("Kill switch rules removed") + return nil +} + +// SetBypassRoutes adds exemptions for bootstrap routes. +func (f *macFirewall) SetBypassRoutes(bypassRoutes []netip.Prefix) error { + var rules []string + for _, route := range bypassRoutes { + rules = append(rules, fmt.Sprintf("pass out quick to %s all", route.String())) + } + if err := f.addRulesToAnchor(rules); err != nil { + return fmt.Errorf("add bypass routes: %w", err) + } + f.logger.Verbosef("Added bypass routes: %v", bypassRoutes) + return nil +} + +// TemporaryBypassSocket uses BPF to attach a filter to the socket for bypass. +func (f *macFirewall) TemporaryBypassSocket(fd int) (func() error, error) { + // Compile a simple BPF program to allow specific traffic (e.g., UDP to VPN ports) + // Example: Allow UDP dport == your tunnel port (adjust as needed) + // This is a basic UDP check; extend for port/IP as needed + instructions := []bpf.Instruction{ + bpf.LoadAbsolute{Off: 9, Size: 1}, // Load IP protocol (offset 9 in IP header) + bpf.JumpIf{Cond: bpf.JumpEqual, Val: 17, SkipFalse: 2}, // Check if UDP (17), jump to reject if not + // Add port check here if needed, e.g.: + // bpf.LoadAbsolute{Off: 22, Size: 2}, // Load UDP dport (network byte order, offset 20 src +2 dst in UDP) + // bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(f.tunnelPort), SkipFalse: 1}, + bpf.RetConstant{Val: 65535}, // Accept (return max packet length) + bpf.RetConstant{Val: 0}, // Reject + } + + prog, err := bpf.Assemble(instructions) // Compile to machine code + if err != nil { + return nil, fmt.Errorf("assemble BPF: %w", err) + } + + // Prepare SockFprog struct + sockFprog := unix.SockFprog{ + Len: uint16(len(prog)), + Filter: (*unix.Sockfilter)(unsafe.Pointer(&prog[0])), + } + + // Attach to socket using exported SetsockoptSockFprog + if err := unix.SetsockoptSockFprog(fd, unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, &sockFprog); err != nil { + return nil, fmt.Errorf("attach BPF to fd %d: %w", fd, err) + } + f.logger.Verbosef("BPF bypass attached to fd %d", fd) + + return func() error { + // Detach BPF + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DETACH_FILTER, 0); err != nil { + f.logger.Errorf("Failed to detach BPF on fd %d: %v", fd, err) + return err + } + f.logger.Verbosef("BPF detached from fd %d", fd) + return nil + }, nil +} + +// Helper: writeRulesToAnchor writes rules to anchor file and reloads. +func (f *macFirewall) writeRulesToAnchor(rules []string) error { + anchorPath := filepath.Join("/etc/pf.anchors", anchorName) + content := strings.Join(rules, "\n") + if err := os.WriteFile(anchorPath, []byte(content), 0644); err != nil { + return fmt.Errorf("write anchor file: %w", err) + } + return f.reloadAnchor() +} + +// addRuleToAnchor appends a single rule and reloads. +func (f *macFirewall) addRuleToAnchor(rule string) error { + anchorPath := filepath.Join("/etc/pf.anchors", anchorName) + file, err := os.OpenFile(anchorPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return fmt.Errorf("open anchor file: %w", err) + } + defer file.Close() + if _, err := file.WriteString(rule + "\n"); err != nil { + return fmt.Errorf("append rule: %w", err) + } + return f.reloadAnchor() +} + +// addRulesToAnchor appends multiple rules and reloads. +func (f *macFirewall) addRulesToAnchor(rules []string) error { + anchorPath := filepath.Join("/etc/pf.anchors", anchorName) + file, err := os.OpenFile(anchorPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return fmt.Errorf("open anchor file: %w", err) + } + defer file.Close() + for _, rule := range rules { + if _, err := file.WriteString(rule + "\n"); err != nil { + return fmt.Errorf("append rule: %w", err) + } + } + return f.reloadAnchor() +} + +// clearAnchor clears the anchor file and reloads. +func (f *macFirewall) clearAnchor() error { + anchorPath := filepath.Join("/etc/pf.anchors", anchorName) + if err := os.WriteFile(anchorPath, []byte{}, 0644); err != nil { + return fmt.Errorf("clear anchor file: %w", err) + } + return f.reloadAnchor() +} + +// reloadAnchor reloads the PF anchor. +func (f *macFirewall) reloadAnchor() error { + if err := execSudoCommand("pfctl", "-a", anchorName, "-F", "all"); err != nil { + return fmt.Errorf("flush anchor: %w", err) + } + if err := execSudoCommand("pfctl", "-a", anchorName, "-f", filepath.Join("/etc/pf.anchors", anchorName)); err != nil { + return fmt.Errorf("load anchor: %w", err) + } + return nil +} + +// execSudoCommand runs a command with sudo (assumes sudo is available; handle prompts in app if needed). +func execSudoCommand(name string, args ...string) error { + cmd := exec.Command("sudo", append([]string{name}, args...)...) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("%s: %w\nOutput: %s", name, err, output) + } + return nil +} + +func CheckIPv6(logger *device.Logger) error { + // Similar to Linux: Check sysctl or interfaces for IPv6 + interfaces, err := net.Interfaces() + if err != nil { + return err + } + for _, iface := range interfaces { + addrs, err := iface.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + ip, _, err := net.ParseCIDR(addr.String()) + if err == nil && ip.To16() != nil && ip.To4() == nil { + logger.Verbosef("IPv6 detected on interface %s", iface.Name) + return nil + } + } + } + return errors.New("no IPv6 interfaces found") +} diff --git a/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewall_windows.go b/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewall_windows.go new file mode 100644 index 0000000..cc0e1a5 --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewall_windows.go @@ -0,0 +1,677 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright © 2026 WG Tunnel. +// Adapted from Tailscale + +//go:build windows + +package osfirewall + +import ( + "fmt" + "net/netip" + "os" + "sync" + "sync/atomic" + + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall" + "golang.org/x/net/nettest" + "golang.org/x/sys/windows" + "inet.af/wf" + "tailscale.com/net/netaddr" +) + +type WindowsFirewall struct { + mu sync.Mutex // protect shared state + + logger *device.Logger + + session *wf.Session + + providerID wf.ProviderID + sublayerID wf.SublayerID + + iface string + + luid uint64 + + killSwitchEnabled atomic.Bool + + tunRules []*wf.Rule + localAddrRules []*wf.Rule + permittedRoutes map[netip.Prefix][]*wf.Rule +} + +type weight uint64 + +const ( + weightDaemonTraffic weight = 15 + weightKnownTraffic weight = 12 + weightCatchAll weight = 0 +) + +type protocol int + +const ( + protocolV4 protocol = iota + protocolV6 + protocolAll +) + +type direction int + +const ( + directionInbound direction = iota + directionOutbound + directionBoth +) + +// Known addresses. +var ( + linkLocalRange = netip.MustParsePrefix("fe80::/10") + linkLocalDHCPMulticast = netip.MustParseAddr("ff02::1:2") + siteLocalDHCPMulticast = netip.MustParseAddr("ff05::1:3") + linkLocalRouterMulticast = netip.MustParseAddr("ff02::2") +) + +func New(logger *device.Logger) (firewall.Firewall, error) { + session, err := wf.New(&wf.Options{ + Name: "WG Tunnel firewall", + Description: "Manages WG Tunnel firewall rules", + Dynamic: true, // Removes rules on close + }) + + if err != nil { + return nil, fmt.Errorf("create WFP session: %w", err) + } + + guid, err := windows.GenerateGUID() + if err != nil { + return nil, err + } + + providerID := wf.ProviderID(guid) + if err := session.AddProvider(&wf.Provider{ + ID: providerID, + Name: "WG Tunnel provider", + }); err != nil { + return nil, err + } + + guid, err = windows.GenerateGUID() + if err != nil { + return nil, err + } + + sublayerID := wf.SublayerID(guid) + if err := session.AddSublayer(&wf.Sublayer{ + ID: sublayerID, + Name: "WG Tunnel permissive and blocking filters", + Weight: uint16(weightCatchAll), + }); err != nil { + return nil, err + } + + f := &WindowsFirewall{ + logger: logger, + session: session, + providerID: providerID, + sublayerID: sublayerID, + permittedRoutes: make(map[netip.Prefix][]*wf.Rule), + tunRules: make([]*wf.Rule, 0), + } + return f, nil +} + +// addPermissiveRulesForPrefixes is a helper to add permissive rules for a list of prefixes +func (f *WindowsFirewall) addPermissiveRulesForPrefixes(prefixes []netip.Prefix, namePrefix string) (map[netip.Prefix][]*wf.Rule, error) { + f.mu.Lock() + defer f.mu.Unlock() + + addedByPrefix := make(map[netip.Prefix][]*wf.Rule) + var partialAdds []netip.Prefix // rollback tracking + for _, prefix := range prefixes { + if prefix.Addr().Is6() && !nettest.SupportsIPv6() { + continue + } + conditions := []*wf.Match{ + { + Field: wf.FieldIPRemoteAddress, + Op: wf.MatchTypeEqual, + Value: prefix, + }, + } + var p protocol + if prefix.Addr().Is4() { + p = protocolV4 + } else { + p = protocolV6 + } + rules, err := f.addRules(namePrefix+prefix.String(), weightKnownTraffic, conditions, wf.ActionPermit, p, directionBoth) + if err != nil { + for _, addedPrefix := range partialAdds { + if delErr := f.removeRules(addedByPrefix[addedPrefix]); delErr != nil { + f.logger.Errorf("Failed to delete partial rules for %v during rollback: %v", addedPrefix, delErr) + } + } + return nil, fmt.Errorf("add permissive rules for %v: %w", prefix, err) + } + addedByPrefix[prefix] = rules + partialAdds = append(partialAdds, prefix) + } + return addedByPrefix, nil +} + +// removeRules is a helper to remove a list of rules +func (f *WindowsFirewall) removeRules(rules []*wf.Rule) error { + f.mu.Lock() + defer f.mu.Unlock() + + for _, rule := range rules { + if err := f.session.DeleteRule(rule.ID); err != nil { + f.logger.Errorf("Failed to delete rule %s: %v", rule.Name, err) + // Continue to try deleting others + } + } + return nil +} + +func (f *WindowsFirewall) AllowLocalNetworks(addrs []netip.Prefix) error { + // cleanup old local addr rules + if err := f.removeRules(f.localAddrRules); err != nil { + f.logger.Errorf("Failed to remove old local addr rules: %v", err) + } + f.mu.Lock() + f.localAddrRules = nil + f.mu.Unlock() + + // add new rules + addedByPrefix, err := f.addPermissiveRulesForPrefixes(addrs, "bypass for local addr ") + if err != nil { + return err + } + f.mu.Lock() + f.localAddrRules = nil + for _, rules := range addedByPrefix { + f.localAddrRules = append(f.localAddrRules, rules...) + } + f.mu.Unlock() + f.logger.Verbosef("Bypassed local addrs in firewall") + return nil +} + +func (f *WindowsFirewall) UpdatePermittedRoutes(newRoutes []netip.Prefix) error { + f.mu.Lock() + // routes to remove + var routesToRemove []netip.Prefix + for existing := range f.permittedRoutes { + found := false + for _, newRoute := range newRoutes { + if existing == newRoute { + found = true + break + } + } + if !found { + routesToRemove = append(routesToRemove, existing) + } + } + f.mu.Unlock() + for _, r := range routesToRemove { + f.mu.Lock() + rules := f.permittedRoutes[r] + f.mu.Unlock() + if err := f.removeRules(rules); err != nil { + f.logger.Errorf("Failed to remove permitted route %v: %v", r, err) + } + f.mu.Lock() + delete(f.permittedRoutes, r) + f.mu.Unlock() + } + + // routes to add + var routesToAdd []netip.Prefix + f.mu.Lock() + for _, newRoute := range newRoutes { + if _, exists := f.permittedRoutes[newRoute]; !exists { + routesToAdd = append(routesToAdd, newRoute) + } + } + f.mu.Unlock() + + // add new rules + addedByPrefix, err := f.addPermissiveRulesForPrefixes(routesToAdd, "permitted route - ") + if err != nil { + return err + } + f.mu.Lock() + for prefix, rules := range addedByPrefix { + f.permittedRoutes[prefix] = rules + } + f.mu.Unlock() + + f.logger.Verbosef("Updated permitted routes: %v", newRoutes) + return nil +} + +// permitDaemon allows the daemon process through firewall +func (f *WindowsFirewall) permitDaemon(w weight) error { + f.mu.Lock() + defer f.mu.Unlock() + currentFile, err := os.Executable() + if err != nil { + return err + } + + appID, err := wf.AppID(currentFile) + if err != nil { + return fmt.Errorf("could not get app id for %q: %w", currentFile, err) + } + conditions := []*wf.Match{ + { + Field: wf.FieldALEAppID, + Op: wf.MatchTypeEqual, + Value: appID, + }, + } + _, err = f.addRules("unrestricted traffic for daemon", w, conditions, wf.ActionPermit, protocolAll, directionBoth) + return err +} + +func (f *WindowsFirewall) BypassTunnel(luid uint64, listenPort uint16) error { + f.mu.Lock() + f.luid = luid + f.mu.Unlock() + if err := f.permitTunInterface(weightDaemonTraffic); err != nil { + return fmt.Errorf("permitTunInterface failed: %w", err) + } + if err := f.permitListenPort(weightDaemonTraffic, listenPort); err != nil { + return fmt.Errorf("permitListenPort failed: %w", err) + } + return nil +} + +func (f *WindowsFirewall) Enable() error { + f.mu.Lock() + if f.killSwitchEnabled.Load() { + f.mu.Unlock() + f.logger.Verbosef("Kill switch already active, skipping activation") + return nil + } + f.mu.Unlock() + if err := f.permitDaemon(weightDaemonTraffic); err != nil { + return fmt.Errorf("permitTailscaleService failed: %w", err) + } + if err := f.permitLoopback(weightDaemonTraffic); err != nil { + return fmt.Errorf("permitLoopback failed: %w", err) + } + if err := f.permitDHCPv4(weightKnownTraffic); err != nil { + return fmt.Errorf("permitDHCPv4 failed: %w", err) + } + + if nettest.SupportsIPv6() { + if err := f.permitDHCPv6(weightKnownTraffic); err != nil { + return fmt.Errorf("permitDHCPv6 failed: %w", err) + } + + if err := f.permitNDP(weightKnownTraffic); err != nil { + return fmt.Errorf("permitNDP failed: %w", err) + } + } + + if err := f.blockAll(weightCatchAll); err != nil { + return fmt.Errorf("blockAll failed: %w", err) + } + + f.killSwitchEnabled.Store(true) + return nil +} + +func (f *WindowsFirewall) IsEnabled() bool { + return f.killSwitchEnabled.Load() +} + +func (f *WindowsFirewall) RemoveTunnelRules() error { + f.mu.Lock() + tunRulesCopy := make([]*wf.Rule, len(f.tunRules)) + copy(tunRulesCopy, f.tunRules) + f.tunRules = nil + f.mu.Unlock() + if err := f.removeRules(tunRulesCopy); err != nil { + f.logger.Errorf("Failed to remove tun rules: %v", err) + } + + f.mu.Lock() + permittedCopy := make(map[netip.Prefix][]*wf.Rule, len(f.permittedRoutes)) + for k, v := range f.permittedRoutes { + permittedCopy[k] = v + } + f.permittedRoutes = make(map[netip.Prefix][]*wf.Rule) + f.mu.Unlock() + for prefix, rules := range permittedCopy { + if err := f.removeRules(rules); err != nil { + f.logger.Errorf("Failed to remove permitted route %s: %v", prefix, err) + } + } + + f.logger.Verbosef("Tunnel rules and permitted routes removed") + return nil +} + +func (f *WindowsFirewall) Disable() error { + f.mu.Lock() + defer f.mu.Unlock() + if err := f.session.Close(); err != nil { + f.logger.Errorf("Failed to close WFP session: %v", err) + } + f.killSwitchEnabled.Store(false) + f.logger.Verbosef("Firewall rules and kill switch cleaned up") + return nil +} + +func (f *WindowsFirewall) permitLoopback(w weight) error { + f.mu.Lock() + defer f.mu.Unlock() + condition := []*wf.Match{ + { + Field: wf.FieldFlags, + Op: wf.MatchTypeFlagsAllSet, + Value: wf.ConditionFlagIsLoopback, + }, + } + _, err := f.addRules("on loopback", w, condition, wf.ActionPermit, protocolAll, directionBoth) + return err +} + +func (f *WindowsFirewall) permitListenPort(w weight, listenPort uint16) error { + f.mu.Lock() + defer f.mu.Unlock() + conditions := []*wf.Match{ + {Field: wf.FieldIPLocalInterface, Op: wf.MatchTypeEqual, Value: f.luid}, + {Field: wf.FieldIPProtocol, Op: wf.MatchTypeEqual, Value: wf.IPProtoUDP}, + {Field: wf.FieldIPLocalPort, Op: wf.MatchTypeEqual, Value: listenPort}, + } + rules, err := f.addRules("WireGuard UDP", w, conditions, wf.ActionPermit, protocolAll, directionInbound) + if err != nil { + return err + } + f.tunRules = append(f.tunRules, rules...) + return nil +} + +func (f *WindowsFirewall) permitDHCPv6(w weight) error { + f.mu.Lock() + defer f.mu.Unlock() + var dhcpConditions = func(remoteAddrs ...any) []*wf.Match { + conditions := []*wf.Match{ + { + Field: wf.FieldIPProtocol, + Op: wf.MatchTypeEqual, + Value: wf.IPProtoUDP, + }, + { + Field: wf.FieldIPLocalAddress, + Op: wf.MatchTypeEqual, + Value: linkLocalRange, + }, + { + Field: wf.FieldIPLocalPort, + Op: wf.MatchTypeEqual, + Value: uint16(546), + }, + { + Field: wf.FieldIPRemotePort, + Op: wf.MatchTypeEqual, + Value: uint16(547), + }, + } + for _, a := range remoteAddrs { + conditions = append(conditions, &wf.Match{ + Field: wf.FieldIPRemoteAddress, + Op: wf.MatchTypeEqual, + Value: a, + }) + } + return conditions + } + conditions := dhcpConditions(linkLocalDHCPMulticast, siteLocalDHCPMulticast) + if _, err := f.addRules("DHCP request", w, conditions, wf.ActionPermit, protocolV6, directionOutbound); err != nil { + return err + } + conditions = dhcpConditions(linkLocalRange) + if _, err := f.addRules("DHCP response", w, conditions, wf.ActionPermit, protocolV6, directionInbound); err != nil { + return err + } + return nil +} + +func (f *WindowsFirewall) permitDHCPv4(w weight) error { + f.mu.Lock() + defer f.mu.Unlock() + var dhcpConditions = func(remoteAddrs ...any) []*wf.Match { + conditions := []*wf.Match{ + { + Field: wf.FieldIPProtocol, + Op: wf.MatchTypeEqual, + Value: wf.IPProtoUDP, + }, + { + Field: wf.FieldIPLocalPort, + Op: wf.MatchTypeEqual, + Value: uint16(68), + }, + { + Field: wf.FieldIPRemotePort, + Op: wf.MatchTypeEqual, + Value: uint16(67), + }, + } + for _, a := range remoteAddrs { + conditions = append(conditions, &wf.Match{ + Field: wf.FieldIPRemoteAddress, + Op: wf.MatchTypeEqual, + Value: a, + }) + } + return conditions + } + conditions := dhcpConditions(netaddr.IPv4(255, 255, 255, 255)) + if _, err := f.addRules("DHCP request", w, conditions, wf.ActionPermit, protocolV4, directionOutbound); err != nil { + return err + } + + conditions = dhcpConditions() + if _, err := f.addRules("DHCP response", w, conditions, wf.ActionPermit, protocolV4, directionInbound); err != nil { + return err + } + return nil +} + +func (f *WindowsFirewall) permitNDP(w weight) error { + f.mu.Lock() + defer f.mu.Unlock() + // These are aliased according to: + // https://social.msdn.microsoft.com/Forums/azure/en-US/eb2aa3cd-5f1c-4461-af86-61e7d43ccc23/filtering-icmp-by-type-code?forum=wfp + fieldICMPType := wf.FieldIPLocalPort + fieldICMPCode := wf.FieldIPRemotePort + + var icmpConditions = func(t, c uint16, remoteAddress any) []*wf.Match { + conditions := []*wf.Match{ + { + Field: wf.FieldIPProtocol, + Op: wf.MatchTypeEqual, + Value: wf.IPProtoICMPV6, + }, + { + Field: fieldICMPType, + Op: wf.MatchTypeEqual, + Value: t, + }, + { + Field: fieldICMPCode, + Op: wf.MatchTypeEqual, + Value: c, + }, + } + if remoteAddress != nil { + conditions = append(conditions, &wf.Match{ + Field: wf.FieldIPRemoteAddress, + Op: wf.MatchTypeEqual, + Value: linkLocalRouterMulticast, + }) + } + return conditions + } + /* TODO: actually handle the hop limit somehow! The rules should vaguely be: + * - icmpv6 133: must be outgoing, dst must be FF02::2/128, hop limit must be 255 + * - icmpv6 134: must be incoming, src must be FE80::/10, hop limit must be 255 + * - icmpv6 135: either incoming or outgoing, hop limit must be 255 + * - icmpv6 136: either incoming or outgoing, hop limit must be 255 + * - icmpv6 137: must be incoming, src must be FE80::/10, hop limit must be 255 + */ + + // + // Router Solicitation Message + // ICMP type 133, code 0. Outgoing. + // + conditions := icmpConditions(133, 0, linkLocalRouterMulticast) + if _, err := f.addRules("NDP type 133", w, conditions, wf.ActionPermit, protocolV6, directionOutbound); err != nil { + return err + } + + // + // Router Advertisement Message + // ICMP type 134, code 0. Incoming. + // + conditions = icmpConditions(134, 0, linkLocalRange) + if _, err := f.addRules("NDP type 134", w, conditions, wf.ActionPermit, protocolV6, directionInbound); err != nil { + return err + } + + // + // Neighbor Solicitation Message + // ICMP type 135, code 0. Bi-directional. + // + conditions = icmpConditions(135, 0, nil) + if _, err := f.addRules("NDP type 135", w, conditions, wf.ActionPermit, protocolV6, directionBoth); err != nil { + return err + } + + // + // Neighbor Advertisement Message + // ICMP type 136, code 0. Bi-directional. + // + conditions = icmpConditions(136, 0, nil) + if _, err := f.addRules("NDP type 136", w, conditions, wf.ActionPermit, protocolV6, directionBoth); err != nil { + return err + } + + // + // Redirect Message + // ICMP type 137, code 0. Incoming. + // + conditions = icmpConditions(137, 0, linkLocalRange) + if _, err := f.addRules("NDP type 137", w, conditions, wf.ActionPermit, protocolV6, directionInbound); err != nil { + return err + } + return nil +} + +func (f *WindowsFirewall) blockAll(w weight) error { + f.mu.Lock() + defer f.mu.Unlock() + _, err := f.addRules("all", w, nil, wf.ActionBlock, protocolAll, directionBoth) + return err +} + +// addRules adds WFP rules with the given parameters +func (f *WindowsFirewall) addRules(name string, w weight, conditions []*wf.Match, action wf.Action, p protocol, d direction) ([]*wf.Rule, error) { + f.mu.Lock() + defer f.mu.Unlock() + var rules []*wf.Rule + for _, layer := range p.getLayers(d) { + r, err := f.newRule(name, w, layer, conditions, action) + if err != nil { + return nil, err + } + if err := f.session.AddRule(r); err != nil { + return nil, err + } + rules = append(rules, r) + } + return rules, nil +} + +// getLayers returns the wf.LayerIDs where the rules should be added based on the protocol and direction. +func (p protocol) getLayers(d direction) []wf.LayerID { + var layers []wf.LayerID + if p == protocolAll || p == protocolV4 { + if d == directionBoth || d == directionInbound { + layers = append(layers, wf.LayerALEAuthRecvAcceptV4) + } + if d == directionBoth || d == directionOutbound { + layers = append(layers, wf.LayerALEAuthConnectV4) + } + } + if p == protocolAll || p == protocolV6 { + if d == directionBoth || d == directionInbound { + layers = append(layers, wf.LayerALEAuthRecvAcceptV6) + } + if d == directionBoth || d == directionOutbound { + layers = append(layers, wf.LayerALEAuthConnectV6) + } + } + return layers +} + +func (f *WindowsFirewall) newRule(name string, w weight, layer wf.LayerID, conditions []*wf.Match, action wf.Action) (*wf.Rule, error) { + f.mu.Lock() + defer f.mu.Unlock() + id, err := windows.GenerateGUID() + if err != nil { + return nil, err + } + return &wf.Rule{ + Name: "WGTunnel-" + ruleName(action, layer, name), + ID: wf.RuleID(id), + Provider: f.providerID, + Sublayer: f.sublayerID, + Layer: layer, + Weight: uint64(w), + Conditions: conditions, + Action: action, + }, nil +} + +func ruleName(action wf.Action, layerID wf.LayerID, name string) string { + switch layerID { + case wf.LayerALEAuthConnectV4: + return fmt.Sprintf("%s outbound %s (IPv4)", action, name) + case wf.LayerALEAuthConnectV6: + return fmt.Sprintf("%s outbound %s (IPv6)", action, name) + case wf.LayerALEAuthRecvAcceptV4: + return fmt.Sprintf("%s inbound %s (IPv4)", action, name) + case wf.LayerALEAuthRecvAcceptV6: + return fmt.Sprintf("%s inbound %s (IPv6)", action, name) + } + return "" +} + +// permitTunInterface allows tun interface through firewall, requires luid to be set +func (f *WindowsFirewall) permitTunInterface(w weight) error { + f.mu.Lock() + defer f.mu.Unlock() + condition := []*wf.Match{ + { + Field: wf.FieldIPLocalInterface, + Op: wf.MatchTypeEqual, + Value: f.luid, + }, + } + rules, err := f.addRules("on TUN", w, condition, wf.ActionPermit, protocolAll, directionBoth) + if err != nil { + return err + } + f.tunRules = append(f.tunRules, rules...) + return nil +} diff --git a/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewallmgr/manager.go b/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewallmgr/manager.go new file mode 100644 index 0000000..7876b05 --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/firewall/osfirewall/firewallmgr/manager.go @@ -0,0 +1,25 @@ +package firewallmgr + +import ( + "sync" + + "github.com/wgtunnel/desktop/tunnel/shared" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall/osfirewall" +) + +var ( + instance firewall.Firewall + once sync.Once + initErr error +) + +func Get() (firewall.Firewall, error) { + once.Do(func() { + instance, initErr = osfirewall.New( + shared.NewLogger("Firewall"), + ) + }) + + return instance, initErr +} diff --git a/tunnel/tools/libwg-go/vpn/router/osrouter/router_linux.go b/tunnel/tools/libwg-go/vpn/router/osrouter/router_linux.go new file mode 100644 index 0000000..da1c797 --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/router/osrouter/router_linux.go @@ -0,0 +1,476 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright © 2026 WG Tunnel. +// Adapted from Tailscale + +//go:build linux + +package osrouter + +import ( + "fmt" + "net" + "net/netip" + "slices" + + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/amnezia-vpn/amneziawg-go/tun" + "github.com/vishvananda/netlink" + "github.com/wgtunnel/desktop/tunnel/vpn/dns" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall/mark" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall/osfirewall" + "github.com/wgtunnel/desktop/tunnel/vpn/router" + "golang.org/x/net/nettest" + "golang.org/x/sys/unix" +) + +const ( + rulePrioBootstrap = 50 + tunnelTableID = 52 + rulePrioMark = 100 + rulePrioExclude = 150 + rulePrioDefault = 200 +) + +type linuxRouter struct { + iface string + fw *osfirewall.LinuxFirewall + logger *device.Logger + prevConfig *router.Config + weEngagedKS bool + v4Full bool + v6Full bool + v6Available bool + + policyRules map[int][]*netlink.Rule +} + +func New(iface string, fw firewall.Firewall, _ tun.Device, logger *device.Logger) (router.Router, error) { + return &linuxRouter{ + iface: iface, + fw: fw.(*osfirewall.LinuxFirewall), + logger: logger, + v6Available: nettest.SupportsIPv6(), + policyRules: make(map[int][]*netlink.Rule), + }, nil +} + +func (r *linuxRouter) Set(c *router.Config) error { + newC := r.normalizeConfig(c) + prevC := r.normalizeConfig(r.prevConfig) + + if r.isUnchanged(newC) { + r.logger.Verbosef("Config unchanged, skipping") + return nil + } + + link, err := netlink.LinkByName(r.iface) + if err != nil { + return fmt.Errorf("get link %s: %w", r.iface, err) + } + + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("set link up: %w", err) + } + + if err := r.syncFirewallState(newC); err != nil { + return err + } + + r.syncDeviceParams(link, newC, prevC) + + r.cleanupPreviousState(link, newC, prevC) + if err := r.applyNewAddresses(link, newC); err != nil { + return err + } + + if err := r.syncRoutingAndRules(link, newC); err != nil { + return err + } + + if err := r.syncDNS(newC, prevC); err != nil { + return err + } + + r.updatePrevState(newC) + return nil +} + +// Close closes the router. +func (r *linuxRouter) Close() error { + // revert DNS before cleanup + if r.prevConfig != nil { + if err := dns.RevertDns(r.iface, r.logger); err != nil { + r.logger.Errorf("revert DNS on close: %v", err) + } + } + + // cleanup + if err := r.Set(nil); err != nil { + r.logger.Errorf("cleanup set nil: %v", err) + } + + if r.weEngagedKS && r.fw.IsEnabled() { + r.logger.Verbosef("Disabling full tunnel kill switch for iface: %s", r.iface) + if err := r.fw.Disable(); err != nil { + return fmt.Errorf("failed to disable firewall: %w", err) + } + } else if r.fw.IsEnabled() { + r.logger.Verbosef("Removing firewall rules for iface: %s", r.iface) + if err := r.fw.RemoveTunnelBypasses(r.iface); err != nil { + return fmt.Errorf("failed remove firewall rules for iface %s : %v", r.iface, err) + } + } + + r.deletePolicyRules(netlink.FAMILY_V4) + r.deletePolicyRules(netlink.FAMILY_V6) + + r.logger.Verbosef("Router closed") + return nil +} + +func (r *linuxRouter) cleanupPreviousState(link netlink.Link, newC, prevC *router.Config) { + if r.prevConfig == nil { + return + } + + // remove old addresses + for _, a := range prevC.TunnelAddrs { + if !slices.Contains(newC.TunnelAddrs, a) { + ipnet := prefixToIPNet(a) + if err := netlink.AddrDel(link, &netlink.Addr{IPNet: ipnet}); err != nil { + r.logger.Errorf("del addr %v: %v", a, err) + } + } + } + + // remove old routes + prevV4Full := hasDefault(prevC, true) + prevV6Full := hasDefault(prevC, false) + newV4Full := hasDefault(newC, true) + newV6Full := hasDefault(newC, false) + + for _, rt := range prevC.Routes { + if !slices.Contains(newC.Routes, rt) { + table := unix.RT_TABLE_MAIN + if (rt.Addr().Is4() && prevV4Full) || (rt.Addr().Is6() && prevV6Full) { + table = tunnelTableID + } + dst := prefixToIPNet(rt) + route := &netlink.Route{LinkIndex: link.Attrs().Index, Dst: dst, Table: table} + _ = netlink.RouteDel(route) + } + } + + // clean up marks + if prevV4Full && !newV4Full { + r.deletePolicyRules(netlink.FAMILY_V4) + r.deleteBootstrapPolicyRules(netlink.FAMILY_V4) + } + if prevV6Full && !newV6Full { + r.deletePolicyRules(netlink.FAMILY_V6) + r.deleteBootstrapPolicyRules(netlink.FAMILY_V6) + } +} + +func (r *linuxRouter) normalizeConfig(c *router.Config) *router.Config { + if c == nil { + return &router.Config{} + } + return c +} + +func addrExists(existing []netlink.Addr, target *net.IPNet) bool { + for _, a := range existing { + if a.IPNet != nil && a.IPNet.String() == target.String() { + return true + } + } + return false +} + +func (r *linuxRouter) deleteExcludeRule(lr netip.Prefix) { + fam := netlink.FAMILY_V4 + if lr.Addr().Is6() { + fam = netlink.FAMILY_V6 + } + + dst := prefixToIPNet(lr) + rule := netlink.NewRule() + rule.Family = fam + rule.Priority = rulePrioExclude + rule.Dst = dst + rule.Table = unix.RT_TABLE_MAIN + + // ignore the error if rule is already gone + if err := netlink.RuleDel(rule); err != nil { + r.logger.Verbosef("del exclude rule %v: %v (ignored)", lr, err) + } +} + +func (r *linuxRouter) isUnchanged(newC *router.Config) bool { + if r.prevConfig == nil { + return false + } + return newC.Equal(r.prevConfig) +} + +func (r *linuxRouter) updatePrevState(newC *router.Config) { + r.v4Full = hasDefault(newC, true) + r.v6Full = hasDefault(newC, false) + r.prevConfig = newC.Clone() + r.logger.Verbosef("Router state updated: full v4=%v v6=%v", r.v4Full, r.v6Full) +} + +func (r *linuxRouter) syncFirewallState(newC *router.Config) error { + v4Full := hasDefault(newC, true) + v6Full := hasDefault(newC, false) + requiresKS := v4Full || v6Full + + if requiresKS && !r.fw.IsEnabled() { + if err := r.fw.Enable(); err != nil { + return fmt.Errorf("enable firewall: %w", err) + } + r.weEngagedKS = true + // add our marks for the tunnel and bootstrap + if err := r.fw.AddTunnelBypasses(r.iface); err != nil { + return fmt.Errorf("add firewall bypasses: %w", err) + } + } else if !requiresKS && r.weEngagedKS { + if err := r.fw.Disable(); err != nil { + return fmt.Errorf("disable firewall: %w", err) + } + r.weEngagedKS = false + } + return nil +} + +func (r *linuxRouter) syncDeviceParams(link netlink.Link, newC, prevC *router.Config) { + // sync mtu + if newC.MTU > 0 && newC.MTU != prevC.MTU { + _ = netlink.LinkSetMTU(link, newC.MTU) + } + + // sync ListenPort for fw + if newC.ListenPort != 0 && newC.ListenPort != prevC.ListenPort { + _ = r.fw.SetTunnelPort(newC.ListenPort) + } +} + +func (r *linuxRouter) syncDNS(newC, prevC *router.Config) error { + v4Full := hasDefault(newC, true) + v6Full := hasDefault(newC, false) + prevV4Full := hasDefault(prevC, true) + prevV6Full := hasDefault(prevC, false) + + // handle if DNS settings or tunnel state changed + dnsChanged := !slices.Equal(newC.DNS, prevC.DNS) || + !slices.Equal(newC.SearchDomains, prevC.SearchDomains) + stateChanged := (v4Full != prevV4Full) || (v6Full != prevV6Full) + + if dnsChanged || stateChanged { + return dns.SetDns(r.iface, newC.DNS, newC.SearchDomains, v4Full || v6Full, r.logger) + } + return nil +} + +func (r *linuxRouter) applyNewAddresses(link netlink.Link, newC *router.Config) error { + existingAddrs, _ := netlink.AddrList(link, netlink.FAMILY_ALL) + + for _, a := range newC.TunnelAddrs { + if a.Addr().Is6() && !r.v6Available { + continue + } + + ipNet := prefixToIPNet(a) + + if !addrExists(existingAddrs, ipNet) { + if err := netlink.AddrReplace(link, &netlink.Addr{IPNet: ipNet}); err != nil { + return fmt.Errorf("failed to add addr %v: %w", a, err) + } + } + } + return nil +} + +func (r *linuxRouter) syncRoutingAndRules(link netlink.Link, newC *router.Config) error { + v4Full := hasDefault(newC, true) + v6Full := hasDefault(newC, false) + + families := []int{netlink.FAMILY_V4} + if r.v6Available { + families = append(families, netlink.FAMILY_V6) + } + + for _, fam := range families { + isFull := (fam == netlink.FAMILY_V4 && v4Full) || (fam == netlink.FAMILY_V6 && v6Full) + + if isFull { + // add unnel rules + if err := r.addPolicyRules(fam); err != nil { + return err + } + // add bootstrap mark rule for DNS bootstrap + if err := r.addBootstrapPolicyRules(fam); err != nil { + return err + } + } + + routes := filterRoutes(newC.Routes, fam == netlink.FAMILY_V4) + table := unix.RT_TABLE_MAIN + if isFull { + table = tunnelTableID + } + + for _, rt := range routes { + if err := r.replaceRouteIdempotent(link, rt, table); err != nil { + return err + } + } + } + return nil +} + +func (r *linuxRouter) addBootstrapPolicyRules(family int) error { + mask := uint32(mark.LinuxFwmarkMaskNum) + rule := netlink.NewRule() + rule.Family = family + rule.Mark = mark.LinuxBootstrapMarkNum + rule.Mask = &mask + rule.Priority = 50 // set as high priority, above main tunnel rules + rule.Table = unix.RT_TABLE_MAIN // force bypass to ISP table + + return r.addRuleIdempotent(rule) +} + +func (r *linuxRouter) deleteBootstrapPolicyRules(family int) error { + rule := netlink.NewRule() + rule.Family = family + rule.Mark = mark.LinuxBootstrapMarkNum + rule.Priority = rulePrioBootstrap + return netlink.RuleDel(rule) +} + +func (r *linuxRouter) addRuleIdempotent(rule *netlink.Rule) error { + rules, err := netlink.RuleList(rule.Family) + if err != nil { + return err + } + + for _, existing := range rules { + if existing.Mark == rule.Mark && existing.Priority == rule.Priority && existing.Table == rule.Table { + return nil // Already exists + } + } + return netlink.RuleAdd(rule) +} + +func (r *linuxRouter) replaceRouteIdempotent(link netlink.Link, rt netip.Prefix, table int) error { + dst := prefixToIPNet(rt) + route := &netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: dst, + Table: table, + Type: unix.RTN_UNICAST, + } + return netlink.RouteReplace(route) +} + +// hasDefault returns true if config has default route for v4 (true) or v6 (false). +func hasDefault(c *router.Config, v4 bool) bool { + if c == nil { + return false + } + for _, rt := range c.Routes { + if rt.Bits() == 0 && ((v4 && rt.Addr().Is4()) || (!v4 && rt.Addr().Is6())) { + return true + } + } + return false +} + +// filterRoutes returns routes for v4 (true) or v6 (false). +func filterRoutes(routes []netip.Prefix, v4 bool) []netip.Prefix { + var filtered []netip.Prefix + for _, rt := range routes { + if (v4 && rt.Addr().Is4()) || (!v4 && rt.Addr().Is6()) { + filtered = append(filtered, rt) + } + } + return filtered +} + +// prefixToIPNet converts netip.Prefix to *net.IPNet. +func prefixToIPNet(p netip.Prefix) *net.IPNet { + if !p.IsValid() { + return nil + } + bits := p.Bits() + addr := p.Addr() + ip := net.IP(addr.AsSlice()) + mask := net.CIDRMask(bits, addr.BitLen()) + return &net.IPNet{IP: ip, Mask: mask} +} + +// addPolicyRules adds mark-based and default tunnel table rules for the family. +func (r *linuxRouter) addPolicyRules(fam int) error { + rules, err := netlink.RuleList(fam) + if err != nil { + return fmt.Errorf("list rules fam %d: %w", fam, err) + } + + // Mark rule: fwmark bypass -> main + markRule := netlink.NewRule() + markRule.Family = fam + markRule.Priority = rulePrioMark + markRule.Mark = mark.LinuxBypassMarkNum + markRule.Table = unix.RT_TABLE_MAIN + + markExists := false + for _, existing := range rules { + if existing.Priority == markRule.Priority && existing.Mark == markRule.Mark && existing.Table == markRule.Table { + markExists = true + break + } + } + if !markExists { + if err := netlink.RuleAdd(markRule); err != nil { + return fmt.Errorf("add mark rule fam %d: %w", fam, err) + } + r.policyRules[fam] = append(r.policyRules[fam], markRule) + } else { + r.logger.Verbosef("Mark rule fam %d already exists, skipping", fam) + } + + defaultRule := netlink.NewRule() + defaultRule.Family = fam + defaultRule.Priority = rulePrioDefault + defaultRule.Table = tunnelTableID + + defaultExists := false + for _, existing := range rules { + if existing.Priority == defaultRule.Priority && existing.Table == defaultRule.Table && existing.Dst == nil { + defaultExists = true + break + } + } + if !defaultExists { + if err := netlink.RuleAdd(defaultRule); err != nil { + return fmt.Errorf("add default tunnel rule fam %d: %w", fam, err) + } + r.policyRules[fam] = append(r.policyRules[fam], defaultRule) + } else { + r.logger.Verbosef("Default tunnel rule fam %d already exists, skipping", fam) + } + return nil +} + +// deletePolicyRules deletes the policy rules for the family. +func (r *linuxRouter) deletePolicyRules(fam int) { + for _, rule := range r.policyRules[fam] { + if err := netlink.RuleDel(rule); err != nil { + r.logger.Verbosef("del policy rule fam %d (prio %d): %v (ignored)", fam, rule.Priority, err) + } + } + r.policyRules[fam] = nil +} diff --git a/tunnel/tools/libwg-go/vpn/router/osrouter/router_windows.go b/tunnel/tools/libwg-go/vpn/router/osrouter/router_windows.go new file mode 100644 index 0000000..7e4bca8 --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/router/osrouter/router_windows.go @@ -0,0 +1,766 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright © 2026 WG Tunnel. +// Adapted from Tailscale + +//go:build windows + +package osrouter + +import ( + "errors" + "fmt" + "net/netip" + "os/exec" + "slices" + "sort" + "strings" + "syscall" + + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/amnezia-vpn/amneziawg-go/tun" + "github.com/wgtunnel/desktop/tunnel/vpn/dns" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall/osfirewall" + "github.com/wgtunnel/desktop/tunnel/vpn/router" + "go4.org/netipx" + "golang.org/x/net/nettest" + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +type windowsRouter struct { + iface string + fw *osfirewall.WindowsFirewall + logger *device.Logger + prevConfig *router.Config + weEngagedKS bool + v6Available bool + nativeTun *tun.NativeTun + luid winipcfg.LUID + rawLuid uint64 + originalSearchDomains []string +} + +func New(iface string, fw firewall.Firewall, tunnel tun.Device, logger *device.Logger) (router.Router, error) { + nativeTun := tunnel.(*tun.NativeTun) + // get windows interface id + rawLuid := nativeTun.LUID() + return &windowsRouter{ + iface: iface, + fw: fw.(*osfirewall.WindowsFirewall), + logger: logger, + v6Available: nettest.SupportsIPv6(), + nativeTun: nativeTun, + rawLuid: rawLuid, + luid: winipcfg.LUID(rawLuid), + }, nil +} + +func (r *windowsRouter) Set(c *router.Config) error { + newC := c + if newC == nil { + newC = &router.Config{} + } + prevC := r.prevConfig + if prevC == nil { + prevC = &router.Config{} + } + + if newC.Equal(prevC) { + r.logger.Verbosef("Config unchanged, skipping") + return nil + } + + err := r.configureInterface(newC) + if err != nil { + r.logger.Errorf("ConfigureInterface: %v", err) + return err + } + + // dns + prevFull := prevC.HasAnyDefaultRoute() + newFull := newC.HasAnyDefaultRoute() + if !slices.Equal(newC.DNS, prevC.DNS) || !slices.Equal(newC.SearchDomains, prevC.SearchDomains) || newFull != prevFull { + if newFull && r.originalSearchDomains == nil { + var err error + r.originalSearchDomains, err = r.getGlobalSearchDomains() + if err != nil { + r.logger.Errorf("Failed to get original search domains: %v", err) + } + } + if err := dns.SetDNS(r.luid, newC.DNS, newC.SearchDomains, newFull, r.logger); err != nil { + return err + } + } + + requiresKS := newFull + if requiresKS && !r.fw.IsEnabled() { + if err := r.fw.Enable(); err != nil { + return err + } + if err := r.fw.BypassTunnel(r.rawLuid, newC.ListenPort); err != nil { + return err + } + //if err := r.fw.AllowLocalNetworks(newC.ExcludedRoutes); err != nil { + // return err + //} + r.weEngagedKS = true + } else if !requiresKS && r.weEngagedKS { + if err := r.fw.Disable(); err != nil { + return err + } + r.weEngagedKS = false + } + + if err := flushCaches(); err != nil { + r.logger.Errorf("flush dns: %v", err) + } + + r.prevConfig = newC.Clone() + return nil +} + +// subtractPrefixes returns the list of prefixes that cover "super" minus all "exclusions" +func subtractPrefixes(super netip.Prefix, exclusions []netip.Prefix) []netip.Prefix { + if !super.IsValid() { + return nil + } + + result := []netip.Prefix{super} + + for _, excl := range exclusions { + if !excl.IsValid() || excl.Bits() != excl.Addr().BitLen() { // skip non-host + continue + } + if !super.Contains(excl.Addr()) { + continue + } + + var newResult []netip.Prefix + for _, r := range result { + if !r.Contains(excl.Addr()) { + newResult = append(newResult, r) + continue + } + + // split the containing prefix + current := r + for current.Bits() < excl.Bits() { + splitBit := current.Bits() + lowMask := netip.PrefixFrom(current.Addr(), splitBit+1) + + if lowMask.Contains(excl.Addr()) { + // Add the high (sibling) half + siblingAddr := flipBit(current.Addr(), splitBit) + sibling := netip.PrefixFrom(siblingAddr, splitBit+1) + newResult = append(newResult, sibling) + current = lowMask // split the the low half + } else { + // add the low half, continue with high + newResult = append(newResult, lowMask) + highAddr := flipBit(current.Addr(), splitBit) + current = netip.PrefixFrom(highAddr, splitBit+1) + } + } + // drop the matching prefix to excluded + } + result = newResult + } + return result +} + +// getBit returns the value of the i-th bit +func getBit(addr netip.Addr, i int) bool { + if i < 0 || i >= addr.BitLen() { + return false + } + if addr.Is4() { + b := addr.As4() + byteIdx := i / 8 + bitIdx := 7 - (i % 8) // MSB first + return (b[byteIdx] & (1 << bitIdx)) != 0 + } + b := addr.As16() + byteIdx := i / 8 + bitIdx := 7 - (i % 8) + return (b[byteIdx] & (1 << bitIdx)) != 0 +} + +// setBit returns a new Addr with the i-th bit set to value +func setBit(addr netip.Addr, i int, value bool) netip.Addr { + if addr.Is4() { + b := addr.As4() + byteIdx := i / 8 + bitIdx := 7 - (i % 8) + if value { + b[byteIdx] |= 1 << bitIdx + } else { + b[byteIdx] &^= 1 << bitIdx + } + return netip.AddrFrom4(b) + } + b := addr.As16() + byteIdx := i / 8 + bitIdx := 7 - (i % 8) + if value { + b[byteIdx] |= 1 << bitIdx + } else { + b[byteIdx] &^= 1 << bitIdx + } + return netip.AddrFrom16(b) +} + +// flipBit returns a new Addr with the i-th bit flipped +func flipBit(addr netip.Addr, i int) netip.Addr { + return setBit(addr, i, !getBit(addr, i)) +} + +func (r *windowsRouter) Close() error { + if r.prevConfig != nil { + dns.RevertDNS(r.luid, r.prevConfig.HasAnyDefaultRoute(), r.originalSearchDomains, r.logger) + } + + r.Set(nil) + + r.logger.Verbosef("Router closed") + return nil +} + +// configureInterface uses the split route specificity approach to prevent routing loops +func (r *windowsRouter) configureInterface(cfg *router.Config) error { + iface, err := interfaceFromLUID(r.luid, winipcfg.GAAFlagIncludeAllInterfaces) + if err != nil { + return fmt.Errorf("getting interface: %v", err) + } + + _, err = r.setPrivateNetwork() + if err != nil { + r.logger.Verbosef("**WARNING** failed to set private network: %v", err) + } + + ipif4, err := r.luid.IPInterface(windows.AF_INET) + if err != nil && !errors.Is(err, windows.ERROR_NOT_FOUND) { + return fmt.Errorf("getting AF_INET interface: %v", err) + } + ipif6, err := r.luid.IPInterface(windows.AF_INET6) + if err != nil && !errors.Is(err, windows.ERROR_NOT_FOUND) { + return fmt.Errorf("getting AF_INET6 interface: %v", err) + } + + // Set up local tunnel addresses and gateways + var localAddr4, localAddr6 netip.Addr + var gatewayAddr4, gatewayAddr6 netip.Addr + addresses := make([]netip.Prefix, 0, len(cfg.TunnelAddrs)) + for _, addr := range cfg.TunnelAddrs { + if (addr.Addr().Is4() && ipif4 == nil) || (addr.Addr().Is6() && ipif6 == nil) { + continue + } + addresses = append(addresses, addr) + if addr.Addr().Is4() && !gatewayAddr4.IsValid() { + localAddr4 = addr.Addr() + gatewayAddr4 = netip.MustParseAddr("192.0.2.1") + } else if addr.Addr().Is6() && !gatewayAddr6.IsValid() { + localAddr6 = addr.Addr() + gatewayAddr6 = netip.MustParseAddr("fc00::1") + } + } + + var routes []*routeData + foundDefault4 := false + foundDefault6 := false + + for _, route := range cfg.Routes { + if (route.Addr().Is4() && ipif4 == nil) || (route.Addr().Is6() && ipif6 == nil) { + continue + } + + // Initialize IPv6 gateway if needed + if route.Addr().Is6() && !gatewayAddr6.IsValid() { + ip := netip.MustParseAddr("fc00::dead:beef") + addresses = append(addresses, netip.PrefixFrom(ip, ip.BitLen())) + gatewayAddr6 = ip + } + + var gateway, localAddr netip.Addr + if route.Addr().Is4() { + localAddr = localAddr4 + gateway = gatewayAddr4 + } else if route.Addr().Is6() { + localAddr = localAddr6 + gateway = gatewayAddr6 + } + + // split route for higher specificity over default route + if route.Bits() == 0 { + var splits []netip.Prefix + if route.Addr().Is4() { + splits = []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/1"), + netip.MustParsePrefix("128.0.0.0/1"), + } + foundDefault4 = true + } else { + splits = []netip.Prefix{ + netip.MustParsePrefix("::/1"), + netip.MustParsePrefix("8000::/1"), + } + foundDefault6 = true + } + + for _, p := range splits { + routes = append(routes, &routeData{ + RouteData: winipcfg.RouteData{ + Destination: p, + NextHop: gateway, + Metric: 0, + }, + }) + } + continue + } + + // non-default routes + if route.Addr().Unmap() == localAddr { + continue + } + if route.IsSingleIP() { + gateway = localAddr + } + + routes = append(routes, &routeData{ + RouteData: winipcfg.RouteData{ + Destination: route, + NextHop: gateway, + Metric: 0, + }, + }) + } + + err = syncAddresses(iface, addresses) + if err != nil { + return fmt.Errorf("syncAddresses: %v", err) + } + + slices.SortFunc(routes, (*routeData).Compare) + + var deduplicatedRoutes []*routeData + for i := 0; i < len(routes); i++ { + if i > 0 && routes[i].Destination == routes[i-1].Destination { + continue + } + deduplicatedRoutes = append(deduplicatedRoutes, routes[i]) + } + + iface, err = interfaceFromLUID(r.luid, winipcfg.GAAFlagIncludeAllInterfaces) + if err != nil { + return fmt.Errorf("getting interface after syncAddresses: %v", err) + } + + var errAcc error + err = syncRoutes(iface, deduplicatedRoutes, cfg.TunnelAddrs) + if err != nil { + errAcc = errors.Join(errAcc, err) + } + + if ipif4 != nil { + ipif4, err = r.luid.IPInterface(windows.AF_INET) + if err != nil { + return fmt.Errorf("getting AF_INET interface: %v", err) + } + if foundDefault4 { + ipif4.UseAutomaticMetric = false + ipif4.Metric = 0 + } + ipif4.NLMTU = uint32(cfg.MTU) + err = ipif4.Set() + if err != nil { + errAcc = errors.Join(errAcc, err) + } + } + + if ipif6 != nil { + ipif6, err = r.luid.IPInterface(windows.AF_INET6) + if err != nil { + return fmt.Errorf("getting AF_INET6 interface: %v", err) + } + if foundDefault6 { + ipif6.UseAutomaticMetric = false + ipif6.Metric = 0 + } + ipif6.NLMTU = uint32(cfg.MTU) + ipif6.DadTransmits = 0 + ipif6.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + err = ipif6.Set() + if err != nil { + errAcc = errors.Join(errAcc, err) + } + } + + return errAcc +} + +func isIPv6LinkLocal(a netip.Prefix) bool { + return a.Addr().Is6() && a.Addr().IsLinkLocalUnicast() +} + +// ipAdapterUnicastAddressToPrefix converts windows IpAdapterUnicastAddress to netip.Prefix +func ipAdapterUnicastAddressToPrefix(u *windows.IpAdapterUnicastAddress) netip.Prefix { + ip, _ := netip.AddrFromSlice(u.Address.IP()) + return netip.PrefixFrom(ip.Unmap(), int(u.OnLinkPrefixLength)) +} + +// unicastIPNets returns all unicast net.IPNet for ifc interface. +func unicastIPNets(ifc *winipcfg.IPAdapterAddresses) []netip.Prefix { + var nets []netip.Prefix + for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next { + nets = append(nets, ipAdapterUnicastAddressToPrefix(addr)) + } + return nets +} + +func syncAddresses(ifc *winipcfg.IPAdapterAddresses, want []netip.Prefix) error { + got := unicastIPNets(ifc) + add, del := deltaNets(got, want) + var erracc error + ll := make([]netip.Prefix, 0) + for _, a := range del { + if isIPv6LinkLocal(a) { + ll = append(ll, a) + continue + } + err := ifc.LUID.DeleteIPAddress(a) + if err != nil { + erracc = errors.Join(erracc, fmt.Errorf("deleting IP %q: %v", a, err)) + } + } + for _, a := range add { + err := ifc.LUID.AddIPAddress(a) + if err != nil { + erracc = errors.Join(erracc, fmt.Errorf("adding IP %q: %v", a, err)) + } + } + for _, a := range ll { + mib, err := ifc.LUID.IPAddress(a.Addr()) + if err != nil { + erracc = errors.Join(erracc, fmt.Errorf("setting skip-as-source on IP %q: unable to retrieve MIB: %v", a, err)) + continue + } + if !mib.SkipAsSource { + mib.SkipAsSource = true + if err := mib.Set(); err != nil { + erracc = errors.Join(erracc, fmt.Errorf("setting skip-as-source on IP %q: unable to set MIB: %v", a, err)) + } + } + } + return erracc +} + +// routeData wraps winipcfg.RouteData with an additional field that permits +// caching of the associated MibIPForwardRow2; by keeping it around, we can +// avoid unnecessary lookups of information that we already have. +type routeData struct { + winipcfg.RouteData + Row *winipcfg.MibIPforwardRow2 +} + +func (rd *routeData) Compare(other *routeData) int { + v := rd.Destination.Addr().Compare(other.Destination.Addr()) + if v != 0 { + return v + } + b1, b2 := rd.Destination.Bits(), other.Destination.Bits() + if b1 != b2 { + if b1 > b2 { + return -1 + } + return 1 + } + v = rd.NextHop.Compare(other.NextHop) + if v != 0 { + return v + } + if rd.Metric < other.Metric { + return -1 + } else if rd.Metric > other.Metric { + return 1 + } + return 0 +} + +func deltaRouteData(a, b []*routeData) (add, del []*routeData) { + add = make([]*routeData, 0, len(b)) + del = make([]*routeData, 0, len(a)) + slices.SortFunc(a, (*routeData).Compare) + slices.SortFunc(b, (*routeData).Compare) + + i, j := 0, 0 + for i < len(a) && j < len(b) { + switch a[i].Compare(b[j]) { + case -1: + del = append(del, a[i]) + i++ + case 0: + i++ + j++ + case 1: + add = append(add, b[j]) + j++ + } + } + del = append(del, a[i:]...) + add = append(add, b[j:]...) + return +} + +// getInterfaceRoutes returns all the interface's routes. +func getInterfaceRoutes(ifc *winipcfg.IPAdapterAddresses, family winipcfg.AddressFamily) (matches []*winipcfg.MibIPforwardRow2, err error) { + routes, err := winipcfg.GetIPForwardTable2(family) + if err != nil { + return nil, err + } + for i := range routes { + if routes[i].InterfaceLUID == ifc.LUID { + matches = append(matches, &routes[i]) + } + } + return +} + +func getAllInterfaceRoutes(ifc *winipcfg.IPAdapterAddresses) ([]*routeData, error) { + routes4, err := getInterfaceRoutes(ifc, windows.AF_INET) + if err != nil { + return nil, err + } + + routes6, err := getInterfaceRoutes(ifc, windows.AF_INET6) + if err != nil { + // TODO: what if v6 unavailable? + return nil, err + } + + rd := make([]*routeData, 0, len(routes4)+len(routes6)) + for _, r := range routes4 { + rd = append(rd, &routeData{ + RouteData: winipcfg.RouteData{ + Destination: r.DestinationPrefix.Prefix(), + NextHop: r.NextHop.Addr(), + Metric: r.Metric, + }, + Row: r, + }) + } + for _, r := range routes6 { + rd = append(rd, &routeData{ + RouteData: winipcfg.RouteData{ + Destination: r.DestinationPrefix.Prefix(), + NextHop: r.NextHop.Addr(), + Metric: r.Metric, + }, + Row: r, + }) + } + return rd, nil +} + +func filterRoutes(routes []*routeData, dontDelete []netip.Prefix) []*routeData { + ddm := make(map[netip.Prefix]bool) + for _, dd := range dontDelete { + ddm[dd] = true + } + for _, r := range routes { + nr := r.Destination + if !nr.IsValid() { + continue + } + if nr.IsSingleIP() { + continue + } + lastIP := netipx.RangeOfPrefix(nr).To() + ddm[netip.PrefixFrom(lastIP, lastIP.BitLen())] = true + } + filtered := make([]*routeData, 0, len(routes)) + for _, r := range routes { + rr := r.Destination + if rr.IsValid() && ddm[rr] { + continue + } + filtered = append(filtered, r) + } + return filtered +} + +// syncRoutes incrementally sets multiples routes on an interface. +// This avoids a full ifc.FlushRoutes call. +// dontDelete is a list of interface address routes that the +// synchronization logic should never delete. +func syncRoutes(ifc *winipcfg.IPAdapterAddresses, want []*routeData, dontDelete []netip.Prefix) error { + existingRoutes, err := getAllInterfaceRoutes(ifc) + if err != nil { + return err + } + got := filterRoutes(existingRoutes, dontDelete) + + add, del := deltaRouteData(got, want) + + var errs []error + for _, a := range del { + var err error + if a.Row == nil { + // DeleteRoute requires a routing table lookup, so only do that if + // a does not already have the row. + err = ifc.LUID.DeleteRoute(a.Destination, a.NextHop) + } else { + // delete the row directly. + err = a.Row.Delete() + } + if err != nil { + dstStr := a.Destination.String() + if dstStr == "169.254.255.255/32" { + // Issue 785 (Tailscale). Ignore these routes + // failing to delete. Harmless. + continue + } + errs = append(errs, fmt.Errorf("deleting route %v: %v", dstStr, err)) + } + } + + for _, a := range add { + err := ifc.LUID.AddRoute(a.Destination, a.NextHop, a.Metric) + if err != nil { + errs = append(errs, fmt.Errorf("adding route %v: %v", &a.Destination, err)) + } + } + + return errors.Join(errs...) +} + +// deltaNets returns the changes to turn a into b. +func deltaNets(a, b []netip.Prefix) (add, del []netip.Prefix) { + add = make([]netip.Prefix, 0, len(b)) + del = make([]netip.Prefix, 0, len(a)) + sortNets(a) + sortNets(b) + + i, j := 0, 0 + for i < len(a) && j < len(b) { + switch netCompare(a[i], b[j]) { + case -1: + del = append(del, a[i]) + i++ + case 0: + i++ + j++ + case 1: + add = append(add, b[j]) + j++ + default: + panic("unexpected compare result") + } + } + del = append(del, a[i:]...) + add = append(add, b[j:]...) + return +} + +func netCompare(a, b netip.Prefix) int { + aip, bip := a.Addr().Unmap(), b.Addr().Unmap() + v := aip.Compare(bip) + if v != 0 { + return v + } + if a.Bits() == b.Bits() { + return 0 + } + if a.Bits() > b.Bits() { + return -1 + } + return 1 +} + +func sortNets(s []netip.Prefix) { + sort.Slice(s, func(i, j int) bool { + return netCompare(s[i], s[j]) < 0 + }) +} + +func (r *windowsRouter) getGlobalSearchDomains() ([]string, error) { + cmd := exec.Command("powershell", "-Command", "(Get-DnsClientGlobalSetting).SuffixSearchList") + output, err := cmd.Output() + if err != nil { + return nil, err + } + lines := strings.Split(strings.TrimSpace(string(output)), "\r\n") + var domains []string + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed != "" { + domains = append(domains, trimmed) + } + } + return domains, nil +} + +func (r *windowsRouter) setPrivateNetwork() (bool, error) { + alias := r.iface + + // Check if visible and get current category + cmd := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetConnectionProfile -InterfaceAlias "%s" | Select-Object -ExpandProperty NetworkCategory`, alias)) + output, err := cmd.CombinedOutput() + if err != nil { + r.logger.Verbosef("setPrivateNetwork: Get-NetConnectionProfile failed: %v", err) + return false, err + } + + category := strings.TrimSpace(string(output)) + if category == "" { + r.logger.Verbosef("setPrivateNetwork: Adapter not found") + return false, nil + } + + if category == "Private" || category == "DomainAuthenticated" { + r.logger.Verbosef("setPrivateNetwork: Already private/domain, skipping") + return true, nil + } + + // Set to Private + cmd = exec.Command("powershell", "-Command", fmt.Sprintf(`Set-NetConnectionProfile -InterfaceAlias "%s" -NetworkCategory Private`, alias)) + output, err = cmd.CombinedOutput() + if err != nil { + r.logger.Errorf("setPrivateNetwork: Set-NetConnectionProfile failed: %v (output: %s)", err, output) + return false, err + } + + r.logger.Verbosef("setPrivateNetwork: Success") + return true, nil +} + +// interfaceFromLUID returns IPAdapterAddresses with specified LUID. +func interfaceFromLUID(luid winipcfg.LUID, flags winipcfg.GAAFlags) (*winipcfg.IPAdapterAddresses, error) { + addresses, err := winipcfg.GetAdaptersAddresses(windows.AF_UNSPEC, flags) + if err != nil { + return nil, err + } + for _, addr := range addresses { + if addr.LUID == luid { + return addr, nil + } + } + return nil, fmt.Errorf("interfaceFromLUID: interface with LUID %v not found", luid) +} + +// Flush clears the local resolver cache. +// Only Windows has a public dns.Flush, needed in router_windows.go. Other +// platforms like Linux need a different flush implementation depending on +// the DNS manager. There is a FlushCaches method on the manager which +// can be used on all platforms. +func flushCaches() error { + cmd := exec.Command("ipconfig", "/flushdns") + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: windows.DETACHED_PROCESS, + } + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("%v (output: %s)", err, out) + } + return nil +} diff --git a/tunnel/tools/libwg-go/vpn/router/router.go b/tunnel/tools/libwg-go/vpn/router/router.go new file mode 100644 index 0000000..07a39af --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/router/router.go @@ -0,0 +1,78 @@ +package router + +import ( + "net/netip" + "reflect" + "slices" +) + +// Router is responsible for managing the system network stack. +type Router interface { + // Set updates the OS network stack with a new Config. It may be + // called multiple times with identical Configs, which the + // implementation should handle gracefully. If it is a full tunnel config, kill switch is enabled + // for the duration of the tunnel already independently enabled. + Set(*Config) error + + // Close closes the router, cleaning up routes and disabling kill switch if it was enabled by the router. + Close() error +} + +// Config is the subset of configuration that is relevant to our Router +type Config struct { + // TunnelAddrs are the addresses for the tunnel interface + TunnelAddrs []netip.Prefix + + // DNS configured for the tunnel, falls back to system if not set + DNS []netip.Addr + + SearchDomains []string + + // Routes are the routes that point into the tunnel + // interface. These are the /32 and /128 routes to peers, (AllowedIps). + Routes []netip.Prefix + + // Falls back to WG default if not set + MTU int + + // Generated by system if not set + ListenPort uint16 +} + +func (c *Config) Equal(b *Config) bool { + if c == nil && b == nil { + return true + } + if (c == nil) != (b == nil) { + return false + } + return reflect.DeepEqual(c, b) +} + +func (c *Config) Clone() *Config { + if c == nil { + return nil + } + c2 := *c + c2.TunnelAddrs = slices.Clone(c.TunnelAddrs) + c2.DNS = slices.Clone(c.DNS) + c2.Routes = slices.Clone(c.Routes) + return &c2 +} + +// HasDefaultRoute checks if tunnel is full tunnel +func (c *Config) hasDefaultRoute(v4 bool) bool { + if c == nil { + return false + } + for _, rt := range c.Routes { + if rt.Bits() == 0 && ((v4 && rt.Addr().Is4()) || (!v4 && rt.Addr().Is6())) { + return true + } + } + return false +} + +func (c *Config) HasAnyDefaultRoute() bool { + return c.hasDefaultRoute(true) || c.hasDefaultRoute(false) +} diff --git a/tunnel/tools/libwg-go/vpn/vpn.go b/tunnel/tools/libwg-go/vpn/vpn.go new file mode 100755 index 0000000..d3d5a6b --- /dev/null +++ b/tunnel/tools/libwg-go/vpn/vpn.go @@ -0,0 +1,378 @@ +//go:build !android + +package vpn + +/* +#include +typedef void (*StatusCodeCallback)(int32_t handle, int32_t status); +*/ +import "C" +import ( + "context" + "errors" + "net" + "net/netip" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + + "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/amnezia-vpn/amneziawg-go/tun" + wireproxyawg "github.com/artem-russkikh/wireproxy-awg" + "github.com/wgtunnel/desktop/tunnel/constants" + "github.com/wgtunnel/desktop/tunnel/dns" + "github.com/wgtunnel/desktop/tunnel/ipc" + "github.com/wgtunnel/desktop/tunnel/shared" + "github.com/wgtunnel/desktop/tunnel/util" + bind2 "github.com/wgtunnel/desktop/tunnel/vpn/bind" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall" + "github.com/wgtunnel/desktop/tunnel/vpn/firewall/osfirewall/firewallmgr" + "github.com/wgtunnel/desktop/tunnel/vpn/router" + "github.com/wgtunnel/desktop/tunnel/vpn/router/osrouter" +) + +type TunnelHandle struct { + device *device.Device + uapi net.Listener + router router.Router + cancel context.CancelFunc + needsResolving atomic.Bool +} + +var ( + tag = "AwgVPN" + tunnelHandles = make(map[int32]*TunnelHandle) + resolvingHandles = sync.Map{} + logger = shared.NewLogger(tag) +) + +func init() { + // handle shutdown signals + go handleSignals() +} + +func handleSignals() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + <-sigs + awgTurnOffAll() + os.Exit(0) +} + +//export awgTurnOn +func awgTurnOn(settings *C.char, callback C.StatusCodeCallback) C.int { + handleID, err := util.GenerateHandle(tunnelHandles) + if err != nil { + shared.LogError(tag, "Unable to find empty handle", err) + return C.int(-1) + } + + shared.StoreTunnelCallback(handleID, shared.StatusCodeCallback(callback)) + + h := &TunnelHandle{} + var success bool + + defer func() { + if !success { + shared.LogDebug(tag, "Startup failed, cleaning up partial resources for handle %d", handleID) + h.close() + resolvingHandles.Delete(handleID) + } + }() + + goSettings := C.GoString(settings) + conf, err := wireproxyawg.ParseConfigString(goSettings) + if err != nil { + shared.LogError(tag, "Invalid config file", err) + return C.int(-1) + } + + // Create a context to manage resolution goroutines + tunnelCtx, tunnelCancel := context.WithCancel(context.Background()) + h.cancel = tunnelCancel + + // Check for peers needing resolution, but we wait to start resolution until the firewall bypasses are set + type peerToResolve struct { + index int + host string + } + var resolutionQueue []peerToResolve + + for i := range conf.Device.Peers { + peer := &conf.Device.Peers[i] + if peer.NeedsResolution() { + host, port, err := net.SplitHostPort(*peer.Endpoint) + if err != nil { + shared.LogError(tag, "Failed to parse endpoint", err) + continue + } + // set dummy, non-routable address with original port + dummyEndpoint := constants.DummyAddress + ":" + port + peer.Endpoint = &dummyEndpoint + + resolutionQueue = append(resolutionQueue, peerToResolve{i, host}) + } + } + + tunnel, err := tun.CreateTUN(constants.IfaceName, conf.Device.MTU) + if err != nil { + shared.LogError(tag, "Create TUN failed", err) + return C.int(-1) + } + + bind := conn.NewDefaultBind() + if err := bind2.SetupBind(logger, bind); err != nil { + tunnel.Close() + return C.int(-1) + } + + statusCB := func(code device.StatusCode) { + go shared.NotifyStatusCode(handleID, int32(code)) + } + + h.device = device.NewDevice(tunnel, bind, logger, false, statusCB) + + var listenPort uint16 = 0 + if conf.Device.ListenPort != nil { + listenPort = uint16(*conf.Device.ListenPort) + } + + _, port, err := h.device.Bind().Open(listenPort) + if err != nil { + shared.LogError(tag, "Failed to open bind", err) + return C.int(-1) + } + + ifaceName, _ := tunnel.Name() + uapi, err := ipc.SetupIPC(ifaceName) + if err != nil { + shared.LogError(tag, "Setup IPC failed", err) + return C.int(-1) + } + h.uapi = uapi + + go func(d *device.Device, l net.Listener) { + for { + connection, err := l.Accept() + if err != nil { + return + } + go d.IpcHandle(connection) + } + }(h.device, h.uapi) + + ipcRequest, err := wireproxyawg.CreateIPCRequest(conf.Device, false) + if err != nil { + return C.int(-1) + } + if err := h.device.IpcSet(ipcRequest.IpcRequest); err != nil { + return C.int(-1) + } + + fw, err := newFirewall() + if err != nil { + return C.int(-1) + } + + r, err := newRouter(ifaceName, fw, tunnel) + if err != nil { + return C.int(-1) + } + h.router = r + + if err := h.device.Up(); err != nil { + return C.int(-1) + } + + // parse config to router config for router/fw + routerCfg, err := parseToRouterConfig(conf, port) + if err != nil { + return C.int(-1) + } + if err := h.router.Set(routerCfg); err != nil { + return C.int(-1) + } + + // try to resolve DNS to replace our dummy endpoints + for _, p := range resolutionQueue { + go resolveAndUpdatePeer(tunnelCtx, handleID, conf, p.index, p.host) + } + + success = true + tunnelHandles[handleID] = h + shared.LogDebug(tag, "Device started successfully; DNS bypasses active for handle %d", handleID) + + return C.int(handleID) +} + +// resolveAndUpdatePeer resolves the host and updates the peer's endpoint if successful. +func resolveAndUpdatePeer(ctx context.Context, tunnelHandle int32, conf *wireproxyawg.Configuration, peerIndex int, host string) { + + resolvingHandles.Store(tunnelHandle, true) + shared.NotifyStatusCode(tunnelHandle, shared.StatusResolvingDNS) + + select { + case <-ctx.Done(): + shared.LogDebug(tag, "Tunnel context cancelled, stopping resolver for %s", host) + resolvingHandles.Delete(tunnelHandle) + return + default: + } + + opts := dns.DefaultOptions() + // TODO make configurable by user + preferIPv6 := false + + resolved, err := dns.ResolveWithBackoff(ctx, host, opts, preferIPv6, logger) + if err != nil { + shared.LogError(tag, "Permanent failure resolving %s: %v", host, err) + return + } + shared.LogDebug(tag, "Successfully resolved the tunnel peer endpoints..") + + var ip netip.Addr + if preferIPv6 && len(resolved.V6) > 0 { + ip = resolved.V6[0] + shared.LogDebug(tag, "Successfully set peer endpoint to preferred resolved ipv6..") + } else if len(resolved.V4) > 0 { + ip = resolved.V4[0] + shared.LogDebug(tag, "Successfully set peer endpoint to resolved ipv4..") + } else { + shared.LogError(tag, "No suitable IP resolved for %s", host) + return + } + + shared.LogDebug(tag, "Updating config with resolved peer endpoints..") + // Update the peer config's peer endpoint from dummy + peer := &conf.Device.Peers[peerIndex] + if err := peer.UpdateEndpointIP(ip); err != nil { + shared.LogError(tag, "Failed to update endpoint for peer %s: %v", peer.PublicKey, err) + return + } + + // Update peers via UAPI + ipcRequest, err := wireproxyawg.CreatePeerIPCRequest(conf.Device) + if err != nil { + shared.LogError(tag, "CreatePeerIPCRequest: %v", err) + return + } + + handle, ok := tunnelHandles[tunnelHandle] + if !ok || handle.cancel == nil { + shared.LogDebug(tag, "Tunnel down, skipping update for %s", host) + return + } + if err := handle.device.IpcSet(ipcRequest.IpcRequest); err != nil { + shared.LogError(tag, "Failed to update peers: %v", err) + return + } + + shared.LogDebug(tag, "Successfully updated peer with resolved endpoint for %s", host) + resolvingHandles.Delete(tunnelHandle) +} + +func (h *TunnelHandle) close() { + if h == nil { + return + } + + // stop all goroutines + if h.cancel != nil { + h.cancel() + } + + // close UAPI listener + if h.uapi != nil { + _ = h.uapi.Close() + } + + // close router to clean up router and firewall rules + if h.router != nil { + _ = h.router.Close() + } + + // close tun device + if h.device != nil { + h.device.Close() + } +} + +//export awgTurnOff +func awgTurnOff(tunnelHandle C.int) { + id := int32(tunnelHandle) + handle, ok := tunnelHandles[id] + if !ok { + shared.LogError(tag, "Tunnel is not up") + return + } + + delete(tunnelHandles, id) + handle.close() + resolvingHandles.Delete(id) +} + +//export awgGetConfig +func awgGetConfig(tunnelHandle C.int) *C.char { + goTunnelHandle := int32(tunnelHandle) + handle, ok := tunnelHandles[goTunnelHandle] + if !ok { + return nil + } + settings, err := handle.device.IpcGet() + if err != nil { + shared.LogError(tag, "Failed to get device config: %v", err) + return nil + } + return C.CString(settings) +} + +//export awgTurnOffAll +func awgTurnOffAll() { + for handle := range tunnelHandles { + awgTurnOff(C.int(handle)) + } + tunnelHandles = make(map[int32]*TunnelHandle) +} + +func newRouter(iface string, fw firewall.Firewall, tunnel tun.Device) (router.Router, error) { + return osrouter.New(iface, fw, tunnel, shared.NewLogger("Router")) +} + +func newFirewall() (firewall.Firewall, error) { + return firewallmgr.Get() +} + +func parseToRouterConfig(conf *wireproxyawg.Configuration, listenPort uint16) (*router.Config, error) { + device := conf.Device + if device == nil { + return nil, errors.New("no [Interface] section found in config") + } + + cfg := &router.Config{ + MTU: device.MTU, + } + + // Normalize and add tunnel addresses for router + for _, addr := range device.Address { + bitLen := 32 + if addr.Is6() { + bitLen = 128 + } + prefix := netip.PrefixFrom(addr, bitLen).Masked() + cfg.TunnelAddrs = append(cfg.TunnelAddrs, prefix) + } + + cfg.DNS = device.DNS + cfg.SearchDomains = device.SearchDomains + cfg.ListenPort = listenPort + + // Add peer routes (AllowedIPs) to router routes + for _, peer := range device.Peers { + cfg.Routes = append(cfg.Routes, peer.AllowedIPs...) + } + + return cfg, nil +} diff --git a/tunnel/tools/wintun/amd64/wintun.dll b/tunnel/tools/wintun/amd64/wintun.dll new file mode 100644 index 0000000..aee04e7 Binary files /dev/null and b/tunnel/tools/wintun/amd64/wintun.dll differ diff --git a/tunnel/tools/wintun/arm64/wintun.dll b/tunnel/tools/wintun/arm64/wintun.dll new file mode 100644 index 0000000..dc4e4ae Binary files /dev/null and b/tunnel/tools/wintun/arm64/wintun.dll differ