mirror of
https://github.com/block/goose.git
synced 2026-06-02 06:19:33 +02:00
Ollama improvements (#5609)
Co-authored-by: Douwe Osinga <douwe@squareup.com>
This commit is contained in:
@@ -35,7 +35,35 @@ pub enum ProviderError {
|
||||
impl From<anyhow::Error> for ProviderError {
|
||||
fn from(error: anyhow::Error) -> Self {
|
||||
if let Some(reqwest_err) = error.downcast_ref::<reqwest::Error>() {
|
||||
return ProviderError::RequestFailed(reqwest_err.to_string());
|
||||
let mut details = vec![];
|
||||
|
||||
if let Some(status) = reqwest_err.status() {
|
||||
details.push(format!("status: {}", status));
|
||||
}
|
||||
if reqwest_err.is_timeout() {
|
||||
details.push("timeout".to_string());
|
||||
}
|
||||
if reqwest_err.is_connect() {
|
||||
if let Some(url) = reqwest_err.url() {
|
||||
if let Some(host) = url.host_str() {
|
||||
let port_info = url.port().map(|p| format!(":{}", p)).unwrap_or_default();
|
||||
|
||||
details.push(format!("failed to connect to {}{}", host, port_info));
|
||||
|
||||
if url.port().is_some() {
|
||||
details.push("check that the port is correct".to_string());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
details.push("connection failed".to_string());
|
||||
}
|
||||
}
|
||||
let msg = if details.is_empty() {
|
||||
reqwest_err.to_string()
|
||||
} else {
|
||||
format!("{} ({})", reqwest_err, details.join(", "))
|
||||
};
|
||||
return ProviderError::RequestFailed(msg);
|
||||
}
|
||||
ProviderError::ExecutionError(error.to_string())
|
||||
}
|
||||
|
||||
@@ -55,9 +55,6 @@ struct StreamingChunk {
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
/// Convert internal Message format to OpenAI's API message specification
|
||||
/// some openai compatible endpoints use the anthropic image spec at the content level
|
||||
/// even though the message structure is otherwise following openai, the enum switches this
|
||||
pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<Value> {
|
||||
let mut messages_spec = Vec::new();
|
||||
for message in messages.iter().filter(|m| m.is_agent_visible()) {
|
||||
@@ -256,7 +253,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
messages_spec
|
||||
}
|
||||
|
||||
/// Convert internal Tool format to OpenAI's API tool specification
|
||||
pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
|
||||
let mut tool_names = std::collections::HashSet::new();
|
||||
let mut result = Vec::new();
|
||||
@@ -270,7 +266,6 @@ pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
// do not silently truncate description
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema,
|
||||
}
|
||||
|
||||
@@ -31,10 +31,9 @@ use tokio_util::io::StreamReader;
|
||||
use url::Url;
|
||||
|
||||
pub const OLLAMA_HOST: &str = "localhost";
|
||||
pub const OLLAMA_TIMEOUT: u64 = 600; // seconds
|
||||
pub const OLLAMA_TIMEOUT: u64 = 600;
|
||||
pub const OLLAMA_DEFAULT_PORT: u16 = 11434;
|
||||
pub const OLLAMA_DEFAULT_MODEL: &str = "qwen3";
|
||||
// Ollama can run many models, we only provide the default
|
||||
pub const OLLAMA_KNOWN_MODELS: &[&str] = &[
|
||||
OLLAMA_DEFAULT_MODEL,
|
||||
"qwen3-coder:30b",
|
||||
@@ -61,7 +60,6 @@ impl OllamaProvider {
|
||||
let timeout: Duration =
|
||||
Duration::from_secs(config.get_param("OLLAMA_TIMEOUT").unwrap_or(OLLAMA_TIMEOUT));
|
||||
|
||||
// OLLAMA_HOST is sometimes just the 'host' or 'host:port' without a scheme
|
||||
let base = if host.starts_with("http://") || host.starts_with("https://") {
|
||||
host.clone()
|
||||
} else {
|
||||
@@ -71,11 +69,6 @@ impl OllamaProvider {
|
||||
let mut base_url =
|
||||
Url::parse(&base).map_err(|e| anyhow::anyhow!("Invalid base URL: {e}"))?;
|
||||
|
||||
// Set the default port if missing
|
||||
// Don't add default port if:
|
||||
// 1. URL/host explicitly contains ports
|
||||
// 2. URL/host uses HTTP/S
|
||||
// 3. only set it for localhost
|
||||
let explicit_port = host.contains(':');
|
||||
let is_localhost = host == "localhost" || host == "127.0.0.1" || host == "::1";
|
||||
|
||||
@@ -86,7 +79,6 @@ impl OllamaProvider {
|
||||
.map_err(|_| anyhow::anyhow!("Failed to set default port"))?;
|
||||
}
|
||||
|
||||
// No authentication for Ollama
|
||||
let auth = AuthMethod::Custom(Box::new(NoAuth));
|
||||
let api_client = ApiClient::with_timeout(base_url.to_string(), auth, timeout)?;
|
||||
|
||||
@@ -104,7 +96,6 @@ impl OllamaProvider {
|
||||
) -> Result<Self> {
|
||||
let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(OLLAMA_TIMEOUT));
|
||||
|
||||
// Parse and normalize the custom URL
|
||||
let base =
|
||||
if config.base_url.starts_with("http://") || config.base_url.starts_with("https://") {
|
||||
config.base_url.clone()
|
||||
@@ -115,7 +106,6 @@ impl OllamaProvider {
|
||||
let mut base_url = Url::parse(&base)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid base URL '{}': {}", config.base_url, e))?;
|
||||
|
||||
// Set default port if missing and not using standard ports
|
||||
let explicit_default_port =
|
||||
config.base_url.ends_with(":80") || config.base_url.ends_with(":443");
|
||||
let is_https = base_url.scheme() == "https";
|
||||
@@ -126,7 +116,6 @@ impl OllamaProvider {
|
||||
.map_err(|_| anyhow::anyhow!("Failed to set default port"))?;
|
||||
}
|
||||
|
||||
// No authentication for Ollama
|
||||
let auth = AuthMethod::Custom(Box::new(NoAuth));
|
||||
let api_client = ApiClient::with_timeout(base_url.to_string(), auth, timeout)?;
|
||||
|
||||
@@ -147,13 +136,11 @@ impl OllamaProvider {
|
||||
}
|
||||
}
|
||||
|
||||
// No authentication provider for Ollama
|
||||
struct NoAuth;
|
||||
|
||||
#[async_trait]
|
||||
impl super::api_client::AuthProvider for NoAuth {
|
||||
async fn get_auth_header(&self) -> Result<(String, String)> {
|
||||
// Return a dummy header that won't be used
|
||||
Ok(("X-No-Auth".to_string(), "true".to_string()))
|
||||
}
|
||||
}
|
||||
@@ -208,32 +195,35 @@ impl Provider for OllamaProvider {
|
||||
};
|
||||
|
||||
let payload = create_request(
|
||||
&self.model,
|
||||
model_config,
|
||||
system,
|
||||
messages,
|
||||
filtered_tools,
|
||||
&super::utils::ImageFormat::OpenAi,
|
||||
)?;
|
||||
|
||||
let mut log = RequestLog::start(model_config, &payload)?;
|
||||
let response = self
|
||||
.with_retry(|| async {
|
||||
let payload_clone = payload.clone();
|
||||
self.post(&payload_clone).await
|
||||
})
|
||||
.await?;
|
||||
let message = response_to_message(&response.clone())?;
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
let _ = log.error(e);
|
||||
})?;
|
||||
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
|
||||
tracing::debug!("Failed to get usage data");
|
||||
Usage::default()
|
||||
});
|
||||
let response_model = get_model(&response);
|
||||
let mut log = RequestLog::start(model_config, &payload)?;
|
||||
log.write(&response, Some(&usage))?;
|
||||
Ok((message, ProviderUsage::new(response_model, usage)))
|
||||
}
|
||||
|
||||
/// Generate a session name based on the conversation history
|
||||
/// This override filters out reasoning tokens that some Ollama models produce
|
||||
async fn generate_session_name(
|
||||
&self,
|
||||
messages: &Conversation,
|
||||
@@ -275,17 +265,29 @@ impl Provider for OllamaProvider {
|
||||
payload["stream_options"] = json!({
|
||||
"include_usage": true,
|
||||
});
|
||||
let mut log = RequestLog::start(&self.model, &payload)?;
|
||||
|
||||
let response = self
|
||||
.api_client
|
||||
.response_post("v1/chat/completions", &payload)
|
||||
.await?;
|
||||
.with_retry(|| async {
|
||||
let resp = self
|
||||
.api_client
|
||||
.response_post("v1/chat/completions", &payload)
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
return Err(super::utils::map_http_error_to_provider_error(status, None));
|
||||
}
|
||||
Ok(resp)
|
||||
})
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
let _ = log.error(e);
|
||||
})?;
|
||||
let response = handle_status_openai_compat(response).await?;
|
||||
|
||||
let stream = response.bytes_stream().map_err(io::Error::other);
|
||||
let model_config = self.model.clone();
|
||||
|
||||
Ok(Box::pin(try_stream! {
|
||||
let mut log = RequestLog::start(&model_config, &payload)?;
|
||||
let stream_reader = StreamReader::new(stream);
|
||||
let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from);
|
||||
let message_stream = response_to_streaming_message(framed);
|
||||
@@ -328,7 +330,6 @@ impl Provider for OllamaProvider {
|
||||
.filter_map(|model| model.get("name").and_then(|n| n.as_str()).map(String::from))
|
||||
.collect();
|
||||
|
||||
// Sort alphabetically
|
||||
model_names.sort();
|
||||
|
||||
Ok(Some(model_names))
|
||||
@@ -336,11 +337,9 @@ impl Provider for OllamaProvider {
|
||||
}
|
||||
|
||||
impl OllamaProvider {
|
||||
/// Filter out reasoning tokens and thinking patterns from model responses
|
||||
fn filter_reasoning_tokens(text: &str) -> String {
|
||||
let mut filtered = text.to_string();
|
||||
|
||||
// Remove common reasoning patterns
|
||||
let reasoning_patterns = [
|
||||
r"<think>.*?</think>",
|
||||
r"<thinking>.*?</thinking>",
|
||||
@@ -361,13 +360,11 @@ impl OllamaProvider {
|
||||
filtered = re.replace_all(&filtered, "").to_string();
|
||||
}
|
||||
}
|
||||
// Remove any remaining thinking markers
|
||||
filtered = filtered
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.replace("<thinking>", "")
|
||||
.replace("</thinking>", "");
|
||||
// Clean up extra whitespace
|
||||
filtered = filtered
|
||||
.lines()
|
||||
.map(|line| line.trim())
|
||||
|
||||
@@ -225,11 +225,6 @@ impl Provider for VeniceProvider {
|
||||
let response = self.api_client.response_get(&self.models_path).await?;
|
||||
let json: serde_json::Value = response.json().await?;
|
||||
|
||||
// Print legend once so users know what flags mean
|
||||
println!(
|
||||
"Capabilities:\n c=code\n f=function calls (goose supported models)\n s=schema\n v=vision\n w=web search\n r=reasoning"
|
||||
);
|
||||
|
||||
let mut models = json["data"]
|
||||
.as_array()
|
||||
.ok_or_else(|| ProviderError::RequestFailed("No data field in JSON".to_string()))?
|
||||
|
||||
@@ -186,9 +186,7 @@ export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
|
||||
try {
|
||||
const response = await apiGetProviderModels({
|
||||
path: { name: providerName },
|
||||
headers: {
|
||||
'X-Secret-Key': await window.electron.getSecretKey(),
|
||||
},
|
||||
throwOnError: true,
|
||||
});
|
||||
return response.data || [];
|
||||
} catch (error) {
|
||||
|
||||
Reference in New Issue
Block a user