feat: supports model capabilities (#297)

1. automatically switch to the model that has the necessary capabilities.
2. throw an error if the client does not have a model with the necessary capabilities
pull/302/head
sigoden 5 months ago committed by GitHub
parent 4e99df4c1b
commit fe35cfd941
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

1
Cargo.lock generated

@ -36,6 +36,7 @@ dependencies = [
"async-trait",
"base64",
"bincode",
"bitflags 2.4.1",
"bstr",
"bytes",
"chrono",

@ -43,6 +43,7 @@ log = "0.4.20"
shell-words = "1.1.0"
mime_guess = "2.0.4"
sha2 = "0.10.8"
bitflags = "2.4.1"
[dependencies.reqwest]
version = "0.11.14"

@ -36,8 +36,11 @@ clients:
api_key: xxx
chat_endpoint: /chat/completions # Optional field
models:
- name: gpt4all-j
- name: mistral
max_tokens: 8192
- name: llava
max_tokens: 8192
capabilities: text,vision # Optional field, possible values: text, vision
# See https://github.com/jmorganca/ollama
- type: ollama
@ -45,7 +48,7 @@ clients:
api_key: Basic xxx # Set authorization header
chat_endpoint: /chat # Optional field
models:
- name: gpt4all-j
- name: mistral
max_tokens: 8192
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart

@ -1,5 +1,5 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::{AzureOpenAIClient, ExtraConfig, PromptType, SendData, Model};
use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData};
use crate::utils::PromptKind;
@ -13,16 +13,10 @@ pub struct AzureOpenAIConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
pub models: Vec<AzureOpenAIModel>,
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIModel {
name: String,
max_tokens: Option<usize>,
}
openai_compatible_client!(AzureOpenAIClient);
impl AzureOpenAIClient {
@ -50,6 +44,7 @@ impl AzureOpenAIClient {
.map(|v| {
Model::new(client_name, &v.name)
.set_max_tokens(v.max_tokens)
.set_capabilities(v.capabilities)
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
.collect()

@ -1,4 +1,4 @@
use super::{openai::OpenAIConfig, ClientConfig, Message, MessageContent};
use super::{openai::OpenAIConfig, ClientConfig, Message, MessageContent, Model};
use crate::{
config::{GlobalConfig, Input},
@ -78,12 +78,26 @@ macro_rules! register_client {
)+
pub fn init_client(config: &$crate::config::GlobalConfig) -> anyhow::Result<Box<dyn Client>> {
None
$(.or_else(|| $client::init(config)))+
.ok_or_else(|| {
let model = config.read().model.clone();
anyhow::anyhow!("Unknown client '{}'", &model.client_name)
})
None
$(.or_else(|| $client::init(config)))+
.ok_or_else(|| {
let model = config.read().model.clone();
anyhow::anyhow!("Unknown client '{}'", &model.client_name)
})
}
pub fn ensure_model_capabilities(client: &mut dyn Client, capabilities: $crate::client::ModelCapabilities) -> anyhow::Result<()> {
if !client.model().capabilities.contains(capabilities) {
let models = client.models();
if let Some(model) = models.into_iter().find(|v| v.capabilities.contains(capabilities)) {
client.set_model(model);
} else {
anyhow::bail!(
"The current model lacks the corresponding capability."
);
}
}
Ok(())
}
pub fn list_client_types() -> Vec<&'static str> {
@ -113,19 +127,38 @@ macro_rules! register_client {
};
}
#[macro_export]
macro_rules! client_common_fns {
() => {
fn config(
&self,
) -> (
&$crate::config::GlobalConfig,
&Option<$crate::client::ExtraConfig>,
) {
(&self.global_config, &self.config.extra)
}
fn models(&self) -> Vec<Model> {
Self::list_models(&self.config)
}
fn model(&self) -> &Model {
&self.model
}
fn set_model(&mut self, model: Model) {
self.model = model;
}
};
}
#[macro_export]
macro_rules! openai_compatible_client {
($client:ident) => {
#[async_trait]
impl $crate::client::Client for $crate::client::$client {
fn config(
&self,
) -> (
&$crate::config::GlobalConfig,
&Option<$crate::client::ExtraConfig>,
) {
(&self.global_config, &self.config.extra)
}
client_common_fns!();
async fn send_message_inner(
&self,
@ -170,6 +203,12 @@ macro_rules! config_get_fn {
pub trait Client {
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>);
fn models(&self) -> Vec<Model>;
fn model(&self) -> &Model;
fn set_model(&mut self, model: Model);
fn build_client(&self) -> Result<ReqwestClient> {
let mut builder = ReqwestClient::builder();
let options = self.config().1;

@ -1,10 +1,6 @@
use super::{ErnieClient, Client, ExtraConfig, PromptType, SendData, Model, patch_system_message};
use super::{patch_system_message, Client, ErnieClient, ExtraConfig, Model, PromptType, SendData};
use crate::{
config::GlobalConfig,
render::ReplyHandler,
utils::PromptKind,
};
use crate::{render::ReplyHandler, utils::PromptKind};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
@ -37,9 +33,7 @@ pub struct ErnieConfig {
#[async_trait]
impl Client for ErnieClient {
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
self.prepare_access_token().await?;
@ -127,10 +121,7 @@ async fn send_message(builder: RequestBuilder) -> Result<String> {
Ok(output.to_string())
}
async fn send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
) -> Result<()> {
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
match event {
@ -216,13 +207,12 @@ fn build_body(data: SendData, _model: String) -> Value {
async fn fetch_access_token(api_key: &str, secret_key: &str) -> Result<String> {
let url = format!("{ACCESS_TOKEN_URL}?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}");
let value: Value = reqwest::get(&url).await?.json().await?;
let result = value["access_token"].as_str()
.ok_or_else(|| {
if let Some(err_msg) = value["error_description"].as_str() {
anyhow!("{err_msg}")
} else {
anyhow!("Invalid response data")
}
})?;
let result = value["access_token"].as_str().ok_or_else(|| {
if let Some(err_msg) = value["error_description"].as_str() {
anyhow!("{err_msg}")
} else {
anyhow!("Invalid response data")
}
})?;
Ok(result.to_string())
}

@ -3,7 +3,7 @@ use super::{
SendData, TokensCountFactors,
};
use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind};
use crate::{render::ReplyHandler, utils::PromptKind};
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
@ -14,10 +14,10 @@ use serde_json::{json, Value};
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models/";
const MODELS: [(&str, usize); 3] = [
("gemini-pro", 32768),
("gemini-pro-vision", 16384),
("gemini-ultra", 32768),
const MODELS: [(&str, usize, &str); 3] = [
("gemini-pro", 32768, "text"),
("gemini-pro-vision", 16384, "vision"),
("gemini-ultra", 32768, "text"),
];
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
@ -31,9 +31,7 @@ pub struct GeminiConfig {
#[async_trait]
impl Client for GeminiClient {
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
@ -61,8 +59,9 @@ impl GeminiClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens)| {
.map(|(name, max_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})

@ -1,5 +1,5 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::{ExtraConfig, LocalAIClient, PromptType, SendData, Model};
use super::{ExtraConfig, LocalAIClient, Model, ModelConfig, PromptType, SendData};
use crate::utils::PromptKind;
@ -14,16 +14,10 @@ pub struct LocalAIConfig {
pub api_base: String,
pub api_key: Option<String>,
pub chat_endpoint: Option<String>,
pub models: Vec<LocalAIModel>,
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LocalAIModel {
name: String,
max_tokens: Option<usize>,
}
openai_compatible_client!(LocalAIClient);
impl LocalAIClient {
@ -49,6 +43,7 @@ impl LocalAIClient {
.iter()
.map(|v| {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_tokens(v.max_tokens)
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})

@ -3,6 +3,7 @@ use super::message::{Message, MessageContent};
use crate::utils::count_tokens;
use anyhow::{bail, Result};
use serde::{Deserialize, Deserializer};
pub type TokensCountFactors = (usize, usize); // (per-messages, bias)
@ -12,6 +13,7 @@ pub struct Model {
pub name: String,
pub max_tokens: Option<usize>,
pub tokens_count_factors: TokensCountFactors,
pub capabilities: ModelCapabilities,
}
impl Default for Model {
@ -27,6 +29,7 @@ impl Model {
name: name.into(),
max_tokens: None,
tokens_count_factors: Default::default(),
capabilities: ModelCapabilities::Text,
}
}
@ -65,6 +68,11 @@ impl Model {
format!("{}:{}", self.client_name, self.name)
}
pub fn set_capabilities(mut self, capabilities: ModelCapabilities) -> Self {
self.capabilities = capabilities;
self
}
pub fn set_max_tokens(mut self, max_tokens: Option<usize>) -> Self {
match max_tokens {
None | Some(0) => self.max_tokens = None,
@ -115,3 +123,46 @@ impl Model {
Ok(())
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelConfig {
pub name: String,
pub max_tokens: Option<usize>,
#[serde(deserialize_with = "deserialize_capabilities")]
#[serde(default = "default_capabilities")]
pub capabilities: ModelCapabilities,
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ModelCapabilities: u32 {
const Text = 0b00000001;
const Vision = 0b00000010;
}
}
impl From<&str> for ModelCapabilities {
fn from(value: &str) -> Self {
let value = if value.is_empty() { "text" } else { value };
let mut output = ModelCapabilities::empty();
if value.contains("text") {
output |= ModelCapabilities::Text;
}
if value.contains("vision") {
output |= ModelCapabilities::Vision;
}
output
}
}
fn deserialize_capabilities<'de, D>(deserializer: D) -> Result<ModelCapabilities, D::Error>
where
D: Deserializer<'de>,
{
let value: String = Deserialize::deserialize(deserializer)?;
Ok(value.as_str().into())
}
fn default_capabilities() -> ModelCapabilities {
ModelCapabilities::Text
}

@ -1,9 +1,9 @@
use super::{
message::*, patch_system_message, Client, ExtraConfig, Model, OllamaClient, PromptType,
SendData, TokensCountFactors,
message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptType, SendData, TokensCountFactors,
};
use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind};
use crate::{render::ReplyHandler, utils::PromptKind};
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
@ -20,21 +20,13 @@ pub struct OllamaConfig {
pub api_base: String,
pub api_key: Option<String>,
pub chat_endpoint: Option<String>,
pub models: Vec<LocalAIModel>,
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LocalAIModel {
name: String,
max_tokens: Option<usize>,
}
#[async_trait]
impl Client for OllamaClient {
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
@ -75,6 +67,7 @@ impl OllamaClient {
.iter()
.map(|v| {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_tokens(v.max_tokens)
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})

@ -1,12 +1,6 @@
use super::{
ExtraConfig, OpenAIClient, PromptType, SendData,
Model, TokensCountFactors,
};
use super::{ExtraConfig, Model, OpenAIClient, PromptType, SendData, TokensCountFactors};
use crate::{
render::ReplyHandler,
utils::PromptKind,
};
use crate::{render::ReplyHandler, utils::PromptKind};
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
@ -19,14 +13,14 @@ use std::env;
const API_BASE: &str = "https://api.openai.com/v1";
const MODELS: [(&str, usize); 7] = [
("gpt-3.5-turbo", 4096),
("gpt-3.5-turbo-16k", 16385),
("gpt-3.5-turbo-1106", 16385),
("gpt-4", 8192),
("gpt-4-32k", 32768),
("gpt-4-1106-preview", 128000),
("gpt-4-vision-preview", 128000),
const MODELS: [(&str, usize, &str); 7] = [
("gpt-3.5-turbo", 4096, "text"),
("gpt-3.5-turbo-16k", 16385, "text"),
("gpt-3.5-turbo-1106", 16385, "text"),
("gpt-4", 8192, "text"),
("gpt-4-32k", 32768, "text"),
("gpt-4-1106-preview", 128000, "text"),
("gpt-4-vision-preview", 128000, "text,vision"),
];
pub const OPENAI_TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
@ -51,8 +45,9 @@ impl OpenAIClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens)| {
.map(|(name, max_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})

@ -1,7 +1,6 @@
use super::{message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData};
use crate::{
config::GlobalConfig,
render::ReplyHandler,
utils::{sha256sum, PromptKind},
};
@ -25,12 +24,12 @@ const API_URL: &str =
const API_URL_VL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation";
const MODELS: [(&str, usize); 5] = [
("qwen-turbo", 8192),
("qwen-plus", 32768),
("qwen-max", 8192),
("qwen-max-longcontext", 30720),
("qwen-vl-plus", 0),
const MODELS: [(&str, usize, &str); 5] = [
("qwen-turbo", 8192, "text"),
("qwen-plus", 32768, "text"),
("qwen-max", 8192, "text"),
("qwen-max-longcontext", 30720, "text"),
("qwen-vl-plus", 0, "text,vision"),
];
#[derive(Debug, Clone, Deserialize, Default)]
@ -42,9 +41,7 @@ pub struct QianwenConfig {
#[async_trait]
impl Client for QianwenClient {
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}
client_common_fns!();
async fn send_message_inner(
&self,
@ -80,8 +77,10 @@ impl QianwenClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens)| {
Model::new(client_name, name).set_max_tokens(Some(max_tokens))
.map(|(name, max_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
})
.collect()
}

@ -1,4 +1,4 @@
use crate::client::{ImageUrl, MessageContent, MessageContentPart};
use crate::client::{ImageUrl, MessageContent, MessageContentPart, ModelCapabilities};
use crate::utils::sha256sum;
use anyhow::{bail, Context, Result};
@ -119,6 +119,14 @@ impl Input {
MessageContent::Array(list)
}
}
pub fn required_capabilities(&self) -> ModelCapabilities {
if !self.medias.is_empty() {
ModelCapabilities::Vision
} else {
ModelCapabilities::Text
}
}
}
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {

@ -14,7 +14,7 @@ use crate::config::{Config, GlobalConfig};
use anyhow::Result;
use clap::Parser;
use client::{init_client, list_models};
use client::{ensure_model_capabilities, init_client, list_models};
use config::Input;
use is_terminal::IsTerminal;
use parking_lot::RwLock;
@ -114,7 +114,8 @@ fn start_directive(
session.guard_save()?;
}
let input = Input::new(text, include.unwrap_or_default())?;
let client = init_client(config)?;
let mut client = init_client(config)?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
config.read().maybe_print_send_tokens(&input);
let output = if no_stream {
let output = client.send_message(input.clone())?;

@ -6,7 +6,7 @@ use self::completer::ReplCompleter;
use self::highlighter::ReplHighlighter;
use self::prompt::ReplPrompt;
use crate::client::init_client;
use crate::client::{ensure_model_capabilities, init_client};
use crate::config::{GlobalConfig, Input, State};
use crate::render::{render_error, render_stream};
use crate::utils::{create_abort_signal, set_text, AbortSignal};
@ -268,7 +268,8 @@ impl Repl {
Input::new(text, files)?
};
self.config.read().maybe_print_send_tokens(&input);
let client = init_client(&self.config)?;
let mut client = init_client(&self.config)?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let output = render_stream(&input, client.as_ref(), &self.config, self.abort.clone())?;
self.config.write().save_message(input, &output)?;
if self.config.read().auto_copy {

Loading…
Cancel
Save