mirror of
https://github.com/aaif-goose/goose.git
synced 2026-06-01 22:09:18 +02:00
fix: make azure api-version query param optional (#9221)
Signed-off-by: Douwe Osinga <douwe@squareup.com> Co-authored-by: Douwe Osinga <douwe@squareup.com>
This commit is contained in:
@@ -12,9 +12,15 @@ const AZURE_PROVIDER_NAME: &str = "azure_openai";
|
||||
pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o";
|
||||
pub const AZURE_DOC_URL: &str =
|
||||
"https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
|
||||
pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21";
|
||||
const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21";
|
||||
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"];
|
||||
|
||||
/// New-style Azure AI endpoints use `/v1/` paths and reject the `api-version` query param.
|
||||
fn is_v1_endpoint(endpoint: &str) -> bool {
|
||||
let normalized = endpoint.trim_end_matches('/');
|
||||
normalized.ends_with("/v1") || endpoint.contains("/v1/")
|
||||
}
|
||||
|
||||
pub struct AzureProvider;
|
||||
|
||||
// Custom auth provider that wraps AzureAuth
|
||||
@@ -57,13 +63,7 @@ impl ProviderDef for AzureProvider {
|
||||
vec![
|
||||
ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None, true),
|
||||
ConfigKey::new("AZURE_OPENAI_DEPLOYMENT_NAME", true, false, None, true),
|
||||
ConfigKey::new(
|
||||
"AZURE_OPENAI_API_VERSION",
|
||||
true,
|
||||
false,
|
||||
Some("2024-10-21"),
|
||||
false,
|
||||
),
|
||||
ConfigKey::new("AZURE_OPENAI_API_VERSION", false, false, None, false),
|
||||
ConfigKey::new("AZURE_OPENAI_API_KEY", false, true, Some(""), true),
|
||||
],
|
||||
)
|
||||
@@ -77,9 +77,16 @@ impl ProviderDef for AzureProvider {
|
||||
let config = crate::config::Config::global();
|
||||
let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?;
|
||||
let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?;
|
||||
let api_version: String = config
|
||||
let api_version: Option<String> = config
|
||||
.get_param("AZURE_OPENAI_API_VERSION")
|
||||
.unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string());
|
||||
.ok()
|
||||
.or_else(|| {
|
||||
if is_v1_endpoint(&endpoint) {
|
||||
None
|
||||
} else {
|
||||
Some(AZURE_DEFAULT_API_VERSION.to_string())
|
||||
}
|
||||
});
|
||||
|
||||
let api_key = config
|
||||
.get_secret("AZURE_OPENAI_API_KEY")
|
||||
@@ -92,8 +99,10 @@ impl ProviderDef for AzureProvider {
|
||||
|
||||
let auth_provider = AzureAuthProvider { auth };
|
||||
let host = format!("{}/openai", endpoint.trim_end_matches('/'));
|
||||
let api_client = ApiClient::new(host, AuthMethod::Custom(Box::new(auth_provider)))?
|
||||
.with_query(vec![("api-version".to_string(), api_version)]);
|
||||
let mut api_client = ApiClient::new(host, AuthMethod::Custom(Box::new(auth_provider)))?;
|
||||
if let Some(version) = api_version {
|
||||
api_client = api_client.with_query(vec![("api-version".to_string(), version)]);
|
||||
}
|
||||
|
||||
Ok(OpenAiCompatibleProvider::new(
|
||||
AZURE_PROVIDER_NAME.to_string(),
|
||||
@@ -104,3 +113,27 @@ impl ProviderDef for AzureProvider {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_v1_endpoint() {
|
||||
assert!(is_v1_endpoint(
|
||||
"https://my-resource.services.ai.azure.com/api/projects/my-proj/openai/v1"
|
||||
));
|
||||
assert!(is_v1_endpoint(
|
||||
"https://my-resource.services.ai.azure.com/api/projects/my-proj/openai/v1/"
|
||||
));
|
||||
assert!(is_v1_endpoint(
|
||||
"https://my-resource.services.ai.azure.com/v1/some/path"
|
||||
));
|
||||
|
||||
assert!(!is_v1_endpoint("https://my-resource.openai.azure.com"));
|
||||
assert!(!is_v1_endpoint("https://my-resource.openai.azure.com/"));
|
||||
assert!(!is_v1_endpoint(
|
||||
"https://my-resource.openai.azure.com/openai"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user