mirror of
https://github.com/block/goose.git
synced 2026-06-02 06:19:33 +02:00
gpt 5-3-Codex model support in databricks (#7516)
This commit is contained in:
@@ -1,18 +1,28 @@
|
||||
use anyhow::Result;
|
||||
use async_stream::try_stream;
|
||||
use async_trait::async_trait;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use tokio::pin;
|
||||
use tokio_util::codec::{FramedRead, LinesCodec};
|
||||
use tokio_util::io::StreamReader;
|
||||
|
||||
use super::api_client::{ApiClient, AuthMethod, AuthProvider};
|
||||
use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata};
|
||||
use super::embedding::EmbeddingCapable;
|
||||
use super::errors::ProviderError;
|
||||
use super::formats::databricks::create_request;
|
||||
use super::formats::openai_responses::{
|
||||
create_responses_request, responses_api_to_streaming_message,
|
||||
};
|
||||
use super::oauth;
|
||||
use super::openai_compatible::{
|
||||
handle_response_openai_compat, map_http_error_to_provider_error, stream_openai_compat,
|
||||
handle_response_openai_compat, handle_status_openai_compat, map_http_error_to_provider_error,
|
||||
stream_openai_compat,
|
||||
};
|
||||
use super::retry::ProviderRetry;
|
||||
use super::utils::{ImageFormat, RequestLog};
|
||||
@@ -208,9 +218,16 @@ impl DatabricksProvider {
|
||||
})
|
||||
}
|
||||
|
||||
fn is_responses_model(model_name: &str) -> bool {
|
||||
let normalized = model_name.to_ascii_lowercase();
|
||||
normalized.contains("codex")
|
||||
}
|
||||
|
||||
fn get_endpoint_path(&self, model_name: &str, is_embedding: bool) -> String {
|
||||
if is_embedding {
|
||||
"serving-endpoints/text-embedding-3-small/invocations".to_string()
|
||||
} else if Self::is_responses_model(model_name) {
|
||||
"serving-endpoints/responses".to_string()
|
||||
} else {
|
||||
format!("serving-endpoints/{}/invocations", model_name)
|
||||
}
|
||||
@@ -282,42 +299,79 @@ impl Provider for DatabricksProvider {
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<MessageStream, ProviderError> {
|
||||
let mut payload =
|
||||
create_request(model_config, system, messages, tools, &self.image_format)?;
|
||||
payload
|
||||
.as_object_mut()
|
||||
.expect("payload should have model key")
|
||||
.remove("model");
|
||||
|
||||
payload
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.insert("stream".to_string(), Value::Bool(true));
|
||||
|
||||
let path = self.get_endpoint_path(&model_config.model_name, false);
|
||||
let mut log = RequestLog::start(model_config, &payload)?;
|
||||
let response = self
|
||||
.with_retry(|| async {
|
||||
let resp = self
|
||||
.api_client
|
||||
.response_post(Some(session_id), &path, &payload)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let error_text = resp.text().await.unwrap_or_default();
|
||||
|
||||
// Parse as JSON if possible to pass to map_http_error_to_provider_error
|
||||
let json_payload = serde_json::from_str::<Value>(&error_text).ok();
|
||||
return Err(map_http_error_to_provider_error(status, json_payload));
|
||||
if Self::is_responses_model(&model_config.model_name) {
|
||||
let mut payload = create_responses_request(model_config, system, messages, tools)?;
|
||||
payload["stream"] = Value::Bool(true);
|
||||
|
||||
let mut log = RequestLog::start(model_config, &payload)?;
|
||||
|
||||
let response = self
|
||||
.with_retry(|| async {
|
||||
let payload_clone = payload.clone();
|
||||
let resp = self
|
||||
.api_client
|
||||
.response_post(Some(session_id), &path, &payload_clone)
|
||||
.await?;
|
||||
handle_status_openai_compat(resp).await
|
||||
})
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
let _ = log.error(e);
|
||||
})?;
|
||||
|
||||
let stream = response.bytes_stream().map_err(io::Error::other);
|
||||
|
||||
Ok(Box::pin(try_stream! {
|
||||
let stream_reader = StreamReader::new(stream);
|
||||
let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from);
|
||||
|
||||
let message_stream = responses_api_to_streaming_message(framed);
|
||||
pin!(message_stream);
|
||||
while let Some(message) = message_stream.next().await {
|
||||
let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
|
||||
log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
|
||||
yield (message, usage);
|
||||
}
|
||||
Ok(resp)
|
||||
})
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
let _ = log.error(e);
|
||||
})?;
|
||||
}))
|
||||
} else {
|
||||
let mut payload =
|
||||
create_request(model_config, system, messages, tools, &self.image_format)?;
|
||||
payload
|
||||
.as_object_mut()
|
||||
.expect("payload should have model key")
|
||||
.remove("model");
|
||||
|
||||
stream_openai_compat(response, log)
|
||||
payload
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.insert("stream".to_string(), Value::Bool(true));
|
||||
|
||||
let mut log = RequestLog::start(model_config, &payload)?;
|
||||
let response = self
|
||||
.with_retry(|| async {
|
||||
let resp = self
|
||||
.api_client
|
||||
.response_post(Some(session_id), &path, &payload)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let error_text = resp.text().await.unwrap_or_default();
|
||||
|
||||
// Parse as JSON if possible to pass to map_http_error_to_provider_error
|
||||
let json_payload = serde_json::from_str::<Value>(&error_text).ok();
|
||||
return Err(map_http_error_to_provider_error(status, json_payload));
|
||||
}
|
||||
Ok(resp)
|
||||
})
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
let _ = log.error(e);
|
||||
})?;
|
||||
|
||||
stream_openai_compat(response, log)
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_embeddings(&self) -> bool {
|
||||
|
||||
Reference in New Issue
Block a user