gpt 5-3-Codex model support in databricks (#7516)

This commit is contained in:
David Katz
2026-02-25 15:25:34 -05:00
committed by GitHub
parent 785818bb87
commit ca34455e6e
+87 -33
View File
@@ -1,18 +1,28 @@
use anyhow::Result;
use async_stream::try_stream;
use async_trait::async_trait;
use futures::future::BoxFuture;
use futures::{StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::io;
use std::time::Duration;
use tokio::pin;
use tokio_util::codec::{FramedRead, LinesCodec};
use tokio_util::io::StreamReader;
use super::api_client::{ApiClient, AuthMethod, AuthProvider};
use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata};
use super::embedding::EmbeddingCapable;
use super::errors::ProviderError;
use super::formats::databricks::create_request;
use super::formats::openai_responses::{
create_responses_request, responses_api_to_streaming_message,
};
use super::oauth;
use super::openai_compatible::{
handle_response_openai_compat, map_http_error_to_provider_error, stream_openai_compat,
handle_response_openai_compat, handle_status_openai_compat, map_http_error_to_provider_error,
stream_openai_compat,
};
use super::retry::ProviderRetry;
use super::utils::{ImageFormat, RequestLog};
@@ -208,9 +218,16 @@ impl DatabricksProvider {
})
}
fn is_responses_model(model_name: &str) -> bool {
let normalized = model_name.to_ascii_lowercase();
normalized.contains("codex")
}
fn get_endpoint_path(&self, model_name: &str, is_embedding: bool) -> String {
if is_embedding {
"serving-endpoints/text-embedding-3-small/invocations".to_string()
} else if Self::is_responses_model(model_name) {
"serving-endpoints/responses".to_string()
} else {
format!("serving-endpoints/{}/invocations", model_name)
}
@@ -282,42 +299,79 @@ impl Provider for DatabricksProvider {
messages: &[Message],
tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let mut payload =
create_request(model_config, system, messages, tools, &self.image_format)?;
payload
.as_object_mut()
.expect("payload should have model key")
.remove("model");
payload
.as_object_mut()
.unwrap()
.insert("stream".to_string(), Value::Bool(true));
let path = self.get_endpoint_path(&model_config.model_name, false);
let mut log = RequestLog::start(model_config, &payload)?;
let response = self
.with_retry(|| async {
let resp = self
.api_client
.response_post(Some(session_id), &path, &payload)
.await?;
if !resp.status().is_success() {
let status = resp.status();
let error_text = resp.text().await.unwrap_or_default();
// Parse as JSON if possible to pass to map_http_error_to_provider_error
let json_payload = serde_json::from_str::<Value>(&error_text).ok();
return Err(map_http_error_to_provider_error(status, json_payload));
if Self::is_responses_model(&model_config.model_name) {
let mut payload = create_responses_request(model_config, system, messages, tools)?;
payload["stream"] = Value::Bool(true);
let mut log = RequestLog::start(model_config, &payload)?;
let response = self
.with_retry(|| async {
let payload_clone = payload.clone();
let resp = self
.api_client
.response_post(Some(session_id), &path, &payload_clone)
.await?;
handle_status_openai_compat(resp).await
})
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;
let stream = response.bytes_stream().map_err(io::Error::other);
Ok(Box::pin(try_stream! {
let stream_reader = StreamReader::new(stream);
let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from);
let message_stream = responses_api_to_streaming_message(framed);
pin!(message_stream);
while let Some(message) = message_stream.next().await {
let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
yield (message, usage);
}
Ok(resp)
})
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;
}))
} else {
let mut payload =
create_request(model_config, system, messages, tools, &self.image_format)?;
payload
.as_object_mut()
.expect("payload should have model key")
.remove("model");
stream_openai_compat(response, log)
payload
.as_object_mut()
.unwrap()
.insert("stream".to_string(), Value::Bool(true));
let mut log = RequestLog::start(model_config, &payload)?;
let response = self
.with_retry(|| async {
let resp = self
.api_client
.response_post(Some(session_id), &path, &payload)
.await?;
if !resp.status().is_success() {
let status = resp.status();
let error_text = resp.text().await.unwrap_or_default();
// Parse as JSON if possible to pass to map_http_error_to_provider_error
let json_payload = serde_json::from_str::<Value>(&error_text).ok();
return Err(map_http_error_to_provider_error(status, json_payload));
}
Ok(resp)
})
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;
stream_openai_compat(response, log)
}
}
fn supports_embeddings(&self) -> bool {