feat: serve all LLMs as OpenAI-compatible API (#431)

pull/432/head
sigoden 2 weeks ago committed by GitHub
parent 9c6c9f10a2
commit 0a4c0413ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

308
Cargo.lock generated

@ -46,12 +46,16 @@ dependencies = [
"dirs",
"fancy-regex",
"futures-util",
"http",
"http-body-util",
"hyper",
"hyper-util",
"inquire",
"is-terminal",
"lazy_static",
"log",
"mime_guess",
"nu-ansi-term",
"nu-ansi-term 0.50.0",
"parking_lot",
"reedline",
"reqwest",
@ -65,7 +69,10 @@ dependencies = [
"simplelog",
"syntect",
"textwrap",
"time",
"tokio",
"tokio-graceful",
"tokio-stream",
"unicode-width",
]
@ -280,7 +287,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706"
dependencies = [
"memchr",
"regex-automata",
"regex-automata 0.4.6",
"serde",
]
@ -338,7 +345,7 @@ dependencies = [
"num-traits",
"serde",
"wasm-bindgen",
"windows-targets 0.52.4",
"windows-targets 0.52.5",
]
[[package]]
@ -601,8 +608,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2"
dependencies = [
"bit-set",
"regex-automata",
"regex-syntax",
"regex-automata 0.4.6",
"regex-syntax 0.8.3",
]
[[package]]
@ -749,6 +756,20 @@ dependencies = [
"byteorder",
]
[[package]]
name = "generator"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "186014d53bc231d0090ef8d6f03e0920c54d85a5ed22f4f2f74315ec56cf83fb"
dependencies = [
"cc",
"cfg-if",
"libc",
"log",
"rustversion",
"windows",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@ -786,6 +807,25 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "h2"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "816ec7294445779408f36fe57bc5b7fc1cf59664059096c65f905c1c61f58069"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http",
"indexmap",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.14.3"
@ -859,6 +899,12 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904"
[[package]]
name = "httpdate"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "1.2.0"
@ -868,9 +914,11 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2",
"http",
"http-body",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
@ -926,7 +974,7 @@ dependencies = [
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
"windows-core 0.52.0",
]
[[package]]
@ -1035,7 +1083,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
dependencies = [
"cfg-if",
"windows-targets 0.52.4",
"windows-targets 0.52.5",
]
[[package]]
@ -1076,6 +1124,22 @@ version = "0.4.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
[[package]]
name = "loom"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca"
dependencies = [
"cfg-if",
"generator",
"pin-utils",
"scoped-tls",
"serde",
"serde_json",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "malloc_buf"
version = "0.0.6"
@ -1085,6 +1149,15 @@ dependencies = [
"libc",
]
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "memchr"
version = "2.7.2"
@ -1165,6 +1238,16 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "nu-ansi-term"
version = "0.50.0"
@ -1296,6 +1379,12 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking_lot"
version = "0.12.1"
@ -1450,7 +1539,7 @@ dependencies = [
"crossterm 0.27.0",
"fd-lock",
"itertools",
"nu-ansi-term",
"nu-ansi-term 0.50.0",
"serde",
"strip-ansi-escapes",
"strum",
@ -1460,6 +1549,27 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "regex"
version = "1.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata 0.4.6",
"regex-syntax 0.8.3",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax 0.6.29",
]
[[package]]
name = "regex-automata"
version = "0.4.6"
@ -1468,9 +1578,15 @@ checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
"regex-syntax 0.8.3",
]
[[package]]
name = "regex-syntax"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.8.3"
@ -1775,6 +1891,15 @@ dependencies = [
"digest",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "shell-words"
version = "1.1.0"
@ -1929,7 +2054,7 @@ dependencies = [
"once_cell",
"onig",
"plist",
"regex-syntax",
"regex-syntax 0.8.3",
"serde",
"serde_derive",
"serde_json",
@ -2001,9 +2126,9 @@ dependencies = [
[[package]]
name = "time"
version = "0.3.35"
version = "0.3.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef89ece63debf11bc32d1ed8d078ac870cbeb44da02afb02a9ff135ae7ca0582"
checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
dependencies = [
"deranged",
"itoa",
@ -2065,6 +2190,19 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "tokio-graceful"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "627ba4daa4cbce14740603401c895e72d47ecd86690a18e3f0841266e9340de7"
dependencies = [
"loom",
"pin-project-lite",
"slab",
"tokio",
"tracing",
]
[[package]]
name = "tokio-macros"
version = "2.2.0"
@ -2099,6 +2237,18 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
name = "tokio-util"
version = "0.7.10"
@ -2149,9 +2299,21 @@ checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.32"
@ -2159,6 +2321,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [
"matchers",
"nu-ansi-term 0.46.0",
"once_cell",
"regex",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]
[[package]]
@ -2264,6 +2456,12 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "version_check"
version = "0.9.4"
@ -2517,13 +2715,42 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows"
version = "0.54.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9252e5725dbed82865af151df558e754e4a3c2c30818359eb17465f1346a1b49"
dependencies = [
"windows-core 0.54.0",
"windows-targets 0.52.5",
]
[[package]]
name = "windows-core"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
dependencies = [
"windows-targets 0.52.4",
"windows-targets 0.52.5",
]
[[package]]
name = "windows-core"
version = "0.54.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12661b9c89351d684a50a8a643ce5f608e20243b9fb84687800163429f161d65"
dependencies = [
"windows-result",
"windows-targets 0.52.5",
]
[[package]]
name = "windows-result"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "749f0da9cc72d82e600d8d2e44cadd0b9eedb9038f71a1c58556ac1c5791813b"
dependencies = [
"windows-targets 0.52.5",
]
[[package]]
@ -2541,7 +2768,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets 0.52.4",
"windows-targets 0.52.5",
]
[[package]]
@ -2561,17 +2788,18 @@ dependencies = [
[[package]]
name = "windows-targets"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b"
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
dependencies = [
"windows_aarch64_gnullvm 0.52.4",
"windows_aarch64_msvc 0.52.4",
"windows_i686_gnu 0.52.4",
"windows_i686_msvc 0.52.4",
"windows_x86_64_gnu 0.52.4",
"windows_x86_64_gnullvm 0.52.4",
"windows_x86_64_msvc 0.52.4",
"windows_aarch64_gnullvm 0.52.5",
"windows_aarch64_msvc 0.52.5",
"windows_i686_gnu 0.52.5",
"windows_i686_gnullvm",
"windows_i686_msvc 0.52.5",
"windows_x86_64_gnu 0.52.5",
"windows_x86_64_gnullvm 0.52.5",
"windows_x86_64_msvc 0.52.5",
]
[[package]]
@ -2582,9 +2810,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9"
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
[[package]]
name = "windows_aarch64_msvc"
@ -2594,9 +2822,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675"
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
[[package]]
name = "windows_i686_gnu"
@ -2606,9 +2834,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3"
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
[[package]]
name = "windows_i686_msvc"
@ -2618,9 +2852,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02"
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
[[package]]
name = "windows_x86_64_gnu"
@ -2630,9 +2864,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03"
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
[[package]]
name = "windows_x86_64_gnullvm"
@ -2642,9 +2876,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177"
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
[[package]]
name = "windows_x86_64_msvc"
@ -2654,9 +2888,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]]
name = "winreg"

@ -23,6 +23,8 @@ serde = { version = "1.0.152", features = ["derive"] }
serde_json = { version = "1.0.93", features = ["preserve_order"] }
serde_yaml = "0.9.17"
tokio = { version = "1.34.0", features = ["rt", "time", "macros", "signal", "rt-multi-thread"] }
tokio-graceful = "0.1.6"
tokio-stream = { version = "0.1.15", default-features = false, features = ["sync"] }
crossterm = "0.27.0"
chrono = "0.4.23"
bincode = "1.3.3"
@ -45,6 +47,11 @@ sha2 = "0.10.8"
bitflags = "2.4.1"
unicode-width = "0.1.11"
async-recursion = "1.1.0"
http = "1.1.0"
http-body-util = "0.1"
hyper = { version = "1.0", features = ["full"] }
hyper-util = { version = "0.1", features = ["server-auto", "client-legacy"] }
time = { version = "0.3.36", features = ["macros"] }
[dependencies.reqwest]
version = "0.12.0"

@ -19,7 +19,7 @@ _aichat() {
case "${cmd}" in
aichat)
opts="-m -r -s -e -c -f -H -S -w -h -V --model --role --session --save-session --execute --code --file --no-highlight --no-stream --wrap --light-theme --dry-run --info --list-models --list-roles --list-sessions --help --version"
opts="-m -r -s -e -c -f -H -S -w -h -V --model --role --session --save-session --serve --execute --code --file --no-highlight --no-stream --wrap --light-theme --dry-run --info --list-models --list-roles --list-sessions --help --version"
if [[ ${cur} == -* || ${COMP_CWORD} -eq 1 ]] ; then
COMPREPLY=( $(compgen -W "${opts}" -- "${cur}") )
return 0

@ -4,6 +4,7 @@ complete -c aichat -s s -l session -x -a"(aichat --list-sessions)" -d 'Start or
complete -c aichat -s f -l file -d 'Include files with the message' -r -F
complete -c aichat -s w -l wrap -d 'Control text wrapping (no, auto, <max-width>)'
complete -c aichat -l save-session -d 'Forces the session to be saved'
complete -c aichat -l serve -d 'Serve all LLMs as OpenAI-compatible API'
complete -c aichat -s e -l execute -d 'Execute commands in natural language'
complete -c aichat -s c -l code -d 'Output code only'
complete -c aichat -s H -l no-highlight -d 'Turn off syntax highlighting'

@ -28,6 +28,7 @@ module completions {
--role(-r): string@"nu-complete aichat role" # Select a role
--session(-s): string@"nu-complete aichat role" # Start or join a session
--save-session # Forces the session to be saved
--serve # Serve all LLMs as OpenAI-compatible API
--execute(-e) # Execute commands in natural language
--code(-c) # Output code only
--file(-f): string # Include files with the message

@ -31,6 +31,7 @@ Register-ArgumentCompleter -Native -CommandName 'aichat' -ScriptBlock {
[CompletionResult]::new('-w', '-w', [CompletionResultType]::ParameterName, 'Control text wrapping (no, auto, <max-width>)')
[CompletionResult]::new('--wrap', '--wrap', [CompletionResultType]::ParameterName, 'Control text wrapping (no, auto, <max-width>)')
[CompletionResult]::new('--save-session', '--save-session', [CompletionResultType]::ParameterName, 'Forces the session to be saved')
[CompletionResult]::new('--serve', '--serve', [CompletionResultType]::ParameterName, 'Serve all LLMs as OpenAI-compatible API')
[CompletionResult]::new('-e', '-e', [CompletionResultType]::ParameterName, 'Execute commands in natural language')
[CompletionResult]::new('--execute', '--execute', [CompletionResultType]::ParameterName, 'Execute commands in natural language')
[CompletionResult]::new('-c', '-c', [CompletionResultType]::ParameterName, 'Output code only')

@ -26,6 +26,7 @@ _aichat() {
'-w+[Control text wrapping (no, auto, <max-width>)]:WRAP: ' \
'--wrap=[Control text wrapping (no, auto, <max-width>)]:WRAP: ' \
'--save-session[Forces the session to be saved]' \
'--serve[Serve all LLMs as OpenAI-compatible API]' \
'-e[Execute commands in natural language]' \
'--execute[Execute commands in natural language]' \
'-c[Output code only]' \

@ -15,6 +15,9 @@ pub struct Cli {
/// Forces the session to be saved
#[clap(long)]
pub save_session: bool,
/// Serve all LLMs as OpenAI-compatible API
#[clap(long, value_name = "ADDRESS")]
pub serve: Option<Option<String>>,
/// Execute commands in natural language
#[clap(short = 'e', long)]
pub execute: bool,

@ -19,7 +19,7 @@ impl ReplyHandler {
}
pub fn text(&mut self, text: &str) -> Result<()> {
debug!("ReplyText: {}", text);
// debug!("ReplyText: {}", text);
if text.is_empty() {
return Ok(());
}
@ -33,7 +33,7 @@ impl ReplyHandler {
}
pub fn done(&mut self) -> Result<()> {
debug!("ReplyDone");
// debug!("ReplyDone");
let ret = self
.sender
.send(ReplyEvent::Done)

@ -79,9 +79,9 @@ pub struct Config {
#[serde(skip)]
pub model: Model,
#[serde(skip)]
pub last_message: Option<(Input, String)>,
pub working_mode: WorkingMode,
#[serde(skip)]
pub in_repl: bool,
pub last_message: Option<(Input, String)>,
}
impl Default for Config {
@ -110,8 +110,8 @@ impl Default for Config {
role: None,
session: None,
model: Default::default(),
working_mode: WorkingMode::Command,
last_message: None,
in_repl: false,
}
}
}
@ -119,13 +119,13 @@ impl Default for Config {
pub type GlobalConfig = Arc<RwLock<Config>>;
impl Config {
pub fn init(is_interactive: bool) -> Result<Self> {
pub fn init(working_mode: WorkingMode) -> Result<Self> {
let config_path = Self::config_file()?;
let api_key = env::var("OPENAI_API_KEY").ok();
let exist_config_path = config_path.exists();
if is_interactive && api_key.is_none() && !exist_config_path {
if working_mode != WorkingMode::Command && api_key.is_none() && !exist_config_path {
create_config_file(&config_path)?;
}
let mut config = if api_key.is_some() && !exist_config_path {
@ -143,14 +143,13 @@ impl Config {
config.set_wrap(&wrap)?;
}
config.working_mode = working_mode;
config.load_roles()?;
config.setup_model()?;
config.setup_highlight();
config.setup_light_theme()?;
setup_logger()?;
Ok(config)
}
@ -611,7 +610,7 @@ impl Config {
let save_session = session.save_session();
if session.dirty && save_session != Some(false) {
if save_session.is_none() || session.is_temp() {
if !self.in_repl {
if self.working_mode != WorkingMode::Repl {
return Ok(());
}
let ans = Confirm::new("Save session?").with_default(false).prompt()?;
@ -998,6 +997,13 @@ impl Keybindings {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WorkingMode {
Command,
Repl,
Serve,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum State {
Normal,
@ -1145,25 +1151,3 @@ fn complete_option_bool(value: Option<bool>) -> Vec<String> {
None => vec!["true".to_string(), "false".to_string()],
}
}
#[cfg(debug_assertions)]
fn setup_logger() -> Result<()> {
use simplelog::{LevelFilter, WriteLogger};
let file = std::fs::File::create(Config::local_path("debug.log")?)?;
let log_filter = match std::env::var("AICHAT_LOG_FILTER") {
Ok(v) => v,
Err(_) => "aichat".into(),
};
let config = simplelog::ConfigBuilder::new()
.add_filter_allow(log_filter)
.set_thread_level(LevelFilter::Off)
.set_time_level(LevelFilter::Off)
.build();
WriteLogger::init(log::LevelFilter::Debug, config, file)?;
Ok(())
}
#[cfg(not(debug_assertions))]
fn setup_logger() -> Result<()> {
Ok(())
}

@ -0,0 +1,40 @@
use crate::config::WorkingMode;
use anyhow::Result;
use log::LevelFilter;
use simplelog::{format_description, Config as LogConfig, ConfigBuilder};
#[cfg(debug_assertions)]
pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {
let config = build_config();
if working_mode == WorkingMode::Serve {
simplelog::SimpleLogger::init(LevelFilter::Debug, config)?;
} else {
let file = std::fs::File::create(crate::config::Config::local_path("debug.log")?)?;
simplelog::WriteLogger::init(LevelFilter::Debug, config, file)?;
}
Ok(())
}
#[cfg(not(debug_assertions))]
pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {
let config = build_config();
if working_mode == WorkingMode::Serve {
simplelog::SimpleLogger::init(log::LevelFilter::Info, config)?;
}
Ok(())
}
fn build_config() -> LogConfig {
let log_filter = match std::env::var("AICHAT_LOG_FILTER") {
Ok(v) => v,
Err(_) => "aichat".into(),
};
ConfigBuilder::new()
.add_filter_allow(log_filter)
.set_time_format_custom(format_description!(
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
))
.set_thread_level(LevelFilter::Off)
.build()
}

@ -1,17 +1,21 @@
mod cli;
mod client;
mod config;
mod logger;
mod render;
mod repl;
mod serve;
#[macro_use]
mod utils;
#[macro_use]
extern crate log;
#[macro_use]
mod utils;
use crate::cli::Cli;
use crate::client::{ensure_model_capabilities, init_client, list_models, send_stream};
use crate::config::{Config, GlobalConfig, Input, CODE_ROLE, EXPLAIN_ROLE, SHELL_ROLE};
use crate::config::{
Config, GlobalConfig, Input, WorkingMode, CODE_ROLE, EXPLAIN_ROLE, SHELL_ROLE,
};
use crate::render::{render_error, MarkdownRender};
use crate::repl::Repl;
use crate::utils::{
@ -33,7 +37,21 @@ use tokio::sync::oneshot;
async fn main() -> Result<()> {
let cli = Cli::parse();
let text = cli.text();
let config = Arc::new(RwLock::new(Config::init(text.is_none())?));
let file = &cli.file;
let no_input = text.is_none() && file.is_empty();
let working_mode = if cli.serve.is_some() {
WorkingMode::Serve
} else if no_input {
WorkingMode::Repl
} else {
WorkingMode::Command
};
crate::logger::setup_logger(working_mode)?;
let config = Arc::new(RwLock::new(Config::init(working_mode)?));
if let Some(addr) = cli.serve {
return serve::run(config, addr).await;
}
if cli.list_roles {
config
.read()
@ -89,20 +107,21 @@ async fn main() -> Result<()> {
return Ok(());
}
let text = aggregate_text(text)?;
let input = create_input(&config, text, &cli.file)?;
if cli.execute {
match input {
Some(input) => {
execute(&config, input).await?;
return Ok(());
}
None => bail!("No input text"),
if no_input {
bail!("No input");
}
let input = create_input(&config, text, file)?;
execute(&config, input).await?;
return Ok(());
}
config.write().apply_prelude()?;
if let Err(err) = match input {
Some(input) => start_directive(&config, input, cli.no_stream, cli.code).await,
None => start_interactive(&config).await,
if let Err(err) = match no_input {
false => {
let input = create_input(&config, text, file)?;
start_directive(&config, input, cli.no_stream, cli.code).await
}
true => start_interactive(&config).await,
} {
let highlight = stderr().is_terminal() && config.read().highlight;
render_error(err, highlight)
@ -232,19 +251,15 @@ fn aggregate_text(text: Option<String>) -> Result<Option<String>> {
Ok(text)
}
fn create_input(
config: &GlobalConfig,
text: Option<String>,
file: &[String],
) -> Result<Option<Input>> {
if text.is_none() && file.is_empty() {
return Ok(None);
}
fn create_input(config: &GlobalConfig, text: Option<String>, file: &[String]) -> Result<Input> {
let input_context = config.read().input_context();
let input = if file.is_empty() {
Input::from_str(&text.unwrap_or_default(), input_context)
} else {
Input::new(&text.unwrap_or_default(), file.to_vec(), input_context)?
};
Ok(Some(input))
if input.is_empty() {
bail!("No input");
}
Ok(input)
}

@ -77,8 +77,6 @@ pub struct Repl {
impl Repl {
pub fn init(config: &GlobalConfig) -> Result<Self> {
config.write().in_repl = true;
let editor = Self::create_editor(config)?;
let prompt = ReplPrompt::new(config);

@ -0,0 +1,364 @@
use crate::{
client::{init_client, ClientConfig, Message, Model, ReplyEvent, ReplyHandler, SendData},
config::{Config, GlobalConfig},
utils::create_abort_signal,
};
use anyhow::{anyhow, bail, Result};
use bytes::Bytes;
use chrono::{Timelike, Utc};
use futures_util::StreamExt;
use http::{Method, Response, StatusCode};
use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
use hyper::{
body::{Frame, Incoming},
service::service_fn,
};
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use serde::Deserialize;
use serde_json::{json, Value};
use std::{convert::Infallible, sync::Arc};
use tokio::{
net::TcpListener,
sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
oneshot,
},
};
use tokio_graceful::Shutdown;
use tokio_stream::wrappers::UnboundedReceiverStream;
const DEFAULT_ADDRESS: &str = "0.0.0.0:8080";
type AppResponse = Response<BoxBody<Bytes, Infallible>>;
pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
let addr = match addr {
Some(addr) => {
if let Ok(port) = addr.parse::<u16>() {
format!("0.0.0.0:{port}")
} else {
addr
}
}
None => DEFAULT_ADDRESS.to_string(),
};
let clients = config.read().clients.clone();
let model = config.read().model.clone();
let listener = TcpListener::bind(&addr).await?;
let server = Arc::new(Server { clients, model });
let stop_server = server.run(listener).await?;
println!("Access the chat completion API at: http://{addr}/v1/chat/completions");
shutdown_signal().await;
let _ = stop_server.send(());
Ok(())
}
struct Server {
clients: Vec<ClientConfig>,
model: Model,
}
impl Server {
async fn run(self: Arc<Self>, listener: TcpListener) -> Result<oneshot::Sender<()>> {
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() });
let guard = shutdown.guard_weak();
loop {
tokio::select! {
res = listener.accept() => {
let Ok((cnx, _)) = res else {
continue;
};
let stream = TokioIo::new(cnx);
let server = self.clone();
shutdown.spawn_task(async move {
let hyper_service = service_fn(move |request: hyper::Request<Incoming>| {
server.clone().handle(request)
});
let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(stream, hyper_service)
.await;
});
}
_ = guard.cancelled() => {
break;
}
}
}
});
Ok(tx)
}
async fn handle(
self: Arc<Self>,
req: hyper::Request<Incoming>,
) -> std::result::Result<AppResponse, hyper::Error> {
let method = req.method().clone();
let uri = req.uri().clone();
let mut status = StatusCode::OK;
let res = if method == Method::POST && uri == "/v1/chat/completions" {
self.chat_completion(req).await
} else if method == Method::OPTIONS && uri == "/v1/chat/completions" {
status = StatusCode::NO_CONTENT;
Ok(Response::default())
} else {
status = StatusCode::NOT_FOUND;
Err(anyhow!("The requested endpoint was not found."))
};
let mut res = match res {
Ok(res) => {
info!("{method} {uri} {}", status.as_u16());
res
}
Err(err) => {
error!("{method} {uri} {} {err}", status.as_u16());
ret_err(err)
}
};
*res.status_mut() = status;
set_cors_header(&mut res);
Ok(res)
}
async fn chat_completion(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: ChatCompletionReqBody = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let ChatCompletionReqBody {
model,
messages,
temperature,
max_tokens,
stream,
} = req_body;
let config = Config {
clients: self.clients.to_vec(),
model: self.model.clone(),
..Default::default()
};
let config = Arc::new(RwLock::new(config));
if model != "default" && model != self.model.id() {
config.write().set_model(&model)?;
}
let mut client = init_client(&config)?;
if max_tokens.is_some() {
client.set_model(client.model().clone().set_max_output_tokens(max_tokens));
}
let abort = create_abort_signal();
let http_client = client.build_client()?;
let completion_id = generate_completion_id();
let created = Utc::now().timestamp();
let send_data: SendData = SendData {
messages,
temperature,
stream,
};
if stream {
let (tx, mut rx) = unbounded_channel();
tokio::spawn(async move {
let mut is_first = true;
let (tx2, rx2) = unbounded_channel();
let mut handler = ReplyHandler::new(tx2, abort);
async fn map_event(
mut rx: UnboundedReceiver<ReplyEvent>,
tx: &UnboundedSender<ResEvent>,
is_first: &mut bool,
) {
while let Some(reply_event) = rx.recv().await {
if *is_first {
let _ = tx.send(ResEvent::First(None));
*is_first = false;
}
match reply_event {
ReplyEvent::Text(text) => {
let _ = tx.send(ResEvent::Text(text));
}
ReplyEvent::Done => {
let _ = tx.send(ResEvent::Done);
}
}
}
}
tokio::select! {
_ = map_event(rx2, &tx, &mut is_first) => {}
ret = client.send_message_streaming_inner(&http_client, &mut handler, send_data) => {
if let Err(err) = ret {
send_first_event(&tx, Some(format!("{err:?}")), &mut is_first)
}
}
}
});
let first_event = rx.recv().await;
if let Some(ResEvent::First(Some(err))) = first_event {
bail!("{err}");
}
let shared: Arc<(String, i64)> = Arc::new((completion_id, created));
let stream = UnboundedReceiverStream::new(rx);
let stream = stream.filter_map(move |res_event| {
let shared = shared.clone();
async move {
match res_event {
ResEvent::Text(text) => {
Some(Ok(create_frame(&shared.0, shared.1, &text, false)))
}
ResEvent::Done => Some(Ok(create_frame(&shared.0, shared.1, "", true))),
_ => None,
}
}
});
let res = Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(BodyExt::boxed(StreamBody::new(stream)))?;
Ok(res)
} else {
let content = client.send_message_inner(&http_client, send_data).await?;
let res = Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(ret_non_stream(&completion_id, created, &content)).boxed())?;
Ok(res)
}
}
}
#[derive(Debug, Deserialize)]
struct ChatCompletionReqBody {
model: String,
messages: Vec<Message>,
temperature: Option<f64>,
max_tokens: Option<isize>,
#[serde(default)]
stream: bool,
}
#[derive(Debug)]
enum ResEvent {
First(Option<String>),
Text(String),
Done,
}
fn send_first_event(tx: &UnboundedSender<ResEvent>, data: Option<String>, is_first: &mut bool) {
if *is_first {
let _ = tx.send(ResEvent::First(data));
*is_first = false;
}
}
async fn shutdown_signal() {
tokio::signal::ctrl_c()
.await
.expect("Failed to install CTRL+C signal handler")
}
fn generate_completion_id() -> String {
let random_id = chrono::Utc::now().nanosecond();
format!("chatcmpl-{}", random_id)
}
fn set_cors_header(res: &mut AppResponse) {
res.headers_mut().insert(
hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN,
hyper::header::HeaderValue::from_static("*"),
);
res.headers_mut().insert(
hyper::header::ACCESS_CONTROL_ALLOW_METHODS,
hyper::header::HeaderValue::from_static("GET,POST,PUT,PATCH,DELETE"),
);
res.headers_mut().insert(
hyper::header::ACCESS_CONTROL_ALLOW_HEADERS,
hyper::header::HeaderValue::from_static("Content-Type,Authorization"),
);
}
fn create_frame(id: &str, created: i64, content: &str, done: bool) -> Frame<Bytes> {
let (delta, finish_reason) = if done {
(json!({}), "stop".into())
} else {
let delta = if content.is_empty() {
json!({ "role": "assistant", "content": content })
} else {
json!({ "content": content })
};
(delta, Value::Null)
};
let mut value = json!({
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": "gpt-3.5-turbo",
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
},
],
});
let output = if done {
value["usage"] = json!({
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
});
format!("data: {value}\n\ndata: [DONE]\n\n")
} else {
format!("data: {value}\n\n")
};
Frame::data(Bytes::from(output))
}
fn ret_non_stream(id: &str, created: i64, content: &str) -> Bytes {
let res_body = json!({
"id": id,
"object": "chat.completion",
"created": created,
"model": "gpt-3.5-turbo",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": content,
},
"finish_reason": "stop",
},
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
});
Bytes::from(res_body.to_string())
}
fn ret_err<T: std::fmt::Display>(err: T) -> AppResponse {
let data = json!({
"error": {
"message": err.to_string(),
"type": "invalid_request_error",
},
});
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(data.to_string())).boxed())
.unwrap()
}
Loading…
Cancel
Save