From b33e2da75efab4246f9f1d74a0be56aba1d8bd07 Mon Sep 17 00:00:00 2001 From: sigoden Date: Sun, 28 Apr 2024 20:37:46 +0800 Subject: [PATCH] refactor: improve serve responses (#455) --- Argcfile.sh | 13 +++++++++---- src/serve.rs | 41 +++++++++++++++++++++++++++-------------- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/Argcfile.sh b/Argcfile.sh index a438bbb..1e5229a 100755 --- a/Argcfile.sh +++ b/Argcfile.sh @@ -28,18 +28,23 @@ test-without-config() { # @arg clients+[`_choice_client`] test-clients() { for c in "${argc_clients[@]}"; do - echo "### $c streaming" + echo "### $c stream" aichat -m "$c" 1 + 2 = ? - echo "### $c non-streaming" + echo "### $c non-stream" aichat -m "$c" -S 1 + 2 = ? done } # @cmd Test proxy server -# @option --model=default +# @option -m --model=default +# @flag -S --no-stream # @arg text~ test-server() { - argc generic-chat \ + args=() + if [[ -n "$argc_no_stream" ]]; then + args+=("-S") + fi + argc generic-chat "${args[@]}" \ --api-base http://localhost:8000/v1 \ --model $argc_model \ "$@" diff --git a/src/serve.rs b/src/serve.rs index 2d5ed95..394a0e4 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -30,6 +30,7 @@ use tokio_graceful::Shutdown; use tokio_stream::wrappers::UnboundedReceiverStream; const DEFAULT_ADDRESS: &str = "127.0.0.1:8000"; +const DEFAULT_MODEL_NAME: &str = "default"; type AppResponse = Response>; @@ -145,8 +146,17 @@ impl Server { ..Default::default() }; let config = Arc::new(RwLock::new(config)); - if model != "default" && model != self.model.id() { - config.write().set_model(&model)?; + + let (model_name, change) = if model == DEFAULT_MODEL_NAME { + (self.model.id(), true) + } else if self.model.id() == model { + (model, false) + } else { + (model, true) + }; + + if change { + config.write().set_model(&model_name)?; } let mut client = init_client(&config)?; @@ -198,6 +208,7 @@ impl Server { if let Err(err) = ret { send_first_event(&tx, Some(format!("{err:?}")), &mut is_first) } + let _ = tx.send(ResEvent::Done); } } }); @@ -208,16 +219,23 @@ impl Server { bail!("{err}"); } - let shared: Arc<(String, i64)> = Arc::new((completion_id, created)); + let shared: Arc<(String, String, i64)> = Arc::new((completion_id, model_name, created)); let stream = UnboundedReceiverStream::new(rx); let stream = stream.filter_map(move |res_event| { let shared = shared.clone(); async move { + let (completion_id, model, created) = shared.as_ref(); match res_event { - ResEvent::Text(text) => { - Some(Ok(create_frame(&shared.0, shared.1, &text, false))) + ResEvent::Text(text) => Some(Ok(create_frame( + completion_id, + model, + *created, + &text, + false, + ))), + ResEvent::Done => { + Some(Ok(create_frame(completion_id, model, *created, "", true))) } - ResEvent::Done => Some(Ok(create_frame(&shared.0, shared.1, "", true))), _ => None, } } @@ -290,7 +308,7 @@ fn set_cors_header(res: &mut AppResponse) { ); } -fn create_frame(id: &str, created: i64, content: &str, done: bool) -> Frame { +fn create_frame(id: &str, model: &str, created: i64, content: &str, done: bool) -> Frame { let (delta, finish_reason) = if done { (json!({}), "stop".into()) } else { @@ -301,11 +319,11 @@ fn create_frame(id: &str, created: i64, content: &str, done: bool) -> Frame Frame