From 62da1e76826276b493ce7f8a9581d482cd7c16ee Mon Sep 17 00:00:00 2001 From: Aiden Cline <63023139+rekram1-node@users.noreply.github.com> Date: Wed, 27 May 2026 15:02:37 -0500 Subject: [PATCH] feat(openai): add responses websocket transport (#29477) --- bun.lock | 10 +- packages/opencode/package.json | 2 + packages/opencode/src/effect/runtime-flags.ts | 1 + packages/opencode/src/plugin/index.ts | 37 +- packages/opencode/src/plugin/openai/README.md | 31 + .../opencode/src/plugin/{ => openai}/codex.ts | 61 +- .../opencode/src/plugin/openai/ws-pool.ts | 247 +++++++ packages/opencode/src/plugin/openai/ws.ts | 315 +++++++++ .../test/effect/runtime-flags.test.ts | 11 + packages/opencode/test/plugin/codex.test.ts | 20 +- .../test/plugin/openai-rollout.test.ts | 17 + .../opencode/test/plugin/openai-ws.test.ts | 619 ++++++++++++++++++ 12 files changed, 1326 insertions(+), 45 deletions(-) create mode 100644 packages/opencode/src/plugin/openai/README.md rename packages/opencode/src/plugin/{ => openai}/codex.ts (91%) create mode 100644 packages/opencode/src/plugin/openai/ws-pool.ts create mode 100644 packages/opencode/src/plugin/openai/ws.ts create mode 100644 packages/opencode/test/plugin/openai-rollout.test.ts create mode 100644 packages/opencode/test/plugin/openai-ws.test.ts diff --git a/bun.lock b/bun.lock index a5764fc15f..aee242b042 100644 --- a/bun.lock +++ b/bun.lock @@ -478,6 +478,7 @@ "@solid-primitives/event-bus": "1.1.2", "@solid-primitives/scheduled": "1.5.2", "@standard-schema/spec": "1.0.0", + "@types/ws": "8.18.1", "@zip.js/zip.js": "2.7.62", "ai": "catalog:", "ai-gateway-provider": "3.1.2", @@ -519,6 +520,7 @@ "vscode-jsonrpc": "8.2.1", "web-tree-sitter": "0.25.10", "which": "6.0.1", + "ws": "8.21.0", "xdg-basedir": "5.1.0", "yargs": "18.0.0", "zod": "catalog:", @@ -5138,7 +5140,7 @@ "write-file-atomic": ["write-file-atomic@7.0.1", "", { "dependencies": { "signal-exit": "^4.0.1" } }, "sha512-OTIk8iR8/aCRWBqvxrzxR0hgxWpnYBblY1S5hDWBQfk/VFmJwzmJgQFN3WsoUKHISv2eAwe+PpbUzyL1CKTLXg=="], - "ws": ["ws@8.18.0", "", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-8VbfWfHLbbwu3+N6OKsOMpBdT4kXPDDB9cJk2bJ6mh9ucxdlnNvH1e+roYkKmN9Nxw2yjz7VzeO9oOz2zJ04Pw=="], + "ws": ["ws@8.21.0", "", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-Vsp28b7DRcimFQvrqu2Wek3z1iYxDCWqHYB8Qsnk/S4RfaCQzPGPyBNuVjJV3cd6UiKtUtp6sNM77gWvzcCH+g=="], "wsl-utils": ["wsl-utils@0.3.1", "", { "dependencies": { "is-wsl": "^3.1.0", "powershell-utils": "^0.1.0" } }, "sha512-g/eziiSUNBSsdDJtCLB8bdYEUMj4jR7AGeUo96p/3dTafgjHhpF4RiCFPiRILwjQoDXx5MqkBr4fwWtR3Ky4Wg=="], @@ -5496,6 +5498,8 @@ "@cloudflare/kv-asset-handler/mime": ["mime@3.0.0", "", { "bin": { "mime": "cli.js" } }, "sha512-jSCU7/VB1loIWBZe14aEYHU/+1UMEHoaO7qxCOVJOw9GgH72VAWppxNcjU+x9a2k3GSIBXNKxXQFqRvvZ7vr3A=="], + "@cloudflare/vite-plugin/ws": ["ws@8.18.0", "", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-8VbfWfHLbbwu3+N6OKsOMpBdT4kXPDDB9cJk2bJ6mh9ucxdlnNvH1e+roYkKmN9Nxw2yjz7VzeO9oOz2zJ04Pw=="], + "@cspotcode/source-map-support/@jridgewell/trace-mapping": ["@jridgewell/trace-mapping@0.3.9", "", { "dependencies": { "@jridgewell/resolve-uri": "^3.0.3", "@jridgewell/sourcemap-codec": "^1.4.10" } }, "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ=="], "@develar/schema-utils/ajv": ["ajv@6.14.0", "", { "dependencies": { "fast-deep-equal": "^3.1.1", "fast-json-stable-stringify": "^2.0.0", "json-schema-traverse": "^0.4.1", "uri-js": "^4.2.2" } }, "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw=="], @@ -5958,6 +5962,8 @@ "miniflare/undici": ["undici@7.14.0", "", {}, "sha512-Vqs8HTzjpQXZeXdpsfChQTlafcMQaaIwnGwLam1wudSSjlJeQ3bw1j+TLPePgrCnCpUXx7Ba5Pdpf5OBih62NQ=="], + "miniflare/ws": ["ws@8.18.0", "", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-8VbfWfHLbbwu3+N6OKsOMpBdT4kXPDDB9cJk2bJ6mh9ucxdlnNvH1e+roYkKmN9Nxw2yjz7VzeO9oOz2zJ04Pw=="], + "miniflare/zod": ["zod@3.22.3", "", {}, "sha512-EjIevzuJRiRPbVH4mGc8nApb/lVLKVpmUhAaR5R5doKGfAnGJ6Gr3CViAVjP+4FWSxCsybeWQdcgCtbX+7oZug=="], "minipass-flush/minipass": ["minipass@3.3.6", "", { "dependencies": { "yallist": "^4.0.0" } }, "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw=="], @@ -6074,6 +6080,8 @@ "storybook/open": ["open@10.2.0", "", { "dependencies": { "default-browser": "^5.2.1", "define-lazy-prop": "^3.0.0", "is-inside-container": "^1.0.0", "wsl-utils": "^0.1.0" } }, "sha512-YgBpdJHPyQ2UE5x+hlSXcnejzAvD0b22U2OuAP+8OnlJT+PjWPxtgmGqKKc+RgTM63U9gN0YzrYc71R2WT/hTA=="], + "storybook/ws": ["ws@8.18.0", "", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-8VbfWfHLbbwu3+N6OKsOMpBdT4kXPDDB9cJk2bJ6mh9ucxdlnNvH1e+roYkKmN9Nxw2yjz7VzeO9oOz2zJ04Pw=="], + "storybook-solidjs-vite/vite-plugin-solid": ["vite-plugin-solid@2.11.12", "", { "dependencies": { "@babel/core": "^7.23.3", "@types/babel__core": "^7.20.4", "babel-preset-solid": "^1.8.4", "merge-anything": "^5.1.7", "solid-refresh": "^0.6.3", "vitefu": "^1.0.4" }, "peerDependencies": { "@testing-library/jest-dom": "^5.16.6 || ^5.17.0 || ^6.*", "solid-js": "^1.7.2", "vite": "^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0" }, "optionalPeers": ["@testing-library/jest-dom"] }, "sha512-FgjPcx2OwX9h6f28jli7A4bG7PP3te8uyakE5iqsmpq3Jqi1TWLgSroC9N6cMfGRU2zXsl4Q6ISvTr2VL0QHpA=="], "string-width-cjs/emoji-regex": ["emoji-regex@8.0.0", "", {}, "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A=="], diff --git a/packages/opencode/package.json b/packages/opencode/package.json index 991a6c4571..4dcc3da664 100644 --- a/packages/opencode/package.json +++ b/packages/opencode/package.json @@ -122,6 +122,7 @@ "@solid-primitives/event-bus": "1.1.2", "@solid-primitives/scheduled": "1.5.2", "@standard-schema/spec": "1.0.0", + "@types/ws": "8.18.1", "@zip.js/zip.js": "2.7.62", "ai": "catalog:", "ai-gateway-provider": "3.1.2", @@ -163,6 +164,7 @@ "vscode-jsonrpc": "8.2.1", "web-tree-sitter": "0.25.10", "which": "6.0.1", + "ws": "8.21.0", "xdg-basedir": "5.1.0", "yargs": "18.0.0", "zod": "catalog:" diff --git a/packages/opencode/src/effect/runtime-flags.ts b/packages/opencode/src/effect/runtime-flags.ts index b12a5d5707..e0360ea722 100644 --- a/packages/opencode/src/effect/runtime-flags.ts +++ b/packages/opencode/src/effect/runtime-flags.ts @@ -54,6 +54,7 @@ export class Service extends ConfigService.Service()("@opencode/Runtime outputTokenMax: positiveInteger("OPENCODE_EXPERIMENTAL_OUTPUT_TOKEN_MAX"), bashDefaultTimeoutMs: positiveInteger("OPENCODE_EXPERIMENTAL_BASH_DEFAULT_TIMEOUT_MS"), experimentalNativeLlm: bool("OPENCODE_EXPERIMENTAL_NATIVE_LLM"), + experimentalWebSockets: bool("OPENCODE_EXPERIMENTAL_WEBSOCKETS"), client: Config.string("OPENCODE_CLIENT").pipe(Config.withDefault("cli")), }) {} diff --git a/packages/opencode/src/plugin/index.ts b/packages/opencode/src/plugin/index.ts index 793a072b29..717dff8db2 100644 --- a/packages/opencode/src/plugin/index.ts +++ b/packages/opencode/src/plugin/index.ts @@ -10,7 +10,7 @@ import { Bus } from "../bus" import * as Log from "@opencode-ai/core/util/log" import { createOpencodeClient } from "@opencode-ai/sdk" import { ServerAuth } from "@/server/auth" -import { CodexAuthPlugin } from "./codex" +import { CodexAuthPlugin } from "./openai/codex" import { Session } from "@/session/session" import { NamedError } from "@opencode-ai/core/util/error" import { CopilotAuthPlugin } from "./github-copilot/copilot" @@ -29,6 +29,7 @@ import { parsePluginSpecifier, readPluginId, readV1Plugin, resolvePluginId } fro import { registerAdapter } from "@/control-plane/adapters" import type { WorkspaceAdapter } from "@/control-plane/types" import { RuntimeFlags } from "@/effect/runtime-flags" +import { InstallationChannel } from "@opencode-ai/core/installation/version" const log = Log.create({ service: "plugin" }) @@ -57,18 +58,28 @@ export interface Interface { export class Service extends Context.Service()("@opencode/Plugin") {} +export function experimentalWebSocketsEnabled(input: { enabled: boolean; channel?: string }) { + return input.enabled || ["local", "dev", "beta"].includes(input.channel ?? InstallationChannel) +} + // Built-in plugins that are directly imported (not installed from npm) -const INTERNAL_PLUGINS: PluginInstance[] = [ - CodexAuthPlugin, - CopilotAuthPlugin, - GitlabAuthPlugin, - PoeAuthPlugin, - CloudflareWorkersAuthPlugin, - CloudflareAIGatewayAuthPlugin, - AzureAuthPlugin, - DigitalOceanAuthPlugin, - XaiAuthPlugin, -] +function internalPlugins(flags: RuntimeFlags.Info): PluginInstance[] { + return [ + // Temporary rollout: pre-release builds use WebSockets by default; releases require explicit opt-in. + (input) => + CodexAuthPlugin(input, { + experimentalWebSockets: experimentalWebSocketsEnabled({ enabled: flags.experimentalWebSockets }), + }), + CopilotAuthPlugin, + GitlabAuthPlugin, + PoeAuthPlugin, + CloudflareWorkersAuthPlugin, + CloudflareAIGatewayAuthPlugin, + AzureAuthPlugin, + DigitalOceanAuthPlugin, + XaiAuthPlugin, + ] +} function isServerPlugin(value: unknown): value is PluginInstance { return typeof value === "function" @@ -151,7 +162,7 @@ export const layer = Layer.effect( $: typeof Bun === "undefined" ? undefined : Bun.$, } - for (const plugin of flags.disableDefaultPlugins ? [] : INTERNAL_PLUGINS) { + for (const plugin of flags.disableDefaultPlugins ? [] : internalPlugins(flags)) { log.info("loading internal plugin", { name: plugin.name }) const init = yield* Effect.tryPromise({ try: () => plugin(input), diff --git a/packages/opencode/src/plugin/openai/README.md b/packages/opencode/src/plugin/openai/README.md new file mode 100644 index 0000000000..c359f8c60c --- /dev/null +++ b/packages/opencode/src/plugin/openai/README.md @@ -0,0 +1,31 @@ +# OpenAI Responses WebSocket + +Enabled by default on `local`, `dev`, and `beta`. On `latest` and `prod`, set `OPENCODE_EXPERIMENTAL_WEBSOCKETS=true`. + +## Flow + +1. A streamed `POST /responses` request arrives. +2. If it has no `session-id` or `x-session-affinity` header, use HTTP. +3. Title requests use HTTP. +4. If that session's socket is busy or already in fallback mode, use HTTP. +5. Otherwise, reuse its open socket or open a new one. +6. Send `response.create` and return WebSocket events as SSE. + +## Lifetime + +- Connect timeout: 15 seconds. +- Idle timeout: 5 minutes. +- After a completed response, keep the socket for reuse. +- Reuse a socket for up to 55 minutes, then replace it on the next request. + +## Retries + +- If WebSocket setup fails or it fails before its first event, replay over HTTP and keep that session on HTTP until idle-pruned. +- If the server returns `websocket_connection_limit_reached` before output, reconnect up to 5 times, then follow the same HTTP fallback. +- If a WebSocket fails after its first event, fail the stream. Do not replay partial output. +- Abort or cancel closes the socket. + +## Next Steps + +- `previous_response_id` continuation. +- Optional second WebSocket for concurrent requests in one session. Currently these use HTTP. diff --git a/packages/opencode/src/plugin/codex.ts b/packages/opencode/src/plugin/openai/codex.ts similarity index 91% rename from packages/opencode/src/plugin/codex.ts rename to packages/opencode/src/plugin/openai/codex.ts index bbd65e0f00..b6edf173c8 100644 --- a/packages/opencode/src/plugin/codex.ts +++ b/packages/opencode/src/plugin/openai/codex.ts @@ -1,10 +1,11 @@ import type { Hooks, PluginInput } from "@opencode-ai/plugin" import * as Log from "@opencode-ai/core/util/log" import { InstallationVersion } from "@opencode-ai/core/installation/version" -import { OAUTH_DUMMY_KEY } from "../auth" +import { OAUTH_DUMMY_KEY } from "../../auth" import os from "os" import { setTimeout as sleep } from "node:timers/promises" import { createServer } from "http" +import { OpenAIWebSocketPool } from "./ws-pool" const log = Log.create({ service: "plugin.codex" }) @@ -28,20 +29,12 @@ interface PkceCodes { } async function generatePKCE(): Promise { - const verifier = generateRandomString(43) - const encoder = new TextEncoder() - const data = encoder.encode(verifier) - const hash = await crypto.subtle.digest("SHA-256", data) - const challenge = base64UrlEncode(hash) - return { verifier, challenge } -} - -function generateRandomString(length: number): string { const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" - const bytes = crypto.getRandomValues(new Uint8Array(length)) - return Array.from(bytes) + const verifier = Array.from(crypto.getRandomValues(new Uint8Array(43))) .map((b) => chars[b % chars.length]) .join("") + const challenge = base64UrlEncode(await crypto.subtle.digest("SHA-256", new TextEncoder().encode(verifier))) + return { verifier, challenge } } function base64UrlEncode(buffer: ArrayBuffer): string { @@ -50,10 +43,6 @@ function base64UrlEncode(buffer: ArrayBuffer): string { return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "") } -function generateState(): string { - return base64UrlEncode(crypto.getRandomValues(new Uint8Array(32)).buffer) -} - export interface IdTokenClaims { chatgpt_account_id?: string organizations?: Array<{ id: string }> @@ -120,6 +109,7 @@ interface TokenResponse { interface CodexAuthPluginOptions { issuer?: string codexApiEndpoint?: string + experimentalWebSockets?: boolean } async function exchangeCodeForTokens(code: string, redirectUri: string, pkce: PkceCodes): Promise { @@ -371,8 +361,14 @@ function waitForOAuthCallback(pkce: PkceCodes, state: string): Promise { const issuer = options.issuer ?? ISSUER const codexApiEndpoint = options.codexApiEndpoint ?? CODEX_API_ENDPOINT + let websocketFetchInstalled = false + const websocketFetches: Array> = [] return { + async dispose() { + for (const websocketFetch of websocketFetches) websocketFetch.close() + websocketFetches.length = 0 + }, provider: { id: "openai", async models(provider, ctx) { @@ -410,7 +406,14 @@ export async function CodexAuthPlugin(input: PluginInput, options: CodexAuthPlug provider: "openai", async loader(getAuth) { const auth = await getAuth() - if (auth.type !== "oauth") return {} + const websocketFetch = options.experimentalWebSockets + ? OpenAIWebSocketPool.createWebSocketFetch({ httpFetch: fetch }) + : undefined + if (websocketFetch) { + websocketFetches.push(websocketFetch) + websocketFetchInstalled = true + } + if (auth.type !== "oauth") return websocketFetch ? { fetch: websocketFetch } : {} let refreshPromise: | Promise<{ @@ -422,7 +425,6 @@ export async function CodexAuthPlugin(input: PluginInput, options: CodexAuthPlug return { apiKey: OAUTH_DUMMY_KEY, async fetch(requestInput: RequestInfo | URL, init?: RequestInit) { - // Remove dummy API key authorization header if (init?.headers) { if (init.headers instanceof Headers) { init.headers.delete("authorization") @@ -436,12 +438,11 @@ export async function CodexAuthPlugin(input: PluginInput, options: CodexAuthPlug } const currentAuth = await getAuth() - if (currentAuth.type !== "oauth") return fetch(requestInput, init) + if (currentAuth.type !== "oauth") + return websocketFetch ? websocketFetch(requestInput, init) : fetch(requestInput, init) - // Cast to include accountId field const authWithAccount = currentAuth as typeof currentAuth & { accountId?: string } - // Check if token needs refresh if (!currentAuth.access || currentAuth.expires < Date.now()) { if (!refreshPromise) { log.info("refreshing codex access token") @@ -473,7 +474,6 @@ export async function CodexAuthPlugin(input: PluginInput, options: CodexAuthPlug authWithAccount.accountId = refreshed.accountId } - // Build headers const headers = new Headers() if (init?.headers) { if (init.headers instanceof Headers) { @@ -488,16 +488,11 @@ export async function CodexAuthPlugin(input: PluginInput, options: CodexAuthPlug } } } - - // Set authorization header with access token headers.set("authorization", `Bearer ${currentAuth.access}`) - - // Set ChatGPT-Account-Id header for organization subscriptions if (authWithAccount.accountId) { headers.set("ChatGPT-Account-Id", authWithAccount.accountId) } - // Rewrite URL to Codex endpoint const parsed = requestInput instanceof URL ? requestInput @@ -507,10 +502,12 @@ export async function CodexAuthPlugin(input: PluginInput, options: CodexAuthPlug ? new URL(codexApiEndpoint) : parsed - return fetch(url, { + const requestInit = { ...init, headers, - }) + } + if (websocketFetch && parsed.pathname.includes("/v1/responses")) return websocketFetch(url, requestInit) + return fetch(url, OpenAIWebSocketPool.withoutInternalHeaders(requestInit)) }, } }, @@ -521,7 +518,7 @@ export async function CodexAuthPlugin(input: PluginInput, options: CodexAuthPlug authorize: async () => { const { redirectUri } = await startOAuthServer() const pkce = await generatePKCE() - const state = generateState() + const state = base64UrlEncode(crypto.getRandomValues(new Uint8Array(32)).buffer) const authUrl = buildAuthorizeUrl(redirectUri, pkce, state) const callbackPromise = waitForOAuthCallback(pkce, state) @@ -639,6 +636,10 @@ export async function CodexAuthPlugin(input: PluginInput, options: CodexAuthPlug output.headers.originator = "opencode" output.headers["User-Agent"] = `opencode/${InstallationVersion} (${os.platform()} ${os.release()}; ${os.arch()})` output.headers["session-id"] = input.sessionID + // Temporary fetch-layer hack: title generation currently shares the conversation + // session ID, so the OpenAI plugin marks it for HTTP fallback until transport + // context can be passed directly instead of smuggled through headers. + if (websocketFetchInstalled && input.agent === "title") output.headers[OpenAIWebSocketPool.TITLE_HEADER] = "true" }, "chat.params": async (input, output) => { if (input.model.providerID !== "openai") return diff --git a/packages/opencode/src/plugin/openai/ws-pool.ts b/packages/opencode/src/plugin/openai/ws-pool.ts new file mode 100644 index 0000000000..ae18eb5407 --- /dev/null +++ b/packages/opencode/src/plugin/openai/ws-pool.ts @@ -0,0 +1,247 @@ +import WebSocket from "ws" +import * as Log from "@opencode-ai/core/util/log" +import { isRecord } from "@/util/record" +import { OpenAIWebSocket } from "./ws" + +export const TITLE_HEADER = "x-opencode-title" + +const log = Log.create({ service: "plugin.openai.ws" }) + +export interface CreateWebSocketFetchOptions { + httpFetch?: typeof globalThis.fetch + url?: string + connectTimeout?: number + idleTimeout?: number + maxConnectionAge?: number + connectionLimitRetries?: number +} + +interface PoolEntry { + socket?: WebSocket + connectedAt?: number + lastUsedAt: number + busy: boolean + fallback: boolean +} + +const DEFAULT_CONNECT_TIMEOUT = 15_000 +const DEFAULT_IDLE_TIMEOUT = 5 * 60 * 1000 +const DEFAULT_MAX_CONNECTION_AGE = 55 * 60 * 1000 +const CONNECTION_LIMIT_REACHED_CODE = "websocket_connection_limit_reached" + +export function createWebSocketFetch(options?: CreateWebSocketFetchOptions) { + const httpFetch = options?.httpFetch ?? globalThis.fetch + const pool = new Map() + const connectTimeout = options?.connectTimeout ?? DEFAULT_CONNECT_TIMEOUT + const idleTimeout = options?.idleTimeout ?? DEFAULT_IDLE_TIMEOUT + const maxConnectionAge = options?.maxConnectionAge ?? DEFAULT_MAX_CONNECTION_AGE + const connectionLimitRetries = options?.connectionLimitRetries ?? 5 + const pruneTimer = setInterval(() => prune(), Math.min(idleTimeout, 60_000)) + if (typeof pruneTimer === "object" && "unref" in pruneTimer && typeof pruneTimer.unref === "function") { + pruneTimer.unref() + } + + async function websocketFetch(input: RequestInfo | URL, init?: RequestInit): Promise { + const url = input instanceof URL ? input.toString() : typeof input === "string" ? input : input.url + const internalHeaders = OpenAIWebSocket.normalizeHeaders(init?.headers) + const httpInit = withoutInternalHeaders(init) + + if (init?.method !== "POST" || !new URL(url).pathname.endsWith("/responses")) { + return httpFetch(input, httpInit) + } + + const body = (() => { + try { + if (typeof init?.body !== "string") return undefined + const parsed = JSON.parse(init.body) + return typeof parsed === "object" && parsed !== null ? parsed : undefined + } catch { + return undefined + } + })() + if (!body?.stream) return httpFetch(input, httpInit) + if (internalHeaders[TITLE_HEADER] === "true") { + log.debug("http fallback", { reason: "title" }) + return httpFetch(input, httpInit) + } + + const sessionID = internalHeaders["x-session-affinity"] ?? internalHeaders["session-id"] + if (!sessionID) { + log.debug("http fallback", { reason: "missing_session" }) + return httpFetch(input, httpInit) + } + const key = `${sessionID}:conversation` + + const entry = pool.get(key) ?? { lastUsedAt: Date.now(), busy: false, fallback: false } + pool.set(key, entry) + + if (entry.fallback) { + log.debug("http fallback", { key, reason: "fallback_active" }) + return httpFetch(input, httpInit) + } + if (entry.busy) { + log.debug("http fallback", { key, reason: "busy" }) + return httpFetch(input, httpInit) + } + + entry.busy = true + entry.lastUsedAt = Date.now() + try { + let connectionLimitAttempts = 0 + entry.socket = await socket( + entry, + options?.url ?? url, + OpenAIWebSocket.normalizeHeaders(httpInit?.headers), + connectTimeout, + maxConnectionAge, + init?.signal, + ) + let resolveFirstEvent: (started: boolean) => void = () => {} + let rejectFirstEvent: (error: Error) => void = () => {} + const firstEvent = new Promise((resolve, reject) => { + resolveFirstEvent = resolve + rejectFirstEvent = reject + }) + const response = OpenAIWebSocket.streamResponsesWebSocket({ + socket: entry.socket, + body, + idleTimeout, + signal: init?.signal ?? undefined, + onFirstEvent: () => resolveFirstEvent(true), + onTerminal: (event) => { + entry.busy = false + entry.lastUsedAt = Date.now() + if (event.type !== "response.completed" && event.type !== "response.done") { + log.warn("websocket terminal failure", { key, type: event.type }) + invalidate(entry) + } + }, + onConnectionInvalid: (error) => { + log.warn("websocket invalidated", { key, error: error instanceof Error ? error.message : String(error) }) + entry.busy = false + entry.fallback = true + invalidate(entry) + resolveFirstEvent(false) + }, + onAbort: (error) => { + log.debug("websocket aborted", { key }) + entry.busy = false + entry.lastUsedAt = Date.now() + invalidate(entry) + rejectFirstEvent(error) + }, + onRetryableTerminal: async (event) => { + const error = connectionLimitError(event) + if (!error) return undefined + if (connectionLimitAttempts >= connectionLimitRetries) throw error + + connectionLimitAttempts++ + log.warn("websocket connection limit reached", { key, attempt: connectionLimitAttempts }) + invalidate(entry) + entry.socket = await socket( + entry, + options?.url ?? url, + OpenAIWebSocket.normalizeHeaders(httpInit?.headers), + connectTimeout, + maxConnectionAge, + init?.signal, + ) + entry.lastUsedAt = Date.now() + return entry.socket + }, + }) + if (await firstEvent) return response + log.debug("http fallback", { key, reason: "websocket_failed_before_first_event" }) + return httpFetch(input, httpInit) + } catch (error) { + entry.busy = false + entry.lastUsedAt = Date.now() + if (OpenAIWebSocket.isAbortError(error)) { + invalidate(entry) + throw error + } + + entry.fallback = true + log.warn("websocket setup failed", { key, error: error instanceof Error ? error.message : String(error), fallback: "http" }) + invalidate(entry) + return httpFetch(input, httpInit) + } + } + + function prune() { + const now = Date.now() + for (const [key, entry] of pool) { + if (entry.busy) continue + if (now - entry.lastUsedAt < idleTimeout) continue + log.debug("websocket idle prune", { key }) + invalidate(entry) + pool.delete(key) + } + } + + function close() { + log.debug("websocket pool close", { count: pool.size }) + clearInterval(pruneTimer) + for (const entry of pool.values()) invalidate(entry) + pool.clear() + } + + return Object.assign(websocketFetch, { close }) +} + +function connectionLimitError(event: Record) { + if (event.type !== "error" || !isRecord(event.error) || event.error.code !== CONNECTION_LIMIT_REACHED_CODE) return + return new Error(typeof event.error.message === "string" ? event.error.message : CONNECTION_LIMIT_REACHED_CODE) +} + +async function socket( + entry: PoolEntry, + url: string, + headers: Record, + connectTimeout: number, + maxConnectionAge: number, + signal?: AbortSignal | null, +) { + if (entry.socket?.readyState === WebSocket.OPEN && entry.connectedAt && Date.now() - entry.connectedAt < maxConnectionAge) { + return entry.socket + } + + invalidate(entry) + const next = await OpenAIWebSocket.connectResponsesWebSocket({ + url: OpenAIWebSocket.toWebSocketUrl(url), + headers, + timeout: connectTimeout, + signal: signal ?? undefined, + }) + entry.connectedAt = Date.now() + return next +} + +function invalidate(entry: PoolEntry) { + if (entry.socket) { + entry.socket.on("error", () => {}) + entry.socket.terminate() + entry.socket = undefined + } + entry.connectedAt = undefined +} + +export function withoutInternalHeaders(init: T | undefined): T | undefined { + if (!init?.headers) return init + if (init.headers instanceof Headers) { + const headers = new Headers(init.headers) + headers.delete(TITLE_HEADER) + return { ...init, headers } + } + + if (Array.isArray(init.headers)) { + return { ...init, headers: init.headers.filter((item) => item[0].toLowerCase() !== TITLE_HEADER) } + } + + return { + ...init, + headers: Object.fromEntries(Object.entries(init.headers).filter(([key]) => key.toLowerCase() !== TITLE_HEADER)), + } +} + +export * as OpenAIWebSocketPool from "./ws-pool" diff --git a/packages/opencode/src/plugin/openai/ws.ts b/packages/opencode/src/plugin/openai/ws.ts new file mode 100644 index 0000000000..9d434d89f4 --- /dev/null +++ b/packages/opencode/src/plugin/openai/ws.ts @@ -0,0 +1,315 @@ +// Low-level OpenAI Responses WebSocket protocol helpers. Session pooling, +// fallback, and continuation state intentionally live above this file. + +import WebSocket from "ws" + +export const PROTOCOL_HEADER = "responses_websockets=2026-02-06" + +export interface ConnectResponsesWebSocketOptions { + url: string + headers: Record + timeout?: number + signal?: AbortSignal +} + +export interface StreamResponsesWebSocketOptions { + socket: WebSocket + body: Record + idleTimeout?: number + signal?: AbortSignal + onFirstEvent?: () => void + onComplete?: (event: Record) => void + onTerminal?: (event: Record) => void + onRetryableTerminal?: (event: Record) => Promise + onConnectionInvalid?: (error: Error) => void + onAbort?: (error: Error) => void +} + +export function toWebSocketUrl(url: string) { + return url.replace(/^http/, "ws") +} + +export function normalizeHeaders(headers: HeadersInit | undefined): Record { + const result: Record = {} + if (!headers) return result + + if (headers instanceof Headers) { + headers.forEach((value, key) => { + result[key.toLowerCase()] = value + }) + return result + } + + if (Array.isArray(headers)) { + for (const [key, value] of headers) { + result[key.toLowerCase()] = value + } + return result + } + + for (const [key, value] of Object.entries(headers)) { + if (value != null) result[key.toLowerCase()] = value + } + return result +} + +export function isAbortError(error: unknown): error is DOMException { + return error instanceof DOMException && error.name === "AbortError" +} + +export function connectResponsesWebSocket(options: ConnectResponsesWebSocketOptions) { + return new Promise((resolve, reject) => { + if (options.signal?.aborted) { + reject(abortError(options.signal)) + return + } + + const headers: Record = { + ...options.headers, + "openai-beta": options.headers["openai-beta"] ?? PROTOCOL_HEADER, + } + delete headers["content-length"] + + const socket = new WebSocket(options.url, { headers }) + const timeout = options.timeout + ? setTimeout(() => { + cleanup() + socket.on("error", () => {}) + socket.terminate() + reject(new Error("WebSocket connect timed out")) + }, options.timeout) + : undefined + + function cleanup() { + if (timeout) clearTimeout(timeout) + socket.off("open", onOpen) + socket.off("error", onError) + socket.off("close", onClose) + options.signal?.removeEventListener("abort", onAbort) + } + + function onOpen() { + cleanup() + resolve(socket) + } + + function onError(error: Error) { + socket.on("error", () => {}) + cleanup() + reject(error) + } + + function onClose(code: number, reason: Buffer) { + cleanup() + reject(closeError("WebSocket closed before open", code, reason)) + } + + function onAbort() { + cleanup() + socket.on("error", () => {}) + socket.terminate() + reject(abortError(options.signal)) + } + + socket.once("open", onOpen) + socket.once("error", onError) + socket.once("close", onClose) + options.signal?.addEventListener("abort", onAbort, { once: true }) + }) +} + +export function streamResponsesWebSocket(options: StreamResponsesWebSocketOptions) { + const encoder = new TextEncoder() + + let socket = options.socket + let controller: ReadableStreamDefaultController | undefined + let cleanupSocket = () => {} + let completed = false + let emitted = false + let idleTimer: ReturnType | undefined + + function cleanup() { + if (idleTimer) clearTimeout(idleTimer) + cleanupSocket() + options.signal?.removeEventListener("abort", onAbort) + } + + function terminateSocket(target = socket) { + target.on("error", () => {}) + target.terminate() + } + + function closeCompleted() { + cleanup() + controller?.enqueue(encoder.encode("data: [DONE]\n\n")) + controller?.close() + } + + function invalidate(error: Error) { + if (completed) return + completed = true + cleanup() + options.onConnectionInvalid?.(error) + controller?.error(error) + } + + function resetIdleTimeout(message: string) { + if (completed) return + if (!options.idleTimeout) return + if (idleTimer) clearTimeout(idleTimer) + idleTimer = setTimeout(() => invalidate(new Error(message)), options.idleTimeout) + if (typeof idleTimer === "object" && "unref" in idleTimer && typeof idleTimer.unref === "function") { + idleTimer.unref() + } + } + + async function onMessage(data: WebSocket.RawData, isBinary: boolean) { + if (completed) return + if (isBinary) { + invalidate(new Error("Unexpected binary WebSocket frame")) + return + } + + const text = data.toString() + const event = (() => { + try { + const parsed = JSON.parse(text) + return typeof parsed === "object" && parsed !== null ? parsed : undefined + } catch { + return undefined + } + })() + + if (event?.type === "error" && !emitted && options.onRetryableTerminal) { + cleanupSocket() + if (idleTimer) clearTimeout(idleTimer) + idleTimer = undefined + try { + const next = await options.onRetryableTerminal(event) + if (completed) { + if (next) terminateSocket(next) + return + } + if (next) { + attach(next) + return + } + } catch (error) { + invalidate(error instanceof Error ? error : new Error(String(error))) + return + } + } + + if (!emitted) options.onFirstEvent?.() + controller?.enqueue(encoder.encode(`${text.split(/\r?\n/).map((line) => `data: ${line}`).join("\n")}\n\n`)) + emitted = true + resetIdleTimeout("idle timeout waiting for websocket") + + if (!event) return + + if (event.type === "response.completed" || event.type === "response.done") { + completed = true + options.onComplete?.(event) + options.onTerminal?.(event) + closeCompleted() + return + } + + if (event.type === "response.failed" || event.type === "response.incomplete" || event.type === "error") { + completed = true + options.onTerminal?.(event) + closeCompleted() + } + } + + function onError(error: Error) { + invalidate(error) + } + + function onClose(code: number, reason: Buffer) { + if (completed) return + invalidate(closeError("WebSocket closed before response.completed", code, reason)) + } + + function onAbort() { + const error = abortError(options.signal) + if (completed) return + completed = true + cleanup() + terminateSocket() + options.onAbort?.(error) + controller?.error(error) + } + + function onCancel(reason: unknown) { + if (completed) return + completed = true + cleanup() + terminateSocket() + options.onAbort?.(cancelError(reason)) + } + + function attach(next: WebSocket) { + cleanupSocket() + socket = next + socket.on("message", onMessage) + socket.once("error", onError) + socket.once("close", onClose) + cleanupSocket = () => { + socket.off("message", onMessage) + socket.off("error", onError) + socket.off("close", onClose) + } + const { stream: _stream, background: _background, ...payload } = options.body + resetIdleTimeout("idle timeout sending websocket request") + socket.send(JSON.stringify({ type: "response.create", ...payload }), (error) => { + if (completed) return + resetIdleTimeout("idle timeout waiting for websocket") + if (error) invalidate(error) + }) + } + + return new Response( + new ReadableStream({ + start(next) { + controller = next + options.signal?.addEventListener("abort", onAbort, { once: true }) + + if (options.signal?.aborted) { + onAbort() + return + } + + attach(socket) + }, + cancel(reason) { + onCancel(reason) + }, + }), + { + status: 200, + headers: { "content-type": "text/event-stream" }, + }, + ) +} + +function cancelError(reason: unknown) { + if (isAbortError(reason)) return reason + if (reason instanceof Error) return reason + return new DOMException(typeof reason === "string" ? reason : "Aborted", "AbortError") +} + +function abortError(signal: AbortSignal | undefined) { + const reason = signal?.reason + if (isAbortError(reason)) return reason + return new DOMException(reason instanceof Error ? reason.message : "Aborted", "AbortError") +} + +function closeError(message: string, code: number, reason: Buffer) { + const details = [`code ${code}`] + if (code === 1009) details.push("message too big") + if (reason.length > 0) details.push(reason.toString()) + return new Error(`${message} (${details.join(": ")})`) +} + +export * as OpenAIWebSocket from "./ws" diff --git a/packages/opencode/test/effect/runtime-flags.test.ts b/packages/opencode/test/effect/runtime-flags.test.ts index b044d07f23..36579f48ea 100644 --- a/packages/opencode/test/effect/runtime-flags.test.ts +++ b/packages/opencode/test/effect/runtime-flags.test.ts @@ -63,6 +63,7 @@ describe("RuntimeFlags", () => { expect(flags.experimentalWorkspaces).toBe(true) expect(flags.experimentalIconDiscovery).toBe(true) expect(flags.experimentalNativeLlm).toBe(false) + expect(flags.experimentalWebSockets).toBe(false) expect(flags.client).toBe("desktop") }), ) @@ -91,6 +92,16 @@ describe("RuntimeFlags", () => { }), ) + it.effect("enables WebSockets via dedicated flag only", () => + Effect.gen(function* () { + const explicit = yield* readFlags.pipe(Effect.provide(fromConfig({ OPENCODE_EXPERIMENTAL_WEBSOCKETS: "true" }))) + const umbrella = yield* readFlags.pipe(Effect.provide(fromConfig({ OPENCODE_EXPERIMENTAL: "true" }))) + + expect(explicit.experimentalWebSockets).toBe(true) + expect(umbrella.experimentalWebSockets).toBe(false) + }), + ) + it.effect("layer accepts partial test overrides and fills defaults from Config definitions", () => Effect.gen(function* () { const flags = yield* readFlags.pipe( diff --git a/packages/opencode/test/plugin/codex.test.ts b/packages/opencode/test/plugin/codex.test.ts index 271bcde99b..a375fe4ee1 100644 --- a/packages/opencode/test/plugin/codex.test.ts +++ b/packages/opencode/test/plugin/codex.test.ts @@ -5,7 +5,7 @@ import { extractAccountIdFromClaims, extractAccountId, type IdTokenClaims, -} from "../../src/plugin/codex" +} from "../../src/plugin/openai/codex" function createTestJwt(payload: object): string { const header = Buffer.from(JSON.stringify({ alg: "none" })).toString("base64url") @@ -122,6 +122,24 @@ describe("plugin.codex", () => { }) }) + test("installs websocket transport only when experimental websockets are enabled", async () => { + const disabled = await CodexAuthPlugin({} as never) + const enabled = await CodexAuthPlugin({} as never, { experimentalWebSockets: true }) + + const disabledOptions = await disabled.auth!.loader!( + async () => ({ type: "api", key: "sk-test" }) as never, + {} as never, + ) + const enabledOptions = await enabled.auth!.loader!( + async () => ({ type: "api", key: "sk-test" }) as never, + {} as never, + ) + + expect(disabledOptions.fetch).toBeUndefined() + expect(enabledOptions.fetch).toBeFunction() + await enabled.dispose?.() + }) + test("deduplicates concurrent Codex token refreshes", async () => { let auth = { type: "oauth" as const, diff --git a/packages/opencode/test/plugin/openai-rollout.test.ts b/packages/opencode/test/plugin/openai-rollout.test.ts new file mode 100644 index 0000000000..1278e1cfbb --- /dev/null +++ b/packages/opencode/test/plugin/openai-rollout.test.ts @@ -0,0 +1,17 @@ +import { describe, expect, test } from "bun:test" +import { experimentalWebSocketsEnabled } from "../../src/plugin" + +describe("plugin.openai.websocket rollout", () => { + test("enables websockets by default only on pre-release channels", () => { + expect(experimentalWebSocketsEnabled({ enabled: false, channel: "local" })).toBe(true) + expect(experimentalWebSocketsEnabled({ enabled: false, channel: "dev" })).toBe(true) + expect(experimentalWebSocketsEnabled({ enabled: false, channel: "beta" })).toBe(true) + expect(experimentalWebSocketsEnabled({ enabled: false, channel: "latest" })).toBe(false) + expect(experimentalWebSocketsEnabled({ enabled: false, channel: "prod" })).toBe(false) + }) + + test("allows releases to opt in through the experimental flag", () => { + expect(experimentalWebSocketsEnabled({ enabled: true, channel: "latest" })).toBe(true) + expect(experimentalWebSocketsEnabled({ enabled: true, channel: "prod" })).toBe(true) + }) +}) diff --git a/packages/opencode/test/plugin/openai-ws.test.ts b/packages/opencode/test/plugin/openai-ws.test.ts new file mode 100644 index 0000000000..65eaca655d --- /dev/null +++ b/packages/opencode/test/plugin/openai-ws.test.ts @@ -0,0 +1,619 @@ +import { describe, expect, test } from "bun:test" +import { EventEmitter } from "node:events" +import type { IncomingMessage } from "node:http" +import net, { type AddressInfo, type Socket } from "node:net" +import WebSocket, { WebSocketServer } from "ws" +import { OpenAIWebSocket } from "../../src/plugin/openai/ws" +import { OpenAIWebSocketPool, TITLE_HEADER } from "../../src/plugin/openai/ws-pool" + +describe("plugin.openai.ws", () => { + test("derives websocket URLs and sends auth plus protocol headers", async () => { + let headers: IncomingMessage["headers"] | undefined + await using server = await createWebSocketServer((_socket, request) => { + headers = request.headers + }) + + const socket = await OpenAIWebSocket.connectResponsesWebSocket({ + url: server.wsUrl, + headers: { authorization: "Bearer test", "content-length": "123" }, + }) + + expect(OpenAIWebSocket.toWebSocketUrl("http://example.com/v1/responses")).toBe("ws://example.com/v1/responses") + expect(OpenAIWebSocket.toWebSocketUrl("https://example.com/v1/responses")).toBe("wss://example.com/v1/responses") + expect(headers?.authorization).toBe("Bearer test") + expect(headers?.["openai-beta"]).toBe(OpenAIWebSocket.PROTOCOL_HEADER) + expect(headers?.["content-length"]).toBeUndefined() + socket.terminate() + }) + + test("enforces websocket connect timeout", async () => { + await using server = await createHangingTcpServer() + + await expect( + OpenAIWebSocket.connectResponsesWebSocket({ + url: server.wsUrl, + headers: {}, + timeout: 20, + }), + ).rejects.toThrow("WebSocket connect timed out") + }) + + test("enforces websocket send idle timeout", async () => { + const socket = new (class extends EventEmitter { + send(_data: string, _callback: (error?: Error) => void) {} + })() as unknown as WebSocket + const invalid: string[] = [] + const response = OpenAIWebSocket.streamResponsesWebSocket({ + socket, + body: { stream: true, input: "hi" }, + idleTimeout: 20, + onConnectionInvalid: (error) => invalid.push(error.message), + }) + + await expect(response.text()).rejects.toThrow("idle timeout sending websocket request") + expect(invalid).toEqual(["idle timeout sending websocket request"]) + }) + + test("streams websocket events as SSE and handles response.done", async () => { + let requestBody: unknown + await using server = await createWebSocketServer((socket) => { + socket.once("message", (data) => { + requestBody = JSON.parse(data.toString()) + socket.send(JSON.stringify({ type: "response.output_text.delta", delta: "hello" })) + socket.send(JSON.stringify({ type: "response.done", response: { id: "resp_123" } })) + socket.close(1000, "done") + }) + }) + + const socket = await OpenAIWebSocket.connectResponsesWebSocket({ + url: server.wsUrl, + headers: { authorization: "Bearer test", "content-length": "123" }, + }) + const completed: Record[] = [] + const response = OpenAIWebSocket.streamResponsesWebSocket({ + socket, + body: { stream: true, background: true, input: "hi" }, + onComplete: (event) => completed.push(event), + }) + + expect(await response.text()).toBe( + 'data: {"type":"response.output_text.delta","delta":"hello"}\n\ndata: {"type":"response.done","response":{"id":"resp_123"}}\n\ndata: [DONE]\n\n', + ) + expect(requestBody).toEqual({ type: "response.create", input: "hi" }) + expect(completed).toHaveLength(1) + expect(completed[0]?.type).toBe("response.done") + }) + + test("errors the SSE stream when the server closes before a terminal event", async () => { + const invalid: string[] = [] + await using server = await createWebSocketServer((socket) => { + socket.once("message", () => { + socket.close(1009, "payload too large") + }) + }) + + const socket = await OpenAIWebSocket.connectResponsesWebSocket({ url: server.wsUrl, headers: {} }) + const response = OpenAIWebSocket.streamResponsesWebSocket({ + socket, + body: { stream: true, input: "hi" }, + onConnectionInvalid: (error) => invalid.push(error.message), + }) + + await expect(response.text()).rejects.toThrow( + "WebSocket closed before response.completed (code 1009: message too big: payload too large)", + ) + expect(invalid).toEqual([ + "WebSocket closed before response.completed (code 1009: message too big: payload too large)", + ]) + }) + + test("rejects unexpected binary websocket frames", async () => { + const invalid: string[] = [] + await using server = await createWebSocketServer((socket) => { + socket.once("message", () => { + socket.send(Buffer.from("not json text")) + }) + }) + + const socket = await OpenAIWebSocket.connectResponsesWebSocket({ url: server.wsUrl, headers: {} }) + const response = OpenAIWebSocket.streamResponsesWebSocket({ + socket, + body: { stream: true, input: "hi" }, + onConnectionInvalid: (error) => invalid.push(error.message), + }) + + await expect(response.text()).rejects.toThrow("Unexpected binary WebSocket frame") + expect(invalid).toEqual(["Unexpected binary WebSocket frame"]) + }) +}) + +describe("plugin.openai.ws-pool", () => { + test("reuses one healthy websocket for sequential requests", async () => { + let connections = 0 + let messages = 0 + await using server = await createWebSocketServer((socket) => { + connections += 1 + socket.on("message", () => { + messages += 1 + socket.send(JSON.stringify({ type: "response.completed", response: { id: `resp_${messages}` } })) + }) + }) + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + httpFetch: mockFetch(async () => new Response("http")), + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest()) + expect(await first.text()).toContain("data: [DONE]") + + const second = await fetch("https://api.openai.com/v1/responses", streamRequest()) + expect(await second.text()).toContain("data: [DONE]") + expect(connections).toBe(1) + expect(messages).toBe(2) + fetch.close() + }) + + test("rotates a socket that exceeds max connection age", async () => { + let connections = 0 + await using server = await createWebSocketServer((socket) => { + connections += 1 + socket.on("message", () => { + socket.send(JSON.stringify({ type: "response.completed", response: { id: `resp_${connections}` } })) + }) + }) + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + httpFetch: mockFetch(async () => new Response("http")), + maxConnectionAge: 0, + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest()) + expect(await first.text()).toContain("data: [DONE]") + + const second = await fetch("https://api.openai.com/v1/responses", streamRequest()) + expect(await second.text()).toContain("data: [DONE]") + expect(connections).toBe(2) + fetch.close() + }) + + test("falls back to HTTP when websocket setup fails and keeps the fallback sticky", async () => { + const attempts: string[] = [] + await using server = await createRejectingWebSocketServer(() => attempts.push("websocket")) + const httpRequests: Headers[] = [] + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + httpFetch: mockFetch(async (_input, init) => { + httpRequests.push(new Headers(init?.headers)) + return new Response("http") + }), + connectTimeout: 100, + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest({ [TITLE_HEADER]: "false" })) + const second = await fetch("https://api.openai.com/v1/responses", streamRequest({ [TITLE_HEADER]: "false" })) + + expect(await first.text()).toBe("http") + expect(await second.text()).toBe("http") + expect(attempts).toEqual(["websocket"]) + expect(httpRequests).toHaveLength(2) + expect(httpRequests[0]?.get(TITLE_HEADER)).toBeNull() + expect(httpRequests[1]?.get(TITLE_HEADER)).toBeNull() + fetch.close() + }) + + test("invalidates but does not reuse a socket after terminal failure frames", async () => { + let connections = 0 + await using server = await createWebSocketServer((socket) => { + connections += 1 + socket.once("message", () => { + socket.send(JSON.stringify({ type: connections === 1 ? "response.failed" : "response.completed" })) + }) + }) + const httpRequests: Headers[] = [] + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + httpFetch: mockFetch(async (_input, init) => { + httpRequests.push(new Headers(init?.headers)) + return new Response("http") + }), + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest()) + expect(await first.text()).toContain('data: {"type":"response.failed"}') + + const second = await fetch("https://api.openai.com/v1/responses", streamRequest()) + expect(await second.text()).toContain('data: {"type":"response.completed"}') + expect(connections).toBe(2) + expect(httpRequests).toHaveLength(0) + fetch.close() + }) + + test("reconnects and replays after websocket connection limit errors", async () => { + let connections = 0 + let messages = 0 + await using server = await createWebSocketServer((socket) => { + connections += 1 + socket.once("message", () => { + messages += 1 + if (connections === 1) { + socket.send( + JSON.stringify({ + type: "error", + status: 400, + error: { + type: "invalid_request_error", + code: "websocket_connection_limit_reached", + message: "Responses websocket connection limit reached", + }, + }), + ) + return + } + socket.send(JSON.stringify({ type: "response.completed", response: { id: "resp_retry" } })) + }) + }) + const httpRequests: Headers[] = [] + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + httpFetch: mockFetch(async (_input, init) => { + httpRequests.push(new Headers(init?.headers)) + return new Response("http") + }), + }) + + const response = await fetch("https://api.openai.com/v1/responses", streamRequest()) + const text = await response.text() + + expect(text).not.toContain("websocket_connection_limit_reached") + expect(text).toContain('data: {"type":"response.completed","response":{"id":"resp_retry"}}') + expect(text).toContain("data: [DONE]") + expect(connections).toBe(2) + expect(messages).toBe(2) + expect(httpRequests).toHaveLength(0) + fetch.close() + }) + + test("falls back to HTTP after websocket connection limit retries are exhausted", async () => { + let connections = 0 + await using server = await createWebSocketServer((socket) => { + connections += 1 + socket.once("message", () => { + socket.send( + JSON.stringify({ + type: "error", + status: 400, + error: { + type: "invalid_request_error", + code: "websocket_connection_limit_reached", + message: "Responses websocket connection limit reached", + }, + }), + ) + }) + }) + let httpRequests = 0 + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + connectionLimitRetries: 2, + httpFetch: mockFetch(async () => { + httpRequests += 1 + return new Response("http") + }), + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest()) + const second = await fetch("https://api.openai.com/v1/responses", streamRequest()) + + expect(await first.text()).toBe("http") + expect(await second.text()).toBe("http") + expect(connections).toBe(3) + expect(httpRequests).toBe(2) + fetch.close() + }) + + test("replays over HTTP when websocket idles before its first event", async () => { + let connections = 0 + await using server = await createWebSocketServer((socket) => { + connections += 1 + socket.once("message", () => {}) + }) + const httpRequests: Headers[] = [] + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + idleTimeout: 20, + httpFetch: mockFetch(async (_input, init) => { + httpRequests.push(new Headers(init?.headers)) + return new Response("http") + }), + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest()) + expect(await first.text()).toBe("http") + const second = await fetch("https://api.openai.com/v1/responses", streamRequest()) + + expect(await second.text()).toBe("http") + expect(connections).toBe(1) + expect(httpRequests).toHaveLength(2) + fetch.close() + }) + + test("does not replay over HTTP after a websocket event was emitted", async () => { + await using server = await createWebSocketServer((socket) => { + socket.once("message", () => { + socket.send(JSON.stringify({ type: "response.output_text.delta", delta: "started" })) + }) + }) + const httpRequests: Headers[] = [] + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + idleTimeout: 20, + httpFetch: mockFetch(async (_input, init) => { + httpRequests.push(new Headers(init?.headers)) + return new Response("http") + }), + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest()) + await expect(first.text()).rejects.toThrow("idle timeout waiting for websocket") + const second = await fetch("https://api.openai.com/v1/responses", streamRequest()) + + expect(await second.text()).toBe("http") + expect(httpRequests).toHaveLength(1) + fetch.close() + }) + + test("falls back to HTTP for missing session and title requests", async () => { + const httpRequests: Headers[] = [] + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + httpFetch: mockFetch(async (_input, init) => { + httpRequests.push(new Headers(init?.headers)) + return new Response("http") + }), + }) + + const missingSession = await fetch("https://api.openai.com/v1/responses", { + method: "POST", + headers: { [TITLE_HEADER]: "false" }, + body: JSON.stringify({ stream: true }), + }) + const title = await fetch("https://api.openai.com/v1/responses", streamRequest({ [TITLE_HEADER]: "true" })) + + expect(await missingSession.text()).toBe("http") + expect(await title.text()).toBe("http") + expect(httpRequests).toHaveLength(2) + expect(httpRequests[0]?.get(TITLE_HEADER)).toBeNull() + expect(httpRequests[1]?.get(TITLE_HEADER)).toBeNull() + fetch.close() + }) + + test("falls back to HTTP while a websocket lane is busy", async () => { + let connections = 0 + await using server = await createWebSocketServer((socket) => { + connections += 1 + socket.once("message", () => { + socket.send(JSON.stringify({ type: "response.output_text.delta", delta: "started" })) + }) + }) + const abort = new AbortController() + const httpRequests: Headers[] = [] + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + httpFetch: mockFetch(async (_input, init) => { + httpRequests.push(new Headers(init?.headers)) + return new Response("http") + }), + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest({}, abort.signal)) + const firstText = first.text() + await waitFor(() => connections === 1, "websocket did not connect") + const second = await fetch("https://api.openai.com/v1/responses", streamRequest()) + + expect(await second.text()).toBe("http") + expect(httpRequests).toHaveLength(1) + expect(connections).toBe(1) + abort.abort(new Error("stop")) + await expect(firstText).rejects.toThrow("stop") + fetch.close() + }) + + test("reserves a websocket lane while its socket is connecting", async () => { + await using server = await createHangingTcpServer() + let httpRequests = 0 + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + connectTimeout: 20, + httpFetch: mockFetch(async () => { + httpRequests += 1 + return new Response("http") + }), + }) + + const first = fetch("https://api.openai.com/v1/responses", streamRequest()) + await waitFor(() => server.connections() === 1, "first websocket did not begin connecting") + const second = fetch("https://api.openai.com/v1/responses", streamRequest()) + + expect(await (await second).text()).toBe("http") + expect(await (await first).text()).toBe("http") + expect(server.connections()).toBe(1) + expect(httpRequests).toBe(2) + fetch.close() + }) + + test("replays over HTTP after an unexpected close before the first event", async () => { + let connections = 0 + await using server = await createWebSocketServer((socket) => { + connections += 1 + socket.once("message", () => { + socket.close(1001, "server shutdown") + }) + }) + const httpRequests: Headers[] = [] + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + httpFetch: mockFetch(async (_input, init) => { + httpRequests.push(new Headers(init?.headers)) + return new Response("http") + }), + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest()) + expect(await first.text()).toBe("http") + const second = await fetch("https://api.openai.com/v1/responses", streamRequest()) + + expect(await second.text()).toBe("http") + expect(connections).toBe(1) + expect(httpRequests).toHaveLength(2) + fetch.close() + }) + + test("does not keep HTTP fallback active after aborting a websocket response", async () => { + let connections = 0 + await using server = await createWebSocketServer((socket) => { + connections += 1 + socket.once("message", () => { + if (connections === 1) { + socket.send(JSON.stringify({ type: "response.output_text.delta", delta: "started" })) + return + } + socket.send(JSON.stringify({ type: "response.completed", response: { id: "resp_456" } })) + }) + }) + const httpRequests: Headers[] = [] + const abort = new AbortController() + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + httpFetch: mockFetch(async (_input, init) => { + httpRequests.push(new Headers(init?.headers)) + return new Response("http") + }), + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest({}, abort.signal)) + const firstText = first.text() + await waitFor(() => connections === 1, "first websocket did not connect") + abort.abort(new Error("stop")) + await expect(firstText).rejects.toThrow("stop") + + const second = await fetch("https://api.openai.com/v1/responses", streamRequest()) + + expect(await second.text()).toContain("data: [DONE]") + expect(connections).toBe(2) + expect(httpRequests).toHaveLength(0) + fetch.close() + }) + + test("releases the websocket lane when the response body is cancelled", async () => { + let connections = 0 + await using server = await createWebSocketServer((socket) => { + connections += 1 + socket.once("message", () => { + if (connections === 1) { + socket.send(JSON.stringify({ type: "response.output_text.delta", delta: "started" })) + return + } + socket.send(JSON.stringify({ type: "response.completed", response: { id: "resp_after_cancel" } })) + }) + }) + const httpRequests: Headers[] = [] + const fetch = OpenAIWebSocketPool.createWebSocketFetch({ + url: server.url, + httpFetch: mockFetch(async (_input, init) => { + httpRequests.push(new Headers(init?.headers)) + return new Response("http") + }), + }) + + const first = await fetch("https://api.openai.com/v1/responses", streamRequest()) + await waitFor(() => connections === 1, "first websocket did not connect") + await first.body!.cancel("stop") + + const second = await fetch("https://api.openai.com/v1/responses", streamRequest()) + + expect(await second.text()).toContain("data: [DONE]") + expect(connections).toBe(2) + expect(httpRequests).toHaveLength(0) + fetch.close() + }) +}) + +function streamRequest(headers?: Record, signal?: AbortSignal): RequestInit { + return { + method: "POST", + headers: { + "session-id": "session-1", + authorization: "Bearer test", + ...headers, + }, + body: JSON.stringify({ stream: true, input: "hi" }), + signal, + } +} + +function mockFetch( + fn: (input: Parameters[0], init: Parameters[1]) => ReturnType, +): typeof globalThis.fetch { + return Object.assign(fn, { preconnect: globalThis.fetch.preconnect }) +} + +async function createWebSocketServer(onConnection: (socket: WebSocket, request: IncomingMessage) => void) { + const server = new WebSocketServer({ host: "127.0.0.1", port: 0 }) + server.on("connection", onConnection) + await new Promise((resolve) => server.once("listening", resolve)) + return websocketServerHandle(server) +} + +async function createHangingTcpServer() { + const sockets = new Set() + let connections = 0 + const server = net.createServer((socket) => { + connections += 1 + sockets.add(socket) + socket.on("close", () => sockets.delete(socket)) + }) + await new Promise((resolve) => server.listen(0, "127.0.0.1", resolve)) + const address = server.address() as AddressInfo + return { + url: `http://127.0.0.1:${address.port}/v1/responses`, + wsUrl: `ws://127.0.0.1:${address.port}/v1/responses`, + connections: () => connections, + async [Symbol.asyncDispose]() { + for (const socket of sockets) socket.destroy() + server.close() + }, + } +} + +async function createRejectingWebSocketServer(onAttempt: () => void) { + const server = new WebSocketServer({ + host: "127.0.0.1", + port: 0, + verifyClient(_info, callback) { + onAttempt() + callback(false, 401, "denied") + }, + }) + await new Promise((resolve) => server.once("listening", resolve)) + return websocketServerHandle(server) +} + +function websocketServerHandle(server: WebSocketServer) { + const address = server.address() as AddressInfo + const url = `http://127.0.0.1:${address.port}/v1/responses` + return { + url, + wsUrl: url.replace(/^http/, "ws"), + async [Symbol.asyncDispose]() { + for (const socket of server.clients) socket.terminate() + server.close() + }, + } +} + +async function waitFor(predicate: () => boolean, message: string) { + const started = Date.now() + while (!predicate()) { + if (Date.now() - started > 1_000) throw new Error(message) + await new Promise((resolve) => setTimeout(resolve, 1)) + } +}