refactor: improve serve responses (#455)

pull/456/head
sigoden 3 weeks ago committed by GitHub
parent d9b8eabf23
commit b33e2da75e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -28,18 +28,23 @@ test-without-config() {
# @arg clients+[`_choice_client`]
test-clients() {
for c in "${argc_clients[@]}"; do
echo "### $c streaming"
echo "### $c stream"
aichat -m "$c" 1 + 2 = ?
echo "### $c non-streaming"
echo "### $c non-stream"
aichat -m "$c" -S 1 + 2 = ?
done
}
# @cmd Test proxy server
# @option --model=default
# @option -m --model=default
# @flag -S --no-stream
# @arg text~
test-server() {
argc generic-chat \
args=()
if [[ -n "$argc_no_stream" ]]; then
args+=("-S")
fi
argc generic-chat "${args[@]}" \
--api-base http://localhost:8000/v1 \
--model $argc_model \
"$@"

@ -30,6 +30,7 @@ use tokio_graceful::Shutdown;
use tokio_stream::wrappers::UnboundedReceiverStream;
const DEFAULT_ADDRESS: &str = "127.0.0.1:8000";
const DEFAULT_MODEL_NAME: &str = "default";
type AppResponse = Response<BoxBody<Bytes, Infallible>>;
@ -145,8 +146,17 @@ impl Server {
..Default::default()
};
let config = Arc::new(RwLock::new(config));
if model != "default" && model != self.model.id() {
config.write().set_model(&model)?;
let (model_name, change) = if model == DEFAULT_MODEL_NAME {
(self.model.id(), true)
} else if self.model.id() == model {
(model, false)
} else {
(model, true)
};
if change {
config.write().set_model(&model_name)?;
}
let mut client = init_client(&config)?;
@ -198,6 +208,7 @@ impl Server {
if let Err(err) = ret {
send_first_event(&tx, Some(format!("{err:?}")), &mut is_first)
}
let _ = tx.send(ResEvent::Done);
}
}
});
@ -208,16 +219,23 @@ impl Server {
bail!("{err}");
}
let shared: Arc<(String, i64)> = Arc::new((completion_id, created));
let shared: Arc<(String, String, i64)> = Arc::new((completion_id, model_name, created));
let stream = UnboundedReceiverStream::new(rx);
let stream = stream.filter_map(move |res_event| {
let shared = shared.clone();
async move {
let (completion_id, model, created) = shared.as_ref();
match res_event {
ResEvent::Text(text) => {
Some(Ok(create_frame(&shared.0, shared.1, &text, false)))
ResEvent::Text(text) => Some(Ok(create_frame(
completion_id,
model,
*created,
&text,
false,
))),
ResEvent::Done => {
Some(Ok(create_frame(completion_id, model, *created, "", true)))
}
ResEvent::Done => Some(Ok(create_frame(&shared.0, shared.1, "", true))),
_ => None,
}
}
@ -290,7 +308,7 @@ fn set_cors_header(res: &mut AppResponse) {
);
}
fn create_frame(id: &str, created: i64, content: &str, done: bool) -> Frame<Bytes> {
fn create_frame(id: &str, model: &str, created: i64, content: &str, done: bool) -> Frame<Bytes> {
let (delta, finish_reason) = if done {
(json!({}), "stop".into())
} else {
@ -301,11 +319,11 @@ fn create_frame(id: &str, created: i64, content: &str, done: bool) -> Frame<Byte
};
(delta, Value::Null)
};
let mut value = json!({
let value = json!({
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": "gpt-3.5-turbo",
"model": model,
"choices": [
{
"index": 0,
@ -315,11 +333,6 @@ fn create_frame(id: &str, created: i64, content: &str, done: bool) -> Frame<Byte
],
});
let output = if done {
value["usage"] = json!({
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
});
format!("data: {value}\n\ndata: [DONE]\n\n")
} else {
format!("data: {value}\n\n")

Loading…
Cancel
Save