feat: support mistral (#324)

pull/326/head
sigoden 4 months ago committed by GitHub
parent c538533014
commit 75fe0b9205
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -45,11 +45,12 @@ Download it from [GitHub Releases](https://github.com/sigoden/aichat/releases),
- OpenAI: gpt-3.5/gpt-4/gpt-4-vision
- Gemini: gemini-pro/gemini-pro-vision
- Claude: claude-instant-1.2/claude-2.0/claude-2.1
- Mistral: mistral-small/mistral-medium/mistral-large
- LocalAI: opensource LLMs and other openai-compatible LLMs
- Ollama: opensource LLMs
- VertexAI: gemini-pro/gemini-pro-vision/gemini-ultra/gemini-ultra-vision
- Claude: claude-instant-1.2/claude-2.0/claude-2.1
- Azure-OpenAI: user deployed gpt-3.5/gpt-4
- VertexAI: gemini-pro/gemini-pro-vision/gemini-ultra/gemini-ultra-vision
- Ernie: ernie-bot-turbo/ernie-bot/ernie-bot-8k/ernie-bot-4
- Qianwen: qwen-turbo/qwen-plus/qwen-max/qwen-max-longcontext/qwen-vl-plus

@ -31,13 +31,20 @@ clients:
- type: gemini
api_key: AIxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# See https://docs.anthropic.com/claude/reference/getting-started-with-the-api
- type: claude
api_key: sk-xxx
- type: mistral
api_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# Any openai-compatible API providers or https://github.com/go-skynet/LocalAI
- type: localai
api_base: http://localhost:8080/v1
api_key: xxx
chat_endpoint: /chat/completions # Optional field
models:
- name: mistral
- name: llama2
max_tokens: 8192
extra_fields: # Optional field, set custom parameters
key: value
@ -62,6 +69,14 @@ clients:
- name: MyGPT4 # Model deployment name
max_tokens: 8192
# See https://cloud.google.com/vertex-ai
- type: vertexai
api_base: https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models
# Setup Application Default Credentials (ADC) file, Optional field
# Run `gcloud auth application-default login` to setup adc
# see https://cloud.google.com/docs/authentication/external/set-up-adc
adc_file: <path-to/gcloud/application_default_credentials.json>
# See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html
- type: ernie
api_key: xxxxxxxxxxxxxxxxxxxxxxxx
@ -70,15 +85,3 @@ clients:
# See https://help.aliyun.com/zh/dashscope/
- type: qianwen
api_key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# See https://docs.anthropic.com/claude/reference/getting-started-with-the-api
- type: claude
api_key: xxx
# See https://cloud.google.com/vertex-ai
- type: vertexai
api_base: https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models
# Setup Application Default Credentials (ADC) file, Optional field
# Run `gcloud auth application-default login` to setup adc
# see https://cloud.google.com/docs/authentication/external/set-up-adc
adc_file: <path-to/gcloud/application_default_credentials.json>

@ -0,0 +1,68 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::{ExtraConfig, MistralClient, Model, PromptType, SendData};
use crate::utils::PromptKind;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
const API_URL: &str = "https://api.mistral.ai/v1/chat/completions";
const MODELS: [(&str, usize, &str); 5] = [
("mistral-small-latest", 32000, "text"),
("mistral-medium-latest", 32000, "text"),
("mistral-larget-latest", 32000, "text"),
("open-mistral-7b", 32000, "text"),
("open-mixtral-8x7b", 32000, "text"),
];
#[derive(Debug, Clone, Deserialize)]
pub struct MistralConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub extra: Option<ExtraConfig>,
}
openai_compatible_client!(MistralClient);
impl MistralClient {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] = [
("api_key", "API Key:", false, PromptKind::String),
];
pub fn list_models(local_config: &MistralConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let mut body = openai_build_body(data, self.model.name.clone());
self.model.merge_extra_fields(&mut body);
let url = API_URL;
debug!("Mistral Request: {url} {body}");
let mut builder = client.post(url).json(&body);
if let Some(api_key) = api_key {
builder = builder.bearer_auth(api_key);
}
Ok(builder)
}
}

@ -11,6 +11,7 @@ register_client!(
(openai, "openai", OpenAIConfig, OpenAIClient),
(gemini, "gemini", GeminiConfig, GeminiClient),
(claude, "claude", ClaudeConfig, ClaudeClient),
(mistral, "mistral", MistralConfig, MistralClient),
(localai, "localai", LocalAIConfig, LocalAIClient),
(ollama, "ollama", OllamaConfig, OllamaClient),
(

Loading…
Cancel
Save