diff --git a/Cargo.lock b/Cargo.lock index f0448ae..c7c634d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,12 +46,16 @@ dependencies = [ "dirs", "fancy-regex", "futures-util", + "http", + "http-body-util", + "hyper", + "hyper-util", "inquire", "is-terminal", "lazy_static", "log", "mime_guess", - "nu-ansi-term", + "nu-ansi-term 0.50.0", "parking_lot", "reedline", "reqwest", @@ -65,7 +69,10 @@ dependencies = [ "simplelog", "syntect", "textwrap", + "time", "tokio", + "tokio-graceful", + "tokio-stream", "unicode-width", ] @@ -280,7 +287,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706" dependencies = [ "memchr", - "regex-automata", + "regex-automata 0.4.6", "serde", ] @@ -338,7 +345,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -601,8 +608,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" dependencies = [ "bit-set", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.6", + "regex-syntax 0.8.3", ] [[package]] @@ -749,6 +756,20 @@ dependencies = [ "byteorder", ] +[[package]] +name = "generator" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186014d53bc231d0090ef8d6f03e0920c54d85a5ed22f4f2f74315ec56cf83fb" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -786,6 +807,25 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "h2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "816ec7294445779408f36fe57bc5b7fc1cf59664059096c65f905c1c61f58069" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.14.3" @@ -859,6 +899,12 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.2.0" @@ -868,9 +914,11 @@ dependencies = [ "bytes", "futures-channel", "futures-util", + "h2", "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -926,7 +974,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows-core", + "windows-core 0.52.0", ] [[package]] @@ -1035,7 +1083,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -1076,6 +1124,22 @@ version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "pin-utils", + "scoped-tls", + "serde", + "serde_json", + "tracing", + "tracing-subscriber", +] + [[package]] name = "malloc_buf" version = "0.0.6" @@ -1085,6 +1149,15 @@ dependencies = [ "libc", ] +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "memchr" version = "2.7.2" @@ -1165,6 +1238,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "nu-ansi-term" version = "0.50.0" @@ -1296,6 +1379,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.1" @@ -1450,7 +1539,7 @@ dependencies = [ "crossterm 0.27.0", "fd-lock", "itertools", - "nu-ansi-term", + "nu-ansi-term 0.50.0", "serde", "strip-ansi-escapes", "strum", @@ -1460,6 +1549,27 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata 0.4.6", + "regex-syntax 0.8.3", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + [[package]] name = "regex-automata" version = "0.4.6" @@ -1468,9 +1578,15 @@ checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.3", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.3" @@ -1775,6 +1891,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shell-words" version = "1.1.0" @@ -1929,7 +2054,7 @@ dependencies = [ "once_cell", "onig", "plist", - "regex-syntax", + "regex-syntax 0.8.3", "serde", "serde_derive", "serde_json", @@ -2001,9 +2126,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.35" +version = "0.3.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef89ece63debf11bc32d1ed8d078ac870cbeb44da02afb02a9ff135ae7ca0582" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" dependencies = [ "deranged", "itoa", @@ -2065,6 +2190,19 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "tokio-graceful" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "627ba4daa4cbce14740603401c895e72d47ecd86690a18e3f0841266e9340de7" +dependencies = [ + "loom", + "pin-project-lite", + "slab", + "tokio", + "tracing", +] + [[package]] name = "tokio-macros" version = "2.2.0" @@ -2099,6 +2237,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "tokio-util" version = "0.7.10" @@ -2149,9 +2299,21 @@ checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ "log", "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.32" @@ -2159,6 +2321,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term 0.46.0", + "once_cell", + "regex", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -2264,6 +2456,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "version_check" version = "0.9.4" @@ -2517,13 +2715,42 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.54.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9252e5725dbed82865af151df558e754e4a3c2c30818359eb17465f1346a1b49" +dependencies = [ + "windows-core 0.54.0", + "windows-targets 0.52.5", +] + [[package]] name = "windows-core" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.4", + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-core" +version = "0.54.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12661b9c89351d684a50a8a643ce5f608e20243b9fb84687800163429f161d65" +dependencies = [ + "windows-result", + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-result" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "749f0da9cc72d82e600d8d2e44cadd0b9eedb9038f71a1c58556ac1c5791813b" +dependencies = [ + "windows-targets 0.52.5", ] [[package]] @@ -2541,7 +2768,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -2561,17 +2788,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" dependencies = [ - "windows_aarch64_gnullvm 0.52.4", - "windows_aarch64_msvc 0.52.4", - "windows_i686_gnu 0.52.4", - "windows_i686_msvc 0.52.4", - "windows_x86_64_gnu 0.52.4", - "windows_x86_64_gnullvm 0.52.4", - "windows_x86_64_msvc 0.52.4", + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", ] [[package]] @@ -2582,9 +2810,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" [[package]] name = "windows_aarch64_msvc" @@ -2594,9 +2822,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" [[package]] name = "windows_i686_gnu" @@ -2606,9 +2834,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.4" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" [[package]] name = "windows_i686_msvc" @@ -2618,9 +2852,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" [[package]] name = "windows_x86_64_gnu" @@ -2630,9 +2864,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" [[package]] name = "windows_x86_64_gnullvm" @@ -2642,9 +2876,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" [[package]] name = "windows_x86_64_msvc" @@ -2654,9 +2888,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" [[package]] name = "winreg" diff --git a/Cargo.toml b/Cargo.toml index 5e792e8..92bdd59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,8 @@ serde = { version = "1.0.152", features = ["derive"] } serde_json = { version = "1.0.93", features = ["preserve_order"] } serde_yaml = "0.9.17" tokio = { version = "1.34.0", features = ["rt", "time", "macros", "signal", "rt-multi-thread"] } +tokio-graceful = "0.1.6" +tokio-stream = { version = "0.1.15", default-features = false, features = ["sync"] } crossterm = "0.27.0" chrono = "0.4.23" bincode = "1.3.3" @@ -45,6 +47,11 @@ sha2 = "0.10.8" bitflags = "2.4.1" unicode-width = "0.1.11" async-recursion = "1.1.0" +http = "1.1.0" +http-body-util = "0.1" +hyper = { version = "1.0", features = ["full"] } +hyper-util = { version = "0.1", features = ["server-auto", "client-legacy"] } +time = { version = "0.3.36", features = ["macros"] } [dependencies.reqwest] version = "0.12.0" diff --git a/scripts/completions/aichat.bash b/scripts/completions/aichat.bash index 0d72747..f0b9bc3 100644 --- a/scripts/completions/aichat.bash +++ b/scripts/completions/aichat.bash @@ -19,7 +19,7 @@ _aichat() { case "${cmd}" in aichat) - opts="-m -r -s -e -c -f -H -S -w -h -V --model --role --session --save-session --execute --code --file --no-highlight --no-stream --wrap --light-theme --dry-run --info --list-models --list-roles --list-sessions --help --version" + opts="-m -r -s -e -c -f -H -S -w -h -V --model --role --session --save-session --serve --execute --code --file --no-highlight --no-stream --wrap --light-theme --dry-run --info --list-models --list-roles --list-sessions --help --version" if [[ ${cur} == -* || ${COMP_CWORD} -eq 1 ]] ; then COMPREPLY=( $(compgen -W "${opts}" -- "${cur}") ) return 0 diff --git a/scripts/completions/aichat.fish b/scripts/completions/aichat.fish index 736ca36..586b745 100644 --- a/scripts/completions/aichat.fish +++ b/scripts/completions/aichat.fish @@ -4,6 +4,7 @@ complete -c aichat -s s -l session -x -a"(aichat --list-sessions)" -d 'Start or complete -c aichat -s f -l file -d 'Include files with the message' -r -F complete -c aichat -s w -l wrap -d 'Control text wrapping (no, auto, )' complete -c aichat -l save-session -d 'Forces the session to be saved' +complete -c aichat -l serve -d 'Serve all LLMs as OpenAI-compatible API' complete -c aichat -s e -l execute -d 'Execute commands in natural language' complete -c aichat -s c -l code -d 'Output code only' complete -c aichat -s H -l no-highlight -d 'Turn off syntax highlighting' diff --git a/scripts/completions/aichat.nu b/scripts/completions/aichat.nu index 03917a6..95cda81 100644 --- a/scripts/completions/aichat.nu +++ b/scripts/completions/aichat.nu @@ -28,6 +28,7 @@ module completions { --role(-r): string@"nu-complete aichat role" # Select a role --session(-s): string@"nu-complete aichat role" # Start or join a session --save-session # Forces the session to be saved + --serve # Serve all LLMs as OpenAI-compatible API --execute(-e) # Execute commands in natural language --code(-c) # Output code only --file(-f): string # Include files with the message diff --git a/scripts/completions/aichat.ps1 b/scripts/completions/aichat.ps1 index afd9c07..7be4bcf 100644 --- a/scripts/completions/aichat.ps1 +++ b/scripts/completions/aichat.ps1 @@ -31,6 +31,7 @@ Register-ArgumentCompleter -Native -CommandName 'aichat' -ScriptBlock { [CompletionResult]::new('-w', '-w', [CompletionResultType]::ParameterName, 'Control text wrapping (no, auto, )') [CompletionResult]::new('--wrap', '--wrap', [CompletionResultType]::ParameterName, 'Control text wrapping (no, auto, )') [CompletionResult]::new('--save-session', '--save-session', [CompletionResultType]::ParameterName, 'Forces the session to be saved') + [CompletionResult]::new('--serve', '--serve', [CompletionResultType]::ParameterName, 'Serve all LLMs as OpenAI-compatible API') [CompletionResult]::new('-e', '-e', [CompletionResultType]::ParameterName, 'Execute commands in natural language') [CompletionResult]::new('--execute', '--execute', [CompletionResultType]::ParameterName, 'Execute commands in natural language') [CompletionResult]::new('-c', '-c', [CompletionResultType]::ParameterName, 'Output code only') diff --git a/scripts/completions/aichat.zsh b/scripts/completions/aichat.zsh index 66f9bbb..5f7a74b 100644 --- a/scripts/completions/aichat.zsh +++ b/scripts/completions/aichat.zsh @@ -26,6 +26,7 @@ _aichat() { '-w+[Control text wrapping (no, auto, )]:WRAP: ' \ '--wrap=[Control text wrapping (no, auto, )]:WRAP: ' \ '--save-session[Forces the session to be saved]' \ +'--serve[Serve all LLMs as OpenAI-compatible API]' \ '-e[Execute commands in natural language]' \ '--execute[Execute commands in natural language]' \ '-c[Output code only]' \ diff --git a/src/cli.rs b/src/cli.rs index 7cfd09a..aff608a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -15,6 +15,9 @@ pub struct Cli { /// Forces the session to be saved #[clap(long)] pub save_session: bool, + /// Serve all LLMs as OpenAI-compatible API + #[clap(long, value_name = "ADDRESS")] + pub serve: Option>, /// Execute commands in natural language #[clap(short = 'e', long)] pub execute: bool, diff --git a/src/client/reply_handler.rs b/src/client/reply_handler.rs index e11ea1d..e024685 100644 --- a/src/client/reply_handler.rs +++ b/src/client/reply_handler.rs @@ -19,7 +19,7 @@ impl ReplyHandler { } pub fn text(&mut self, text: &str) -> Result<()> { - debug!("ReplyText: {}", text); + // debug!("ReplyText: {}", text); if text.is_empty() { return Ok(()); } @@ -33,7 +33,7 @@ impl ReplyHandler { } pub fn done(&mut self) -> Result<()> { - debug!("ReplyDone"); + // debug!("ReplyDone"); let ret = self .sender .send(ReplyEvent::Done) diff --git a/src/config/mod.rs b/src/config/mod.rs index f86cda1..e1cec4d 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -79,9 +79,9 @@ pub struct Config { #[serde(skip)] pub model: Model, #[serde(skip)] - pub last_message: Option<(Input, String)>, + pub working_mode: WorkingMode, #[serde(skip)] - pub in_repl: bool, + pub last_message: Option<(Input, String)>, } impl Default for Config { @@ -110,8 +110,8 @@ impl Default for Config { role: None, session: None, model: Default::default(), + working_mode: WorkingMode::Command, last_message: None, - in_repl: false, } } } @@ -119,13 +119,13 @@ impl Default for Config { pub type GlobalConfig = Arc>; impl Config { - pub fn init(is_interactive: bool) -> Result { + pub fn init(working_mode: WorkingMode) -> Result { let config_path = Self::config_file()?; let api_key = env::var("OPENAI_API_KEY").ok(); let exist_config_path = config_path.exists(); - if is_interactive && api_key.is_none() && !exist_config_path { + if working_mode != WorkingMode::Command && api_key.is_none() && !exist_config_path { create_config_file(&config_path)?; } let mut config = if api_key.is_some() && !exist_config_path { @@ -143,14 +143,13 @@ impl Config { config.set_wrap(&wrap)?; } + config.working_mode = working_mode; config.load_roles()?; config.setup_model()?; config.setup_highlight(); config.setup_light_theme()?; - setup_logger()?; - Ok(config) } @@ -611,7 +610,7 @@ impl Config { let save_session = session.save_session(); if session.dirty && save_session != Some(false) { if save_session.is_none() || session.is_temp() { - if !self.in_repl { + if self.working_mode != WorkingMode::Repl { return Ok(()); } let ans = Confirm::new("Save session?").with_default(false).prompt()?; @@ -998,6 +997,13 @@ impl Keybindings { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum WorkingMode { + Command, + Repl, + Serve, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum State { Normal, @@ -1145,25 +1151,3 @@ fn complete_option_bool(value: Option) -> Vec { None => vec!["true".to_string(), "false".to_string()], } } - -#[cfg(debug_assertions)] -fn setup_logger() -> Result<()> { - use simplelog::{LevelFilter, WriteLogger}; - let file = std::fs::File::create(Config::local_path("debug.log")?)?; - let log_filter = match std::env::var("AICHAT_LOG_FILTER") { - Ok(v) => v, - Err(_) => "aichat".into(), - }; - let config = simplelog::ConfigBuilder::new() - .add_filter_allow(log_filter) - .set_thread_level(LevelFilter::Off) - .set_time_level(LevelFilter::Off) - .build(); - WriteLogger::init(log::LevelFilter::Debug, config, file)?; - Ok(()) -} - -#[cfg(not(debug_assertions))] -fn setup_logger() -> Result<()> { - Ok(()) -} diff --git a/src/logger.rs b/src/logger.rs new file mode 100644 index 0000000..f7ef2f5 --- /dev/null +++ b/src/logger.rs @@ -0,0 +1,40 @@ +use crate::config::WorkingMode; + +use anyhow::Result; +use log::LevelFilter; +use simplelog::{format_description, Config as LogConfig, ConfigBuilder}; + +#[cfg(debug_assertions)] +pub fn setup_logger(working_mode: WorkingMode) -> Result<()> { + let config = build_config(); + if working_mode == WorkingMode::Serve { + simplelog::SimpleLogger::init(LevelFilter::Debug, config)?; + } else { + let file = std::fs::File::create(crate::config::Config::local_path("debug.log")?)?; + simplelog::WriteLogger::init(LevelFilter::Debug, config, file)?; + } + Ok(()) +} + +#[cfg(not(debug_assertions))] +pub fn setup_logger(working_mode: WorkingMode) -> Result<()> { + let config = build_config(); + if working_mode == WorkingMode::Serve { + simplelog::SimpleLogger::init(log::LevelFilter::Info, config)?; + } + Ok(()) +} + +fn build_config() -> LogConfig { + let log_filter = match std::env::var("AICHAT_LOG_FILTER") { + Ok(v) => v, + Err(_) => "aichat".into(), + }; + ConfigBuilder::new() + .add_filter_allow(log_filter) + .set_time_format_custom(format_description!( + "[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z" + )) + .set_thread_level(LevelFilter::Off) + .build() +} diff --git a/src/main.rs b/src/main.rs index bc6b8f1..b396917 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,21 @@ mod cli; mod client; mod config; +mod logger; mod render; mod repl; +mod serve; +#[macro_use] +mod utils; #[macro_use] extern crate log; -#[macro_use] -mod utils; use crate::cli::Cli; use crate::client::{ensure_model_capabilities, init_client, list_models, send_stream}; -use crate::config::{Config, GlobalConfig, Input, CODE_ROLE, EXPLAIN_ROLE, SHELL_ROLE}; +use crate::config::{ + Config, GlobalConfig, Input, WorkingMode, CODE_ROLE, EXPLAIN_ROLE, SHELL_ROLE, +}; use crate::render::{render_error, MarkdownRender}; use crate::repl::Repl; use crate::utils::{ @@ -33,7 +37,21 @@ use tokio::sync::oneshot; async fn main() -> Result<()> { let cli = Cli::parse(); let text = cli.text(); - let config = Arc::new(RwLock::new(Config::init(text.is_none())?)); + let file = &cli.file; + let no_input = text.is_none() && file.is_empty(); + let working_mode = if cli.serve.is_some() { + WorkingMode::Serve + } else if no_input { + WorkingMode::Repl + } else { + WorkingMode::Command + }; + crate::logger::setup_logger(working_mode)?; + let config = Arc::new(RwLock::new(Config::init(working_mode)?)); + + if let Some(addr) = cli.serve { + return serve::run(config, addr).await; + } if cli.list_roles { config .read() @@ -89,20 +107,21 @@ async fn main() -> Result<()> { return Ok(()); } let text = aggregate_text(text)?; - let input = create_input(&config, text, &cli.file)?; if cli.execute { - match input { - Some(input) => { - execute(&config, input).await?; - return Ok(()); - } - None => bail!("No input text"), + if no_input { + bail!("No input"); } + let input = create_input(&config, text, file)?; + execute(&config, input).await?; + return Ok(()); } config.write().apply_prelude()?; - if let Err(err) = match input { - Some(input) => start_directive(&config, input, cli.no_stream, cli.code).await, - None => start_interactive(&config).await, + if let Err(err) = match no_input { + false => { + let input = create_input(&config, text, file)?; + start_directive(&config, input, cli.no_stream, cli.code).await + } + true => start_interactive(&config).await, } { let highlight = stderr().is_terminal() && config.read().highlight; render_error(err, highlight) @@ -232,19 +251,15 @@ fn aggregate_text(text: Option) -> Result> { Ok(text) } -fn create_input( - config: &GlobalConfig, - text: Option, - file: &[String], -) -> Result> { - if text.is_none() && file.is_empty() { - return Ok(None); - } +fn create_input(config: &GlobalConfig, text: Option, file: &[String]) -> Result { let input_context = config.read().input_context(); let input = if file.is_empty() { Input::from_str(&text.unwrap_or_default(), input_context) } else { Input::new(&text.unwrap_or_default(), file.to_vec(), input_context)? }; - Ok(Some(input)) + if input.is_empty() { + bail!("No input"); + } + Ok(input) } diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 3c0b98b..dd1c735 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -77,8 +77,6 @@ pub struct Repl { impl Repl { pub fn init(config: &GlobalConfig) -> Result { - config.write().in_repl = true; - let editor = Self::create_editor(config)?; let prompt = ReplPrompt::new(config); diff --git a/src/serve.rs b/src/serve.rs new file mode 100644 index 0000000..2286b61 --- /dev/null +++ b/src/serve.rs @@ -0,0 +1,364 @@ +use crate::{ + client::{init_client, ClientConfig, Message, Model, ReplyEvent, ReplyHandler, SendData}, + config::{Config, GlobalConfig}, + utils::create_abort_signal, +}; + +use anyhow::{anyhow, bail, Result}; +use bytes::Bytes; +use chrono::{Timelike, Utc}; +use futures_util::StreamExt; +use http::{Method, Response, StatusCode}; +use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; +use hyper::{ + body::{Frame, Incoming}, + service::service_fn, +}; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use parking_lot::RwLock; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::{convert::Infallible, sync::Arc}; +use tokio::{ + net::TcpListener, + sync::{ + mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + oneshot, + }, +}; +use tokio_graceful::Shutdown; +use tokio_stream::wrappers::UnboundedReceiverStream; + +const DEFAULT_ADDRESS: &str = "0.0.0.0:8080"; + +type AppResponse = Response>; + +pub async fn run(config: GlobalConfig, addr: Option) -> Result<()> { + let addr = match addr { + Some(addr) => { + if let Ok(port) = addr.parse::() { + format!("0.0.0.0:{port}") + } else { + addr + } + } + None => DEFAULT_ADDRESS.to_string(), + }; + let clients = config.read().clients.clone(); + let model = config.read().model.clone(); + let listener = TcpListener::bind(&addr).await?; + let server = Arc::new(Server { clients, model }); + let stop_server = server.run(listener).await?; + println!("Access the chat completion API at: http://{addr}/v1/chat/completions"); + shutdown_signal().await; + let _ = stop_server.send(()); + Ok(()) +} + +struct Server { + clients: Vec, + model: Model, +} + +impl Server { + async fn run(self: Arc, listener: TcpListener) -> Result> { + let (tx, rx) = oneshot::channel(); + tokio::spawn(async move { + let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() }); + let guard = shutdown.guard_weak(); + + loop { + tokio::select! { + res = listener.accept() => { + let Ok((cnx, _)) = res else { + continue; + }; + + let stream = TokioIo::new(cnx); + let server = self.clone(); + shutdown.spawn_task(async move { + let hyper_service = service_fn(move |request: hyper::Request| { + server.clone().handle(request) + }); + let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(stream, hyper_service) + .await; + }); + } + _ = guard.cancelled() => { + break; + } + } + } + }); + Ok(tx) + } + + async fn handle( + self: Arc, + req: hyper::Request, + ) -> std::result::Result { + let method = req.method().clone(); + let uri = req.uri().clone(); + let mut status = StatusCode::OK; + let res = if method == Method::POST && uri == "/v1/chat/completions" { + self.chat_completion(req).await + } else if method == Method::OPTIONS && uri == "/v1/chat/completions" { + status = StatusCode::NO_CONTENT; + Ok(Response::default()) + } else { + status = StatusCode::NOT_FOUND; + Err(anyhow!("The requested endpoint was not found.")) + }; + let mut res = match res { + Ok(res) => { + info!("{method} {uri} {}", status.as_u16()); + res + } + Err(err) => { + error!("{method} {uri} {} {err}", status.as_u16()); + ret_err(err) + } + }; + *res.status_mut() = status; + set_cors_header(&mut res); + Ok(res) + } + + async fn chat_completion(&self, req: hyper::Request) -> Result { + let req_body = req.collect().await?.to_bytes(); + let req_body: ChatCompletionReqBody = serde_json::from_slice(&req_body) + .map_err(|err| anyhow!("Invalid request body, {err}"))?; + + let ChatCompletionReqBody { + model, + messages, + temperature, + max_tokens, + stream, + } = req_body; + + let config = Config { + clients: self.clients.to_vec(), + model: self.model.clone(), + ..Default::default() + }; + let config = Arc::new(RwLock::new(config)); + if model != "default" && model != self.model.id() { + config.write().set_model(&model)?; + } + + let mut client = init_client(&config)?; + if max_tokens.is_some() { + client.set_model(client.model().clone().set_max_output_tokens(max_tokens)); + } + let abort = create_abort_signal(); + let http_client = client.build_client()?; + + let completion_id = generate_completion_id(); + let created = Utc::now().timestamp(); + + let send_data: SendData = SendData { + messages, + temperature, + stream, + }; + + if stream { + let (tx, mut rx) = unbounded_channel(); + tokio::spawn(async move { + let mut is_first = true; + let (tx2, rx2) = unbounded_channel(); + let mut handler = ReplyHandler::new(tx2, abort); + async fn map_event( + mut rx: UnboundedReceiver, + tx: &UnboundedSender, + is_first: &mut bool, + ) { + while let Some(reply_event) = rx.recv().await { + if *is_first { + let _ = tx.send(ResEvent::First(None)); + *is_first = false; + } + match reply_event { + ReplyEvent::Text(text) => { + let _ = tx.send(ResEvent::Text(text)); + } + ReplyEvent::Done => { + let _ = tx.send(ResEvent::Done); + } + } + } + } + tokio::select! { + _ = map_event(rx2, &tx, &mut is_first) => {} + ret = client.send_message_streaming_inner(&http_client, &mut handler, send_data) => { + if let Err(err) = ret { + send_first_event(&tx, Some(format!("{err:?}")), &mut is_first) + } + } + } + }); + + let first_event = rx.recv().await; + + if let Some(ResEvent::First(Some(err))) = first_event { + bail!("{err}"); + } + + let shared: Arc<(String, i64)> = Arc::new((completion_id, created)); + let stream = UnboundedReceiverStream::new(rx); + let stream = stream.filter_map(move |res_event| { + let shared = shared.clone(); + async move { + match res_event { + ResEvent::Text(text) => { + Some(Ok(create_frame(&shared.0, shared.1, &text, false))) + } + ResEvent::Done => Some(Ok(create_frame(&shared.0, shared.1, "", true))), + _ => None, + } + } + }); + let res = Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "text/event-stream") + .header("Cache-Control", "no-cache") + .header("Connection", "keep-alive") + .body(BodyExt::boxed(StreamBody::new(stream)))?; + Ok(res) + } else { + let content = client.send_message_inner(&http_client, send_data).await?; + let res = Response::builder() + .header("Content-Type", "application/json") + .body(Full::new(ret_non_stream(&completion_id, created, &content)).boxed())?; + Ok(res) + } + } +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionReqBody { + model: String, + messages: Vec, + temperature: Option, + max_tokens: Option, + #[serde(default)] + stream: bool, +} + +#[derive(Debug)] +enum ResEvent { + First(Option), + Text(String), + Done, +} + +fn send_first_event(tx: &UnboundedSender, data: Option, is_first: &mut bool) { + if *is_first { + let _ = tx.send(ResEvent::First(data)); + *is_first = false; + } +} + +async fn shutdown_signal() { + tokio::signal::ctrl_c() + .await + .expect("Failed to install CTRL+C signal handler") +} + +fn generate_completion_id() -> String { + let random_id = chrono::Utc::now().nanosecond(); + format!("chatcmpl-{}", random_id) +} + +fn set_cors_header(res: &mut AppResponse) { + res.headers_mut().insert( + hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, + hyper::header::HeaderValue::from_static("*"), + ); + res.headers_mut().insert( + hyper::header::ACCESS_CONTROL_ALLOW_METHODS, + hyper::header::HeaderValue::from_static("GET,POST,PUT,PATCH,DELETE"), + ); + res.headers_mut().insert( + hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, + hyper::header::HeaderValue::from_static("Content-Type,Authorization"), + ); +} + +fn create_frame(id: &str, created: i64, content: &str, done: bool) -> Frame { + let (delta, finish_reason) = if done { + (json!({}), "stop".into()) + } else { + let delta = if content.is_empty() { + json!({ "role": "assistant", "content": content }) + } else { + json!({ "content": content }) + }; + (delta, Value::Null) + }; + let mut value = json!({ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + }, + ], + }); + let output = if done { + value["usage"] = json!({ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }); + format!("data: {value}\n\ndata: [DONE]\n\n") + } else { + format!("data: {value}\n\n") + }; + Frame::data(Bytes::from(output)) +} + +fn ret_non_stream(id: &str, created: i64, content: &str) -> Bytes { + let res_body = json!({ + "id": id, + "object": "chat.completion", + "created": created, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + }, + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + }); + Bytes::from(res_body.to_string()) +} + +fn ret_err(err: T) -> AppResponse { + let data = json!({ + "error": { + "message": err.to_string(), + "type": "invalid_request_error", + }, + }); + Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(Full::new(Bytes::from(data.to_string())).boxed()) + .unwrap() +}