|
|
|
@ -1,12 +1,11 @@
|
|
|
|
|
use super::{
|
|
|
|
|
catch_error, extract_system_message, json_stream, message::*, Client, CohereClient,
|
|
|
|
|
catch_error, extract_system_message, json_stream, message::*, CohereClient,
|
|
|
|
|
ExtraConfig, Model, ModelConfig, PromptType, ReplyHandler, SendData,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
use crate::utils::PromptKind;
|
|
|
|
|
|
|
|
|
|
use anyhow::{bail, Result};
|
|
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
|
|
|
|
use serde::Deserialize;
|
|
|
|
|
use serde_json::{json, Value};
|
|
|
|
@ -22,26 +21,6 @@ pub struct CohereConfig {
|
|
|
|
|
pub extra: Option<ExtraConfig>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
|
impl Client for CohereClient {
|
|
|
|
|
client_common_fns!();
|
|
|
|
|
|
|
|
|
|
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
|
|
|
|
|
let builder = self.request_builder(client, data)?;
|
|
|
|
|
send_message(builder).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn send_message_streaming_inner(
|
|
|
|
|
&self,
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
handler: &mut ReplyHandler,
|
|
|
|
|
data: SendData,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
let builder = self.request_builder(client, data)?;
|
|
|
|
|
send_message_streaming(builder, handler).await
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl CohereClient {
|
|
|
|
|
list_models_fn!(
|
|
|
|
|
CohereConfig,
|
|
|
|
@ -74,7 +53,9 @@ impl CohereClient {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) async fn send_message(builder: RequestBuilder) -> Result<String> {
|
|
|
|
|
impl_client_trait!(CohereClient, send_message, send_message_streaming);
|
|
|
|
|
|
|
|
|
|
async fn send_message(builder: RequestBuilder) -> Result<String> {
|
|
|
|
|
let res = builder.send().await?;
|
|
|
|
|
let status = res.status();
|
|
|
|
|
let data: Value = res.json().await?;
|
|
|
|
@ -85,10 +66,7 @@ pub(crate) async fn send_message(builder: RequestBuilder) -> Result<String> {
|
|
|
|
|
Ok(output.to_string())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) async fn send_message_streaming(
|
|
|
|
|
builder: RequestBuilder,
|
|
|
|
|
handler: &mut ReplyHandler,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
|
|
|
|
|
let res = builder.send().await?;
|
|
|
|
|
let status = res.status();
|
|
|
|
|
if status != 200 {
|
|
|
|
|