From ea2e128bc4b8f96193972b9dbe70149eb2ba0675 Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Mon, 8 Aug 2022 01:10:32 -0500 Subject: [PATCH] Refactor to use distant manager (#112) --- .config/nextest.toml | 4 + .github/workflows/ci-all.yml | 49 - .github/workflows/ci-linux.yml | 46 - .github/workflows/ci-macos.yml | 43 - .github/workflows/ci-windows.yml | 45 - .github/workflows/ci.yml | 191 + BUILDING.md | 5 + CHANGELOG.md | 37 +- Cargo.lock | 816 +++- Cargo.toml | 33 +- README.md | 115 +- distant-core/Cargo.toml | 19 +- distant-core/README.md | 63 +- distant-core/src/api.rs | 626 ++++ distant-core/src/api/local.rs | 2124 +++++++++++ .../process/mod.rs => api/local/process.rs} | 4 +- .../distant => api/local}/process/pty.rs | 28 +- .../distant => api/local}/process/simple.rs | 37 +- .../local}/process/simple/tasks.rs | 7 +- .../distant => api/local}/process/wait.rs | 0 distant-core/src/api/local/state.rs | 68 + distant-core/src/api/local/state/process.rs | 231 ++ .../src/api/local/state/process/instance.rs | 230 ++ distant-core/src/api/local/state/watcher.rs | 286 ++ .../src/api/local/state/watcher/path.rs | 212 ++ distant-core/src/api/reply.rs | 29 + distant-core/src/client.rs | 18 + distant-core/src/client/ext.rs | 424 +++ .../src/client/{lsp/mod.rs => lsp.rs} | 313 +- .../src/client/lsp/{data.rs => msg.rs} | 536 +-- distant-core/src/client/mod.rs | 10 - distant-core/src/client/process.rs | 843 +++-- distant-core/src/client/session/ext.rs | 514 --- distant-core/src/client/session/info.rs | 235 -- distant-core/src/client/session/mailbox.rs | 84 - distant-core/src/client/session/mod.rs | 504 --- distant-core/src/client/utils.rs | 13 - distant-core/src/client/watcher.rs | 197 +- distant-core/src/constants.rs | 13 - distant-core/src/credentials.rs | 133 + distant-core/src/data.rs | 1505 +------- distant-core/src/data/change.rs | 506 +++ distant-core/src/data/clap_impl.rs | 106 + distant-core/src/data/cmd.rs | 52 + distant-core/src/data/error.rs | 269 ++ distant-core/src/data/filesystem.rs | 45 + distant-core/src/data/map.rs | 244 ++ distant-core/src/data/metadata.rs | 404 ++ distant-core/src/data/pty.rs | 137 + distant-core/src/data/system.rs | 45 + distant-core/src/data/utils.rs | 27 + distant-core/src/lib.rs | 21 +- distant-core/src/manager.rs | 7 + distant-core/src/manager/client.rs | 761 ++++ distant-core/src/manager/client/config.rs | 85 + distant-core/src/manager/client/ext.rs | 14 + distant-core/src/manager/client/ext/tcp.rs | 50 + distant-core/src/manager/client/ext/unix.rs | 54 + .../src/manager/client/ext/windows.rs | 91 + distant-core/src/manager/data.rs | 20 + distant-core/src/manager/data/destination.rs | 266 ++ distant-core/src/manager/data/extra.rs | 2 + distant-core/src/manager/data/id.rs | 5 + distant-core/src/manager/data/info.rs | 15 + distant-core/src/manager/data/list.rs | 58 + distant-core/src/manager/data/request.rs | 72 + distant-core/src/manager/data/response.rs | 53 + distant-core/src/manager/server.rs | 698 ++++ distant-core/src/manager/server/config.rs | 31 + distant-core/src/manager/server/connection.rs | 201 + distant-core/src/manager/server/ext.rs | 14 + distant-core/src/manager/server/ext/tcp.rs | 30 + distant-core/src/manager/server/ext/unix.rs | 50 + .../src/manager/server/ext/windows.rs | 48 + distant-core/src/manager/server/handler.rs | 69 + distant-core/src/manager/server/ref.rs | 73 + distant-core/src/net/listener.rs | 162 - distant-core/src/net/transport/inmemory.rs | 583 --- distant-core/src/net/transport/mod.rs | 399 -- distant-core/src/net/transport/tcp.rs | 38 - distant-core/src/net/transport/unix.rs | 37 - distant-core/src/serde_str.rs | 45 + distant-core/src/server/distant/handler.rs | 3281 ----------------- distant-core/src/server/distant/mod.rs | 362 -- distant-core/src/server/distant/state.rs | 232 -- distant-core/src/server/mod.rs | 8 - distant-core/src/server/relay.rs | 382 -- distant-core/src/server/utils.rs | 289 -- distant-core/tests/manager_tests.rs | 96 + distant-core/tests/stress/distant/watch.rs | 8 +- distant-core/tests/stress/fixtures.rs | 42 +- distant-net/Cargo.toml | 37 + distant-net/README.md | 49 + distant-net/src/any.rs | 29 + distant-net/src/auth.rs | 122 + distant-net/src/auth/client.rs | 817 ++++ distant-net/src/auth/handshake.rs | 62 + distant-net/src/auth/handshake/pkb.rs | 60 + distant-net/src/auth/handshake/salt.rs | 111 + distant-net/src/auth/server.rs | 653 ++++ distant-net/src/client.rs | 163 + distant-net/src/client/channel.rs | 236 ++ distant-net/src/client/channel/mailbox.rs | 128 + distant-net/src/client/ext.rs | 14 + distant-net/src/client/ext/tcp.rs | 49 + distant-net/src/client/ext/unix.rs | 54 + distant-net/src/client/ext/windows.rs | 86 + .../codec/mod.rs => distant-net/src/codec.rs | 0 .../src}/codec/plain.rs | 2 +- .../src}/codec/xchacha20poly1305.rs | 15 +- distant-net/src/id.rs | 2 + .../src/net/mod.rs => distant-net/src/key.rs | 17 +- distant-net/src/lib.rs | 27 + distant-net/src/listener.rs | 34 + distant-net/src/listener/mapped.rs | 40 + distant-net/src/listener/mpsc.rs | 31 + distant-net/src/listener/oneshot.rs | 84 + distant-net/src/listener/tcp.rs | 167 + distant-net/src/listener/unix.rs | 212 ++ distant-net/src/listener/windows.rs | 162 + distant-net/src/packet.rs | 68 + .../src/server => distant-net/src}/port.rs | 12 +- distant-net/src/server.rs | 49 + distant-net/src/server/connection.rs | 51 + distant-net/src/server/context.rs | 17 + distant-net/src/server/ext.rs | 195 + distant-net/src/server/ext/tcp.rs | 94 + distant-net/src/server/ext/unix.rs | 97 + distant-net/src/server/ext/windows.rs | 109 + distant-net/src/server/ref.rs | 120 + distant-net/src/server/ref/tcp.rs | 39 + distant-net/src/server/ref/unix.rs | 38 + distant-net/src/server/ref/windows.rs | 38 + distant-net/src/server/reply.rs | 198 + distant-net/src/server/state.rs | 23 + distant-net/src/transport.rs | 112 + distant-net/src/transport/framed.rs | 209 ++ distant-net/src/transport/framed/read.rs | 109 + distant-net/src/transport/framed/test.rs | 12 + distant-net/src/transport/framed/write.rs | 72 + distant-net/src/transport/inmemory.rs | 225 ++ distant-net/src/transport/inmemory/read.rs | 249 ++ distant-net/src/transport/inmemory/write.rs | 147 + distant-net/src/transport/mpsc.rs | 66 + distant-net/src/transport/mpsc/read.rs | 22 + distant-net/src/transport/mpsc/write.rs | 25 + distant-net/src/transport/router.rs | 370 ++ distant-net/src/transport/tcp.rs | 196 + distant-net/src/transport/unix.rs | 187 + distant-net/src/transport/untyped.rs | 61 + distant-net/src/transport/windows.rs | 202 + distant-net/src/transport/windows/pipe.rs | 101 + distant-net/src/utils.rs | 20 + distant-net/tests/auth.rs | 169 + distant-net/tests/lib.rs | 1 + distant-ssh2/Cargo.toml | 19 +- distant-ssh2/README.md | 63 +- distant-ssh2/src/api.rs | 842 +++++ distant-ssh2/src/handler.rs | 787 ---- distant-ssh2/src/lib.rs | 443 ++- distant-ssh2/src/process.rs | 200 +- distant-ssh2/src/utils.rs | 293 ++ distant-ssh2/tests/lib.rs | 1 + distant-ssh2/tests/ssh2/client.rs | 1447 ++++++++ distant-ssh2/tests/ssh2/launched.rs | 1469 ++++++++ distant-ssh2/tests/ssh2/mod.rs | 4 +- distant-ssh2/tests/ssh2/session.rs | 1943 ---------- distant-ssh2/tests/ssh2/ssh.rs | 19 + distant-ssh2/tests/sshd.rs | 436 --- distant-ssh2/tests/sshd/mod.rs | 677 ++++ distant-ssh2/tests/utils/mod.rs | 36 + rustfmt.toml | 11 + src/cli.rs | 157 + src/cli/cache.rs | 102 + src/cli/cache/id.rs | 105 + src/cli/client.rs | 176 + src/{ => cli/client}/msg.rs | 25 +- src/cli/commands.rs | 25 + src/cli/commands/client.rs | 738 ++++ src/{ => cli/commands/client}/buf.rs | 0 .../commands/client/format.rs} | 195 +- src/{ => cli/commands/client}/link.rs | 5 +- src/cli/commands/client/lsp.rs | 51 + src/cli/commands/client/shell.rs | 113 + src/{ => cli/commands/client}/stdin.rs | 0 src/cli/commands/generate.rs | 93 + src/cli/commands/manager.rs | 431 +++ src/cli/commands/manager/handlers.rs | 404 ++ src/cli/commands/server.rs | 238 ++ src/cli/manager.rs | 77 + src/cli/spawner.rs | 211 ++ src/config.rs | 126 + src/config/client.rs | 24 + src/config/client/action.rs | 7 + src/config/client/launch.rs | 162 + src/config/client/repl.rs | 7 + src/config/common.rs | 52 + src/config/generate.rs | 9 + src/config/manager.rs | 45 + src/config/network.rs | 105 + src/config/server.rs | 14 + src/config/server/listen.rs | 149 + src/constants.rs | 25 - src/environment.rs | 142 - src/exit.rs | 148 - src/lib.rs | 111 +- src/main.rs | 52 +- src/opt.rs | 755 ---- src/paths.rs | 112 + src/session.rs | 163 - src/subcommand/action.rs | 218 -- src/subcommand/launch.rs | 374 -- src/subcommand/listen.rs | 145 - src/subcommand/lsp.rs | 118 - src/subcommand/mod.rs | 171 - src/subcommand/shell.rs | 165 - src/utils.rs | 4 - src/win_service.rs | 223 ++ tests/cli/action/copy.rs | 132 +- tests/cli/action/dir_create.rs | 119 +- tests/cli/action/dir_read.rs | 473 +-- tests/cli/action/exists.rs | 64 +- tests/cli/action/file_append.rs | 94 +- tests/cli/action/file_append_text.rs | 94 +- tests/cli/action/file_read.rs | 84 +- tests/cli/action/file_read_text.rs | 84 +- tests/cli/action/file_write.rs | 93 +- tests/cli/action/file_write_text.rs | 93 +- tests/cli/action/metadata.rs | 229 +- tests/cli/action/proc_spawn.rs | 482 +-- tests/cli/action/remove.rs | 169 +- tests/cli/action/rename.rs | 135 +- tests/cli/action/system_info.rs | 40 +- tests/cli/action/watch.rs | 390 +- tests/cli/fixtures.rs | 300 +- tests/cli/fixtures/repl.rs | 213 ++ tests/cli/mod.rs | 2 + tests/cli/repl/copy.rs | 111 + tests/cli/repl/dir_create.rs | 92 + tests/cli/repl/dir_read.rs | 264 ++ tests/cli/repl/exists.rs | 63 + tests/cli/repl/file_append.rs | 75 + tests/cli/repl/file_append_text.rs | 75 + tests/cli/repl/file_read.rs | 60 + tests/cli/repl/file_read_text.rs | 60 + tests/cli/repl/file_write.rs | 69 + tests/cli/repl/file_write_text.rs | 69 + tests/cli/repl/metadata.rs | 159 + tests/cli/repl/mod.rs | 16 + tests/cli/repl/proc_spawn.rs | 273 ++ tests/cli/repl/remove.rs | 135 + tests/cli/repl/rename.rs | 114 + tests/cli/repl/system_info.rs | 29 + tests/cli/repl/watch.rs | 256 ++ tests/cli/scripts.rs | 121 + tests/cli/utils.rs | 139 +- tests/cli/utils/reader.rs | 106 + 257 files changed, 30653 insertions(+), 18828 deletions(-) create mode 100644 .config/nextest.toml delete mode 100644 .github/workflows/ci-all.yml delete mode 100644 .github/workflows/ci-linux.yml delete mode 100644 .github/workflows/ci-macos.yml delete mode 100644 .github/workflows/ci-windows.yml create mode 100644 .github/workflows/ci.yml create mode 100644 distant-core/src/api.rs create mode 100644 distant-core/src/api/local.rs rename distant-core/src/{server/distant/process/mod.rs => api/local/process.rs} (98%) rename distant-core/src/{server/distant => api/local}/process/pty.rs (93%) rename distant-core/src/{server/distant => api/local}/process/simple.rs (86%) rename distant-core/src/{server/distant => api/local}/process/simple/tasks.rs (91%) rename distant-core/src/{server/distant => api/local}/process/wait.rs (100%) create mode 100644 distant-core/src/api/local/state.rs create mode 100644 distant-core/src/api/local/state/process.rs create mode 100644 distant-core/src/api/local/state/process/instance.rs create mode 100644 distant-core/src/api/local/state/watcher.rs create mode 100644 distant-core/src/api/local/state/watcher/path.rs create mode 100644 distant-core/src/api/reply.rs create mode 100644 distant-core/src/client.rs create mode 100644 distant-core/src/client/ext.rs rename distant-core/src/client/{lsp/mod.rs => lsp.rs} (81%) rename distant-core/src/client/lsp/{data.rs => msg.rs} (52%) delete mode 100644 distant-core/src/client/mod.rs delete mode 100644 distant-core/src/client/session/ext.rs delete mode 100644 distant-core/src/client/session/info.rs delete mode 100644 distant-core/src/client/session/mailbox.rs delete mode 100644 distant-core/src/client/session/mod.rs delete mode 100644 distant-core/src/client/utils.rs create mode 100644 distant-core/src/credentials.rs create mode 100644 distant-core/src/data/change.rs create mode 100644 distant-core/src/data/clap_impl.rs create mode 100644 distant-core/src/data/cmd.rs create mode 100644 distant-core/src/data/error.rs create mode 100644 distant-core/src/data/filesystem.rs create mode 100644 distant-core/src/data/map.rs create mode 100644 distant-core/src/data/metadata.rs create mode 100644 distant-core/src/data/pty.rs create mode 100644 distant-core/src/data/system.rs create mode 100644 distant-core/src/data/utils.rs create mode 100644 distant-core/src/manager.rs create mode 100644 distant-core/src/manager/client.rs create mode 100644 distant-core/src/manager/client/config.rs create mode 100644 distant-core/src/manager/client/ext.rs create mode 100644 distant-core/src/manager/client/ext/tcp.rs create mode 100644 distant-core/src/manager/client/ext/unix.rs create mode 100644 distant-core/src/manager/client/ext/windows.rs create mode 100644 distant-core/src/manager/data.rs create mode 100644 distant-core/src/manager/data/destination.rs create mode 100644 distant-core/src/manager/data/extra.rs create mode 100644 distant-core/src/manager/data/id.rs create mode 100644 distant-core/src/manager/data/info.rs create mode 100644 distant-core/src/manager/data/list.rs create mode 100644 distant-core/src/manager/data/request.rs create mode 100644 distant-core/src/manager/data/response.rs create mode 100644 distant-core/src/manager/server.rs create mode 100644 distant-core/src/manager/server/config.rs create mode 100644 distant-core/src/manager/server/connection.rs create mode 100644 distant-core/src/manager/server/ext.rs create mode 100644 distant-core/src/manager/server/ext/tcp.rs create mode 100644 distant-core/src/manager/server/ext/unix.rs create mode 100644 distant-core/src/manager/server/ext/windows.rs create mode 100644 distant-core/src/manager/server/handler.rs create mode 100644 distant-core/src/manager/server/ref.rs delete mode 100644 distant-core/src/net/listener.rs delete mode 100644 distant-core/src/net/transport/inmemory.rs delete mode 100644 distant-core/src/net/transport/mod.rs delete mode 100644 distant-core/src/net/transport/tcp.rs delete mode 100644 distant-core/src/net/transport/unix.rs create mode 100644 distant-core/src/serde_str.rs delete mode 100644 distant-core/src/server/distant/handler.rs delete mode 100644 distant-core/src/server/distant/mod.rs delete mode 100644 distant-core/src/server/distant/state.rs delete mode 100644 distant-core/src/server/mod.rs delete mode 100644 distant-core/src/server/relay.rs delete mode 100644 distant-core/src/server/utils.rs create mode 100644 distant-core/tests/manager_tests.rs create mode 100644 distant-net/Cargo.toml create mode 100644 distant-net/README.md create mode 100644 distant-net/src/any.rs create mode 100644 distant-net/src/auth.rs create mode 100644 distant-net/src/auth/client.rs create mode 100644 distant-net/src/auth/handshake.rs create mode 100644 distant-net/src/auth/handshake/pkb.rs create mode 100644 distant-net/src/auth/handshake/salt.rs create mode 100644 distant-net/src/auth/server.rs create mode 100644 distant-net/src/client.rs create mode 100644 distant-net/src/client/channel.rs create mode 100644 distant-net/src/client/channel/mailbox.rs create mode 100644 distant-net/src/client/ext.rs create mode 100644 distant-net/src/client/ext/tcp.rs create mode 100644 distant-net/src/client/ext/unix.rs create mode 100644 distant-net/src/client/ext/windows.rs rename distant-core/src/net/transport/codec/mod.rs => distant-net/src/codec.rs (100%) rename {distant-core/src/net/transport => distant-net/src}/codec/plain.rs (99%) rename {distant-core/src/net/transport => distant-net/src}/codec/xchacha20poly1305.rs (97%) create mode 100644 distant-net/src/id.rs rename distant-core/src/net/mod.rs => distant-net/src/key.rs (87%) create mode 100644 distant-net/src/lib.rs create mode 100644 distant-net/src/listener.rs create mode 100644 distant-net/src/listener/mapped.rs create mode 100644 distant-net/src/listener/mpsc.rs create mode 100644 distant-net/src/listener/oneshot.rs create mode 100644 distant-net/src/listener/tcp.rs create mode 100644 distant-net/src/listener/unix.rs create mode 100644 distant-net/src/listener/windows.rs create mode 100644 distant-net/src/packet.rs rename {distant-core/src/server => distant-net/src}/port.rs (94%) create mode 100644 distant-net/src/server.rs create mode 100644 distant-net/src/server/connection.rs create mode 100644 distant-net/src/server/context.rs create mode 100644 distant-net/src/server/ext.rs create mode 100644 distant-net/src/server/ext/tcp.rs create mode 100644 distant-net/src/server/ext/unix.rs create mode 100644 distant-net/src/server/ext/windows.rs create mode 100644 distant-net/src/server/ref.rs create mode 100644 distant-net/src/server/ref/tcp.rs create mode 100644 distant-net/src/server/ref/unix.rs create mode 100644 distant-net/src/server/ref/windows.rs create mode 100644 distant-net/src/server/reply.rs create mode 100644 distant-net/src/server/state.rs create mode 100644 distant-net/src/transport.rs create mode 100644 distant-net/src/transport/framed.rs create mode 100644 distant-net/src/transport/framed/read.rs create mode 100644 distant-net/src/transport/framed/test.rs create mode 100644 distant-net/src/transport/framed/write.rs create mode 100644 distant-net/src/transport/inmemory.rs create mode 100644 distant-net/src/transport/inmemory/read.rs create mode 100644 distant-net/src/transport/inmemory/write.rs create mode 100644 distant-net/src/transport/mpsc.rs create mode 100644 distant-net/src/transport/mpsc/read.rs create mode 100644 distant-net/src/transport/mpsc/write.rs create mode 100644 distant-net/src/transport/router.rs create mode 100644 distant-net/src/transport/tcp.rs create mode 100644 distant-net/src/transport/unix.rs create mode 100644 distant-net/src/transport/untyped.rs create mode 100644 distant-net/src/transport/windows.rs create mode 100644 distant-net/src/transport/windows/pipe.rs create mode 100644 distant-net/src/utils.rs create mode 100644 distant-net/tests/auth.rs create mode 100644 distant-net/tests/lib.rs create mode 100644 distant-ssh2/src/api.rs delete mode 100644 distant-ssh2/src/handler.rs create mode 100644 distant-ssh2/src/utils.rs create mode 100644 distant-ssh2/tests/ssh2/client.rs create mode 100644 distant-ssh2/tests/ssh2/launched.rs delete mode 100644 distant-ssh2/tests/ssh2/session.rs create mode 100644 distant-ssh2/tests/ssh2/ssh.rs delete mode 100644 distant-ssh2/tests/sshd.rs create mode 100644 distant-ssh2/tests/sshd/mod.rs create mode 100644 distant-ssh2/tests/utils/mod.rs create mode 100644 rustfmt.toml create mode 100644 src/cli.rs create mode 100644 src/cli/cache.rs create mode 100644 src/cli/cache/id.rs create mode 100644 src/cli/client.rs rename src/{ => cli/client}/msg.rs (81%) create mode 100644 src/cli/commands.rs create mode 100644 src/cli/commands/client.rs rename src/{ => cli/commands/client}/buf.rs (100%) rename src/{output.rs => cli/commands/client/format.rs} (67%) rename src/{ => cli/commands/client}/link.rs (95%) create mode 100644 src/cli/commands/client/lsp.rs create mode 100644 src/cli/commands/client/shell.rs rename src/{ => cli/commands/client}/stdin.rs (100%) create mode 100644 src/cli/commands/generate.rs create mode 100644 src/cli/commands/manager.rs create mode 100644 src/cli/commands/manager/handlers.rs create mode 100644 src/cli/commands/server.rs create mode 100644 src/cli/manager.rs create mode 100644 src/cli/spawner.rs create mode 100644 src/config.rs create mode 100644 src/config/client.rs create mode 100644 src/config/client/action.rs create mode 100644 src/config/client/launch.rs create mode 100644 src/config/client/repl.rs create mode 100644 src/config/common.rs create mode 100644 src/config/generate.rs create mode 100644 src/config/manager.rs create mode 100644 src/config/network.rs create mode 100644 src/config/server.rs create mode 100644 src/config/server/listen.rs delete mode 100644 src/environment.rs delete mode 100644 src/exit.rs delete mode 100644 src/opt.rs create mode 100644 src/paths.rs delete mode 100644 src/session.rs delete mode 100644 src/subcommand/action.rs delete mode 100644 src/subcommand/launch.rs delete mode 100644 src/subcommand/listen.rs delete mode 100644 src/subcommand/lsp.rs delete mode 100644 src/subcommand/mod.rs delete mode 100644 src/subcommand/shell.rs delete mode 100644 src/utils.rs create mode 100644 src/win_service.rs create mode 100644 tests/cli/fixtures/repl.rs create mode 100644 tests/cli/repl/copy.rs create mode 100644 tests/cli/repl/dir_create.rs create mode 100644 tests/cli/repl/dir_read.rs create mode 100644 tests/cli/repl/exists.rs create mode 100644 tests/cli/repl/file_append.rs create mode 100644 tests/cli/repl/file_append_text.rs create mode 100644 tests/cli/repl/file_read.rs create mode 100644 tests/cli/repl/file_read_text.rs create mode 100644 tests/cli/repl/file_write.rs create mode 100644 tests/cli/repl/file_write_text.rs create mode 100644 tests/cli/repl/metadata.rs create mode 100644 tests/cli/repl/mod.rs create mode 100644 tests/cli/repl/proc_spawn.rs create mode 100644 tests/cli/repl/remove.rs create mode 100644 tests/cli/repl/rename.rs create mode 100644 tests/cli/repl/system_info.rs create mode 100644 tests/cli/repl/watch.rs create mode 100644 tests/cli/scripts.rs create mode 100644 tests/cli/utils/reader.rs diff --git a/.config/nextest.toml b/.config/nextest.toml new file mode 100644 index 0000000..b3bdf3a --- /dev/null +++ b/.config/nextest.toml @@ -0,0 +1,4 @@ +[profile.ci] +fail-fast = false +retries = 2 +slow-timeout = { period = "60s", terminate-after = 3 } diff --git a/.github/workflows/ci-all.yml b/.github/workflows/ci-all.yml deleted file mode 100644 index 659e1ae..0000000 --- a/.github/workflows/ci-all.yml +++ /dev/null @@ -1,49 +0,0 @@ -name: CI (All) - -on: - push: - branches: - - master - pull_request: - branches: - - master - -jobs: - clippy: - name: Lint with clippy - runs-on: ubuntu-latest - env: - RUSTFLAGS: -Dwarnings - steps: - - uses: actions/checkout@v2 - - name: Install Rust (clippy) - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - components: clippy - - uses: Swatinem/rust-cache@v1 - - name: Check Cargo availability - run: cargo --version - - name: distant-core (all features) - run: cargo clippy -p distant-core --all-targets --verbose --all-features - - name: distant-ssh2 (all features) - run: cargo clippy -p distant-ssh2 --all-targets --verbose --all-features - - name: distant (all features) - run: cargo clippy --all-targets --verbose --all-features - - rustfmt: - name: Verify code formatting - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: Install Rust (rustfmt) - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - components: rustfmt - - uses: Swatinem/rust-cache@v1 - - name: Check Cargo availability - run: cargo --version - - run: cargo fmt --all -- --check diff --git a/.github/workflows/ci-linux.yml b/.github/workflows/ci-linux.yml deleted file mode 100644 index fdafc97..0000000 --- a/.github/workflows/ci-linux.yml +++ /dev/null @@ -1,46 +0,0 @@ -name: CI (Linux) - -on: - push: - branches: - - master - pull_request: - branches: - - master - -jobs: - tests: - name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}" - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - include: - - { rust: stable, os: ubuntu-latest } - - { rust: 1.51.0, os: ubuntu-latest } - steps: - - uses: actions/checkout@v2 - - name: Install Rust ${{ matrix.rust }} - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: ${{ matrix.rust }} - - uses: Swatinem/rust-cache@v1 - - name: Check Cargo availability - run: cargo --version - - name: Run core tests (default features) - run: cargo test --release --verbose -p distant-core - - name: Run core tests (all features) - run: cargo test --release --verbose --all-features -p distant-core - - name: Ensure /run/sshd exists on Unix - run: mkdir -p /run/sshd - - name: Run ssh2 tests (default features) - run: cargo test --release --verbose -p distant-ssh2 - - name: Run ssh2 tests (all features) - run: cargo test --release --verbose --all-features -p distant-ssh2 - - name: Run CLI tests - run: cargo test --release --verbose - shell: bash - - name: Run CLI tests (no default features) - run: cargo test --release --verbose --no-default-features - shell: bash diff --git a/.github/workflows/ci-macos.yml b/.github/workflows/ci-macos.yml deleted file mode 100644 index 7ac59be..0000000 --- a/.github/workflows/ci-macos.yml +++ /dev/null @@ -1,43 +0,0 @@ -name: CI (MacOS) - -on: - push: - branches: - - master - pull_request: - branches: - - master - -jobs: - tests: - name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}" - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - include: - - { rust: stable, os: macos-latest } - steps: - - uses: actions/checkout@v2 - - name: Install Rust ${{ matrix.rust }} - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: ${{ matrix.rust }} - - uses: Swatinem/rust-cache@v1 - - name: Check Cargo availability - run: cargo --version - - name: Run core tests (default features) - run: cargo test --release --verbose -p distant-core - - name: Run core tests (all features) - run: cargo test --release --verbose --all-features -p distant-core - - name: Run ssh2 tests (default features) - run: cargo test --release --verbose -p distant-ssh2 - - name: Run ssh2 tests (all features) - run: cargo test --release --verbose --all-features -p distant-ssh2 - - name: Run CLI tests - run: cargo test --release --verbose - shell: bash - - name: Run CLI tests (no default features) - run: cargo test --release --verbose --no-default-features - shell: bash diff --git a/.github/workflows/ci-windows.yml b/.github/workflows/ci-windows.yml deleted file mode 100644 index 7be36cc..0000000 --- a/.github/workflows/ci-windows.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: CI (Windows) - -on: - push: - branches: - - master - pull_request: - branches: - - master - -jobs: - tests: - name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}" - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - include: - - { rust: stable, os: windows-latest, target: x86_64-pc-windows-msvc } - steps: - - uses: actions/checkout@v2 - - name: Install Rust ${{ matrix.rust }} - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: ${{ matrix.rust }} - target: ${{ matrix.target }} - - uses: Swatinem/rust-cache@v1 - - name: Check Cargo availability - run: cargo --version - - uses: Vampire/setup-wsl@v1 - - name: Run distant-core tests (default features) - run: cargo test --release --verbose -p distant-core - - name: Run distant-core tests (all features) - run: cargo test --release --verbose --all-features -p distant-core - - name: Build distant-ssh2 (default features) - run: cargo build --release --verbose -p distant-ssh2 - - name: Build distant-ssh2 (all features) - run: cargo build --release --verbose --all-features -p distant-ssh2 - - name: Build CLI - run: cargo build --release --verbose - shell: bash - - name: Build CLI (no default features) - run: cargo build --release --verbose --no-default-features - shell: bash diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..f2d43ce --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,191 @@ +name: CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + clippy: + name: "Lint with clippy (${{ matrix.os }})" + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + include: + - { os: windows-latest } + - { os: ubuntu-latest } + env: + RUSTFLAGS: -Dwarnings + steps: + - name: Ensure windows git checkout keeps \n line ending + run: | + git config --system core.autocrlf false + git config --system core.eol lf + if: matrix.os == 'windows-latest' + - uses: actions/checkout@v2 + - name: Install Rust (clippy) + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + components: clippy + - uses: Swatinem/rust-cache@v1 + - name: Check Cargo availability + run: cargo --version + - name: distant-core (all features) + run: cargo clippy -p distant-core --all-targets --verbose --all-features + - name: distant-ssh2 (all features) + run: cargo clippy -p distant-ssh2 --all-targets --verbose --all-features + - name: distant (all features) + run: cargo clippy --all-targets --verbose --all-features + rustfmt: + name: "Verify code formatting (${{ matrix.os }})" + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + include: + - { os: windows-latest } + - { os: ubuntu-latest } + steps: + - name: Ensure windows git checkout keeps \n line ending + run: | + git config --system core.autocrlf false + git config --system core.eol lf + if: matrix.os == 'windows-latest' + - uses: actions/checkout@v2 + - name: Install Rust (rustfmt) + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + components: rustfmt + - uses: Swatinem/rust-cache@v1 + - name: Check Cargo availability + run: cargo --version + - run: cargo fmt --all -- --check + tests: + name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}" + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + include: + - { rust: stable, os: windows-latest, target: x86_64-pc-windows-msvc } + - { rust: stable, os: macos-latest } + - { rust: stable, os: ubuntu-latest } + - { rust: 1.61.0, os: ubuntu-latest } + steps: + - uses: actions/checkout@v2 + - name: Install Rust ${{ matrix.rust }} + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + target: ${{ matrix.target }} + - uses: taiki-e/install-action@v1 + with: + tool: cargo-nextest + - uses: Swatinem/rust-cache@v1 + - name: Check Cargo availability + run: cargo --version + - name: Install OpenSSH on Windows + run: | + # From https://gist.github.com/inevity/a0d7b9f1c5ba5a813917b92736122797 + Add-Type -AssemblyName System.IO.Compression.FileSystem + function Unzip + { + param([string]$zipfile, [string]$outpath) + + [System.IO.Compression.ZipFile]::ExtractToDirectory($zipfile, $outpath) + } + + $url = 'https://github.com/PowerShell/Win32-OpenSSH/releases/latest/' + $request = [System.Net.WebRequest]::Create($url) + $request.AllowAutoRedirect=$false + $response=$request.GetResponse() + $file = $([String]$response.GetResponseHeader("Location")).Replace('tag','download') + '/OpenSSH-Win64.zip' + + $client = new-object system.Net.Webclient; + $client.DownloadFile($file ,"c:\\OpenSSH-Win64.zip") + + Unzip "c:\\OpenSSH-Win64.zip" "C:\Program Files\" + mv "c:\\Program Files\OpenSSH-Win64" "C:\Program Files\OpenSSH\" + + powershell.exe -ExecutionPolicy Bypass -File "C:\Program Files\OpenSSH\install-sshd.ps1" + + New-NetFirewallRule -Name sshd -DisplayName 'OpenSSH Server (sshd)' -Enabled True -Direction Inbound -Protocol TCP -Action Allow -LocalPort 22,49152-65535 + + net start sshd + + Set-Service sshd -StartupType Automatic + Set-Service ssh-agent -StartupType Automatic + + cd "C:\Program Files\OpenSSH\" + Powershell.exe -ExecutionPolicy Bypass -Command '. .\FixHostFilePermissions.ps1 -Confirm:$false' + + $registryPath = "HKLM:\SOFTWARE\OpenSSH\" + $Name = "DefaultShell" + $value = "C:\windows\System32\WindowsPowerShell\v1.0\powershell.exe" + + IF(!(Test-Path $registryPath)) + { + New-Item -Path $registryPath -Force + New-ItemProperty -Path $registryPath -Name $name -Value $value -PropertyType String -Force + } ELSE { + New-ItemProperty -Path $registryPath -Name $name -Value $value -PropertyType String -Force + } + shell: pwsh + if: matrix.os == 'windows-latest' + - name: Run net tests (default features) + run: cargo nextest run --profile ci --release --verbose -p distant-net + - name: Run core tests (default features) + run: cargo nextest run --profile ci --release --verbose -p distant-core + - name: Run core tests (all features) + run: cargo nextest run --profile ci --release --verbose --all-features -p distant-core + - name: Ensure /run/sshd exists on Unix + run: mkdir -p /run/sshd + if: matrix.os == 'ubuntu-latest' + - name: Run ssh2 client tests (default features) + run: cargo nextest run --profile ci --release --verbose -p distant-ssh2 ssh2::client + - name: Run ssh2 client tests (all features) + run: cargo nextest run --profile ci --release --verbose --all-features -p distant-ssh2 ssh2::client + - name: Run CLI tests + run: cargo nextest run --profile ci --release --verbose + - name: Run CLI tests (no default features) + run: cargo nextest run --profile ci --release --verbose --no-default-features + ssh-launch-tests: + name: "Test ssh launch using Rust ${{ matrix.rust }} on ${{ matrix.os }}" + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + include: + - { rust: stable, os: macos-latest } + - { rust: stable, os: ubuntu-latest } + - { rust: 1.61.0, os: ubuntu-latest } + steps: + - uses: actions/checkout@v2 + - name: Install Rust ${{ matrix.rust }} + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + - uses: taiki-e/install-action@v1 + with: + tool: cargo-nextest + - uses: Swatinem/rust-cache@v1 + - name: Check Cargo availability + run: cargo --version + - name: Install distant cli for use in launch tests + run: | + cargo install --path . + echo "DISTANT_PATH=$HOME/.cargo/bin/distant" >> $GITHUB_ENV + - name: Run ssh2 launch tests (default features) + run: cargo nextest run --profile ci --release --verbose -p distant-ssh2 ssh2::launched + - name: Run ssh2 launch tests (all features) + run: cargo nextest run --profile ci --release --verbose --all-features -p distant-ssh2 ssh2::launched diff --git a/BUILDING.md b/BUILDING.md index fbc51bf..1467bd2 100644 --- a/BUILDING.md +++ b/BUILDING.md @@ -5,6 +5,11 @@ * `make` - needed to build openssl via vendor feature * `perl` - needed to build openssl via vendor feature +### FreeBSD + +* `gmake` - needed to build openssl via vender feature (`pkg install gmake`) + + ## Using Cargo A debug build is straightforward: diff --git a/CHANGELOG.md b/CHANGELOG.md index 929ee7c..3b7011c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,39 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added + +- `distant manager` subcommand + - `distant manager service` subcommand contains functionality to install, + start, stop, and uninstall the manager as a service on various operating + systems + - `distant manager info` will print information about an active connection + - `distant manager list` will print information about all connections +- `distant generate` subcommand + - `distant generate schema` will produce JSON schema for server + request/response + - `distant generate completion` will produce completion file for a specific + shell + ### Changed + +- `distant launch` is now `distant client launch` +- `distant action` is now `distant client action` +- `distant shell` is now `distant client shell` +- `distant listen` is now `distant server listen` +- Minimum supported rust version (MSRV) has been bumped to `1.61.0` + ### Fixed + +- Shell no longer has issues with fancier command prompts and other + terminal-oriented printing as `TERM=x256-color` is now set by default + ### Removed +- Networking directly from distant client to distant server. All connections + are now facilitated by the manager interface with client -> manager -> server +- Environment variable output as part of launch is now gone as the connection + is now being managed, so there is no need to export session information + ## [0.16.4] - 2022-06-01 ### Added - Dockerfile using Alpine linux with a basic install of distant, tagged as @@ -134,6 +163,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 pending upon full channel and no longer locks up - stdout, stderr, and stdin of `RemoteProcess` no longer cause deadlock -[Unreleased]: https://github.com/chipsenkbeil/distant/compare/v0.15.1...HEAD +[Unreleased]: https://github.com/chipsenkbeil/distant/compare/v0.17.0...HEAD +[0.17.0]: https://github.com/chipsenkbeil/distant/compare/v0.16.4...v0.17.0 +[0.16.4]: https://github.com/chipsenkbeil/distant/compare/v0.16.3...v0.16.4 +[0.16.3]: https://github.com/chipsenkbeil/distant/compare/v0.16.2...v0.16.3 +[0.16.2]: https://github.com/chipsenkbeil/distant/compare/v0.16.1...v0.16.2 +[0.16.1]: https://github.com/chipsenkbeil/distant/compare/v0.16.0...v0.16.1 +[0.16.0]: https://github.com/chipsenkbeil/distant/compare/v0.15.1...v0.16.0 [0.15.1]: https://github.com/chipsenkbeil/distant/compare/v0.15.0...v0.15.1 [0.15.0]: https://github.com/chipsenkbeil/distant/compare/v0.14.0...v0.15.0 diff --git a/Cargo.lock b/Cargo.lock index 94dcaf1..2b4a777 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,9 +31,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f9b8508dccb7687a1d6c4ce66b2b0ecef467c94667de27d8d7fe1f8d2a9cdc" +checksum = "bb07d2053ccdbe10e2af2995a2f116c1330396493dc1269f6a91d0ae82e19704" [[package]] name = "assert_cmd" @@ -176,6 +176,12 @@ dependencies = [ "futures-lite", ] +[[package]] +name = "async-once-cell" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f61305cacf1d0c5c9d3ee283d22f8f1f8c743a18ceb44a1b102bd53476c141de" + [[package]] name = "async-process" version = "1.4.0" @@ -227,6 +233,17 @@ version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30696a84d817107fc028e049980e09d5e140e8da8f1caeb17e8e950658a3cea9" +[[package]] +name = "async-trait" +version = "0.1.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96cf8829f67d2eab0b2dfa42c5d0ef737e0724e4a82b01b3e292456202b19716" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.0.0" @@ -250,12 +267,24 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "base16ct" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" + [[package]] name = "base64" version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" +[[package]] +name = "base64ct" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea908e7347a8c64e378c17e30ef880ad73e3b4498346b055c2c00ea342f3179" + [[package]] name = "bitflags" version = "1.3.2" @@ -271,6 +300,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-buffer" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +dependencies = [ + "generic-array", +] + [[package]] name = "blocking" version = "1.2.0" @@ -313,6 +351,12 @@ version = "3.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37ccbd214614c6783386c1af30caf03192f17891059cecc394b4fb119e363de3" +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + [[package]] name = "bytes" version = "1.1.0" @@ -351,21 +395,20 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chacha20" -version = "0.8.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01b72a433d0cf2aef113ba70f62634c56fddb0f244e6377185c56a7cadbd8f91" +checksum = "c7fc89c7c5b9e7a02dfe45cd2367bae382f9ed31c61ca8debe5f827c420a2f08" dependencies = [ "cfg-if 1.0.0", "cipher", "cpufeatures", - "zeroize", ] [[package]] name = "chacha20poly1305" -version = "0.9.0" +version = "0.10.0-pre" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b84ed6d1d5f7aa9bdde921a5090e0ca4d934d250ea3b402a5fab3a994e28a2a" +checksum = "746c430f71e66469abcf493c11484b1a86b957c84fc2d0ba664cd12ac23679ea" dependencies = [ "aead", "chacha20", @@ -416,26 +459,71 @@ dependencies = [ [[package]] name = "cipher" -version = "0.3.0" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" +checksum = "d1873270f8f7942c191139cb8a40fd228da6c3fd2fc376d7e92d47aa14aeb59e" dependencies = [ - "generic-array", + "crypto-common", + "inout", + "zeroize", ] [[package]] name = "clap" -version = "2.34.0" +version = "3.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +checksum = "d53da17d37dba964b9b3ecb5c5a1f193a2762c700e6829201e645b9381c99dc7" dependencies = [ - "ansi_term", "atty", "bitflags", + "clap_derive", + "clap_lex", + "indexmap", + "once_cell", "strsim", + "termcolor", "textwrap", - "unicode-width", - "vec_map", +] + +[[package]] +name = "clap_complete" +version = "3.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ead064480dfc4880a10764488415a97fdd36a4cf1bb022d372f02e8faf8386e1" +dependencies = [ + "clap", +] + +[[package]] +name = "clap_derive" +version = "3.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c11d40217d16aee8508cc8e5fde8b4ff24639758608e5374e731b53f85749fb9" +dependencies = [ + "heck 0.4.0", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5538cd660450ebeb4234cfecf8f2284b844ffc4c50531e66d584ad5b91293613" +dependencies = [ + "os_str_bytes", +] + +[[package]] +name = "combine" +version = "4.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a604e93b79d1808327a6fca85a6f2d69de66461e7620f5a4cbf5fb4d1d7c948" +dependencies = [ + "bytes", + "memchr", ] [[package]] @@ -447,6 +535,41 @@ dependencies = [ "cache-padded", ] +[[package]] +name = "config" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ea917b74b6edfb5024e3b55d3c8f710b5f4ed92646429601a42e96f0812b31b" +dependencies = [ + "async-trait", + "lazy_static", + "nom 7.1.1", + "pathdiff", + "serde", + "toml", +] + +[[package]] +name = "console" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28b32d32ca44b70c3e4acd7db1babf555fa026e385fb95f18028f88848b3c31" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "regex", + "terminal_size", + "unicode-width", + "winapi", +] + +[[package]] +name = "const-oid" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "722e23542a15cea1f65d4a1419c4cfd7a26706c70871a13a04238ca3f40f1661" + [[package]] name = "convert_case" version = "0.4.0" @@ -513,6 +636,28 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "crypto-bigint" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac961631d66e80ac7ac2ac01320628ce214ad2b5ef0a88ceb86eae459069e2b4" +dependencies = [ + "generic-array", + "rand_core 0.6.3", + "subtle", + "zeroize", +] + +[[package]] +name = "crypto-common" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "ctor" version = "0.1.22" @@ -523,6 +668,17 @@ dependencies = [ "syn", ] +[[package]] +name = "der" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13dd2ae565c0a381dde7fade45fce95984c568bdcb4700a4fdbe3175e0380b2f" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "derive_more" version = "0.99.17" @@ -535,6 +691,15 @@ dependencies = [ "syn", ] +[[package]] +name = "dialoguer" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8c8ae48e400addc32a8710c8d62d55cb84249a7d58ac4cd959daecfbaddc545" +dependencies = [ + "console", +] + [[package]] name = "difflib" version = "0.4.0" @@ -550,6 +715,26 @@ dependencies = [ "generic-array", ] +[[package]] +name = "digest" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" +dependencies = [ + "block-buffer 0.10.2", + "crypto-common", + "subtle", +] + +[[package]] +name = "directories" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f51c5d4ddabd36886dd3e1438cb358cdcb0d7c499cb99cb4ac2e38e18b5cb210" +dependencies = [ + "dirs-sys", +] + [[package]] name = "dirs" version = "2.0.2" @@ -560,6 +745,15 @@ dependencies = [ "dirs-sys", ] +[[package]] +name = "dirs" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" +dependencies = [ + "dirs-sys", +] + [[package]] name = "dirs-next" version = "2.0.0" @@ -594,11 +788,18 @@ dependencies = [ [[package]] name = "distant" -version = "0.16.4" +version = "0.17.0" dependencies = [ + "anyhow", "assert_cmd", "assert_fs", + "async-trait", + "clap", + "clap_complete", + "config", "derive_more", + "dialoguer", + "directories", "distant-core", "distant-ssh2", "flexi_logger 0.18.1", @@ -608,28 +809,37 @@ dependencies = [ "once_cell", "predicates", "rand 0.8.5", + "rpassword", "rstest 0.11.0", "serde", "serde_json", - "structopt", - "strum", + "service-manager", + "shell-words", "sysinfo", + "tabled", "terminal_size", "termwiz", "tokio", + "toml_edit", + "uriparse", + "which", "whoami", + "windows-service", + "winsplit", ] [[package]] name = "distant-core" -version = "0.16.4" +version = "0.17.0" dependencies = [ "assert_fs", + "async-trait", "bitflags", "bytes", - "chacha20poly1305", "ciborium", + "clap", "derive_more", + "distant-net", "flexi_logger 0.22.5", "futures", "hex", @@ -642,28 +852,64 @@ dependencies = [ "predicates", "rand 0.8.5", "rstest 0.13.0", + "schemars", "serde", + "serde_bytes", "serde_json", - "structopt", + "shell-words", "strum", "tokio", "tokio-util", + "uriparse", "walkdir", + "winsplit", +] + +[[package]] +name = "distant-net" +version = "0.17.0" +dependencies = [ + "async-trait", + "bytes", + "chacha20poly1305", + "derive_more", + "futures", + "hex", + "hkdf", + "log", + "p256", + "paste", + "rand 0.8.5", + "rmp-serde", + "schemars", + "serde", + "serde_bytes", + "sha2 0.10.2", + "tempfile", + "tokio", + "tokio-util", ] [[package]] name = "distant-ssh2" -version = "0.16.4" +version = "0.17.0" dependencies = [ + "anyhow", "assert_cmd", "assert_fs", "async-compat", + "async-once-cell", + "async-trait", + "derive_more", "distant-core", + "dunce", "flexi_logger 0.19.6", "futures", + "hex", "indoc", "log", "once_cell", + "openssl-src", "predicates", "rand 0.8.5", "rpassword", @@ -673,7 +919,9 @@ dependencies = [ "smol", "tokio", "wezterm-ssh", + "which", "whoami", + "winsplit", ] [[package]] @@ -682,12 +930,78 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +[[package]] +name = "dunce" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "453440c271cf5577fd2a40e4942540cb7d0d2f85e27c8d07dd0023c925a67541" + +[[package]] +name = "dyn-clone" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f94fa09c2aeea5b8839e414b7b841bf429fd25b9c522116ac97ee87856d88b2" + +[[package]] +name = "ecdsa" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1e737f9eebb44576f3ee654141a789464071eb369d02c4397b32b6a79790112" +dependencies = [ + "der", + "elliptic-curve", + "rfc6979", + "signature", +] + [[package]] name = "either" version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" +[[package]] +name = "elliptic-curve" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f6664c6a37892ed55da8dda26a99e6ccc783f0c72fa3c2eeaa00ed30d8f4d9a" +dependencies = [ + "base16ct", + "crypto-bigint", + "der", + "digest 0.10.3", + "ff", + "generic-array", + "group", + "hkdf", + "pem-rfc7468", + "pkcs8", + "rand_core 0.6.3", + "sec1", + "subtle", + "zeroize", +] + +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + +[[package]] +name = "err-derive" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34a887c8df3ed90498c1c437ce21f211c8e27672921a8ffa293cb8d6d4caa9e" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "rustversion", + "syn", + "synstructure", +] + [[package]] name = "event-listener" version = "2.5.2" @@ -703,6 +1017,16 @@ dependencies = [ "instant", ] +[[package]] +name = "ff" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df689201f395c6b90dfe87127685f8dbfc083a5e779e613575d8bd7314300c3e" +dependencies = [ + "rand_core 0.6.3", + "subtle", +] + [[package]] name = "filedescriptor" version = "0.8.2" @@ -1005,12 +1329,29 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "group" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7391856def869c1c81063a03457c676fbcd419709c3dfb33d8d319de484b154d" +dependencies = [ + "ff", + "rand_core 0.6.3", + "subtle", +] + [[package]] name = "half" version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" +[[package]] +name = "hashbrown" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db0d4cf898abf0081f964436dc980e96670a0f36863e4b83aaacdb65c9d7ccc3" + [[package]] name = "heck" version = "0.3.3" @@ -1020,6 +1361,12 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "heck" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" + [[package]] name = "hermit-abi" version = "0.1.19" @@ -1035,6 +1382,24 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hkdf" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "791a029f6b9fc27657f6f188ec6e5e43f6911f6f878e0dc5501396e09809d437" +dependencies = [ + "hmac", +] + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest 0.10.3", +] + [[package]] name = "ignore" version = "0.4.18" @@ -1053,6 +1418,16 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "indexmap" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c6392766afd7964e2531940894cffe4bd8d7d17dbc3c1c4857040fd4b33bdb3" +dependencies = [ + "autocfg", + "hashbrown", +] + [[package]] name = "indoc" version = "1.0.6" @@ -1079,6 +1454,15 @@ dependencies = [ "libc", ] +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "instant" version = "0.1.12" @@ -1252,6 +1636,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "mio" version = "0.8.3" @@ -1274,6 +1664,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "nom" +version = "7.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8903e5a29a317527874d0402f867152a3d21c908bb0b933e416c65e301d4c36" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "normalize-line-endings" version = "0.3.0" @@ -1410,6 +1810,32 @@ dependencies = [ "num-traits", ] +[[package]] +name = "os_str_bytes" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21326818e99cfe6ce1e524c2a805c189a99b5ae555a35d19f9a284b427d86afa" + +[[package]] +name = "p256" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51f44edd08f51e2ade572f141051021c5af22677e42b7dd28a88155151c33594" +dependencies = [ + "ecdsa", + "elliptic-curve", + "sha2 0.10.2", +] + +[[package]] +name = "papergrid" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608b6444acf7f5ea39e8bd06dd6037e34a4b5ddfb29ae840edad49ea798e9e79" +dependencies = [ + "unicode-width", +] + [[package]] name = "parking" version = "2.0.0" @@ -1464,6 +1890,27 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "paste" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c520e05135d6e763148b6426a837e239041653ba7becd2e538c076c738025fc" + +[[package]] +name = "pathdiff" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd" + +[[package]] +name = "pem-rfc7468" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d159833a9105500e0398934e205e0773f0b27529557134ecfc51c27646adac" +dependencies = [ + "base64ct", +] + [[package]] name = "pest" version = "2.1.3" @@ -1523,6 +1970,16 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs8" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.25" @@ -1806,13 +2263,48 @@ dependencies = [ "winapi", ] +[[package]] +name = "rfc6979" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c0788437d5ee113c49af91d3594ebc4fcdcc962f8b6df5aa1c3eeafd8ad95de" +dependencies = [ + "crypto-bigint", + "hmac", + "zeroize", +] + +[[package]] +name = "rmp" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44519172358fd6d58656c86ab8e7fbc9e1490c3e8f14d35ed78ca0dd07403c9f" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25786b0d276110195fa3d6f3f31299900cf71dfbd6c28450f3f58a0e7f7a347e" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rpassword" -version = "5.0.1" +version = "6.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffc936cf8a7ea60c58f030fd36a612a48f440610214dc54bc36431f9ea0c3efb" +checksum = "2bf099a1888612545b683d2661a1940089f6c2e5a8e38979b2159da876bfd956" dependencies = [ "libc", + "serde", + "serde_json", "winapi", ] @@ -1885,12 +2377,50 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schemars" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1847b767a3d62d95cbf3d8a9f0e421cf57a0d8aa4f411d4b16525afb0284d4ed" +dependencies = [ + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af4d7e1b012cb3d9129567661a63755ea4b8a7386d339dc945ae187e403c6743" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + [[package]] name = "scopeguard" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sec1" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be24c1842290c45df0a7bf069e0c268a747ad05a192f2fd7dcfdbc1cba40928" +dependencies = [ + "base16ct", + "der", + "generic-array", + "pkcs8", + "subtle", + "zeroize", +] + [[package]] name = "semver" version = "0.11.0" @@ -1924,6 +2454,15 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_bytes" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "212e73464ebcde48d723aa02eb270ba62eff38a9b732df31f33f1b4e145f3a54" +dependencies = [ + "serde", +] + [[package]] name = "serde_derive" version = "1.0.137" @@ -1935,6 +2474,17 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_derive_internals" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85bf8229e7920a9f636479437026331ce11aa132b4dde37d121944a44d6e5f3c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_json" version = "1.0.81" @@ -1988,19 +2538,42 @@ dependencies = [ "serial-core", ] +[[package]] +name = "service-manager" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "108229bdf3ea1da7a4ba32ad62c4101c4d324fcc009dd8ea001b211162cb18a8" +dependencies = [ + "clap", + "dirs 4.0.0", + "serde", + "which", +] + [[package]] name = "sha2" version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" dependencies = [ - "block-buffer", + "block-buffer 0.9.0", "cfg-if 1.0.0", "cpufeatures", - "digest", + "digest 0.9.0", "opaque-debug", ] +[[package]] +name = "sha2" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55deaec60f81eefe3cce0dc50bda92d6d8e88f2a27df7c5033b42afeb1ed2676" +dependencies = [ + "cfg-if 1.0.0", + "cpufeatures", + "digest 0.10.3", +] + [[package]] name = "shared_library" version = "0.1.9" @@ -2046,6 +2619,16 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f054c6c1a6e95179d6f23ed974060dcefb2d9388bb7256900badad682c499de4" +dependencies = [ + "digest 0.10.3", + "rand_core 0.6.3", +] + [[package]] name = "siphasher" version = "0.3.10" @@ -2092,6 +2675,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "spki" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67cf02bbac7a337dc36e4f5a693db6c21e7863f45070f7064577eb4367a3212b" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "ssh2" version = "0.9.3" @@ -2106,33 +2699,9 @@ dependencies = [ [[package]] name = "strsim" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" - -[[package]] -name = "structopt" -version = "0.3.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c6b5c64445ba8094a6ab0c3cd2ad323e07171012d9c98b0b15651daf1787a10" -dependencies = [ - "clap", - "lazy_static", - "structopt-derive", -] - -[[package]] -name = "structopt-derive" -version = "0.4.18" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcb5ae327f9cc13b68763b5749770cb9e048a99bd9dfdfa58d0cf05d5f64afe0" -dependencies = [ - "heck", - "proc-macro-error", - "proc-macro2", - "quote", - "syn", -] +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "strum" @@ -2149,7 +2718,7 @@ version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d06aaeeee809dbc59eb4556183dd927df67db1540de5be8d3ec0b6636358a5ec" dependencies = [ - "heck", + "heck 0.3.3", "proc-macro2", "quote", "syn", @@ -2172,11 +2741,23 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "synstructure" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "unicode-xid", +] + [[package]] name = "sysinfo" -version = "0.23.13" +version = "0.24.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3977ec2e0520829be45c8a2df70db2bf364714d8a748316a10c3c35d4d2b01c9" +checksum = "54cb4ebf3d49308b99e6e9dc95e989e2fdbdc210e4f67c39db0bb89ba927001c" dependencies = [ "cfg-if 1.0.0", "core-foundation-sys", @@ -2187,6 +2768,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "tabled" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2407502760ccfd538f2fb1f843dd87b6daf1a17848d57bc5a25617e408ef4c7a" +dependencies = [ + "papergrid", + "tabled_derive", +] + +[[package]] +name = "tabled_derive" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "278ea3921cee8c5a69e0542998a089f7a14fa43c9c4e4f9951295da89bd0c943" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tempfile" version = "3.3.0" @@ -2201,6 +2803,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "termcolor" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" +dependencies = [ + "winapi-util", +] + [[package]] name = "terminal_size" version = "0.1.17" @@ -2217,9 +2828,9 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76971977e6121664ec1b960d1313aacfa75642adc93b9d4d53b247bd4cb1747e" dependencies = [ - "dirs", + "dirs 2.0.2", "fnv", - "nom", + "nom 5.1.2", "phf", "phf_codegen", ] @@ -2269,7 +2880,7 @@ dependencies = [ "ordered-float", "regex", "semver 0.11.0", - "sha2", + "sha2 0.9.9", "signal-hook 0.1.17", "terminfo", "termios 0.3.3", @@ -2282,12 +2893,9 @@ dependencies = [ [[package]] name = "textwrap" -version = "0.11.0" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" -dependencies = [ - "unicode-width", -] +checksum = "b1141d4d61095b28419e22cb0bbf02755f5e54e0526f97f1e3d1d160e60885fb" [[package]] name = "thiserror" @@ -2391,6 +2999,27 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5376256e44f2443f8896ac012507c19a012df0fe8758b55246ae51a2279db51f" +dependencies = [ + "combine", + "indexmap", + "itertools", + "serde", +] + [[package]] name = "typenum" version = "1.15.0" @@ -2421,6 +3050,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" +[[package]] +name = "unicode-xid" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "957e51f3646910546462e67d5f7599b9e4fb8acdd304b087a6494730f9eebf04" + [[package]] name = "universal-hash" version = "0.4.1" @@ -2431,6 +3066,16 @@ dependencies = [ "subtle", ] +[[package]] +name = "uriparse" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0200d0fc04d809396c2ad43f3c95da3582a2556eba8d453c1087f4120ee352ff" +dependencies = [ + "fnv", + "lazy_static", +] + [[package]] name = "utf8parse" version = "0.2.0" @@ -2453,12 +3098,6 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" -[[package]] -name = "vec_map" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" - [[package]] name = "version_check" version = "0.9.4" @@ -2626,6 +3265,17 @@ dependencies = [ "thiserror", ] +[[package]] +name = "which" +version = "4.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c4fb54e6113b6a8772ee41c3404fb0301ac79604489467e0a9ce1f3e97c24ae" +dependencies = [ + "either", + "lazy_static", + "libc", +] + [[package]] name = "whoami" version = "1.2.1" @@ -2636,6 +3286,12 @@ dependencies = [ "web-sys", ] +[[package]] +name = "widestring" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" + [[package]] name = "winapi" version = "0.3.9" @@ -2667,6 +3323,18 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-service" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "917fdb865e7ff03af9dd86609f8767bc88fefba89e8efd569de8e208af8724b3" +dependencies = [ + "bitflags", + "err-derive", + "widestring", + "windows-sys", +] + [[package]] name = "windows-sys" version = "0.36.1" @@ -2710,6 +3378,12 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" +[[package]] +name = "winsplit" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ab703352da6a72f35c39a533526393725640575bb211f61987a2748323ad956" + [[package]] name = "yansi" version = "0.5.1" @@ -2718,6 +3392,6 @@ checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" [[package]] name = "zeroize" -version = "1.4.3" +version = "1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d68d9dcec5f9b43a30d38c49f91dfedfaac384cb8f085faca366c26207dd1619" +checksum = "94693807d016b2f2d2e14420eb3bfcca689311ff775dcf113d74ea624b7cdf07" diff --git a/Cargo.toml b/Cargo.toml index 0e5f56e..7785795 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,16 +3,16 @@ name = "distant" description = "Operate on a remote computer through file and process manipulation" categories = ["command-line-utilities"] keywords = ["cli"] -version = "0.16.4" +version = "0.17.0" authors = ["Chip Senkbeil "] -edition = "2018" +edition = "2021" homepage = "https://github.com/chipsenkbeil/distant" repository = "https://github.com/chipsenkbeil/distant" readme = "README.md" license = "MIT OR Apache-2.0" [workspace] -members = ["distant-core", "distant-ssh2"] +members = ["distant-core", "distant-net", "distant-ssh2"] [profile.release] opt-level = 'z' @@ -25,31 +25,44 @@ libssh = ["distant-ssh2/libssh"] ssh2 = ["distant-ssh2/ssh2"] [dependencies] +anyhow = "1.0.58" +async-trait = "0.1.56" +clap = { version = "3.2.5", features = ["derive"] } +clap_complete = "3.2.3" +config = { version = "0.13.1", default-features = false, features = ["toml"] } derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error", "is_variant"] } -distant-core = { version = "=0.16.4", path = "distant-core", features = ["structopt"] } +dialoguer = { version = "0.10.1", default-features = false } +distant-core = { version = "=0.17.0", path = "distant-core", features = ["clap", "schemars"] } +directories = "4.0.1" flexi_logger = "0.18.1" indoc = "1.0.6" log = "0.4.17" once_cell = "1.12.0" rand = { version = "0.8.5", features = ["getrandom"] } +rpassword = "6.0.1" serde = { version = "1.0.137", features = ["derive"] } serde_json = "1.0.81" -structopt = "0.3.26" -strum = { version = "0.21.0", features = ["derive"] } -sysinfo = "0.23.13" +shell-words = "1.0" +service-manager = { version = "0.1.2", features = ["clap", "serde"] } +tabled = "0.7.0" tokio = { version = "1.19.0", features = ["full"] } +toml_edit = { version = "0.14.4", features = ["serde"] } terminal_size = "0.1.17" termwiz = "0.15.0" +uriparse = "0.6.4" +which = "4.2.5" +winsplit = "0.1" whoami = "1.2.1" # Optional native SSH functionality -distant-ssh2 = { version = "=0.16.4", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true } +distant-ssh2 = { version = "=0.17.0", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true } [target.'cfg(unix)'.dependencies] fork = "0.1.19" -# [target.'cfg(windows)'.dependencies] -# sysinfo = "0.23.2" +[target.'cfg(windows)'.dependencies] +sysinfo = "0.24.7" +windows-service = "0.5" [dev-dependencies] assert_cmd = "2.0.4" diff --git a/README.md b/README.md index 64a9408..d31c573 100644 --- a/README.md +++ b/README.md @@ -1,26 +1,15 @@ # distant - remotely edit files and run programs -[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![RustC 1.51+][distant_rustc_img]][distant_rustc_lnk] - -| Operating System | Status | -| ---------------- | ------------------------------------------------------------------ | -| MacOS (x86, ARM) | [![MacOS CI][distant_ci_macos_img]][distant_ci_macos_lnk] | -| Linux (x86) | [![Linux CI][distant_ci_linux_img]][distant_ci_linux_lnk] | -| Windows (x86) | [![Windows CI][distant_ci_windows_img]][distant_ci_windows_lnk] | +[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![CI][distant_ci_img]][distant_ci_lnk] [![RustC 1.61+][distant_rustc_img]][distant_rustc_lnk] [distant_crates_img]: https://img.shields.io/crates/v/distant.svg [distant_crates_lnk]: https://crates.io/crates/distant [distant_doc_img]: https://docs.rs/distant/badge.svg [distant_doc_lnk]: https://docs.rs/distant -[distant_rustc_img]: https://img.shields.io/badge/distant-rustc_1.51+-lightgray.svg -[distant_rustc_lnk]: https://blog.rust-lang.org/2021/03/25/Rust-1.51.0.html - -[distant_ci_macos_img]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-macos.yml/badge.svg -[distant_ci_macos_lnk]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-macos.yml -[distant_ci_linux_img]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-linux.yml/badge.svg -[distant_ci_linux_lnk]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-linux.yml -[distant_ci_windows_img]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-windows.yml/badge.svg -[distant_ci_windows_lnk]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-windows.yml +[distant_ci_img]: https://github.com/chipsenkbeil/distant/actions/workflows/ci.yml/badge.svg +[distant_ci_lnk]: https://github.com/chipsenkbeil/distant/actions/workflows/ci.yml +[distant_rustc_img]: https://img.shields.io/badge/distant-rustc_1.61+-lightgray.svg +[distant_rustc_lnk]: https://blog.rust-lang.org/2022/05/19/Rust-1.61.0.html 🚧 **(Alpha stage software) This program is in rapid development and may break or change frequently!** 🚧 @@ -31,7 +20,7 @@ a command to start a server and configure the local client to be able to talk to the server. - Asynchronous in nature, powered by [`tokio`](https://tokio.rs/) -- Data is serialized to send across the wire via [`CBOR`](https://cbor.io/) +- Data is serialized to send across the wire via [`msgpack`](https://msgpack.org/) - Encryption & authentication are handled via [XChaCha20Poly1305](https://tools.ietf.org/html/rfc8439) for an authenticated encryption scheme via @@ -59,29 +48,83 @@ cargo install distant Alternatively, you can clone this repository and build from source following the [build guide](./BUILDING.md). -## Examples +## Example + +### Starting the manager + +In order to facilitate communication between a client and server, you first +need to start the manager. This can be done in one of two ways: + +1. Leverage the `service` functionality to spawn the manager using one of the + following supported service management platforms: + - [`sc.exe`](https://docs.microsoft.com/en-us/previous-versions/windows/it-pro/windows-server-2012-r2-and-2012/cc754599(v=ws.11)) for use with [Window Service](https://en.wikipedia.org/wiki/Windows_service) (Windows) + - [Launchd](https://en.wikipedia.org/wiki/Launchd) (MacOS) + - [systemd](https://en.wikipedia.org/wiki/Systemd) (Linux) + - [OpenRC](https://en.wikipedia.org/wiki/OpenRC) (Linux) + - [rc.d](https://en.wikipedia.org/wiki/Init#Research_Unix-style/BSD-style) (FreeBSD) +2. Run the manager manually by using the `listen` subcommand -Launch a remote instance of `distant`. Calling `launch` will do the following: +#### Service management -1. Ssh into the specified host (in the below example, `my.example.com`) -2. Execute `distant listen --host ssh` on the remote machine -3. Receive on the local machine the credentials needed to connect to the server -4. Depending on the options specified, print/store/use the session settings so - future calls to `distant action` can connect +```bash +# If you want to install the manager as a service, you can use the service +# interface available directly from the CLI +# +# By default, this will install a system-level service, which means that you +# will need elevated permissions to both install AND communicate with the +# manager +distant manager service install + +# If you want to maintain a user-level manager service, you can include the +# --user flag. Note that this is only supported on MacOS (via launchd) and +# Linux (via systemd) +distant manager service install --user + +# ........ + +# Once you have installed the service, you will normally need to start it +# manually or restart your machine to trigger startup on boot +distant manager service start # --user if you are working with user-level +``` + +#### Manual start ```bash -# Connects to my.example.com on port 22 via SSH to start a new session -# and print out information to configure your system to talk to it -distant launch my.example.com - -# NOTE: If you are using sh, bash, or zsh, you can automatically set the -# appropriate environment variables using the following -eval "$(distant launch my.example.com)" - -# After the session is established, you can perform different operations -# on the remote machine via `distant action {command} [args]` -distant action copy path/to/file new/path/to/file -distant action spawn -- echo 'Hello, this is from the other side' +# If you choose to run the manager without a service management platform, you +# can either run the manager in the foreground or provide --daemon to spawn and +# detach the manager + +# Run in the foreground +distant manager listen + +# Detach the manager where it will not terminate even if the parent exits +distant manager listen --daemon +``` + +### Interacting with a remote machine + +Once you have a manager listening for client requests, you can begin +interacting with the manager, spawn and/or connect to servers, and interact +with remote machines. + +```bash +# Connect to my.example.com on port 22 via SSH and start a distant server +distant client launch ssh://my.example.com + +# After the connection is established, you can perform different operations +# on the remote machine via `distant client action {command} [args]` +distant client action copy path/to/file new/path/to/file +distant client action spawn -- echo 'Hello, this is from the other side' + +# Opening a shell to the remote machine is trivial +distant client shell + +# If you have more than one connection open, you can switch between active +# connections by using the `select` subcommand +distant client select '' + +# For programmatic use, a REPL following the JSON API is available +distant client repl --format json ``` ## License diff --git a/distant-core/Cargo.toml b/distant-core/Cargo.toml index 91a4ab8..a4dab2f 100644 --- a/distant-core/Cargo.toml +++ b/distant-core/Cargo.toml @@ -3,20 +3,24 @@ name = "distant-core" description = "Core library for distant, enabling operation on a remote computer through file and process manipulation" categories = ["network-programming"] keywords = ["api", "async"] -version = "0.16.4" +version = "0.17.0" authors = ["Chip Senkbeil "] -edition = "2018" +edition = "2021" homepage = "https://github.com/chipsenkbeil/distant" repository = "https://github.com/chipsenkbeil/distant" readme = "README.md" license = "MIT OR Apache-2.0" +[features] +schemars = ["dep:schemars", "distant-net/schemars"] + [dependencies] +async-trait = "0.1.56" bitflags = "1.3.2" bytes = "1.1.0" -chacha20poly1305 = "0.9.0" ciborium = "0.2.0" -derive_more = { version = "0.99.16", default-features = false, features = ["deref", "deref_mut", "display", "from", "error", "into_iterator", "is_variant"] } +derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] } +distant-net = { version = "=0.17.0", path = "../distant-net" } futures = "0.3.16" hex = "0.4.3" log = "0.4.14" @@ -26,14 +30,19 @@ once_cell = "1.8.0" portable-pty = "0.7.0" rand = { version = "0.8.4", features = ["getrandom"] } serde = { version = "1.0.126", features = ["derive"] } +serde_bytes = "0.11.6" serde_json = "1.0.64" +shell-words = "1.0" strum = { version = "0.21.0", features = ["derive"] } tokio = { version = "1.12.0", features = ["full"] } tokio-util = { version = "0.6.7", features = ["codec"] } +uriparse = "0.6.4" walkdir = "2.3.2" +winsplit = "0.1" # Optional dependencies based on features -structopt = { version = "0.3.22", optional = true } +clap = { version = "3.2.5", features = ["derive"], optional = true } +schemars = { version = "0.8.10", optional = true } [dev-dependencies] assert_fs = "1.0.4" diff --git a/distant-core/README.md b/distant-core/README.md index e4c5b5d..a0942bb 100644 --- a/distant-core/README.md +++ b/distant-core/README.md @@ -1,13 +1,13 @@ # distant core -[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.51.0][distant_rustc_img]][distant_rustc_lnk] +[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.61.0][distant_rustc_img]][distant_rustc_lnk] [distant_crates_img]: https://img.shields.io/crates/v/distant-core.svg [distant_crates_lnk]: https://crates.io/crates/distant-core [distant_doc_img]: https://docs.rs/distant-core/badge.svg [distant_doc_lnk]: https://docs.rs/distant-core -[distant_rustc_img]: https://img.shields.io/badge/distant_core-rustc_1.51+-lightgray.svg -[distant_rustc_lnk]: https://blog.rust-lang.org/2021/03/25/Rust-1.51.0.html +[distant_rustc_img]: https://img.shields.io/badge/distant_core-rustc_1.61+-lightgray.svg +[distant_rustc_lnk]: https://blog.rust-lang.org/2022/05/19/Rust-1.61.0.html Library that powers the [`distant`](https://github.com/chipsenkbeil/distant) binary. @@ -16,15 +16,11 @@ binary. ## Details -The `distant` library supplies a mixture of functionality and data to run -servers that operate on remote machines and clients that talk to them. - -- Asynchronous in nature, powered by [`tokio`](https://tokio.rs/) -- Data is serialized to send across the wire via [`CBOR`](https://cbor.io/) -- Encryption & authentication are handled via - [XChaCha20Poly1305](https://tools.ietf.org/html/rfc8439) for an authenticated - encryption scheme via - [RustCrypto/ChaCha20Poly1305](https://github.com/RustCrypto/AEADs/tree/master/chacha20poly1305) +The `distant-core` library supplies the client, manager, and server +implementations for use with the distant API in order to communicate with +remote machines and perform actions. This library acts as the primary +implementation that powers the CLI, but is also available for other extensions +like `distant-ssh2`. ## Installation @@ -32,40 +28,43 @@ You can import the dependency by adding the following to your `Cargo.toml`: ```toml [dependencies] -distant-core = "0.16" +distant-core = "0.17" ``` ## Features Currently, the library supports the following features: -- `structopt`: generates [`StructOpt`](https://github.com/TeXitoi/structopt) - bindings for `RequestData` (used by cli to expose request actions) +- `clap`: generates [`Clap`](https://github.com/clap-rs) bindings for + `DistantRequestData` (used by cli to expose request actions) +- `schemars`: derives the `schemars::JsonSchema` interface on + `DistantMsg`, `DistantRequestData`, and `DistantResponseData` data types By default, no features are enabled on the library. ## Examples -Below is an example of connecting to a distant server over TCP: +Below is an example of connecting to a distant server over TCP without any +encryption or authentication: ```rust -use distant_core::{Session, SessionChannelExt, SecretKey32, XChaCha20Poly1305Codec}; -use std::net::SocketAddr; - -// 32-byte secret key paresd from hex, used for a specific codec -let key: SecretKey32 = "DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF".parse().unwrap(); -let codec = XChaCha20Poly1305Codec::from(key); - +use distant_core::{ + DistantClient, + DistantChannelExt, + net::{PlainCodec, TcpClientExt}, +}; +use std::{net::SocketAddr, path::Path}; + +// Connect to a server located at example.com on port 8080 that is using +// no encryption or authentication (PlainCodec) let addr: SocketAddr = "example.com:8080".parse().unwrap(); -let mut session = Session::tcp_connect(addr, codec).await.unwrap(); - -// Append text to a file, representing request as -// NOTE: This method comes from SessionChannelExt -session.append_file_text( - "", - "path/to/file.txt".to_string(), - "new contents" -).await.expect("Failed to append to file"); +let mut client = DistantClient::connect(addr, PlainCodec).await + .expect("Failed to connect"); + +// Append text to a file +// NOTE: This method comes from DistantChannelExt +client.append_file_text(Path::new("path/to/file.txt"), "new contents").await + .expect("Failed to append to file"); ``` ## License diff --git a/distant-core/src/api.rs b/distant-core/src/api.rs new file mode 100644 index 0000000..f1e5546 --- /dev/null +++ b/distant-core/src/api.rs @@ -0,0 +1,626 @@ +use crate::{ + data::{ChangeKind, DirEntry, Environment, Error, Metadata, ProcessId, PtySize, SystemInfo}, + ConnectionId, DistantMsg, DistantRequestData, DistantResponseData, +}; +use async_trait::async_trait; +use distant_net::{Reply, Server, ServerCtx}; +use log::*; +use std::{io, path::PathBuf, sync::Arc}; + +mod local; +pub use local::LocalDistantApi; + +mod reply; +use reply::DistantSingleReply; + +/// Represents the context provided to the [`DistantApi`] for incoming requests +pub struct DistantCtx { + pub connection_id: ConnectionId, + pub reply: Box>, + pub local_data: Arc, +} + +/// Represents a server that leverages an API compliant with `distant` +pub struct DistantApiServer +where + T: DistantApi, +{ + api: T, +} + +impl DistantApiServer +where + T: DistantApi, +{ + pub fn new(api: T) -> Self { + Self { api } + } +} + +impl DistantApiServer::LocalData> { + /// Creates a new server using the [`LocalDistantApi`] implementation + pub fn local() -> io::Result { + Ok(Self { + api: LocalDistantApi::initialize()?, + }) + } +} + +#[inline] +fn unsupported(label: &str) -> io::Result { + Err(io::Error::new( + io::ErrorKind::Unsupported, + format!("{} is unsupported", label), + )) +} + +/// Interface to support the suite of functionality available with distant, +/// which can be used to build other servers that are compatible with distant +#[async_trait] +pub trait DistantApi { + type LocalData: Send + Sync; + + /// Invoked whenever a new connection is established, providing a mutable reference to the + /// newly-created local data. This is a way to support modifying local data before it is used. + #[allow(unused_variables)] + async fn on_accept(&self, local_data: &mut Self::LocalData) {} + + /// Reads bytes from a file. + /// + /// * `path` - the path to the file + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn read_file( + &self, + ctx: DistantCtx, + path: PathBuf, + ) -> io::Result> { + unsupported("read_file") + } + + /// Reads bytes from a file as text. + /// + /// * `path` - the path to the file + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn read_file_text( + &self, + ctx: DistantCtx, + path: PathBuf, + ) -> io::Result { + unsupported("read_file_text") + } + + /// Writes bytes to a file, overwriting the file if it exists. + /// + /// * `path` - the path to the file + /// * `data` - the data to write + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn write_file( + &self, + ctx: DistantCtx, + path: PathBuf, + data: Vec, + ) -> io::Result<()> { + unsupported("write_file") + } + + /// Writes text to a file, overwriting the file if it exists. + /// + /// * `path` - the path to the file + /// * `data` - the data to write + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn write_file_text( + &self, + ctx: DistantCtx, + path: PathBuf, + data: String, + ) -> io::Result<()> { + unsupported("write_file_text") + } + + /// Writes bytes to the end of a file, creating it if it is missing. + /// + /// * `path` - the path to the file + /// * `data` - the data to append + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn append_file( + &self, + ctx: DistantCtx, + path: PathBuf, + data: Vec, + ) -> io::Result<()> { + unsupported("append_file") + } + + /// Writes bytes to the end of a file, creating it if it is missing. + /// + /// * `path` - the path to the file + /// * `data` - the data to append + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn append_file_text( + &self, + ctx: DistantCtx, + path: PathBuf, + data: String, + ) -> io::Result<()> { + unsupported("append_file_text") + } + + /// Reads entries from a directory. + /// + /// * `path` - the path to the directory + /// * `depth` - how far to traverse the directory, 0 being unlimited + /// * `absolute` - if true, will return absolute paths instead of relative paths + /// * `canonicalize` - if true, will canonicalize entry paths before returned + /// * `include_root` - if true, will include the directory specified in the entries + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn read_dir( + &self, + ctx: DistantCtx, + path: PathBuf, + depth: usize, + absolute: bool, + canonicalize: bool, + include_root: bool, + ) -> io::Result<(Vec, Vec)> { + unsupported("read_dir") + } + + /// Creates a directory. + /// + /// * `path` - the path to the directory + /// * `all` - if true, will create all missing parent components + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn create_dir( + &self, + ctx: DistantCtx, + path: PathBuf, + all: bool, + ) -> io::Result<()> { + unsupported("create_dir") + } + + /// Copies some file or directory. + /// + /// * `src` - the path to the file or directory to copy + /// * `dst` - the path where the copy will be placed + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn copy( + &self, + ctx: DistantCtx, + src: PathBuf, + dst: PathBuf, + ) -> io::Result<()> { + unsupported("copy") + } + + /// Removes some file or directory. + /// + /// * `path` - the path to a file or directory + /// * `force` - if true, will remove non-empty directories + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn remove( + &self, + ctx: DistantCtx, + path: PathBuf, + force: bool, + ) -> io::Result<()> { + unsupported("remove") + } + + /// Renames some file or directory. + /// + /// * `src` - the path to the file or directory to rename + /// * `dst` - the new name for the file or directory + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn rename( + &self, + ctx: DistantCtx, + src: PathBuf, + dst: PathBuf, + ) -> io::Result<()> { + unsupported("rename") + } + + /// Watches a file or directory for changes. + /// + /// * `path` - the path to the file or directory + /// * `recursive` - if true, will watch for changes within subdirectories and beyond + /// * `only` - if non-empty, will limit reported changes to those included in this list + /// * `except` - if non-empty, will limit reported changes to those not included in this list + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn watch( + &self, + ctx: DistantCtx, + path: PathBuf, + recursive: bool, + only: Vec, + except: Vec, + ) -> io::Result<()> { + unsupported("watch") + } + + /// Removes a file or directory from being watched. + /// + /// * `path` - the path to the file or directory + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn unwatch(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<()> { + unsupported("unwatch") + } + + /// Checks if the specified path exists. + /// + /// * `path` - the path to the file or directory + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result { + unsupported("exists") + } + + /// Reads metadata for a file or directory. + /// + /// * `path` - the path to the file or directory + /// * `canonicalize` - if true, will include a canonicalized path in the metadata + /// * `resolve_file_type` - if true, will resolve symlinks to underlying type (file or dir) + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn metadata( + &self, + ctx: DistantCtx, + path: PathBuf, + canonicalize: bool, + resolve_file_type: bool, + ) -> io::Result { + unsupported("metadata") + } + + /// Spawns a new process, returning its id. + /// + /// * `cmd` - the full command to run as a new process (including arguments) + /// * `environment` - the environment variables to associate with the process + /// * `current_dir` - the alternative current directory to use with the process + /// * `persist` - if true, the process will continue running even after the connection that + /// spawned the process has terminated + /// * `pty` - if provided, will run the process within a PTY of the given size + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn proc_spawn( + &self, + ctx: DistantCtx, + cmd: String, + environment: Environment, + current_dir: Option, + persist: bool, + pty: Option, + ) -> io::Result { + unsupported("proc_spawn") + } + + /// Kills a running process by its id. + /// + /// * `id` - the unique id of the process + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> { + unsupported("proc_kill") + } + + /// Sends data to the stdin of the process with the specified id. + /// + /// * `id` - the unique id of the process + /// * `data` - the bytes to send to stdin + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn proc_stdin( + &self, + ctx: DistantCtx, + id: ProcessId, + data: Vec, + ) -> io::Result<()> { + unsupported("proc_stdin") + } + + /// Resizes the PTY of the process with the specified id. + /// + /// * `id` - the unique id of the process + /// * `size` - the new size of the pty + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn proc_resize_pty( + &self, + ctx: DistantCtx, + id: ProcessId, + size: PtySize, + ) -> io::Result<()> { + unsupported("proc_resize_pty") + } + + /// Retrieves information about the system. + /// + /// *Override this, otherwise it will return "unsupported" as an error.* + #[allow(unused_variables)] + async fn system_info(&self, ctx: DistantCtx) -> io::Result { + unsupported("system_info") + } +} + +#[async_trait] +impl Server for DistantApiServer +where + T: DistantApi + Send + Sync, + D: Send + Sync, +{ + type Request = DistantMsg; + type Response = DistantMsg; + type LocalData = D; + + /// Overridden to leverage [`DistantApi`] implementation of `on_accept` + async fn on_accept(&self, local_data: &mut Self::LocalData) { + T::on_accept(&self.api, local_data).await + } + + async fn on_request(&self, ctx: ServerCtx) { + let ServerCtx { + connection_id, + request, + reply, + local_data, + } = ctx; + + // Convert our reply to a queued reply so we can ensure that the result + // of an API function is sent back before anything else + let reply = reply.queue(); + + // Process single vs batch requests + let response = match request.payload { + DistantMsg::Single(data) => { + let ctx = DistantCtx { + connection_id, + reply: Box::new(DistantSingleReply::from(reply.clone_reply())), + local_data, + }; + + let data = handle_request(self, ctx, data).await; + + // Report outgoing errors in our debug logs + if let DistantResponseData::Error(x) = &data { + debug!("[Conn {}] {}", connection_id, x); + } + + DistantMsg::Single(data) + } + DistantMsg::Batch(list) => { + let mut out = Vec::new(); + + for data in list { + let ctx = DistantCtx { + connection_id, + reply: Box::new(DistantSingleReply::from(reply.clone_reply())), + local_data: Arc::clone(&local_data), + }; + + // TODO: This does not run in parallel, meaning that the next item in the + // batch will not be queued until the previous item completes! This + // would be useful if we wanted to chain requests where the previous + // request feeds into the current request, but not if we just want + // to run everything together. So we should instead rewrite this + // to spawn a task per request and then await completion of all tasks + let data = handle_request(self, ctx, data).await; + + // Report outgoing errors in our debug logs + if let DistantResponseData::Error(x) = &data { + debug!("[Conn {}] {}", connection_id, x); + } + + out.push(data); + } + + DistantMsg::Batch(out) + } + }; + + // Queue up our result to go before ANY of the other messages that might be sent. + // This is important to avoid situations such as when a process is started, but before + // the confirmation can be sent some stdout or stderr is captured and sent first. + if let Err(x) = reply.send_before(response).await { + error!("[Conn {}] Failed to send response: {}", connection_id, x); + } + + // Flush out all of our replies thus far and toggle to no longer hold submissions + if let Err(x) = reply.flush(false).await { + error!( + "[Conn {}] Failed to flush response queue: {}", + connection_id, x + ); + } + } +} + +/// Processes an incoming request +async fn handle_request( + server: &DistantApiServer, + ctx: DistantCtx, + request: DistantRequestData, +) -> DistantResponseData +where + T: DistantApi + Send + Sync, + D: Send + Sync, +{ + match request { + DistantRequestData::FileRead { path } => server + .api + .read_file(ctx, path) + .await + .map(|data| DistantResponseData::Blob { data }) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::FileReadText { path } => server + .api + .read_file_text(ctx, path) + .await + .map(|data| DistantResponseData::Text { data }) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::FileWrite { path, data } => server + .api + .write_file(ctx, path, data) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::FileWriteText { path, text } => server + .api + .write_file_text(ctx, path, text) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::FileAppend { path, data } => server + .api + .append_file(ctx, path, data) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::FileAppendText { path, text } => server + .api + .append_file_text(ctx, path, text) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::DirRead { + path, + depth, + absolute, + canonicalize, + include_root, + } => server + .api + .read_dir(ctx, path, depth, absolute, canonicalize, include_root) + .await + .map(|(entries, errors)| DistantResponseData::DirEntries { + entries, + errors: errors.into_iter().map(Error::from).collect(), + }) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::DirCreate { path, all } => server + .api + .create_dir(ctx, path, all) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::Remove { path, force } => server + .api + .remove(ctx, path, force) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::Copy { src, dst } => server + .api + .copy(ctx, src, dst) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::Rename { src, dst } => server + .api + .rename(ctx, src, dst) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::Watch { + path, + recursive, + only, + except, + } => server + .api + .watch(ctx, path, recursive, only, except) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::Unwatch { path } => server + .api + .unwatch(ctx, path) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::Exists { path } => server + .api + .exists(ctx, path) + .await + .map(|value| DistantResponseData::Exists { value }) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::Metadata { + path, + canonicalize, + resolve_file_type, + } => server + .api + .metadata(ctx, path, canonicalize, resolve_file_type) + .await + .map(DistantResponseData::Metadata) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::ProcSpawn { + cmd, + environment, + current_dir, + persist, + pty, + } => server + .api + .proc_spawn(ctx, cmd.into(), environment, current_dir, persist, pty) + .await + .map(|id| DistantResponseData::ProcSpawned { id }) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::ProcKill { id } => server + .api + .proc_kill(ctx, id) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::ProcStdin { id, data } => server + .api + .proc_stdin(ctx, id, data) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::ProcResizePty { id, size } => server + .api + .proc_resize_pty(ctx, id, size) + .await + .map(|_| DistantResponseData::Ok) + .unwrap_or_else(DistantResponseData::from), + DistantRequestData::SystemInfo {} => server + .api + .system_info(ctx) + .await + .map(DistantResponseData::SystemInfo) + .unwrap_or_else(DistantResponseData::from), + } +} diff --git a/distant-core/src/api/local.rs b/distant-core/src/api/local.rs new file mode 100644 index 0000000..2fdc740 --- /dev/null +++ b/distant-core/src/api/local.rs @@ -0,0 +1,2124 @@ +use crate::{ + data::{ + ChangeKind, ChangeKindSet, DirEntry, Environment, FileType, Metadata, ProcessId, PtySize, + SystemInfo, + }, + DistantApi, DistantCtx, +}; +use async_trait::async_trait; +use log::*; +use std::{ + io, + path::{Path, PathBuf}, +}; +use tokio::io::AsyncWriteExt; +use walkdir::WalkDir; + +mod process; + +mod state; +pub use state::ConnectionState; +use state::*; + +/// Represents an implementation of [`DistantApi`] that works with the local machine +/// where the server using this api is running. In other words, this is a direct +/// impementation of the API instead of a proxy to another machine as seen with +/// implementations on top of SSH and other protocol +pub struct LocalDistantApi { + state: GlobalState, +} + +impl LocalDistantApi { + /// Initialize the api instance + pub fn initialize() -> io::Result { + Ok(Self { + state: GlobalState::initialize()?, + }) + } +} + +#[async_trait] +impl DistantApi for LocalDistantApi { + type LocalData = ConnectionState; + + /// Injects the global channels into the local connection + async fn on_accept(&self, local_data: &mut Self::LocalData) { + local_data.process_channel = self.state.process.clone_channel(); + local_data.watcher_channel = self.state.watcher.clone_channel(); + } + + async fn read_file( + &self, + ctx: DistantCtx, + path: PathBuf, + ) -> io::Result> { + debug!( + "[Conn {}] Reading bytes from file {:?}", + ctx.connection_id, path + ); + + tokio::fs::read(path).await + } + + async fn read_file_text( + &self, + ctx: DistantCtx, + path: PathBuf, + ) -> io::Result { + debug!( + "[Conn {}] Reading text from file {:?}", + ctx.connection_id, path + ); + + tokio::fs::read_to_string(path).await + } + + async fn write_file( + &self, + ctx: DistantCtx, + path: PathBuf, + data: Vec, + ) -> io::Result<()> { + debug!( + "[Conn {}] Writing bytes to file {:?}", + ctx.connection_id, path + ); + + tokio::fs::write(path, data).await + } + + async fn write_file_text( + &self, + ctx: DistantCtx, + path: PathBuf, + data: String, + ) -> io::Result<()> { + debug!( + "[Conn {}] Writing text to file {:?}", + ctx.connection_id, path + ); + + tokio::fs::write(path, data).await + } + + async fn append_file( + &self, + ctx: DistantCtx, + path: PathBuf, + data: Vec, + ) -> io::Result<()> { + debug!( + "[Conn {}] Appending bytes to file {:?}", + ctx.connection_id, path + ); + + let mut file = tokio::fs::OpenOptions::new() + .create(true) + .append(true) + .open(path) + .await?; + file.write_all(data.as_ref()).await + } + + async fn append_file_text( + &self, + ctx: DistantCtx, + path: PathBuf, + data: String, + ) -> io::Result<()> { + debug!( + "[Conn {}] Appending text to file {:?}", + ctx.connection_id, path + ); + + let mut file = tokio::fs::OpenOptions::new() + .create(true) + .append(true) + .open(path) + .await?; + file.write_all(data.as_ref()).await + } + + async fn read_dir( + &self, + ctx: DistantCtx, + path: PathBuf, + depth: usize, + absolute: bool, + canonicalize: bool, + include_root: bool, + ) -> io::Result<(Vec, Vec)> { + debug!( + "[Conn {}] Reading directory {:?} {{depth: {}, absolute: {}, canonicalize: {}, include_root: {}}}", + ctx.connection_id, path, depth, absolute, canonicalize, include_root + ); + + // Canonicalize our provided path to ensure that it is exists, not a loop, and absolute + let root_path = tokio::fs::canonicalize(path).await?; + + // Traverse, but don't include root directory in entries (hence min depth 1), unless indicated + // to do so (min depth 0) + let dir = WalkDir::new(root_path.as_path()) + .min_depth(if include_root { 0 } else { 1 }) + .sort_by_file_name(); + + // If depth > 0, will recursively traverse to specified max depth, otherwise + // performs infinite traversal + let dir = if depth > 0 { dir.max_depth(depth) } else { dir }; + + // Determine our entries and errors + let mut entries = Vec::new(); + let mut errors = Vec::new(); + + #[inline] + fn map_file_type(ft: std::fs::FileType) -> FileType { + if ft.is_dir() { + FileType::Dir + } else if ft.is_file() { + FileType::File + } else { + FileType::Symlink + } + } + + for entry in dir { + match entry.map_err(io::Error::from) { + // For entries within the root, we want to transform the path based on flags + Ok(e) if e.depth() > 0 => { + // Canonicalize the path if specified, otherwise just return + // the path as is + let mut path = if canonicalize { + match tokio::fs::canonicalize(e.path()).await { + Ok(path) => path, + Err(x) => { + errors.push(x); + continue; + } + } + } else { + e.path().to_path_buf() + }; + + // Strip the path of its prefix based if not flagged as absolute + if !absolute { + // NOTE: In the situation where we canonicalized the path earlier, + // there is no guarantee that our root path is still the + // parent of the symlink's destination; so, in that case we MUST just + // return the path if the strip_prefix fails + path = path + .strip_prefix(root_path.as_path()) + .map(Path::to_path_buf) + .unwrap_or(path); + }; + + entries.push(DirEntry { + path, + file_type: map_file_type(e.file_type()), + depth: e.depth(), + }); + } + + // For the root, we just want to echo back the entry as is + Ok(e) => { + entries.push(DirEntry { + path: e.path().to_path_buf(), + file_type: map_file_type(e.file_type()), + depth: e.depth(), + }); + } + + Err(x) => errors.push(x), + } + } + + Ok((entries, errors)) + } + + async fn create_dir( + &self, + ctx: DistantCtx, + path: PathBuf, + all: bool, + ) -> io::Result<()> { + debug!( + "[Conn {}] Creating directory {:?} {{all: {}}}", + ctx.connection_id, path, all + ); + if all { + tokio::fs::create_dir_all(path).await + } else { + tokio::fs::create_dir(path).await + } + } + + async fn remove( + &self, + ctx: DistantCtx, + path: PathBuf, + force: bool, + ) -> io::Result<()> { + debug!( + "[Conn {}] Removing {:?} {{force: {}}}", + ctx.connection_id, path, force + ); + let path_metadata = tokio::fs::metadata(path.as_path()).await?; + if path_metadata.is_dir() { + if force { + tokio::fs::remove_dir_all(path).await + } else { + tokio::fs::remove_dir(path).await + } + } else { + tokio::fs::remove_file(path).await + } + } + + async fn copy( + &self, + ctx: DistantCtx, + src: PathBuf, + dst: PathBuf, + ) -> io::Result<()> { + debug!( + "[Conn {}] Copying {:?} to {:?}", + ctx.connection_id, src, dst + ); + let src_metadata = tokio::fs::metadata(src.as_path()).await?; + if src_metadata.is_dir() { + // Create the destination directory first, regardless of if anything + // is in the source directory + tokio::fs::create_dir_all(dst.as_path()).await?; + + for entry in WalkDir::new(src.as_path()) + .min_depth(1) + .follow_links(false) + .into_iter() + .filter_entry(|e| { + e.file_type().is_file() || e.file_type().is_dir() || e.path_is_symlink() + }) + { + let entry = entry?; + + // Get unique portion of path relative to src + // NOTE: Because we are traversing files that are all within src, this + // should always succeed + let local_src = entry.path().strip_prefix(src.as_path()).unwrap(); + + // Get the file without any directories + let local_src_file_name = local_src.file_name().unwrap(); + + // Get the directory housing the file + // NOTE: Because we enforce files/symlinks, there will always be a parent + let local_src_dir = local_src.parent().unwrap(); + + // Map out the path to the destination + let dst_parent_dir = dst.join(local_src_dir); + + // Create the destination directory for the file when copying + tokio::fs::create_dir_all(dst_parent_dir.as_path()).await?; + + let dst_path = dst_parent_dir.join(local_src_file_name); + + // Perform copying from entry to destination (if a file/symlink) + if !entry.file_type().is_dir() { + tokio::fs::copy(entry.path(), dst_path).await?; + + // Otherwise, if a directory, create it + } else { + tokio::fs::create_dir(dst_path).await?; + } + } + } else { + tokio::fs::copy(src, dst).await?; + } + + Ok(()) + } + + async fn rename( + &self, + ctx: DistantCtx, + src: PathBuf, + dst: PathBuf, + ) -> io::Result<()> { + debug!( + "[Conn {}] Renaming {:?} to {:?}", + ctx.connection_id, src, dst + ); + tokio::fs::rename(src, dst).await + } + + async fn watch( + &self, + ctx: DistantCtx, + path: PathBuf, + recursive: bool, + only: Vec, + except: Vec, + ) -> io::Result<()> { + let only = only.into_iter().collect::(); + let except = except.into_iter().collect::(); + debug!( + "[Conn {}] Watching {:?} {{recursive: {}, only: {}, except: {}}}", + ctx.connection_id, path, recursive, only, except + ); + + let path = RegisteredPath::register( + ctx.connection_id, + path.as_path(), + recursive, + only, + except, + ctx.reply, + ) + .await?; + + self.state.watcher.watch(path).await?; + + Ok(()) + } + + async fn unwatch(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<()> { + debug!("[Conn {}] Unwatching {:?}", ctx.connection_id, path); + + self.state + .watcher + .unwatch(ctx.connection_id, path.as_path()) + .await?; + Ok(()) + } + + async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result { + debug!("[Conn {}] Checking if {:?} exists", ctx.connection_id, path); + + // Following experimental `std::fs::try_exists`, which checks the error kind of the + // metadata lookup to see if it is not found and filters accordingly + match tokio::fs::metadata(path.as_path()).await { + Ok(_) => Ok(true), + Err(x) if x.kind() == io::ErrorKind::NotFound => Ok(false), + Err(x) => return Err(x), + } + } + + async fn metadata( + &self, + ctx: DistantCtx, + path: PathBuf, + canonicalize: bool, + resolve_file_type: bool, + ) -> io::Result { + debug!( + "[Conn {}] Reading metadata for {:?} {{canonicalize: {}, resolve_file_type: {}}}", + ctx.connection_id, path, canonicalize, resolve_file_type + ); + Metadata::read(path, canonicalize, resolve_file_type).await + } + + async fn proc_spawn( + &self, + ctx: DistantCtx, + cmd: String, + environment: Environment, + current_dir: Option, + persist: bool, + pty: Option, + ) -> io::Result { + debug!( + "[Conn {}] Spawning {} {{environment: {:?}, current_dir: {:?}, persist: {}, pty: {:?}}}", + ctx.connection_id, cmd, environment, current_dir, persist, pty + ); + self.state + .process + .spawn(cmd, environment, current_dir, persist, pty, ctx.reply) + .await + } + + async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> { + debug!("[Conn {}] Killing process {}", ctx.connection_id, id); + self.state.process.kill(id).await + } + + async fn proc_stdin( + &self, + ctx: DistantCtx, + id: ProcessId, + data: Vec, + ) -> io::Result<()> { + debug!( + "[Conn {}] Sending stdin to process {}", + ctx.connection_id, id + ); + self.state.process.send_stdin(id, data).await + } + + async fn proc_resize_pty( + &self, + ctx: DistantCtx, + id: ProcessId, + size: PtySize, + ) -> io::Result<()> { + debug!( + "[Conn {}] Resizing pty of process {} to {}", + ctx.connection_id, id, size + ); + self.state.process.resize_pty(id, size).await + } + + async fn system_info(&self, ctx: DistantCtx) -> io::Result { + debug!("[Conn {}] Reading system information", ctx.connection_id); + Ok(SystemInfo::default()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::data::DistantResponseData; + use assert_fs::prelude::*; + use distant_net::Reply; + use once_cell::sync::Lazy; + use predicates::prelude::*; + use std::{sync::Arc, time::Duration}; + use tokio::sync::mpsc; + + static TEMP_SCRIPT_DIR: Lazy = + Lazy::new(|| assert_fs::TempDir::new().unwrap()); + static SCRIPT_RUNNER: Lazy = Lazy::new(|| String::from("bash")); + + static ECHO_ARGS_TO_STDOUT_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_args_to_stdout.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + printf "%s" "$*" + "# + )) + .unwrap(); + script + }); + + static ECHO_ARGS_TO_STDERR_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_args_to_stderr.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + printf "%s" "$*" 1>&2 + "# + )) + .unwrap(); + script + }); + + static ECHO_STDIN_TO_STDOUT_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_stdin_to_stdout.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + while IFS= read; do echo "$REPLY"; done + "# + )) + .unwrap(); + script + }); + + static SLEEP_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("sleep.sh"); + script + .write_str(indoc::indoc!( + r#" + #!/usr/bin/env bash + sleep "$1" + "# + )) + .unwrap(); + script + }); + + static DOES_NOT_EXIST_BIN: Lazy = + Lazy::new(|| TEMP_SCRIPT_DIR.child("does_not_exist_bin")); + + async fn setup( + buffer: usize, + ) -> ( + LocalDistantApi, + DistantCtx, + mpsc::Receiver, + ) { + let api = LocalDistantApi::initialize().unwrap(); + let (reply, rx) = make_reply(buffer); + let mut local_data = ConnectionState::default(); + DistantApi::on_accept(&api, &mut local_data).await; + let ctx = DistantCtx { + connection_id: rand::random(), + reply, + local_data: Arc::new(local_data), + }; + (api, ctx, rx) + } + + fn make_reply( + buffer: usize, + ) -> ( + Box>, + mpsc::Receiver, + ) { + let (tx, rx) = mpsc::channel(buffer); + (Box::new(tx), rx) + } + + #[tokio::test] + async fn read_file_should_fail_if_file_missing() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let path = temp.child("missing-file").path().to_path_buf(); + + let _ = api.read_file(ctx, path).await.unwrap_err(); + } + + #[tokio::test] + async fn read_file_should_send_blob_with_file_contents() { + let (api, ctx, _rx) = setup(1).await; + + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + let bytes = api.read_file(ctx, file.path().to_path_buf()).await.unwrap(); + assert_eq!(bytes, b"some file contents"); + } + + #[tokio::test] + async fn read_file_text_should_send_error_if_fails_to_read_file() { + let (api, ctx, _rx) = setup(1).await; + + let temp = assert_fs::TempDir::new().unwrap(); + let path = temp.child("missing-file").path().to_path_buf(); + + let _ = api.read_file_text(ctx, path).await.unwrap_err(); + } + + #[tokio::test] + async fn read_file_text_should_send_text_with_file_contents() { + let (api, ctx, _rx) = setup(1).await; + + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + let text = api + .read_file_text(ctx, file.path().to_path_buf()) + .await + .unwrap(); + assert_eq!(text, "some file contents"); + } + + #[tokio::test] + async fn write_file_should_send_error_if_fails_to_write_file() { + let (api, ctx, _rx) = setup(1).await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + let _ = api + .write_file(ctx, file.path().to_path_buf(), b"some text".to_vec()) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); + } + + #[tokio::test] + async fn write_file_should_send_ok_when_successful() { + let (api, ctx, _rx) = setup(1).await; + + // Path should point to a file that does not exist, but all + // other components leading up to it do + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + api.write_file(ctx, file.path().to_path_buf(), b"some text".to_vec()) + .await + .unwrap(); + + // Also verify that we actually did create the file + // with the associated contents + file.assert("some text"); + } + + #[tokio::test] + async fn write_file_text_should_send_error_if_fails_to_write_file() { + let (api, ctx, _rx) = setup(1).await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + api.write_file_text(ctx, file.path().to_path_buf(), "some text".to_string()) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); + } + + #[tokio::test] + async fn write_file_text_should_send_ok_when_successful() { + let (api, ctx, _rx) = setup(1).await; + + // Path should point to a file that does not exist, but all + // other components leading up to it do + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + api.write_file_text(ctx, file.path().to_path_buf(), "some text".to_string()) + .await + .unwrap(); + + // Also verify that we actually did create the file + // with the associated contents + file.assert("some text"); + } + + #[tokio::test] + async fn append_file_should_send_error_if_fails_to_create_file() { + let (api, ctx, _rx) = setup(1).await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + api.append_file( + ctx, + file.path().to_path_buf(), + b"some extra contents".to_vec(), + ) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); + } + + #[tokio::test] + async fn append_file_should_create_file_if_missing() { + let (api, ctx, _rx) = setup(1).await; + + // Don't create the file directly, but define path + // where the file should be + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + api.append_file( + ctx, + file.path().to_path_buf(), + b"some extra contents".to_vec(), + ) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did create to the file + file.assert("some extra contents"); + } + + #[tokio::test] + async fn append_file_should_send_ok_when_successful() { + let (api, ctx, _rx) = setup(1).await; + + // Create a temporary file and fill it with some contents + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + api.append_file( + ctx, + file.path().to_path_buf(), + b"some extra contents".to_vec(), + ) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did append to the file + file.assert("some file contentssome extra contents"); + } + + #[tokio::test] + async fn append_file_text_should_send_error_if_fails_to_create_file() { + let (api, ctx, _rx) = setup(1).await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + let _ = api + .append_file_text( + ctx, + file.path().to_path_buf(), + "some extra contents".to_string(), + ) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); + } + + #[tokio::test] + async fn append_file_text_should_create_file_if_missing() { + let (api, ctx, _rx) = setup(1).await; + + // Don't create the file directly, but define path + // where the file should be + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + api.append_file_text( + ctx, + file.path().to_path_buf(), + "some extra contents".to_string(), + ) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did create to the file + file.assert("some extra contents"); + } + + #[tokio::test] + async fn append_file_text_should_send_ok_when_successful() { + let (api, ctx, _rx) = setup(1).await; + + // Create a temporary file and fill it with some contents + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + api.append_file_text( + ctx, + file.path().to_path_buf(), + "some extra contents".to_string(), + ) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did append to the file + file.assert("some file contentssome extra contents"); + } + + #[tokio::test] + async fn dir_read_should_send_error_if_directory_does_not_exist() { + let (api, ctx, _rx) = setup(1).await; + + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("test-dir"); + + let _ = api + .read_dir( + ctx, + dir.path().to_path_buf(), + /* depth */ 0, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap_err(); + } + + // /root/ + // /root/file1 + // /root/link1 -> /root/sub1/file2 + // /root/sub1/ + // /root/sub1/file2 + async fn setup_dir() -> assert_fs::TempDir { + let root_dir = assert_fs::TempDir::new().unwrap(); + root_dir.child("file1").touch().unwrap(); + + let sub1 = root_dir.child("sub1"); + sub1.create_dir_all().unwrap(); + + let file2 = sub1.child("file2"); + file2.touch().unwrap(); + + let link1 = root_dir.child("link1"); + link1.symlink_to_file(file2.path()).unwrap(); + + root_dir + } + + #[tokio::test] + async fn dir_read_should_support_depth_limits() { + let (api, ctx, _rx) = setup(1).await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = api + .read_dir( + ctx, + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, Path::new("link1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, Path::new("sub1")); + assert_eq!(entries[2].depth, 1); + } + + #[tokio::test] + async fn dir_read_should_support_unlimited_depth_using_zero() { + let (api, ctx, _rx) = setup(1).await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = api + .read_dir( + ctx, + root_dir.path().to_path_buf(), + /* depth */ 0, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 4, "Wrong number of entries found"); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, Path::new("link1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, Path::new("sub1")); + assert_eq!(entries[2].depth, 1); + + assert_eq!(entries[3].file_type, FileType::File); + assert_eq!(entries[3].path, Path::new("sub1").join("file2")); + assert_eq!(entries[3].depth, 2); + } + + #[tokio::test] + async fn dir_read_should_support_including_directory_in_returned_entries() { + let (api, ctx, _rx) = setup(1).await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = api + .read_dir( + ctx, + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ true, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 4, "Wrong number of entries found"); + + // NOTE: Root entry is always absolute, resolved path + assert_eq!(entries[0].file_type, FileType::Dir); + assert_eq!(entries[0].path, root_dir.path().canonicalize().unwrap()); + assert_eq!(entries[0].depth, 0); + + assert_eq!(entries[1].file_type, FileType::File); + assert_eq!(entries[1].path, Path::new("file1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Symlink); + assert_eq!(entries[2].path, Path::new("link1")); + assert_eq!(entries[2].depth, 1); + + assert_eq!(entries[3].file_type, FileType::Dir); + assert_eq!(entries[3].path, Path::new("sub1")); + assert_eq!(entries[3].depth, 1); + } + + #[tokio::test] + async fn dir_read_should_support_returning_absolute_paths() { + let (api, ctx, _rx) = setup(1).await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = api + .read_dir( + ctx, + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ true, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + let root_path = root_dir.path().canonicalize().unwrap(); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, root_path.join("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, root_path.join("link1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, root_path.join("sub1")); + assert_eq!(entries[2].depth, 1); + } + + #[tokio::test] + async fn dir_read_should_support_returning_canonicalized_paths() { + let (api, ctx, _rx) = setup(1).await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = api + .read_dir( + ctx, + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ false, + /* canonicalize */ true, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + // Symlink should be resolved from $ROOT/link1 -> $ROOT/sub1/file2 + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, Path::new("sub1").join("file2")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, Path::new("sub1")); + assert_eq!(entries[2].depth, 1); + } + + #[tokio::test] + async fn create_dir_should_send_error_if_fails() { + let (api, ctx, _rx) = setup(1).await; + + // Make a path that has multiple non-existent components + // so the creation will fail + let root_dir = setup_dir().await; + let path = root_dir.path().join("nested").join("new-dir"); + + let _ = api + .create_dir(ctx, path.to_path_buf(), /* all */ false) + .await + .unwrap_err(); + + // Also verify that the directory was not actually created + assert!(!path.exists(), "Path unexpectedly exists"); + } + + #[tokio::test] + async fn create_dir_should_send_ok_when_successful() { + let (api, ctx, _rx) = setup(1).await; + let root_dir = setup_dir().await; + let path = root_dir.path().join("new-dir"); + + api.create_dir(ctx, path.to_path_buf(), /* all */ false) + .await + .unwrap(); + + // Also verify that the directory was actually created + assert!(path.exists(), "Directory not created"); + } + + #[tokio::test] + async fn create_dir_should_support_creating_multiple_dir_components() { + let (api, ctx, _rx) = setup(1).await; + let root_dir = setup_dir().await; + let path = root_dir.path().join("nested").join("new-dir"); + + api.create_dir(ctx, path.to_path_buf(), /* all */ true) + .await + .unwrap(); + + // Also verify that the directory was actually created + assert!(path.exists(), "Directory not created"); + } + + #[tokio::test] + async fn remove_should_send_error_on_failure() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("missing-file"); + + let _ = api + .remove(ctx, file.path().to_path_buf(), /* false */ false) + .await + .unwrap_err(); + + // Also, verify that path does not exist + file.assert(predicate::path::missing()); + } + + #[tokio::test] + async fn remove_should_support_deleting_a_directory() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + api.remove(ctx, dir.path().to_path_buf(), /* false */ false) + .await + .unwrap(); + + // Also, verify that path does not exist + dir.assert(predicate::path::missing()); + } + + #[tokio::test] + async fn remove_should_delete_nonempty_directory_if_force_is_true() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + dir.child("file").touch().unwrap(); + + api.remove(ctx, dir.path().to_path_buf(), /* false */ true) + .await + .unwrap(); + + // Also, verify that path does not exist + dir.assert(predicate::path::missing()); + } + + #[tokio::test] + async fn remove_should_support_deleting_a_single_file() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("some-file"); + file.touch().unwrap(); + + api.remove(ctx, file.path().to_path_buf(), /* false */ false) + .await + .unwrap(); + + // Also, verify that path does not exist + file.assert(predicate::path::missing()); + } + + #[tokio::test] + async fn copy_should_send_error_on_failure() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + let dst = temp.child("dst"); + + let _ = api + .copy(ctx, src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap_err(); + + // Also, verify that destination does not exist + dst.assert(predicate::path::missing()); + } + + #[tokio::test] + async fn copy_should_support_copying_an_entire_directory() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let src_file = src.child("file"); + src_file.write_str("some contents").unwrap(); + + let dst = temp.child("dst"); + let dst_file = dst.child("file"); + + api.copy(ctx, src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we have source and destination directories and associated contents + src.assert(predicate::path::is_dir()); + src_file.assert(predicate::path::is_file()); + dst.assert(predicate::path::is_dir()); + dst_file.assert(predicate::path::eq_file(src_file.path())); + } + + #[tokio::test] + async fn copy_should_support_copying_an_empty_directory() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let dst = temp.child("dst"); + + api.copy(ctx, src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we still have source and destination directories + src.assert(predicate::path::is_dir()); + dst.assert(predicate::path::is_dir()); + } + + #[tokio::test] + async fn copy_should_support_copying_a_directory_that_only_contains_directories() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let src_dir = src.child("dir"); + src_dir.create_dir_all().unwrap(); + + let dst = temp.child("dst"); + let dst_dir = dst.child("dir"); + + api.copy(ctx, src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we have source and destination directories and associated contents + src.assert(predicate::path::is_dir().name("src")); + src_dir.assert(predicate::path::is_dir().name("src/dir")); + dst.assert(predicate::path::is_dir().name("dst")); + dst_dir.assert(predicate::path::is_dir().name("dst/dir")); + } + + #[tokio::test] + async fn copy_should_support_copying_a_single_file() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + src.write_str("some text").unwrap(); + let dst = temp.child("dst"); + + api.copy(ctx, src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we still have source and that destination has source's contents + src.assert(predicate::path::is_file()); + dst.assert(predicate::path::eq_file(src.path())); + } + + #[tokio::test] + async fn rename_should_fail_if_path_missing() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + let dst = temp.child("dst"); + + let _ = api + .rename(ctx, src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap_err(); + + // Also, verify that destination does not exist + dst.assert(predicate::path::missing()); + } + + #[tokio::test] + async fn rename_should_support_renaming_an_entire_directory() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let src_file = src.child("file"); + src_file.write_str("some contents").unwrap(); + + let dst = temp.child("dst"); + let dst_file = dst.child("file"); + + api.rename(ctx, src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we moved the contents + src.assert(predicate::path::missing()); + src_file.assert(predicate::path::missing()); + dst.assert(predicate::path::is_dir()); + dst_file.assert("some contents"); + } + + #[tokio::test] + async fn rename_should_support_renaming_a_single_file() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + src.write_str("some text").unwrap(); + let dst = temp.child("dst"); + + api.rename(ctx, src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we moved the file + src.assert(predicate::path::missing()); + dst.assert("some text"); + } + + /// Validates a response as being a series of changes that include the provided paths + fn validate_changed_paths( + data: &DistantResponseData, + expected_paths: &[PathBuf], + should_panic: bool, + ) -> bool { + match data { + DistantResponseData::Changed(change) if should_panic => { + let paths: Vec = change + .paths + .iter() + .map(|x| x.canonicalize().unwrap()) + .collect(); + assert_eq!(paths, expected_paths, "Wrong paths reported: {:?}", change); + + true + } + DistantResponseData::Changed(change) => { + let paths: Vec = change + .paths + .iter() + .map(|x| x.canonicalize().unwrap()) + .collect(); + paths == expected_paths + } + x if should_panic => panic!("Unexpected response: {:?}", x), + _ => false, + } + } + + #[tokio::test] + async fn watch_should_support_watching_a_single_file() { + // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. + let (api, ctx, mut rx) = setup(100).await; + let temp = assert_fs::TempDir::new().unwrap(); + + let file = temp.child("file"); + file.touch().unwrap(); + + api.watch( + ctx, + file.path().to_path_buf(), + /* recursive */ false, + /* only */ Default::default(), + /* except */ Default::default(), + ) + .await + .unwrap(); + + // Update the file and verify we get a notification + file.write_str("some text").unwrap(); + + let data = rx + .recv() + .await + .expect("Channel closed before we got change"); + validate_changed_paths( + &data, + &[file.path().to_path_buf().canonicalize().unwrap()], + /* should_panic */ true, + ); + } + + #[tokio::test] + async fn watch_should_support_watching_a_directory_recursively() { + // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. + let (api, ctx, mut rx) = setup(100).await; + let temp = assert_fs::TempDir::new().unwrap(); + + let file = temp.child("file"); + file.touch().unwrap(); + + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + api.watch( + ctx, + file.path().to_path_buf(), + /* recursive */ true, + /* only */ Default::default(), + /* except */ Default::default(), + ) + .await + .unwrap(); + + // Update the file and verify we get a notification + file.write_str("some text").unwrap(); + + // Create a nested file and verify we get a notification + let nested_file = dir.child("nested-file"); + nested_file.write_str("some text").unwrap(); + + // Sleep a bit to give time to get all changes happening + // TODO: Can we slim down this sleep? Or redesign test in some other way? + tokio::time::sleep(Duration::from_millis(100)).await; + + // Collect all responses, as we may get multiple for interactions within a directory + let mut responses = Vec::new(); + while let Ok(res) = rx.try_recv() { + responses.push(res); + } + + // Validate that we have at least one change reported for each of our paths + assert!( + responses.len() >= 2, + "Less than expected total responses: {:?}", + responses + ); + + let path = file.path().to_path_buf(); + assert!( + responses.iter().any(|res| validate_changed_paths( + res, + &[file.path().to_path_buf().canonicalize().unwrap()], + /* should_panic */ false, + )), + "Missing {:?} in {:?}", + path, + responses + .iter() + .map(|x| format!("{:?}", x)) + .collect::>(), + ); + + let path = nested_file.path().to_path_buf(); + assert!( + responses.iter().any(|res| validate_changed_paths( + res, + &[file.path().to_path_buf().canonicalize().unwrap()], + /* should_panic */ false, + )), + "Missing {:?} in {:?}", + path, + responses + .iter() + .map(|x| format!("{:?}", x)) + .collect::>(), + ); + } + + #[tokio::test] + async fn watch_should_report_changes_using_the_ctx_replies() { + // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. + let (api, ctx_1, mut rx_1) = setup(100).await; + let (ctx_2, mut rx_2) = { + let (reply, rx) = make_reply(100); + let ctx = DistantCtx { + connection_id: ctx_1.connection_id, + reply, + local_data: Arc::clone(&ctx_1.local_data), + }; + (ctx, rx) + }; + + let temp = assert_fs::TempDir::new().unwrap(); + + let file_1 = temp.child("file_1"); + file_1.touch().unwrap(); + + let file_2 = temp.child("file_2"); + file_2.touch().unwrap(); + + // Sleep a bit to give time to get all changes happening + // TODO: Can we slim down this sleep? Or redesign test in some other way? + tokio::time::sleep(Duration::from_millis(100)).await; + + // Initialize watch on file 1 + api.watch( + ctx_1, + file_1.path().to_path_buf(), + /* recursive */ false, + /* only */ Default::default(), + /* except */ Default::default(), + ) + .await + .unwrap(); + + // Initialize watch on file 2 + api.watch( + ctx_2, + file_2.path().to_path_buf(), + /* recursive */ false, + /* only */ Default::default(), + /* except */ Default::default(), + ) + .await + .unwrap(); + + // Update the files and verify we get notifications from different origins + file_1.write_str("some text").unwrap(); + let data = rx_1 + .recv() + .await + .expect("Channel closed before we got change"); + validate_changed_paths( + &data, + &[file_1.path().to_path_buf().canonicalize().unwrap()], + /* should_panic */ true, + ); + + // Update the files and verify we get notifications from different origins + file_2.write_str("some text").unwrap(); + let data = rx_2 + .recv() + .await + .expect("Channel closed before we got change"); + validate_changed_paths( + &data, + &[file_2.path().to_path_buf().canonicalize().unwrap()], + /* should_panic */ true, + ); + } + + #[tokio::test] + async fn exists_should_send_true_if_path_exists() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.touch().unwrap(); + + let exists = api.exists(ctx, file.path().to_path_buf()).await.unwrap(); + assert!(exists, "Expected exists to be true, but was false"); + } + + #[tokio::test] + async fn exists_should_send_false_if_path_does_not_exist() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + + let exists = api.exists(ctx, file.path().to_path_buf()).await.unwrap(); + assert!(!exists, "Expected exists to be false, but was true"); + } + + #[tokio::test] + async fn metadata_should_send_error_on_failure() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + + let _ = api + .metadata( + ctx, + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap_err(); + } + + #[tokio::test] + async fn metadata_should_send_back_metadata_on_file_if_exists() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let metadata = api + .metadata( + ctx, + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + canonicalized_path: None, + file_type: FileType::File, + len: 9, + readonly: false, + .. + } + ), + "{:?}", + metadata + ); + } + + #[cfg(unix)] + #[tokio::test] + async fn metadata_should_include_unix_specific_metadata_on_unix_platform() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let metadata = api + .metadata( + ctx, + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + #[allow(clippy::match_single_binding)] + match metadata { + Metadata { unix, windows, .. } => { + assert!(unix.is_some(), "Unexpectedly missing unix metadata on unix"); + assert!( + windows.is_none(), + "Unexpectedly got windows metadata on unix" + ); + } + } + } + + #[cfg(windows)] + #[tokio::test] + async fn metadata_should_include_windows_specific_metadata_on_windows_platform() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let metadata = api + .metadata( + ctx, + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + #[allow(clippy::match_single_binding)] + match metadata { + Metadata { unix, windows, .. } => { + assert!( + windows.is_some(), + "Unexpectedly missing windows metadata on windows" + ); + assert!(unix.is_none(), "Unexpectedly got unix metadata on windows"); + } + } + } + + #[tokio::test] + async fn metadata_should_send_back_metadata_on_dir_if_exists() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + let metadata = api + .metadata( + ctx, + dir.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + canonicalized_path: None, + file_type: FileType::Dir, + readonly: false, + .. + } + ), + "{:?}", + metadata + ); + } + + #[tokio::test] + async fn metadata_should_send_back_metadata_on_symlink_if_exists() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let symlink = temp.child("link"); + symlink.symlink_to_file(file.path()).unwrap(); + + let metadata = api + .metadata( + ctx, + symlink.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + canonicalized_path: None, + file_type: FileType::Symlink, + readonly: false, + .. + } + ), + "{:?}", + metadata + ); + } + + #[tokio::test] + async fn metadata_should_include_canonicalized_path_if_flag_specified() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let symlink = temp.child("link"); + symlink.symlink_to_file(file.path()).unwrap(); + + let metadata = api + .metadata( + ctx, + symlink.path().to_path_buf(), + /* canonicalize */ true, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + match metadata { + Metadata { + canonicalized_path: Some(path), + file_type: FileType::Symlink, + readonly: false, + .. + } => assert_eq!( + path, + file.path().canonicalize().unwrap(), + "Symlink canonicalized path does not match referenced file" + ), + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified() { + let (api, ctx, _rx) = setup(1).await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let symlink = temp.child("link"); + symlink.symlink_to_file(file.path()).unwrap(); + + let metadata = api + .metadata( + ctx, + symlink.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ true, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + file_type: FileType::File, + .. + } + ), + "{:?}", + metadata + ); + } + + // NOTE: Ignoring on windows because it's using WSL which wants a Linux path + // with / but thinks it's on windows and is providing \ + #[tokio::test] + #[cfg_attr(windows, ignore)] + async fn proc_spawn_should_send_error_on_failure() { + let (api, ctx, _rx) = setup(1).await; + + let _ = api + .proc_spawn( + ctx, + /* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap_err(); + } + + // NOTE: Ignoring on windows because it's using WSL which wants a Linux path + // with / but thinks it's on windows and is providing \ + #[tokio::test] + #[cfg_attr(windows, ignore)] + async fn proc_spawn_should_return_id_of_spawned_process() { + let (api, ctx, _rx) = setup(1).await; + + let id = api + .proc_spawn( + ctx, + /* cmd */ + format!( + "{} {}", + *SCRIPT_RUNNER, + ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + assert!(id > 0); + } + + // NOTE: Ignoring on windows because it's using WSL which wants a Linux path + // with / but thinks it's on windows and is providing \ + #[tokio::test] + #[cfg_attr(windows, ignore)] + async fn proc_spawn_should_send_back_stdout_periodically_when_available() { + let (api, ctx, mut rx) = setup(1).await; + + let proc_id = api + .proc_spawn( + ctx, + /* cmd */ + format!( + "{} {} some stdout", + *SCRIPT_RUNNER, + ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Gather two additional responses: + // + // 1. An indirect response for stdout + // 2. An indirect response that is proc completing + // + // Note that order is not a guarantee, so we have to check that + // we get one of each type of response + let data_1 = rx.recv().await.expect("Missing first response"); + let data_2 = rx.recv().await.expect("Missing second response"); + + let mut got_stdout = false; + let mut got_done = false; + + let mut check_data = |data: &DistantResponseData| match data { + DistantResponseData::ProcStdout { id, data } => { + assert_eq!( + *id, proc_id, + "Got {}, but expected {} as process id", + id, proc_id + ); + assert_eq!(data, b"some stdout", "Got wrong stdout"); + got_stdout = true; + } + DistantResponseData::ProcDone { id, success, .. } => { + assert_eq!( + *id, proc_id, + "Got {}, but expected {} as process id", + id, proc_id + ); + assert!(success, "Process should have completed successfully"); + got_done = true; + } + x => panic!("Unexpected response: {:?}", x), + }; + + check_data(&data_1); + check_data(&data_2); + assert!(got_stdout, "Missing stdout response"); + assert!(got_done, "Missing done response"); + } + + // NOTE: Ignoring on windows because it's using WSL which wants a Linux path + // with / but thinks it's on windows and is providing \ + #[tokio::test] + #[cfg_attr(windows, ignore)] + async fn proc_spawn_should_send_back_stderr_periodically_when_available() { + let (api, ctx, mut rx) = setup(1).await; + + let proc_id = api + .proc_spawn( + ctx, + /* cmd */ + format!( + "{} {} some stderr", + *SCRIPT_RUNNER, + ECHO_ARGS_TO_STDERR_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Gather two additional responses: + // + // 1. An indirect response for stderr + // 2. An indirect response that is proc completing + // + // Note that order is not a guarantee, so we have to check that + // we get one of each type of response + let data_1 = rx.recv().await.expect("Missing first response"); + let data_2 = rx.recv().await.expect("Missing second response"); + + let mut got_stderr = false; + let mut got_done = false; + + let mut check_data = |data: &DistantResponseData| match data { + DistantResponseData::ProcStderr { id, data } => { + assert_eq!( + *id, proc_id, + "Got {}, but expected {} as process id", + id, proc_id + ); + assert_eq!(data, b"some stderr", "Got wrong stderr"); + got_stderr = true; + } + DistantResponseData::ProcDone { id, success, .. } => { + assert_eq!( + *id, proc_id, + "Got {}, but expected {} as process id", + id, proc_id + ); + assert!(success, "Process should have completed successfully"); + got_done = true; + } + x => panic!("Unexpected response: {:?}", x), + }; + + check_data(&data_1); + check_data(&data_2); + assert!(got_stderr, "Missing stderr response"); + assert!(got_done, "Missing done response"); + } + + // NOTE: Ignoring on windows because it's using WSL which wants a Linux path + // with / but thinks it's on windows and is providing \ + #[tokio::test] + #[cfg_attr(windows, ignore)] + async fn proc_spawn_should_send_done_signal_when_completed() { + let (api, ctx, mut rx) = setup(1).await; + + let proc_id = api + .proc_spawn( + ctx, + /* cmd */ + format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Wait for process to finish + match rx.recv().await.unwrap() { + DistantResponseData::ProcDone { id, .. } => assert_eq!( + id, proc_id, + "Got {}, but expected {} as process id", + id, proc_id + ), + x => panic!("Unexpected response: {:?}", x), + } + } + + // NOTE: Ignoring on windows because it's using WSL which wants a Linux path + // with / but thinks it's on windows and is providing \ + #[tokio::test] + #[cfg_attr(windows, ignore)] + async fn proc_spawn_should_clear_process_from_state_when_killed() { + let (api, ctx_1, mut rx) = setup(1).await; + let (ctx_2, _rx) = { + let (reply, rx) = make_reply(1); + let ctx = DistantCtx { + connection_id: ctx_1.connection_id, + reply, + local_data: Arc::clone(&ctx_1.local_data), + }; + (ctx, rx) + }; + + let proc_id = api + .proc_spawn( + ctx_1, + /* cmd */ + format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Send kill signal + api.proc_kill(ctx_2, proc_id).await.unwrap(); + + // Wait for the completion response to come in + match rx.recv().await.unwrap() { + DistantResponseData::ProcDone { id, .. } => assert_eq!( + id, proc_id, + "Got {}, but expected {} as process id", + id, proc_id + ), + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + async fn proc_kill_should_fail_if_given_non_existent_process() { + let (api, ctx, _rx) = setup(1).await; + + // Send kill to a non-existent process + let _ = api.proc_kill(ctx, 0xDEADBEEF).await.unwrap_err(); + } + + #[tokio::test] + async fn proc_stdin_should_fail_if_given_non_existent_process() { + let (api, ctx, _rx) = setup(1).await; + + // Send stdin to a non-existent process + let _ = api + .proc_stdin(ctx, 0xDEADBEEF, b"some input".to_vec()) + .await + .unwrap_err(); + } + + // NOTE: Ignoring on windows because it's using WSL which wants a Linux path + // with / but thinks it's on windows and is providing \ + #[tokio::test] + #[cfg_attr(windows, ignore)] + async fn proc_stdin_should_send_stdin_to_process() { + let (api, ctx_1, mut rx) = setup(1).await; + let (ctx_2, _rx) = { + let (reply, rx) = make_reply(1); + let ctx = DistantCtx { + connection_id: ctx_1.connection_id, + reply, + local_data: Arc::clone(&ctx_1.local_data), + }; + (ctx, rx) + }; + + // First, run a program that listens for stdin + let id = api + .proc_spawn( + ctx_1, + /* cmd */ + format!( + "{} {}", + *SCRIPT_RUNNER, + ECHO_STDIN_TO_STDOUT_SH.to_str().unwrap() + ), + Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Second, send stdin to the remote process + api.proc_stdin(ctx_2, id, b"hello world\n".to_vec()) + .await + .unwrap(); + + // Third, check the async response of stdout to verify we got stdin + match rx.recv().await.unwrap() { + DistantResponseData::ProcStdout { data, .. } => { + assert_eq!(data, b"hello world\n", "Mirrored data didn't match"); + } + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + async fn system_info_should_return_system_info_based_on_binary() { + let (api, ctx, _rx) = setup(1).await; + + let system_info = api.system_info(ctx).await.unwrap(); + assert_eq!( + system_info, + SystemInfo { + family: std::env::consts::FAMILY.to_string(), + os: std::env::consts::OS.to_string(), + arch: std::env::consts::ARCH.to_string(), + current_dir: std::env::current_dir().unwrap_or_default(), + main_separator: std::path::MAIN_SEPARATOR, + } + ); + } +} diff --git a/distant-core/src/server/distant/process/mod.rs b/distant-core/src/api/local/process.rs similarity index 98% rename from distant-core/src/server/distant/process/mod.rs rename to distant-core/src/api/local/process.rs index 6a4b50e..562b6d4 100644 --- a/distant-core/src/server/distant/process/mod.rs +++ b/distant-core/src/api/local/process.rs @@ -1,4 +1,4 @@ -use crate::data::PtySize; +use crate::data::{ProcessId, PtySize}; use std::{future::Future, pin::Pin}; use tokio::{io, sync::mpsc}; @@ -17,7 +17,7 @@ pub type FutureReturn<'a, T> = Pin + Send + 'a>>; /// Represents a process on the remote server pub trait Process: ProcessKiller + ProcessPty { /// Represents the id of the process - fn id(&self) -> usize; + fn id(&self) -> ProcessId; /// Waits for the process to exit, returning the exit status /// diff --git a/distant-core/src/server/distant/process/pty.rs b/distant-core/src/api/local/process/pty.rs similarity index 93% rename from distant-core/src/server/distant/process/pty.rs rename to distant-core/src/api/local/process/pty.rs index 19ebe06..ff25f14 100644 --- a/distant-core/src/server/distant/process/pty.rs +++ b/distant-core/src/api/local/process/pty.rs @@ -1,19 +1,23 @@ use super::{ - wait, ExitStatus, FutureReturn, InputChannel, OutputChannel, Process, ProcessKiller, + wait, ExitStatus, FutureReturn, InputChannel, OutputChannel, Process, ProcessId, ProcessKiller, ProcessPty, PtySize, WaitRx, }; -use crate::constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_MILLIS}; +use crate::{ + constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_MILLIS}, + data::Environment, +}; use portable_pty::{CommandBuilder, MasterPty, PtySize as PortablePtySize}; use std::{ ffi::OsStr, io::{self, Read, Write}, + path::PathBuf, sync::{Arc, Mutex}, }; use tokio::{sync::mpsc, task::JoinHandle}; /// Represents a process that is associated with a pty pub struct PtyProcess { - id: usize, + id: ProcessId, pty_master: PtyProcessMaster, stdin: Option>, stdout: Option>, @@ -25,7 +29,13 @@ pub struct PtyProcess { impl PtyProcess { /// Spawns a new simple process - pub fn spawn(program: S, args: I, size: PtySize) -> io::Result + pub fn spawn( + program: S, + args: I, + environment: Environment, + current_dir: Option, + size: PtySize, + ) -> io::Result where S: AsRef, I: IntoIterator, @@ -47,6 +57,12 @@ impl PtyProcess { // Spawn our process within the pty let mut cmd = CommandBuilder::new(program); cmd.args(args); + if let Some(path) = current_dir { + cmd.cwd(path); + } + for (key, value) in environment { + cmd.env(key, value); + } let mut child = pty_slave .spawn_command(cmd) .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; @@ -78,7 +94,7 @@ impl PtyProcess { loop { match stdout_reader.read(&mut buf) { Ok(n) if n > 0 => { - let _ = stdout_tx.blocking_send(buf[..n].to_vec()).map_err(|_| { + stdout_tx.blocking_send(buf[..n].to_vec()).map_err(|_| { io::Error::new(io::ErrorKind::BrokenPipe, "Output channel closed") })?; } @@ -137,7 +153,7 @@ impl PtyProcess { } impl Process for PtyProcess { - fn id(&self) -> usize { + fn id(&self) -> ProcessId { self.id } diff --git a/distant-core/src/server/distant/process/simple.rs b/distant-core/src/api/local/process/simple.rs similarity index 86% rename from distant-core/src/server/distant/process/simple.rs rename to distant-core/src/api/local/process/simple.rs index afd61be..07a8f69 100644 --- a/distant-core/src/server/distant/process/simple.rs +++ b/distant-core/src/api/local/process/simple.rs @@ -1,15 +1,16 @@ use super::{ - wait, ExitStatus, FutureReturn, InputChannel, NoProcessPty, OutputChannel, Process, + wait, ExitStatus, FutureReturn, InputChannel, NoProcessPty, OutputChannel, Process, ProcessId, ProcessKiller, WaitRx, }; -use std::{ffi::OsStr, process::Stdio}; +use crate::data::Environment; +use std::{ffi::OsStr, path::PathBuf, process::Stdio}; use tokio::{io, process::Command, sync::mpsc, task::JoinHandle}; mod tasks; /// Represents a simple process that does not have a pty pub struct SimpleProcess { - id: usize, + id: ProcessId, stdin: Option>, stdout: Option>, stderr: Option>, @@ -22,18 +23,32 @@ pub struct SimpleProcess { impl SimpleProcess { /// Spawns a new simple process - pub fn spawn(program: S, args: I) -> io::Result + pub fn spawn( + program: S, + args: I, + environment: Environment, + current_dir: Option, + ) -> io::Result where S: AsRef, I: IntoIterator, S2: AsRef, { - let mut child = Command::new(program) - .args(args) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn()?; + let mut child = { + let mut command = Command::new(program); + + if let Some(path) = current_dir { + command.current_dir(path); + } + + command + .envs(environment) + .args(args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()? + }; let stdout = child.stdout.take().unwrap(); let (stdout_task, stdout_ch) = tasks::spawn_read_task(stdout, 1); @@ -80,7 +95,7 @@ impl SimpleProcess { } impl Process for SimpleProcess { - fn id(&self) -> usize { + fn id(&self) -> ProcessId { self.id } diff --git a/distant-core/src/server/distant/process/simple/tasks.rs b/distant-core/src/api/local/process/simple/tasks.rs similarity index 91% rename from distant-core/src/server/distant/process/simple/tasks.rs rename to distant-core/src/api/local/process/simple/tasks.rs index 4d253af..7d8f8ae 100644 --- a/distant-core/src/server/distant/process/simple/tasks.rs +++ b/distant-core/src/api/local/process/simple/tasks.rs @@ -1,6 +1,7 @@ use crate::constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_MILLIS}; +use std::io; use tokio::{ - io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, sync::mpsc, task::JoinHandle, }; @@ -27,7 +28,7 @@ where loop { match reader.read(&mut buf).await { Ok(n) if n > 0 => { - let _ = channel.send(buf[..n].to_vec()).await.map_err(|_| { + channel.send(buf[..n].to_vec()).await.map_err(|_| { io::Error::new(io::ErrorKind::BrokenPipe, "Output channel closed") })?; @@ -65,7 +66,7 @@ where W: AsyncWrite + Unpin, { while let Some(data) = channel.recv().await { - let _ = writer.write_all(&data).await?; + writer.write_all(&data).await?; } Ok(()) } diff --git a/distant-core/src/server/distant/process/wait.rs b/distant-core/src/api/local/process/wait.rs similarity index 100% rename from distant-core/src/server/distant/process/wait.rs rename to distant-core/src/api/local/process/wait.rs diff --git a/distant-core/src/api/local/state.rs b/distant-core/src/api/local/state.rs new file mode 100644 index 0000000..e48fb7f --- /dev/null +++ b/distant-core/src/api/local/state.rs @@ -0,0 +1,68 @@ +use crate::{data::ProcessId, ConnectionId}; +use std::{io, path::PathBuf}; + +mod process; +pub use process::*; + +mod watcher; +pub use watcher::*; + +/// Holds global state state managed by the server +pub struct GlobalState { + /// State that holds information about processes running on the server + pub process: ProcessState, + + /// Watcher used for filesystem events + pub watcher: WatcherState, +} + +impl GlobalState { + pub fn initialize() -> io::Result { + Ok(Self { + process: ProcessState::new(), + watcher: WatcherState::initialize()?, + }) + } +} + +/// Holds connection-specific state managed by the server +#[derive(Default)] +pub struct ConnectionState { + /// Unique id associated with connection + id: ConnectionId, + + /// Channel connected to global process state + pub(crate) process_channel: ProcessChannel, + + /// Channel connected to global watcher state + pub(crate) watcher_channel: WatcherChannel, + + /// Contains ids of processes that will be terminated when the connection is closed + processes: Vec, + + /// Contains paths being watched that will be unwatched when the connection is closed + paths: Vec, +} + +impl Drop for ConnectionState { + fn drop(&mut self) { + let id = self.id; + let processes: Vec = self.processes.drain(..).collect(); + let paths: Vec = self.paths.drain(..).collect(); + + let process_channel = self.process_channel.clone(); + let watcher_channel = self.watcher_channel.clone(); + + // NOTE: We cannot (and should not) block during drop to perform cleanup, + // instead spawning a task that will do the cleanup async + tokio::spawn(async move { + for id in processes { + let _ = process_channel.kill(id).await; + } + + for path in paths { + let _ = watcher_channel.unwatch(id, path).await; + } + }); + } +} diff --git a/distant-core/src/api/local/state/process.rs b/distant-core/src/api/local/state/process.rs new file mode 100644 index 0000000..1fe3319 --- /dev/null +++ b/distant-core/src/api/local/state/process.rs @@ -0,0 +1,231 @@ +use crate::data::{DistantResponseData, Environment, ProcessId, PtySize}; +use distant_net::Reply; +use std::{collections::HashMap, io, ops::Deref, path::PathBuf}; +use tokio::{ + sync::{mpsc, oneshot}, + task::JoinHandle, +}; + +mod instance; +pub use instance::*; + +/// Holds information related to spawned processes on the server +pub struct ProcessState { + channel: ProcessChannel, + task: JoinHandle<()>, +} + +impl Drop for ProcessState { + /// Aborts the task that handles process operations and management + fn drop(&mut self) { + self.abort(); + } +} + +impl ProcessState { + pub fn new() -> Self { + let (tx, rx) = mpsc::channel(1); + let task = tokio::spawn(process_task(tx.clone(), rx)); + + Self { + channel: ProcessChannel { tx }, + task, + } + } + + pub fn clone_channel(&self) -> ProcessChannel { + self.channel.clone() + } + + /// Aborts the process task + pub fn abort(&self) { + self.task.abort(); + } +} + +impl Deref for ProcessState { + type Target = ProcessChannel; + + fn deref(&self) -> &Self::Target { + &self.channel + } +} + +#[derive(Clone)] +pub struct ProcessChannel { + tx: mpsc::Sender, +} + +impl Default for ProcessChannel { + /// Creates a new channel that is closed by default + fn default() -> Self { + let (tx, _) = mpsc::channel(1); + Self { tx } + } +} + +impl ProcessChannel { + /// Spawns a new process, returning the id associated with it + pub async fn spawn( + &self, + cmd: String, + environment: Environment, + current_dir: Option, + persist: bool, + pty: Option, + reply: Box>, + ) -> io::Result { + let (cb, rx) = oneshot::channel(); + self.tx + .send(InnerProcessMsg::Spawn { + cmd, + environment, + current_dir, + persist, + pty, + reply, + cb, + }) + .await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?; + rx.await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to spawn dropped"))? + } + + /// Resizes the pty of a running process + pub async fn resize_pty(&self, id: ProcessId, size: PtySize) -> io::Result<()> { + let (cb, rx) = oneshot::channel(); + self.tx + .send(InnerProcessMsg::Resize { id, size, cb }) + .await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?; + rx.await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to resize dropped"))? + } + + /// Send stdin to a running process + pub async fn send_stdin(&self, id: ProcessId, data: Vec) -> io::Result<()> { + let (cb, rx) = oneshot::channel(); + self.tx + .send(InnerProcessMsg::Stdin { id, data, cb }) + .await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?; + rx.await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to stdin dropped"))? + } + + /// Kills a running process + pub async fn kill(&self, id: ProcessId) -> io::Result<()> { + let (cb, rx) = oneshot::channel(); + self.tx + .send(InnerProcessMsg::Kill { id, cb }) + .await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?; + rx.await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to kill dropped"))? + } +} + +/// Internal message to pass to our task below to perform some action +enum InnerProcessMsg { + Spawn { + cmd: String, + environment: Environment, + current_dir: Option, + persist: bool, + pty: Option, + reply: Box>, + cb: oneshot::Sender>, + }, + Resize { + id: ProcessId, + size: PtySize, + cb: oneshot::Sender>, + }, + Stdin { + id: ProcessId, + data: Vec, + cb: oneshot::Sender>, + }, + Kill { + id: ProcessId, + cb: oneshot::Sender>, + }, + InternalRemove { + id: ProcessId, + }, +} + +async fn process_task(tx: mpsc::Sender, mut rx: mpsc::Receiver) { + let mut processes: HashMap = HashMap::new(); + + while let Some(msg) = rx.recv().await { + match msg { + InnerProcessMsg::Spawn { + cmd, + environment, + current_dir, + persist, + pty, + reply, + cb, + } => { + let _ = cb.send( + match ProcessInstance::spawn(cmd, environment, current_dir, persist, pty, reply) + { + Ok(mut process) => { + let id = process.id; + + // Attach a callback for when the process is finished where + // we will remove it from our above list + let tx = tx.clone(); + process.on_done(move |_| async move { + let _ = tx.send(InnerProcessMsg::InternalRemove { id }).await; + }); + + processes.insert(id, process); + Ok(id) + } + Err(x) => Err(x), + }, + ); + } + InnerProcessMsg::Resize { id, size, cb } => { + let _ = cb.send(match processes.get(&id) { + Some(process) => process.pty.resize_pty(size), + None => Err(io::Error::new( + io::ErrorKind::Other, + format!("No process found with id {}", id), + )), + }); + } + InnerProcessMsg::Stdin { id, data, cb } => { + let _ = cb.send(match processes.get_mut(&id) { + Some(process) => match process.stdin.as_mut() { + Some(stdin) => stdin.send(&data).await, + None => Err(io::Error::new( + io::ErrorKind::Other, + format!("Process {} stdin is closed", id), + )), + }, + None => Err(io::Error::new( + io::ErrorKind::Other, + format!("No process found with id {}", id), + )), + }); + } + InnerProcessMsg::Kill { id, cb } => { + let _ = cb.send(match processes.get_mut(&id) { + Some(process) => process.killer.kill().await, + None => Err(io::Error::new( + io::ErrorKind::Other, + format!("No process found with id {}", id), + )), + }); + } + InnerProcessMsg::InternalRemove { id } => { + processes.remove(&id); + } + } + } +} diff --git a/distant-core/src/api/local/state/process/instance.rs b/distant-core/src/api/local/state/process/instance.rs new file mode 100644 index 0000000..4b558cf --- /dev/null +++ b/distant-core/src/api/local/state/process/instance.rs @@ -0,0 +1,230 @@ +use crate::{ + api::local::process::{ + InputChannel, OutputChannel, Process, ProcessKiller, ProcessPty, PtyProcess, SimpleProcess, + }, + data::{DistantResponseData, Environment, ProcessId, PtySize}, +}; +use distant_net::Reply; +use log::*; +use std::{future::Future, io, path::PathBuf}; +use tokio::task::JoinHandle; + +/// Holds information related to a spawned process on the server +pub struct ProcessInstance { + pub cmd: String, + pub args: Vec, + pub persist: bool, + + pub id: ProcessId, + pub stdin: Option>, + pub killer: Box, + pub pty: Box, + + stdout_task: Option>>, + stderr_task: Option>>, + wait_task: Option>>, +} + +impl Drop for ProcessInstance { + /// Closes stdin and attempts to kill the process when dropped + fn drop(&mut self) { + // Drop stdin first to close it + self.stdin = None; + + // Clear out our tasks if we still have them + let stdout_task = self.stdout_task.take(); + let stderr_task = self.stderr_task.take(); + let wait_task = self.wait_task.take(); + + // Attempt to kill the process, which is an async operation that we + // will spawn a task to handle + let id = self.id; + let mut killer = self.killer.clone_killer(); + tokio::spawn(async move { + if let Err(x) = killer.kill().await { + error!("Failed to kill process {} when dropped: {}", id, x); + + if let Some(task) = stdout_task.as_ref() { + task.abort(); + } + if let Some(task) = stderr_task.as_ref() { + task.abort(); + } + if let Some(task) = wait_task.as_ref() { + task.abort(); + } + } + }); + } +} + +impl ProcessInstance { + pub fn spawn( + cmd: String, + environment: Environment, + current_dir: Option, + persist: bool, + pty: Option, + reply: Box>, + ) -> io::Result { + // Build out the command and args from our string + let mut cmd_and_args = if cfg!(windows) { + winsplit::split(&cmd) + } else { + shell_words::split(&cmd).map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))? + }; + + if cmd_and_args.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Command was empty", + )); + } + + // Split command from arguments, where arguments could be empty + let args = cmd_and_args.split_off(1); + let cmd = cmd_and_args.into_iter().next().unwrap(); + + let mut child: Box = match pty { + Some(size) => Box::new(PtyProcess::spawn( + cmd.clone(), + args.clone(), + environment, + current_dir, + size, + )?), + None => Box::new(SimpleProcess::spawn( + cmd.clone(), + args.clone(), + environment, + current_dir, + )?), + }; + + let id = child.id(); + let stdin = child.take_stdin(); + let stdout = child.take_stdout(); + let stderr = child.take_stderr(); + let killer = child.clone_killer(); + let pty = child.clone_pty(); + + // Spawn a task that sends stdout as a response + let stdout_task = match stdout { + Some(stdout) => { + let reply = reply.clone_reply(); + let task = tokio::spawn(stdout_task(id, stdout, reply)); + Some(task) + } + None => None, + }; + + // Spawn a task that sends stderr as a response + let stderr_task = match stderr { + Some(stderr) => { + let reply = reply.clone_reply(); + let task = tokio::spawn(stderr_task(id, stderr, reply)); + Some(task) + } + None => None, + }; + + // Spawn a task that waits on the process to exit but can also + // kill the process when triggered + let wait_task = Some(tokio::spawn(wait_task(id, child, reply))); + + Ok(ProcessInstance { + cmd, + args, + persist, + id, + stdin, + killer, + pty, + stdout_task, + stderr_task, + wait_task, + }) + } + + /// Invokes the function once the process has completed + /// + /// NOTE: Can only be used with one function. All future calls + /// will do nothing + pub fn on_done(&mut self, f: F) + where + F: FnOnce(io::Result<()>) -> R + Send + 'static, + R: Future + Send, + { + if let Some(task) = self.wait_task.take() { + tokio::spawn(async move { + f(task + .await + .unwrap_or_else(|x| Err(io::Error::new(io::ErrorKind::Other, x)))) + .await + }); + } + } +} + +async fn stdout_task( + id: ProcessId, + mut stdout: Box, + reply: Box>, +) -> io::Result<()> { + loop { + match stdout.recv().await { + Ok(Some(data)) => { + if let Err(x) = reply + .send(DistantResponseData::ProcStdout { id, data }) + .await + { + return Err(x); + } + } + Ok(None) => return Ok(()), + Err(x) => return Err(x), + } + } +} + +async fn stderr_task( + id: ProcessId, + mut stderr: Box, + reply: Box>, +) -> io::Result<()> { + loop { + match stderr.recv().await { + Ok(Some(data)) => { + if let Err(x) = reply + .send(DistantResponseData::ProcStderr { id, data }) + .await + { + return Err(x); + } + } + Ok(None) => return Ok(()), + Err(x) => return Err(x), + } + } +} + +async fn wait_task( + id: ProcessId, + mut child: Box, + reply: Box>, +) -> io::Result<()> { + let status = child.wait().await; + + match status { + Ok(status) => { + reply + .send(DistantResponseData::ProcDone { + id, + success: status.success, + code: status.code, + }) + .await + } + Err(x) => reply.send(DistantResponseData::from(x)).await, + } +} diff --git a/distant-core/src/api/local/state/watcher.rs b/distant-core/src/api/local/state/watcher.rs new file mode 100644 index 0000000..5f490c1 --- /dev/null +++ b/distant-core/src/api/local/state/watcher.rs @@ -0,0 +1,286 @@ +use crate::{constants::SERVER_WATCHER_CAPACITY, data::ChangeKind, ConnectionId}; +use log::*; +use notify::{ + Config as WatcherConfig, Error as WatcherError, Event as WatcherEvent, RecommendedWatcher, + RecursiveMode, Watcher, +}; +use std::{ + collections::HashMap, + io, + ops::Deref, + path::{Path, PathBuf}, +}; +use tokio::{ + sync::{ + mpsc::{self, error::TrySendError}, + oneshot, + }, + task::JoinHandle, +}; + +mod path; +pub use path::*; + +/// Holds information related to watched paths on the server +pub struct WatcherState { + channel: WatcherChannel, + task: JoinHandle<()>, +} + +impl Drop for WatcherState { + /// Aborts the task that handles watcher path operations and management + fn drop(&mut self) { + self.abort(); + } +} + +impl WatcherState { + /// Will create a watcher and initialize watched paths to be empty + pub fn initialize() -> io::Result { + // NOTE: Cannot be something small like 1 as this seems to cause a deadlock sometimes + // with a large volume of watch requests + let (tx, rx) = mpsc::channel(SERVER_WATCHER_CAPACITY); + + let mut watcher = { + let tx = tx.clone(); + notify::recommended_watcher(move |res| { + match tx.try_send(match res { + Ok(x) => InnerWatcherMsg::Event { ev: x }, + Err(x) => InnerWatcherMsg::Error { err: x }, + }) { + Ok(_) => (), + Err(TrySendError::Full(_)) => { + warn!( + "Reached watcher capacity of {}! Dropping watcher event!", + SERVER_WATCHER_CAPACITY, + ); + } + Err(TrySendError::Closed(_)) => { + warn!("Skipping watch event because watcher channel closed"); + } + } + }) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x))? + }; + + // Attempt to configure watcher, but don't fail if these configurations fail + match watcher.configure(WatcherConfig::PreciseEvents(true)) { + Ok(true) => debug!("Watcher configured for precise events"), + Ok(false) => debug!("Watcher not configured for precise events",), + Err(x) => error!("Watcher configuration for precise events failed: {}", x), + } + + // Attempt to configure watcher, but don't fail if these configurations fail + match watcher.configure(WatcherConfig::NoticeEvents(true)) { + Ok(true) => debug!("Watcher configured for notice events"), + Ok(false) => debug!("Watcher not configured for notice events",), + Err(x) => error!("Watcher configuration for notice events failed: {}", x), + } + + Ok(Self { + channel: WatcherChannel { tx }, + task: tokio::spawn(watcher_task(watcher, rx)), + }) + } + + pub fn clone_channel(&self) -> WatcherChannel { + self.channel.clone() + } + + /// Aborts the watcher task + pub fn abort(&self) { + self.task.abort(); + } +} + +impl Deref for WatcherState { + type Target = WatcherChannel; + + fn deref(&self) -> &Self::Target { + &self.channel + } +} + +#[derive(Clone)] +pub struct WatcherChannel { + tx: mpsc::Sender, +} + +impl Default for WatcherChannel { + /// Creates a new channel that is closed by default + fn default() -> Self { + let (tx, _) = mpsc::channel(1); + Self { tx } + } +} + +impl WatcherChannel { + /// Watch a path for a specific connection denoted by the id within the registered path + pub async fn watch(&self, registered_path: RegisteredPath) -> io::Result<()> { + let (cb, rx) = oneshot::channel(); + self.tx + .send(InnerWatcherMsg::Watch { + registered_path, + cb, + }) + .await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal watcher task closed"))?; + rx.await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to watch dropped"))? + } + + /// Unwatch a path for a specific connection denoted by the id + pub async fn unwatch(&self, id: ConnectionId, path: impl AsRef) -> io::Result<()> { + let (cb, rx) = oneshot::channel(); + let path = tokio::fs::canonicalize(path.as_ref()) + .await + .unwrap_or_else(|_| path.as_ref().to_path_buf()); + self.tx + .send(InnerWatcherMsg::Unwatch { id, path, cb }) + .await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal watcher task closed"))?; + rx.await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to unwatch dropped"))? + } +} + +/// Internal message to pass to our task below to perform some action +enum InnerWatcherMsg { + Watch { + registered_path: RegisteredPath, + cb: oneshot::Sender>, + }, + Unwatch { + id: ConnectionId, + path: PathBuf, + cb: oneshot::Sender>, + }, + Event { + ev: WatcherEvent, + }, + Error { + err: WatcherError, + }, +} + +async fn watcher_task(mut watcher: RecommendedWatcher, mut rx: mpsc::Receiver) { + // TODO: Optimize this in some way to be more performant than + // checking every path whenever an event comes in + let mut registered_paths: Vec = Vec::new(); + let mut path_cnt: HashMap = HashMap::new(); + + while let Some(msg) = rx.recv().await { + match msg { + InnerWatcherMsg::Watch { + registered_path, + cb, + } => { + // Check if we are tracking the path across any connection + if let Some(cnt) = path_cnt.get_mut(registered_path.path()) { + // Increment the count of times we are watching that path + *cnt += 1; + + // Store the registered path in our collection without worry + // since we are already watching a path that impacts this one + registered_paths.push(registered_path); + + // Send an okay because we always succeed in this case + let _ = cb.send(Ok(())); + } else { + let res = watcher + .watch( + registered_path.path(), + if registered_path.is_recursive() { + RecursiveMode::Recursive + } else { + RecursiveMode::NonRecursive + }, + ) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x)); + + // If we succeeded, store our registered path and set the tracking cnt to 1 + if res.is_ok() { + path_cnt.insert(registered_path.path().to_path_buf(), 1); + registered_paths.push(registered_path); + } + + // Send the result of the watch, but don't worry if the channel was closed + let _ = cb.send(res); + } + } + InnerWatcherMsg::Unwatch { id, path, cb } => { + // Check if we are tracking the path across any connection + if let Some(cnt) = path_cnt.get(path.as_path()) { + // Cycle through and remove all paths that match the given id and path, + // capturing how many paths we removed + let removed_cnt = { + let old_len = registered_paths.len(); + registered_paths + .retain(|p| p.id() != id || (p.path() != path && p.raw_path() != path)); + let new_len = registered_paths.len(); + old_len - new_len + }; + + // 1. If we are now at zero cnt for our path, we want to actually unwatch the + // path with our watcher + // 2. If we removed nothing from our path list, we want to return an error + // 3. Otherwise, we return okay because we succeeded + if *cnt <= removed_cnt { + let _ = cb.send( + watcher + .unwatch(&path) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x)), + ); + } else if removed_cnt == 0 { + // Send a failure as there was nothing to unwatch for this connection + let _ = cb.send(Err(io::Error::new( + io::ErrorKind::Other, + format!("{:?} is not being watched", path), + ))); + } else { + // Send a success as we removed some paths + let _ = cb.send(Ok(())); + } + } else { + // Send a failure as there was nothing to unwatch + let _ = cb.send(Err(io::Error::new( + io::ErrorKind::Other, + format!("{:?} is not being watched", path), + ))); + } + } + InnerWatcherMsg::Event { ev } => { + let kind = ChangeKind::from(ev.kind); + + for registered_path in registered_paths.iter() { + match registered_path.filter_and_send(kind, &ev.paths).await { + Ok(_) => (), + Err(x) => error!( + "[Conn {}] Failed to forward changes to paths: {}", + registered_path.id(), + x + ), + } + } + } + InnerWatcherMsg::Error { err } => { + let msg = err.to_string(); + error!("Watcher encountered an error {} for {:?}", msg, err.paths); + + for registered_path in registered_paths.iter() { + match registered_path + .filter_and_send_error(&msg, &err.paths, !err.paths.is_empty()) + .await + { + Ok(_) => (), + Err(x) => error!( + "[Conn {}] Failed to forward changes to paths: {}", + registered_path.id(), + x + ), + } + } + } + } + } +} diff --git a/distant-core/src/api/local/state/watcher/path.rs b/distant-core/src/api/local/state/watcher/path.rs new file mode 100644 index 0000000..ab6b6ad --- /dev/null +++ b/distant-core/src/api/local/state/watcher/path.rs @@ -0,0 +1,212 @@ +use crate::{ + data::{Change, ChangeKind, ChangeKindSet, DistantResponseData, Error}, + ConnectionId, +}; +use distant_net::Reply; +use std::{ + fmt, + hash::{Hash, Hasher}, + io, + path::{Path, PathBuf}, +}; + +/// Represents a path registered with a watcher that includes relevant state including +/// the ability to reply with +pub struct RegisteredPath { + /// Unique id tied to the path to distinguish it + id: ConnectionId, + + /// The raw path provided to the watcher, which is not canonicalized + raw_path: PathBuf, + + /// The canonicalized path at the time of providing to the watcher, + /// as all paths must exist for a watcher, we use this to get the + /// source of truth when watching + path: PathBuf, + + /// Whether or not the path was set to be recursive + recursive: bool, + + /// Specific filter for path (only the allowed change kinds are tracked) + /// NOTE: This is a combination of only and except filters + allowed: ChangeKindSet, + + /// Used to send a reply through the connection watching this path + reply: Box>, +} + +impl fmt::Debug for RegisteredPath { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RegisteredPath") + .field("raw_path", &self.raw_path) + .field("path", &self.path) + .field("recursive", &self.recursive) + .field("allowed", &self.allowed) + .finish() + } +} + +impl PartialEq for RegisteredPath { + /// Checks for equality using the id, canonicalized path, and allowed change kinds + fn eq(&self, other: &Self) -> bool { + self.id == other.id && self.path == other.path && self.allowed == other.allowed + } +} + +impl Eq for RegisteredPath {} + +impl Hash for RegisteredPath { + /// Hashes using the id, canonicalized path, and allowed change kinds + fn hash(&self, state: &mut H) { + self.id.hash(state); + self.path.hash(state); + self.allowed.hash(state); + } +} + +impl RegisteredPath { + /// Registers a new path to be watched (does not actually do any watching) + pub async fn register( + id: ConnectionId, + path: impl Into, + recursive: bool, + only: impl Into, + except: impl Into, + reply: Box>, + ) -> io::Result { + let raw_path = path.into(); + let path = tokio::fs::canonicalize(raw_path.as_path()).await?; + let only = only.into(); + let except = except.into(); + + // Calculate the true list of kinds based on only and except filters + let allowed = if only.is_empty() { + ChangeKindSet::all() - except + } else { + only - except + }; + + Ok(Self { + id, + raw_path, + path, + recursive, + allowed, + reply, + }) + } + + /// Represents a unique id to distinguish this path from other registrations + /// of the same path + pub fn id(&self) -> ConnectionId { + self.id + } + + /// Represents the path provided during registration before canonicalization + pub fn raw_path(&self) -> &Path { + self.raw_path.as_path() + } + + /// Represents the canonicalized path used by watchers + pub fn path(&self) -> &Path { + self.path.as_path() + } + + /// Returns true if this path represents a recursive watcher path + pub fn is_recursive(&self) -> bool { + self.recursive + } + + /// Returns reference to set of [`ChangeKind`] that this path watches + pub fn allowed(&self) -> &ChangeKindSet { + &self.allowed + } + + /// Sends a reply for a change tied to this registered path, filtering + /// out any paths that are not applicable + /// + /// Returns true if message was sent, and false if not + pub async fn filter_and_send(&self, kind: ChangeKind, paths: T) -> io::Result + where + T: IntoIterator, + T::Item: AsRef, + { + if !self.allowed().contains(&kind) { + return Ok(false); + } + + let paths: Vec = paths + .into_iter() + .filter(|p| self.applies_to_path(p.as_ref())) + .map(|p| p.as_ref().to_path_buf()) + .collect(); + + if !paths.is_empty() { + self.reply + .send(DistantResponseData::Changed(Change { kind, paths })) + .await + .map(|_| true) + } else { + Ok(false) + } + } + + /// Sends an error message and includes paths if provided, skipping sending the message if + /// no paths match and `skip_if_no_paths` is true + /// + /// Returns true if message was sent, and false if not + pub async fn filter_and_send_error( + &self, + msg: &str, + paths: T, + skip_if_no_paths: bool, + ) -> io::Result + where + T: IntoIterator, + T::Item: AsRef, + { + let paths: Vec = paths + .into_iter() + .filter(|p| self.applies_to_path(p.as_ref())) + .map(|p| p.as_ref().to_path_buf()) + .collect(); + + if !paths.is_empty() || !skip_if_no_paths { + self.reply + .send(if paths.is_empty() { + DistantResponseData::Error(Error::from(msg)) + } else { + DistantResponseData::Error(Error::from(format!("{} about {:?}", msg, paths))) + }) + .await + .map(|_| true) + } else { + Ok(false) + } + } + + /// Returns true if this path applies to the given path. + /// This is accomplished by checking if the path is contained + /// within either the raw or canonicalized path of the watcher + /// and ensures that recursion rules are respected + pub fn applies_to_path(&self, path: &Path) -> bool { + let check_path = |path: &Path| -> bool { + let cnt = path.components().count(); + + // 0 means exact match from strip_prefix + // 1 means that it was within immediate directory (fine for non-recursive) + // 2+ means it needs to be recursive + cnt < 2 || self.recursive + }; + + match ( + path.strip_prefix(self.path()), + path.strip_prefix(self.raw_path()), + ) { + (Ok(p1), Ok(p2)) => check_path(p1) || check_path(p2), + (Ok(p), Err(_)) => check_path(p), + (Err(_), Ok(p)) => check_path(p), + (Err(_), Err(_)) => false, + } + } +} diff --git a/distant-core/src/api/reply.rs b/distant-core/src/api/reply.rs new file mode 100644 index 0000000..c693c56 --- /dev/null +++ b/distant-core/src/api/reply.rs @@ -0,0 +1,29 @@ +use crate::{api::DistantMsg, data::DistantResponseData}; +use distant_net::Reply; +use std::{future::Future, io, pin::Pin}; + +/// Wrapper around a reply that can be batch or single, converting +/// a single data into the wrapped type +pub struct DistantSingleReply(Box>>); + +impl From>>> for DistantSingleReply { + fn from(reply: Box>>) -> Self { + Self(reply) + } +} + +impl Reply for DistantSingleReply { + type Data = DistantResponseData; + + fn send(&self, data: Self::Data) -> Pin> + Send + '_>> { + self.0.send(DistantMsg::Single(data)) + } + + fn blocking_send(&self, data: Self::Data) -> io::Result<()> { + self.0.blocking_send(DistantMsg::Single(data)) + } + + fn clone_reply(&self) -> Box> { + Box::new(Self(self.0.clone_reply())) + } +} diff --git a/distant-core/src/client.rs b/distant-core/src/client.rs new file mode 100644 index 0000000..987b401 --- /dev/null +++ b/distant-core/src/client.rs @@ -0,0 +1,18 @@ +use crate::{DistantMsg, DistantRequestData, DistantResponseData}; +use distant_net::{Channel, Client}; + +mod ext; +mod lsp; +mod process; +mod watcher; + +/// Represents a [`Client`] that communicates using the distant protocol +pub type DistantClient = Client, DistantMsg>; + +/// Represents a [`Channel`] that communicates using the distant protocol +pub type DistantChannel = Channel, DistantMsg>; + +pub use ext::*; +pub use lsp::*; +pub use process::*; +pub use watcher::*; diff --git a/distant-core/src/client/ext.rs b/distant-core/src/client/ext.rs new file mode 100644 index 0000000..f912982 --- /dev/null +++ b/distant-core/src/client/ext.rs @@ -0,0 +1,424 @@ +use crate::{ + client::{ + RemoteCommand, RemoteLspCommand, RemoteLspProcess, RemoteOutput, RemoteProcess, Watcher, + }, + data::{ + ChangeKindSet, DirEntry, DistantRequestData, DistantResponseData, Environment, + Error as Failure, Metadata, PtySize, SystemInfo, + }, + DistantMsg, +}; +use distant_net::{Channel, Request}; +use std::{future::Future, io, path::PathBuf, pin::Pin}; + +pub type AsyncReturn<'a, T, E = io::Error> = + Pin> + Send + 'a>>; + +fn mismatched_response() -> io::Error { + io::Error::new(io::ErrorKind::Other, "Mismatched response") +} + +/// Provides convenience functions on top of a [`SessionChannel`] +pub trait DistantChannelExt { + /// Appends to a remote file using the data from a collection of bytes + fn append_file( + &mut self, + path: impl Into, + data: impl Into>, + ) -> AsyncReturn<'_, ()>; + + /// Appends to a remote file using the data from a string + fn append_file_text( + &mut self, + path: impl Into, + data: impl Into, + ) -> AsyncReturn<'_, ()>; + + /// Copies a remote file or directory from src to dst + fn copy(&mut self, src: impl Into, dst: impl Into) -> AsyncReturn<'_, ()>; + + /// Creates a remote directory, optionally creating all parent components if specified + fn create_dir(&mut self, path: impl Into, all: bool) -> AsyncReturn<'_, ()>; + + fn exists(&mut self, path: impl Into) -> AsyncReturn<'_, bool>; + + /// Retrieves metadata about a path on a remote machine + fn metadata( + &mut self, + path: impl Into, + canonicalize: bool, + resolve_file_type: bool, + ) -> AsyncReturn<'_, Metadata>; + + /// Reads entries from a directory, returning a tuple of directory entries and failures + fn read_dir( + &mut self, + path: impl Into, + depth: usize, + absolute: bool, + canonicalize: bool, + include_root: bool, + ) -> AsyncReturn<'_, (Vec, Vec)>; + + /// Reads a remote file as a collection of bytes + fn read_file(&mut self, path: impl Into) -> AsyncReturn<'_, Vec>; + + /// Returns a remote file as a string + fn read_file_text(&mut self, path: impl Into) -> AsyncReturn<'_, String>; + + /// Removes a remote file or directory, supporting removal of non-empty directories if + /// force is true + fn remove(&mut self, path: impl Into, force: bool) -> AsyncReturn<'_, ()>; + + /// Renames a remote file or directory from src to dst + fn rename(&mut self, src: impl Into, dst: impl Into) -> AsyncReturn<'_, ()>; + + /// Watches a remote file or directory + fn watch( + &mut self, + path: impl Into, + recursive: bool, + only: impl Into, + except: impl Into, + ) -> AsyncReturn<'_, Watcher>; + + /// Unwatches a remote file or directory + fn unwatch(&mut self, path: impl Into) -> AsyncReturn<'_, ()>; + + /// Spawns a process on the remote machine + fn spawn( + &mut self, + cmd: impl Into, + environment: Environment, + current_dir: Option, + persist: bool, + pty: Option, + ) -> AsyncReturn<'_, RemoteProcess>; + + /// Spawns an LSP process on the remote machine + fn spawn_lsp( + &mut self, + cmd: impl Into, + environment: Environment, + current_dir: Option, + persist: bool, + pty: Option, + ) -> AsyncReturn<'_, RemoteLspProcess>; + + /// Spawns a process on the remote machine and wait for it to complete + fn output( + &mut self, + cmd: impl Into, + environment: Environment, + current_dir: Option, + pty: Option, + ) -> AsyncReturn<'_, RemoteOutput>; + + /// Retrieves information about the remote system + fn system_info(&mut self) -> AsyncReturn<'_, SystemInfo>; + + /// Writes a remote file with the data from a collection of bytes + fn write_file( + &mut self, + path: impl Into, + data: impl Into>, + ) -> AsyncReturn<'_, ()>; + + /// Writes a remote file with the data from a string + fn write_file_text( + &mut self, + path: impl Into, + data: impl Into, + ) -> AsyncReturn<'_, ()>; +} + +macro_rules! make_body { + ($self:expr, $data:expr, @ok) => { + make_body!($self, $data, |data| { + match data { + DistantResponseData::Ok => Ok(()), + DistantResponseData::Error(x) => Err(io::Error::from(x)), + _ => Err(mismatched_response()), + } + }) + }; + + ($self:expr, $data:expr, $and_then:expr) => {{ + let req = Request::new(DistantMsg::Single($data)); + Box::pin(async move { + $self + .send(req) + .await + .and_then(|res| match res.payload { + DistantMsg::Single(x) => Ok(x), + _ => Err(mismatched_response()), + }) + .and_then($and_then) + }) + }}; +} + +impl DistantChannelExt + for Channel, DistantMsg> +{ + fn append_file( + &mut self, + path: impl Into, + data: impl Into>, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + DistantRequestData::FileAppend { path: path.into(), data: data.into() }, + @ok + ) + } + + fn append_file_text( + &mut self, + path: impl Into, + data: impl Into, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + DistantRequestData::FileAppendText { path: path.into(), text: data.into() }, + @ok + ) + } + + fn copy(&mut self, src: impl Into, dst: impl Into) -> AsyncReturn<'_, ()> { + make_body!( + self, + DistantRequestData::Copy { src: src.into(), dst: dst.into() }, + @ok + ) + } + + fn create_dir(&mut self, path: impl Into, all: bool) -> AsyncReturn<'_, ()> { + make_body!( + self, + DistantRequestData::DirCreate { path: path.into(), all }, + @ok + ) + } + + fn exists(&mut self, path: impl Into) -> AsyncReturn<'_, bool> { + make_body!( + self, + DistantRequestData::Exists { path: path.into() }, + |data| match data { + DistantResponseData::Exists { value } => Ok(value), + DistantResponseData::Error(x) => Err(io::Error::from(x)), + _ => Err(mismatched_response()), + } + ) + } + + fn metadata( + &mut self, + path: impl Into, + canonicalize: bool, + resolve_file_type: bool, + ) -> AsyncReturn<'_, Metadata> { + make_body!( + self, + DistantRequestData::Metadata { + path: path.into(), + canonicalize, + resolve_file_type + }, + |data| match data { + DistantResponseData::Metadata(x) => Ok(x), + DistantResponseData::Error(x) => Err(io::Error::from(x)), + _ => Err(mismatched_response()), + } + ) + } + + fn read_dir( + &mut self, + path: impl Into, + depth: usize, + absolute: bool, + canonicalize: bool, + include_root: bool, + ) -> AsyncReturn<'_, (Vec, Vec)> { + make_body!( + self, + DistantRequestData::DirRead { + path: path.into(), + depth, + absolute, + canonicalize, + include_root + }, + |data| match data { + DistantResponseData::DirEntries { entries, errors } => Ok((entries, errors)), + DistantResponseData::Error(x) => Err(io::Error::from(x)), + _ => Err(mismatched_response()), + } + ) + } + + fn read_file(&mut self, path: impl Into) -> AsyncReturn<'_, Vec> { + make_body!( + self, + DistantRequestData::FileRead { path: path.into() }, + |data| match data { + DistantResponseData::Blob { data } => Ok(data), + DistantResponseData::Error(x) => Err(io::Error::from(x)), + _ => Err(mismatched_response()), + } + ) + } + + fn read_file_text(&mut self, path: impl Into) -> AsyncReturn<'_, String> { + make_body!( + self, + DistantRequestData::FileReadText { path: path.into() }, + |data| match data { + DistantResponseData::Text { data } => Ok(data), + DistantResponseData::Error(x) => Err(io::Error::from(x)), + _ => Err(mismatched_response()), + } + ) + } + + fn remove(&mut self, path: impl Into, force: bool) -> AsyncReturn<'_, ()> { + make_body!( + self, + DistantRequestData::Remove { path: path.into(), force }, + @ok + ) + } + + fn rename(&mut self, src: impl Into, dst: impl Into) -> AsyncReturn<'_, ()> { + make_body!( + self, + DistantRequestData::Rename { src: src.into(), dst: dst.into() }, + @ok + ) + } + + fn watch( + &mut self, + path: impl Into, + recursive: bool, + only: impl Into, + except: impl Into, + ) -> AsyncReturn<'_, Watcher> { + let path = path.into(); + let only = only.into(); + let except = except.into(); + Box::pin(async move { Watcher::watch(self.clone(), path, recursive, only, except).await }) + } + + fn unwatch(&mut self, path: impl Into) -> AsyncReturn<'_, ()> { + fn inner_unwatch( + channel: &mut Channel, DistantMsg>, + path: impl Into, + ) -> AsyncReturn<'_, ()> { + make_body!( + channel, + DistantRequestData::Unwatch { path: path.into() }, + @ok + ) + } + + let path = path.into(); + + Box::pin(async move { inner_unwatch(self, path).await }) + } + + fn spawn( + &mut self, + cmd: impl Into, + environment: Environment, + current_dir: Option, + persist: bool, + pty: Option, + ) -> AsyncReturn<'_, RemoteProcess> { + let cmd = cmd.into(); + Box::pin(async move { + RemoteCommand::new() + .environment(environment) + .current_dir(current_dir) + .persist(persist) + .pty(pty) + .spawn(self.clone(), cmd) + .await + }) + } + + fn spawn_lsp( + &mut self, + cmd: impl Into, + environment: Environment, + current_dir: Option, + persist: bool, + pty: Option, + ) -> AsyncReturn<'_, RemoteLspProcess> { + let cmd = cmd.into(); + Box::pin(async move { + RemoteLspCommand::new() + .environment(environment) + .current_dir(current_dir) + .persist(persist) + .pty(pty) + .spawn(self.clone(), cmd) + .await + }) + } + + fn output( + &mut self, + cmd: impl Into, + environment: Environment, + current_dir: Option, + pty: Option, + ) -> AsyncReturn<'_, RemoteOutput> { + let cmd = cmd.into(); + Box::pin(async move { + RemoteCommand::new() + .environment(environment) + .current_dir(current_dir) + .persist(false) + .pty(pty) + .spawn(self.clone(), cmd) + .await? + .output() + .await + }) + } + + fn system_info(&mut self) -> AsyncReturn<'_, SystemInfo> { + make_body!(self, DistantRequestData::SystemInfo {}, |data| match data { + DistantResponseData::SystemInfo(x) => Ok(x), + DistantResponseData::Error(x) => Err(io::Error::from(x)), + _ => Err(mismatched_response()), + }) + } + + fn write_file( + &mut self, + path: impl Into, + data: impl Into>, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + DistantRequestData::FileWrite { path: path.into(), data: data.into() }, + @ok + ) + } + + fn write_file_text( + &mut self, + path: impl Into, + data: impl Into, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + DistantRequestData::FileWriteText { path: path.into(), text: data.into() }, + @ok + ) + } +} diff --git a/distant-core/src/client/lsp/mod.rs b/distant-core/src/client/lsp.rs similarity index 81% rename from distant-core/src/client/lsp/mod.rs rename to distant-core/src/client/lsp.rs index 35b1e71..244d0d9 100644 --- a/distant-core/src/client/lsp/mod.rs +++ b/distant-core/src/client/lsp.rs @@ -1,39 +1,90 @@ -use super::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout}; -use crate::{client::SessionChannel, data::PtySize}; +use crate::{ + client::{ + DistantChannel, RemoteCommand, RemoteProcess, RemoteStatus, RemoteStderr, RemoteStdin, + RemoteStdout, + }, + data::{Environment, PtySize}, +}; use futures::stream::{Stream, StreamExt}; use std::{ io::{self, Cursor, Read}, ops::{Deref, DerefMut}, + path::PathBuf, }; use tokio::{ sync::mpsc::{self, error::TryRecvError}, task::JoinHandle, }; -mod data; -pub use data::*; +mod msg; +pub use msg::*; -/// Represents an LSP server process on a remote machine -#[derive(Debug)] -pub struct RemoteLspProcess { - inner: RemoteProcess, - pub stdin: Option, - pub stdout: Option, - pub stderr: Option, +/// A [`RemoteLspProcess`] builder providing support to configure +/// before spawning the process on a remote machine +pub struct RemoteLspCommand { + persist: bool, + pty: Option, + environment: Environment, + current_dir: Option, } -impl RemoteLspProcess { +impl Default for RemoteLspCommand { + fn default() -> Self { + Self::new() + } +} + +impl RemoteLspCommand { + /// Creates a new set of options for a remote LSP process + pub fn new() -> Self { + Self { + persist: false, + pty: None, + environment: Environment::new(), + current_dir: None, + } + } + + /// Sets whether or not the process will be persistent, + /// meaning that it will not be terminated when the + /// connection to the remote machine is terminated + pub fn persist(&mut self, persist: bool) -> &mut Self { + self.persist = persist; + self + } + + /// Configures the process to leverage a PTY with the specified size + pub fn pty(&mut self, pty: Option) -> &mut Self { + self.pty = pty; + self + } + + /// Replaces the existing environment variables with the given collection + pub fn environment(&mut self, environment: Environment) -> &mut Self { + self.environment = environment; + self + } + + /// Configures the process with an alternative current directory + pub fn current_dir(&mut self, current_dir: Option) -> &mut Self { + self.current_dir = current_dir; + self + } + /// Spawns the specified process on the remote machine using the given session, treating /// the process like an LSP server pub async fn spawn( - tenant: impl Into, - channel: SessionChannel, + &mut self, + channel: DistantChannel, cmd: impl Into, - args: Vec, - persist: bool, - pty: Option, - ) -> Result { - let mut inner = RemoteProcess::spawn(tenant, channel, cmd, args, persist, pty).await?; + ) -> io::Result { + let mut command = RemoteCommand::new(); + command.environment(self.environment.clone()); + command.current_dir(self.current_dir.clone()); + command.persist(self.persist); + command.pty(self.pty); + + let mut inner = command.spawn(channel, cmd).await?; let stdin = inner.stdin.take().map(RemoteLspStdin::new); let stdout = inner.stdout.take().map(RemoteLspStdout::new); let stderr = inner.stderr.take().map(RemoteLspStderr::new); @@ -45,9 +96,20 @@ impl RemoteLspProcess { stderr, }) } +} + +/// Represents an LSP server process on a remote machine +#[derive(Debug)] +pub struct RemoteLspProcess { + inner: RemoteProcess, + pub stdin: Option, + pub stdout: Option, + pub stderr: Option, +} +impl RemoteLspProcess { /// Waits for the process to terminate, returning the success status and an optional exit code - pub async fn wait(self) -> Result<(bool, Option), RemoteProcessError> { + pub async fn wait(self) -> io::Result { self.inner.wait().await } } @@ -116,7 +178,7 @@ impl RemoteLspStdin { self.write(data.as_bytes()).await } - fn update_and_read_messages(&mut self, data: &[u8]) -> io::Result> { + fn update_and_read_messages(&mut self, data: &[u8]) -> io::Result> { // Create or insert into our buffer match &mut self.buf { Some(buf) => buf.extend(data), @@ -318,7 +380,7 @@ where (read_task, rx) } -fn read_lsp_messages(input: &[u8]) -> io::Result<(Option>, Vec)> { +fn read_lsp_messages(input: &[u8]) -> io::Result<(Option>, Vec)> { let mut queue = Vec::new(); // Continue to read complete messages from the input until we either fail to parse or we reach @@ -326,7 +388,7 @@ fn read_lsp_messages(input: &[u8]) -> io::Result<(Option>, Vec) // cursor may have moved partially from lsp successfully reading the start of a message let mut cursor = Cursor::new(input); let mut pos = 0; - while let Ok(data) = LspData::from_buf_reader(&mut cursor) { + while let Ok(data) = LspMsg::from_buf_reader(&mut cursor) { queue.push(data); pos = cursor.position(); } @@ -347,10 +409,10 @@ fn read_lsp_messages(input: &[u8]) -> io::Result<(Option>, Vec) #[cfg(test)] mod tests { use super::*; - use crate::{ - client::Session, - data::{Request, RequestData, Response, ResponseData}, - net::{InmemoryStream, PlainCodec, Transport}, + use crate::data::{DistantRequestData, DistantResponseData}; + use distant_net::{ + Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Request, Response, + TypedAsyncRead, TypedAsyncWrite, }; use std::{future::Future, time::Duration}; @@ -358,29 +420,26 @@ mod tests { const TIMEOUT: Duration = Duration::from_millis(50); // Configures an lsp process with a means to send & receive data from outside - async fn spawn_lsp_process() -> (Transport, RemoteLspProcess) { - let (mut t1, t2) = Transport::make_pair(); - let session = Session::initialize(t2).unwrap(); + async fn spawn_lsp_process() -> ( + FramedTransport, + RemoteLspProcess, + ) { + let (mut t1, t2) = FramedTransport::pair(100); + let (writer, reader) = t2.into_split(); + let session = Client::new(writer, reader).unwrap(); let spawn_task = tokio::spawn(async move { - RemoteLspProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteLspCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = t1.receive::().await.unwrap().unwrap(); + let req: Request = t1.read().await.unwrap().unwrap(); // Send back a response through the session - t1.send(Response::new( - "test-tenant", + t1.write(Response::new( req.id, - vec![ResponseData::ProcSpawned { id: rand::random() }], + DistantResponseData::ProcSpawned { id: rand::random() }, )) .await .unwrap(); @@ -427,13 +486,12 @@ mod tests { .unwrap(); // Validate that the outgoing req is a complete LSP message - let req = transport.receive::().await.unwrap().unwrap(); - assert_eq!(req.payload.len(), 1, "Unexpected payload size"); - match &req.payload[0] { - RequestData::ProcStdin { data, .. } => { + let req: Request = transport.read().await.unwrap().unwrap(); + match req.payload { + DistantRequestData::ProcStdin { data, .. } => { assert_eq!( data, - &make_lsp_msg(serde_json::json!({ + make_lsp_msg(serde_json::json!({ "field1": "a", "field2": "b", })) @@ -460,20 +518,23 @@ mod tests { // Verify that nothing has been sent out yet // NOTE: Yield to ensure that data would be waiting at the transport if it was sent tokio::task::yield_now().await; - let result = timeout(TIMEOUT, transport.receive::()).await; + let result = timeout( + TIMEOUT, + TypedAsyncRead::>::read(&mut transport), + ) + .await; assert!(result.is_err(), "Unexpectedly got data: {:?}", result); // Write remainder of message proc.stdin.as_mut().unwrap().write(msg_b).await.unwrap(); // Validate that the outgoing req is a complete LSP message - let req = transport.receive::().await.unwrap().unwrap(); - assert_eq!(req.payload.len(), 1, "Unexpected payload size"); - match &req.payload[0] { - RequestData::ProcStdin { data, .. } => { + let req: Request = transport.read().await.unwrap().unwrap(); + match req.payload { + DistantRequestData::ProcStdin { data, .. } => { assert_eq!( data, - &make_lsp_msg(serde_json::json!({ + make_lsp_msg(serde_json::json!({ "field1": "a", "field2": "b", })) @@ -503,13 +564,12 @@ mod tests { .unwrap(); // Validate that the outgoing req is a complete LSP message - let req = transport.receive::().await.unwrap().unwrap(); - assert_eq!(req.payload.len(), 1, "Unexpected payload size"); - match &req.payload[0] { - RequestData::ProcStdin { data, .. } => { + let req: Request = transport.read().await.unwrap().unwrap(); + match req.payload { + DistantRequestData::ProcStdin { data, .. } => { assert_eq!( data, - &make_lsp_msg(serde_json::json!({ + make_lsp_msg(serde_json::json!({ "field1": "a", "field2": "b", })) @@ -553,13 +613,12 @@ mod tests { .unwrap(); // Validate that the first outgoing req is a complete LSP message matching first - let req = transport.receive::().await.unwrap().unwrap(); - assert_eq!(req.payload.len(), 1, "Unexpected payload size"); - match &req.payload[0] { - RequestData::ProcStdin { data, .. } => { + let req: Request = transport.read().await.unwrap().unwrap(); + match req.payload { + DistantRequestData::ProcStdin { data, .. } => { assert_eq!( data, - &make_lsp_msg(serde_json::json!({ + make_lsp_msg(serde_json::json!({ "field1": "a", "field2": "b", })) @@ -569,13 +628,12 @@ mod tests { } // Validate that the second outgoing req is a complete LSP message matching second - let req = transport.receive::().await.unwrap().unwrap(); - assert_eq!(req.payload.len(), 1, "Unexpected payload size"); - match &req.payload[0] { - RequestData::ProcStdin { data, .. } => { + let req: Request = transport.read().await.unwrap().unwrap(); + match req.payload { + DistantRequestData::ProcStdin { data, .. } => { assert_eq!( data, - &make_lsp_msg(serde_json::json!({ + make_lsp_msg(serde_json::json!({ "field1": "c", "field2": "d", })) @@ -600,16 +658,15 @@ mod tests { .unwrap(); // Validate that the outgoing req is a complete LSP message - let req = transport.receive::().await.unwrap().unwrap(); - assert_eq!(req.payload.len(), 1, "Unexpected payload size"); - match &req.payload[0] { - RequestData::ProcStdin { data, .. } => { + let req: Request = transport.read().await.unwrap().unwrap(); + match req.payload { + DistantRequestData::ProcStdin { data, .. } => { // Verify the contents AND headers are as expected; in this case, // this will also ensure that the Content-Length is adjusted // when the distant scheme was changed to file assert_eq!( data, - &make_lsp_msg(serde_json::json!({ + make_lsp_msg(serde_json::json!({ "field1": "file://some/path", "field2": "file://other/path", })) @@ -625,16 +682,15 @@ mod tests { // Send complete LSP message as stdout to process transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStdout { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStdout { id: proc.id(), data: make_lsp_msg(serde_json::json!({ "field1": "a", "field2": "b", })), - }], + }, )) .await .unwrap(); @@ -662,13 +718,12 @@ mod tests { // Send half of LSP message over stdout transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStdout { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStdout { id: proc.id(), data: msg_a.to_vec(), - }], + }, )) .await .unwrap(); @@ -681,13 +736,12 @@ mod tests { // Send other half of LSP message over stdout transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStdout { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStdout { id: proc.id(), data: msg_b.to_vec(), - }], + }, )) .await .unwrap(); @@ -716,13 +770,12 @@ mod tests { // Send complete LSP message as stdout to process transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStdout { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStdout { id: proc.id(), data: format!("{}{}", String::from_utf8(msg).unwrap(), extra).into_bytes(), - }], + }, )) .await .unwrap(); @@ -760,10 +813,9 @@ mod tests { // Send complete LSP message as stdout to process transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStdout { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStdout { id: proc.id(), data: format!( "{}{}", @@ -771,7 +823,7 @@ mod tests { String::from_utf8(msg_2).unwrap() ) .into_bytes(), - }], + }, )) .await .unwrap(); @@ -803,16 +855,15 @@ mod tests { // Send complete LSP message as stdout to process transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStdout { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStdout { id: proc.id(), data: make_lsp_msg(serde_json::json!({ "field1": "distant://some/path", "field2": "file://other/path", })), - }], + }, )) .await .unwrap(); @@ -834,16 +885,15 @@ mod tests { // Send complete LSP message as stderr to process transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStderr { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStderr { id: proc.id(), data: make_lsp_msg(serde_json::json!({ "field1": "a", "field2": "b", })), - }], + }, )) .await .unwrap(); @@ -871,13 +921,12 @@ mod tests { // Send half of LSP message over stderr transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStderr { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStderr { id: proc.id(), data: msg_a.to_vec(), - }], + }, )) .await .unwrap(); @@ -890,13 +939,12 @@ mod tests { // Send other half of LSP message over stderr transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStderr { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStderr { id: proc.id(), data: msg_b.to_vec(), - }], + }, )) .await .unwrap(); @@ -925,13 +973,12 @@ mod tests { // Send complete LSP message as stderr to process transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStderr { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStderr { id: proc.id(), data: format!("{}{}", String::from_utf8(msg).unwrap(), extra).into_bytes(), - }], + }, )) .await .unwrap(); @@ -969,10 +1016,9 @@ mod tests { // Send complete LSP message as stderr to process transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStderr { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStderr { id: proc.id(), data: format!( "{}{}", @@ -980,7 +1026,7 @@ mod tests { String::from_utf8(msg_2).unwrap() ) .into_bytes(), - }], + }, )) .await .unwrap(); @@ -1012,16 +1058,15 @@ mod tests { // Send complete LSP message as stderr to process transport - .send(Response::new( - "test-tenant", - proc.origin_id(), - vec![ResponseData::ProcStderr { + .write(Response::new( + proc.origin_id().to_string(), + DistantResponseData::ProcStderr { id: proc.id(), data: make_lsp_msg(serde_json::json!({ "field1": "distant://some/path", "field2": "file://other/path", })), - }], + }, )) .await .unwrap(); diff --git a/distant-core/src/client/lsp/data.rs b/distant-core/src/client/lsp/msg.rs similarity index 52% rename from distant-core/src/client/lsp/data.rs rename to distant-core/src/client/lsp/msg.rs index 1fa9936..0156d31 100644 --- a/distant-core/src/client/lsp/data.rs +++ b/distant-core/src/client/lsp/msg.rs @@ -1,4 +1,3 @@ -use crate::client::{SessionInfo, SessionInfoParseError}; use derive_more::{Display, Error, From}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; @@ -10,33 +9,9 @@ use std::{ string::FromUtf8Error, }; -#[derive(Copy, Clone, Debug, PartialEq, Eq, Display, Error, From)] -pub enum LspSessionInfoError { - /// Encountered when attempting to create a session from a request that is not initialize - NotInitializeRequest, - - /// Encountered if missing session parameters within an initialize request - MissingSessionInfoParams, - - /// Encountered if session parameters are not expected types - InvalidSessionInfoParams, - - /// Encountered when failing to parse session - SessionInfoParseError(SessionInfoParseError), -} - -impl From for io::Error { - fn from(x: LspSessionInfoError) -> Self { - match x { - LspSessionInfoError::SessionInfoParseError(x) => x.into(), - x => io::Error::new(io::ErrorKind::InvalidData, x), - } - } -} - /// Represents some data being communicated to/from an LSP consisting of a header and content part #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct LspData { +pub struct LspMsg { /// Header-portion of some data related to LSP header: LspHeader, @@ -45,7 +20,7 @@ pub struct LspData { } #[derive(Debug, Display, Error, From)] -pub enum LspDataParseError { +pub enum LspMsgParseError { /// When the received content is malformed BadContent(LspContentParseError), @@ -65,23 +40,23 @@ pub enum LspDataParseError { UnexpectedEof, } -impl From for io::Error { - fn from(x: LspDataParseError) -> Self { +impl From for io::Error { + fn from(x: LspMsgParseError) -> Self { match x { - LspDataParseError::BadContent(x) => x.into(), - LspDataParseError::BadHeader(x) => x.into(), - LspDataParseError::BadHeaderTermination => io::Error::new( + LspMsgParseError::BadContent(x) => x.into(), + LspMsgParseError::BadHeader(x) => x.into(), + LspMsgParseError::BadHeaderTermination => io::Error::new( io::ErrorKind::InvalidData, r"Received header line not terminated in \r\n", ), - LspDataParseError::BadInput(x) => io::Error::new(io::ErrorKind::InvalidData, x), - LspDataParseError::IoError(x) => x, - LspDataParseError::UnexpectedEof => io::Error::from(io::ErrorKind::UnexpectedEof), + LspMsgParseError::BadInput(x) => io::Error::new(io::ErrorKind::InvalidData, x), + LspMsgParseError::IoError(x) => x, + LspMsgParseError::UnexpectedEof => io::Error::from(io::ErrorKind::UnexpectedEof), } } } -impl LspData { +impl LspMsg { /// Returns a reference to the header part pub fn header(&self) -> &LspHeader { &self.header @@ -106,19 +81,6 @@ impl LspData { self.header.content_length = self.content.to_string().len(); } - /// Creates a session's info by inspecting the content for session parameters, removing the - /// session parameters from the content. Will also adjust the content length header to match - /// the new size of the content. - pub fn take_session_info(&mut self) -> Result { - match self.content.take_session_info() { - Ok(session) => { - self.refresh_content_length(); - Ok(session) - } - Err(x) => Err(x), - } - } - /// Attempts to read incoming lsp data from a buffered reader. /// /// Note that this is **blocking** while it waits on the header information (or EOF)! @@ -132,7 +94,7 @@ impl LspData { /// ... /// } /// ``` - pub fn from_buf_reader(r: &mut R) -> Result { + pub fn from_buf_reader(r: &mut R) -> Result { // Read in our headers first so we can figure out how much more to read let mut buf = String::new(); loop { @@ -145,14 +107,14 @@ impl LspData { // We shouldn't be getting end of the reader yet if len == 0 { - return Err(LspDataParseError::UnexpectedEof); + return Err(LspMsgParseError::UnexpectedEof); } let line = &buf[start..end]; // Check if we've gotten bad data if !line.ends_with("\r\n") { - return Err(LspDataParseError::BadHeaderTermination); + return Err(LspMsgParseError::BadHeaderTermination); // Check if we've received the header termination } else if line == "\r\n" { @@ -168,9 +130,9 @@ impl LspData { let mut buf = vec![0u8; header.content_length]; r.read_exact(&mut buf).map_err(|x| { if x.kind() == io::ErrorKind::UnexpectedEof { - LspDataParseError::UnexpectedEof + LspMsgParseError::UnexpectedEof } else { - LspDataParseError::IoError(x) + LspMsgParseError::IoError(x) } })?; String::from_utf8(buf)?.parse::()? @@ -185,7 +147,7 @@ impl LspData { } } -impl fmt::Display for LspData { +impl fmt::Display for LspMsg { /// Outputs header & content in form /// /// ```text @@ -203,8 +165,8 @@ impl fmt::Display for LspData { } } -impl FromStr for LspData { - type Err = LspDataParseError; +impl FromStr for LspMsg { + type Err = LspMsgParseError; /// Parses headers and content in the form of /// @@ -380,63 +342,6 @@ impl LspContent { pub fn convert_distant_scheme_to_local(&mut self) { swap_prefix(&mut self.0, "distant:", "file:"); } - - /// Creates a session's info by inspecting the content for session parameters, removing the - /// session parameters from the content - pub fn take_session_info(&mut self) -> Result { - // Verify that we're dealing with an initialize request - match self.0.get("method") { - Some(value) if value == "initialize" => {} - _ => return Err(LspSessionInfoError::NotInitializeRequest), - } - - // Attempt to grab the distant initialization options - match self.strip_session_params() { - Some((Some(host), Some(port), Some(key))) => { - let host = host - .as_str() - .ok_or(LspSessionInfoError::InvalidSessionInfoParams)?; - let port = port - .as_u64() - .ok_or(LspSessionInfoError::InvalidSessionInfoParams)?; - let key = key - .as_str() - .ok_or(LspSessionInfoError::InvalidSessionInfoParams)?; - Ok(format!("DISTANT CONNECT {} {} {}", host, port, key).parse()?) - } - _ => Err(LspSessionInfoError::MissingSessionInfoParams), - } - } - - /// Strips the session params from the content, returning them if they exist - /// - /// ```json - /// { - /// "params": { - /// "initializationOptions": { - /// "distant": { - /// "host": "...", - /// "port": ..., - /// "key": "..." - /// } - /// } - /// } - /// } - /// ``` - fn strip_session_params(&mut self) -> Option<(Option, Option, Option)> { - self.0 - .get_mut("params") - .and_then(|v| v.get_mut("initializationOptions")) - .and_then(|v| v.as_object_mut()) - .and_then(|o| o.remove("distant")) - .map(|mut v| { - ( - v.get_mut("host").map(Value::take), - v.get_mut("port").map(Value::take), - v.get_mut("key").map(Value::take), - ) - }) - } } impl AsRef> for LspContent { @@ -491,10 +396,6 @@ impl FromStr for LspContent { #[cfg(test)] mod tests { use super::*; - use crate::net::SecretKey; - - // 32-byte test hex key (64 hex characters) - const TEST_HEX_KEY: &str = "ABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCD"; macro_rules! make_obj { ($($tail:tt)*) => { @@ -506,8 +407,8 @@ mod tests { } #[test] - fn data_display_should_output_header_and_content() { - let data = LspData { + fn msg_display_should_output_header_and_content() { + let msg = LspMsg { header: LspHeader { content_length: 123, content_type: Some(String::from("some content type")), @@ -515,7 +416,7 @@ mod tests { content: LspContent(make_obj!({"hello": "world"})), }; - let output = data.to_string(); + let output = msg.to_string(); assert_eq!( output, concat!( @@ -530,7 +431,7 @@ mod tests { } #[test] - fn data_from_buf_reader_should_be_successful_if_valid_data_received() { + fn msg_from_buf_reader_should_be_successful_if_valid_msg_received() { let mut input = io::Cursor::new(concat!( "Content-Length: 22\r\n", "Content-Type: some content type\r\n", @@ -539,45 +440,45 @@ mod tests { " \"hello\": \"world\"\n", "}", )); - let data = LspData::from_buf_reader(&mut input).unwrap(); - assert_eq!(data.header.content_length, 22); + let msg = LspMsg::from_buf_reader(&mut input).unwrap(); + assert_eq!(msg.header.content_length, 22); assert_eq!( - data.header.content_type.as_deref(), + msg.header.content_type.as_deref(), Some("some content type") ); - assert_eq!(data.content.as_ref(), &make_obj!({ "hello": "world" })); + assert_eq!(msg.content.as_ref(), &make_obj!({ "hello": "world" })); } #[test] - fn data_from_buf_reader_should_fail_if_reach_eof_before_received_full_data() { + fn msg_from_buf_reader_should_fail_if_reach_eof_before_received_full_msg() { // No line termination - let err = LspData::from_buf_reader(&mut io::Cursor::new("Content-Length: 22")).unwrap_err(); + let err = LspMsg::from_buf_reader(&mut io::Cursor::new("Content-Length: 22")).unwrap_err(); assert!( - matches!(err, LspDataParseError::BadHeaderTermination), + matches!(err, LspMsgParseError::BadHeaderTermination), "{:?}", err ); // Header doesn't finish - let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!( + let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!( "Content-Length: 22\r\n", "Content-Type: some content type\r\n", ))) .unwrap_err(); - assert!(matches!(err, LspDataParseError::UnexpectedEof), "{:?}", err); + assert!(matches!(err, LspMsgParseError::UnexpectedEof), "{:?}", err); // No content after header - let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!( + let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!( "Content-Length: 22\r\n", "\r\n", ))) .unwrap_err(); - assert!(matches!(err, LspDataParseError::UnexpectedEof), "{:?}", err); + assert!(matches!(err, LspMsgParseError::UnexpectedEof), "{:?}", err); } #[test] - fn data_from_buf_reader_should_fail_if_missing_proper_line_termination_for_header_field() { - let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!( + fn msg_from_buf_reader_should_fail_if_missing_proper_line_termination_for_header_field() { + let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!( "Content-Length: 22\n", "\r\n", "{\n", @@ -586,16 +487,16 @@ mod tests { ))) .unwrap_err(); assert!( - matches!(err, LspDataParseError::BadHeaderTermination), + matches!(err, LspMsgParseError::BadHeaderTermination), "{:?}", err ); } #[test] - fn data_from_buf_reader_should_fail_if_bad_header_provided() { + fn msg_from_buf_reader_should_fail_if_bad_header_provided() { // Invalid content length - let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!( + let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!( "Content-Length: -1\r\n", "\r\n", "{\n", @@ -603,10 +504,10 @@ mod tests { "}", ))) .unwrap_err(); - assert!(matches!(err, LspDataParseError::BadHeader(_)), "{:?}", err); + assert!(matches!(err, LspMsgParseError::BadHeader(_)), "{:?}", err); // Missing content length - let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!( + let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!( "Content-Type: some content type\r\n", "\r\n", "{\n", @@ -614,371 +515,30 @@ mod tests { "}", ))) .unwrap_err(); - assert!(matches!(err, LspDataParseError::BadHeader(_)), "{:?}", err); + assert!(matches!(err, LspMsgParseError::BadHeader(_)), "{:?}", err); } #[test] - fn data_from_buf_reader_should_fail_if_bad_content_provided() { + fn msg_from_buf_reader_should_fail_if_bad_content_provided() { // Not full content - let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!( + let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!( "Content-Length: 21\r\n", "\r\n", "{\n", " \"hello\": \"world\"\n", ))) .unwrap_err(); - assert!(matches!(err, LspDataParseError::BadContent(_)), "{:?}", err); + assert!(matches!(err, LspMsgParseError::BadContent(_)), "{:?}", err); } #[test] - fn data_from_buf_reader_should_fail_if_non_utf8_data_encountered_for_content() { + fn msg_from_buf_reader_should_fail_if_non_utf8_msg_encountered_for_content() { // Not utf-8 content let mut raw = b"Content-Length: 2\r\n\r\n".to_vec(); raw.extend(vec![0, 159]); - let err = LspData::from_buf_reader(&mut io::Cursor::new(raw)).unwrap_err(); - assert!(matches!(err, LspDataParseError::BadInput(_)), "{:?}", err); - } - - #[test] - fn data_take_session_info_should_succeed_if_valid_session_found_in_params() { - let mut data = LspData { - header: LspHeader { - content_length: 123, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": { - "distant": { - "host": "some.host", - "port": 22, - "key": TEST_HEX_KEY - } - } - } - })), - }; - - let info = data.take_session_info().unwrap(); - assert_eq!( - info, - SessionInfo { - host: String::from("some.host"), - port: 22, - key: SecretKey::from_slice(&hex::decode(TEST_HEX_KEY).unwrap()).unwrap(), - } - ); - } - - #[test] - fn data_take_session_info_should_remove_session_parameters_if_successful() { - let mut data = LspData { - header: LspHeader { - content_length: 123, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": { - "distant": { - "host": "some.host", - "port": 22, - "key": TEST_HEX_KEY - } - } - } - })), - }; - - let _ = data.take_session_info().unwrap(); - assert_eq!( - data.content.as_ref(), - &make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": {} - } - }) - ); - } - - #[test] - fn data_take_session_info_should_adjust_content_length_based_on_new_content_byte_length() { - let mut data = LspData { - header: LspHeader { - content_length: 123456, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": { - "distant": { - "host": "some.host", - "port": 22, - "key": TEST_HEX_KEY - } - } - } - })), - }; - - let _ = data.take_session_info().unwrap(); - assert_eq!(data.header.content_length, data.content.to_string().len()); - } - - #[test] - fn data_take_session_info_should_fail_if_path_incomplete_to_session_params() { - let mut data = LspData { - header: LspHeader { - content_length: 123456, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": {} - } - })), - }; - - let err = data.take_session_info().unwrap_err(); - assert!( - matches!(err, LspSessionInfoError::MissingSessionInfoParams), - "{:?}", - err - ); - } - - #[test] - fn data_take_session_info_should_fail_if_missing_host_param() { - let mut data = LspData { - header: LspHeader { - content_length: 123456, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": { - "distant": { - "port": 22, - "key": TEST_HEX_KEY - } - } - } - })), - }; - - let err = data.take_session_info().unwrap_err(); - assert!( - matches!(err, LspSessionInfoError::MissingSessionInfoParams), - "{:?}", - err - ); - } - - #[test] - fn data_take_session_info_should_fail_if_host_param_is_invalid() { - let mut data = LspData { - header: LspHeader { - content_length: 123456, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": { - "distant": { - "host": 1234, - "port": 22, - "key": TEST_HEX_KEY - } - } - } - })), - }; - - let err = data.take_session_info().unwrap_err(); - assert!( - matches!(err, LspSessionInfoError::InvalidSessionInfoParams), - "{:?}", - err - ); - } - - #[test] - fn data_take_session_info_should_fail_if_missing_port_param() { - let mut data = LspData { - header: LspHeader { - content_length: 123456, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": { - "distant": { - "host": "some.host", - "key": TEST_HEX_KEY - } - } - } - })), - }; - - let err = data.take_session_info().unwrap_err(); - assert!( - matches!(err, LspSessionInfoError::MissingSessionInfoParams), - "{:?}", - err - ); - } - - #[test] - fn data_take_session_info_should_fail_if_port_param_is_invalid() { - let mut data = LspData { - header: LspHeader { - content_length: 123456, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": { - "distant": { - "host": "some.host", - "port": "abcd", - "key": TEST_HEX_KEY - } - } - } - })), - }; - - let err = data.take_session_info().unwrap_err(); - assert!( - matches!(err, LspSessionInfoError::InvalidSessionInfoParams), - "{:?}", - err - ); - } - - #[test] - fn data_take_session_info_should_fail_if_missing_key_param() { - let mut data = LspData { - header: LspHeader { - content_length: 123456, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": { - "distant": { - "host": "some.host", - "port": 22, - } - } - } - })), - }; - - let err = data.take_session_info().unwrap_err(); - assert!( - matches!(err, LspSessionInfoError::MissingSessionInfoParams), - "{:?}", - err - ); - } - - #[test] - fn data_take_session_info_should_fail_if_key_param_is_invalid() { - let mut data = LspData { - header: LspHeader { - content_length: 123456, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "initialize", - "params": { - "initializationOptions": { - "distant": { - "host": "some.host", - "port": 22, - "key": 1234, - } - } - } - })), - }; - - let err = data.take_session_info().unwrap_err(); - assert!( - matches!(err, LspSessionInfoError::InvalidSessionInfoParams), - "{:?}", - err - ); - } - - #[test] - fn data_take_session_info_should_fail_if_missing_method_field() { - let mut data = LspData { - header: LspHeader { - content_length: 123456, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "params": { - "initializationOptions": { - "distant": { - "host": "some.host", - "port": 22, - "key": TEST_HEX_KEY, - } - } - } - })), - }; - - let err = data.take_session_info().unwrap_err(); - assert!( - matches!(err, LspSessionInfoError::NotInitializeRequest), - "{:?}", - err - ); - } - - #[test] - fn data_take_session_info_should_fail_if_method_field_is_not_initialize() { - let mut data = LspData { - header: LspHeader { - content_length: 123456, - content_type: Some(String::from("some content type")), - }, - content: LspContent(make_obj!({ - "method": "not initialize", - "params": { - "initializationOptions": { - "distant": { - "host": "some.host", - "port": 22, - "key": TEST_HEX_KEY, - } - } - } - })), - }; - - let err = data.take_session_info().unwrap_err(); - assert!( - matches!(err, LspSessionInfoError::NotInitializeRequest), - "{:?}", - err - ); + let err = LspMsg::from_buf_reader(&mut io::Cursor::new(raw)).unwrap_err(); + assert!(matches!(err, LspMsgParseError::BadInput(_)), "{:?}", err); } #[test] diff --git a/distant-core/src/client/mod.rs b/distant-core/src/client/mod.rs deleted file mode 100644 index 1c9c7f2..0000000 --- a/distant-core/src/client/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -mod lsp; -mod process; -mod session; -mod utils; -mod watcher; - -pub use lsp::*; -pub use process::*; -pub use session::*; -pub use watcher::*; diff --git a/distant-core/src/client/process.rs b/distant-core/src/client/process.rs index 2835020..95c7e35 100644 --- a/distant-core/src/client/process.rs +++ b/distant-core/src/client/process.rs @@ -1,12 +1,12 @@ use crate::{ - client::{Mailbox, SessionChannel}, + client::DistantChannel, constants::CLIENT_PIPE_CAPACITY, - data::{PtySize, Request, RequestData, ResponseData}, - net::TransportError, + data::{Cmd, DistantRequestData, DistantResponseData, Environment, ProcessId, PtySize}, + DistantMsg, }; -use derive_more::{Display, Error, From}; +use distant_net::{Mailbox, Request, Response}; use log::*; -use std::sync::Arc; +use std::{path::PathBuf, sync::Arc}; use tokio::{ io, sync::{ @@ -16,121 +16,126 @@ use tokio::{ }, RwLock, }, - task::{JoinError, JoinHandle}, + task::JoinHandle, }; -type StatusResult = Result<(bool, Option), RemoteProcessError>; - -#[derive(Debug, Display, Error, From)] -pub enum RemoteProcessError { - /// When attempting to relay stdout/stderr over channels, but the channels fail - ChannelDead, - - /// When the communication over the wire has issues - TransportError(TransportError), - - /// When the stream of responses from the server closes without receiving - /// an indicator of the process' exit status - UnexpectedEof, - - /// When attempting to wait on the remote process, but the internal task joining failed - WaitFailed(JoinError), +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct RemoteOutput { + pub success: bool, + pub code: Option, + pub stdout: Vec, + pub stderr: Vec, } -/// Represents a process on a remote machine -#[derive(Debug)] -pub struct RemoteProcess { - /// Id of the process - id: usize, - - /// Id used to map back to mailbox - origin_id: usize, +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct RemoteStatus { + pub success: bool, + pub code: Option, +} - // Sender to abort req task - abort_req_task_tx: mpsc::Sender<()>, +impl From<(bool, Option)> for RemoteStatus { + fn from((success, code): (bool, Option)) -> Self { + Self { success, code } + } +} - // Sender to abort res task - abort_res_task_tx: mpsc::Sender<()>, +type StatusResult = io::Result; - /// Sender for stdin - pub stdin: Option, +/// A [`RemoteProcess`] builder providing support to configure +/// before spawning the process on a remote machine +pub struct RemoteCommand { + persist: bool, + pty: Option, + environment: Environment, + current_dir: Option, +} - /// Receiver for stdout - pub stdout: Option, +impl Default for RemoteCommand { + fn default() -> Self { + Self::new() + } +} - /// Receiver for stderr - pub stderr: Option, +impl RemoteCommand { + /// Creates a new set of options for a remote process + pub fn new() -> Self { + Self { + persist: false, + pty: None, + environment: Environment::new(), + current_dir: None, + } + } - /// Sender for resize events - resizer: RemoteProcessResizer, + /// Sets whether or not the process will be persistent, + /// meaning that it will not be terminated when the + /// connection to the remote machine is terminated + pub fn persist(&mut self, persist: bool) -> &mut Self { + self.persist = persist; + self + } - /// Sender for kill events - killer: RemoteProcessKiller, + /// Configures the process to leverage a PTY with the specified size + pub fn pty(&mut self, pty: Option) -> &mut Self { + self.pty = pty; + self + } - /// Task that waits for the process to complete - wait_task: JoinHandle<()>, + /// Replaces the existing environment variables with the given collection + pub fn environment(&mut self, environment: Environment) -> &mut Self { + self.environment = environment; + self + } - /// Handles the success and exit code for a completed process - status: Arc>>, -} + /// Configures the process with an alternative current directory + pub fn current_dir(&mut self, current_dir: Option) -> &mut Self { + self.current_dir = current_dir; + self + } -impl RemoteProcess { - /// Spawns the specified process on the remote machine using the given session + /// Spawns the specified process on the remote machine using the given `channel` and `cmd` pub async fn spawn( - tenant: impl Into, - mut channel: SessionChannel, + &mut self, + mut channel: DistantChannel, cmd: impl Into, - args: Vec, - persist: bool, - pty: Option, - ) -> Result { - let tenant = tenant.into(); + ) -> io::Result { let cmd = cmd.into(); // Submit our run request and get back a mailbox for responses let mut mailbox = channel - .mail(Request::new( - tenant.as_str(), - vec![RequestData::ProcSpawn { - cmd, - args, - persist, - pty, - }], - )) + .mail(Request::new(DistantMsg::Single( + DistantRequestData::ProcSpawn { + cmd: Cmd::from(cmd), + persist: self.persist, + pty: self.pty, + environment: self.environment.clone(), + current_dir: self.current_dir.clone(), + }, + ))) .await?; // Wait until we get the first response, and get id from proc started let (id, origin_id) = match mailbox.next().await { - Some(res) if res.payload.len() != 1 => { - return Err(RemoteProcessError::TransportError(TransportError::IoError( - io::Error::new(io::ErrorKind::InvalidData, "Got wrong payload size"), - ))); - } Some(res) => { let origin_id = res.origin_id; - match res.payload.into_iter().next().unwrap() { - ResponseData::ProcSpawned { id } => (id, origin_id), - ResponseData::Error(x) => { - return Err(RemoteProcessError::TransportError(TransportError::IoError( - x.into(), - ))) + match res.payload { + DistantMsg::Single(DistantResponseData::ProcSpawned { id }) => (id, origin_id), + DistantMsg::Single(DistantResponseData::Error(x)) => return Err(x.into()), + DistantMsg::Single(x) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got response type of {}", x.as_ref()), + )) } - x => { - return Err(RemoteProcessError::TransportError(TransportError::IoError( - io::Error::new( - io::ErrorKind::InvalidData, - format!("Got response type of {}", x.as_ref()), - ), - ))) + DistantMsg::Batch(_) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Got batch instead of single response", + )); } } } - None => { - return Err(RemoteProcessError::TransportError(TransportError::IoError( - io::Error::from(io::ErrorKind::ConnectionAborted), - ))) - } + None => return Err(io::Error::from(io::ErrorKind::ConnectionAborted)), }; // Create channels for our stdin/stdout/stderr @@ -165,7 +170,7 @@ impl RemoteProcess { _ = abort_req_task_rx.recv() => { panic!("killed"); } - res = process_outgoing_requests(tenant, id, channel, stdin_rx, resize_rx, kill_rx) => { + res = process_outgoing_requests( id, channel, stdin_rx, resize_rx, kill_rx) => { res } } @@ -175,13 +180,13 @@ impl RemoteProcess { let status_2 = Arc::clone(&status); let wait_task = tokio::spawn(async move { let res = match tokio::try_join!(req_task, res_task) { - Ok((_, res)) => res, - Err(x) => Err(RemoteProcessError::from(x)), + Ok((_, res)) => res.map(RemoteStatus::from), + Err(x) => Err(io::Error::new(io::ErrorKind::Interrupted, x)), }; status_2.write().await.replace(res); }); - Ok(Self { + Ok(RemoteProcess { id, origin_id, abort_req_task_tx, @@ -195,30 +200,72 @@ impl RemoteProcess { status, }) } +} + +/// Represents a process on a remote machine +#[derive(Debug)] +pub struct RemoteProcess { + /// Id of the process + id: ProcessId, + + /// Id used to map back to mailbox + origin_id: String, + + // Sender to abort req task + abort_req_task_tx: mpsc::Sender<()>, + + // Sender to abort res task + abort_res_task_tx: mpsc::Sender<()>, + + /// Sender for stdin + pub stdin: Option, + + /// Receiver for stdout + pub stdout: Option, + + /// Receiver for stderr + pub stderr: Option, + + /// Sender for resize events + resizer: RemoteProcessResizer, + + /// Sender for kill events + killer: RemoteProcessKiller, + + /// Task that waits for the process to complete + wait_task: JoinHandle<()>, + + /// Handles the success and exit code for a completed process + status: Arc>>, +} +impl RemoteProcess { /// Returns the id of the running process - pub fn id(&self) -> usize { + pub fn id(&self) -> ProcessId { self.id } /// Returns the id of the request that spawned this process - pub fn origin_id(&self) -> usize { - self.origin_id + pub fn origin_id(&self) -> &str { + &self.origin_id } /// Checks if the process has completed, returning the exit status if it has, without /// consuming the process itself. Note that this does not include join errors that can /// occur when aborting and instead converts any error to a status of false. To acquire /// the actual error, you must call `wait` - pub async fn status(&self) -> Option<(bool, Option)> { + pub async fn status(&self) -> Option { self.status.read().await.as_ref().map(|x| match x { - Ok((success, exit_code)) => (*success, *exit_code), - Err(_) => (false, None), + Ok(status) => *status, + Err(_) => RemoteStatus { + success: false, + code: None, + }, }) } /// Waits for the process to terminate, returning the success status and an optional exit code - pub async fn wait(self) -> Result<(bool, Option), RemoteProcessError> { + pub async fn wait(self) -> io::Result { // Wait for the process to complete before we try to get the status let _ = self.wait_task.await; @@ -227,11 +274,41 @@ impl RemoteProcess { .write() .await .take() - .unwrap_or(Err(RemoteProcessError::UnexpectedEof)) + .unwrap_or_else(|| Err(errors::unexpected_eof())) + } + + /// Waits for the process to terminate, returning the success status, an optional exit code, + /// and any remaining stdout and stderr (if still attached to the process) + pub async fn output(mut self) -> io::Result { + let maybe_stdout = self.stdout.take(); + let maybe_stderr = self.stderr.take(); + + let status = self.wait().await?; + + let mut stdout = Vec::new(); + if let Some(mut reader) = maybe_stdout { + while let Ok(data) = reader.read().await { + stdout.extend(&data); + } + } + + let mut stderr = Vec::new(); + if let Some(mut reader) = maybe_stderr { + while let Ok(data) = reader.read().await { + stderr.extend(&data); + } + } + + Ok(RemoteOutput { + success: status.success, + code: status.code, + stdout, + stderr, + }) } /// Resizes the pty of the remote process if it is attached to one - pub async fn resize(&self, size: PtySize) -> Result<(), RemoteProcessError> { + pub async fn resize(&self, size: PtySize) -> io::Result<()> { self.resizer.resize(size).await } @@ -241,7 +318,7 @@ impl RemoteProcess { } /// Submits a kill request for the running process - pub async fn kill(&mut self) -> Result<(), RemoteProcessError> { + pub async fn kill(&mut self) -> io::Result<()> { self.killer.kill().await } @@ -265,11 +342,11 @@ pub struct RemoteProcessResizer(mpsc::Sender); impl RemoteProcessResizer { /// Resizes the pty of the remote process if it is attached to one - pub async fn resize(&self, size: PtySize) -> Result<(), RemoteProcessError> { + pub async fn resize(&self, size: PtySize) -> io::Result<()> { self.0 .send(size) .await - .map_err(|_| RemoteProcessError::ChannelDead)?; + .map_err(|_| errors::dead_channel())?; Ok(()) } } @@ -280,11 +357,8 @@ pub struct RemoteProcessKiller(mpsc::Sender<()>); impl RemoteProcessKiller { /// Submits a kill request for the running process - pub async fn kill(&mut self) -> Result<(), RemoteProcessError> { - self.0 - .send(()) - .await - .map_err(|_| RemoteProcessError::ChannelDead)?; + pub async fn kill(&mut self) -> io::Result<()> { + self.0.send(()).await.map_err(|_| errors::dead_channel())?; Ok(()) } } @@ -421,46 +495,42 @@ impl RemoteStderr { /// Helper function that loops, processing outgoing stdin requests to a remote process as well as /// supporting a kill request to terminate the remote process async fn process_outgoing_requests( - tenant: String, - id: usize, - mut channel: SessionChannel, + id: ProcessId, + mut channel: DistantChannel, mut stdin_rx: mpsc::Receiver>, mut resize_rx: mpsc::Receiver, mut kill_rx: mpsc::Receiver<()>, -) -> Result<(), RemoteProcessError> { +) -> io::Result<()> { let result = loop { tokio::select! { data = stdin_rx.recv() => { match data { Some(data) => channel.fire( Request::new( - tenant.as_str(), - vec![RequestData::ProcStdin { id, data }] + DistantMsg::Single(DistantRequestData::ProcStdin { id, data }) ) ).await?, - None => break Err(RemoteProcessError::ChannelDead), + None => break Err(errors::dead_channel()), } } size = resize_rx.recv() => { match size { Some(size) => channel.fire( Request::new( - tenant.as_str(), - vec![RequestData::ProcResizePty { id, size }] + DistantMsg::Single(DistantRequestData::ProcResizePty { id, size }) ) ).await?, - None => break Err(RemoteProcessError::ChannelDead), + None => break Err(errors::dead_channel()), } } msg = kill_rx.recv() => { if msg.is_some() { channel.fire(Request::new( - tenant.as_str(), - vec![RequestData::ProcKill { id }], + DistantMsg::Single(DistantRequestData::ProcKill { id }) )).await?; break Ok(()); } else { - break Err(RemoteProcessError::ChannelDead); + break Err(errors::dead_channel()); } } } @@ -472,16 +542,18 @@ async fn process_outgoing_requests( /// Helper function that loops, processing incoming stdout & stderr requests from a remote process async fn process_incoming_responses( - proc_id: usize, - mut mailbox: Mailbox, + proc_id: ProcessId, + mut mailbox: Mailbox>>, stdout_tx: mpsc::Sender>, stderr_tx: mpsc::Sender>, kill_tx: mpsc::Sender<()>, -) -> Result<(bool, Option), RemoteProcessError> { +) -> io::Result<(bool, Option)> { while let Some(res) = mailbox.next().await { + let payload = res.payload.into_vec(); + // Check if any of the payload data is the termination - let exit_status = res.payload.iter().find_map(|data| match data { - ResponseData::ProcDone { id, success, code } if *id == proc_id => { + let exit_status = payload.iter().find_map(|data| match data { + DistantResponseData::ProcDone { id, success, code } if *id == proc_id => { Some((*success, *code)) } _ => None, @@ -489,12 +561,12 @@ async fn process_incoming_responses( // Next, check for stdout/stderr and send them along our channels // TODO: What should we do about unexpected data? For now, just ignore - for data in res.payload { + for data in payload { match data { - ResponseData::ProcStdout { id, data } if id == proc_id => { + DistantResponseData::ProcStdout { id, data } if id == proc_id => { let _ = stdout_tx.send(data).await; } - ResponseData::ProcStderr { id, data } if id == proc_id => { + DistantResponseData::ProcStderr { id, data } if id == proc_id => { let _ = stderr_tx.send(data).await; } _ => {} @@ -514,62 +586,72 @@ async fn process_incoming_responses( let _ = kill_tx.try_send(()); trace!("Process incoming channel closed"); - Err(RemoteProcessError::UnexpectedEof) + Err(errors::unexpected_eof()) +} + +mod errors { + use std::io; + + pub fn dead_channel() -> io::Error { + io::Error::new(io::ErrorKind::BrokenPipe, "Channel is dead") + } + + pub fn unexpected_eof() -> io::Error { + io::Error::from(io::ErrorKind::UnexpectedEof) + } } #[cfg(test)] mod tests { use super::*; use crate::{ - client::Session, - data::{Error, ErrorKind, Response}, - net::{InmemoryStream, PlainCodec, Transport}, + client::DistantClient, + data::{Error, ErrorKind}, + }; + use distant_net::{ + Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Response, + TypedAsyncRead, TypedAsyncWrite, }; use std::time::Duration; - fn make_session() -> (Transport, Session) { - let (t1, t2) = Transport::make_pair(); - (t1, Session::initialize(t2).unwrap()) + fn make_session() -> ( + FramedTransport, + DistantClient, + ) { + let (t1, t2) = FramedTransport::pair(100); + let (writer, reader) = t2.into_split(); + (t1, Client::new(writer, reader).unwrap()) } #[tokio::test] - async fn spawn_should_return_invalid_data_if_payload_size_unexpected() { + async fn spawn_should_return_invalid_data_if_received_batch_response() { let (mut transport, session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session transport - .send(Response::new("test-tenant", req.id, Vec::new())) + .write(Response::new( + req.id, + DistantMsg::Batch(vec![DistantResponseData::ProcSpawned { id: 1 }]), + )) .await .unwrap(); // Get the spawn result and verify - let result = spawn_task.await.unwrap(); - assert!( - matches!( - &result, - Err(RemoteProcessError::TransportError(TransportError::IoError(x))) - if x.kind() == io::ErrorKind::InvalidData - ), - "Unexpected result: {:?}", - result - ); + match spawn_task.await.unwrap() { + Err(x) if x.kind() == io::ErrorKind::InvalidData => {} + x => panic!("Unexpected result: {:?}", x), + } } #[tokio::test] @@ -579,44 +661,31 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::Error(Error { + DistantMsg::Single(DistantResponseData::Error(Error { kind: ErrorKind::BrokenPipe, description: String::from("some error"), - })], + })), )) .await .unwrap(); // Get the spawn result and verify - let result = spawn_task.await.unwrap(); - assert!( - matches!( - &result, - Err(RemoteProcessError::TransportError(TransportError::IoError(x))) - if x.kind() == io::ErrorKind::BrokenPipe - ), - "Unexpected result: {:?}", - result - ); + match spawn_task.await.unwrap() { + Err(x) if x.kind() == io::ErrorKind::BrokenPipe => {} + x => panic!("Unexpected result: {:?}", x), + } } #[tokio::test] @@ -626,27 +695,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcSpawned { id }], + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -658,12 +720,10 @@ mod tests { // Ensure that the other tasks are aborted before continuing tokio::task::yield_now().await; - let result = proc.kill().await; - assert!( - matches!(result, Err(RemoteProcessError::ChannelDead)), - "Unexpected result: {:?}", - result - ); + match proc.kill().await { + Err(x) if x.kind() == io::ErrorKind::BrokenPipe => {} + x => panic!("Unexpected result: {:?}", x), + } } #[tokio::test] @@ -673,27 +733,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcSpawned { id }], + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -703,13 +756,13 @@ mod tests { assert!(proc.kill().await.is_ok(), "Failed to send kill request"); // Verify the kill request was sent - let req = transport.receive::().await.unwrap().unwrap(); - assert_eq!( - req.payload.len(), - 1, - "Unexpected payload length for kill request" - ); - assert_eq!(req.payload[0], RequestData::ProcKill { id }); + let req: Request> = transport.read().await.unwrap().unwrap(); + match req.payload { + DistantMsg::Single(DistantRequestData::ProcKill { id: proc_id }) => { + assert_eq!(proc_id, id) + } + x => panic!("Unexpected request: {:?}", x), + } // Verify we can no longer write to stdin anymore assert_eq!( @@ -731,27 +784,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcSpawned { id }], + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -766,15 +812,10 @@ mod tests { .unwrap(); // Verify that a request is made through the session - match &transport - .receive::() - .await - .unwrap() - .unwrap() - .payload[0] - { - RequestData::ProcStdin { id, data } => { - assert_eq!(*id, 12345); + let req: Request> = transport.read().await.unwrap().unwrap(); + match req.payload { + DistantMsg::Single(DistantRequestData::ProcStdin { id, data }) => { + assert_eq!(id, 12345); assert_eq!(data, b"some input"); } x => panic!("Unexpected request: {:?}", x), @@ -788,27 +829,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", - req.id, - vec![ResponseData::ProcSpawned { id }], + .write(Response::new( + req.id.clone(), + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -817,13 +851,12 @@ mod tests { let mut proc = spawn_task.await.unwrap().unwrap(); transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcStdout { + DistantMsg::Single(DistantResponseData::ProcStdout { id, data: b"some out".to_vec(), - }], + }), )) .await .unwrap(); @@ -839,27 +872,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", - req.id, - vec![ResponseData::ProcSpawned { id }], + .write(Response::new( + req.id.clone(), + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -868,13 +894,12 @@ mod tests { let mut proc = spawn_task.await.unwrap().unwrap(); transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcStderr { + DistantMsg::Single(DistantResponseData::ProcStderr { id, data: b"some err".to_vec(), - }], + }), )) .await .unwrap(); @@ -890,27 +915,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcSpawned { id }], + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -929,27 +947,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcSpawned { id }], + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -964,7 +975,13 @@ mod tests { // Peek at the status to confirm the result let result = proc.status().await; match result { - Some((false, None)) => {} + Some(status) => { + assert!(!status.success, "Status unexpectedly reported success"); + assert!( + status.code.is_none(), + "Status unexpectedly reported exit code" + ); + } x => panic!("Unexpected result: {:?}", x), } } @@ -976,27 +993,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", - req.id, - vec![ResponseData::ProcSpawned { id }], + .write(Response::new( + req.id.clone(), + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -1006,14 +1016,13 @@ mod tests { // Send a process completion response to pass along exit status and conclude wait transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcDone { + DistantMsg::Single(DistantResponseData::ProcDone { id, success: true, code: Some(123), - }], + }), )) .await .unwrap(); @@ -1022,7 +1031,13 @@ mod tests { tokio::time::sleep(Duration::from_millis(100)).await; // Finally, verify that we complete and get the expected results - assert_eq!(proc.status().await, Some((true, Some(123)))); + assert_eq!( + proc.status().await, + Some(RemoteStatus { + success: true, + code: Some(123) + }) + ); } #[tokio::test] @@ -1032,27 +1047,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcSpawned { id }], + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -1061,12 +1069,10 @@ mod tests { let proc = spawn_task.await.unwrap().unwrap(); proc.abort(); - let result = proc.wait().await; - assert!( - matches!(result, Err(RemoteProcessError::WaitFailed(_))), - "Unexpected result: {:?}", - result - ); + match proc.wait().await { + Err(x) if x.kind() == io::ErrorKind::Interrupted => {} + x => panic!("Unexpected result: {:?}", x), + } } #[tokio::test] @@ -1076,27 +1082,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcSpawned { id }], + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -1112,12 +1111,10 @@ mod tests { // Ensure that the other tasks are cancelled before continuing tokio::task::yield_now().await; - let result = proc.wait().await; - assert!( - matches!(result, Err(RemoteProcessError::UnexpectedEof)), - "Unexpected result: {:?}", - result - ); + match proc.wait().await { + Err(x) if x.kind() == io::ErrorKind::UnexpectedEof => {} + x => panic!("Unexpected result: {:?}", x), + } } #[tokio::test] @@ -1127,27 +1124,20 @@ mod tests { // Create a task for process spawning as we need to handle the request and a response // in a separate async block let spawn_task = tokio::spawn(async move { - RemoteProcess::spawn( - String::from("test-tenant"), - session.clone_channel(), - String::from("cmd"), - vec![String::from("arg")], - false, - None, - ) - .await + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request> = transport.read().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .send(Response::new( - "test-tenant", - req.id, - vec![ResponseData::ProcSpawned { id }], + .write(Response::new( + req.id.clone(), + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) .await .unwrap(); @@ -1158,19 +1148,102 @@ mod tests { // Send a process completion response to pass along exit status and conclude wait transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::ProcDone { + DistantMsg::Single(DistantResponseData::ProcDone { id, success: false, code: Some(123), - }], + }), )) .await .unwrap(); // Finally, verify that we complete and get the expected results - assert_eq!(proc_wait_task.await.unwrap().unwrap(), (false, Some(123))); + assert_eq!( + proc_wait_task.await.unwrap().unwrap(), + RemoteStatus { + success: false, + code: Some(123) + } + ); + } + + #[tokio::test] + async fn receiving_done_response_should_result_in_output_returning_exit_information() { + let (mut transport, session) = make_session(); + + // Create a task for process spawning as we need to handle the request and a response + // in a separate async block + let spawn_task = tokio::spawn(async move { + RemoteCommand::new() + .spawn(session.clone_channel(), String::from("cmd arg")) + .await + }); + + // Wait until we get the request from the session + let req: Request> = transport.read().await.unwrap().unwrap(); + + // Send back a response through the session + let id = 12345; + transport + .write(Response::new( + req.id.clone(), + DistantMsg::Single(DistantResponseData::ProcSpawned { id }), + )) + .await + .unwrap(); + + // Receive the process and then spawn a task for it to complete + let proc = spawn_task.await.unwrap().unwrap(); + let proc_output_task = tokio::spawn(proc.output()); + + // Send some stdout + transport + .write(Response::new( + req.id.clone(), + DistantMsg::Single(DistantResponseData::ProcStdout { + id, + data: b"some out".to_vec(), + }), + )) + .await + .unwrap(); + + // Send some stderr + transport + .write(Response::new( + req.id.clone(), + DistantMsg::Single(DistantResponseData::ProcStderr { + id, + data: b"some err".to_vec(), + }), + )) + .await + .unwrap(); + + // Send a process completion response to pass along exit status and conclude wait + transport + .write(Response::new( + req.id, + DistantMsg::Single(DistantResponseData::ProcDone { + id, + success: false, + code: Some(123), + }), + )) + .await + .unwrap(); + + // Finally, verify that we complete and get the expected results + assert_eq!( + proc_output_task.await.unwrap().unwrap(), + RemoteOutput { + success: false, + code: Some(123), + stdout: b"some out".to_vec(), + stderr: b"some err".to_vec(), + } + ); } } diff --git a/distant-core/src/client/session/ext.rs b/distant-core/src/client/session/ext.rs deleted file mode 100644 index b62bd67..0000000 --- a/distant-core/src/client/session/ext.rs +++ /dev/null @@ -1,514 +0,0 @@ -use crate::{ - client::{ - RemoteLspProcess, RemoteProcess, RemoteProcessError, SessionChannel, UnwatchError, - WatchError, Watcher, - }, - data::{ - ChangeKindSet, DirEntry, Error as Failure, Metadata, PtySize, Request, RequestData, - ResponseData, SystemInfo, - }, - net::TransportError, -}; -use derive_more::{Display, Error, From}; -use std::{future::Future, path::PathBuf, pin::Pin}; - -/// Represents an error that can occur related to convenience functions tied to a -/// [`SessionChannel`] through [`SessionChannelExt`] -#[derive(Debug, Display, Error, From)] -pub enum SessionChannelExtError { - /// Occurs when the remote action fails - Failure(#[error(not(source))] Failure), - - /// Occurs when a transport error is encountered - TransportError(TransportError), - - /// Occurs when receiving a response that was not expected - MismatchedResponse, -} - -pub type AsyncReturn<'a, T, E = SessionChannelExtError> = - Pin> + Send + 'a>>; - -/// Provides convenience functions on top of a [`SessionChannel`] -pub trait SessionChannelExt { - /// Appends to a remote file using the data from a collection of bytes - fn append_file( - &mut self, - tenant: impl Into, - path: impl Into, - data: impl Into>, - ) -> AsyncReturn<'_, ()>; - - /// Appends to a remote file using the data from a string - fn append_file_text( - &mut self, - tenant: impl Into, - path: impl Into, - data: impl Into, - ) -> AsyncReturn<'_, ()>; - - /// Copies a remote file or directory from src to dst - fn copy( - &mut self, - tenant: impl Into, - src: impl Into, - dst: impl Into, - ) -> AsyncReturn<'_, ()>; - - /// Creates a remote directory, optionally creating all parent components if specified - fn create_dir( - &mut self, - tenant: impl Into, - path: impl Into, - all: bool, - ) -> AsyncReturn<'_, ()>; - - /// Checks if a path exists on a remote machine - fn exists( - &mut self, - tenant: impl Into, - path: impl Into, - ) -> AsyncReturn<'_, bool>; - - /// Retrieves metadata about a path on a remote machine - fn metadata( - &mut self, - tenant: impl Into, - path: impl Into, - canonicalize: bool, - resolve_file_type: bool, - ) -> AsyncReturn<'_, Metadata>; - - /// Reads entries from a directory, returning a tuple of directory entries and failures - fn read_dir( - &mut self, - tenant: impl Into, - path: impl Into, - depth: usize, - absolute: bool, - canonicalize: bool, - include_root: bool, - ) -> AsyncReturn<'_, (Vec, Vec)>; - - /// Reads a remote file as a collection of bytes - fn read_file( - &mut self, - tenant: impl Into, - path: impl Into, - ) -> AsyncReturn<'_, Vec>; - - /// Returns a remote file as a string - fn read_file_text( - &mut self, - tenant: impl Into, - path: impl Into, - ) -> AsyncReturn<'_, String>; - - /// Removes a remote file or directory, supporting removal of non-empty directories if - /// force is true - fn remove( - &mut self, - tenant: impl Into, - path: impl Into, - force: bool, - ) -> AsyncReturn<'_, ()>; - - /// Renames a remote file or directory from src to dst - fn rename( - &mut self, - tenant: impl Into, - src: impl Into, - dst: impl Into, - ) -> AsyncReturn<'_, ()>; - - /// Watches a remote file or directory - fn watch( - &mut self, - tenant: impl Into, - path: impl Into, - recursive: bool, - only: impl Into, - except: impl Into, - ) -> AsyncReturn<'_, Watcher, WatchError>; - - /// Unwatches a remote file or directory - fn unwatch( - &mut self, - tenant: impl Into, - path: impl Into, - ) -> AsyncReturn<'_, (), UnwatchError>; - - /// Spawns a process on the remote machine - fn spawn( - &mut self, - tenant: impl Into, - cmd: impl Into, - args: Vec>, - persist: bool, - pty: Option, - ) -> AsyncReturn<'_, RemoteProcess, RemoteProcessError>; - - /// Spawns an LSP process on the remote machine - fn spawn_lsp( - &mut self, - tenant: impl Into, - cmd: impl Into, - args: Vec>, - persist: bool, - pty: Option, - ) -> AsyncReturn<'_, RemoteLspProcess, RemoteProcessError>; - - /// Retrieves information about the remote system - fn system_info(&mut self, tenant: impl Into) -> AsyncReturn<'_, SystemInfo>; - - /// Writes a remote file with the data from a collection of bytes - fn write_file( - &mut self, - tenant: impl Into, - path: impl Into, - data: impl Into>, - ) -> AsyncReturn<'_, ()>; - - /// Writes a remote file with the data from a string - fn write_file_text( - &mut self, - tenant: impl Into, - path: impl Into, - data: impl Into, - ) -> AsyncReturn<'_, ()>; -} - -macro_rules! make_body { - ($self:expr, $tenant:expr, $data:expr, @ok) => { - make_body!($self, $tenant, $data, |data| { - match data { - ResponseData::Ok => Ok(()), - ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)), - _ => Err(SessionChannelExtError::MismatchedResponse), - } - }) - }; - - ($self:expr, $tenant:expr, $data:expr, $and_then:expr) => {{ - let req = Request::new($tenant, vec![$data]); - Box::pin(async move { - $self - .send(req) - .await - .map_err(SessionChannelExtError::from) - .and_then(|res| { - if res.payload.len() == 1 { - Ok(res.payload.into_iter().next().unwrap()) - } else { - Err(SessionChannelExtError::MismatchedResponse) - } - }) - .and_then($and_then) - }) - }}; -} - -impl SessionChannelExt for SessionChannel { - fn append_file( - &mut self, - tenant: impl Into, - path: impl Into, - data: impl Into>, - ) -> AsyncReturn<'_, ()> { - make_body!( - self, - tenant, - RequestData::FileAppend { path: path.into(), data: data.into() }, - @ok - ) - } - - fn append_file_text( - &mut self, - tenant: impl Into, - path: impl Into, - data: impl Into, - ) -> AsyncReturn<'_, ()> { - make_body!( - self, - tenant, - RequestData::FileAppendText { path: path.into(), text: data.into() }, - @ok - ) - } - - fn copy( - &mut self, - tenant: impl Into, - src: impl Into, - dst: impl Into, - ) -> AsyncReturn<'_, ()> { - make_body!( - self, - tenant, - RequestData::Copy { src: src.into(), dst: dst.into() }, - @ok - ) - } - - fn create_dir( - &mut self, - tenant: impl Into, - path: impl Into, - all: bool, - ) -> AsyncReturn<'_, ()> { - make_body!( - self, - tenant, - RequestData::DirCreate { path: path.into(), all }, - @ok - ) - } - - fn exists( - &mut self, - tenant: impl Into, - path: impl Into, - ) -> AsyncReturn<'_, bool> { - make_body!( - self, - tenant, - RequestData::Exists { path: path.into() }, - |data| match data { - ResponseData::Exists { value } => Ok(value), - ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)), - _ => Err(SessionChannelExtError::MismatchedResponse), - } - ) - } - - fn metadata( - &mut self, - tenant: impl Into, - path: impl Into, - canonicalize: bool, - resolve_file_type: bool, - ) -> AsyncReturn<'_, Metadata> { - make_body!( - self, - tenant, - RequestData::Metadata { - path: path.into(), - canonicalize, - resolve_file_type - }, - |data| match data { - ResponseData::Metadata(x) => Ok(x), - ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)), - _ => Err(SessionChannelExtError::MismatchedResponse), - } - ) - } - - fn read_dir( - &mut self, - tenant: impl Into, - path: impl Into, - depth: usize, - absolute: bool, - canonicalize: bool, - include_root: bool, - ) -> AsyncReturn<'_, (Vec, Vec)> { - make_body!( - self, - tenant, - RequestData::DirRead { - path: path.into(), - depth, - absolute, - canonicalize, - include_root - }, - |data| match data { - ResponseData::DirEntries { entries, errors } => Ok((entries, errors)), - ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)), - _ => Err(SessionChannelExtError::MismatchedResponse), - } - ) - } - - fn read_file( - &mut self, - tenant: impl Into, - path: impl Into, - ) -> AsyncReturn<'_, Vec> { - make_body!( - self, - tenant, - RequestData::FileRead { path: path.into() }, - |data| match data { - ResponseData::Blob { data } => Ok(data), - ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)), - _ => Err(SessionChannelExtError::MismatchedResponse), - } - ) - } - - fn read_file_text( - &mut self, - tenant: impl Into, - path: impl Into, - ) -> AsyncReturn<'_, String> { - make_body!( - self, - tenant, - RequestData::FileReadText { path: path.into() }, - |data| match data { - ResponseData::Text { data } => Ok(data), - ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)), - _ => Err(SessionChannelExtError::MismatchedResponse), - } - ) - } - - fn remove( - &mut self, - tenant: impl Into, - path: impl Into, - force: bool, - ) -> AsyncReturn<'_, ()> { - make_body!( - self, - tenant, - RequestData::Remove { path: path.into(), force }, - @ok - ) - } - - fn rename( - &mut self, - tenant: impl Into, - src: impl Into, - dst: impl Into, - ) -> AsyncReturn<'_, ()> { - make_body!( - self, - tenant, - RequestData::Rename { src: src.into(), dst: dst.into() }, - @ok - ) - } - - fn watch( - &mut self, - tenant: impl Into, - path: impl Into, - recursive: bool, - only: impl Into, - except: impl Into, - ) -> AsyncReturn<'_, Watcher, WatchError> { - let tenant = tenant.into(); - let path = path.into(); - let only = only.into(); - let except = except.into(); - Box::pin(async move { - Watcher::watch(tenant, self.clone(), path, recursive, only, except).await - }) - } - - fn unwatch( - &mut self, - tenant: impl Into, - path: impl Into, - ) -> AsyncReturn<'_, (), UnwatchError> { - fn inner_unwatch( - channel: &mut SessionChannel, - tenant: impl Into, - path: impl Into, - ) -> AsyncReturn<'_, ()> { - make_body!( - channel, - tenant, - RequestData::Unwatch { path: path.into() }, - @ok - ) - } - - let tenant = tenant.into(); - let path = path.into(); - - Box::pin(async move { - inner_unwatch(self, tenant, path) - .await - .map_err(UnwatchError::from) - }) - } - - fn spawn( - &mut self, - tenant: impl Into, - cmd: impl Into, - args: Vec>, - persist: bool, - pty: Option, - ) -> AsyncReturn<'_, RemoteProcess, RemoteProcessError> { - let tenant = tenant.into(); - let cmd = cmd.into(); - let args = args.into_iter().map(Into::into).collect(); - Box::pin(async move { - RemoteProcess::spawn(tenant, self.clone(), cmd, args, persist, pty).await - }) - } - - fn spawn_lsp( - &mut self, - tenant: impl Into, - cmd: impl Into, - args: Vec>, - persist: bool, - pty: Option, - ) -> AsyncReturn<'_, RemoteLspProcess, RemoteProcessError> { - let tenant = tenant.into(); - let cmd = cmd.into(); - let args = args.into_iter().map(Into::into).collect(); - Box::pin(async move { - RemoteLspProcess::spawn(tenant, self.clone(), cmd, args, persist, pty).await - }) - } - - fn system_info(&mut self, tenant: impl Into) -> AsyncReturn<'_, SystemInfo> { - make_body!( - self, - tenant, - RequestData::SystemInfo {}, - |data| match data { - ResponseData::SystemInfo(x) => Ok(x), - ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)), - _ => Err(SessionChannelExtError::MismatchedResponse), - } - ) - } - - fn write_file( - &mut self, - tenant: impl Into, - path: impl Into, - data: impl Into>, - ) -> AsyncReturn<'_, ()> { - make_body!( - self, - tenant, - RequestData::FileWrite { path: path.into(), data: data.into() }, - @ok - ) - } - - fn write_file_text( - &mut self, - tenant: impl Into, - path: impl Into, - data: impl Into, - ) -> AsyncReturn<'_, ()> { - make_body!( - self, - tenant, - RequestData::FileWriteText { path: path.into(), text: data.into() }, - @ok - ) - } -} diff --git a/distant-core/src/client/session/info.rs b/distant-core/src/client/session/info.rs deleted file mode 100644 index a13d012..0000000 --- a/distant-core/src/client/session/info.rs +++ /dev/null @@ -1,235 +0,0 @@ -use crate::net::{SecretKey32, UnprotectedToHexKey}; -use derive_more::{Display, Error}; -use std::{ - env, - net::{IpAddr, SocketAddr}, - ops::Deref, - path::{Path, PathBuf}, - str::FromStr, -}; -use tokio::{io, net::lookup_host}; - -#[derive(Debug, PartialEq, Eq)] -pub struct SessionInfo { - pub host: String, - pub port: u16, - pub key: SecretKey32, -} - -#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq)] -pub enum SessionInfoParseError { - #[display(fmt = "Prefix of string is invalid")] - BadPrefix, - - #[display(fmt = "Bad hex key for session")] - BadHexKey, - - #[display(fmt = "Invalid key for session")] - InvalidKey, - - #[display(fmt = "Invalid port for session")] - InvalidPort, - - #[display(fmt = "Missing address for session")] - MissingAddr, - - #[display(fmt = "Missing key for session")] - MissingKey, - - #[display(fmt = "Missing port for session")] - MissingPort, -} - -impl From for io::Error { - fn from(x: SessionInfoParseError) -> Self { - io::Error::new(io::ErrorKind::InvalidData, x) - } -} - -impl FromStr for SessionInfo { - type Err = SessionInfoParseError; - - fn from_str(s: &str) -> Result { - let mut tokens = s.trim().split(' ').take(5); - - // First, validate that we have the appropriate prefix - if tokens.next().ok_or(SessionInfoParseError::BadPrefix)? != "DISTANT" { - return Err(SessionInfoParseError::BadPrefix); - } - if tokens.next().ok_or(SessionInfoParseError::BadPrefix)? != "CONNECT" { - return Err(SessionInfoParseError::BadPrefix); - } - - // Second, load up the address without parsing it - let host = tokens - .next() - .ok_or(SessionInfoParseError::MissingAddr)? - .trim() - .to_string(); - - // Third, load up the port and parse it into a number - let port = tokens - .next() - .ok_or(SessionInfoParseError::MissingPort)? - .trim() - .parse::() - .map_err(|_| SessionInfoParseError::InvalidPort)?; - - // Fourth, load up the key and convert it back into a secret key from a hex slice - let key = SecretKey32::from_slice( - &hex::decode( - tokens - .next() - .ok_or(SessionInfoParseError::MissingKey)? - .trim(), - ) - .map_err(|_| SessionInfoParseError::BadHexKey)?, - ) - .map_err(|_| SessionInfoParseError::InvalidKey)?; - - Ok(SessionInfo { host, port, key }) - } -} - -impl SessionInfo { - /// Loads session from environment variables - pub fn from_environment() -> io::Result { - fn to_err(x: env::VarError) -> io::Error { - io::Error::new(io::ErrorKind::InvalidInput, x) - } - - let host = env::var("DISTANT_HOST").map_err(to_err)?; - let port = env::var("DISTANT_PORT").map_err(to_err)?; - let key = env::var("DISTANT_KEY").map_err(to_err)?; - Ok(format!("DISTANT CONNECT {} {} {}", host, port, key).parse()?) - } - - /// Loads session from the next line available in this program's stdin - pub fn from_stdin() -> io::Result { - let mut line = String::new(); - std::io::stdin().read_line(&mut line)?; - line.parse() - .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)) - } - - /// Consumes the session and returns the key - pub fn into_key(self) -> SecretKey32 { - self.key - } - - /// Returns the ip address associated with the session based on the host - pub async fn to_ip_addr(&self) -> io::Result { - let addr = match self.host.parse::() { - Ok(addr) => addr, - Err(_) => lookup_host((self.host.as_str(), self.port)) - .await? - .next() - .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Failed to lookup_host"))? - .ip(), - }; - - Ok(addr) - } - - /// Returns socket address associated with the session - pub async fn to_socket_addr(&self) -> io::Result { - let addr = self.to_ip_addr().await?; - Ok(SocketAddr::from((addr, self.port))) - } - - /// Converts the session's key to a hex string - pub fn key_to_unprotected_string(&self) -> String { - self.key.unprotected_to_hex_key() - } - - /// Converts to unprotected string that exposes the key in the form of - /// `DISTANT CONNECT ` - pub fn to_unprotected_string(&self) -> String { - format!( - "DISTANT CONNECT {} {} {}", - self.host, - self.port, - self.key.unprotected_to_hex_key() - ) - } -} - -/// Provides operations related to working with a session that is disk-based -pub struct SessionInfoFile { - path: PathBuf, - session: SessionInfo, -} - -impl AsRef for SessionInfoFile { - fn as_ref(&self) -> &Path { - self.as_path() - } -} - -impl AsRef for SessionInfoFile { - fn as_ref(&self) -> &SessionInfo { - self.as_session() - } -} - -impl Deref for SessionInfoFile { - type Target = SessionInfo; - - fn deref(&self) -> &Self::Target { - &self.session - } -} - -impl From for SessionInfo { - fn from(sf: SessionInfoFile) -> Self { - sf.session - } -} - -impl SessionInfoFile { - /// Creates a new inmemory pointer to a session and its file - pub fn new(path: impl Into, session: SessionInfo) -> Self { - Self { - path: path.into(), - session, - } - } - - /// Returns a reference to the path to the session file - pub fn as_path(&self) -> &Path { - self.path.as_path() - } - - /// Returns a reference to the session - pub fn as_session(&self) -> &SessionInfo { - &self.session - } - - /// Saves a session by overwriting its current - pub async fn save(&self) -> io::Result<()> { - self.save_to(self.as_path(), true).await - } - - /// Saves a session to to a file at the specified path - /// - /// If all is true, will create all directories leading up to file's location - pub async fn save_to(&self, path: impl AsRef, all: bool) -> io::Result<()> { - if all { - if let Some(dir) = path.as_ref().parent() { - tokio::fs::create_dir_all(dir).await?; - } - } - - tokio::fs::write(path.as_ref(), self.session.to_unprotected_string()).await - } - - /// Loads a session from a file at the specified path - pub async fn load_from(path: impl AsRef) -> io::Result { - let text = tokio::fs::read_to_string(path.as_ref()).await?; - - Ok(Self { - path: path.as_ref().to_path_buf(), - session: text.parse()?, - }) - } -} diff --git a/distant-core/src/client/session/mailbox.rs b/distant-core/src/client/session/mailbox.rs deleted file mode 100644 index 20d44a1..0000000 --- a/distant-core/src/client/session/mailbox.rs +++ /dev/null @@ -1,84 +0,0 @@ -use crate::{client::utils, data::Response}; -use std::{collections::HashMap, time::Duration}; -use tokio::{io, sync::mpsc}; - -pub struct PostOffice { - mailboxes: HashMap>, -} - -impl PostOffice { - pub fn new() -> Self { - Self { - mailboxes: HashMap::new(), - } - } - - /// Creates a new mailbox using the given id and buffer size for maximum messages - pub fn make_mailbox(&mut self, id: usize, buffer: usize) -> Mailbox { - let (tx, rx) = mpsc::channel(buffer); - self.mailboxes.insert(id, tx); - - Mailbox { id, rx } - } - - /// Delivers a response to appropriate mailbox, returning false if no mailbox is found - /// for the response or if the mailbox is no longer receiving responses - pub async fn deliver(&mut self, res: Response) -> bool { - let id = res.origin_id; - - let success = if let Some(tx) = self.mailboxes.get_mut(&id) { - tx.send(res).await.is_ok() - } else { - false - }; - - // If failed, we want to remove the mailbox sender as it is no longer valid - if !success { - self.mailboxes.remove(&id); - } - - success - } - - /// Removes all mailboxes from post office that are closed - pub fn prune_mailboxes(&mut self) { - self.mailboxes.retain(|_, tx| !tx.is_closed()) - } - - /// Closes out all mailboxes by removing the mailboxes delivery trackers internally - pub fn clear_mailboxes(&mut self) { - self.mailboxes.clear(); - } -} - -pub struct Mailbox { - /// Represents id associated with the mailbox - id: usize, - - /// Underlying mailbox storage - rx: mpsc::Receiver, -} - -impl Mailbox { - /// Represents id associated with the mailbox - pub fn id(&self) -> usize { - self.id - } - - /// Receives next response in mailbox - pub async fn next(&mut self) -> Option { - self.rx.recv().await - } - - /// Receives next response in mailbox, waiting up to duration before timing out - pub async fn next_timeout(&mut self, duration: Duration) -> io::Result> { - utils::timeout(duration, self.next()).await - } - - /// Closes the mailbox such that it will not receive any more responses - /// - /// Any responses already in the mailbox will still be returned via `next` - pub async fn close(&mut self) { - self.rx.close() - } -} diff --git a/distant-core/src/client/session/mod.rs b/distant-core/src/client/session/mod.rs deleted file mode 100644 index bbe4221..0000000 --- a/distant-core/src/client/session/mod.rs +++ /dev/null @@ -1,504 +0,0 @@ -use crate::{ - client::utils, - constants::CLIENT_MAILBOX_CAPACITY, - data::{Request, Response}, - net::{Codec, DataStream, Transport, TransportError}, -}; -use log::*; -use serde::{Deserialize, Serialize}; -use std::{ - convert, - net::SocketAddr, - ops::{Deref, DerefMut}, - path::{Path, PathBuf}, - sync::{Arc, Weak}, -}; -use tokio::{ - io, - net::TcpStream, - sync::{mpsc, Mutex}, - task::{JoinError, JoinHandle}, - time::Duration, -}; - -mod ext; -pub use ext::{SessionChannelExt, SessionChannelExtError}; - -mod info; -pub use info::{SessionInfo, SessionInfoFile, SessionInfoParseError}; - -mod mailbox; -pub use mailbox::Mailbox; -use mailbox::PostOffice; - -/// Details about the session -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum SessionDetails { - /// Indicates session is a TCP type - Tcp { addr: SocketAddr, tag: String }, - - /// Indicates session is a Unix socket type - Socket { path: PathBuf, tag: String }, - - /// Indicates session type is inmemory - Inmemory { tag: String }, - - /// Indicates session type is a custom type (such as ssh) - Custom { tag: String }, -} - -impl SessionDetails { - /// Represents the tag associated with the session - pub fn tag(&self) -> &str { - match self { - Self::Tcp { tag, .. } => tag.as_str(), - Self::Socket { tag, .. } => tag.as_str(), - Self::Inmemory { tag } => tag.as_str(), - Self::Custom { tag } => tag.as_str(), - } - } - - /// Represents the socket address associated with the session, if it has one - pub fn addr(&self) -> Option { - match self { - Self::Tcp { addr, .. } => Some(*addr), - _ => None, - } - } - - /// Represents the path associated with the session, if it has one - pub fn path(&self) -> Option<&Path> { - match self { - Self::Socket { path, .. } => Some(path.as_path()), - _ => None, - } - } -} - -/// Represents a session with a remote server that can be used to send requests & receive responses -pub struct Session { - /// Used to send requests to a server - channel: SessionChannel, - - /// Details about the session - details: Option, - - /// Contains the task that is running to send requests to a server - request_task: JoinHandle<()>, - - /// Contains the task that is running to receive responses from a server - response_task: JoinHandle<()>, - - /// Contains the task that runs on a timer to prune closed mailboxes - prune_task: JoinHandle<()>, -} - -impl Session { - /// Connect to a remote TCP server using the provided information - pub async fn tcp_connect(addr: SocketAddr, codec: U) -> io::Result - where - U: Codec + Send + 'static, - { - let transport = Transport::::connect(addr, codec).await?; - let details = SessionDetails::Tcp { - addr, - tag: transport.to_connection_tag(), - }; - debug!( - "Session has been established with {}", - transport - .peer_addr() - .map(|x| x.to_string()) - .unwrap_or_else(|_| String::from("???")) - ); - Self::initialize_with_details(transport, Some(details)) - } - - /// Connect to a remote TCP server, timing out after duration has passed - pub async fn tcp_connect_timeout( - addr: SocketAddr, - codec: U, - duration: Duration, - ) -> io::Result - where - U: Codec + Send + 'static, - { - utils::timeout(duration, Self::tcp_connect(addr, codec)) - .await - .and_then(convert::identity) - } - - /// Convert into underlying channel - pub fn into_channel(self) -> SessionChannel { - self.channel - } -} - -#[cfg(unix)] -impl Session { - /// Connect to a proxy unix socket - pub async fn unix_connect(path: impl AsRef, codec: U) -> io::Result - where - U: Codec + Send + 'static, - { - let p = path.as_ref(); - let transport = Transport::::connect(p, codec).await?; - let details = SessionDetails::Socket { - path: p.to_path_buf(), - tag: transport.to_connection_tag(), - }; - debug!( - "Session has been established with {}", - transport - .peer_addr() - .map(|x| format!("{:?}", x)) - .unwrap_or_else(|_| String::from("???")) - ); - Self::initialize_with_details(transport, Some(details)) - } - - /// Connect to a proxy unix socket, timing out after duration has passed - pub async fn unix_connect_timeout( - path: impl AsRef, - codec: U, - duration: Duration, - ) -> io::Result - where - U: Codec + Send + 'static, - { - utils::timeout(duration, Self::unix_connect(path, codec)) - .await - .and_then(convert::identity) - } -} - -impl Session { - /// Initializes a session using the provided transport and no extra details - pub fn initialize(transport: Transport) -> io::Result - where - T: DataStream, - U: Codec + Send + 'static, - { - Self::initialize_with_details(transport, None) - } - - /// Initializes a session using the provided transport and extra details - pub fn initialize_with_details( - transport: Transport, - details: Option, - ) -> io::Result - where - T: DataStream, - U: Codec + Send + 'static, - { - let (mut t_read, mut t_write) = transport.into_split(); - let post_office = Arc::new(Mutex::new(PostOffice::new())); - let weak_post_office = Arc::downgrade(&post_office); - - // Start a task that continually checks for responses and delivers them using the - // post office - let response_task = tokio::spawn(async move { - loop { - match t_read.receive::().await { - Ok(Some(res)) => { - trace!("Incoming response: {:?}", res); - let res_id = res.id; - let res_origin_id = res.origin_id; - - // Try to send response to appropriate mailbox - // NOTE: We don't log failures as errors as using fire(...) for a - // session is valid and would not have a mailbox - if !post_office.lock().await.deliver(res).await { - trace!( - "Response {} has no mailbox for origin {}", - res_id, - res_origin_id - ); - } - } - Ok(None) => { - debug!("Session closing response task as transport read-half closed!"); - break; - } - Err(x) => { - error!("Failed to receive response from server: {}", x); - break; - } - } - } - - // Clean up remaining mailbox senders - post_office.lock().await.clear_mailboxes(); - }); - - let (tx, mut rx) = mpsc::channel::(1); - let request_task = tokio::spawn(async move { - while let Some(req) = rx.recv().await { - if let Err(x) = t_write.send(req).await { - error!("Failed to send request to server: {}", x); - break; - } - } - }); - - // Create a task that runs once a minute and prunes mailboxes - let post_office = Weak::clone(&weak_post_office); - let prune_task = tokio::spawn(async move { - loop { - tokio::time::sleep(Duration::from_secs(60)).await; - if let Some(post_office) = Weak::upgrade(&post_office) { - post_office.lock().await.prune_mailboxes(); - } else { - break; - } - } - }); - - let channel = SessionChannel { - tx, - post_office: weak_post_office, - }; - - Ok(Self { - channel, - details, - request_task, - response_task, - prune_task, - }) - } -} - -impl Session { - /// Returns details about the session, if it has any - pub fn details(&self) -> Option<&SessionDetails> { - self.details.as_ref() - } - - /// Waits for the session to terminate, which results when the receiving end of the network - /// connection is closed (or the session is shutdown) - pub async fn wait(self) -> Result<(), JoinError> { - self.prune_task.abort(); - tokio::try_join!(self.request_task, self.response_task).map(|_| ()) - } - - /// Abort the session's current connection by forcing its tasks to abort - pub fn abort(&self) { - self.request_task.abort(); - self.response_task.abort(); - self.prune_task.abort(); - } - - /// Clones the underlying channel for requests and returns the cloned instance - pub fn clone_channel(&self) -> SessionChannel { - self.channel.clone() - } -} - -impl Deref for Session { - type Target = SessionChannel; - - fn deref(&self) -> &Self::Target { - &self.channel - } -} - -impl DerefMut for Session { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.channel - } -} - -impl From for SessionChannel { - fn from(session: Session) -> Self { - session.channel - } -} - -/// Represents a sender of requests tied to a session, holding onto a weak reference of -/// mailboxes to relay responses, meaning that once the [`Session`] is closed or dropped, -/// any sent request will no longer be able to receive responses -#[derive(Clone)] -pub struct SessionChannel { - /// Used to send requests to a server - tx: mpsc::Sender, - - /// Collection of mailboxes for receiving responses to requests - post_office: Weak>, -} - -impl SessionChannel { - /// Returns true if no more requests can be transferred - pub fn is_closed(&self) -> bool { - self.tx.is_closed() - } - - /// Sends a request and returns a mailbox that can receive one or more responses, failing if - /// unable to send a request or if the session's receiving line to the remote server has - /// already been severed - pub async fn mail(&mut self, req: Request) -> Result { - trace!("Mailing request: {:?}", req); - - // First, create a mailbox using the request's id - let mailbox = Weak::upgrade(&self.post_office) - .ok_or_else(|| { - TransportError::IoError(io::Error::new( - io::ErrorKind::NotConnected, - "Session's post office is no longer available", - )) - })? - .lock() - .await - .make_mailbox(req.id, CLIENT_MAILBOX_CAPACITY); - - // Second, send the request - self.fire(req).await?; - - // Third, return mailbox - Ok(mailbox) - } - - /// Sends a request and waits for a response, failing if unable to send a request or if - /// the session's receiving line to the remote server has already been severed - pub async fn send(&mut self, req: Request) -> Result { - trace!("Sending request: {:?}", req); - - // Send mail and get back a mailbox - let mut mailbox = self.mail(req).await?; - - // Wait for first response, and then drop the mailbox - mailbox.next().await.ok_or_else(|| { - TransportError::IoError(io::Error::from(io::ErrorKind::ConnectionAborted)) - }) - } - - /// Sends a request and waits for a response, timing out after duration has passed - pub async fn send_timeout( - &mut self, - req: Request, - duration: Duration, - ) -> Result { - utils::timeout(duration, self.send(req)) - .await - .map_err(TransportError::from) - .and_then(convert::identity) - } - - /// Sends a request without waiting for a response; this method is able to be used even - /// if the session's receiving line to the remote server has been severed - pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> { - trace!("Firing off request: {:?}", req); - self.tx - .send(req) - .await - .map_err(|x| TransportError::IoError(io::Error::new(io::ErrorKind::BrokenPipe, x))) - } - - /// Sends a request without waiting for a response, timing out after duration has passed - pub async fn fire_timeout( - &mut self, - req: Request, - duration: Duration, - ) -> Result<(), TransportError> { - utils::timeout(duration, self.fire(req)) - .await - .map_err(TransportError::from) - .and_then(convert::identity) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - constants::test::TENANT, - data::{RequestData, ResponseData}, - }; - use std::time::Duration; - - #[tokio::test] - async fn mail_should_return_mailbox_that_receives_responses_until_transport_closes() { - let (t1, mut t2) = Transport::make_pair(); - let mut session = Session::initialize(t1).unwrap(); - - let req = Request::new(TENANT, vec![RequestData::ProcList {}]); - let res = Response::new(TENANT, req.id, vec![ResponseData::Ok]); - - let mut mailbox = session.mail(req).await.unwrap(); - - // Get first response - match tokio::join!(mailbox.next(), t2.send(res.clone())) { - (Some(actual), _) => assert_eq!(actual, res), - x => panic!("Unexpected response: {:?}", x), - } - - // Get second response - match tokio::join!(mailbox.next(), t2.send(res.clone())) { - (Some(actual), _) => assert_eq!(actual, res), - x => panic!("Unexpected response: {:?}", x), - } - - // Trigger the mailbox to wait BEFORE closing our transport to ensure that - // we don't get stuck if the mailbox was already waiting - let next_task = tokio::spawn(async move { mailbox.next().await }); - tokio::task::yield_now().await; - - drop(t2); - match next_task.await { - Ok(None) => {} - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn send_should_wait_until_response_received() { - let (t1, mut t2) = Transport::make_pair(); - let mut session = Session::initialize(t1).unwrap(); - - let req = Request::new(TENANT, vec![RequestData::ProcList {}]); - let res = Response::new( - TENANT, - req.id, - vec![ResponseData::ProcEntries { - entries: Vec::new(), - }], - ); - - let (actual, _) = tokio::join!(session.send(req), t2.send(res.clone())); - match actual { - Ok(actual) => assert_eq!(actual, res), - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn send_timeout_should_fail_if_response_not_received_in_time() { - let (t1, mut t2) = Transport::make_pair(); - let mut session = Session::initialize(t1).unwrap(); - - let req = Request::new(TENANT, vec![RequestData::ProcList {}]); - match session.send_timeout(req, Duration::from_millis(30)).await { - Err(TransportError::IoError(x)) => assert_eq!(x.kind(), io::ErrorKind::TimedOut), - x => panic!("Unexpected response: {:?}", x), - } - - let req = t2.receive::().await.unwrap().unwrap(); - assert_eq!(req.tenant, TENANT); - } - - #[tokio::test] - async fn fire_should_send_request_and_not_wait_for_response() { - let (t1, mut t2) = Transport::make_pair(); - let mut session = Session::initialize(t1).unwrap(); - - let req = Request::new(TENANT, vec![RequestData::ProcList {}]); - match session.fire(req).await { - Ok(_) => {} - x => panic!("Unexpected response: {:?}", x), - } - - let req = t2.receive::().await.unwrap().unwrap(); - assert_eq!(req.tenant, TENANT); - } -} diff --git a/distant-core/src/client/utils.rs b/distant-core/src/client/utils.rs deleted file mode 100644 index fcff08f..0000000 --- a/distant-core/src/client/utils.rs +++ /dev/null @@ -1,13 +0,0 @@ -use std::{future::Future, time::Duration}; -use tokio::{io, time}; - -// Wraps a future in a tokio timeout call, transforming the error into -// an io error -pub async fn timeout(d: Duration, f: F) -> io::Result -where - F: Future, -{ - time::timeout(d, f) - .await - .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) -} diff --git a/distant-core/src/client/watcher.rs b/distant-core/src/client/watcher.rs index f87fe50..072d6d0 100644 --- a/distant-core/src/client/watcher.rs +++ b/distant-core/src/client/watcher.rs @@ -1,43 +1,20 @@ use crate::{ - client::{SessionChannel, SessionChannelExt, SessionChannelExtError}, + client::{DistantChannel, DistantChannelExt}, constants::CLIENT_WATCHER_CAPACITY, - data::{Change, ChangeKindSet, Error as DistantError, Request, RequestData, ResponseData}, - net::TransportError, + data::{Change, ChangeKindSet, DistantRequestData, DistantResponseData}, + DistantMsg, }; -use derive_more::{Display, Error, From}; +use distant_net::Request; use log::*; use std::{ - fmt, + fmt, io, path::{Path, PathBuf}, }; use tokio::{sync::mpsc, task::JoinHandle}; -#[derive(Debug, Display, Error)] -pub enum WatchError { - /// When no confirmation of watch is received - MissingConfirmation, - - /// A server-side error occurred when attempting to watch - ServerError(DistantError), - - /// When the communication over the wire has issues - TransportError(TransportError), - - /// When a queued change is dropped because the response channel closed early - QueuedChangeDropped, - - /// Some unexpected response was received when attempting to watch - #[display(fmt = "Unexpected response: {:?}", _0)] - UnexpectedResponse(#[error(not(source))] ResponseData), -} - -#[derive(Debug, Display, From, Error)] -pub struct UnwatchError(SessionChannelExtError); - /// Represents a watcher of some path on a remote machine pub struct Watcher { - tenant: String, - channel: SessionChannel, + channel: DistantChannel, path: PathBuf, task: JoinHandle<()>, rx: mpsc::Receiver, @@ -46,24 +23,19 @@ pub struct Watcher { impl fmt::Debug for Watcher { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Watcher") - .field("tenant", &self.tenant) - .field("path", &self.path) - .finish() + f.debug_struct("Watcher").field("path", &self.path).finish() } } impl Watcher { /// Creates a watcher for some remote path pub async fn watch( - tenant: impl Into, - mut channel: SessionChannel, + mut channel: DistantChannel, path: impl Into, recursive: bool, only: impl Into, except: impl Into, - ) -> Result { - let tenant = tenant.into(); + ) -> io::Result { let path = path.into(); let only = only.into(); let except = except.into(); @@ -85,17 +57,15 @@ impl Watcher { // Submit our run request and get back a mailbox for responses let mut mailbox = channel - .mail(Request::new( - tenant.as_str(), - vec![RequestData::Watch { + .mail(Request::new(DistantMsg::Single( + DistantRequestData::Watch { path: path.to_path_buf(), recursive, only: only.into_vec(), except: except.into_vec(), - }], - )) - .await - .map_err(WatchError::TransportError)?; + }, + ))) + .await?; let (tx, rx) = mpsc::channel(CLIENT_WATCHER_CAPACITY); @@ -103,14 +73,19 @@ impl Watcher { let mut queue: Vec = Vec::new(); let mut confirmed = false; while let Some(res) = mailbox.next().await { - for data in res.payload { + for data in res.payload.into_vec() { match data { - ResponseData::Changed(change) => queue.push(change), - ResponseData::Ok => { + DistantResponseData::Changed(change) => queue.push(change), + DistantResponseData::Ok => { confirmed = true; } - ResponseData::Error(x) => return Err(WatchError::ServerError(x)), - x => return Err(WatchError::UnexpectedResponse(x)), + DistantResponseData::Error(x) => return Err(io::Error::from(x)), + x => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("Unexpected response: {:?}", x), + )) + } } } @@ -126,14 +101,14 @@ impl Watcher { trace!("Forwarding {} queued changes for {:?}", queue.len(), path); for change in queue { if tx.send(change).await.is_err() { - return Err(WatchError::QueuedChangeDropped); + return Err(io::Error::new(io::ErrorKind::Other, "Queue change dropped")); } } // If we never received an acknowledgement of watch before the mailbox closed, // fail with a missing confirmation error if !confirmed { - return Err(WatchError::MissingConfirmation); + return Err(io::Error::new(io::ErrorKind::Other, "Missing confirmation")); } // Spawn a task that continues to look for change events, discarding anything @@ -142,9 +117,9 @@ impl Watcher { let path = path.clone(); async move { while let Some(res) = mailbox.next().await { - for data in res.payload { + for data in res.payload.into_vec() { match data { - ResponseData::Changed(change) => { + DistantResponseData::Changed(change) => { // If we can't queue up a change anymore, we've // been closed and therefore want to quit if tx.is_closed() { @@ -168,7 +143,6 @@ impl Watcher { }); Ok(Self { - tenant, path, channel, task, @@ -193,42 +167,37 @@ impl Watcher { } /// Unwatches the path being watched, closing out the watcher - pub async fn unwatch(&mut self) -> Result<(), UnwatchError> { + pub async fn unwatch(&mut self) -> io::Result<()> { trace!("Unwatching {:?}", self.path); - let result = self - .channel - .unwatch(self.tenant.to_string(), self.path.to_path_buf()) - .await - .map_err(UnwatchError::from); + self.channel.unwatch(self.path.to_path_buf()).await?; - match result { - Ok(_) => { - // Kill our task that processes inbound changes if we - // have successfully unwatched the path - self.task.abort(); - self.active = false; + // Kill our task that processes inbound changes if we have successfully unwatched the path + self.task.abort(); + self.active = false; - Ok(()) - } - Err(x) => Err(x), - } + Ok(()) } } #[cfg(test)] mod tests { use super::*; - use crate::{ - client::Session, - data::{ChangeKind, Response}, - net::{InmemoryStream, PlainCodec, Transport}, + use crate::data::ChangeKind; + use crate::DistantClient; + use distant_net::{ + Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Response, + TypedAsyncRead, TypedAsyncWrite, }; use std::sync::Arc; use tokio::sync::Mutex; - fn make_session() -> (Transport, Session) { - let (t1, t2) = Transport::make_pair(); - (t1, Session::initialize(t2).unwrap()) + fn make_session() -> ( + FramedTransport, + DistantClient, + ) { + let (t1, t2) = FramedTransport::pair(100); + let (writer, reader) = t2.into_split(); + (t1, Client::new(writer, reader).unwrap()) } #[tokio::test] @@ -240,7 +209,6 @@ mod tests { // in a separate async block let watch_task = tokio::spawn(async move { Watcher::watch( - String::from("test-tenant"), session.clone_channel(), test_path, true, @@ -251,11 +219,11 @@ mod tests { }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request = transport.read().await.unwrap().unwrap(); // Send back an acknowledgement that a watcher was created transport - .send(Response::new("test-tenant", req.id, vec![ResponseData::Ok])) + .write(Response::new(req.id, DistantResponseData::Ok)) .await .unwrap(); @@ -273,7 +241,6 @@ mod tests { // in a separate async block let watch_task = tokio::spawn(async move { Watcher::watch( - String::from("test-tenant"), session.clone_channel(), test_path, true, @@ -284,11 +251,11 @@ mod tests { }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request = transport.read().await.unwrap().unwrap(); // Send back an acknowledgement that a watcher was created transport - .send(Response::new("test-tenant", req.id, vec![ResponseData::Ok])) + .write(Response::new(req.id.clone(), DistantResponseData::Ok)) .await .unwrap(); @@ -297,15 +264,14 @@ mod tests { // Send some changes related to the file transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, vec![ - ResponseData::Changed(Change { + DistantResponseData::Changed(Change { kind: ChangeKind::Access, paths: vec![test_path.to_path_buf()], }), - ResponseData::Changed(Change { + DistantResponseData::Changed(Change { kind: ChangeKind::Content, paths: vec![test_path.to_path_buf()], }), @@ -343,7 +309,6 @@ mod tests { // in a separate async block let watch_task = tokio::spawn(async move { Watcher::watch( - String::from("test-tenant"), session.clone_channel(), test_path, true, @@ -354,11 +319,11 @@ mod tests { }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request = transport.read().await.unwrap().unwrap(); // Send back an acknowledgement that a watcher was created transport - .send(Response::new("test-tenant", req.id, vec![ResponseData::Ok])) + .write(Response::new(req.id.clone(), DistantResponseData::Ok)) .await .unwrap(); @@ -367,39 +332,36 @@ mod tests { // Send a change from the appropriate origin transport - .send(Response::new( - "test-tenant", - req.id, - vec![ResponseData::Changed(Change { + .write(Response::new( + req.id.clone(), + DistantResponseData::Changed(Change { kind: ChangeKind::Access, paths: vec![test_path.to_path_buf()], - })], + }), )) .await .unwrap(); // Send a change from a different origin transport - .send(Response::new( - "test-tenant", - req.id + 1, - vec![ResponseData::Changed(Change { + .write(Response::new( + req.id.clone() + "1", + DistantResponseData::Changed(Change { kind: ChangeKind::Content, paths: vec![test_path.to_path_buf()], - })], + }), )) .await .unwrap(); // Send a change from the appropriate origin transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::Changed(Change { + DistantResponseData::Changed(Change { kind: ChangeKind::Remove, paths: vec![test_path.to_path_buf()], - })], + }), )) .await .unwrap(); @@ -433,7 +395,6 @@ mod tests { // in a separate async block let watch_task = tokio::spawn(async move { Watcher::watch( - String::from("test-tenant"), session.clone_channel(), test_path, true, @@ -444,29 +405,28 @@ mod tests { }); // Wait until we get the request from the session - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request = transport.read().await.unwrap().unwrap(); // Send back an acknowledgement that a watcher was created transport - .send(Response::new("test-tenant", req.id, vec![ResponseData::Ok])) + .write(Response::new(req.id.clone(), DistantResponseData::Ok)) .await .unwrap(); // Send some changes from the appropriate origin transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, vec![ - ResponseData::Changed(Change { + DistantResponseData::Changed(Change { kind: ChangeKind::Access, paths: vec![test_path.to_path_buf()], }), - ResponseData::Changed(Change { + DistantResponseData::Changed(Change { kind: ChangeKind::Content, paths: vec![test_path.to_path_buf()], }), - ResponseData::Changed(Change { + DistantResponseData::Changed(Change { kind: ChangeKind::Remove, paths: vec![test_path.to_path_buf()], }), @@ -501,24 +461,23 @@ mod tests { let watcher_2 = Arc::clone(&watcher); let unwatch_task = tokio::spawn(async move { watcher_2.lock().await.unwatch().await }); - let req = transport.receive::().await.unwrap().unwrap(); + let req: Request = transport.read().await.unwrap().unwrap(); transport - .send(Response::new("test-tenant", req.id, vec![ResponseData::Ok])) + .write(Response::new(req.id.clone(), DistantResponseData::Ok)) .await .unwrap(); // Wait for the unwatch to complete - let _ = unwatch_task.await.unwrap().unwrap(); + unwatch_task.await.unwrap().unwrap(); transport - .send(Response::new( - "test-tenant", + .write(Response::new( req.id, - vec![ResponseData::Changed(Change { + DistantResponseData::Changed(Change { kind: ChangeKind::Unknown, paths: vec![test_path.to_path_buf()], - })], + }), )) .await .unwrap(); diff --git a/distant-core/src/constants.rs b/distant-core/src/constants.rs index 8aa3683..2a949ec 100644 --- a/distant-core/src/constants.rs +++ b/distant-core/src/constants.rs @@ -1,6 +1,3 @@ -/// Capacity associated with a client mailboxes for receiving multiple responses to a request -pub const CLIENT_MAILBOX_CAPACITY: usize = 10000; - /// Capacity associated stdin, stdout, and stderr pipes receiving data from remote server pub const CLIENT_PIPE_CAPACITY: usize = 10000; @@ -19,13 +16,3 @@ pub const MAX_PIPE_CHUNK_SIZE: usize = 16384; /// Duration in milliseconds to sleep between reading stdout/stderr chunks /// to avoid sending many small messages to clients pub const READ_PAUSE_MILLIS: u64 = 50; - -/// Maximum message capacity per connection for the distant server -pub const MAX_MSG_CAPACITY: usize = 10000; - -/// Test-only constants -#[cfg(test)] -pub mod test { - pub const BUFFER_SIZE: usize = 100; - pub const TENANT: &str = "test-tenant"; -} diff --git a/distant-core/src/credentials.rs b/distant-core/src/credentials.rs new file mode 100644 index 0000000..9599ede --- /dev/null +++ b/distant-core/src/credentials.rs @@ -0,0 +1,133 @@ +use crate::{ + serde_str::{deserialize_from_str, serialize_to_str}, + Destination, +}; +use distant_net::SecretKey32; +use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize}; +use std::{ + convert::{TryFrom, TryInto}, + fmt, io, + str::FromStr, +}; +use uriparse::{URIReference, URI}; + +/// Represents credentials used for a distant server that is maintaining a single key +/// across all connections +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct DistantSingleKeyCredentials { + pub host: String, + pub port: u16, + pub key: SecretKey32, + pub username: Option, +} + +impl fmt::Display for DistantSingleKeyCredentials { + /// Converts credentials into string in the form of `distant://[username]:{key}@{host}:{port}` + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "distant://")?; + if let Some(username) = self.username.as_ref() { + write!(f, "{}", username)?; + } + write!(f, ":{}@{}:{}", self.key, self.host, self.port) + } +} + +impl FromStr for DistantSingleKeyCredentials { + type Err = io::Error; + + /// Parse `distant://[username]:{key}@{host}` as credentials. Note that this requires the + /// `distant` scheme to be included. If parsing without scheme is desired, call the + /// [`DistantSingleKeyCredentials::try_from_uri_ref`] method instead with `require_scheme` + /// set to false + fn from_str(s: &str) -> Result { + Self::try_from_uri_ref(s, true) + } +} + +impl Serialize for DistantSingleKeyCredentials { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serialize_to_str(self, serializer) + } +} + +impl<'de> Deserialize<'de> for DistantSingleKeyCredentials { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserialize_from_str(deserializer) + } +} + +impl DistantSingleKeyCredentials { + /// Converts credentials into a [`Destination`] of the form `distant://[username]:{key}@{host}`, + /// failing if the credentials would not produce a valid [`Destination`] + pub fn try_to_destination(&self) -> io::Result { + let uri = self.try_to_uri()?; + Destination::try_from(uri.as_uri_reference().to_borrowed()) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)) + } + + /// Converts credentials into a [`URI`] of the form `distant://[username]:{key}@{host}`, + /// failing if the credentials would not produce a valid [`URI`] + pub fn try_to_uri(&self) -> io::Result> { + let uri_string = self.to_string(); + URI::try_from(uri_string.as_str()) + .map(URI::into_owned) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)) + } + + /// Parses credentials from a [`URIReference`], failing if the input was not a valid + /// [`URIReference`] or if required parameters like `host` or `password` are missing or bad + /// format + /// + /// If `require_scheme` is true, will enforce that a scheme is provided. Regardless, if a + /// scheme is provided that is not `distant`, this will also fail + pub fn try_from_uri_ref<'a, E>( + uri_ref: impl TryInto, Error = E>, + require_scheme: bool, + ) -> io::Result + where + E: std::error::Error + Send + Sync + 'static, + { + let uri_ref = uri_ref + .try_into() + .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; + + // Check if the scheme is correct, and if missing if we require it + if let Some(scheme) = uri_ref.scheme() { + if !scheme.as_str().eq_ignore_ascii_case("distant") { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Scheme is not distant", + )); + } + } else if require_scheme { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Missing scheme", + )); + } + + Ok(Self { + host: uri_ref + .host() + .map(ToString::to_string) + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Missing host"))?, + port: uri_ref + .port() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Missing port"))?, + key: uri_ref + .password() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Missing password")) + .and_then(|x| { + x.parse() + .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x)) + })?, + username: uri_ref.username().map(ToString::to_string), + }) + } +} diff --git a/distant-core/src/data.rs b/distant-core/src/data.rs index e872397..debc7ec 100644 --- a/distant-core/src/data.rs +++ b/distant-core/src/data.rs @@ -1,70 +1,151 @@ -use bitflags::bitflags; -use derive_more::{Deref, DerefMut, Display, Error, IntoIterator, IsVariant}; -use notify::{ - event::Event as NotifyEvent, ErrorKind as NotifyErrorKind, EventKind as NotifyEventKind, -}; -use portable_pty::PtySize as PortablePtySize; +use derive_more::{From, IsVariant}; use serde::{Deserialize, Serialize}; -use std::{ - collections::HashSet, io, iter::FromIterator, num::ParseIntError, ops::BitOr, path::PathBuf, - str::FromStr, -}; -use strum::{AsRefStr, EnumString, EnumVariantNames, VariantNames}; +use std::{io, path::PathBuf}; +use strum::AsRefStr; + +#[cfg(feature = "clap")] +use strum::VariantNames; + +mod change; +pub use change::*; + +mod cmd; +pub use cmd::*; + +#[cfg(feature = "clap")] +mod clap_impl; + +mod error; +pub use error::*; + +mod filesystem; +pub use filesystem::*; + +mod map; +pub use map::Map; + +mod metadata; +pub use metadata::*; + +mod pty; +pub use pty::*; + +mod system; +pub use system::*; + +mod utils; +pub(crate) use utils::*; + +/// Id for a remote process +pub type ProcessId = u32; + +/// Mapping of environment variables +pub type Environment = Map; /// Type alias for a vec of bytes /// /// NOTE: This only exists to support properly parsing a Vec from an entire string -/// with structopt rather than trying to parse a string as a singular u8 +/// with clap rather than trying to parse a string as a singular u8 pub type ByteVec = Vec; -#[cfg(feature = "structopt")] +#[cfg(feature = "clap")] fn parse_byte_vec(src: &str) -> ByteVec { src.as_bytes().to_vec() } -/// Represents the request to be performed on the remote machine -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", deny_unknown_fields)] -pub struct Request { - /// A name tied to the requester (tenant) - pub tenant: String, +/// Represents a wrapper around a distant message, supporting single and batch requests +#[derive(Clone, Debug, From, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(untagged)] +pub enum DistantMsg { + Single(T), + Batch(Vec), +} - /// A unique id associated with the request - pub id: usize, +impl DistantMsg { + /// Returns true if msg has a single payload + pub fn is_single(&self) -> bool { + matches!(self, Self::Single(_)) + } - /// The main payload containing a collection of data comprising one or more actions - pub payload: Vec, -} + /// Returns reference to single value if msg is single variant + pub fn as_single(&self) -> Option<&T> { + match self { + Self::Single(x) => Some(x), + _ => None, + } + } -impl Request { - /// Creates a new request, generating a unique id for it - pub fn new(tenant: impl Into, payload: Vec) -> Self { - let id = rand::random(); - Self { - tenant: tenant.into(), - id, - payload, + /// Returns mutable reference to single value if msg is single variant + pub fn as_mut_single(&mut self) -> Option<&T> { + match self { + Self::Single(x) => Some(x), + _ => None, } } - /// Converts to a string representing the type (or types) contained in the payload - pub fn to_payload_type_string(&self) -> String { - self.payload - .iter() - .map(AsRef::as_ref) - .collect::>() - .join(",") + /// Returns the single value if msg is single variant + pub fn into_single(self) -> Option { + match self { + Self::Single(x) => Some(x), + _ => None, + } + } + + /// Returns true if msg has a batch of payloads + pub fn is_batch(&self) -> bool { + matches!(self, Self::Batch(_)) + } + + /// Returns reference to batch value if msg is batch variant + pub fn as_batch(&self) -> Option<&[T]> { + match self { + Self::Batch(x) => Some(x), + _ => None, + } + } + + /// Returns mutable reference to batch value if msg is batch variant + pub fn as_mut_batch(&mut self) -> Option<&mut [T]> { + match self { + Self::Batch(x) => Some(x), + _ => None, + } + } + + /// Returns the batch value if msg is batch variant + pub fn into_batch(self) -> Option> { + match self { + Self::Batch(x) => Some(x), + _ => None, + } + } + + /// Convert into a collection of payload data + pub fn into_vec(self) -> Vec { + match self { + Self::Single(x) => vec![x], + Self::Batch(x) => x, + } + } +} + +#[cfg(feature = "schemars")] +impl DistantMsg { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(DistantMsg) } } /// Represents the payload of a request to be performed on the remote machine -#[derive(Clone, Debug, PartialEq, Eq, AsRefStr, IsVariant, Serialize, Deserialize)] -#[cfg_attr(feature = "structopt", derive(structopt::StructOpt))] +#[derive(Clone, Debug, PartialEq, Eq, IsVariant, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[cfg_attr(feature = "clap", derive(clap::Subcommand))] #[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")] -#[strum(serialize_all = "snake_case")] -pub enum RequestData { +#[cfg_attr(feature = "clap", clap(rename_all = "kebab-case"))] +pub enum DistantRequestData { /// Reads a file from the specified path on the remote machine - #[cfg_attr(feature = "structopt", structopt(visible_aliases = &["cat"]))] + #[cfg_attr(feature = "clap", clap(visible_aliases = &["cat"]))] FileRead { /// The path to the file on the remote machine path: PathBuf, @@ -84,7 +165,7 @@ pub enum RequestData { path: PathBuf, /// Data for server-side writing of content - #[cfg_attr(feature = "structopt", structopt(parse(from_str = parse_byte_vec)))] + #[cfg_attr(feature = "clap", clap(parse(from_str = parse_byte_vec)))] data: ByteVec, }, @@ -104,7 +185,7 @@ pub enum RequestData { path: PathBuf, /// Data for server-side writing of content - #[cfg_attr(feature = "structopt", structopt(parse(from_str = parse_byte_vec)))] + #[cfg_attr(feature = "clap", clap(parse(from_str = parse_byte_vec)))] data: ByteVec, }, @@ -118,7 +199,7 @@ pub enum RequestData { }, /// Reads a directory from the specified path on the remote machine - #[cfg_attr(feature = "structopt", structopt(visible_aliases = &["ls"]))] + #[cfg_attr(feature = "clap", clap(visible_aliases = &["ls"]))] DirRead { /// The path to the directory on the remote machine path: PathBuf, @@ -127,12 +208,12 @@ pub enum RequestData { /// depth and 1 indicating the most immediate children within the /// directory #[serde(default = "one")] - #[cfg_attr(feature = "structopt", structopt(short, long, default_value = "1"))] + #[cfg_attr(feature = "clap", clap(long, default_value = "1"))] depth: usize, /// Whether or not to return absolute or relative paths #[serde(default)] - #[cfg_attr(feature = "structopt", structopt(short, long))] + #[cfg_attr(feature = "clap", clap(long))] absolute: bool, /// Whether or not to canonicalize the resulting paths, meaning @@ -142,7 +223,7 @@ pub enum RequestData { /// Note that the flag absolute must be true to have absolute paths /// returned, even if canonicalize is flagged as true #[serde(default)] - #[cfg_attr(feature = "structopt", structopt(short, long))] + #[cfg_attr(feature = "clap", clap(long))] canonicalize: bool, /// Whether or not to include the root directory in the retrieved @@ -151,24 +232,24 @@ pub enum RequestData { /// If included, the root directory will also be a canonicalized, /// absolute path and will not follow any of the other flags #[serde(default)] - #[cfg_attr(feature = "structopt", structopt(long))] + #[cfg_attr(feature = "clap", clap(long))] include_root: bool, }, /// Creates a directory on the remote machine - #[cfg_attr(feature = "structopt", structopt(visible_aliases = &["mkdir"]))] + #[cfg_attr(feature = "clap", clap(visible_aliases = &["mkdir"]))] DirCreate { /// The path to the directory on the remote machine path: PathBuf, /// Whether or not to create all parent directories #[serde(default)] - #[cfg_attr(feature = "structopt", structopt(short, long))] + #[cfg_attr(feature = "clap", clap(long))] all: bool, }, /// Removes a file or directory on the remote machine - #[cfg_attr(feature = "structopt", structopt(visible_aliases = &["rm"]))] + #[cfg_attr(feature = "clap", clap(visible_aliases = &["rm"]))] Remove { /// The path to the file or directory on the remote machine path: PathBuf, @@ -176,12 +257,12 @@ pub enum RequestData { /// Whether or not to remove all contents within directory if is a directory. /// Does nothing different for files #[serde(default)] - #[cfg_attr(feature = "structopt", structopt(short, long))] + #[cfg_attr(feature = "clap", clap(long))] force: bool, }, /// Copies a file or directory on the remote machine - #[cfg_attr(feature = "structopt", structopt(visible_aliases = &["cp"]))] + #[cfg_attr(feature = "clap", clap(visible_aliases = &["cp"]))] Copy { /// The path to the file or directory on the remote machine src: PathBuf, @@ -191,7 +272,7 @@ pub enum RequestData { }, /// Moves/renames a file or directory on the remote machine - #[cfg_attr(feature = "structopt", structopt(visible_aliases = &["mv"]))] + #[cfg_attr(feature = "clap", clap(visible_aliases = &["mv"]))] Rename { /// The path to the file or directory on the remote machine src: PathBuf, @@ -208,22 +289,22 @@ pub enum RequestData { /// If true, will recursively watch for changes within directories, othewise /// will only watch for changes immediately within directories #[serde(default)] - #[cfg_attr(feature = "structopt", structopt(short, long))] + #[cfg_attr(feature = "clap", clap(long))] recursive: bool, /// Filter to only report back specified changes #[serde(default)] #[cfg_attr( - feature = "structopt", - structopt(short, long, possible_values = &ChangeKind::VARIANTS) + feature = "clap", + clap(long, possible_values = ChangeKind::VARIANTS) )] only: Vec, /// Filter to report back changes except these specified changes #[serde(default)] #[cfg_attr( - feature = "structopt", - structopt(short, long, possible_values = &ChangeKind::VARIANTS) + feature = "clap", + clap(long, possible_values = ChangeKind::VARIANTS) )] except: Vec, }, @@ -249,114 +330,88 @@ pub enum RequestData { /// returning the canonical, absolute form of a path with all /// intermediate components normalized and symbolic links resolved #[serde(default)] - #[cfg_attr(feature = "structopt", structopt(short, long))] + #[cfg_attr(feature = "clap", clap(long))] canonicalize: bool, /// Whether or not to follow symlinks to determine absolute file type (dir/file) #[serde(default)] - #[cfg_attr(feature = "structopt", structopt(long))] + #[cfg_attr(feature = "clap", clap(long))] resolve_file_type: bool, }, /// Spawns a new process on the remote machine - #[cfg_attr(feature = "structopt", structopt(visible_aliases = &["spawn", "run"]))] + #[cfg_attr(feature = "clap", clap(visible_aliases = &["spawn", "run"]))] ProcSpawn { - /// Name of the command to run - cmd: String, + /// The full command to run including arguments + #[cfg_attr(feature = "clap", clap(flatten))] + cmd: Cmd, + + /// Environment to provide to the remote process + #[serde(default)] + #[cfg_attr(feature = "clap", clap(long, default_value_t = Environment::default()))] + environment: Environment, - /// Arguments for the command + /// Alternative current directory for the remote process #[serde(default)] - args: Vec, + #[cfg_attr(feature = "clap", clap(long))] + current_dir: Option, /// Whether or not the process should be persistent, meaning that the process will not be /// killed when the associated client disconnects #[serde(default)] - #[cfg_attr(feature = "structopt", structopt(long))] + #[cfg_attr(feature = "clap", clap(long))] persist: bool, /// If provided, will spawn process in a pty, otherwise spawns directly #[serde(default)] - #[cfg_attr(feature = "structopt", structopt(long))] + #[cfg_attr(feature = "clap", clap(long))] pty: Option, }, /// Kills a process running on the remote machine - #[cfg_attr(feature = "structopt", structopt(visible_aliases = &["kill"]))] + #[cfg_attr(feature = "clap", clap(visible_aliases = &["kill"]))] ProcKill { /// Id of the actively-running process - id: usize, + id: ProcessId, }, /// Sends additional data to stdin of running process ProcStdin { /// Id of the actively-running process to send stdin data - id: usize, + id: ProcessId, /// Data to send to a process's stdin pipe + #[serde(with = "serde_bytes")] + #[cfg_attr(feature = "schemars", schemars(with = "Vec"))] data: Vec, }, /// Resize pty of remote process ProcResizePty { /// Id of the actively-running process whose pty to resize - id: usize, + id: ProcessId, /// The new pty dimensions size: PtySize, }, - /// Retrieve a list of all processes being managed by the remote server - ProcList {}, - /// Retrieve information about the server and the system it is on SystemInfo {}, } -/// Represents an response to a request performed on the remote machine -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", deny_unknown_fields)] -pub struct Response { - /// A name tied to the requester (tenant) - pub tenant: String, - - /// A unique id associated with the response - pub id: usize, - - /// The id of the originating request that yielded this response - /// (more than one response may have same origin) - pub origin_id: usize, - - /// The main payload containing a collection of data comprising one or more results - pub payload: Vec, -} - -impl Response { - /// Creates a new response, generating a unique id for it - pub fn new(tenant: impl Into, origin_id: usize, payload: Vec) -> Self { - let id = rand::random(); - Self { - tenant: tenant.into(), - id, - origin_id, - payload, - } - } - - /// Converts to a string representing the type (or types) contained in the payload - pub fn to_payload_type_string(&self) -> String { - self.payload - .iter() - .map(AsRef::as_ref) - .collect::>() - .join(",") +#[cfg(feature = "schemars")] +impl DistantRequestData { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(DistantRequestData) } } /// Represents the payload of a successful response #[derive(Clone, Debug, PartialEq, Eq, AsRefStr, IsVariant, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")] #[strum(serialize_all = "snake_case")] -pub enum ResponseData { +pub enum DistantResponseData { /// General okay with no extra data, returned in cases like /// creating or removing a directory, copying a file, or renaming /// a file @@ -368,6 +423,8 @@ pub enum ResponseData { /// Response containing some arbitrary, binary data Blob { /// Binary data associated with the response + #[serde(with = "serde_bytes")] + #[cfg_attr(feature = "schemars", schemars(with = "Vec"))] data: Vec, }, @@ -398,31 +455,35 @@ pub enum ResponseData { /// Response to starting a new process ProcSpawned { /// Arbitrary id associated with running process - id: usize, + id: ProcessId, }, /// Actively-transmitted stdout as part of running process ProcStdout { /// Arbitrary id associated with running process - id: usize, + id: ProcessId, /// Data read from a process' stdout pipe + #[serde(with = "serde_bytes")] + #[cfg_attr(feature = "schemars", schemars(with = "Vec"))] data: Vec, }, /// Actively-transmitted stderr as part of running process ProcStderr { /// Arbitrary id associated with running process - id: usize, + id: ProcessId, /// Data read from a process' stderr pipe + #[serde(with = "serde_bytes")] + #[cfg_attr(feature = "schemars", schemars(with = "Vec"))] data: Vec, }, /// Response to a process finishing ProcDone { /// Arbitrary id associated with running process - id: usize, + id: ProcessId, /// Whether or not termination was successful success: bool, @@ -431,1217 +492,23 @@ pub enum ResponseData { code: Option, }, - /// Response to retrieving a list of managed processes - ProcEntries { - /// List of managed processes - entries: Vec, - }, - /// Response to retrieving information about the server and the system it is on SystemInfo(SystemInfo), } -/// Represents the size associated with a remote PTY -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct PtySize { - /// Number of lines of text - pub rows: u16, - - /// Number of columns of text - pub cols: u16, - - /// Width of a cell in pixels. Note that some systems never fill this value and ignore it. - #[serde(default)] - pub pixel_width: u16, - - /// Height of a cell in pixels. Note that some systems never fill this value and ignore it. - #[serde(default)] - pub pixel_height: u16, -} - -impl PtySize { - /// Creates new size using just rows and columns - pub fn from_rows_and_cols(rows: u16, cols: u16) -> Self { - Self { - rows, - cols, - ..Default::default() - } - } -} - -impl From for PtySize { - fn from(size: PortablePtySize) -> Self { - Self { - rows: size.rows, - cols: size.cols, - pixel_width: size.pixel_width, - pixel_height: size.pixel_height, - } - } -} - -impl From for PortablePtySize { - fn from(size: PtySize) -> Self { - Self { - rows: size.rows, - cols: size.cols, - pixel_width: size.pixel_width, - pixel_height: size.pixel_height, - } - } -} - -impl Default for PtySize { - fn default() -> Self { - PtySize { - rows: 24, - cols: 80, - pixel_width: 0, - pixel_height: 0, - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Display, Error)] -pub enum PtySizeParseError { - MissingRows, - MissingColumns, - InvalidRows(ParseIntError), - InvalidColumns(ParseIntError), - InvalidPixelWidth(ParseIntError), - InvalidPixelHeight(ParseIntError), -} - -impl FromStr for PtySize { - type Err = PtySizeParseError; - - /// Attempts to parse a str into PtySize using one of the following formats: - /// - /// * rows,cols (defaults to 0 for pixel_width & pixel_height) - /// * rows,cols,pixel_width,pixel_height - fn from_str(s: &str) -> Result { - let mut tokens = s.split(','); - - Ok(Self { - rows: tokens - .next() - .ok_or(PtySizeParseError::MissingRows)? - .trim() - .parse() - .map_err(PtySizeParseError::InvalidRows)?, - cols: tokens - .next() - .ok_or(PtySizeParseError::MissingColumns)? - .trim() - .parse() - .map_err(PtySizeParseError::InvalidColumns)?, - pixel_width: tokens - .next() - .map(|s| s.trim().parse()) - .transpose() - .map_err(PtySizeParseError::InvalidPixelWidth)? - .unwrap_or(0), - pixel_height: tokens - .next() - .map(|s| s.trim().parse()) - .transpose() - .map_err(PtySizeParseError::InvalidPixelHeight)? - .unwrap_or(0), - }) - } -} - -/// Represents metadata about some path on a remote machine -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct Metadata { - /// Canonicalized path to the file or directory, resolving symlinks, only included - /// if flagged during the request - pub canonicalized_path: Option, - - /// Represents the type of the entry as a file/dir/symlink - pub file_type: FileType, - - /// Size of the file/directory/symlink in bytes - pub len: u64, - - /// Whether or not the file/directory/symlink is marked as unwriteable - pub readonly: bool, - - /// Represents the last time (in milliseconds) when the file/directory/symlink was accessed; - /// can be optional as certain systems don't support this - #[serde(serialize_with = "serialize_u128_option")] - #[serde(deserialize_with = "deserialize_u128_option")] - pub accessed: Option, - - /// Represents when (in milliseconds) the file/directory/symlink was created; - /// can be optional as certain systems don't support this - #[serde(serialize_with = "serialize_u128_option")] - #[serde(deserialize_with = "deserialize_u128_option")] - pub created: Option, - - /// Represents the last time (in milliseconds) when the file/directory/symlink was modified; - /// can be optional as certain systems don't support this - #[serde(serialize_with = "serialize_u128_option")] - #[serde(deserialize_with = "deserialize_u128_option")] - pub modified: Option, - - /// Represents metadata that is specific to a unix remote machine - pub unix: Option, - - /// Represents metadata that is specific to a windows remote machine - pub windows: Option, -} - -/// Represents unix-specific metadata about some path on a remote machine -#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct UnixMetadata { - /// Represents whether or not owner can read from the file - pub owner_read: bool, - - /// Represents whether or not owner can write to the file - pub owner_write: bool, - - /// Represents whether or not owner can execute the file - pub owner_exec: bool, - - /// Represents whether or not associated group can read from the file - pub group_read: bool, - - /// Represents whether or not associated group can write to the file - pub group_write: bool, - - /// Represents whether or not associated group can execute the file - pub group_exec: bool, - - /// Represents whether or not other can read from the file - pub other_read: bool, - - /// Represents whether or not other can write to the file - pub other_write: bool, - - /// Represents whether or not other can execute the file - pub other_exec: bool, -} - -impl From for UnixMetadata { - /// Create from a unix mode bitset - fn from(mode: u32) -> Self { - let flags = UnixFilePermissionFlags::from_bits_truncate(mode); - Self { - owner_read: flags.contains(UnixFilePermissionFlags::OWNER_READ), - owner_write: flags.contains(UnixFilePermissionFlags::OWNER_WRITE), - owner_exec: flags.contains(UnixFilePermissionFlags::OWNER_EXEC), - group_read: flags.contains(UnixFilePermissionFlags::GROUP_READ), - group_write: flags.contains(UnixFilePermissionFlags::GROUP_WRITE), - group_exec: flags.contains(UnixFilePermissionFlags::GROUP_EXEC), - other_read: flags.contains(UnixFilePermissionFlags::OTHER_READ), - other_write: flags.contains(UnixFilePermissionFlags::OTHER_WRITE), - other_exec: flags.contains(UnixFilePermissionFlags::OTHER_EXEC), - } - } -} - -impl From for u32 { - /// Convert to a unix mode bitset - fn from(metadata: UnixMetadata) -> Self { - let mut flags = UnixFilePermissionFlags::empty(); - - if metadata.owner_read { - flags.insert(UnixFilePermissionFlags::OWNER_READ); - } - if metadata.owner_write { - flags.insert(UnixFilePermissionFlags::OWNER_WRITE); - } - if metadata.owner_exec { - flags.insert(UnixFilePermissionFlags::OWNER_EXEC); - } - - if metadata.group_read { - flags.insert(UnixFilePermissionFlags::GROUP_READ); - } - if metadata.group_write { - flags.insert(UnixFilePermissionFlags::GROUP_WRITE); - } - if metadata.group_exec { - flags.insert(UnixFilePermissionFlags::GROUP_EXEC); - } - - if metadata.other_read { - flags.insert(UnixFilePermissionFlags::OTHER_READ); - } - if metadata.other_write { - flags.insert(UnixFilePermissionFlags::OTHER_WRITE); - } - if metadata.other_exec { - flags.insert(UnixFilePermissionFlags::OTHER_EXEC); - } - - flags.bits - } -} - -impl UnixMetadata { - pub fn is_readonly(self) -> bool { - !(self.owner_read || self.group_read || self.other_read) - } -} - -bitflags! { - struct UnixFilePermissionFlags: u32 { - const OWNER_READ = 0o400; - const OWNER_WRITE = 0o200; - const OWNER_EXEC = 0o100; - const GROUP_READ = 0o40; - const GROUP_WRITE = 0o20; - const GROUP_EXEC = 0o10; - const OTHER_READ = 0o4; - const OTHER_WRITE = 0o2; - const OTHER_EXEC = 0o1; +#[cfg(feature = "schemars")] +impl DistantResponseData { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(DistantResponseData) } } -/// Represents windows-specific metadata about some path on a remote machine -#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct WindowsMetadata { - /// Represents whether or not a file or directory is an archive - pub archive: bool, - - /// Represents whether or not a file or directory is compressed - pub compressed: bool, - - /// Represents whether or not the file or directory is encrypted - pub encrypted: bool, - - /// Represents whether or not a file or directory is hidden - pub hidden: bool, - - /// Represents whether or not a directory or user data stream is configured with integrity - pub integrity_stream: bool, - - /// Represents whether or not a file does not have other attributes set - pub normal: bool, - - /// Represents whether or not a file or directory is not to be indexed by content indexing - /// service - pub not_content_indexed: bool, - - /// Represents whether or not a user data stream is not to be read by the background data - /// integrity scanner - pub no_scrub_data: bool, - - /// Represents whether or not the data of a file is not available immediately - pub offline: bool, - - /// Represents whether or not a file or directory is not fully present locally - pub recall_on_data_access: bool, - - /// Represents whether or not a file or directory has no physical representation on the local - /// system (is virtual) - pub recall_on_open: bool, - - /// Represents whether or not a file or directory has an associated reparse point, or a file is - /// a symbolic link - pub reparse_point: bool, - - /// Represents whether or not a file is a sparse file - pub sparse_file: bool, - - /// Represents whether or not a file or directory is used partially or exclusively by the - /// operating system - pub system: bool, - - /// Represents whether or not a file is being used for temporary storage - pub temporary: bool, -} - -impl From for WindowsMetadata { - /// Create from a windows file attribute bitset - fn from(file_attributes: u32) -> Self { - let flags = WindowsFileAttributeFlags::from_bits_truncate(file_attributes); - Self { - archive: flags.contains(WindowsFileAttributeFlags::ARCHIVE), - compressed: flags.contains(WindowsFileAttributeFlags::COMPRESSED), - encrypted: flags.contains(WindowsFileAttributeFlags::ENCRYPTED), - hidden: flags.contains(WindowsFileAttributeFlags::HIDDEN), - integrity_stream: flags.contains(WindowsFileAttributeFlags::INTEGRITY_SYSTEM), - normal: flags.contains(WindowsFileAttributeFlags::NORMAL), - not_content_indexed: flags.contains(WindowsFileAttributeFlags::NOT_CONTENT_INDEXED), - no_scrub_data: flags.contains(WindowsFileAttributeFlags::NO_SCRUB_DATA), - offline: flags.contains(WindowsFileAttributeFlags::OFFLINE), - recall_on_data_access: flags.contains(WindowsFileAttributeFlags::RECALL_ON_DATA_ACCESS), - recall_on_open: flags.contains(WindowsFileAttributeFlags::RECALL_ON_OPEN), - reparse_point: flags.contains(WindowsFileAttributeFlags::REPARSE_POINT), - sparse_file: flags.contains(WindowsFileAttributeFlags::SPARSE_FILE), - system: flags.contains(WindowsFileAttributeFlags::SYSTEM), - temporary: flags.contains(WindowsFileAttributeFlags::TEMPORARY), - } - } -} - -impl From for u32 { - /// Convert to a windows file attribute bitset - fn from(metadata: WindowsMetadata) -> Self { - let mut flags = WindowsFileAttributeFlags::empty(); - - if metadata.archive { - flags.insert(WindowsFileAttributeFlags::ARCHIVE); - } - if metadata.compressed { - flags.insert(WindowsFileAttributeFlags::COMPRESSED); - } - if metadata.encrypted { - flags.insert(WindowsFileAttributeFlags::ENCRYPTED); - } - if metadata.hidden { - flags.insert(WindowsFileAttributeFlags::HIDDEN); - } - if metadata.integrity_stream { - flags.insert(WindowsFileAttributeFlags::INTEGRITY_SYSTEM); - } - if metadata.normal { - flags.insert(WindowsFileAttributeFlags::NORMAL); - } - if metadata.not_content_indexed { - flags.insert(WindowsFileAttributeFlags::NOT_CONTENT_INDEXED); - } - if metadata.no_scrub_data { - flags.insert(WindowsFileAttributeFlags::NO_SCRUB_DATA); - } - if metadata.offline { - flags.insert(WindowsFileAttributeFlags::OFFLINE); - } - if metadata.recall_on_data_access { - flags.insert(WindowsFileAttributeFlags::RECALL_ON_DATA_ACCESS); - } - if metadata.recall_on_open { - flags.insert(WindowsFileAttributeFlags::RECALL_ON_OPEN); - } - if metadata.reparse_point { - flags.insert(WindowsFileAttributeFlags::REPARSE_POINT); - } - if metadata.sparse_file { - flags.insert(WindowsFileAttributeFlags::SPARSE_FILE); - } - if metadata.system { - flags.insert(WindowsFileAttributeFlags::SYSTEM); - } - if metadata.temporary { - flags.insert(WindowsFileAttributeFlags::TEMPORARY); - } - - flags.bits - } -} - -bitflags! { - struct WindowsFileAttributeFlags: u32 { - const ARCHIVE = 0x20; - const COMPRESSED = 0x800; - const ENCRYPTED = 0x4000; - const HIDDEN = 0x2; - const INTEGRITY_SYSTEM = 0x8000; - const NORMAL = 0x80; - const NOT_CONTENT_INDEXED = 0x2000; - const NO_SCRUB_DATA = 0x20000; - const OFFLINE = 0x1000; - const RECALL_ON_DATA_ACCESS = 0x400000; - const RECALL_ON_OPEN = 0x40000; - const REPARSE_POINT = 0x400; - const SPARSE_FILE = 0x200; - const SYSTEM = 0x4; - const TEMPORARY = 0x100; - const VIRTUAL = 0x10000; - } -} - -pub(crate) fn deserialize_u128_option<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - match Option::::deserialize(deserializer)? { - Some(s) => match s.parse::() { - Ok(value) => Ok(Some(value)), - Err(error) => Err(serde::de::Error::custom(format!( - "Cannot convert to u128 with error: {:?}", - error - ))), - }, - None => Ok(None), - } -} - -pub(crate) fn serialize_u128_option( - val: &Option, - s: S, -) -> Result { - match val { - Some(v) => format!("{}", *v).serialize(s), - None => s.serialize_unit(), - } -} - -/// Represents information about a system -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct SystemInfo { - /// Family of the operating system as described in - /// https://doc.rust-lang.org/std/env/consts/constant.FAMILY.html - pub family: String, - - /// Name of the specific operating system as described in - /// https://doc.rust-lang.org/std/env/consts/constant.OS.html - pub os: String, - - /// Architecture of the CPI as described in - /// https://doc.rust-lang.org/std/env/consts/constant.ARCH.html - pub arch: String, - - /// Current working directory of the running server process - pub current_dir: PathBuf, - - /// Primary separator for path components for the current platform - /// as defined in https://doc.rust-lang.org/std/path/constant.MAIN_SEPARATOR.html - pub main_separator: char, -} - -/// Represents information about a single entry within a directory -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", deny_unknown_fields)] -pub struct DirEntry { - /// Represents the full path to the entry - pub path: PathBuf, - - /// Represents the type of the entry as a file/dir/symlink - pub file_type: FileType, - - /// Depth at which this entry was created relative to the root (0 being immediately within - /// root) - pub depth: usize, -} - -/// Represents the type associated with a dir entry -#[derive(Copy, Clone, Debug, PartialEq, Eq, AsRefStr, IsVariant, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", deny_unknown_fields)] -#[strum(serialize_all = "snake_case")] -pub enum FileType { - Dir, - File, - Symlink, -} - -/// Represents information about a running process -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", deny_unknown_fields)] -pub struct RunningProcess { - /// Name of the command being run - pub cmd: String, - - /// Arguments for the command - pub args: Vec, - - /// Whether or not the process was run in persist mode - pub persist: bool, - - /// Pty associated with running process if it has one - pub pty: Option, - - /// Arbitrary id associated with running process - /// - /// Not the same as the process' pid! - pub id: usize, -} - -impl From for ResponseData { +impl From for DistantResponseData { fn from(x: io::Error) -> Self { Self::Error(Error::from(x)) } } -impl From for ResponseData { - fn from(x: walkdir::Error) -> Self { - Self::Error(Error::from(x)) - } -} - -impl From for ResponseData { - fn from(x: notify::Error) -> Self { - Self::Error(Error::from(x)) - } -} - -impl From for ResponseData { - fn from(x: tokio::task::JoinError) -> Self { - Self::Error(Error::from(x)) - } -} - -/// Change to one or more paths on the filesystem -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", deny_unknown_fields)] -pub struct Change { - /// Label describing the kind of change - pub kind: ChangeKind, - - /// Paths that were changed - pub paths: Vec, -} - -impl From for Change { - fn from(x: NotifyEvent) -> Self { - Self { - kind: x.kind.into(), - paths: x.paths, - } - } -} - -#[derive( - Copy, - Clone, - Debug, - strum::Display, - EnumString, - EnumVariantNames, - Hash, - PartialEq, - Eq, - PartialOrd, - Ord, - Serialize, - Deserialize, -)] -#[serde(rename_all = "snake_case", deny_unknown_fields)] -#[strum(serialize_all = "snake_case")] -pub enum ChangeKind { - /// Something about a file or directory was accessed, but - /// no specific details were known - Access, - - /// A file was closed for executing - AccessCloseExecute, - - /// A file was closed for reading - AccessCloseRead, - - /// A file was closed for writing - AccessCloseWrite, - - /// A file was opened for executing - AccessOpenExecute, - - /// A file was opened for reading - AccessOpenRead, - - /// A file was opened for writing - AccessOpenWrite, - - /// A file or directory was read - AccessRead, - - /// The access time of a file or directory was changed - AccessTime, - - /// A file, directory, or something else was created - Create, - - /// The content of a file or directory changed - Content, - - /// The data of a file or directory was modified, but - /// no specific details were known - Data, - - /// The metadata of a file or directory was modified, but - /// no specific details were known - Metadata, - - /// Something about a file or directory was modified, but - /// no specific details were known - Modify, - - /// A file, directory, or something else was removed - Remove, - - /// A file or directory was renamed, but no specific details were known - Rename, - - /// A file or directory was renamed, and the provided paths - /// are the source and target in that order (from, to) - RenameBoth, - - /// A file or directory was renamed, and the provided path - /// is the origin of the rename (before being renamed) - RenameFrom, - - /// A file or directory was renamed, and the provided path - /// is the result of the rename - RenameTo, - - /// A file's size changed - Size, - - /// The ownership of a file or directory was changed - Ownership, - - /// The permissions of a file or directory was changed - Permissions, - - /// The write or modify time of a file or directory was changed - WriteTime, - - // Catchall in case we have no insight as to the type of change - Unknown, -} - -impl ChangeKind { - /// Returns true if the change is a kind of access - pub fn is_access_kind(&self) -> bool { - self.is_open_access_kind() - || self.is_close_access_kind() - || matches!(self, Self::Access | Self::AccessRead) - } - - /// Returns true if the change is a kind of open access - pub fn is_open_access_kind(&self) -> bool { - matches!( - self, - Self::AccessOpenExecute | Self::AccessOpenRead | Self::AccessOpenWrite - ) - } - - /// Returns true if the change is a kind of close access - pub fn is_close_access_kind(&self) -> bool { - matches!( - self, - Self::AccessCloseExecute | Self::AccessCloseRead | Self::AccessCloseWrite - ) - } - - /// Returns true if the change is a kind of creation - pub fn is_create_kind(&self) -> bool { - matches!(self, Self::Create) - } - - /// Returns true if the change is a kind of modification - pub fn is_modify_kind(&self) -> bool { - self.is_data_modify_kind() || self.is_metadata_modify_kind() || matches!(self, Self::Modify) - } - - /// Returns true if the change is a kind of data modification - pub fn is_data_modify_kind(&self) -> bool { - matches!(self, Self::Content | Self::Data | Self::Size) - } - - /// Returns true if the change is a kind of metadata modification - pub fn is_metadata_modify_kind(&self) -> bool { - matches!( - self, - Self::AccessTime - | Self::Metadata - | Self::Ownership - | Self::Permissions - | Self::WriteTime - ) - } - - /// Returns true if the change is a kind of rename - pub fn is_rename_kind(&self) -> bool { - matches!( - self, - Self::Rename | Self::RenameBoth | Self::RenameFrom | Self::RenameTo - ) - } - - /// Returns true if the change is a kind of removal - pub fn is_remove_kind(&self) -> bool { - matches!(self, Self::Remove) - } - - /// Returns true if the change kind is unknown - pub fn is_unknown_kind(&self) -> bool { - matches!(self, Self::Unknown) - } -} - -impl BitOr for ChangeKind { - type Output = ChangeKindSet; - - fn bitor(self, rhs: Self) -> Self::Output { - let mut set = ChangeKindSet::empty(); - set.insert(self); - set.insert(rhs); - set - } -} - -impl From for ChangeKind { - fn from(x: NotifyEventKind) -> Self { - use notify::event::{ - AccessKind, AccessMode, DataChange, MetadataKind, ModifyKind, RenameMode, - }; - match x { - // File/directory access events - NotifyEventKind::Access(AccessKind::Read) => Self::AccessRead, - NotifyEventKind::Access(AccessKind::Open(AccessMode::Execute)) => { - Self::AccessOpenExecute - } - NotifyEventKind::Access(AccessKind::Open(AccessMode::Read)) => Self::AccessOpenRead, - NotifyEventKind::Access(AccessKind::Open(AccessMode::Write)) => Self::AccessOpenWrite, - NotifyEventKind::Access(AccessKind::Close(AccessMode::Execute)) => { - Self::AccessCloseExecute - } - NotifyEventKind::Access(AccessKind::Close(AccessMode::Read)) => Self::AccessCloseRead, - NotifyEventKind::Access(AccessKind::Close(AccessMode::Write)) => Self::AccessCloseWrite, - NotifyEventKind::Access(_) => Self::Access, - - // File/directory creation events - NotifyEventKind::Create(_) => Self::Create, - - // Rename-oriented events - NotifyEventKind::Modify(ModifyKind::Name(RenameMode::Both)) => Self::RenameBoth, - NotifyEventKind::Modify(ModifyKind::Name(RenameMode::From)) => Self::RenameFrom, - NotifyEventKind::Modify(ModifyKind::Name(RenameMode::To)) => Self::RenameTo, - NotifyEventKind::Modify(ModifyKind::Name(_)) => Self::Rename, - - // Data-modification events - NotifyEventKind::Modify(ModifyKind::Data(DataChange::Content)) => Self::Content, - NotifyEventKind::Modify(ModifyKind::Data(DataChange::Size)) => Self::Size, - NotifyEventKind::Modify(ModifyKind::Data(_)) => Self::Data, - - // Metadata-modification events - NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::AccessTime)) => { - Self::AccessTime - } - NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::WriteTime)) => { - Self::WriteTime - } - NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::Permissions)) => { - Self::Permissions - } - NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::Ownership)) => { - Self::Ownership - } - NotifyEventKind::Modify(ModifyKind::Metadata(_)) => Self::Metadata, - - // General modification events - NotifyEventKind::Modify(_) => Self::Modify, - - // File/directory removal events - NotifyEventKind::Remove(_) => Self::Remove, - - // Catch-all for other events - NotifyEventKind::Any | NotifyEventKind::Other => Self::Unknown, - } - } -} - -/// Represents a distinct set of different change kinds -#[derive( - Clone, Debug, Deref, DerefMut, Display, IntoIterator, PartialEq, Eq, Serialize, Deserialize, -)] -#[display( - fmt = "{}", - "_0.iter().map(ToString::to_string).collect::>().join(\",\")" -)] -pub struct ChangeKindSet(HashSet); - -impl ChangeKindSet { - /// Produces an empty set of [`ChangeKind`] - pub fn empty() -> Self { - Self(HashSet::new()) - } - - /// Produces a set of all [`ChangeKind`] - pub fn all() -> Self { - vec![ - ChangeKind::Access, - ChangeKind::AccessCloseExecute, - ChangeKind::AccessCloseRead, - ChangeKind::AccessCloseWrite, - ChangeKind::AccessOpenExecute, - ChangeKind::AccessOpenRead, - ChangeKind::AccessOpenWrite, - ChangeKind::AccessRead, - ChangeKind::AccessTime, - ChangeKind::Create, - ChangeKind::Content, - ChangeKind::Data, - ChangeKind::Metadata, - ChangeKind::Modify, - ChangeKind::Remove, - ChangeKind::Rename, - ChangeKind::RenameBoth, - ChangeKind::RenameFrom, - ChangeKind::RenameTo, - ChangeKind::Size, - ChangeKind::Ownership, - ChangeKind::Permissions, - ChangeKind::WriteTime, - ChangeKind::Unknown, - ] - .into_iter() - .collect() - } - - /// Produces a changeset containing all of the access kinds - pub fn access_set() -> Self { - Self::access_open_set() - | Self::access_close_set() - | ChangeKind::AccessRead - | ChangeKind::Access - } - - /// Produces a changeset containing all of the open access kinds - pub fn access_open_set() -> Self { - ChangeKind::AccessOpenExecute | ChangeKind::AccessOpenRead | ChangeKind::AccessOpenWrite - } - - /// Produces a changeset containing all of the close access kinds - pub fn access_close_set() -> Self { - ChangeKind::AccessCloseExecute | ChangeKind::AccessCloseRead | ChangeKind::AccessCloseWrite - } - - // Produces a changeset containing all of the modification kinds - pub fn modify_set() -> Self { - Self::modify_data_set() | Self::modify_metadata_set() | ChangeKind::Modify - } - - /// Produces a changeset containing all of the data modification kinds - pub fn modify_data_set() -> Self { - ChangeKind::Content | ChangeKind::Data | ChangeKind::Size - } - - /// Produces a changeset containing all of the metadata modification kinds - pub fn modify_metadata_set() -> Self { - ChangeKind::AccessTime - | ChangeKind::Metadata - | ChangeKind::Ownership - | ChangeKind::Permissions - | ChangeKind::WriteTime - } - - /// Produces a changeset containing all of the rename kinds - pub fn rename_set() -> Self { - ChangeKind::Rename | ChangeKind::RenameBoth | ChangeKind::RenameFrom | ChangeKind::RenameTo - } - - /// Consumes set and returns a vec of the kinds of changes - pub fn into_vec(self) -> Vec { - self.0.into_iter().collect() - } -} - -impl BitOr for ChangeKindSet { - type Output = Self; - - fn bitor(mut self, rhs: ChangeKindSet) -> Self::Output { - self.extend(rhs.0); - self - } -} - -impl BitOr for ChangeKindSet { - type Output = Self; - - fn bitor(mut self, rhs: ChangeKind) -> Self::Output { - self.0.insert(rhs); - self - } -} - -impl BitOr for ChangeKind { - type Output = ChangeKindSet; - - fn bitor(self, rhs: ChangeKindSet) -> Self::Output { - rhs | self - } -} - -impl FromStr for ChangeKindSet { - type Err = strum::ParseError; - - fn from_str(s: &str) -> Result { - let mut change_set = HashSet::new(); - - for word in s.split(',') { - change_set.insert(ChangeKind::from_str(word.trim())?); - } - - Ok(ChangeKindSet(change_set)) - } -} - -impl FromIterator for ChangeKindSet { - fn from_iter>(iter: I) -> Self { - let mut change_set = HashSet::new(); - - for i in iter { - change_set.insert(i); - } - - ChangeKindSet(change_set) - } -} - -impl From for ChangeKindSet { - fn from(change_kind: ChangeKind) -> Self { - let mut set = Self::empty(); - set.insert(change_kind); - set - } -} - -impl From> for ChangeKindSet { - fn from(changes: Vec) -> Self { - changes.into_iter().collect() - } -} - -impl Default for ChangeKindSet { - fn default() -> Self { - Self::empty() - } -} - -/// General purpose error type that can be sent across the wire -#[derive(Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] -#[display(fmt = "{}: {}", kind, description)] -#[serde(rename_all = "snake_case", deny_unknown_fields)] -pub struct Error { - /// Label describing the kind of error - pub kind: ErrorKind, - - /// Description of the error itself - pub description: String, -} - -impl std::error::Error for Error {} - -impl<'a> From<&'a str> for Error { - fn from(x: &'a str) -> Self { - Self::from(x.to_string()) - } -} - -impl From for Error { - fn from(x: String) -> Self { - Self { - kind: ErrorKind::Other, - description: x, - } - } -} - -impl From for Error { - fn from(x: io::Error) -> Self { - Self { - kind: ErrorKind::from(x.kind()), - description: format!("{}", x), - } - } -} - -impl From for io::Error { - fn from(x: Error) -> Self { - Self::new(x.kind.into(), x.description) - } -} - -impl From for Error { - fn from(x: notify::Error) -> Self { - let err = match x.kind { - NotifyErrorKind::Generic(x) => Self { - kind: ErrorKind::Other, - description: x, - }, - NotifyErrorKind::Io(x) => Self::from(x), - NotifyErrorKind::PathNotFound => Self { - kind: ErrorKind::Other, - description: String::from("Path not found"), - }, - NotifyErrorKind::WatchNotFound => Self { - kind: ErrorKind::Other, - description: String::from("Watch not found"), - }, - NotifyErrorKind::InvalidConfig(_) => Self { - kind: ErrorKind::Other, - description: String::from("Invalid config"), - }, - NotifyErrorKind::MaxFilesWatch => Self { - kind: ErrorKind::Other, - description: String::from("Max files watched"), - }, - }; - - Self { - kind: err.kind, - description: format!( - "{}\n\nPaths: {}", - err.description, - x.paths - .into_iter() - .map(|p| p.to_string_lossy().to_string()) - .collect::>() - .join(", ") - ), - } - } -} - -impl From for Error { - fn from(x: walkdir::Error) -> Self { - if x.io_error().is_some() { - x.into_io_error().map(Self::from).unwrap() - } else { - Self { - kind: ErrorKind::Loop, - description: format!("{}", x), - } - } - } -} - -impl From for Error { - fn from(x: tokio::task::JoinError) -> Self { - Self { - kind: if x.is_cancelled() { - ErrorKind::TaskCancelled - } else { - ErrorKind::TaskPanicked - }, - description: format!("{}", x), - } - } -} - -/// All possible kinds of errors that can be returned -#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", deny_unknown_fields)] -pub enum ErrorKind { - /// An entity was not found, often a file - NotFound, - - /// The operation lacked the necessary privileges to complete - PermissionDenied, - - /// The connection was refused by the remote server - ConnectionRefused, - - /// The connection was reset by the remote server - ConnectionReset, - - /// The connection was aborted (terminated) by the remote server - ConnectionAborted, - - /// The network operation failed because it was not connected yet - NotConnected, - - /// A socket address could not be bound because the address is already in use elsewhere - AddrInUse, - - /// A nonexistent interface was requested or the requested address was not local - AddrNotAvailable, - - /// The operation failed because a pipe was closed - BrokenPipe, - - /// An entity already exists, often a file - AlreadyExists, - - /// The operation needs to block to complete, but the blocking operation was requested to not - /// occur - WouldBlock, - - /// A parameter was incorrect - InvalidInput, - - /// Data not valid for the operation were encountered - InvalidData, - - /// The I/O operation's timeout expired, causing it to be cancelled - TimedOut, - - /// An error returned when an operation could not be completed because a - /// call to `write` returned `Ok(0)` - WriteZero, - - /// This operation was interrupted - Interrupted, - - /// Any I/O error not part of this list - Other, - - /// An error returned when an operation could not be completed because an "end of file" was - /// reached prematurely - UnexpectedEof, - - /// When a loop is encountered when walking a directory - Loop, - - /// When a task is cancelled - TaskCancelled, - - /// When a task panics - TaskPanicked, - - /// Catchall for an error that has no specific type - Unknown, -} - -impl From for ErrorKind { - fn from(kind: io::ErrorKind) -> Self { - match kind { - io::ErrorKind::NotFound => Self::NotFound, - io::ErrorKind::PermissionDenied => Self::PermissionDenied, - io::ErrorKind::ConnectionRefused => Self::ConnectionRefused, - io::ErrorKind::ConnectionReset => Self::ConnectionReset, - io::ErrorKind::ConnectionAborted => Self::ConnectionAborted, - io::ErrorKind::NotConnected => Self::NotConnected, - io::ErrorKind::AddrInUse => Self::AddrInUse, - io::ErrorKind::AddrNotAvailable => Self::AddrNotAvailable, - io::ErrorKind::BrokenPipe => Self::BrokenPipe, - io::ErrorKind::AlreadyExists => Self::AlreadyExists, - io::ErrorKind::WouldBlock => Self::WouldBlock, - io::ErrorKind::InvalidInput => Self::InvalidInput, - io::ErrorKind::InvalidData => Self::InvalidData, - io::ErrorKind::TimedOut => Self::TimedOut, - io::ErrorKind::WriteZero => Self::WriteZero, - io::ErrorKind::Interrupted => Self::Interrupted, - io::ErrorKind::Other => Self::Other, - io::ErrorKind::UnexpectedEof => Self::UnexpectedEof, - - // This exists because io::ErrorKind is non_exhaustive - _ => Self::Unknown, - } - } -} - -impl From for io::ErrorKind { - fn from(kind: ErrorKind) -> Self { - match kind { - ErrorKind::NotFound => Self::NotFound, - ErrorKind::PermissionDenied => Self::PermissionDenied, - ErrorKind::ConnectionRefused => Self::ConnectionRefused, - ErrorKind::ConnectionReset => Self::ConnectionReset, - ErrorKind::ConnectionAborted => Self::ConnectionAborted, - ErrorKind::NotConnected => Self::NotConnected, - ErrorKind::AddrInUse => Self::AddrInUse, - ErrorKind::AddrNotAvailable => Self::AddrNotAvailable, - ErrorKind::BrokenPipe => Self::BrokenPipe, - ErrorKind::AlreadyExists => Self::AlreadyExists, - ErrorKind::WouldBlock => Self::WouldBlock, - ErrorKind::InvalidInput => Self::InvalidInput, - ErrorKind::InvalidData => Self::InvalidData, - ErrorKind::TimedOut => Self::TimedOut, - ErrorKind::WriteZero => Self::WriteZero, - ErrorKind::Interrupted => Self::Interrupted, - ErrorKind::Other => Self::Other, - ErrorKind::UnexpectedEof => Self::UnexpectedEof, - _ => Self::Other, - } - } -} - /// Used to provide a default serde value of 1 const fn one() -> usize { 1 diff --git a/distant-core/src/data/change.rs b/distant-core/src/data/change.rs new file mode 100644 index 0000000..593f5ae --- /dev/null +++ b/distant-core/src/data/change.rs @@ -0,0 +1,506 @@ +use derive_more::{Deref, DerefMut, IntoIterator}; +use notify::{event::Event as NotifyEvent, EventKind as NotifyEventKind}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashSet, + fmt, + hash::{Hash, Hasher}, + iter::FromIterator, + ops::{BitOr, Sub}, + path::PathBuf, + str::FromStr, +}; +use strum::{EnumString, EnumVariantNames}; + +/// Change to one or more paths on the filesystem +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +pub struct Change { + /// Label describing the kind of change + pub kind: ChangeKind, + + /// Paths that were changed + pub paths: Vec, +} + +#[cfg(feature = "schemars")] +impl Change { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(Change) + } +} + +impl From for Change { + fn from(x: NotifyEvent) -> Self { + Self { + kind: x.kind.into(), + paths: x.paths, + } + } +} + +#[derive( + Copy, + Clone, + Debug, + strum::Display, + EnumString, + EnumVariantNames, + Hash, + PartialEq, + Eq, + PartialOrd, + Ord, + Serialize, + Deserialize, +)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +#[strum(serialize_all = "snake_case")] +#[cfg_attr(feature = "clap", derive(clap::ValueEnum))] +#[cfg_attr(feature = "clap", clap(rename_all = "snake_case"))] +pub enum ChangeKind { + /// Something about a file or directory was accessed, but + /// no specific details were known + Access, + + /// A file was closed for executing + AccessCloseExecute, + + /// A file was closed for reading + AccessCloseRead, + + /// A file was closed for writing + AccessCloseWrite, + + /// A file was opened for executing + AccessOpenExecute, + + /// A file was opened for reading + AccessOpenRead, + + /// A file was opened for writing + AccessOpenWrite, + + /// A file or directory was read + AccessRead, + + /// The access time of a file or directory was changed + AccessTime, + + /// A file, directory, or something else was created + Create, + + /// The content of a file or directory changed + Content, + + /// The data of a file or directory was modified, but + /// no specific details were known + Data, + + /// The metadata of a file or directory was modified, but + /// no specific details were known + Metadata, + + /// Something about a file or directory was modified, but + /// no specific details were known + Modify, + + /// A file, directory, or something else was removed + Remove, + + /// A file or directory was renamed, but no specific details were known + Rename, + + /// A file or directory was renamed, and the provided paths + /// are the source and target in that order (from, to) + RenameBoth, + + /// A file or directory was renamed, and the provided path + /// is the origin of the rename (before being renamed) + RenameFrom, + + /// A file or directory was renamed, and the provided path + /// is the result of the rename + RenameTo, + + /// A file's size changed + Size, + + /// The ownership of a file or directory was changed + Ownership, + + /// The permissions of a file or directory was changed + Permissions, + + /// The write or modify time of a file or directory was changed + WriteTime, + + // Catchall in case we have no insight as to the type of change + Unknown, +} + +impl ChangeKind { + /// Returns true if the change is a kind of access + pub fn is_access_kind(&self) -> bool { + self.is_open_access_kind() + || self.is_close_access_kind() + || matches!(self, Self::Access | Self::AccessRead) + } + + /// Returns true if the change is a kind of open access + pub fn is_open_access_kind(&self) -> bool { + matches!( + self, + Self::AccessOpenExecute | Self::AccessOpenRead | Self::AccessOpenWrite + ) + } + + /// Returns true if the change is a kind of close access + pub fn is_close_access_kind(&self) -> bool { + matches!( + self, + Self::AccessCloseExecute | Self::AccessCloseRead | Self::AccessCloseWrite + ) + } + + /// Returns true if the change is a kind of creation + pub fn is_create_kind(&self) -> bool { + matches!(self, Self::Create) + } + + /// Returns true if the change is a kind of modification + pub fn is_modify_kind(&self) -> bool { + self.is_data_modify_kind() || self.is_metadata_modify_kind() || matches!(self, Self::Modify) + } + + /// Returns true if the change is a kind of data modification + pub fn is_data_modify_kind(&self) -> bool { + matches!(self, Self::Content | Self::Data | Self::Size) + } + + /// Returns true if the change is a kind of metadata modification + pub fn is_metadata_modify_kind(&self) -> bool { + matches!( + self, + Self::AccessTime + | Self::Metadata + | Self::Ownership + | Self::Permissions + | Self::WriteTime + ) + } + + /// Returns true if the change is a kind of rename + pub fn is_rename_kind(&self) -> bool { + matches!( + self, + Self::Rename | Self::RenameBoth | Self::RenameFrom | Self::RenameTo + ) + } + + /// Returns true if the change is a kind of removal + pub fn is_remove_kind(&self) -> bool { + matches!(self, Self::Remove) + } + + /// Returns true if the change kind is unknown + pub fn is_unknown_kind(&self) -> bool { + matches!(self, Self::Unknown) + } +} + +#[cfg(feature = "schemars")] +impl ChangeKind { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(ChangeKind) + } +} + +impl BitOr for ChangeKind { + type Output = ChangeKindSet; + + fn bitor(self, rhs: Self) -> Self::Output { + let mut set = ChangeKindSet::empty(); + set.insert(self); + set.insert(rhs); + set + } +} + +impl From for ChangeKind { + fn from(x: NotifyEventKind) -> Self { + use notify::event::{ + AccessKind, AccessMode, DataChange, MetadataKind, ModifyKind, RenameMode, + }; + match x { + // File/directory access events + NotifyEventKind::Access(AccessKind::Read) => Self::AccessRead, + NotifyEventKind::Access(AccessKind::Open(AccessMode::Execute)) => { + Self::AccessOpenExecute + } + NotifyEventKind::Access(AccessKind::Open(AccessMode::Read)) => Self::AccessOpenRead, + NotifyEventKind::Access(AccessKind::Open(AccessMode::Write)) => Self::AccessOpenWrite, + NotifyEventKind::Access(AccessKind::Close(AccessMode::Execute)) => { + Self::AccessCloseExecute + } + NotifyEventKind::Access(AccessKind::Close(AccessMode::Read)) => Self::AccessCloseRead, + NotifyEventKind::Access(AccessKind::Close(AccessMode::Write)) => Self::AccessCloseWrite, + NotifyEventKind::Access(_) => Self::Access, + + // File/directory creation events + NotifyEventKind::Create(_) => Self::Create, + + // Rename-oriented events + NotifyEventKind::Modify(ModifyKind::Name(RenameMode::Both)) => Self::RenameBoth, + NotifyEventKind::Modify(ModifyKind::Name(RenameMode::From)) => Self::RenameFrom, + NotifyEventKind::Modify(ModifyKind::Name(RenameMode::To)) => Self::RenameTo, + NotifyEventKind::Modify(ModifyKind::Name(_)) => Self::Rename, + + // Data-modification events + NotifyEventKind::Modify(ModifyKind::Data(DataChange::Content)) => Self::Content, + NotifyEventKind::Modify(ModifyKind::Data(DataChange::Size)) => Self::Size, + NotifyEventKind::Modify(ModifyKind::Data(_)) => Self::Data, + + // Metadata-modification events + NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::AccessTime)) => { + Self::AccessTime + } + NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::WriteTime)) => { + Self::WriteTime + } + NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::Permissions)) => { + Self::Permissions + } + NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::Ownership)) => { + Self::Ownership + } + NotifyEventKind::Modify(ModifyKind::Metadata(_)) => Self::Metadata, + + // General modification events + NotifyEventKind::Modify(_) => Self::Modify, + + // File/directory removal events + NotifyEventKind::Remove(_) => Self::Remove, + + // Catch-all for other events + NotifyEventKind::Any | NotifyEventKind::Other => Self::Unknown, + } + } +} + +/// Represents a distinct set of different change kinds +#[derive(Clone, Debug, Deref, DerefMut, IntoIterator, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ChangeKindSet(HashSet); + +impl ChangeKindSet { + /// Produces an empty set of [`ChangeKind`] + pub fn empty() -> Self { + Self(HashSet::new()) + } + + /// Produces a set of all [`ChangeKind`] + pub fn all() -> Self { + vec![ + ChangeKind::Access, + ChangeKind::AccessCloseExecute, + ChangeKind::AccessCloseRead, + ChangeKind::AccessCloseWrite, + ChangeKind::AccessOpenExecute, + ChangeKind::AccessOpenRead, + ChangeKind::AccessOpenWrite, + ChangeKind::AccessRead, + ChangeKind::AccessTime, + ChangeKind::Create, + ChangeKind::Content, + ChangeKind::Data, + ChangeKind::Metadata, + ChangeKind::Modify, + ChangeKind::Remove, + ChangeKind::Rename, + ChangeKind::RenameBoth, + ChangeKind::RenameFrom, + ChangeKind::RenameTo, + ChangeKind::Size, + ChangeKind::Ownership, + ChangeKind::Permissions, + ChangeKind::WriteTime, + ChangeKind::Unknown, + ] + .into_iter() + .collect() + } + + /// Produces a changeset containing all of the access kinds + pub fn access_set() -> Self { + Self::access_open_set() + | Self::access_close_set() + | ChangeKind::AccessRead + | ChangeKind::Access + } + + /// Produces a changeset containing all of the open access kinds + pub fn access_open_set() -> Self { + ChangeKind::AccessOpenExecute | ChangeKind::AccessOpenRead | ChangeKind::AccessOpenWrite + } + + /// Produces a changeset containing all of the close access kinds + pub fn access_close_set() -> Self { + ChangeKind::AccessCloseExecute | ChangeKind::AccessCloseRead | ChangeKind::AccessCloseWrite + } + + // Produces a changeset containing all of the modification kinds + pub fn modify_set() -> Self { + Self::modify_data_set() | Self::modify_metadata_set() | ChangeKind::Modify + } + + /// Produces a changeset containing all of the data modification kinds + pub fn modify_data_set() -> Self { + ChangeKind::Content | ChangeKind::Data | ChangeKind::Size + } + + /// Produces a changeset containing all of the metadata modification kinds + pub fn modify_metadata_set() -> Self { + ChangeKind::AccessTime + | ChangeKind::Metadata + | ChangeKind::Ownership + | ChangeKind::Permissions + | ChangeKind::WriteTime + } + + /// Produces a changeset containing all of the rename kinds + pub fn rename_set() -> Self { + ChangeKind::Rename | ChangeKind::RenameBoth | ChangeKind::RenameFrom | ChangeKind::RenameTo + } + + /// Consumes set and returns a vec of the kinds of changes + pub fn into_vec(self) -> Vec { + self.0.into_iter().collect() + } +} + +#[cfg(feature = "schemars")] +impl ChangeKindSet { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(ChangeKindSet) + } +} + +impl fmt::Display for ChangeKindSet { + /// Outputs a comma-separated series of [`ChangeKind`] as string that are sorted + /// such that this will always be consistent output + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut kinds = self + .0 + .iter() + .map(ToString::to_string) + .collect::>(); + kinds.sort_unstable(); + write!(f, "{}", kinds.join(",")) + } +} + +impl PartialEq for ChangeKindSet { + fn eq(&self, other: &Self) -> bool { + self.to_string() == other.to_string() + } +} + +impl Eq for ChangeKindSet {} + +impl Hash for ChangeKindSet { + /// Hashes based on the output of [`fmt::Display`] + fn hash(&self, state: &mut H) { + self.to_string().hash(state); + } +} + +impl BitOr for ChangeKindSet { + type Output = Self; + + fn bitor(mut self, rhs: ChangeKindSet) -> Self::Output { + self.extend(rhs.0); + self + } +} + +impl BitOr for ChangeKindSet { + type Output = Self; + + fn bitor(mut self, rhs: ChangeKind) -> Self::Output { + self.0.insert(rhs); + self + } +} + +impl BitOr for ChangeKind { + type Output = ChangeKindSet; + + fn bitor(self, rhs: ChangeKindSet) -> Self::Output { + rhs | self + } +} + +impl Sub for ChangeKindSet { + type Output = Self; + + fn sub(self, other: Self) -> Self::Output { + ChangeKindSet(&self.0 - &other.0) + } +} + +impl Sub<&'_ ChangeKindSet> for &ChangeKindSet { + type Output = ChangeKindSet; + + fn sub(self, other: &ChangeKindSet) -> Self::Output { + ChangeKindSet(&self.0 - &other.0) + } +} + +impl FromStr for ChangeKindSet { + type Err = strum::ParseError; + + fn from_str(s: &str) -> Result { + let mut change_set = HashSet::new(); + + for word in s.split(',') { + change_set.insert(ChangeKind::from_str(word.trim())?); + } + + Ok(ChangeKindSet(change_set)) + } +} + +impl FromIterator for ChangeKindSet { + fn from_iter>(iter: I) -> Self { + let mut change_set = HashSet::new(); + + for i in iter { + change_set.insert(i); + } + + ChangeKindSet(change_set) + } +} + +impl From for ChangeKindSet { + fn from(change_kind: ChangeKind) -> Self { + let mut set = Self::empty(); + set.insert(change_kind); + set + } +} + +impl From> for ChangeKindSet { + fn from(changes: Vec) -> Self { + changes.into_iter().collect() + } +} + +impl Default for ChangeKindSet { + fn default() -> Self { + Self::empty() + } +} diff --git a/distant-core/src/data/clap_impl.rs b/distant-core/src/data/clap_impl.rs new file mode 100644 index 0000000..8985215 --- /dev/null +++ b/distant-core/src/data/clap_impl.rs @@ -0,0 +1,106 @@ +use crate::{data::Cmd, DistantMsg, DistantRequestData}; +use clap::{ + error::{Error, ErrorKind}, + Arg, ArgAction, ArgMatches, Args, Command, FromArgMatches, Subcommand, +}; + +impl FromArgMatches for Cmd { + fn from_arg_matches(matches: &ArgMatches) -> Result { + let mut matches = matches.clone(); + Self::from_arg_matches_mut(&mut matches) + } + fn from_arg_matches_mut(matches: &mut ArgMatches) -> Result { + let cmd = matches.get_one::("cmd").ok_or_else(|| { + Error::raw( + ErrorKind::MissingRequiredArgument, + "program must be specified", + ) + })?; + let args: Vec = matches + .get_many::("arg") + .unwrap_or_default() + .map(ToString::to_string) + .collect(); + Ok(Self::new(format!("{cmd} {}", args.join(" ")))) + } + fn update_from_arg_matches(&mut self, matches: &ArgMatches) -> Result<(), Error> { + let mut matches = matches.clone(); + self.update_from_arg_matches_mut(&mut matches) + } + fn update_from_arg_matches_mut(&mut self, _matches: &mut ArgMatches) -> Result<(), Error> { + Ok(()) + } +} + +impl Args for Cmd { + fn augment_args(cmd: Command<'_>) -> Command<'_> { + cmd.arg( + Arg::new("cmd") + .required(true) + .value_name("CMD") + .action(ArgAction::Set), + ) + .trailing_var_arg(true) + .arg( + Arg::new("arg") + .value_name("ARGS") + .multiple_values(true) + .action(ArgAction::Append), + ) + } + fn augment_args_for_update(cmd: Command<'_>) -> Command<'_> { + cmd + } +} + +impl FromArgMatches for DistantMsg { + fn from_arg_matches(matches: &ArgMatches) -> Result { + match matches.subcommand() { + Some(("single", args)) => Ok(Self::Single(DistantRequestData::from_arg_matches(args)?)), + Some((_, _)) => Err(Error::raw( + ErrorKind::UnrecognizedSubcommand, + "Valid subcommand is `single`", + )), + None => Err(Error::raw( + ErrorKind::MissingSubcommand, + "Valid subcommand is `single`", + )), + } + } + + fn update_from_arg_matches(&mut self, matches: &ArgMatches) -> Result<(), Error> { + match matches.subcommand() { + Some(("single", args)) => { + *self = Self::Single(DistantRequestData::from_arg_matches(args)?) + } + Some((_, _)) => { + return Err(Error::raw( + ErrorKind::UnrecognizedSubcommand, + "Valid subcommand is `single`", + )) + } + None => (), + }; + Ok(()) + } +} + +impl Subcommand for DistantMsg { + fn augment_subcommands(cmd: Command<'_>) -> Command<'_> { + cmd.subcommand(DistantRequestData::augment_subcommands(Command::new( + "single", + ))) + .subcommand_required(true) + } + + fn augment_subcommands_for_update(cmd: Command<'_>) -> Command<'_> { + cmd.subcommand(DistantRequestData::augment_subcommands(Command::new( + "single", + ))) + .subcommand_required(true) + } + + fn has_subcommand(name: &str) -> bool { + matches!(name, "single") + } +} diff --git a/distant-core/src/data/cmd.rs b/distant-core/src/data/cmd.rs new file mode 100644 index 0000000..11c7942 --- /dev/null +++ b/distant-core/src/data/cmd.rs @@ -0,0 +1,52 @@ +use derive_more::{Display, From, Into}; +use serde::{Deserialize, Serialize}; +use std::ops::{Deref, DerefMut}; + +/// Represents some command with arguments to execute +#[derive(Clone, Debug, Display, From, Into, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Cmd(String); + +impl Cmd { + /// Creates a new command from the given `cmd` + pub fn new(cmd: impl Into) -> Self { + Self(cmd.into()) + } + + /// Returns reference to the program portion of the command + pub fn program(&self) -> &str { + match self.0.split_once(' ') { + Some((program, _)) => program.trim(), + None => self.0.trim(), + } + } + + /// Returns reference to the arguments portion of the command + pub fn arguments(&self) -> &str { + match self.0.split_once(' ') { + Some((_, arguments)) => arguments.trim(), + None => "", + } + } +} + +#[cfg(feature = "schemars")] +impl Cmd { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(Cmd) + } +} + +impl Deref for Cmd { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Cmd { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/distant-core/src/data/error.rs b/distant-core/src/data/error.rs new file mode 100644 index 0000000..406d02a --- /dev/null +++ b/distant-core/src/data/error.rs @@ -0,0 +1,269 @@ +use derive_more::Display; +use notify::ErrorKind as NotifyErrorKind; +use serde::{Deserialize, Serialize}; +use std::io; + +/// General purpose error type that can be sent across the wire +#[derive(Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[display(fmt = "{}: {}", kind, description)] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +pub struct Error { + /// Label describing the kind of error + pub kind: ErrorKind, + + /// Description of the error itself + pub description: String, +} + +impl std::error::Error for Error {} + +#[cfg(feature = "schemars")] +impl Error { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(Error) + } +} + +impl<'a> From<&'a str> for Error { + fn from(x: &'a str) -> Self { + Self::from(x.to_string()) + } +} + +impl From for Error { + fn from(x: String) -> Self { + Self { + kind: ErrorKind::Other, + description: x, + } + } +} + +impl From for Error { + fn from(x: io::Error) -> Self { + Self { + kind: ErrorKind::from(x.kind()), + description: x.to_string(), + } + } +} + +impl From for io::Error { + fn from(x: Error) -> Self { + Self::new(x.kind.into(), x.description) + } +} + +impl From for Error { + fn from(x: notify::Error) -> Self { + let err = match x.kind { + NotifyErrorKind::Generic(x) => Self { + kind: ErrorKind::Other, + description: x, + }, + NotifyErrorKind::Io(x) => Self::from(x), + NotifyErrorKind::PathNotFound => Self { + kind: ErrorKind::Other, + description: String::from("Path not found"), + }, + NotifyErrorKind::WatchNotFound => Self { + kind: ErrorKind::Other, + description: String::from("Watch not found"), + }, + NotifyErrorKind::InvalidConfig(_) => Self { + kind: ErrorKind::Other, + description: String::from("Invalid config"), + }, + NotifyErrorKind::MaxFilesWatch => Self { + kind: ErrorKind::Other, + description: String::from("Max files watched"), + }, + }; + + Self { + kind: err.kind, + description: format!( + "{}\n\nPaths: {}", + err.description, + x.paths + .into_iter() + .map(|p| p.to_string_lossy().to_string()) + .collect::>() + .join(", ") + ), + } + } +} + +impl From for Error { + fn from(x: walkdir::Error) -> Self { + if x.io_error().is_some() { + x.into_io_error().map(Self::from).unwrap() + } else { + Self { + kind: ErrorKind::Loop, + description: format!("{}", x), + } + } + } +} + +impl From for Error { + fn from(x: tokio::task::JoinError) -> Self { + Self { + kind: if x.is_cancelled() { + ErrorKind::TaskCancelled + } else { + ErrorKind::TaskPanicked + }, + description: format!("{}", x), + } + } +} + +/// All possible kinds of errors that can be returned +#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +pub enum ErrorKind { + /// An entity was not found, often a file + NotFound, + + /// The operation lacked the necessary privileges to complete + PermissionDenied, + + /// The connection was refused by the remote server + ConnectionRefused, + + /// The connection was reset by the remote server + ConnectionReset, + + /// The connection was aborted (terminated) by the remote server + ConnectionAborted, + + /// The network operation failed because it was not connected yet + NotConnected, + + /// A socket address could not be bound because the address is already in use elsewhere + AddrInUse, + + /// A nonexistent interface was requested or the requested address was not local + AddrNotAvailable, + + /// The operation failed because a pipe was closed + BrokenPipe, + + /// An entity already exists, often a file + AlreadyExists, + + /// The operation needs to block to complete, but the blocking operation was requested to not + /// occur + WouldBlock, + + /// A parameter was incorrect + InvalidInput, + + /// Data not valid for the operation were encountered + InvalidData, + + /// The I/O operation's timeout expired, causing it to be cancelled + TimedOut, + + /// An error returned when an operation could not be completed because a + /// call to `write` returned `Ok(0)` + WriteZero, + + /// This operation was interrupted + Interrupted, + + /// Any I/O error not part of this list + Other, + + /// An error returned when an operation could not be completed because an "end of file" was + /// reached prematurely + UnexpectedEof, + + /// This operation is unsupported on this platform + Unsupported, + + /// An operation could not be completed, because it failed to allocate enough memory + OutOfMemory, + + /// When a loop is encountered when walking a directory + Loop, + + /// When a task is cancelled + TaskCancelled, + + /// When a task panics + TaskPanicked, + + /// Catchall for an error that has no specific type + Unknown, +} + +#[cfg(feature = "schemars")] +impl ErrorKind { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(ErrorKind) + } +} + +impl From for ErrorKind { + fn from(kind: io::ErrorKind) -> Self { + match kind { + io::ErrorKind::NotFound => Self::NotFound, + io::ErrorKind::PermissionDenied => Self::PermissionDenied, + io::ErrorKind::ConnectionRefused => Self::ConnectionRefused, + io::ErrorKind::ConnectionReset => Self::ConnectionReset, + io::ErrorKind::ConnectionAborted => Self::ConnectionAborted, + io::ErrorKind::NotConnected => Self::NotConnected, + io::ErrorKind::AddrInUse => Self::AddrInUse, + io::ErrorKind::AddrNotAvailable => Self::AddrNotAvailable, + io::ErrorKind::BrokenPipe => Self::BrokenPipe, + io::ErrorKind::AlreadyExists => Self::AlreadyExists, + io::ErrorKind::WouldBlock => Self::WouldBlock, + io::ErrorKind::InvalidInput => Self::InvalidInput, + io::ErrorKind::InvalidData => Self::InvalidData, + io::ErrorKind::TimedOut => Self::TimedOut, + io::ErrorKind::WriteZero => Self::WriteZero, + io::ErrorKind::Interrupted => Self::Interrupted, + io::ErrorKind::Other => Self::Other, + io::ErrorKind::OutOfMemory => Self::OutOfMemory, + io::ErrorKind::UnexpectedEof => Self::UnexpectedEof, + io::ErrorKind::Unsupported => Self::Unsupported, + + // This exists because io::ErrorKind is non_exhaustive + _ => Self::Unknown, + } + } +} + +impl From for io::ErrorKind { + fn from(kind: ErrorKind) -> Self { + match kind { + ErrorKind::NotFound => Self::NotFound, + ErrorKind::PermissionDenied => Self::PermissionDenied, + ErrorKind::ConnectionRefused => Self::ConnectionRefused, + ErrorKind::ConnectionReset => Self::ConnectionReset, + ErrorKind::ConnectionAborted => Self::ConnectionAborted, + ErrorKind::NotConnected => Self::NotConnected, + ErrorKind::AddrInUse => Self::AddrInUse, + ErrorKind::AddrNotAvailable => Self::AddrNotAvailable, + ErrorKind::BrokenPipe => Self::BrokenPipe, + ErrorKind::AlreadyExists => Self::AlreadyExists, + ErrorKind::WouldBlock => Self::WouldBlock, + ErrorKind::InvalidInput => Self::InvalidInput, + ErrorKind::InvalidData => Self::InvalidData, + ErrorKind::TimedOut => Self::TimedOut, + ErrorKind::WriteZero => Self::WriteZero, + ErrorKind::Interrupted => Self::Interrupted, + ErrorKind::Other => Self::Other, + ErrorKind::OutOfMemory => Self::OutOfMemory, + ErrorKind::UnexpectedEof => Self::UnexpectedEof, + ErrorKind::Unsupported => Self::Unsupported, + _ => Self::Other, + } + } +} diff --git a/distant-core/src/data/filesystem.rs b/distant-core/src/data/filesystem.rs new file mode 100644 index 0000000..5ca9f53 --- /dev/null +++ b/distant-core/src/data/filesystem.rs @@ -0,0 +1,45 @@ +use derive_more::IsVariant; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use strum::AsRefStr; + +/// Represents information about a single entry within a directory +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +pub struct DirEntry { + /// Represents the full path to the entry + pub path: PathBuf, + + /// Represents the type of the entry as a file/dir/symlink + pub file_type: FileType, + + /// Depth at which this entry was created relative to the root (0 being immediately within + /// root) + pub depth: usize, +} + +#[cfg(feature = "schemars")] +impl DirEntry { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(DirEntry) + } +} + +/// Represents the type associated with a dir entry +#[derive(Copy, Clone, Debug, PartialEq, Eq, AsRefStr, IsVariant, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +#[strum(serialize_all = "snake_case")] +pub enum FileType { + Dir, + File, + Symlink, +} + +#[cfg(feature = "schemars")] +impl FileType { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(FileType) + } +} diff --git a/distant-core/src/data/map.rs b/distant-core/src/data/map.rs new file mode 100644 index 0000000..41fc9ed --- /dev/null +++ b/distant-core/src/data/map.rs @@ -0,0 +1,244 @@ +use crate::serde_str::{deserialize_from_str, serialize_to_str}; +use derive_more::{From, IntoIterator}; +use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize}; +use std::{ + collections::HashMap, + fmt, + ops::{Deref, DerefMut}, + str::FromStr, +}; + +/// Contains map information for connections and other use cases +#[derive(Clone, Debug, From, IntoIterator, PartialEq, Eq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Map(HashMap); + +impl Map { + pub fn new() -> Self { + Self(HashMap::new()) + } + + pub fn into_map(self) -> HashMap { + self.0 + } +} + +#[cfg(feature = "schemars")] +impl Map { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(Map) + } +} + +impl Default for Map { + fn default() -> Self { + Self::new() + } +} + +impl Deref for Map { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Map { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl fmt::Display for Map { + /// Outputs a `key=value` mapping in the form `key="value",key2="value2"` + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let len = self.0.len(); + for (i, (key, value)) in self.0.iter().enumerate() { + write!(f, "{}=\"{}\"", key, value)?; + + // Include a comma after each but the last pair + if i + 1 < len { + write!(f, ",")?; + } + } + Ok(()) + } +} + +impl FromStr for Map { + type Err = &'static str; + + /// Parses a series of `key=value` pairs in the form `key="value",key2=value2` where + /// the quotes around the value are optional + fn from_str(s: &str) -> Result { + let mut map = HashMap::new(); + + let mut s = s.trim(); + while !s.is_empty() { + // Find {key}={tail...} where tail is everything after = + let (key, tail) = s.split_once('=').ok_or("Missing = after key")?; + + // Remove whitespace around the key and ensure it starts with a proper character + let key = key.trim(); + + if !key.starts_with(char::is_alphabetic) { + return Err("Key must start with alphabetic character"); + } + + // Remove any map whitespace at the front of the tail + let tail = tail.trim_start(); + + // Determine if we start with a quote " otherwise we will look for the next , + let (value, tail) = match tail.strip_prefix('"') { + // If quoted, we maintain the whitespace within the quotes + Some(tail) => { + // Skip the quote so we can look for the trailing quote + let (value, tail) = + tail.split_once('"').ok_or("Missing closing \" for value")?; + + // Skip comma if we have one + let tail = tail.strip_prefix(',').unwrap_or(tail); + + (value, tail) + } + + // If not quoted, we remove all whitespace around the value + None => match tail.split_once(',') { + Some((value, tail)) => (value.trim(), tail), + None => (tail.trim(), ""), + }, + }; + + // Insert our new pair and update the slice to be the tail (removing whitespace) + map.insert(key.to_string(), value.to_string()); + s = tail.trim(); + } + + Ok(Self(map)) + } +} + +impl Serialize for Map { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serialize_to_str(self, serializer) + } +} + +impl<'de> Deserialize<'de> for Map { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserialize_from_str(deserializer) + } +} + +#[macro_export] +macro_rules! map { + ($($key:literal -> $value:literal),*) => {{ + let mut _map = ::std::collections::HashMap::new(); + + $( + _map.insert($key.to_string(), $value.to_string()); + )* + + $crate::Map::from(_map) + }}; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_support_being_parsed_from_str() { + // Empty string (whitespace only) yields an empty map + let map = " ".parse::().unwrap(); + assert_eq!(map, map!()); + + // Simple key=value should succeed + let map = "key=value".parse::().unwrap(); + assert_eq!(map, map!("key" -> "value")); + + // Key can be anything but = + let map = "key.with-characters@=value".parse::().unwrap(); + assert_eq!(map, map!("key.with-characters@" -> "value")); + + // Value can be anything but , + let map = "key=value.has -@#$".parse::().unwrap(); + assert_eq!(map, map!("key" -> "value.has -@#$")); + + // Value can include comma if quoted + let map = r#"key=",,,,""#.parse::().unwrap(); + assert_eq!(map, map!("key" -> ",,,,")); + + // Supports whitespace around key and value + let map = " key = value ".parse::().unwrap(); + assert_eq!(map, map!("key" -> "value")); + + // Supports value capturing whitespace if quoted + let map = r#" key = " value " "#.parse::().unwrap(); + assert_eq!(map, map!("key" -> " value ")); + + // Multiple key=value should succeed + let map = "key=value,key2=value2".parse::().unwrap(); + assert_eq!(map, map!("key" -> "value", "key2" -> "value2")); + + // Quoted key=value should succeed + let map = r#"key="value one",key2=value2"#.parse::().unwrap(); + assert_eq!(map, map!("key" -> "value one", "key2" -> "value2")); + + let map = r#"key=value,key2="value two""#.parse::().unwrap(); + assert_eq!(map, map!("key" -> "value", "key2" -> "value two")); + + let map = r#"key="value one",key2="value two""#.parse::().unwrap(); + assert_eq!(map, map!("key" -> "value one", "key2" -> "value two")); + + let map = r#"key="1,2,3",key2="4,5,6""#.parse::().unwrap(); + assert_eq!(map, map!("key" -> "1,2,3", "key2" -> "4,5,6")); + + // Dangling comma is okay + let map = "key=value,".parse::().unwrap(); + assert_eq!(map, map!("key" -> "value")); + let map = r#"key=",value,","#.parse::().unwrap(); + assert_eq!(map, map!("key" -> ",value,")); + + // Demonstrating greedy + let map = "key=value key2=value2".parse::().unwrap(); + assert_eq!(map, map!("key" -> "value key2=value2")); + + // Variety of edge cases that should fail + let _ = ",".parse::().unwrap_err(); + let _ = ",key=value".parse::().unwrap_err(); + let _ = "key=value,key2".parse::().unwrap_err(); + } + + #[test] + fn should_support_being_displayed_as_a_string() { + let map = map!().to_string(); + assert_eq!(map, ""); + + let map = map!("key" -> "value").to_string(); + assert_eq!(map, r#"key="value""#); + + // Order of key=value output is not guaranteed + let map = map!("key" -> "value", "key2" -> "value2").to_string(); + assert!( + map == r#"key="value",key2="value2""# || map == r#"key2="value2",key="value""#, + "{:?}", + map + ); + + // Order of key=value output is not guaranteed + let map = map!("key" -> ",", "key2" -> ",,").to_string(); + assert!( + map == r#"key=",",key2=",,""# || map == r#"key2=",,",key=",""#, + "{:?}", + map + ); + } +} diff --git a/distant-core/src/data/metadata.rs b/distant-core/src/data/metadata.rs new file mode 100644 index 0000000..6a25a9e --- /dev/null +++ b/distant-core/src/data/metadata.rs @@ -0,0 +1,404 @@ +use super::{deserialize_u128_option, serialize_u128_option, FileType}; +use bitflags::bitflags; +use serde::{Deserialize, Serialize}; +use std::{ + io, + path::{Path, PathBuf}, + time::SystemTime, +}; + +/// Represents metadata about some path on a remote machine +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Metadata { + /// Canonicalized path to the file or directory, resolving symlinks, only included + /// if flagged during the request + pub canonicalized_path: Option, + + /// Represents the type of the entry as a file/dir/symlink + pub file_type: FileType, + + /// Size of the file/directory/symlink in bytes + pub len: u64, + + /// Whether or not the file/directory/symlink is marked as unwriteable + pub readonly: bool, + + /// Represents the last time (in milliseconds) when the file/directory/symlink was accessed; + /// can be optional as certain systems don't support this + #[serde(serialize_with = "serialize_u128_option")] + #[serde(deserialize_with = "deserialize_u128_option")] + pub accessed: Option, + + /// Represents when (in milliseconds) the file/directory/symlink was created; + /// can be optional as certain systems don't support this + #[serde(serialize_with = "serialize_u128_option")] + #[serde(deserialize_with = "deserialize_u128_option")] + pub created: Option, + + /// Represents the last time (in milliseconds) when the file/directory/symlink was modified; + /// can be optional as certain systems don't support this + #[serde(serialize_with = "serialize_u128_option")] + #[serde(deserialize_with = "deserialize_u128_option")] + pub modified: Option, + + /// Represents metadata that is specific to a unix remote machine + pub unix: Option, + + /// Represents metadata that is specific to a windows remote machine + pub windows: Option, +} + +impl Metadata { + pub async fn read( + path: impl AsRef, + canonicalize: bool, + resolve_file_type: bool, + ) -> io::Result { + let metadata = tokio::fs::symlink_metadata(path.as_ref()).await?; + let canonicalized_path = if canonicalize { + Some(tokio::fs::canonicalize(path.as_ref()).await?) + } else { + None + }; + + // If asking for resolved file type and current type is symlink, then we want to refresh + // our metadata to get the filetype for the resolved link + let file_type = if resolve_file_type && metadata.file_type().is_symlink() { + tokio::fs::metadata(path).await?.file_type() + } else { + metadata.file_type() + }; + + Ok(Self { + canonicalized_path, + accessed: metadata + .accessed() + .ok() + .and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok()) + .map(|d| d.as_millis()), + created: metadata + .created() + .ok() + .and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok()) + .map(|d| d.as_millis()), + modified: metadata + .modified() + .ok() + .and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok()) + .map(|d| d.as_millis()), + len: metadata.len(), + readonly: metadata.permissions().readonly(), + file_type: if file_type.is_dir() { + FileType::Dir + } else if file_type.is_file() { + FileType::File + } else { + FileType::Symlink + }, + + #[cfg(unix)] + unix: Some({ + use std::os::unix::prelude::*; + let mode = metadata.mode(); + crate::data::UnixMetadata::from(mode) + }), + #[cfg(not(unix))] + unix: None, + + #[cfg(windows)] + windows: Some({ + use std::os::windows::prelude::*; + let attributes = metadata.file_attributes(); + crate::data::WindowsMetadata::from(attributes) + }), + #[cfg(not(windows))] + windows: None, + }) + } +} + +#[cfg(feature = "schemars")] +impl Metadata { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(Metadata) + } +} + +/// Represents unix-specific metadata about some path on a remote machine +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct UnixMetadata { + /// Represents whether or not owner can read from the file + pub owner_read: bool, + + /// Represents whether or not owner can write to the file + pub owner_write: bool, + + /// Represents whether or not owner can execute the file + pub owner_exec: bool, + + /// Represents whether or not associated group can read from the file + pub group_read: bool, + + /// Represents whether or not associated group can write to the file + pub group_write: bool, + + /// Represents whether or not associated group can execute the file + pub group_exec: bool, + + /// Represents whether or not other can read from the file + pub other_read: bool, + + /// Represents whether or not other can write to the file + pub other_write: bool, + + /// Represents whether or not other can execute the file + pub other_exec: bool, +} + +#[cfg(feature = "schemars")] +impl UnixMetadata { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(UnixMetadata) + } +} + +impl From for UnixMetadata { + /// Create from a unix mode bitset + fn from(mode: u32) -> Self { + let flags = UnixFilePermissionFlags::from_bits_truncate(mode); + Self { + owner_read: flags.contains(UnixFilePermissionFlags::OWNER_READ), + owner_write: flags.contains(UnixFilePermissionFlags::OWNER_WRITE), + owner_exec: flags.contains(UnixFilePermissionFlags::OWNER_EXEC), + group_read: flags.contains(UnixFilePermissionFlags::GROUP_READ), + group_write: flags.contains(UnixFilePermissionFlags::GROUP_WRITE), + group_exec: flags.contains(UnixFilePermissionFlags::GROUP_EXEC), + other_read: flags.contains(UnixFilePermissionFlags::OTHER_READ), + other_write: flags.contains(UnixFilePermissionFlags::OTHER_WRITE), + other_exec: flags.contains(UnixFilePermissionFlags::OTHER_EXEC), + } + } +} + +impl From for u32 { + /// Convert to a unix mode bitset + fn from(metadata: UnixMetadata) -> Self { + let mut flags = UnixFilePermissionFlags::empty(); + + if metadata.owner_read { + flags.insert(UnixFilePermissionFlags::OWNER_READ); + } + if metadata.owner_write { + flags.insert(UnixFilePermissionFlags::OWNER_WRITE); + } + if metadata.owner_exec { + flags.insert(UnixFilePermissionFlags::OWNER_EXEC); + } + + if metadata.group_read { + flags.insert(UnixFilePermissionFlags::GROUP_READ); + } + if metadata.group_write { + flags.insert(UnixFilePermissionFlags::GROUP_WRITE); + } + if metadata.group_exec { + flags.insert(UnixFilePermissionFlags::GROUP_EXEC); + } + + if metadata.other_read { + flags.insert(UnixFilePermissionFlags::OTHER_READ); + } + if metadata.other_write { + flags.insert(UnixFilePermissionFlags::OTHER_WRITE); + } + if metadata.other_exec { + flags.insert(UnixFilePermissionFlags::OTHER_EXEC); + } + + flags.bits + } +} + +impl UnixMetadata { + pub fn is_readonly(self) -> bool { + !(self.owner_read || self.group_read || self.other_read) + } +} + +bitflags! { + struct UnixFilePermissionFlags: u32 { + const OWNER_READ = 0o400; + const OWNER_WRITE = 0o200; + const OWNER_EXEC = 0o100; + const GROUP_READ = 0o40; + const GROUP_WRITE = 0o20; + const GROUP_EXEC = 0o10; + const OTHER_READ = 0o4; + const OTHER_WRITE = 0o2; + const OTHER_EXEC = 0o1; + } +} + +/// Represents windows-specific metadata about some path on a remote machine +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct WindowsMetadata { + /// Represents whether or not a file or directory is an archive + pub archive: bool, + + /// Represents whether or not a file or directory is compressed + pub compressed: bool, + + /// Represents whether or not the file or directory is encrypted + pub encrypted: bool, + + /// Represents whether or not a file or directory is hidden + pub hidden: bool, + + /// Represents whether or not a directory or user data stream is configured with integrity + pub integrity_stream: bool, + + /// Represents whether or not a file does not have other attributes set + pub normal: bool, + + /// Represents whether or not a file or directory is not to be indexed by content indexing + /// service + pub not_content_indexed: bool, + + /// Represents whether or not a user data stream is not to be read by the background data + /// integrity scanner + pub no_scrub_data: bool, + + /// Represents whether or not the data of a file is not available immediately + pub offline: bool, + + /// Represents whether or not a file or directory is not fully present locally + pub recall_on_data_access: bool, + + /// Represents whether or not a file or directory has no physical representation on the local + /// system (is virtual) + pub recall_on_open: bool, + + /// Represents whether or not a file or directory has an associated reparse point, or a file is + /// a symbolic link + pub reparse_point: bool, + + /// Represents whether or not a file is a sparse file + pub sparse_file: bool, + + /// Represents whether or not a file or directory is used partially or exclusively by the + /// operating system + pub system: bool, + + /// Represents whether or not a file is being used for temporary storage + pub temporary: bool, +} + +#[cfg(feature = "schemars")] +impl WindowsMetadata { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(WindowsMetadata) + } +} + +impl From for WindowsMetadata { + /// Create from a windows file attribute bitset + fn from(file_attributes: u32) -> Self { + let flags = WindowsFileAttributeFlags::from_bits_truncate(file_attributes); + Self { + archive: flags.contains(WindowsFileAttributeFlags::ARCHIVE), + compressed: flags.contains(WindowsFileAttributeFlags::COMPRESSED), + encrypted: flags.contains(WindowsFileAttributeFlags::ENCRYPTED), + hidden: flags.contains(WindowsFileAttributeFlags::HIDDEN), + integrity_stream: flags.contains(WindowsFileAttributeFlags::INTEGRITY_SYSTEM), + normal: flags.contains(WindowsFileAttributeFlags::NORMAL), + not_content_indexed: flags.contains(WindowsFileAttributeFlags::NOT_CONTENT_INDEXED), + no_scrub_data: flags.contains(WindowsFileAttributeFlags::NO_SCRUB_DATA), + offline: flags.contains(WindowsFileAttributeFlags::OFFLINE), + recall_on_data_access: flags.contains(WindowsFileAttributeFlags::RECALL_ON_DATA_ACCESS), + recall_on_open: flags.contains(WindowsFileAttributeFlags::RECALL_ON_OPEN), + reparse_point: flags.contains(WindowsFileAttributeFlags::REPARSE_POINT), + sparse_file: flags.contains(WindowsFileAttributeFlags::SPARSE_FILE), + system: flags.contains(WindowsFileAttributeFlags::SYSTEM), + temporary: flags.contains(WindowsFileAttributeFlags::TEMPORARY), + } + } +} + +impl From for u32 { + /// Convert to a windows file attribute bitset + fn from(metadata: WindowsMetadata) -> Self { + let mut flags = WindowsFileAttributeFlags::empty(); + + if metadata.archive { + flags.insert(WindowsFileAttributeFlags::ARCHIVE); + } + if metadata.compressed { + flags.insert(WindowsFileAttributeFlags::COMPRESSED); + } + if metadata.encrypted { + flags.insert(WindowsFileAttributeFlags::ENCRYPTED); + } + if metadata.hidden { + flags.insert(WindowsFileAttributeFlags::HIDDEN); + } + if metadata.integrity_stream { + flags.insert(WindowsFileAttributeFlags::INTEGRITY_SYSTEM); + } + if metadata.normal { + flags.insert(WindowsFileAttributeFlags::NORMAL); + } + if metadata.not_content_indexed { + flags.insert(WindowsFileAttributeFlags::NOT_CONTENT_INDEXED); + } + if metadata.no_scrub_data { + flags.insert(WindowsFileAttributeFlags::NO_SCRUB_DATA); + } + if metadata.offline { + flags.insert(WindowsFileAttributeFlags::OFFLINE); + } + if metadata.recall_on_data_access { + flags.insert(WindowsFileAttributeFlags::RECALL_ON_DATA_ACCESS); + } + if metadata.recall_on_open { + flags.insert(WindowsFileAttributeFlags::RECALL_ON_OPEN); + } + if metadata.reparse_point { + flags.insert(WindowsFileAttributeFlags::REPARSE_POINT); + } + if metadata.sparse_file { + flags.insert(WindowsFileAttributeFlags::SPARSE_FILE); + } + if metadata.system { + flags.insert(WindowsFileAttributeFlags::SYSTEM); + } + if metadata.temporary { + flags.insert(WindowsFileAttributeFlags::TEMPORARY); + } + + flags.bits + } +} + +bitflags! { + struct WindowsFileAttributeFlags: u32 { + const ARCHIVE = 0x20; + const COMPRESSED = 0x800; + const ENCRYPTED = 0x4000; + const HIDDEN = 0x2; + const INTEGRITY_SYSTEM = 0x8000; + const NORMAL = 0x80; + const NOT_CONTENT_INDEXED = 0x2000; + const NO_SCRUB_DATA = 0x20000; + const OFFLINE = 0x1000; + const RECALL_ON_DATA_ACCESS = 0x400000; + const RECALL_ON_OPEN = 0x40000; + const REPARSE_POINT = 0x400; + const SPARSE_FILE = 0x200; + const SYSTEM = 0x4; + const TEMPORARY = 0x100; + const VIRTUAL = 0x10000; + } +} diff --git a/distant-core/src/data/pty.rs b/distant-core/src/data/pty.rs new file mode 100644 index 0000000..2fd57c3 --- /dev/null +++ b/distant-core/src/data/pty.rs @@ -0,0 +1,137 @@ +use derive_more::{Display, Error}; +use portable_pty::PtySize as PortablePtySize; +use serde::{Deserialize, Serialize}; +use std::{fmt, num::ParseIntError, str::FromStr}; + +/// Represents the size associated with a remote PTY +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct PtySize { + /// Number of lines of text + pub rows: u16, + + /// Number of columns of text + pub cols: u16, + + /// Width of a cell in pixels. Note that some systems never fill this value and ignore it. + #[serde(default)] + pub pixel_width: u16, + + /// Height of a cell in pixels. Note that some systems never fill this value and ignore it. + #[serde(default)] + pub pixel_height: u16, +} + +impl PtySize { + /// Creates new size using just rows and columns + pub fn from_rows_and_cols(rows: u16, cols: u16) -> Self { + Self { + rows, + cols, + ..Default::default() + } + } +} + +#[cfg(feature = "schemars")] +impl PtySize { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(PtySize) + } +} + +impl From for PtySize { + fn from(size: PortablePtySize) -> Self { + Self { + rows: size.rows, + cols: size.cols, + pixel_width: size.pixel_width, + pixel_height: size.pixel_height, + } + } +} + +impl From for PortablePtySize { + fn from(size: PtySize) -> Self { + Self { + rows: size.rows, + cols: size.cols, + pixel_width: size.pixel_width, + pixel_height: size.pixel_height, + } + } +} + +impl fmt::Display for PtySize { + /// Prints out `rows,cols[,pixel_width,pixel_height]` where the + /// pixel width and pixel height are only included if either + /// one of them is not zero + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{},{}", self.rows, self.cols)?; + if self.pixel_width > 0 || self.pixel_height > 0 { + write!(f, ",{},{}", self.pixel_width, self.pixel_height)?; + } + + Ok(()) + } +} + +impl Default for PtySize { + fn default() -> Self { + PtySize { + rows: 24, + cols: 80, + pixel_width: 0, + pixel_height: 0, + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Display, Error)] +pub enum PtySizeParseError { + MissingRows, + MissingColumns, + InvalidRows(ParseIntError), + InvalidColumns(ParseIntError), + InvalidPixelWidth(ParseIntError), + InvalidPixelHeight(ParseIntError), +} + +impl FromStr for PtySize { + type Err = PtySizeParseError; + + /// Attempts to parse a str into PtySize using one of the following formats: + /// + /// * rows,cols (defaults to 0 for pixel_width & pixel_height) + /// * rows,cols,pixel_width,pixel_height + fn from_str(s: &str) -> Result { + let mut tokens = s.split(','); + + Ok(Self { + rows: tokens + .next() + .ok_or(PtySizeParseError::MissingRows)? + .trim() + .parse() + .map_err(PtySizeParseError::InvalidRows)?, + cols: tokens + .next() + .ok_or(PtySizeParseError::MissingColumns)? + .trim() + .parse() + .map_err(PtySizeParseError::InvalidColumns)?, + pixel_width: tokens + .next() + .map(|s| s.trim().parse()) + .transpose() + .map_err(PtySizeParseError::InvalidPixelWidth)? + .unwrap_or(0), + pixel_height: tokens + .next() + .map(|s| s.trim().parse()) + .transpose() + .map_err(PtySizeParseError::InvalidPixelHeight)? + .unwrap_or(0), + }) + } +} diff --git a/distant-core/src/data/system.rs b/distant-core/src/data/system.rs new file mode 100644 index 0000000..fb3d4f0 --- /dev/null +++ b/distant-core/src/data/system.rs @@ -0,0 +1,45 @@ +use serde::{Deserialize, Serialize}; +use std::{env, path::PathBuf}; + +/// Represents information about a system +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct SystemInfo { + /// Family of the operating system as described in + /// https://doc.rust-lang.org/std/env/consts/constant.FAMILY.html + pub family: String, + + /// Name of the specific operating system as described in + /// https://doc.rust-lang.org/std/env/consts/constant.OS.html + pub os: String, + + /// Architecture of the CPI as described in + /// https://doc.rust-lang.org/std/env/consts/constant.ARCH.html + pub arch: String, + + /// Current working directory of the running server process + pub current_dir: PathBuf, + + /// Primary separator for path components for the current platform + /// as defined in https://doc.rust-lang.org/std/path/constant.MAIN_SEPARATOR.html + pub main_separator: char, +} + +#[cfg(feature = "schemars")] +impl SystemInfo { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(SystemInfo) + } +} + +impl Default for SystemInfo { + fn default() -> Self { + Self { + family: env::consts::FAMILY.to_string(), + os: env::consts::OS.to_string(), + arch: env::consts::ARCH.to_string(), + current_dir: env::current_dir().unwrap_or_default(), + main_separator: std::path::MAIN_SEPARATOR, + } + } +} diff --git a/distant-core/src/data/utils.rs b/distant-core/src/data/utils.rs new file mode 100644 index 0000000..ceec131 --- /dev/null +++ b/distant-core/src/data/utils.rs @@ -0,0 +1,27 @@ +use serde::{Deserialize, Serialize}; + +pub(crate) fn deserialize_u128_option<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + match Option::::deserialize(deserializer)? { + Some(s) => match s.parse::() { + Ok(value) => Ok(Some(value)), + Err(error) => Err(serde::de::Error::custom(format!( + "Cannot convert to u128 with error: {:?}", + error + ))), + }, + None => Ok(None), + } +} + +pub(crate) fn serialize_u128_option( + val: &Option, + s: S, +) -> Result { + match val { + Some(v) => format!("{}", *v).serialize(s), + None => s.serialize_unit(), + } +} diff --git a/distant-core/src/lib.rs b/distant-core/src/lib.rs index 7754de9..a457712 100644 --- a/distant-core/src/lib.rs +++ b/distant-core/src/lib.rs @@ -1,13 +1,20 @@ +mod api; +pub use api::*; + mod client; pub use client::*; -mod constants; - -mod net; -pub use net::*; +mod credentials; +pub use credentials::*; pub mod data; -pub use data::*; +pub use data::{DistantMsg, DistantRequestData, DistantResponseData, Map}; + +mod manager; +pub use manager::*; + +mod constants; +mod serde_str; -mod server; -pub use server::*; +/// Re-export of `distant-net` as `net` +pub use distant_net as net; diff --git a/distant-core/src/manager.rs b/distant-core/src/manager.rs new file mode 100644 index 0000000..6a9bf19 --- /dev/null +++ b/distant-core/src/manager.rs @@ -0,0 +1,7 @@ +mod client; +mod data; +mod server; + +pub use client::*; +pub use data::*; +pub use server::*; diff --git a/distant-core/src/manager/client.rs b/distant-core/src/manager/client.rs new file mode 100644 index 0000000..1acdcc0 --- /dev/null +++ b/distant-core/src/manager/client.rs @@ -0,0 +1,761 @@ +use super::data::{ + ConnectionId, ConnectionInfo, ConnectionList, Destination, Extra, ManagerRequest, + ManagerResponse, +}; +use crate::{DistantChannel, DistantClient, DistantMsg, DistantRequestData, DistantResponseData}; +use distant_net::{ + router, Auth, AuthServer, Client, IntoSplit, MpscTransport, OneshotListener, Request, Response, + ServerExt, ServerRef, UntypedTransportRead, UntypedTransportWrite, +}; +use log::*; +use std::{ + collections::HashMap, + io, + ops::{Deref, DerefMut}, +}; +use tokio::task::JoinHandle; + +mod config; +pub use config::*; + +mod ext; +pub use ext::*; + +router!(DistantManagerClientRouter { + auth_transport: Request => Response, + manager_transport: Response => Request, +}); + +/// Represents a client that can connect to a remote distant manager +pub struct DistantManagerClient { + auth: Box, + client: Client, + distant_clients: HashMap, +} + +impl Drop for DistantManagerClient { + fn drop(&mut self) { + self.auth.abort(); + self.client.abort(); + } +} + +/// Represents a raw channel between a manager client and some remote server +pub struct RawDistantChannel { + pub transport: MpscTransport< + Request>, + Response>, + >, + forward_task: JoinHandle<()>, + mailbox_task: JoinHandle<()>, +} + +impl RawDistantChannel { + pub fn abort(&self) { + self.forward_task.abort(); + self.mailbox_task.abort(); + } +} + +impl Deref for RawDistantChannel { + type Target = MpscTransport< + Request>, + Response>, + >; + + fn deref(&self) -> &Self::Target { + &self.transport + } +} + +impl DerefMut for RawDistantChannel { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.transport + } +} + +struct ClientHandle { + client: DistantClient, + forward_task: JoinHandle<()>, + mailbox_task: JoinHandle<()>, +} + +impl Drop for ClientHandle { + fn drop(&mut self) { + self.forward_task.abort(); + self.mailbox_task.abort(); + } +} + +impl DistantManagerClient { + /// Initializes a client using the provided [`UntypedTransport`] + pub fn new(config: DistantManagerClientConfig, transport: T) -> io::Result + where + T: IntoSplit + 'static, + T::Read: UntypedTransportRead + 'static, + T::Write: UntypedTransportWrite + 'static, + { + let DistantManagerClientRouter { + auth_transport, + manager_transport, + .. + } = DistantManagerClientRouter::new(transport); + + // Initialize our client with manager request/response transport + let (writer, reader) = manager_transport.into_split(); + let client = Client::new(writer, reader)?; + + // Initialize our auth handler with auth/auth transport + let auth = AuthServer { + on_challenge: config.on_challenge, + on_verify: config.on_verify, + on_info: config.on_info, + on_error: config.on_error, + } + .start(OneshotListener::from_value(auth_transport.into_split()))?; + + Ok(Self { + auth, + client, + distant_clients: HashMap::new(), + }) + } + + /// Request that the manager launches a new server at the given `destination` + /// with `extra` being passed for destination-specific details, returning the new + /// `destination` of the spawned server to connect to + pub async fn launch( + &mut self, + destination: impl Into, + extra: impl Into, + ) -> io::Result { + let destination = Box::new(destination.into()); + let extra = extra.into(); + trace!("launch({}, {})", destination, extra); + + let res = self + .client + .send(ManagerRequest::Launch { destination, extra }) + .await?; + match res.payload { + ManagerResponse::Launched { destination } => Ok(destination), + ManagerResponse::Error(x) => Err(x.into()), + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + } + } + + /// Request that the manager establishes a new connection at the given `destination` + /// with `extra` being passed for destination-specific details + pub async fn connect( + &mut self, + destination: impl Into, + extra: impl Into, + ) -> io::Result { + let destination = Box::new(destination.into()); + let extra = extra.into(); + trace!("connect({}, {})", destination, extra); + + let res = self + .client + .send(ManagerRequest::Connect { destination, extra }) + .await?; + match res.payload { + ManagerResponse::Connected { id } => Ok(id), + ManagerResponse::Error(x) => Err(x.into()), + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + } + } + + /// Establishes a channel with the server represented by the `connection_id`, + /// returning a [`DistantChannel`] acting as the connection + /// + /// ### Note + /// + /// Multiple calls to open a channel against the same connection will result in + /// clones of the same [`DistantChannel`] rather than establishing a duplicate + /// remote connection to the same server + pub async fn open_channel( + &mut self, + connection_id: ConnectionId, + ) -> io::Result { + trace!("open_channel({})", connection_id); + if let Some(handle) = self.distant_clients.get(&connection_id) { + Ok(handle.client.clone_channel()) + } else { + let RawDistantChannel { + transport, + forward_task, + mailbox_task, + } = self.open_raw_channel(connection_id).await?; + let (writer, reader) = transport.into_split(); + let client = DistantClient::new(writer, reader)?; + let channel = client.clone_channel(); + self.distant_clients.insert( + connection_id, + ClientHandle { + client, + forward_task, + mailbox_task, + }, + ); + Ok(channel) + } + } + + /// Establishes a channel with the server represented by the `connection_id`, + /// returning a [`Transport`] acting as the connection + /// + /// ### Note + /// + /// Multiple calls to open a channel against the same connection will result in establishing a + /// duplicate remote connections to the same server, so take care when using this method + pub async fn open_raw_channel( + &mut self, + connection_id: ConnectionId, + ) -> io::Result { + trace!("open_raw_channel({})", connection_id); + let mut mailbox = self + .client + .mail(ManagerRequest::OpenChannel { id: connection_id }) + .await?; + + // Wait for the first response, which should be channel confirmation + let channel_id = match mailbox.next().await { + Some(response) => match response.payload { + ManagerResponse::ChannelOpened { id } => Ok(id), + ManagerResponse::Error(x) => Err(x.into()), + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + }, + None => Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "open_channel mailbox aborted", + )), + }?; + + // Spawn reader and writer tasks to forward requests and replies + // using our opened channel + let (t1, t2) = MpscTransport::pair(1); + let (mut writer, mut reader) = t1.into_split(); + let mailbox_task = tokio::spawn(async move { + use distant_net::TypedAsyncWrite; + while let Some(response) = mailbox.next().await { + match response.payload { + ManagerResponse::Channel { response, .. } => { + if let Err(x) = writer.write(response).await { + error!("[Conn {}] {}", connection_id, x); + } + } + ManagerResponse::ChannelClosed { .. } => break, + _ => continue, + } + } + }); + + let mut manager_channel = self.client.clone_channel(); + let forward_task = tokio::spawn(async move { + use distant_net::TypedAsyncRead; + loop { + match reader.read().await { + Ok(Some(request)) => { + // NOTE: In this situation, we do not expect a response to this + // request (even if the server sends something back) + if let Err(x) = manager_channel + .fire(ManagerRequest::Channel { + id: channel_id, + request, + }) + .await + { + error!("[Conn {}] {}", connection_id, x); + } + } + Ok(None) => break, + Err(x) => { + error!("[Conn {}] {}", connection_id, x); + continue; + } + } + } + }); + + Ok(RawDistantChannel { + transport: t2, + forward_task, + mailbox_task, + }) + } + + /// Retrieves information about a specific connection + pub async fn info(&mut self, id: ConnectionId) -> io::Result { + trace!("info({})", id); + let res = self.client.send(ManagerRequest::Info { id }).await?; + match res.payload { + ManagerResponse::Info(info) => Ok(info), + ManagerResponse::Error(x) => Err(x.into()), + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + } + } + + /// Kills the specified connection + pub async fn kill(&mut self, id: ConnectionId) -> io::Result<()> { + trace!("kill({})", id); + let res = self.client.send(ManagerRequest::Kill { id }).await?; + match res.payload { + ManagerResponse::Killed => Ok(()), + ManagerResponse::Error(x) => Err(x.into()), + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + } + } + + /// Retrieves a list of active connections + pub async fn list(&mut self) -> io::Result { + trace!("list()"); + let res = self.client.send(ManagerRequest::List).await?; + match res.payload { + ManagerResponse::List(list) => Ok(list), + ManagerResponse::Error(x) => Err(x.into()), + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + } + } + + /// Requests that the manager shuts down + pub async fn shutdown(&mut self) -> io::Result<()> { + trace!("shutdown()"); + let res = self.client.send(ManagerRequest::Shutdown).await?; + match res.payload { + ManagerResponse::Shutdown => Ok(()), + ManagerResponse::Error(x) => Err(x.into()), + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::data::{Error, ErrorKind}; + use distant_net::{ + FramedTransport, InmemoryTransport, PlainCodec, UntypedTransportRead, UntypedTransportWrite, + }; + + fn setup() -> ( + DistantManagerClient, + FramedTransport, + ) { + let (t1, t2) = FramedTransport::pair(100); + let client = + DistantManagerClient::new(DistantManagerClientConfig::with_empty_prompts(), t1) + .unwrap(); + (client, t2) + } + + #[inline] + fn test_error() -> Error { + Error { + kind: ErrorKind::Interrupted, + description: "test error".to_string(), + } + } + + #[inline] + fn test_io_error() -> io::Error { + test_error().into() + } + + #[tokio::test] + async fn connect_should_report_error_if_receives_error_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new( + request.id, + ManagerResponse::Error(test_error()), + )) + .await + .unwrap(); + }); + + let err = client + .connect( + "scheme://host".parse::().unwrap(), + "key=value".parse::().unwrap(), + ) + .await + .unwrap_err(); + assert_eq!(err.kind(), test_io_error().kind()); + assert_eq!(err.to_string(), test_io_error().to_string()); + } + + #[tokio::test] + async fn connect_should_report_error_if_receives_unexpected_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new(request.id, ManagerResponse::Shutdown)) + .await + .unwrap(); + }); + + let err = client + .connect( + "scheme://host".parse::().unwrap(), + "key=value".parse::().unwrap(), + ) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn connect_should_return_id_from_successful_response() { + let (mut client, mut transport) = setup(); + + let expected_id = 999; + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new( + request.id, + ManagerResponse::Connected { id: expected_id }, + )) + .await + .unwrap(); + }); + + let id = client + .connect( + "scheme://host".parse::().unwrap(), + "key=value".parse::().unwrap(), + ) + .await + .unwrap(); + assert_eq!(id, expected_id); + } + + #[tokio::test] + async fn info_should_report_error_if_receives_error_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new( + request.id, + ManagerResponse::Error(test_error()), + )) + .await + .unwrap(); + }); + + let err = client.info(123).await.unwrap_err(); + assert_eq!(err.kind(), test_io_error().kind()); + assert_eq!(err.to_string(), test_io_error().to_string()); + } + + #[tokio::test] + async fn info_should_report_error_if_receives_unexpected_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new(request.id, ManagerResponse::Shutdown)) + .await + .unwrap(); + }); + + let err = client.info(123).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn info_should_return_connection_info_from_successful_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + let info = ConnectionInfo { + id: 123, + destination: "scheme://host".parse::().unwrap(), + extra: "key=value".parse::().unwrap(), + }; + + transport + .write(Response::new(request.id, ManagerResponse::Info(info))) + .await + .unwrap(); + }); + + let info = client.info(123).await.unwrap(); + assert_eq!(info.id, 123); + assert_eq!( + info.destination, + "scheme://host".parse::().unwrap() + ); + assert_eq!(info.extra, "key=value".parse::().unwrap()); + } + + #[tokio::test] + async fn list_should_report_error_if_receives_error_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new( + request.id, + ManagerResponse::Error(test_error()), + )) + .await + .unwrap(); + }); + + let err = client.list().await.unwrap_err(); + assert_eq!(err.kind(), test_io_error().kind()); + assert_eq!(err.to_string(), test_io_error().to_string()); + } + + #[tokio::test] + async fn list_should_report_error_if_receives_unexpected_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new(request.id, ManagerResponse::Shutdown)) + .await + .unwrap(); + }); + + let err = client.list().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn list_should_return_connection_list_from_successful_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + let mut list = ConnectionList::new(); + list.insert(123, "scheme://host".parse::().unwrap()); + + transport + .write(Response::new(request.id, ManagerResponse::List(list))) + .await + .unwrap(); + }); + + let list = client.list().await.unwrap(); + assert_eq!(list.len(), 1); + assert_eq!( + list.get(&123).expect("Connection list missing item"), + &"scheme://host".parse::().unwrap() + ); + } + + #[tokio::test] + async fn kill_should_report_error_if_receives_error_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new( + request.id, + ManagerResponse::Error(test_error()), + )) + .await + .unwrap(); + }); + + let err = client.kill(123).await.unwrap_err(); + assert_eq!(err.kind(), test_io_error().kind()); + assert_eq!(err.to_string(), test_io_error().to_string()); + } + + #[tokio::test] + async fn kill_should_report_error_if_receives_unexpected_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new(request.id, ManagerResponse::Shutdown)) + .await + .unwrap(); + }); + + let err = client.kill(123).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn kill_should_return_success_from_successful_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new(request.id, ManagerResponse::Killed)) + .await + .unwrap(); + }); + + client.kill(123).await.unwrap(); + } + + #[tokio::test] + async fn shutdown_should_report_error_if_receives_error_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new( + request.id, + ManagerResponse::Connected { id: 0 }, + )) + .await + .unwrap(); + }); + + let err = client.shutdown().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn shutdown_should_report_error_if_receives_unexpected_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new( + request.id, + ManagerResponse::Error(test_error()), + )) + .await + .unwrap(); + }); + + let err = client.shutdown().await.unwrap_err(); + assert_eq!(err.kind(), test_io_error().kind()); + assert_eq!(err.to_string(), test_io_error().to_string()); + } + + #[tokio::test] + async fn shutdown_should_return_success_from_successful_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read::>() + .await + .unwrap() + .unwrap(); + + transport + .write(Response::new(request.id, ManagerResponse::Shutdown)) + .await + .unwrap(); + }); + + client.shutdown().await.unwrap(); + } +} diff --git a/distant-core/src/manager/client/config.rs b/distant-core/src/manager/client/config.rs new file mode 100644 index 0000000..cc0e03e --- /dev/null +++ b/distant-core/src/manager/client/config.rs @@ -0,0 +1,85 @@ +use distant_net::{AuthChallengeFn, AuthErrorFn, AuthInfoFn, AuthVerifyFn, AuthVerifyKind}; +use log::*; +use std::io; + +/// Configuration to use when creating a new [`DistantManagerClient`](super::DistantManagerClient) +pub struct DistantManagerClientConfig { + pub on_challenge: Box, + pub on_verify: Box, + pub on_info: Box, + pub on_error: Box, +} + +impl DistantManagerClientConfig { + /// Creates a new config with prompts that return empty strings + pub fn with_empty_prompts() -> Self { + Self::with_prompts(|_| Ok("".to_string()), |_| Ok("".to_string())) + } + + /// Creates a new config with two prompts + /// + /// * `password_prompt` - used for prompting for a secret, and should not display what is typed + /// * `text_prompt` - used for general text, and is okay to display what is typed + pub fn with_prompts(password_prompt: PP, text_prompt: PT) -> Self + where + PP: Fn(&str) -> io::Result + Send + Sync + 'static, + PT: Fn(&str) -> io::Result + Send + Sync + 'static, + { + Self { + on_challenge: Box::new(move |questions, _extra| { + trace!("[manager client] on_challenge({questions:?}, {_extra:?})"); + let mut answers = Vec::new(); + for question in questions.iter() { + // Contains all prompt lines including same line + let mut lines = question.text.split('\n').collect::>(); + + // Line that is prompt on same line as answer + let line = lines.pop().unwrap(); + + // Go ahead and display all other lines + for line in lines.into_iter() { + eprintln!("{}", line); + } + + // Get an answer from user input, or use a blank string as an answer + // if we fail to get input from the user + let answer = password_prompt(line).unwrap_or_default(); + + answers.push(answer); + } + answers + }), + on_verify: Box::new(move |kind, text| { + trace!("[manager client] on_verify({kind}, {text})"); + match kind { + AuthVerifyKind::Host => { + eprintln!("{}", text); + + match text_prompt("Enter [y/N]> ") { + Ok(answer) => { + trace!("Verify? Answer = '{answer}'"); + matches!(answer.trim(), "y" | "Y" | "yes" | "YES") + } + Err(x) => { + error!("Failed verification: {x}"); + false + } + } + } + x => { + error!("Unsupported verify kind: {x}"); + false + } + } + }), + on_info: Box::new(|text| { + trace!("[manager client] on_info({text})"); + println!("{}", text); + }), + on_error: Box::new(|kind, text| { + trace!("[manager client] on_error({kind}, {text})"); + eprintln!("{}: {}", kind, text); + }), + } + } +} diff --git a/distant-core/src/manager/client/ext.rs b/distant-core/src/manager/client/ext.rs new file mode 100644 index 0000000..d23a3d2 --- /dev/null +++ b/distant-core/src/manager/client/ext.rs @@ -0,0 +1,14 @@ +mod tcp; +pub use tcp::*; + +#[cfg(unix)] +mod unix; + +#[cfg(unix)] +pub use unix::*; + +#[cfg(windows)] +mod windows; + +#[cfg(windows)] +pub use windows::*; diff --git a/distant-core/src/manager/client/ext/tcp.rs b/distant-core/src/manager/client/ext/tcp.rs new file mode 100644 index 0000000..e31ffc1 --- /dev/null +++ b/distant-core/src/manager/client/ext/tcp.rs @@ -0,0 +1,50 @@ +use crate::{DistantManagerClient, DistantManagerClientConfig}; +use async_trait::async_trait; +use distant_net::{Codec, FramedTransport, TcpTransport}; +use std::{convert, net::SocketAddr}; +use tokio::{io, time::Duration}; + +#[async_trait] +pub trait TcpDistantManagerClientExt { + /// Connect to a remote TCP server using the provided information + async fn connect( + config: DistantManagerClientConfig, + addr: SocketAddr, + codec: C, + ) -> io::Result + where + C: Codec + Send + 'static; + + /// Connect to a remote TCP server, timing out after duration has passed + async fn connect_timeout( + config: DistantManagerClientConfig, + addr: SocketAddr, + codec: C, + duration: Duration, + ) -> io::Result + where + C: Codec + Send + 'static, + { + tokio::time::timeout(duration, Self::connect(config, addr, codec)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity) + } +} + +#[async_trait] +impl TcpDistantManagerClientExt for DistantManagerClient { + /// Connect to a remote TCP server using the provided information + async fn connect( + config: DistantManagerClientConfig, + addr: SocketAddr, + codec: C, + ) -> io::Result + where + C: Codec + Send + 'static, + { + let transport = TcpTransport::connect(addr).await?; + let transport = FramedTransport::new(transport, codec); + Self::new(config, transport) + } +} diff --git a/distant-core/src/manager/client/ext/unix.rs b/distant-core/src/manager/client/ext/unix.rs new file mode 100644 index 0000000..18df8c8 --- /dev/null +++ b/distant-core/src/manager/client/ext/unix.rs @@ -0,0 +1,54 @@ +use crate::{DistantManagerClient, DistantManagerClientConfig}; +use async_trait::async_trait; +use distant_net::{Codec, FramedTransport, UnixSocketTransport}; +use std::{convert, path::Path}; +use tokio::{io, time::Duration}; + +#[async_trait] +pub trait UnixSocketDistantManagerClientExt { + /// Connect to a proxy unix socket + async fn connect( + config: DistantManagerClientConfig, + path: P, + codec: C, + ) -> io::Result + where + P: AsRef + Send, + C: Codec + Send + 'static; + + /// Connect to a proxy unix socket, timing out after duration has passed + async fn connect_timeout( + config: DistantManagerClientConfig, + path: P, + codec: C, + duration: Duration, + ) -> io::Result + where + P: AsRef + Send, + C: Codec + Send + 'static, + { + tokio::time::timeout(duration, Self::connect(config, path, codec)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity) + } +} + +#[async_trait] +impl UnixSocketDistantManagerClientExt for DistantManagerClient { + /// Connect to a proxy unix socket + async fn connect( + config: DistantManagerClientConfig, + path: P, + codec: C, + ) -> io::Result + where + P: AsRef + Send, + C: Codec + Send + 'static, + { + let p = path.as_ref(); + let transport = UnixSocketTransport::connect(p).await?; + let transport = FramedTransport::new(transport, codec); + Ok(DistantManagerClient::new(config, transport)?) + } +} diff --git a/distant-core/src/manager/client/ext/windows.rs b/distant-core/src/manager/client/ext/windows.rs new file mode 100644 index 0000000..c13f10a --- /dev/null +++ b/distant-core/src/manager/client/ext/windows.rs @@ -0,0 +1,91 @@ +use crate::{DistantManagerClient, DistantManagerClientConfig}; +use async_trait::async_trait; +use distant_net::{Codec, FramedTransport, WindowsPipeTransport}; +use std::{ + convert, + ffi::{OsStr, OsString}, +}; +use tokio::{io, time::Duration}; + +#[async_trait] +pub trait WindowsPipeDistantManagerClientExt { + /// Connect to a server listening on a Windows pipe at the specified address + /// using the given codec + async fn connect( + config: DistantManagerClientConfig, + addr: A, + codec: C, + ) -> io::Result + where + A: AsRef + Send, + C: Codec + Send + 'static; + + /// Connect to a server listening on a Windows pipe at the specified address + /// via `\\.\pipe\{name}` using the given codec + async fn connect_local( + config: DistantManagerClientConfig, + name: N, + codec: C, + ) -> io::Result + where + N: AsRef + Send, + C: Codec + Send + 'static, + { + let mut addr = OsString::from(r"\\.\pipe\"); + addr.push(name.as_ref()); + Self::connect(config, addr, codec).await + } + + /// Connect to a server listening on a Windows pipe at the specified address + /// using the given codec, timing out after duration has passed + async fn connect_timeout( + config: DistantManagerClientConfig, + addr: A, + codec: C, + duration: Duration, + ) -> io::Result + where + A: AsRef + Send, + C: Codec + Send + 'static, + { + tokio::time::timeout(duration, Self::connect(config, addr, codec)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity) + } + + /// Connect to a server listening on a Windows pipe at the specified address + /// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed + async fn connect_local_timeout( + config: DistantManagerClientConfig, + name: N, + codec: C, + duration: Duration, + ) -> io::Result + where + N: AsRef + Send, + C: Codec + Send + 'static, + { + let mut addr = OsString::from(r"\\.\pipe\"); + addr.push(name.as_ref()); + Self::connect_timeout(config, addr, codec, duration).await + } +} + +#[async_trait] +impl WindowsPipeDistantManagerClientExt for DistantManagerClient { + async fn connect( + config: DistantManagerClientConfig, + addr: A, + codec: C, + ) -> io::Result + where + A: AsRef + Send, + C: Codec + Send + 'static, + { + let a = addr.as_ref(); + let transport = WindowsPipeTransport::connect(a).await?; + let transport = FramedTransport::new(transport, codec); + Ok(DistantManagerClient::new(config, transport)?) + } +} diff --git a/distant-core/src/manager/data.rs b/distant-core/src/manager/data.rs new file mode 100644 index 0000000..0be4b2a --- /dev/null +++ b/distant-core/src/manager/data.rs @@ -0,0 +1,20 @@ +mod destination; +pub use destination::*; + +mod extra; +pub use extra::*; + +mod id; +pub use id::*; + +mod info; +pub use info::*; + +mod list; +pub use list::*; + +mod request; +pub use request::*; + +mod response; +pub use response::*; diff --git a/distant-core/src/manager/data/destination.rs b/distant-core/src/manager/data/destination.rs new file mode 100644 index 0000000..3dbbebe --- /dev/null +++ b/distant-core/src/manager/data/destination.rs @@ -0,0 +1,266 @@ +use crate::serde_str::{deserialize_from_str, serialize_to_str}; +use derive_more::{Display, Error, From}; +use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize}; +use std::{convert::TryFrom, fmt, hash::Hash, str::FromStr}; +use uriparse::{ + Authority, AuthorityError, Host, Password, Scheme, URIReference, URIReferenceError, Username, + URI, +}; + +/// Represents an error that occurs when trying to parse a destination from a str +#[derive(Copy, Clone, Debug, Display, Error, From, PartialEq, Eq)] +pub enum DestinationError { + MissingHost, + URIReferenceError(URIReferenceError), +} + +/// `distant` connects and logs into the specified destination, which may be specified as either +/// `hostname:port` where an attempt to connect to a **distant** server will be made, or a URI of +/// one of the following forms: +/// +/// * `distant://hostname:port` - connect to a distant server +/// * `ssh://[user@]hostname[:port]` - connect to an SSH server +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct Destination(URIReference<'static>); + +impl Destination { + /// Returns a reference to the scheme associated with the destination, if it has one + pub fn scheme(&self) -> Option<&str> { + self.0.scheme().map(Scheme::as_str) + } + + /// Returns the host of the destination as a string + pub fn to_host_string(&self) -> String { + // NOTE: We guarantee that there is a host for a destination during construction + self.0.host().unwrap().to_string() + } + + /// Returns the port tied to the destination, if it has one + pub fn port(&self) -> Option { + self.0.port() + } + + /// Returns the username tied with the destination if it has one + pub fn username(&self) -> Option<&str> { + self.0.username().map(Username::as_str) + } + + /// Returns the password tied with the destination if it has one + pub fn password(&self) -> Option<&str> { + self.0.password().map(Password::as_str) + } + + /// Replaces the host of the destination + pub fn replace_host(&mut self, host: &str) -> Result<(), URIReferenceError> { + let username = self + .0 + .username() + .map(Username::as_borrowed) + .map(Username::into_owned); + let password = self + .0 + .password() + .map(Password::as_borrowed) + .map(Password::into_owned); + let port = self.0.port(); + let _ = self.0.set_authority(Some( + Authority::from_parts( + username, + password, + Host::try_from(host) + .map(Host::into_owned) + .map_err(AuthorityError::from) + .map_err(URIReferenceError::from)?, + port, + ) + .map(Authority::into_owned) + .map_err(URIReferenceError::from)?, + ))?; + Ok(()) + } + + /// Indicates whether the host destination is globally routable + pub fn is_host_global(&self) -> bool { + match self.0.host() { + Some(Host::IPv4Address(x)) => { + !(x.is_broadcast() + || x.is_documentation() + || x.is_link_local() + || x.is_loopback() + || x.is_private() + || x.is_unspecified()) + } + Some(Host::IPv6Address(x)) => { + // NOTE: 14 is the global flag + x.is_multicast() && (x.segments()[0] & 0x000f == 14) + } + Some(Host::RegisteredName(name)) => !name.trim().is_empty(), + None => false, + } + } + + /// Returns true if destination represents a distant server + pub fn is_distant(&self) -> bool { + self.scheme_eq("distant") + } + + /// Returns true if destination represents an ssh server + pub fn is_ssh(&self) -> bool { + self.scheme_eq("ssh") + } + + fn scheme_eq(&self, s: &str) -> bool { + match self.0.scheme() { + Some(scheme) => scheme.as_str().eq_ignore_ascii_case(s), + None => false, + } + } + + /// Returns reference to inner [`URIReference`] + pub fn as_uri_ref(&self) -> &URIReference<'static> { + &self.0 + } +} + +impl AsRef for &Destination { + fn as_ref(&self) -> &Destination { + *self + } +} + +impl AsRef> for Destination { + fn as_ref(&self) -> &URIReference<'static> { + self.as_uri_ref() + } +} + +impl fmt::Display for Destination { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl FromStr for Destination { + type Err = DestinationError; + + fn from_str(s: &str) -> Result { + // Disallow empty (whitespace-only) input as that passes our + // parsing for a URI reference (relative with no scheme or anything) + if s.trim().is_empty() { + return Err(DestinationError::MissingHost); + } + + let mut destination = URIReference::try_from(s) + .map(URIReference::into_owned) + .map(Destination) + .map_err(DestinationError::URIReferenceError)?; + + // Only support relative reference if it is a path reference as + // we convert that to a relative reference with a host + if destination.0.is_relative_reference() { + let path = destination.0.path().to_string(); + destination.replace_host(path.as_str())?; + let _ = destination.0.set_path("/")?; + } + + Ok(destination) + } +} + +impl<'a> TryFrom> for Destination { + type Error = DestinationError; + + fn try_from(uri_ref: URIReference<'a>) -> Result { + if uri_ref.host().is_none() { + return Err(DestinationError::MissingHost); + } + + Ok(Self(uri_ref.into_owned())) + } +} + +impl<'a> TryFrom> for Destination { + type Error = DestinationError; + + fn try_from(uri: URI<'a>) -> Result { + let uri_ref: URIReference<'a> = uri.into(); + Self::try_from(uri_ref) + } +} + +impl FromStr for Box { + type Err = DestinationError; + + fn from_str(s: &str) -> Result { + let destination = s.parse::()?; + Ok(Box::new(destination)) + } +} + +impl Serialize for Destination { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serialize_to_str(self, serializer) + } +} + +impl<'de> Deserialize<'de> for Destination { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserialize_from_str(deserializer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_should_fail_if_string_is_only_whitespace() { + let err = "".parse::().unwrap_err(); + assert_eq!(err, DestinationError::MissingHost); + + let err = " ".parse::().unwrap_err(); + assert_eq!(err, DestinationError::MissingHost); + + let err = "\t".parse::().unwrap_err(); + assert_eq!(err, DestinationError::MissingHost); + + let err = "\n".parse::().unwrap_err(); + assert_eq!(err, DestinationError::MissingHost); + + let err = "\r".parse::().unwrap_err(); + assert_eq!(err, DestinationError::MissingHost); + + let err = "\r\n".parse::().unwrap_err(); + assert_eq!(err, DestinationError::MissingHost); + } + + #[test] + fn parse_should_succeed_with_valid_uri() { + let destination = "distant://localhost".parse::().unwrap(); + assert_eq!(destination.scheme().unwrap(), "distant"); + assert_eq!(destination.to_host_string(), "localhost"); + assert_eq!(destination.as_uri_ref().path().to_string(), "/"); + } + + #[test] + fn parse_should_fail_if_relative_reference_that_is_not_valid_host() { + let _ = "/".parse::().unwrap_err(); + let _ = "/localhost".parse::().unwrap_err(); + let _ = "my/path".parse::().unwrap_err(); + let _ = "/my/path".parse::().unwrap_err(); + let _ = "//localhost".parse::().unwrap_err(); + } + + #[test] + fn parse_should_succeed_with_nonempty_relative_reference_by_setting_host_to_path() { + let destination = "localhost".parse::().unwrap(); + assert_eq!(destination.to_host_string(), "localhost"); + assert_eq!(destination.as_uri_ref().path().to_string(), "/"); + } +} diff --git a/distant-core/src/manager/data/extra.rs b/distant-core/src/manager/data/extra.rs new file mode 100644 index 0000000..1e3e89a --- /dev/null +++ b/distant-core/src/manager/data/extra.rs @@ -0,0 +1,2 @@ +/// Represents extra data included for connections +pub type Extra = crate::data::Map; diff --git a/distant-core/src/manager/data/id.rs b/distant-core/src/manager/data/id.rs new file mode 100644 index 0000000..34abc0d --- /dev/null +++ b/distant-core/src/manager/data/id.rs @@ -0,0 +1,5 @@ +/// Id associated with an active connection +pub type ConnectionId = u64; + +/// Id associated with an open channel +pub type ChannelId = u64; diff --git a/distant-core/src/manager/data/info.rs b/distant-core/src/manager/data/info.rs new file mode 100644 index 0000000..517b0cc --- /dev/null +++ b/distant-core/src/manager/data/info.rs @@ -0,0 +1,15 @@ +use super::{ConnectionId, Destination, Extra}; +use serde::{Deserialize, Serialize}; + +/// Information about a specific connection +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct ConnectionInfo { + /// Connection's id + pub id: ConnectionId, + + /// Destination with which this connection is associated + pub destination: Destination, + + /// Extra information associated with this connection + pub extra: Extra, +} diff --git a/distant-core/src/manager/data/list.rs b/distant-core/src/manager/data/list.rs new file mode 100644 index 0000000..dfbed3d --- /dev/null +++ b/distant-core/src/manager/data/list.rs @@ -0,0 +1,58 @@ +use super::{ConnectionId, Destination}; +use derive_more::IntoIterator; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + ops::{Deref, DerefMut, Index, IndexMut}, +}; + +/// Represents a list of information about active connections +#[derive(Clone, Debug, PartialEq, Eq, IntoIterator, Serialize, Deserialize)] +pub struct ConnectionList(pub(crate) HashMap); + +impl ConnectionList { + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// Returns a reference to the destination associated with an active connection + pub fn connection_destination(&self, id: ConnectionId) -> Option<&Destination> { + self.0.get(&id) + } +} + +impl Default for ConnectionList { + fn default() -> Self { + Self::new() + } +} + +impl Deref for ConnectionList { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ConnectionList { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Index for ConnectionList { + type Output = Destination; + + fn index(&self, connection_id: u64) -> &Self::Output { + &self.0[&connection_id] + } +} + +impl IndexMut for ConnectionList { + fn index_mut(&mut self, connection_id: u64) -> &mut Self::Output { + self.0 + .get_mut(&connection_id) + .expect("No connection with id") + } +} diff --git a/distant-core/src/manager/data/request.rs b/distant-core/src/manager/data/request.rs new file mode 100644 index 0000000..2d29c35 --- /dev/null +++ b/distant-core/src/manager/data/request.rs @@ -0,0 +1,72 @@ +use super::{ChannelId, ConnectionId, Destination, Extra}; +use crate::{DistantMsg, DistantRequestData}; +use distant_net::Request; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[cfg_attr(feature = "clap", derive(clap::Subcommand))] +#[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")] +pub enum ManagerRequest { + /// Launch a server using the manager + Launch { + // NOTE: Boxed per clippy's large_enum_variant warning + destination: Box, + + /// Extra details specific to the connection + #[cfg_attr(feature = "clap", clap(short, long, action = clap::ArgAction::Append))] + extra: Extra, + }, + + /// Initiate a connection through the manager + Connect { + // NOTE: Boxed per clippy's large_enum_variant warning + destination: Box, + + /// Extra details specific to the connection + #[cfg_attr(feature = "clap", clap(short, long, action = clap::ArgAction::Append))] + extra: Extra, + }, + + /// Opens a channel for communication with a server + #[cfg_attr(feature = "clap", clap(skip))] + OpenChannel { + /// Id of the connection + id: ConnectionId, + }, + + /// Sends data through channel + #[cfg_attr(feature = "clap", clap(skip))] + Channel { + /// Id of the channel + id: ChannelId, + + /// Request to send to through the channel + #[cfg_attr(feature = "clap", clap(skip = skipped_request()))] + request: Request>, + }, + + /// Closes an open channel + #[cfg_attr(feature = "clap", clap(skip))] + CloseChannel { + /// Id of the channel to close + id: ChannelId, + }, + + /// Retrieve information about a specific connection + Info { id: ConnectionId }, + + /// Kill a specific connection + Kill { id: ConnectionId }, + + /// Retrieve list of connections being managed + List, + + /// Signals the manager to shutdown + Shutdown, +} + +/// Produces some default request, purely to satisfy clap +#[cfg(feature = "clap")] +fn skipped_request() -> Request> { + Request::new(DistantMsg::Single(DistantRequestData::SystemInfo {})) +} diff --git a/distant-core/src/manager/data/response.rs b/distant-core/src/manager/data/response.rs new file mode 100644 index 0000000..7a53260 --- /dev/null +++ b/distant-core/src/manager/data/response.rs @@ -0,0 +1,53 @@ +use crate::{data::Error, ConnectionInfo, ConnectionList, Destination}; +use crate::{ChannelId, ConnectionId, DistantMsg, DistantResponseData}; +use distant_net::Response; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")] +pub enum ManagerResponse { + /// Acknowledgement that a connection was killed + Killed, + + /// Broadcast that the manager is shutting down (not guaranteed to be sent) + Shutdown, + + /// Indicates that some error occurred during a request + Error(Error), + + /// Confirmation of a distant server being launched + Launched { + /// Updated location of the spawned server + destination: Destination, + }, + + /// Confirmation of a connection being established + Connected { id: ConnectionId }, + + /// Information about a specific connection + Info(ConnectionInfo), + + /// List of connections in the form of id -> destination + List(ConnectionList), + + /// Forward a response back to a specific channel that made a request + Channel { + /// Id of the channel + id: ChannelId, + + /// Response to an earlier channel request + response: Response>, + }, + + /// Indicates that a channel has been opened + ChannelOpened { + /// Id of the channel + id: ChannelId, + }, + + /// Indicates that a channel has been closed + ChannelClosed { + /// Id of the channel + id: ChannelId, + }, +} diff --git a/distant-core/src/manager/server.rs b/distant-core/src/manager/server.rs new file mode 100644 index 0000000..e1c5da2 --- /dev/null +++ b/distant-core/src/manager/server.rs @@ -0,0 +1,698 @@ +use crate::{ + ChannelId, ConnectionId, ConnectionInfo, ConnectionList, Destination, Extra, ManagerRequest, + ManagerResponse, +}; +use async_trait::async_trait; +use distant_net::{ + router, Auth, AuthClient, Client, IntoSplit, Listener, MpscListener, Request, Response, Server, + ServerCtx, ServerExt, UntypedTransportRead, UntypedTransportWrite, +}; +use log::*; +use std::{collections::HashMap, io, sync::Arc}; +use tokio::{ + sync::{mpsc, Mutex, RwLock}, + task::JoinHandle, +}; + +mod config; +pub use config::*; + +mod connection; +pub use connection::*; + +mod ext; +pub use ext::*; + +mod handler; +pub use handler::*; + +mod r#ref; +pub use r#ref::*; + +router!(DistantManagerRouter { + auth_transport: Response => Request, + manager_transport: Request => Response, +}); + +/// Represents a manager of multiple distant server connections +pub struct DistantManager { + /// Receives authentication clients to feed into local data of server + auth_client_rx: Mutex>, + + /// Configuration settings for the server + config: DistantManagerConfig, + + /// Mapping of connection id -> connection + connections: RwLock>, + + /// Handlers for launch requests + launch_handlers: Arc>>, + + /// Handlers for connect requests + connect_handlers: Arc>>, + + /// Primary task of server + task: JoinHandle<()>, +} + +impl DistantManager { + /// Initializes a new instance of [`DistantManagerServer`] using the provided [`UntypedTransport`] + pub fn start( + mut config: DistantManagerConfig, + mut listener: L, + ) -> io::Result + where + L: Listener + 'static, + T: IntoSplit + Send + 'static, + T::Read: UntypedTransportRead + 'static, + T::Write: UntypedTransportWrite + 'static, + { + let (conn_tx, mpsc_listener) = MpscListener::channel(config.connection_buffer_size); + let (auth_client_tx, auth_client_rx) = mpsc::channel(1); + + // Spawn task that uses our input listener to get both auth and manager events, + // forwarding manager events to the internal mpsc listener + let task = tokio::spawn(async move { + while let Ok(transport) = listener.accept().await { + let DistantManagerRouter { + auth_transport, + manager_transport, + .. + } = DistantManagerRouter::new(transport); + + let (writer, reader) = auth_transport.into_split(); + let client = match Client::new(writer, reader) { + Ok(client) => client, + Err(x) => { + error!("Creating auth client failed: {}", x); + continue; + } + }; + let auth_client = AuthClient::from(client); + + // Forward auth client for new connection in server + if auth_client_tx.send(auth_client).await.is_err() { + break; + } + + // Forward connected and routed transport to server + if conn_tx.send(manager_transport.into_split()).await.is_err() { + break; + } + } + }); + + let launch_handlers = Arc::new(RwLock::new(config.launch_handlers.drain().collect())); + let weak_launch_handlers = Arc::downgrade(&launch_handlers); + let connect_handlers = Arc::new(RwLock::new(config.connect_handlers.drain().collect())); + let weak_connect_handlers = Arc::downgrade(&connect_handlers); + let server_ref = Self { + auth_client_rx: Mutex::new(auth_client_rx), + config, + launch_handlers, + connect_handlers, + connections: RwLock::new(HashMap::new()), + task, + } + .start(mpsc_listener)?; + + Ok(DistantManagerRef { + launch_handlers: weak_launch_handlers, + connect_handlers: weak_connect_handlers, + inner: server_ref, + }) + } + + /// Launches a new server at the specified `destination` using the given `extra` information + /// and authentication client (if needed) to retrieve additional information needed to + /// enter the destination prior to starting the server, returning the destination of the + /// launched server + async fn launch( + &self, + destination: Destination, + extra: Extra, + auth: Option<&mut AuthClient>, + ) -> io::Result { + let auth = auth.ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "Authentication client not initialized", + ) + })?; + + let scheme = match destination.scheme() { + Some(scheme) => { + trace!("Using scheme {}", scheme); + scheme + } + None => { + trace!( + "Using fallback scheme of {}", + self.config.fallback_scheme.as_str() + ); + self.config.fallback_scheme.as_str() + } + } + .to_lowercase(); + + let credentials = { + let lock = self.launch_handlers.read().await; + let handler = lock.get(&scheme).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("No launch handler registered for {}", scheme), + ) + })?; + handler.launch(&destination, &extra, auth).await? + }; + + Ok(credentials) + } + + /// Connects to a new server at the specified `destination` using the given `extra` information + /// and authentication client (if needed) to retrieve additional information needed to + /// establish the connection to the server + async fn connect( + &self, + destination: Destination, + extra: Extra, + auth: Option<&mut AuthClient>, + ) -> io::Result { + let auth = auth.ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "Authentication client not initialized", + ) + })?; + + let scheme = match destination.scheme() { + Some(scheme) => { + trace!("Using scheme {}", scheme); + scheme + } + None => { + trace!( + "Using fallback scheme of {}", + self.config.fallback_scheme.as_str() + ); + self.config.fallback_scheme.as_str() + } + } + .to_lowercase(); + + let (writer, reader) = { + let lock = self.connect_handlers.read().await; + let handler = lock.get(&scheme).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("No connect handler registered for {}", scheme), + ) + })?; + handler.connect(&destination, &extra, auth).await? + }; + + let connection = DistantManagerConnection::new(destination, extra, writer, reader); + let id = connection.id; + self.connections.write().await.insert(id, connection); + Ok(id) + } + + /// Retrieves information about the connection to the server with the specified `id` + async fn info(&self, id: ConnectionId) -> io::Result { + match self.connections.read().await.get(&id) { + Some(connection) => Ok(ConnectionInfo { + id: connection.id, + destination: connection.destination.clone(), + extra: connection.extra.clone(), + }), + None => Err(io::Error::new( + io::ErrorKind::NotConnected, + "No connection found", + )), + } + } + + /// Retrieves a list of connections to servers + async fn list(&self) -> io::Result { + Ok(ConnectionList( + self.connections + .read() + .await + .values() + .map(|conn| (conn.id, conn.destination.clone())) + .collect(), + )) + } + + /// Kills the connection to the server with the specified `id` + async fn kill(&self, id: ConnectionId) -> io::Result<()> { + match self.connections.write().await.remove(&id) { + Some(_) => Ok(()), + None => Err(io::Error::new( + io::ErrorKind::NotConnected, + "No connection found", + )), + } + } +} + +#[derive(Default)] +pub struct DistantManagerServerConnection { + /// Authentication client that manager can use when establishing a new connection + /// and needing to get authentication details from the client to move forward + auth_client: Option>, + + /// Holds on to open channels feeding data back from a server to some connected client, + /// enabling us to cancel the tasks on demand + channels: RwLock>, +} + +#[async_trait] +impl Server for DistantManager { + type Request = ManagerRequest; + type Response = ManagerResponse; + type LocalData = DistantManagerServerConnection; + + async fn on_accept(&self, local_data: &mut Self::LocalData) { + local_data.auth_client = self + .auth_client_rx + .lock() + .await + .recv() + .await + .map(Mutex::new); + + // Enable jit handshake + if let Some(auth_client) = local_data.auth_client.as_ref() { + auth_client.lock().await.set_jit_handshake(true); + } + } + + async fn on_request(&self, ctx: ServerCtx) { + let ServerCtx { + connection_id, + request, + reply, + local_data, + } = ctx; + + let response = match request.payload { + ManagerRequest::Launch { destination, extra } => { + let mut auth = match local_data.auth_client.as_ref() { + Some(client) => Some(client.lock().await), + None => None, + }; + + match self.launch(*destination, extra, auth.as_deref_mut()).await { + Ok(destination) => ManagerResponse::Launched { destination }, + Err(x) => ManagerResponse::Error(x.into()), + } + } + ManagerRequest::Connect { destination, extra } => { + let mut auth = match local_data.auth_client.as_ref() { + Some(client) => Some(client.lock().await), + None => None, + }; + + match self.connect(*destination, extra, auth.as_deref_mut()).await { + Ok(id) => ManagerResponse::Connected { id }, + Err(x) => ManagerResponse::Error(x.into()), + } + } + ManagerRequest::OpenChannel { id } => match self.connections.read().await.get(&id) { + Some(connection) => match connection.open_channel(reply.clone()).await { + Ok(channel) => { + let id = channel.id(); + local_data.channels.write().await.insert(id, channel); + ManagerResponse::ChannelOpened { id } + } + Err(x) => ManagerResponse::Error(x.into()), + }, + None => ManagerResponse::Error( + io::Error::new(io::ErrorKind::NotConnected, "Connection does not exist").into(), + ), + }, + ManagerRequest::Channel { id, request } => { + match local_data.channels.read().await.get(&id) { + // TODO: For now, we are NOT sending back a response to acknowledge + // a successful channel send. We could do this in order for + // the client to listen for a complete send, but is it worth it? + Some(channel) => match channel.send(request).await { + Ok(_) => return, + Err(x) => ManagerResponse::Error(x.into()), + }, + None => ManagerResponse::Error( + io::Error::new( + io::ErrorKind::NotConnected, + "Channel is not open or does not exist", + ) + .into(), + ), + } + } + ManagerRequest::CloseChannel { id } => { + match local_data.channels.write().await.remove(&id) { + Some(channel) => match channel.close().await { + Ok(_) => ManagerResponse::ChannelClosed { id }, + Err(x) => ManagerResponse::Error(x.into()), + }, + None => ManagerResponse::Error( + io::Error::new( + io::ErrorKind::NotConnected, + "Channel is not open or does not exist", + ) + .into(), + ), + } + } + ManagerRequest::Info { id } => match self.info(id).await { + Ok(info) => ManagerResponse::Info(info), + Err(x) => ManagerResponse::Error(x.into()), + }, + ManagerRequest::List => match self.list().await { + Ok(list) => ManagerResponse::List(list), + Err(x) => ManagerResponse::Error(x.into()), + }, + ManagerRequest::Kill { id } => match self.kill(id).await { + Ok(()) => ManagerResponse::Killed, + Err(x) => ManagerResponse::Error(x.into()), + }, + ManagerRequest::Shutdown => { + if let Err(x) = reply.send(ManagerResponse::Shutdown).await { + error!("[Conn {}] {}", connection_id, x); + } + + // Clear out handler state in order to trigger drops + self.launch_handlers.write().await.clear(); + self.connect_handlers.write().await.clear(); + + // Shutdown the primary server task + self.task.abort(); + + // TODO: Perform a graceful shutdown instead of this? + // Review https://tokio.rs/tokio/topics/shutdown + std::process::exit(0); + } + }; + + if let Err(x) = reply.send(response).await { + error!("[Conn {}] {}", connection_id, x); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use distant_net::{ + AuthClient, FramedTransport, HeapAuthServer, InmemoryTransport, IntoSplit, MappedListener, + OneshotListener, PlainCodec, ServerExt, ServerRef, + }; + + /// Create a new server, bypassing the start loop + fn setup() -> DistantManager { + let (_, rx) = mpsc::channel(1); + DistantManager { + auth_client_rx: Mutex::new(rx), + config: Default::default(), + connections: RwLock::new(HashMap::new()), + launch_handlers: Arc::new(RwLock::new(HashMap::new())), + connect_handlers: Arc::new(RwLock::new(HashMap::new())), + task: tokio::spawn(async move {}), + } + } + + /// Creates a connected [`AuthClient`] with a launched auth server that blindly responds + fn auth_client_server() -> (AuthClient, Box) { + let (t1, t2) = FramedTransport::pair(1); + let client = AuthClient::from(Client::from_framed_transport(t1).unwrap()); + + // Create a server that does nothing, but will support + let server = HeapAuthServer { + on_challenge: Box::new(|_, _| Vec::new()), + on_verify: Box::new(|_, _| false), + on_info: Box::new(|_| ()), + on_error: Box::new(|_, _| ()), + } + .start(MappedListener::new(OneshotListener::from_value(t2), |t| { + t.into_split() + })) + .unwrap(); + + (client, server) + } + + fn dummy_distant_writer_reader() -> (BoxedDistantWriter, BoxedDistantReader) { + setup_distant_writer_reader().0 + } + + /// Creates a writer & reader with a connected transport + fn setup_distant_writer_reader() -> ( + (BoxedDistantWriter, BoxedDistantReader), + FramedTransport, + ) { + let (t1, t2) = FramedTransport::pair(1); + let (writer, reader) = t1.into_split(); + ((Box::new(writer), Box::new(reader)), t2) + } + + #[tokio::test] + async fn launch_should_fail_if_destination_scheme_is_unsupported() { + let server = setup(); + + let destination = "scheme://host".parse::().unwrap(); + let extra = "".parse::().unwrap(); + let (mut auth, _auth_server) = auth_client_server(); + let err = server + .launch(destination, extra, Some(&mut auth)) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err); + } + + #[tokio::test] + async fn launch_should_fail_if_handler_tied_to_scheme_fails() { + let server = setup(); + + let handler: Box = Box::new(|_: &_, _: &_, _: &mut _| async { + Err(io::Error::new(io::ErrorKind::Other, "test failure")) + }); + + server + .launch_handlers + .write() + .await + .insert("scheme".to_string(), handler); + + let destination = "scheme://host".parse::().unwrap(); + let extra = "".parse::().unwrap(); + let (mut auth, _auth_server) = auth_client_server(); + let err = server + .launch(destination, extra, Some(&mut auth)) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!(err.to_string(), "test failure"); + } + + #[tokio::test] + async fn launch_should_return_new_destination_on_success() { + let server = setup(); + + let handler: Box = { + Box::new(|_: &_, _: &_, _: &mut _| async { + Ok("scheme2://host2".parse::().unwrap()) + }) + }; + + server + .launch_handlers + .write() + .await + .insert("scheme".to_string(), handler); + + let destination = "scheme://host".parse::().unwrap(); + let extra = "key=value".parse::().unwrap(); + let (mut auth, _auth_server) = auth_client_server(); + let destination = server + .launch(destination, extra, Some(&mut auth)) + .await + .unwrap(); + + assert_eq!( + destination, + "scheme2://host2".parse::().unwrap() + ); + } + + #[tokio::test] + async fn connect_should_fail_if_destination_scheme_is_unsupported() { + let server = setup(); + + let destination = "scheme://host".parse::().unwrap(); + let extra = "".parse::().unwrap(); + let (mut auth, _auth_server) = auth_client_server(); + let err = server + .connect(destination, extra, Some(&mut auth)) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err); + } + + #[tokio::test] + async fn connect_should_fail_if_handler_tied_to_scheme_fails() { + let server = setup(); + + let handler: Box = Box::new(|_: &_, _: &_, _: &mut _| async { + Err(io::Error::new(io::ErrorKind::Other, "test failure")) + }); + + server + .connect_handlers + .write() + .await + .insert("scheme".to_string(), handler); + + let destination = "scheme://host".parse::().unwrap(); + let extra = "".parse::().unwrap(); + let (mut auth, _auth_server) = auth_client_server(); + let err = server + .connect(destination, extra, Some(&mut auth)) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!(err.to_string(), "test failure"); + } + + #[tokio::test] + async fn connect_should_return_id_of_new_connection_on_success() { + let server = setup(); + + let handler: Box = + Box::new(|_: &_, _: &_, _: &mut _| async { Ok(dummy_distant_writer_reader()) }); + + server + .connect_handlers + .write() + .await + .insert("scheme".to_string(), handler); + + let destination = "scheme://host".parse::().unwrap(); + let extra = "key=value".parse::().unwrap(); + let (mut auth, _auth_server) = auth_client_server(); + let id = server + .connect(destination, extra, Some(&mut auth)) + .await + .unwrap(); + + let lock = server.connections.read().await; + let connection = lock.get(&id).unwrap(); + assert_eq!(connection.id, id); + assert_eq!(connection.destination, "scheme://host".parse().unwrap()); + assert_eq!(connection.extra, "key=value".parse().unwrap()); + } + + #[tokio::test] + async fn info_should_fail_if_no_connection_found_for_specified_id() { + let server = setup(); + + let err = server.info(999).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err); + } + + #[tokio::test] + async fn info_should_return_information_about_established_connection() { + let server = setup(); + + let (writer, reader) = dummy_distant_writer_reader(); + let connection = DistantManagerConnection::new( + "scheme://host".parse().unwrap(), + "key=value".parse().unwrap(), + writer, + reader, + ); + let id = connection.id; + server.connections.write().await.insert(id, connection); + + let info = server.info(id).await.unwrap(); + assert_eq!( + info, + ConnectionInfo { + id, + destination: "scheme://host".parse().unwrap(), + extra: "key=value".parse().unwrap(), + } + ); + } + + #[tokio::test] + async fn list_should_return_empty_connection_list_if_no_established_connections() { + let server = setup(); + + let list = server.list().await.unwrap(); + assert_eq!(list, ConnectionList(HashMap::new())); + } + + #[tokio::test] + async fn list_should_return_a_list_of_established_connections() { + let server = setup(); + + let (writer, reader) = dummy_distant_writer_reader(); + let connection = DistantManagerConnection::new( + "scheme://host".parse().unwrap(), + "key=value".parse().unwrap(), + writer, + reader, + ); + let id_1 = connection.id; + server.connections.write().await.insert(id_1, connection); + + let (writer, reader) = dummy_distant_writer_reader(); + let connection = DistantManagerConnection::new( + "other://host2".parse().unwrap(), + "key=value".parse().unwrap(), + writer, + reader, + ); + let id_2 = connection.id; + server.connections.write().await.insert(id_2, connection); + + let list = server.list().await.unwrap(); + assert_eq!( + list.get(&id_1).unwrap(), + &"scheme://host".parse::().unwrap() + ); + assert_eq!( + list.get(&id_2).unwrap(), + &"other://host2".parse::().unwrap() + ); + } + + #[tokio::test] + async fn kill_should_fail_if_no_connection_found_for_specified_id() { + let server = setup(); + + let err = server.kill(999).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err); + } + + #[tokio::test] + async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() { + let server = setup(); + + let (writer, reader) = dummy_distant_writer_reader(); + let connection = DistantManagerConnection::new( + "scheme://host".parse().unwrap(), + "key=value".parse().unwrap(), + writer, + reader, + ); + let id = connection.id; + server.connections.write().await.insert(id, connection); + + server.kill(id).await.unwrap(); + + let lock = server.connections.read().await; + assert!(!lock.contains_key(&id), "Connection still exists"); + } +} diff --git a/distant-core/src/manager/server/config.rs b/distant-core/src/manager/server/config.rs new file mode 100644 index 0000000..207cb27 --- /dev/null +++ b/distant-core/src/manager/server/config.rs @@ -0,0 +1,31 @@ +use crate::{BoxedConnectHandler, BoxedLaunchHandler}; +use std::collections::HashMap; + +pub struct DistantManagerConfig { + /// Scheme to use when none is provided in a destination + pub fallback_scheme: String, + + /// Buffer size for queue of incoming connections before blocking + pub connection_buffer_size: usize, + + /// If listening as local user + pub user: bool, + + /// Handlers to use for launch requests + pub launch_handlers: HashMap, + + /// Handlers to use for connect requests + pub connect_handlers: HashMap, +} + +impl Default for DistantManagerConfig { + fn default() -> Self { + Self { + fallback_scheme: "distant".to_string(), + connection_buffer_size: 100, + user: false, + launch_handlers: HashMap::new(), + connect_handlers: HashMap::new(), + } + } +} diff --git a/distant-core/src/manager/server/connection.rs b/distant-core/src/manager/server/connection.rs new file mode 100644 index 0000000..112f1cf --- /dev/null +++ b/distant-core/src/manager/server/connection.rs @@ -0,0 +1,201 @@ +use crate::{ + manager::{ + data::{ChannelId, ConnectionId, Destination, Extra}, + BoxedDistantReader, BoxedDistantWriter, + }, + DistantMsg, DistantRequestData, DistantResponseData, ManagerResponse, +}; +use distant_net::{Request, Response, ServerReply}; +use log::*; +use std::{collections::HashMap, io}; +use tokio::{sync::mpsc, task::JoinHandle}; + +/// Represents a connection a distant manager has with some distant-compatible server +pub struct DistantManagerConnection { + pub id: ConnectionId, + pub destination: Destination, + pub extra: Extra, + tx: mpsc::Sender, + reader_task: JoinHandle<()>, + writer_task: JoinHandle<()>, +} + +#[derive(Clone)] +pub struct DistantManagerChannel { + channel_id: ChannelId, + tx: mpsc::Sender, +} + +impl DistantManagerChannel { + pub fn id(&self) -> ChannelId { + self.channel_id + } + + pub async fn send(&self, request: Request>) -> io::Result<()> { + let channel_id = self.channel_id; + self.tx + .send(StateMachine::Write { + id: channel_id, + request, + }) + .await + .map_err(|x| { + io::Error::new( + io::ErrorKind::BrokenPipe, + format!("channel {} send failed: {}", channel_id, x), + ) + }) + } + + pub async fn close(&self) -> io::Result<()> { + let channel_id = self.channel_id; + self.tx + .send(StateMachine::Unregister { id: channel_id }) + .await + .map_err(|x| { + io::Error::new( + io::ErrorKind::BrokenPipe, + format!("channel {} close failed: {}", channel_id, x), + ) + }) + } +} + +enum StateMachine { + Register { + id: ChannelId, + reply: ServerReply, + }, + + Unregister { + id: ChannelId, + }, + + Read { + response: Response>, + }, + + Write { + id: ChannelId, + request: Request>, + }, +} + +impl DistantManagerConnection { + pub fn new( + destination: Destination, + extra: Extra, + mut writer: BoxedDistantWriter, + mut reader: BoxedDistantReader, + ) -> Self { + let connection_id = rand::random(); + let (tx, mut rx) = mpsc::channel(1); + let reader_task = { + let tx = tx.clone(); + tokio::spawn(async move { + loop { + match reader.read().await { + Ok(Some(response)) => { + if tx.send(StateMachine::Read { response }).await.is_err() { + break; + } + } + Ok(None) => break, + Err(x) => { + error!("[Conn {}] {}", connection_id, x); + continue; + } + } + } + }) + }; + let writer_task = tokio::spawn(async move { + let mut registered = HashMap::new(); + while let Some(state_machine) = rx.recv().await { + match state_machine { + StateMachine::Register { id, reply } => { + registered.insert(id, reply); + } + StateMachine::Unregister { id } => { + registered.remove(&id); + } + StateMachine::Read { mut response } => { + // Split {channel id}_{request id} back into pieces and + // update the origin id to match the request id only + let channel_id = match response.origin_id.split_once('_') { + Some((cid_str, oid_str)) => { + if let Ok(cid) = cid_str.parse::() { + response.origin_id = oid_str.to_string(); + cid + } else { + continue; + } + } + None => continue, + }; + + if let Some(reply) = registered.get(&channel_id) { + let response = ManagerResponse::Channel { + id: channel_id, + response, + }; + if let Err(x) = reply.send(response).await { + error!("[Conn {}] {}", connection_id, x); + } + } + } + StateMachine::Write { id, request } => { + // Combine channel id with request id so we can properly forward + // the response containing this in the origin id + let request = Request { + id: format!("{}_{}", id, request.id), + payload: request.payload, + }; + if let Err(x) = writer.write(request).await { + error!("[Conn {}] {}", connection_id, x); + } + } + } + } + }); + + Self { + id: connection_id, + destination, + extra, + tx, + reader_task, + writer_task, + } + } + + pub async fn open_channel( + &self, + reply: ServerReply, + ) -> io::Result { + let channel_id = rand::random(); + self.tx + .send(StateMachine::Register { + id: channel_id, + reply, + }) + .await + .map_err(|x| { + io::Error::new( + io::ErrorKind::BrokenPipe, + format!("open_channel failed: {}", x), + ) + })?; + Ok(DistantManagerChannel { + channel_id, + tx: self.tx.clone(), + }) + } +} + +impl Drop for DistantManagerConnection { + fn drop(&mut self) { + self.reader_task.abort(); + self.writer_task.abort(); + } +} diff --git a/distant-core/src/manager/server/ext.rs b/distant-core/src/manager/server/ext.rs new file mode 100644 index 0000000..d23a3d2 --- /dev/null +++ b/distant-core/src/manager/server/ext.rs @@ -0,0 +1,14 @@ +mod tcp; +pub use tcp::*; + +#[cfg(unix)] +mod unix; + +#[cfg(unix)] +pub use unix::*; + +#[cfg(windows)] +mod windows; + +#[cfg(windows)] +pub use windows::*; diff --git a/distant-core/src/manager/server/ext/tcp.rs b/distant-core/src/manager/server/ext/tcp.rs new file mode 100644 index 0000000..f9a2f6d --- /dev/null +++ b/distant-core/src/manager/server/ext/tcp.rs @@ -0,0 +1,30 @@ +use crate::{DistantManager, DistantManagerConfig}; +use distant_net::{ + Codec, FramedTransport, IntoSplit, MappedListener, PortRange, TcpListener, TcpServerRef, +}; +use std::{io, net::IpAddr}; + +impl DistantManager { + /// Start a new server by binding to the given IP address and one of the ports in the + /// specified range, mapping all connections to use the given codec + pub async fn start_tcp( + config: DistantManagerConfig, + addr: IpAddr, + port: P, + codec: C, + ) -> io::Result + where + P: Into + Send, + C: Codec + Send + Sync + 'static, + { + let listener = TcpListener::bind(addr, port).await?; + let port = listener.port(); + + let listener = MappedListener::new(listener, move |transport| { + let transport = FramedTransport::new(transport, codec.clone()); + transport.into_split() + }); + let inner = DistantManager::start(config, listener)?; + Ok(TcpServerRef::new(addr, port, Box::new(inner))) + } +} diff --git a/distant-core/src/manager/server/ext/unix.rs b/distant-core/src/manager/server/ext/unix.rs new file mode 100644 index 0000000..fec9743 --- /dev/null +++ b/distant-core/src/manager/server/ext/unix.rs @@ -0,0 +1,50 @@ +use crate::{DistantManager, DistantManagerConfig}; +use distant_net::{ + Codec, FramedTransport, IntoSplit, MappedListener, UnixSocketListener, UnixSocketServerRef, +}; +use std::{io, path::Path}; + +impl DistantManager { + /// Start a new server using the specified path as a unix socket using default unix socket file + /// permissions + pub async fn start_unix_socket( + config: DistantManagerConfig, + path: P, + codec: C, + ) -> io::Result + where + P: AsRef + Send, + C: Codec + Send + Sync + 'static, + { + Self::start_unix_socket_with_permissions( + config, + path, + codec, + UnixSocketListener::default_unix_socket_file_permissions(), + ) + .await + } + + /// Start a new server using the specified path as a unix socket and `mode` as the unix socket + /// file permissions + pub async fn start_unix_socket_with_permissions( + config: DistantManagerConfig, + path: P, + codec: C, + mode: u32, + ) -> io::Result + where + P: AsRef + Send, + C: Codec + Send + Sync + 'static, + { + let listener = UnixSocketListener::bind_with_permissions(path, mode).await?; + let path = listener.path().to_path_buf(); + + let listener = MappedListener::new(listener, move |transport| { + let transport = FramedTransport::new(transport, codec.clone()); + transport.into_split() + }); + let inner = DistantManager::start(config, listener)?; + Ok(UnixSocketServerRef::new(path, Box::new(inner))) + } +} diff --git a/distant-core/src/manager/server/ext/windows.rs b/distant-core/src/manager/server/ext/windows.rs new file mode 100644 index 0000000..537bbfe --- /dev/null +++ b/distant-core/src/manager/server/ext/windows.rs @@ -0,0 +1,48 @@ +use crate::{DistantManager, DistantManagerConfig}; +use distant_net::{ + Codec, FramedTransport, IntoSplit, MappedListener, WindowsPipeListener, WindowsPipeServerRef, +}; +use std::{ + ffi::{OsStr, OsString}, + io, +}; + +impl DistantManager { + /// Start a new server at the specified address via `\\.\pipe\{name}` using the given codec + pub async fn start_local_named_pipe( + config: DistantManagerConfig, + name: N, + codec: C, + ) -> io::Result + where + Self: Sized, + N: AsRef + Send, + C: Codec + Send + Sync + 'static, + { + let mut addr = OsString::from(r"\\.\pipe\"); + addr.push(name.as_ref()); + Self::start_named_pipe(config, addr, codec).await + } + + /// Start a new server at the specified pipe address using the given codec + pub async fn start_named_pipe( + config: DistantManagerConfig, + addr: A, + codec: C, + ) -> io::Result + where + A: AsRef + Send, + C: Codec + Send + Sync + 'static, + { + let a = addr.as_ref(); + let listener = WindowsPipeListener::bind(a)?; + let addr = listener.addr().to_os_string(); + + let listener = MappedListener::new(listener, move |transport| { + let transport = FramedTransport::new(transport, codec.clone()); + transport.into_split() + }); + let inner = DistantManager::start(config, listener)?; + Ok(WindowsPipeServerRef::new(addr, Box::new(inner))) + } +} diff --git a/distant-core/src/manager/server/handler.rs b/distant-core/src/manager/server/handler.rs new file mode 100644 index 0000000..a396f2e --- /dev/null +++ b/distant-core/src/manager/server/handler.rs @@ -0,0 +1,69 @@ +use crate::{ + manager::data::{Destination, Extra}, + DistantMsg, DistantRequestData, DistantResponseData, +}; +use async_trait::async_trait; +use distant_net::{AuthClient, Request, Response, TypedAsyncRead, TypedAsyncWrite}; +use std::{future::Future, io}; + +pub type BoxedDistantWriter = + Box>> + Send>; +pub type BoxedDistantReader = + Box>> + Send>; +pub type BoxedDistantWriterReader = (BoxedDistantWriter, BoxedDistantReader); +pub type BoxedLaunchHandler = Box; +pub type BoxedConnectHandler = Box; + +/// Used to launch a server at the specified destination, returning some result as a vec of bytes +#[async_trait] +pub trait LaunchHandler: Send + Sync { + async fn launch( + &self, + destination: &Destination, + extra: &Extra, + auth_client: &mut AuthClient, + ) -> io::Result; +} + +#[async_trait] +impl LaunchHandler for F +where + F: for<'a> Fn(&'a Destination, &'a Extra, &'a mut AuthClient) -> R + Send + Sync + 'static, + R: Future> + Send + 'static, +{ + async fn launch( + &self, + destination: &Destination, + extra: &Extra, + auth_client: &mut AuthClient, + ) -> io::Result { + self(destination, extra, auth_client).await + } +} + +/// Used to connect to a destination, returning a connected reader and writer pair +#[async_trait] +pub trait ConnectHandler: Send + Sync { + async fn connect( + &self, + destination: &Destination, + extra: &Extra, + auth_client: &mut AuthClient, + ) -> io::Result; +} + +#[async_trait] +impl ConnectHandler for F +where + F: for<'a> Fn(&'a Destination, &'a Extra, &'a mut AuthClient) -> R + Send + Sync + 'static, + R: Future> + Send + 'static, +{ + async fn connect( + &self, + destination: &Destination, + extra: &Extra, + auth_client: &mut AuthClient, + ) -> io::Result { + self(destination, extra, auth_client).await + } +} diff --git a/distant-core/src/manager/server/ref.rs b/distant-core/src/manager/server/ref.rs new file mode 100644 index 0000000..360a00f --- /dev/null +++ b/distant-core/src/manager/server/ref.rs @@ -0,0 +1,73 @@ +use super::{BoxedConnectHandler, BoxedLaunchHandler, ConnectHandler, LaunchHandler}; +use distant_net::{ServerRef, ServerState}; +use std::{collections::HashMap, io, sync::Weak}; +use tokio::sync::RwLock; + +/// Reference to a distant manager's server instance +pub struct DistantManagerRef { + /// Mapping of "scheme" -> handler + pub(crate) launch_handlers: Weak>>, + + /// Mapping of "scheme" -> handler + pub(crate) connect_handlers: Weak>>, + + pub(crate) inner: Box, +} + +impl DistantManagerRef { + /// Registers a new [`LaunchHandler`] for the specified scheme (e.g. "distant" or "ssh") + pub async fn register_launch_handler( + &self, + scheme: impl Into, + handler: impl LaunchHandler + 'static, + ) -> io::Result<()> { + let handlers = Weak::upgrade(&self.launch_handlers).ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "Handler reference is no longer available", + ) + })?; + + handlers + .write() + .await + .insert(scheme.into(), Box::new(handler)); + + Ok(()) + } + + /// Registers a new [`ConnectHandler`] for the specified scheme (e.g. "distant" or "ssh") + pub async fn register_connect_handler( + &self, + scheme: impl Into, + handler: impl ConnectHandler + 'static, + ) -> io::Result<()> { + let handlers = Weak::upgrade(&self.connect_handlers).ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "Handler reference is no longer available", + ) + })?; + + handlers + .write() + .await + .insert(scheme.into(), Box::new(handler)); + + Ok(()) + } +} + +impl ServerRef for DistantManagerRef { + fn state(&self) -> &ServerState { + self.inner.state() + } + + fn is_finished(&self) -> bool { + self.inner.is_finished() + } + + fn abort(&self) { + self.inner.abort(); + } +} diff --git a/distant-core/src/net/listener.rs b/distant-core/src/net/listener.rs deleted file mode 100644 index b10f986..0000000 --- a/distant-core/src/net/listener.rs +++ /dev/null @@ -1,162 +0,0 @@ -use super::{Codec, DataStream, Transport}; -use futures::stream::Stream; -use log::*; -use std::{future::Future, pin::Pin}; -use tokio::{ - io, - net::{TcpListener, TcpStream}, - sync::mpsc, - task::JoinHandle, -}; - -/// Represents a [`Stream`] consisting of newly-connected [`DataStream`] instances that -/// have been wrapped in [`Transport`] -pub struct TransportListener -where - T: DataStream, - U: Codec, -{ - listen_task: JoinHandle<()>, - accept_task: JoinHandle<()>, - rx: mpsc::Receiver>, -} - -impl TransportListener -where - T: DataStream + Send + 'static, - U: Codec + Send + 'static, -{ - pub fn initialize(listener: L, mut make_transport: F) -> Self - where - L: Listener + 'static, - F: FnMut(T) -> Transport + Send + 'static, - { - let (stream_tx, mut stream_rx) = mpsc::channel::(1); - let listen_task = tokio::spawn(async move { - loop { - match listener.accept().await { - Ok(stream) => { - if stream_tx.send(stream).await.is_err() { - error!("Listener failed to pass along stream"); - break; - } - } - Err(x) => { - error!("Listener failed to accept stream: {}", x); - break; - } - } - } - }); - - let (tx, rx) = mpsc::channel::>(1); - let accept_task = tokio::spawn(async move { - // Check if we have a new connection. If so, wrap it in a transport and forward - // it along to - while let Some(stream) = stream_rx.recv().await { - let transport = make_transport(stream); - if let Err(x) = tx.send(transport).await { - error!("Failed to forward transport: {}", x); - } - } - }); - - Self { - listen_task, - accept_task, - rx, - } - } - - pub fn abort(&self) { - self.listen_task.abort(); - self.accept_task.abort(); - } - - /// Waits for the next fully-initialized transport for an incoming stream to be available, - /// returning none if no longer accepting new connections - pub async fn accept(&mut self) -> Option> { - self.rx.recv().await - } - - /// Converts into a stream of transport-wrapped connections - pub fn into_stream(self) -> impl Stream> { - futures::stream::unfold(self, |mut _self| async move { - _self - .accept() - .await - .map(move |transport| (transport, _self)) - }) - } -} - -pub type AcceptFuture<'a, T> = Pin> + Send + 'a>>; - -/// Represents a type that has a listen interface for receiving raw streams -pub trait Listener: Send + Sync { - type Output; - - fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output> - where - Self: Sync + 'a; -} - -impl Listener for TcpListener { - type Output = TcpStream; - - fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output> - where - Self: Sync + 'a, - { - async fn accept(_self: &TcpListener) -> io::Result { - _self.accept().await.map(|(stream, _)| stream) - } - - Box::pin(accept(self)) - } -} - -#[cfg(unix)] -impl Listener for tokio::net::UnixListener { - type Output = tokio::net::UnixStream; - - fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output> - where - Self: Sync + 'a, - { - async fn accept(_self: &tokio::net::UnixListener) -> io::Result { - _self.accept().await.map(|(stream, _)| stream) - } - - Box::pin(accept(self)) - } -} - -#[cfg(test)] -impl Listener for tokio::sync::Mutex> -where - T: DataStream + Send + Sync + 'static, -{ - type Output = T; - - fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output> - where - Self: Sync + 'a, - { - async fn accept( - _self: &tokio::sync::Mutex>, - ) -> io::Result - where - T: DataStream + Send + Sync + 'static, - { - _self - .lock() - .await - .recv() - .await - .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe)) - } - - Box::pin(accept(self)) - } -} diff --git a/distant-core/src/net/transport/inmemory.rs b/distant-core/src/net/transport/inmemory.rs deleted file mode 100644 index a1dd397..0000000 --- a/distant-core/src/net/transport/inmemory.rs +++ /dev/null @@ -1,583 +0,0 @@ -use super::{DataStream, PlainCodec, Transport}; -use futures::ready; -use std::{ - fmt, - future::Future, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::{ - io::{self, AsyncRead, AsyncWrite, ReadBuf}, - sync::mpsc, -}; - -/// Represents a data stream comprised of two inmemory channels -#[derive(Debug)] -pub struct InmemoryStream { - incoming: InmemoryStreamReadHalf, - outgoing: InmemoryStreamWriteHalf, -} - -impl InmemoryStream { - pub fn new(incoming: mpsc::Receiver>, outgoing: mpsc::Sender>) -> Self { - Self { - incoming: InmemoryStreamReadHalf::new(incoming), - outgoing: InmemoryStreamWriteHalf::new(outgoing), - } - } - - /// Returns (incoming_tx, outgoing_rx, stream) - pub fn make(buffer: usize) -> (mpsc::Sender>, mpsc::Receiver>, Self) { - let (incoming_tx, incoming_rx) = mpsc::channel(buffer); - let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer); - - ( - incoming_tx, - outgoing_rx, - Self::new(incoming_rx, outgoing_tx), - ) - } - - /// Returns pair of streams that are connected such that one sends to the other and - /// vice versa - pub fn pair(buffer: usize) -> (Self, Self) { - let (tx, rx, stream) = Self::make(buffer); - (stream, Self::new(rx, tx)) - } -} - -impl AsyncRead for InmemoryStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.incoming).poll_read(cx, buf) - } -} - -impl AsyncWrite for InmemoryStream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.outgoing).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.outgoing).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.outgoing).poll_shutdown(cx) - } -} - -/// Read portion of an inmemory channel -#[derive(Debug)] -pub struct InmemoryStreamReadHalf { - rx: mpsc::Receiver>, - overflow: Vec, -} - -impl InmemoryStreamReadHalf { - pub fn new(rx: mpsc::Receiver>) -> Self { - Self { - rx, - overflow: Vec::new(), - } - } -} - -impl AsyncRead for InmemoryStreamReadHalf { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - // If we cannot fit any more into the buffer at the moment, we wait - if buf.remaining() == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "Cannot poll as buf.remaining() == 0", - ))); - } - - // If we have overflow from the last poll, put that in the buffer - if !self.overflow.is_empty() { - if self.overflow.len() > buf.remaining() { - let extra = self.overflow.split_off(buf.remaining()); - buf.put_slice(&self.overflow); - self.overflow = extra; - } else { - buf.put_slice(&self.overflow); - self.overflow.clear(); - } - - return Poll::Ready(Ok(())); - } - - // Otherwise, we poll for the next batch to read in - match ready!(self.rx.poll_recv(cx)) { - Some(mut x) => { - if x.len() > buf.remaining() { - self.overflow = x.split_off(buf.remaining()); - } - buf.put_slice(&x); - Poll::Ready(Ok(())) - } - None => Poll::Ready(Ok(())), - } - } -} - -/// Write portion of an inmemory channel -pub struct InmemoryStreamWriteHalf { - tx: Option>>, - task: Option> + Send + Sync + 'static>>>, -} - -impl InmemoryStreamWriteHalf { - pub fn new(tx: mpsc::Sender>) -> Self { - Self { - tx: Some(tx), - task: None, - } - } -} - -impl fmt::Debug for InmemoryStreamWriteHalf { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("InmemoryStreamWriteHalf") - .field("tx", &self.tx) - .field( - "task", - &if self.tx.is_some() { - "Some(...)" - } else { - "None" - }, - ) - .finish() - } -} - -impl AsyncWrite for InmemoryStreamWriteHalf { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - match self.task.as_mut() { - Some(task) => { - let res = ready!(task.as_mut().poll(cx)); - self.task.take(); - return Poll::Ready(res); - } - None => match self.tx.as_mut() { - Some(tx) => { - let n = buf.len(); - let tx_2 = tx.clone(); - let data = buf.to_vec(); - let task = - Box::pin(async move { tx_2.send(data).await.map(|_| n).or(Ok(0)) }); - self.task.replace(task); - } - None => return Poll::Ready(Ok(0)), - }, - } - } - } - - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - self.tx.take(); - self.task.take(); - Poll::Ready(Ok(())) - } -} - -impl DataStream for InmemoryStream { - type Read = InmemoryStreamReadHalf; - type Write = InmemoryStreamWriteHalf; - - fn to_connection_tag(&self) -> String { - String::from("inmemory-stream") - } - - fn into_split(self) -> (Self::Read, Self::Write) { - (self.incoming, self.outgoing) - } -} - -impl Transport { - /// Produces a pair of inmemory transports that are connected to each other using - /// a standard codec - /// - /// Sets the buffer for message passing for each underlying stream to the given buffer size - pub fn pair( - buffer: usize, - ) -> ( - Transport, - Transport, - ) { - let (a, b) = InmemoryStream::pair(buffer); - let a = Transport::new(a, PlainCodec::new()); - let b = Transport::new(b, PlainCodec::new()); - (a, b) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - #[test] - fn to_connection_tag_should_be_hardcoded_string() { - let (_, _, stream) = InmemoryStream::make(1); - assert_eq!(stream.to_connection_tag(), "inmemory-stream"); - } - - #[tokio::test] - async fn make_should_return_sender_that_sends_data_to_stream() { - let (tx, _, mut stream) = InmemoryStream::make(3); - - tx.send(b"test msg 1".to_vec()).await.unwrap(); - tx.send(b"test msg 2".to_vec()).await.unwrap(); - tx.send(b"test msg 3".to_vec()).await.unwrap(); - - // Should get data matching a singular message - let mut buf = [0; 256]; - let len = stream.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 1"); - - // Next call would get the second message - let len = stream.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 2"); - - // When the last of the senders is dropped, we should still get - // the rest of the data that was sent first before getting - // an indicator that there is no more data - drop(tx); - - let len = stream.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 3"); - - let len = stream.read(&mut buf).await.unwrap(); - assert_eq!(len, 0, "Unexpectedly got more data"); - } - - #[tokio::test] - async fn make_should_return_receiver_that_receives_data_from_stream() { - let (_, mut rx, mut stream) = InmemoryStream::make(3); - - stream.write_all(b"test msg 1").await.unwrap(); - stream.write_all(b"test msg 2").await.unwrap(); - stream.write_all(b"test msg 3").await.unwrap(); - - // Should get data matching a singular message - assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec())); - - // Next call would get the second message - assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec())); - - // When the stream is dropped, we should still get - // the rest of the data that was sent first before getting - // an indicator that there is no more data - drop(stream); - - assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec())); - - assert_eq!(rx.recv().await, None, "Unexpectedly got more data"); - } - - #[tokio::test] - async fn into_split_should_provide_a_read_half_that_receives_from_sender() { - let (tx, _, stream) = InmemoryStream::make(3); - let (mut read_half, _) = stream.into_split(); - - tx.send(b"test msg 1".to_vec()).await.unwrap(); - tx.send(b"test msg 2".to_vec()).await.unwrap(); - tx.send(b"test msg 3".to_vec()).await.unwrap(); - - // Should get data matching a singular message - let mut buf = [0; 256]; - let len = read_half.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 1"); - - // Next call would get the second message - let len = read_half.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 2"); - - // When the last of the senders is dropped, we should still get - // the rest of the data that was sent first before getting - // an indicator that there is no more data - drop(tx); - - let len = read_half.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 3"); - - let len = read_half.read(&mut buf).await.unwrap(); - assert_eq!(len, 0, "Unexpectedly got more data"); - } - - #[tokio::test] - async fn into_split_should_provide_a_write_half_that_sends_to_receiver() { - let (_, mut rx, stream) = InmemoryStream::make(3); - let (_, mut write_half) = stream.into_split(); - - write_half.write_all(b"test msg 1").await.unwrap(); - write_half.write_all(b"test msg 2").await.unwrap(); - write_half.write_all(b"test msg 3").await.unwrap(); - - // Should get data matching a singular message - assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec())); - - // Next call would get the second message - assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec())); - - // When the stream is dropped, we should still get - // the rest of the data that was sent first before getting - // an indicator that there is no more data - drop(write_half); - - assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec())); - - assert_eq!(rx.recv().await, None, "Unexpectedly got more data"); - } - - #[tokio::test] - async fn read_half_should_fail_if_buf_has_no_space_remaining() { - let (_tx, _rx, stream) = InmemoryStream::make(1); - let (mut t_read, _t_write) = stream.into_split(); - - let mut buf = [0u8; 0]; - match t_read.read(&mut buf).await { - Err(x) if x.kind() == io::ErrorKind::Other => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn read_half_should_update_buf_with_all_overflow_from_last_read_if_it_all_fits() { - let (tx, _rx, stream) = InmemoryStream::make(1); - let (mut t_read, _t_write) = stream.into_split(); - - tx.send(vec![1, 2, 3]).await.expect("Failed to send"); - - let mut buf = [0u8; 2]; - - // First, read part of the data (first two bytes) - match t_read.read(&mut buf).await { - Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]), - x => panic!("Unexpected result: {:?}", x), - } - - // Second, we send more data because the last message was placed in overflow - tx.send(vec![4, 5, 6]).await.expect("Failed to send"); - - // Third, read remainder of the overflow from first message (third byte) - match t_read.read(&mut buf).await { - Ok(n) if n == 1 => assert_eq!(&buf[..n], &[3]), - x => panic!("Unexpected result: {:?}", x), - } - - // Fourth, verify that we start to receive the next overflow - match t_read.read(&mut buf).await { - Ok(n) if n == 2 => assert_eq!(&buf[..n], &[4, 5]), - x => panic!("Unexpected result: {:?}", x), - } - - // Fifth, verify that we get the last bit of overflow - match t_read.read(&mut buf).await { - Ok(n) if n == 1 => assert_eq!(&buf[..n], &[6]), - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn read_half_should_update_buf_with_some_of_overflow_that_can_fit() { - let (tx, _rx, stream) = InmemoryStream::make(1); - let (mut t_read, _t_write) = stream.into_split(); - - tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); - - let mut buf = [0u8; 2]; - - // First, read part of the data (first two bytes) - match t_read.read(&mut buf).await { - Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]), - x => panic!("Unexpected result: {:?}", x), - } - - // Second, we send more data because the last message was placed in overflow - tx.send(vec![6]).await.expect("Failed to send"); - - // Third, read next chunk of the overflow from first message (next two byte) - match t_read.read(&mut buf).await { - Ok(n) if n == 2 => assert_eq!(&buf[..n], &[3, 4]), - x => panic!("Unexpected result: {:?}", x), - } - - // Fourth, read last chunk of the overflow from first message (fifth byte) - match t_read.read(&mut buf).await { - Ok(n) if n == 1 => assert_eq!(&buf[..n], &[5]), - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn read_half_should_update_buf_with_all_of_inner_channel_when_it_fits() { - let (tx, _rx, stream) = InmemoryStream::make(1); - let (mut t_read, _t_write) = stream.into_split(); - - let mut buf = [0u8; 5]; - - tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); - - // First, read all of data that fits exactly - match t_read.read(&mut buf).await { - Ok(n) if n == 5 => assert_eq!(&buf[..n], &[1, 2, 3, 4, 5]), - x => panic!("Unexpected result: {:?}", x), - } - - tx.send(vec![6, 7, 8]).await.expect("Failed to send"); - - // Second, read data that fits within buf - match t_read.read(&mut buf).await { - Ok(n) if n == 3 => assert_eq!(&buf[..n], &[6, 7, 8]), - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn read_half_should_update_buf_with_some_of_inner_channel_that_can_fit_and_add_rest_to_overflow( - ) { - let (tx, _rx, stream) = InmemoryStream::make(1); - let (mut t_read, _t_write) = stream.into_split(); - - let mut buf = [0u8; 1]; - - tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); - - // Attempt a read that places more in overflow - match t_read.read(&mut buf).await { - Ok(n) if n == 1 => assert_eq!(&buf[..n], &[1]), - x => panic!("Unexpected result: {:?}", x), - } - - // Verify overflow contains the rest - assert_eq!(&t_read.overflow, &[2, 3, 4, 5]); - - // Queue up extra data that will not be read until overflow is finished - tx.send(vec![6, 7, 8]).await.expect("Failed to send"); - - // Read next data point - match t_read.read(&mut buf).await { - Ok(n) if n == 1 => assert_eq!(&buf[..n], &[2]), - x => panic!("Unexpected result: {:?}", x), - } - - // Verify overflow contains the rest without having added extra data - assert_eq!(&t_read.overflow, &[3, 4, 5]); - } - - #[tokio::test] - async fn read_half_should_yield_pending_if_no_data_available_on_inner_channel() { - let (_tx, _rx, stream) = InmemoryStream::make(1); - let (mut t_read, _t_write) = stream.into_split(); - - let mut buf = [0u8; 1]; - - // Attempt a read that should yield ok with no change, which is what should - // happen when nothing is read into buf - let f = t_read.read(&mut buf); - tokio::pin!(f); - match futures::poll!(f) { - Poll::Pending => {} - x => panic!("Unexpected poll result: {:?}", x), - } - } - - #[tokio::test] - async fn read_half_should_not_update_buf_if_inner_channel_closed() { - let (tx, _rx, stream) = InmemoryStream::make(1); - let (mut t_read, _t_write) = stream.into_split(); - - let mut buf = [0u8; 1]; - - // Drop the channel that would be sending data to the transport - drop(tx); - - // Attempt a read that should yield ok with no change, which is what should - // happen when nothing is read into buf - match t_read.read(&mut buf).await { - Ok(n) if n == 0 => assert_eq!(&buf, &[0]), - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn write_half_should_return_buf_len_if_can_send_immediately() { - let (_tx, mut rx, stream) = InmemoryStream::make(1); - let (_t_read, mut t_write) = stream.into_split(); - - // Write that is not waiting should always succeed with full contents - let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); - assert_eq!(n, 3, "Unexpected byte count returned"); - - // Verify we actually had the data sent - let data = rx.try_recv().expect("Failed to recv data"); - assert_eq!(data, &[1, 2, 3]); - } - - #[tokio::test] - async fn write_half_should_return_support_eventually_sending_by_retrying_when_not_ready() { - let (_tx, mut rx, stream) = InmemoryStream::make(1); - let (_t_read, mut t_write) = stream.into_split(); - - // Queue a write already so that we block on the next one - let _ = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); - - // Verify that the next write is pending - let f = t_write.write(&[4, 5]); - tokio::pin!(f); - match futures::poll!(&mut f) { - Poll::Pending => {} - x => panic!("Unexpected poll result: {:?}", x), - } - - // Consume first batch of data so future of second can continue - let data = rx.try_recv().expect("Failed to recv data"); - assert_eq!(data, &[1, 2, 3]); - - // Verify that poll now returns success - match futures::poll!(f) { - Poll::Ready(Ok(n)) if n == 2 => {} - x => panic!("Unexpected poll result: {:?}", x), - } - - // Consume second batch of data - let data = rx.try_recv().expect("Failed to recv data"); - assert_eq!(data, &[4, 5]); - } - - #[tokio::test] - async fn write_half_should_zero_if_inner_channel_closed() { - let (_tx, rx, stream) = InmemoryStream::make(1); - let (_t_read, mut t_write) = stream.into_split(); - - // Drop receiving end that transport would talk to - drop(rx); - - // Channel is dropped, so return 0 to indicate no bytes sent - let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); - assert_eq!(n, 0, "Unexpected byte count returned"); - } -} diff --git a/distant-core/src/net/transport/mod.rs b/distant-core/src/net/transport/mod.rs deleted file mode 100644 index 300d8b3..0000000 --- a/distant-core/src/net/transport/mod.rs +++ /dev/null @@ -1,399 +0,0 @@ -use crate::net::SecretKeyError; -use derive_more::{Display, Error, From}; -use futures::{SinkExt, StreamExt}; -use serde::{de::DeserializeOwned, Serialize}; -use std::marker::Unpin; -use tokio::io::{self, AsyncRead, AsyncWrite}; -use tokio_util::codec::{Framed, FramedRead, FramedWrite}; - -mod codec; -pub use codec::*; - -mod inmemory; -pub use inmemory::*; - -mod tcp; -pub use tcp::*; - -#[cfg(unix)] -mod unix; - -#[cfg(unix)] -pub use unix::*; - -#[derive(Debug, Display, Error, From)] -pub struct SerializeError(#[error(not(source))] String); - -#[derive(Debug, Display, Error, From)] -pub struct DeserializeError(#[error(not(source))] String); - -fn serialize_to_vec(value: &T) -> Result, SerializeError> { - let mut v = Vec::new(); - - let _ = ciborium::ser::into_writer(value, &mut v).map_err(|x| SerializeError(x.to_string()))?; - - Ok(v) -} - -fn deserialize_from_slice(slice: &[u8]) -> Result { - ciborium::de::from_reader(slice).map_err(|x| DeserializeError(x.to_string())) -} - -#[derive(Debug, Display, Error, From)] -pub enum TransportError { - CryptoError(SecretKeyError), - IoError(io::Error), - SerializeError(SerializeError), - DeserializeError(DeserializeError), -} - -/// Interface representing a two-way data stream -/// -/// Enables splitting into separate, functioning halves that can read and write respectively -pub trait DataStream: AsyncRead + AsyncWrite + Unpin { - type Read: AsyncRead + Send + Unpin + 'static; - type Write: AsyncWrite + Send + Unpin + 'static; - - /// Returns a textual description of the connection associated with this stream - fn to_connection_tag(&self) -> String; - - /// Splits this stream into read and write halves - fn into_split(self) -> (Self::Read, Self::Write); -} - -/// Represents a transport of data across the network -#[derive(Debug)] -pub struct Transport(Framed) -where - T: DataStream, - U: Codec; - -impl Transport -where - T: DataStream, - U: Codec, -{ - /// Creates a new instance of the transport, wrapping the stream in a `Framed` - pub fn new(stream: T, codec: U) -> Self { - Self(Framed::new(stream, codec)) - } - - /// Sends some data across the wire, waiting for it to completely send - pub async fn send(&mut self, data: D) -> Result<(), TransportError> { - // Serialize data into a byte stream - // NOTE: Cannot used packed implementation for now due to issues with deserialization - let data = serialize_to_vec(&data)?; - - // Use underlying codec to send data (may encrypt, sign, etc.) - self.0.send(&data).await.map_err(TransportError::from) - } - - /// Receives some data from out on the wire, waiting until it's available, - /// returning none if the transport is now closed - pub async fn receive(&mut self) -> Result, TransportError> { - // Use underlying codec to receive data (may decrypt, validate, etc.) - if let Some(data) = self.0.next().await { - let data = data?; - - // Deserialize byte stream into our expected type - let data = deserialize_from_slice(&data)?; - Ok(Some(data)) - } else { - Ok(None) - } - } - - /// Returns a textual description of the transport's underlying connection - pub fn to_connection_tag(&self) -> String { - self.0.get_ref().to_connection_tag() - } - - /// Returns a reference to the underlying I/O stream - /// - /// Note that care should be taken to not tamper with the underlying stream of data coming in - /// as it may corrupt the stream of frames otherwise being worked with - pub fn get_ref(&self) -> &T { - self.0.get_ref() - } - - /// Returns a reference to the underlying I/O stream - /// - /// Note that care should be taken to not tamper with the underlying stream of data coming in - /// as it may corrupt the stream of frames otherwise being worked with - pub fn get_mut(&mut self) -> &mut T { - self.0.get_mut() - } - - /// Consumes the transport, returning its underlying I/O stream - /// - /// Note that care should be taken to not tamper with the underlying stream of data coming in - /// as it may corrupt the stream of frames otherwise being worked with. - pub fn into_inner(self) -> T { - self.0.into_inner() - } - - /// Splits transport into read and write halves - pub fn into_split( - self, - ) -> ( - TransportReadHalf, - TransportWriteHalf, - ) { - let parts = self.0.into_parts(); - let (read_half, write_half) = parts.io.into_split(); - - // Create our split read half and populate its buffer with existing contents - let mut f_read = FramedRead::new(read_half, parts.codec.clone()); - *f_read.read_buffer_mut() = parts.read_buf; - - // Create our split write half and populate its buffer with existing contents - let mut f_write = FramedWrite::new(write_half, parts.codec); - *f_write.write_buffer_mut() = parts.write_buf; - - let t_read = TransportReadHalf(f_read); - let t_write = TransportWriteHalf(f_write); - - (t_read, t_write) - } -} - -/// Represents a transport of data out to the network -pub struct TransportWriteHalf(FramedWrite) -where - T: AsyncWrite + Unpin, - U: Codec; - -impl TransportWriteHalf -where - T: AsyncWrite + Unpin, - U: Codec, -{ - /// Sends some data across the wire, waiting for it to completely send - pub async fn send(&mut self, data: D) -> Result<(), TransportError> { - // Serialize data into a byte stream - // NOTE: Cannot used packed implementation for now due to issues with deserialization - let data = serialize_to_vec(&data)?; - - // Use underlying codec to send data (may encrypt, sign, etc.) - self.0.send(&data).await.map_err(TransportError::from) - } -} - -/// Represents a transport of data in from the network -pub struct TransportReadHalf(FramedRead) -where - T: AsyncRead + Unpin, - U: Codec; - -impl TransportReadHalf -where - T: AsyncRead + Unpin, - U: Codec, -{ - /// Receives some data from out on the wire, waiting until it's available, - /// returning none if the transport is now closed - pub async fn receive(&mut self) -> Result, TransportError> { - // Use underlying codec to receive data (may decrypt, validate, etc.) - if let Some(data) = self.0.next().await { - let data = data?; - - // Deserialize byte stream into our expected type - let data = deserialize_from_slice(&data)?; - Ok(Some(data)) - } else { - Ok(None) - } - } -} - -/// Test utilities -#[cfg(test)] -impl Transport { - /// Makes a connected pair of inmemory transports - pub fn make_pair() -> ( - Transport, - Transport, - ) { - Self::pair(crate::constants::test::BUFFER_SIZE) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde::{Deserialize, Serialize}; - - #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] - pub struct TestData { - name: String, - value: usize, - } - - #[tokio::test] - async fn send_should_convert_data_into_byte_stream_and_send_through_stream() { - let (_tx, mut rx, stream) = InmemoryStream::make(1); - let mut transport = Transport::new(stream, PlainCodec::new()); - - let data = TestData { - name: String::from("test"), - value: 123, - }; - - let bytes = serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - transport.send(data).await.unwrap(); - - let outgoing = rx.recv().await.unwrap(); - assert_eq!(outgoing, frame); - } - - #[tokio::test] - async fn receive_should_return_none_if_stream_is_closed() { - let (_, _, stream) = InmemoryStream::make(1); - let mut transport = Transport::new(stream, PlainCodec::new()); - - let result = transport.receive::().await; - match result { - Ok(None) => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn receive_should_fail_if_unable_to_convert_to_type() { - let (tx, _rx, stream) = InmemoryStream::make(1); - let mut transport = Transport::new(stream, PlainCodec::new()); - - #[derive(Serialize, Deserialize)] - struct OtherTestData(usize); - - let data = OtherTestData(123); - let bytes = serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - tx.send(frame).await.unwrap(); - let result = transport.receive::().await; - match result { - Err(TransportError::DeserializeError(_)) => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn receive_should_return_some_instance_of_type_when_coming_into_stream() { - let (tx, _rx, stream) = InmemoryStream::make(1); - let mut transport = Transport::new(stream, PlainCodec::new()); - - let data = TestData { - name: String::from("test"), - value: 123, - }; - - let bytes = serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - tx.send(frame).await.unwrap(); - let received_data = transport.receive::().await.unwrap().unwrap(); - assert_eq!(received_data, data); - } - - mod read_half { - use super::*; - - #[tokio::test] - async fn receive_should_return_none_if_stream_is_closed() { - let (_, _, stream) = InmemoryStream::make(1); - let transport = Transport::new(stream, PlainCodec::new()); - let (mut rh, _) = transport.into_split(); - - let result = rh.receive::().await; - match result { - Ok(None) => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn receive_should_fail_if_unable_to_convert_to_type() { - let (tx, _rx, stream) = InmemoryStream::make(1); - let transport = Transport::new(stream, PlainCodec::new()); - let (mut rh, _) = transport.into_split(); - - #[derive(Serialize, Deserialize)] - struct OtherTestData(usize); - - let data = OtherTestData(123); - let bytes = serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - tx.send(frame).await.unwrap(); - let result = rh.receive::().await; - match result { - Err(TransportError::DeserializeError(_)) => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn receive_should_return_some_instance_of_type_when_coming_into_stream() { - let (tx, _rx, stream) = InmemoryStream::make(1); - let transport = Transport::new(stream, PlainCodec::new()); - let (mut rh, _) = transport.into_split(); - - let data = TestData { - name: String::from("test"), - value: 123, - }; - - let bytes = serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - tx.send(frame).await.unwrap(); - let received_data = rh.receive::().await.unwrap().unwrap(); - assert_eq!(received_data, data); - } - } - - mod write_half { - use super::*; - - #[tokio::test] - async fn send_should_convert_data_into_byte_stream_and_send_through_stream() { - let (_tx, mut rx, stream) = InmemoryStream::make(1); - let transport = Transport::new(stream, PlainCodec::new()); - let (_, mut wh) = transport.into_split(); - - let data = TestData { - name: String::from("test"), - value: 123, - }; - - let bytes = serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - wh.send(data).await.unwrap(); - - let outgoing = rx.recv().await.unwrap(); - assert_eq!(outgoing, frame); - } - } -} diff --git a/distant-core/src/net/transport/tcp.rs b/distant-core/src/net/transport/tcp.rs deleted file mode 100644 index d3766e4..0000000 --- a/distant-core/src/net/transport/tcp.rs +++ /dev/null @@ -1,38 +0,0 @@ -use super::{Codec, DataStream, Transport}; -use std::net::SocketAddr; -use tokio::{ - io, - net::{ - tcp::{OwnedReadHalf, OwnedWriteHalf}, - TcpStream, ToSocketAddrs, - }, -}; - -impl DataStream for TcpStream { - type Read = OwnedReadHalf; - type Write = OwnedWriteHalf; - - fn to_connection_tag(&self) -> String { - self.peer_addr() - .map(|addr| format!("{}", addr)) - .unwrap_or_else(|_| String::from("--")) - } - - fn into_split(self) -> (Self::Read, Self::Write) { - TcpStream::into_split(self) - } -} - -impl Transport { - /// Establishes a connection to one of the specified addresses and uses the provided codec - /// for transportation - pub async fn connect(addrs: impl ToSocketAddrs, codec: U) -> io::Result { - let stream = TcpStream::connect(addrs).await?; - Ok(Transport::new(stream, codec)) - } - - /// Returns the address of the peer the transport is connected to - pub fn peer_addr(&self) -> io::Result { - self.0.get_ref().peer_addr() - } -} diff --git a/distant-core/src/net/transport/unix.rs b/distant-core/src/net/transport/unix.rs deleted file mode 100644 index 5eba515..0000000 --- a/distant-core/src/net/transport/unix.rs +++ /dev/null @@ -1,37 +0,0 @@ -use super::{Codec, DataStream, Transport}; -use tokio::{ - io, - net::{ - unix::{OwnedReadHalf, OwnedWriteHalf, SocketAddr}, - UnixStream, - }, -}; - -impl DataStream for UnixStream { - type Read = OwnedReadHalf; - type Write = OwnedWriteHalf; - - fn to_connection_tag(&self) -> String { - self.peer_addr() - .map(|addr| format!("{:?}", addr)) - .unwrap_or_else(|_| String::from("--")) - } - - fn into_split(self) -> (Self::Read, Self::Write) { - UnixStream::into_split(self) - } -} - -impl Transport { - /// Establishes a connection to the socket at the specified path and uses the provided codec - /// for transportation - pub async fn connect(path: impl AsRef, codec: U) -> io::Result { - let stream = UnixStream::connect(path.as_ref()).await?; - Ok(Transport::new(stream, codec)) - } - - /// Returns the address of the peer the transport is connected to - pub fn peer_addr(&self) -> io::Result { - self.0.get_ref().peer_addr() - } -} diff --git a/distant-core/src/serde_str.rs b/distant-core/src/serde_str.rs new file mode 100644 index 0000000..8898b74 --- /dev/null +++ b/distant-core/src/serde_str.rs @@ -0,0 +1,45 @@ +use serde::{ + de::{Deserializer, Error as SerdeError, Visitor}, + ser::Serializer, +}; +use std::{fmt, marker::PhantomData, str::FromStr}; + +/// From https://docs.rs/serde_with/1.14.0/src/serde_with/rust.rs.html#90-118 +pub fn deserialize_from_str<'de, D, T>(deserializer: D) -> Result +where + D: Deserializer<'de>, + T: FromStr, + T::Err: fmt::Display, +{ + struct Helper(PhantomData); + + impl<'de, S> Visitor<'de> for Helper + where + S: FromStr, + ::Err: fmt::Display, + { + type Value = S; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "a string") + } + + fn visit_str(self, value: &str) -> Result + where + E: SerdeError, + { + value.parse::().map_err(SerdeError::custom) + } + } + + deserializer.deserialize_str(Helper(PhantomData)) +} + +/// From https://docs.rs/serde_with/1.14.0/src/serde_with/rust.rs.html#121-127 +pub fn serialize_to_str(value: &T, serializer: S) -> Result +where + T: fmt::Display, + S: Serializer, +{ + serializer.collect_str(&value) +} diff --git a/distant-core/src/server/distant/handler.rs b/distant-core/src/server/distant/handler.rs deleted file mode 100644 index b2bb836..0000000 --- a/distant-core/src/server/distant/handler.rs +++ /dev/null @@ -1,3281 +0,0 @@ -use crate::{ - constants::SERVER_WATCHER_CAPACITY, - data::{ - self, Change, ChangeKind, ChangeKindSet, DirEntry, FileType, Metadata, PtySize, Request, - RequestData, Response, ResponseData, RunningProcess, SystemInfo, - }, - server::distant::{ - process::{Process, PtyProcess, SimpleProcess}, - state::{ProcessState, State, WatcherPath}, - }, -}; -use derive_more::{Display, Error, From}; -use futures::future; -use log::*; -use notify::{Config as WatcherConfig, RecursiveMode, Watcher}; -use std::{ - env, - future::Future, - path::{Path, PathBuf}, - pin::Pin, - sync::Arc, - time::SystemTime, -}; -use tokio::{ - io::{self, AsyncWriteExt}, - sync::{ - mpsc::{self, error::TrySendError}, - Mutex, - }, -}; -use walkdir::WalkDir; - -type HState = Arc>; -type ReplyRet = Pin + Send + 'static>>; - -#[derive(Debug, Display, Error, From)] -pub enum ServerError { - Io(io::Error), - Notify(notify::Error), - WalkDir(walkdir::Error), -} - -impl From for ResponseData { - fn from(x: ServerError) -> Self { - match x { - ServerError::Io(x) => Self::from(x), - ServerError::Notify(x) => Self::from(x), - ServerError::WalkDir(x) => Self::from(x), - } - } -} - -type PostHook = Box; -struct Outgoing { - data: ResponseData, - post_hook: Option, -} - -impl From for Outgoing { - fn from(data: ResponseData) -> Self { - Self { - data, - post_hook: None, - } - } -} - -/// Processes the provided request, sending replies using the given sender -pub(super) async fn process( - conn_id: usize, - state: HState, - req: Request, - tx: mpsc::Sender, -) -> Result<(), mpsc::error::SendError> { - async fn inner( - conn_id: usize, - state: HState, - data: RequestData, - reply: F, - ) -> Result - where - F: FnMut(Vec) -> ReplyRet + Clone + Send + 'static, - { - match data { - RequestData::FileRead { path } => file_read(path).await, - RequestData::FileReadText { path } => file_read_text(path).await, - RequestData::FileWrite { path, data } => file_write(path, data).await, - RequestData::FileWriteText { path, text } => file_write(path, text).await, - RequestData::FileAppend { path, data } => file_append(path, data).await, - RequestData::FileAppendText { path, text } => file_append(path, text).await, - RequestData::DirRead { - path, - depth, - absolute, - canonicalize, - include_root, - } => dir_read(path, depth, absolute, canonicalize, include_root).await, - RequestData::DirCreate { path, all } => dir_create(path, all).await, - RequestData::Remove { path, force } => remove(path, force).await, - RequestData::Copy { src, dst } => copy(src, dst).await, - RequestData::Rename { src, dst } => rename(src, dst).await, - RequestData::Watch { - path, - recursive, - only, - except, - } => watch(conn_id, state, reply, path, recursive, only, except).await, - RequestData::Unwatch { path } => unwatch(conn_id, state, path).await, - RequestData::Exists { path } => exists(path).await, - RequestData::Metadata { - path, - canonicalize, - resolve_file_type, - } => metadata(path, canonicalize, resolve_file_type).await, - RequestData::ProcSpawn { - cmd, - args, - persist, - pty, - } => proc_spawn(conn_id, state, reply, cmd, args, persist, pty).await, - RequestData::ProcKill { id } => proc_kill(conn_id, state, id).await, - RequestData::ProcStdin { id, data } => proc_stdin(conn_id, state, id, data).await, - RequestData::ProcResizePty { id, size } => { - proc_resize_pty(conn_id, state, id, size).await - } - RequestData::ProcList {} => proc_list(state).await, - RequestData::SystemInfo {} => system_info().await, - } - } - - let reply = { - let origin_id = req.id; - let tenant = req.tenant.clone(); - let tx_2 = tx.clone(); - move |payload: Vec| -> ReplyRet { - let tx = tx_2.clone(); - let res = Response::new(tenant.to_string(), origin_id, payload); - Box::pin(async move { tx.send(res).await.is_ok() }) - } - }; - - // Build up a collection of tasks to run independently - let mut payload_tasks = Vec::new(); - for data in req.payload { - let state_2 = Arc::clone(&state); - let reply_2 = reply.clone(); - payload_tasks.push(tokio::spawn(async move { - match inner(conn_id, state_2, data, reply_2).await { - Ok(outgoing) => outgoing, - Err(x) => Outgoing::from(ResponseData::from(x)), - } - })); - } - - // Collect the results of our tasks into the payload entries - let mut outgoing: Vec = future::join_all(payload_tasks) - .await - .into_iter() - .map(|x| match x { - Ok(outgoing) => outgoing, - Err(x) => Outgoing::from(ResponseData::from(x)), - }) - .collect(); - - let post_hooks: Vec = outgoing - .iter_mut() - .filter_map(|x| x.post_hook.take()) - .collect(); - - let payload = outgoing.into_iter().map(|x| x.data).collect(); - let res = Response::new(req.tenant, req.id, payload); - - // Send out our primary response from processing the request - let result = tx.send(res).await; - - // Invoke all post hooks - for hook in post_hooks { - hook(); - } - - result -} - -async fn file_read(path: PathBuf) -> Result { - Ok(Outgoing::from(ResponseData::Blob { - data: tokio::fs::read(path).await?, - })) -} - -async fn file_read_text(path: PathBuf) -> Result { - Ok(Outgoing::from(ResponseData::Text { - data: tokio::fs::read_to_string(path).await?, - })) -} - -async fn file_write(path: PathBuf, data: impl AsRef<[u8]>) -> Result { - tokio::fs::write(path, data).await?; - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn file_append(path: PathBuf, data: impl AsRef<[u8]>) -> Result { - let mut file = tokio::fs::OpenOptions::new() - .create(true) - .append(true) - .open(path) - .await?; - file.write_all(data.as_ref()).await?; - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn dir_read( - path: PathBuf, - depth: usize, - absolute: bool, - canonicalize: bool, - include_root: bool, -) -> Result { - // Canonicalize our provided path to ensure that it is exists, not a loop, and absolute - let root_path = tokio::fs::canonicalize(path).await?; - - // Traverse, but don't include root directory in entries (hence min depth 1), unless indicated - // to do so (min depth 0) - let dir = WalkDir::new(root_path.as_path()) - .min_depth(if include_root { 0 } else { 1 }) - .sort_by_file_name(); - - // If depth > 0, will recursively traverse to specified max depth, otherwise - // performs infinite traversal - let dir = if depth > 0 { dir.max_depth(depth) } else { dir }; - - // Determine our entries and errors - let mut entries = Vec::new(); - let mut errors = Vec::new(); - - #[inline] - fn map_file_type(ft: std::fs::FileType) -> FileType { - if ft.is_dir() { - FileType::Dir - } else if ft.is_file() { - FileType::File - } else { - FileType::Symlink - } - } - - for entry in dir { - match entry.map_err(data::Error::from) { - // For entries within the root, we want to transform the path based on flags - Ok(e) if e.depth() > 0 => { - // Canonicalize the path if specified, otherwise just return - // the path as is - let mut path = if canonicalize { - match tokio::fs::canonicalize(e.path()).await { - Ok(path) => path, - Err(x) => { - errors.push(data::Error::from(x)); - continue; - } - } - } else { - e.path().to_path_buf() - }; - - // Strip the path of its prefix based if not flagged as absolute - if !absolute { - // NOTE: In the situation where we canonicalized the path earlier, - // there is no guarantee that our root path is still the - // parent of the symlink's destination; so, in that case we MUST just - // return the path if the strip_prefix fails - path = path - .strip_prefix(root_path.as_path()) - .map(Path::to_path_buf) - .unwrap_or(path); - }; - - entries.push(DirEntry { - path, - file_type: map_file_type(e.file_type()), - depth: e.depth(), - }); - } - - // For the root, we just want to echo back the entry as is - Ok(e) => { - entries.push(DirEntry { - path: e.path().to_path_buf(), - file_type: map_file_type(e.file_type()), - depth: e.depth(), - }); - } - - Err(x) => errors.push(x), - } - } - - Ok(Outgoing::from(ResponseData::DirEntries { entries, errors })) -} - -async fn dir_create(path: PathBuf, all: bool) -> Result { - if all { - tokio::fs::create_dir_all(path).await?; - } else { - tokio::fs::create_dir(path).await?; - } - - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn remove(path: PathBuf, force: bool) -> Result { - let path_metadata = tokio::fs::metadata(path.as_path()).await?; - if path_metadata.is_dir() { - if force { - tokio::fs::remove_dir_all(path).await?; - } else { - tokio::fs::remove_dir(path).await?; - } - } else { - tokio::fs::remove_file(path).await?; - } - - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn copy(src: PathBuf, dst: PathBuf) -> Result { - let src_metadata = tokio::fs::metadata(src.as_path()).await?; - if src_metadata.is_dir() { - // Create the destination directory first, regardless of if anything - // is in the source directory - tokio::fs::create_dir_all(dst.as_path()).await?; - - for entry in WalkDir::new(src.as_path()) - .min_depth(1) - .follow_links(false) - .into_iter() - .filter_entry(|e| { - e.file_type().is_file() || e.file_type().is_dir() || e.path_is_symlink() - }) - { - let entry = entry?; - - // Get unique portion of path relative to src - // NOTE: Because we are traversing files that are all within src, this - // should always succeed - let local_src = entry.path().strip_prefix(src.as_path()).unwrap(); - - // Get the file without any directories - let local_src_file_name = local_src.file_name().unwrap(); - - // Get the directory housing the file - // NOTE: Because we enforce files/symlinks, there will always be a parent - let local_src_dir = local_src.parent().unwrap(); - - // Map out the path to the destination - let dst_parent_dir = dst.join(local_src_dir); - - // Create the destination directory for the file when copying - tokio::fs::create_dir_all(dst_parent_dir.as_path()).await?; - - let dst_path = dst_parent_dir.join(local_src_file_name); - - // Perform copying from entry to destination (if a file/symlink) - if !entry.file_type().is_dir() { - tokio::fs::copy(entry.path(), dst_path).await?; - - // Otherwise, if a directory, create it - } else { - tokio::fs::create_dir(dst_path).await?; - } - } - } else { - tokio::fs::copy(src, dst).await?; - } - - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn rename(src: PathBuf, dst: PathBuf) -> Result { - tokio::fs::rename(src, dst).await?; - - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn watch( - conn_id: usize, - state: HState, - reply: F, - path: PathBuf, - recursive: bool, - only: Vec, - except: Vec, -) -> Result -where - F: FnMut(Vec) -> ReplyRet + Clone + Send + 'static, -{ - let only = only.into_iter().collect::(); - let except = except.into_iter().collect::(); - let state_2 = Arc::clone(&state); - let mut state = state.lock().await; - - // NOTE: Do not use get_or_insert_with since notify::recommended_watcher returns a result - // and we cannot unpack the result within the above function. Since we are locking - // our state, we can be confident that no one else is modifying the watcher option - // concurrently; so, we do a naive check for option being populated - if state.watcher.is_none() { - // NOTE: Cannot be something small like 1 as this seems to cause a deadlock sometimes - // with a large volume of watch requests - let (tx, mut rx) = mpsc::channel(SERVER_WATCHER_CAPACITY); - - let mut watcher = notify::recommended_watcher(move |res| match tx.try_send(res) { - Ok(_) => {} - Err(TrySendError::Full(_)) => { - warn!( - "Reached watcher capacity of {}! Dropping watcher event!", - SERVER_WATCHER_CAPACITY, - ); - } - Err(TrySendError::Closed(_)) => { - warn!("Skipping watch event because watcher channel closed"); - } - })?; - - // Attempt to configure watcher, but don't fail if these configurations fail - match watcher.configure(WatcherConfig::PreciseEvents(true)) { - Ok(true) => debug!(" Watcher configured for precise events", conn_id), - Ok(false) => debug!( - " Watcher not configured for precise events", - conn_id, - ), - Err(x) => error!( - " Watcher configuration for precise events failed: {}", - conn_id, x - ), - } - - // Attempt to configure watcher, but don't fail if these configurations fail - match watcher.configure(WatcherConfig::NoticeEvents(true)) { - Ok(true) => debug!(" Watcher configured for notice events", conn_id), - Ok(false) => debug!( - " Watcher not configured for notice events", - conn_id, - ), - Err(x) => error!( - " Watcher configuration for notice events failed: {}", - conn_id, x - ), - } - - let _ = state.watcher.insert(watcher); - - tokio::spawn(async move { - while let Some(res) = rx.recv().await { - let is_ok = match res { - Ok(mut x) => { - let mut state = state_2.lock().await; - let paths: Vec<_> = x.paths.drain(..).collect(); - let kind = ChangeKind::from(x.kind); - - trace!( - " Watcher detected '{}' change for {:?}", - conn_id, - kind, - paths - ); - - fn make_res_data(kind: ChangeKind, paths: &[&PathBuf]) -> ResponseData { - ResponseData::Changed(Change { - kind, - paths: paths.iter().map(|p| p.to_path_buf()).collect(), - }) - } - - let results = state.map_paths_to_watcher_paths_and_replies(&paths); - let mut is_ok = true; - - for (paths, only, reply) in results { - // Skip sending this change if we are not watching it - if (!only.is_empty() && !only.contains(&kind)) - || (!except.is_empty() && except.contains(&kind)) - { - trace!( - " Skipping change '{}' for {:?}", - conn_id, - kind, - paths - ); - continue; - } - - if !reply(vec![make_res_data(kind, &paths)]).await { - is_ok = false; - break; - } - } - is_ok - } - Err(mut x) => { - let mut state = state_2.lock().await; - let paths: Vec<_> = x.paths.drain(..).collect(); - let msg = x.to_string(); - - error!( - " Watcher encountered an error {} for {:?}", - conn_id, msg, paths - ); - - fn make_res_data(msg: &str, paths: &[&PathBuf]) -> ResponseData { - if paths.is_empty() { - ResponseData::Error(msg.into()) - } else { - ResponseData::Error(format!("{} about {:?}", msg, paths).into()) - } - } - - let mut is_ok = true; - - // If we have no paths for the errors, then we send the error to everyone - if paths.is_empty() { - trace!(" Relaying error to all watchers", conn_id); - for reply in state.watcher_paths.values_mut() { - if !reply(vec![make_res_data(&msg, &[])]).await { - is_ok = false; - break; - } - } - // Otherwise, figure out the relevant watchers from our paths and - // send the error to them - } else { - let results = state.map_paths_to_watcher_paths_and_replies(&paths); - - trace!( - " Relaying error to {} watchers", - conn_id, - results.len() - ); - for (paths, _, reply) in results { - if !reply(vec![make_res_data(&msg, &paths)]).await { - is_ok = false; - break; - } - } - } - - is_ok - } - }; - - if !is_ok { - error!(" Watcher channel closed", conn_id); - break; - } - } - }); - } - - match state.watcher.as_mut() { - Some(watcher) => { - let wp = WatcherPath::new(&path, recursive, only)?; - watcher.watch( - path.as_path(), - if recursive { - RecursiveMode::Recursive - } else { - RecursiveMode::NonRecursive - }, - )?; - debug!(" Now watching {:?}", conn_id, wp.path()); - state.watcher_paths.insert(wp, Box::new(reply)); - Ok(Outgoing::from(ResponseData::Ok)) - } - None => Err(ServerError::Io(io::Error::new( - io::ErrorKind::BrokenPipe, - format!(" Unable to initialize watcher", conn_id), - ))), - } -} - -async fn unwatch(conn_id: usize, state: HState, path: PathBuf) -> Result { - if let Some(watcher) = state.lock().await.watcher.as_mut() { - watcher.unwatch(path.as_path())?; - // TODO: This also needs to remove any path that matches in either raw form - // or canonicalized form from the map of PathBuf -> ReplyFn - return Ok(Outgoing::from(ResponseData::Ok)); - } - - Err(ServerError::Io(io::Error::new( - io::ErrorKind::BrokenPipe, - format!( - " Unable to unwatch as watcher not initialized", - conn_id, - ), - ))) -} - -async fn exists(path: PathBuf) -> Result { - // Following experimental `std::fs::try_exists`, which checks the error kind of the - // metadata lookup to see if it is not found and filters accordingly - Ok(match tokio::fs::metadata(path.as_path()).await { - Ok(_) => Outgoing::from(ResponseData::Exists { value: true }), - Err(x) if x.kind() == io::ErrorKind::NotFound => { - Outgoing::from(ResponseData::Exists { value: false }) - } - Err(x) => return Err(ServerError::from(x)), - }) -} - -async fn metadata( - path: PathBuf, - canonicalize: bool, - resolve_file_type: bool, -) -> Result { - let metadata = tokio::fs::symlink_metadata(path.as_path()).await?; - let canonicalized_path = if canonicalize { - Some(tokio::fs::canonicalize(path.as_path()).await?) - } else { - None - }; - - // If asking for resolved file type and current type is symlink, then we want to refresh - // our metadata to get the filetype for the resolved link - let file_type = if resolve_file_type && metadata.file_type().is_symlink() { - tokio::fs::metadata(path).await?.file_type() - } else { - metadata.file_type() - }; - - Ok(Outgoing::from(ResponseData::Metadata(Metadata { - canonicalized_path, - accessed: metadata - .accessed() - .ok() - .and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok()) - .map(|d| d.as_millis()), - created: metadata - .created() - .ok() - .and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok()) - .map(|d| d.as_millis()), - modified: metadata - .modified() - .ok() - .and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok()) - .map(|d| d.as_millis()), - len: metadata.len(), - readonly: metadata.permissions().readonly(), - file_type: if file_type.is_dir() { - FileType::Dir - } else if file_type.is_file() { - FileType::File - } else { - FileType::Symlink - }, - - #[cfg(unix)] - unix: Some({ - use std::os::unix::prelude::*; - let mode = metadata.mode(); - crate::data::UnixMetadata::from(mode) - }), - #[cfg(not(unix))] - unix: None, - - #[cfg(windows)] - windows: Some({ - use std::os::windows::prelude::*; - let attributes = metadata.file_attributes(); - crate::data::WindowsMetadata::from(attributes) - }), - #[cfg(not(windows))] - windows: None, - }))) -} - -async fn proc_spawn( - conn_id: usize, - state: HState, - reply: F, - cmd: String, - args: Vec, - persist: bool, - pty: Option, -) -> Result -where - F: FnMut(Vec) -> ReplyRet + Clone + Send + 'static, -{ - debug!(" Spawning {} {}", conn_id, cmd, args.join(" ")); - let mut child: Box = match pty { - Some(size) => Box::new(PtyProcess::spawn(cmd.clone(), args.clone(), size)?), - None => Box::new(SimpleProcess::spawn(cmd.clone(), args.clone())?), - }; - - let id = child.id(); - let stdin = child.take_stdin(); - let stdout = child.take_stdout(); - let stderr = child.take_stderr(); - let killer = child.clone_killer(); - let pty = child.clone_pty(); - - let state_2 = Arc::clone(&state); - let post_hook = Box::new(move || { - // Spawn a task that sends stdout as a response - if let Some(mut stdout) = stdout { - let mut reply_2 = reply.clone(); - let _ = tokio::spawn(async move { - loop { - match stdout.recv().await { - Ok(Some(data)) => { - let payload = vec![ResponseData::ProcStdout { id, data }]; - if !reply_2(payload).await { - error!(" Stdout channel closed", conn_id, id); - break; - } - } - Ok(None) => break, - Err(x) => { - error!( - " Reading stdout failed: {}", - conn_id, id, x - ); - break; - } - } - } - }); - } - - // Spawn a task that sends stderr as a response - if let Some(mut stderr) = stderr { - let mut reply_2 = reply.clone(); - let _ = tokio::spawn(async move { - loop { - match stderr.recv().await { - Ok(Some(data)) => { - let payload = vec![ResponseData::ProcStderr { id, data }]; - if !reply_2(payload).await { - error!(" Stderr channel closed", conn_id, id); - break; - } - } - Ok(None) => break, - Err(x) => { - error!( - " Reading stderr failed: {}", - conn_id, id, x - ); - break; - } - } - } - }); - } - - // Spawn a task that waits on the process to exit but can also - // kill the process when triggered - let mut reply_2 = reply.clone(); - let _ = tokio::spawn(async move { - let status = child.wait().await; - debug!(" Completed {:?}", conn_id, id, status); - - state_2.lock().await.remove_process(conn_id, id); - - match status { - Ok(status) => { - let payload = vec![ResponseData::ProcDone { - id, - success: status.success, - code: status.code, - }]; - if !reply_2(payload).await { - error!(" Failed to send done", conn_id, id,); - } - } - Err(x) => { - let payload = vec![ResponseData::from(x)]; - if !reply_2(payload).await { - error!( - " Failed to send error for waiting", - conn_id, id, - ); - } - } - } - }); - }); - - state.lock().await.push_process_state( - conn_id, - ProcessState { - cmd, - args, - persist, - id, - stdin, - killer, - pty, - }, - ); - - debug!( - " Spawned successfully! Will enter post hook later", - conn_id, id - ); - Ok(Outgoing { - data: ResponseData::ProcSpawned { id }, - post_hook: Some(post_hook), - }) -} - -async fn proc_kill(conn_id: usize, state: HState, id: usize) -> Result { - if let Some(mut process) = state.lock().await.processes.remove(&id) { - if process.killer.kill().await.is_ok() { - return Ok(Outgoing::from(ResponseData::Ok)); - } - } - - Err(ServerError::Io(io::Error::new( - io::ErrorKind::BrokenPipe, - format!( - " Unable to send kill signal to process", - conn_id, id - ), - ))) -} - -async fn proc_stdin( - conn_id: usize, - state: HState, - id: usize, - data: Vec, -) -> Result { - if let Some(process) = state.lock().await.processes.get_mut(&id) { - if let Some(stdin) = process.stdin.as_mut() { - if stdin.send(&data).await.is_ok() { - return Ok(Outgoing::from(ResponseData::Ok)); - } - } - } - - Err(ServerError::Io(io::Error::new( - io::ErrorKind::BrokenPipe, - format!( - " Unable to send stdin to process", - conn_id, id, - ), - ))) -} - -async fn proc_resize_pty( - conn_id: usize, - state: HState, - id: usize, - size: PtySize, -) -> Result { - if let Some(process) = state.lock().await.processes.get(&id) { - let _ = process.pty.resize_pty(size)?; - - return Ok(Outgoing::from(ResponseData::Ok)); - } - - Err(ServerError::Io(io::Error::new( - io::ErrorKind::BrokenPipe, - format!( - " Unable to resize pty to {:?}", - conn_id, id, size, - ), - ))) -} - -async fn proc_list(state: HState) -> Result { - Ok(Outgoing::from(ResponseData::ProcEntries { - entries: state - .lock() - .await - .processes - .values() - .map(|p| RunningProcess { - cmd: p.cmd.to_string(), - args: p.args.clone(), - persist: p.persist, - // TODO: Support retrieving current pty size - pty: None, - id: p.id, - }) - .collect(), - })) -} - -async fn system_info() -> Result { - Ok(Outgoing::from(ResponseData::SystemInfo(SystemInfo { - family: env::consts::FAMILY.to_string(), - os: env::consts::OS.to_string(), - arch: env::consts::ARCH.to_string(), - current_dir: env::current_dir().unwrap_or_default(), - main_separator: std::path::MAIN_SEPARATOR, - }))) -} - -#[cfg(test)] -mod tests { - use super::*; - use assert_fs::prelude::*; - use once_cell::sync::Lazy; - use predicates::prelude::*; - use std::time::Duration; - - static TEMP_SCRIPT_DIR: Lazy = - Lazy::new(|| assert_fs::TempDir::new().unwrap()); - static SCRIPT_RUNNER: Lazy = Lazy::new(|| String::from("bash")); - - static ECHO_ARGS_TO_STDOUT_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("echo_args_to_stdout.sh"); - script - .write_str(indoc::indoc!( - r#" - #/usr/bin/env bash - printf "%s" "$*" - "# - )) - .unwrap(); - script - }); - - static ECHO_ARGS_TO_STDERR_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("echo_args_to_stderr.sh"); - script - .write_str(indoc::indoc!( - r#" - #/usr/bin/env bash - printf "%s" "$*" 1>&2 - "# - )) - .unwrap(); - script - }); - - static ECHO_STDIN_TO_STDOUT_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("echo_stdin_to_stdout.sh"); - script - .write_str(indoc::indoc!( - r#" - #/usr/bin/env bash - while IFS= read; do echo "$REPLY"; done - "# - )) - .unwrap(); - script - }); - - static SLEEP_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("sleep.sh"); - script - .write_str(indoc::indoc!( - r#" - #!/usr/bin/env bash - sleep "$1" - "# - )) - .unwrap(); - script - }); - - static DOES_NOT_EXIST_BIN: Lazy = - Lazy::new(|| TEMP_SCRIPT_DIR.child("does_not_exist_bin")); - - fn setup( - buffer: usize, - ) -> ( - usize, - Arc>, - mpsc::Sender, - mpsc::Receiver, - ) { - let (tx, rx) = mpsc::channel(buffer); - ( - rand::random(), - Arc::new(Mutex::new(State::default())), - tx, - rx, - ) - } - - #[tokio::test] - async fn file_read_should_send_error_if_fails_to_read_file() { - let (conn_id, state, tx, mut rx) = setup(1); - - let temp = assert_fs::TempDir::new().unwrap(); - let path = temp.child("missing-file").path().to_path_buf(); - - let req = Request::new("test-tenant", vec![RequestData::FileRead { path }]); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - #[tokio::test] - async fn file_read_should_send_blob_with_file_contents() { - let (conn_id, state, tx, mut rx) = setup(1); - - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str("some file contents").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileRead { - path: file.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Blob { data } => assert_eq!(data, b"some file contents"), - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn file_read_text_should_send_error_if_fails_to_read_file() { - let (conn_id, state, tx, mut rx) = setup(1); - - let temp = assert_fs::TempDir::new().unwrap(); - let path = temp.child("missing-file").path().to_path_buf(); - - let req = Request::new("test-tenant", vec![RequestData::FileReadText { path }]); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - #[tokio::test] - async fn file_read_text_should_send_text_with_file_contents() { - let (conn_id, state, tx, mut rx) = setup(1); - - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str("some file contents").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileReadText { - path: file.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Text { data } => assert_eq!(data, "some file contents"), - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn file_write_should_send_error_if_fails_to_write_file() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create a temporary path and add to it to ensure that there are - // extra components that don't exist to cause writing to fail - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("dir").child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileWrite { - path: file.path().to_path_buf(), - data: b"some text".to_vec(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we didn't actually create the file - file.assert(predicate::path::missing()); - } - - #[tokio::test] - async fn file_write_should_send_ok_when_successful() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Path should point to a file that does not exist, but all - // other components leading up to it do - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileWrite { - path: file.path().to_path_buf(), - data: b"some text".to_vec(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we actually did create the file - // with the associated contents - file.assert("some text"); - } - - #[tokio::test] - async fn file_write_text_should_send_error_if_fails_to_write_file() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create a temporary path and add to it to ensure that there are - // extra components that don't exist to cause writing to fail - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("dir").child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileWriteText { - path: file.path().to_path_buf(), - text: String::from("some text"), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we didn't actually create the file - file.assert(predicate::path::missing()); - } - - #[tokio::test] - async fn file_write_text_should_send_ok_when_successful() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Path should point to a file that does not exist, but all - // other components leading up to it do - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileWriteText { - path: file.path().to_path_buf(), - text: String::from("some text"), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we actually did create the file - // with the associated contents - file.assert("some text"); - } - - #[tokio::test] - async fn file_append_should_send_error_if_fails_to_create_file() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create a temporary path and add to it to ensure that there are - // extra components that don't exist to cause writing to fail - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("dir").child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileAppend { - path: file.path().to_path_buf(), - data: b"some extra contents".to_vec(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we didn't actually create the file - file.assert(predicate::path::missing()); - } - - #[tokio::test] - async fn file_append_should_create_file_if_missing() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Don't create the file directly, but define path - // where the file should be - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileAppend { - path: file.path().to_path_buf(), - data: b"some extra contents".to_vec(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Yield to allow chance to finish appending to file - tokio::time::sleep(Duration::from_millis(50)).await; - - // Also verify that we actually did create to the file - file.assert("some extra contents"); - } - - #[tokio::test] - async fn file_append_should_send_ok_when_successful() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create a temporary file and fill it with some contents - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str("some file contents").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileAppend { - path: file.path().to_path_buf(), - data: b"some extra contents".to_vec(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Yield to allow chance to finish appending to file - tokio::time::sleep(Duration::from_millis(50)).await; - - // Also verify that we actually did append to the file - file.assert("some file contentssome extra contents"); - } - - #[tokio::test] - async fn file_append_text_should_send_error_if_fails_to_create_file() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create a temporary path and add to it to ensure that there are - // extra components that don't exist to cause writing to fail - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("dir").child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileAppendText { - path: file.path().to_path_buf(), - text: String::from("some extra contents"), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we didn't actually create the file - file.assert(predicate::path::missing()); - } - - #[tokio::test] - async fn file_append_text_should_create_file_if_missing() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Don't create the file directly, but define path - // where the file should be - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileAppendText { - path: file.path().to_path_buf(), - text: "some extra contents".to_string(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Yield to allow chance to finish appending to file - tokio::time::sleep(Duration::from_millis(50)).await; - - // Also verify that we actually did create to the file - file.assert("some extra contents"); - } - - #[tokio::test] - async fn file_append_text_should_send_ok_when_successful() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create a temporary file and fill it with some contents - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str("some file contents").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileAppendText { - path: file.path().to_path_buf(), - text: String::from("some extra contents"), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Yield to allow chance to finish appending to file - tokio::time::sleep(Duration::from_millis(50)).await; - - // Also verify that we actually did append to the file - file.assert("some file contentssome extra contents"); - } - - #[tokio::test] - async fn dir_read_should_send_error_if_directory_does_not_exist() { - let (conn_id, state, tx, mut rx) = setup(1); - - let temp = assert_fs::TempDir::new().unwrap(); - let dir = temp.child("test-dir"); - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: dir.path().to_path_buf(), - depth: 0, - absolute: false, - canonicalize: false, - include_root: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - // /root/ - // /root/file1 - // /root/link1 -> /root/sub1/file2 - // /root/sub1/ - // /root/sub1/file2 - async fn setup_dir() -> assert_fs::TempDir { - let root_dir = assert_fs::TempDir::new().unwrap(); - root_dir.child("file1").touch().unwrap(); - - let sub1 = root_dir.child("sub1"); - sub1.create_dir_all().unwrap(); - - let file2 = sub1.child("file2"); - file2.touch().unwrap(); - - let link1 = root_dir.child("link1"); - link1.symlink_to_file(file2.path()).unwrap(); - - root_dir - } - - #[tokio::test] - async fn dir_read_should_support_depth_limits() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create directory with some nested items - let root_dir = setup_dir().await; - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: root_dir.path().to_path_buf(), - depth: 1, - absolute: false, - canonicalize: false, - include_root: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::DirEntries { entries, .. } => { - assert_eq!(entries.len(), 3, "Wrong number of entries found"); - - assert_eq!(entries[0].file_type, FileType::File); - assert_eq!(entries[0].path, Path::new("file1")); - assert_eq!(entries[0].depth, 1); - - assert_eq!(entries[1].file_type, FileType::Symlink); - assert_eq!(entries[1].path, Path::new("link1")); - assert_eq!(entries[1].depth, 1); - - assert_eq!(entries[2].file_type, FileType::Dir); - assert_eq!(entries[2].path, Path::new("sub1")); - assert_eq!(entries[2].depth, 1); - } - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn dir_read_should_support_unlimited_depth_using_zero() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create directory with some nested items - let root_dir = setup_dir().await; - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: root_dir.path().to_path_buf(), - depth: 0, - absolute: false, - canonicalize: false, - include_root: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::DirEntries { entries, .. } => { - assert_eq!(entries.len(), 4, "Wrong number of entries found"); - - assert_eq!(entries[0].file_type, FileType::File); - assert_eq!(entries[0].path, Path::new("file1")); - assert_eq!(entries[0].depth, 1); - - assert_eq!(entries[1].file_type, FileType::Symlink); - assert_eq!(entries[1].path, Path::new("link1")); - assert_eq!(entries[1].depth, 1); - - assert_eq!(entries[2].file_type, FileType::Dir); - assert_eq!(entries[2].path, Path::new("sub1")); - assert_eq!(entries[2].depth, 1); - - assert_eq!(entries[3].file_type, FileType::File); - assert_eq!(entries[3].path, Path::new("sub1").join("file2")); - assert_eq!(entries[3].depth, 2); - } - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn dir_read_should_support_including_directory_in_returned_entries() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create directory with some nested items - let root_dir = setup_dir().await; - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: root_dir.path().to_path_buf(), - depth: 1, - absolute: false, - canonicalize: false, - include_root: true, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::DirEntries { entries, .. } => { - assert_eq!(entries.len(), 4, "Wrong number of entries found"); - - // NOTE: Root entry is always absolute, resolved path - assert_eq!(entries[0].file_type, FileType::Dir); - assert_eq!(entries[0].path, root_dir.path().canonicalize().unwrap()); - assert_eq!(entries[0].depth, 0); - - assert_eq!(entries[1].file_type, FileType::File); - assert_eq!(entries[1].path, Path::new("file1")); - assert_eq!(entries[1].depth, 1); - - assert_eq!(entries[2].file_type, FileType::Symlink); - assert_eq!(entries[2].path, Path::new("link1")); - assert_eq!(entries[2].depth, 1); - - assert_eq!(entries[3].file_type, FileType::Dir); - assert_eq!(entries[3].path, Path::new("sub1")); - assert_eq!(entries[3].depth, 1); - } - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn dir_read_should_support_returning_absolute_paths() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create directory with some nested items - let root_dir = setup_dir().await; - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: root_dir.path().to_path_buf(), - depth: 1, - absolute: true, - canonicalize: false, - include_root: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::DirEntries { entries, .. } => { - assert_eq!(entries.len(), 3, "Wrong number of entries found"); - let root_path = root_dir.path().canonicalize().unwrap(); - - assert_eq!(entries[0].file_type, FileType::File); - assert_eq!(entries[0].path, root_path.join("file1")); - assert_eq!(entries[0].depth, 1); - - assert_eq!(entries[1].file_type, FileType::Symlink); - assert_eq!(entries[1].path, root_path.join("link1")); - assert_eq!(entries[1].depth, 1); - - assert_eq!(entries[2].file_type, FileType::Dir); - assert_eq!(entries[2].path, root_path.join("sub1")); - assert_eq!(entries[2].depth, 1); - } - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn dir_read_should_support_returning_canonicalized_paths() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Create directory with some nested items - let root_dir = setup_dir().await; - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: root_dir.path().to_path_buf(), - depth: 1, - absolute: false, - canonicalize: true, - include_root: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::DirEntries { entries, .. } => { - assert_eq!(entries.len(), 3, "Wrong number of entries found"); - - assert_eq!(entries[0].file_type, FileType::File); - assert_eq!(entries[0].path, Path::new("file1")); - assert_eq!(entries[0].depth, 1); - - // Symlink should be resolved from $ROOT/link1 -> $ROOT/sub1/file2 - assert_eq!(entries[1].file_type, FileType::Symlink); - assert_eq!(entries[1].path, Path::new("sub1").join("file2")); - assert_eq!(entries[1].depth, 1); - - assert_eq!(entries[2].file_type, FileType::Dir); - assert_eq!(entries[2].path, Path::new("sub1")); - assert_eq!(entries[2].depth, 1); - } - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn dir_create_should_send_error_if_fails() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Make a path that has multiple non-existent components - // so the creation will fail - let root_dir = setup_dir().await; - let path = root_dir.path().join("nested").join("new-dir"); - - let req = Request::new( - "test-tenant", - vec![RequestData::DirCreate { - path: path.to_path_buf(), - all: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that the directory was not actually created - assert!(!path.exists(), "Path unexpectedly exists"); - } - - #[tokio::test] - async fn dir_create_should_send_ok_when_successful() { - let (conn_id, state, tx, mut rx) = setup(1); - let root_dir = setup_dir().await; - let path = root_dir.path().join("new-dir"); - - let req = Request::new( - "test-tenant", - vec![RequestData::DirCreate { - path: path.to_path_buf(), - all: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that the directory was actually created - assert!(path.exists(), "Directory not created"); - } - - #[tokio::test] - async fn dir_create_should_support_creating_multiple_dir_components() { - let (conn_id, state, tx, mut rx) = setup(1); - let root_dir = setup_dir().await; - let path = root_dir.path().join("nested").join("new-dir"); - - let req = Request::new( - "test-tenant", - vec![RequestData::DirCreate { - path: path.to_path_buf(), - all: true, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that the directory was actually created - assert!(path.exists(), "Directory not created"); - } - - #[tokio::test] - async fn remove_should_send_error_on_failure() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("missing-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Remove { - path: file.path().to_path_buf(), - force: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that path does not exist - file.assert(predicate::path::missing()); - } - - #[tokio::test] - async fn remove_should_support_deleting_a_directory() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Remove { - path: dir.path().to_path_buf(), - force: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that path does not exist - dir.assert(predicate::path::missing()); - } - - #[tokio::test] - async fn remove_should_delete_nonempty_directory_if_force_is_true() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - dir.child("file").touch().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Remove { - path: dir.path().to_path_buf(), - force: true, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that path does not exist - dir.assert(predicate::path::missing()); - } - - #[tokio::test] - async fn remove_should_support_deleting_a_single_file() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("some-file"); - file.touch().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Remove { - path: file.path().to_path_buf(), - force: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that path does not exist - file.assert(predicate::path::missing()); - } - - #[tokio::test] - async fn copy_should_send_error_on_failure() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let src = temp.child("src"); - let dst = temp.child("dst"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Copy { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that destination does not exist - dst.assert(predicate::path::missing()); - } - - #[tokio::test] - async fn copy_should_support_copying_an_entire_directory() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - - let src = temp.child("src"); - src.create_dir_all().unwrap(); - let src_file = src.child("file"); - src_file.write_str("some contents").unwrap(); - - let dst = temp.child("dst"); - let dst_file = dst.child("file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Copy { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we have source and destination directories and associated contents - src.assert(predicate::path::is_dir()); - src_file.assert(predicate::path::is_file()); - dst.assert(predicate::path::is_dir()); - dst_file.assert(predicate::path::eq_file(src_file.path())); - } - - #[tokio::test] - async fn copy_should_support_copying_an_empty_directory() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let src = temp.child("src"); - src.create_dir_all().unwrap(); - let dst = temp.child("dst"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Copy { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we still have source and destination directories - src.assert(predicate::path::is_dir()); - dst.assert(predicate::path::is_dir()); - } - - #[tokio::test] - async fn copy_should_support_copying_a_directory_that_only_contains_directories() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - - let src = temp.child("src"); - src.create_dir_all().unwrap(); - let src_dir = src.child("dir"); - src_dir.create_dir_all().unwrap(); - - let dst = temp.child("dst"); - let dst_dir = dst.child("dir"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Copy { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we have source and destination directories and associated contents - src.assert(predicate::path::is_dir().name("src")); - src_dir.assert(predicate::path::is_dir().name("src/dir")); - dst.assert(predicate::path::is_dir().name("dst")); - dst_dir.assert(predicate::path::is_dir().name("dst/dir")); - } - - #[tokio::test] - async fn copy_should_support_copying_a_single_file() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let src = temp.child("src"); - src.write_str("some text").unwrap(); - let dst = temp.child("dst"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Copy { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we still have source and that destination has source's contents - src.assert(predicate::path::is_file()); - dst.assert(predicate::path::eq_file(src.path())); - } - - #[tokio::test] - async fn rename_should_send_error_on_failure() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let src = temp.child("src"); - let dst = temp.child("dst"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Rename { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that destination does not exist - dst.assert(predicate::path::missing()); - } - - #[tokio::test] - async fn rename_should_support_renaming_an_entire_directory() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - - let src = temp.child("src"); - src.create_dir_all().unwrap(); - let src_file = src.child("file"); - src_file.write_str("some contents").unwrap(); - - let dst = temp.child("dst"); - let dst_file = dst.child("file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Rename { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we moved the contents - src.assert(predicate::path::missing()); - src_file.assert(predicate::path::missing()); - dst.assert(predicate::path::is_dir()); - dst_file.assert("some contents"); - } - - #[tokio::test] - async fn rename_should_support_renaming_a_single_file() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let src = temp.child("src"); - src.write_str("some text").unwrap(); - let dst = temp.child("dst"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Rename { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we moved the file - src.assert(predicate::path::missing()); - dst.assert("some text"); - } - - /// Validates a response as being a series of changes that include the provided paths - fn validate_changed_paths( - res: &Response, - expected_paths: &[PathBuf], - should_panic: bool, - ) -> bool { - match &res.payload[0] { - ResponseData::Changed(change) if should_panic => { - let paths: Vec = change - .paths - .iter() - .map(|x| x.canonicalize().unwrap()) - .collect(); - assert_eq!(paths, expected_paths, "Wrong paths reported: {:?}", change); - - true - } - ResponseData::Changed(change) => { - let paths: Vec = change - .paths - .iter() - .map(|x| x.canonicalize().unwrap()) - .collect(); - paths == expected_paths - } - x if should_panic => panic!("Unexpected response: {:?}", x), - _ => false, - } - } - - #[tokio::test] - async fn watch_should_support_watching_a_single_file() { - // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. - let (conn_id, state, tx, mut rx) = setup(100); - let temp = assert_fs::TempDir::new().unwrap(); - - let file = temp.child("file"); - file.touch().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Watch { - path: file.path().to_path_buf(), - recursive: false, - only: Default::default(), - except: Default::default(), - }], - ); - - // NOTE: We need to clone state so we don't drop the watcher - // as part of dropping the state - process(conn_id, Arc::clone(&state), req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Update the file and verify we get a notification - file.write_str("some text").unwrap(); - - let res = rx - .recv() - .await - .expect("Channel closed before we got change"); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - validate_changed_paths( - &res, - &[file.path().to_path_buf().canonicalize().unwrap()], - /* should_panic */ true, - ); - } - - #[tokio::test] - async fn watch_should_support_watching_a_directory_recursively() { - // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. - let (conn_id, state, tx, mut rx) = setup(100); - let temp = assert_fs::TempDir::new().unwrap(); - - let file = temp.child("file"); - file.touch().unwrap(); - - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Watch { - path: temp.path().to_path_buf(), - recursive: true, - only: Default::default(), - except: Default::default(), - }], - ); - - // NOTE: We need to clone state so we don't drop the watcher - // as part of dropping the state - process(conn_id, Arc::clone(&state), req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Update the file and verify we get a notification - file.write_str("some text").unwrap(); - - // Create a nested file and verify we get a notification - let nested_file = dir.child("nested-file"); - nested_file.write_str("some text").unwrap(); - - // Sleep a bit to give time to get all changes happening - // TODO: Can we slim down this sleep? Or redesign test in some other way? - tokio::time::sleep(Duration::from_millis(100)).await; - - // Collect all responses, as we may get multiple for interactions within a directory - let mut responses = Vec::new(); - while let Ok(res) = rx.try_recv() { - responses.push(res); - } - - // Validate that we have at least one change reported for each of our paths - assert!( - responses.len() >= 2, - "Less than expected total responses: {:?}", - responses - ); - - let path = file.path().to_path_buf(); - assert!( - responses.iter().any(|res| validate_changed_paths( - res, - &[file.path().to_path_buf().canonicalize().unwrap()], - /* should_panic */ false, - )), - "Missing {:?} in {:?}", - path, - responses - .iter() - .map(|x| format!("{:?}", x)) - .collect::>(), - ); - - let path = nested_file.path().to_path_buf(); - assert!( - responses.iter().any(|res| validate_changed_paths( - res, - &[file.path().to_path_buf().canonicalize().unwrap()], - /* should_panic */ false, - )), - "Missing {:?} in {:?}", - path, - responses - .iter() - .map(|x| format!("{:?}", x)) - .collect::>(), - ); - } - - #[tokio::test] - async fn watch_should_report_changes_using_the_request_id() { - // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. - let (conn_id, state, tx, mut rx) = setup(100); - let temp = assert_fs::TempDir::new().unwrap(); - - let file_1 = temp.child("file_1"); - file_1.touch().unwrap(); - - let file_2 = temp.child("file_2"); - file_2.touch().unwrap(); - - // Sleep a bit to give time to get all changes happening - // TODO: Can we slim down this sleep? Or redesign test in some other way? - tokio::time::sleep(Duration::from_millis(100)).await; - - // Initialize watch on file 1 - let file_1_origin_id = { - let req = Request::new( - "test-tenant", - vec![RequestData::Watch { - path: file_1.path().to_path_buf(), - recursive: false, - only: Default::default(), - except: Default::default(), - }], - ); - let origin_id = req.id; - - // NOTE: We need to clone state so we don't drop the watcher - // as part of dropping the state - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - origin_id - }; - - // Initialize watch on file 2 - let file_2_origin_id = { - let req = Request::new( - "test-tenant", - vec![RequestData::Watch { - path: file_2.path().to_path_buf(), - recursive: false, - only: Default::default(), - except: Default::default(), - }], - ); - let origin_id = req.id; - - // NOTE: We need to clone state so we don't drop the watcher - // as part of dropping the state - process(conn_id, Arc::clone(&state), req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - origin_id - }; - - // Update the files and verify we get notifications from different origins - { - file_1.write_str("some text").unwrap(); - let res = rx - .recv() - .await - .expect("Channel closed before we got change"); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - validate_changed_paths( - &res, - &[file_1.path().to_path_buf().canonicalize().unwrap()], - /* should_panic */ true, - ); - assert_eq!(res.origin_id, file_1_origin_id, "Wrong origin id (file 1)"); - - // Process any extra messages (we might get create, content, and more) - loop { - // Sleep a bit to give time to get all changes happening - // TODO: Can we slim down this sleep? Or redesign test in some other way? - tokio::time::sleep(Duration::from_millis(100)).await; - - if rx.try_recv().is_err() { - break; - } - } - } - - // Update the files and verify we get notifications from different origins - { - file_2.write_str("some text").unwrap(); - let res = rx - .recv() - .await - .expect("Channel closed before we got change"); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - validate_changed_paths( - &res, - &[file_2.path().to_path_buf().canonicalize().unwrap()], - /* should_panic */ true, - ); - assert_eq!(res.origin_id, file_2_origin_id, "Wrong origin id (file 2)"); - } - } - - #[tokio::test] - async fn exists_should_send_true_if_path_exists() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("file"); - file.touch().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Exists { - path: file.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert_eq!(res.payload[0], ResponseData::Exists { value: true }); - } - - #[tokio::test] - async fn exists_should_send_false_if_path_does_not_exist() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Exists { - path: file.path().to_path_buf(), - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert_eq!(res.payload[0], ResponseData::Exists { value: false }); - } - - #[tokio::test] - async fn metadata_should_send_error_on_failure() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: file.path().to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - #[tokio::test] - async fn metadata_should_send_back_metadata_on_file_if_exists() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("file"); - file.write_str("some text").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: file.path().to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!( - res.payload[0], - ResponseData::Metadata(Metadata { - canonicalized_path: None, - file_type: FileType::File, - len: 9, - readonly: false, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - #[cfg(unix)] - #[tokio::test] - async fn metadata_should_include_unix_specific_metadata_on_unix_platform() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("file"); - file.write_str("some text").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: file.path().to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - match &res.payload[0] { - ResponseData::Metadata(Metadata { unix, windows, .. }) => { - assert!(unix.is_some(), "Unexpectedly missing unix metadata on unix"); - assert!( - windows.is_none(), - "Unexpectedly got windows metadata on unix" - ); - } - x => panic!("Unexpected response: {:?}", x), - } - } - - #[cfg(windows)] - #[tokio::test] - async fn metadata_should_include_unix_specific_metadata_on_windows_platform() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("file"); - file.write_str("some text").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: file.path().to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - match &res.payload[0] { - ResponseData::Metadata(Metadata { unix, windows, .. }) => { - assert!( - windows.is_some(), - "Unexpectedly missing windows metadata on windows" - ); - assert!(unix.is_none(), "Unexpectedly got unix metadata on windows"); - } - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn metadata_should_send_back_metadata_on_dir_if_exists() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: dir.path().to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!( - res.payload[0], - ResponseData::Metadata(Metadata { - canonicalized_path: None, - file_type: FileType::Dir, - readonly: false, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - #[tokio::test] - async fn metadata_should_send_back_metadata_on_symlink_if_exists() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("file"); - file.write_str("some text").unwrap(); - - let symlink = temp.child("link"); - symlink.symlink_to_file(file.path()).unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: symlink.path().to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!( - res.payload[0], - ResponseData::Metadata(Metadata { - canonicalized_path: None, - file_type: FileType::Symlink, - readonly: false, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - #[tokio::test] - async fn metadata_should_include_canonicalized_path_if_flag_specified() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("file"); - file.write_str("some text").unwrap(); - - let symlink = temp.child("link"); - symlink.symlink_to_file(file.path()).unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: symlink.path().to_path_buf(), - canonicalize: true, - resolve_file_type: false, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Metadata(Metadata { - canonicalized_path: Some(path), - file_type: FileType::Symlink, - readonly: false, - .. - }) => assert_eq!( - path, - &file.path().canonicalize().unwrap(), - "Symlink canonicalized path does not match referenced file" - ), - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified() { - let (conn_id, state, tx, mut rx) = setup(1); - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("file"); - file.write_str("some text").unwrap(); - - let symlink = temp.child("link"); - symlink.symlink_to_file(file.path()).unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: symlink.path().to_path_buf(), - canonicalize: false, - resolve_file_type: true, - }], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Metadata(Metadata { - file_type: FileType::File, - .. - }) => {} - x => panic!("Unexpected response: {:?}", x), - } - } - - #[tokio::test] - async fn proc_spawn_should_send_error_on_failure() { - let (conn_id, state, tx, mut rx) = setup(1); - - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), - args: Vec::new(), - persist: false, - pty: None, - }], - ); - - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(&res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - #[tokio::test] - async fn proc_spawn_should_send_back_proc_start_on_success() { - let (conn_id, state, tx, mut rx) = setup(1); - - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap().to_string()], - persist: false, - pty: None, - }], - ); - - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(&res.payload[0], ResponseData::ProcSpawned { .. }), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - // NOTE: Ignoring on windows because it's using WSL which wants a Linux path - // with / but thinks it's on windows and is providing \ - #[tokio::test] - #[cfg_attr(windows, ignore)] - async fn proc_spawn_should_send_back_stdout_periodically_when_available() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Run a program that echoes to stdout - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ - ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap().to_string(), - String::from("some stdout"), - ], - persist: false, - pty: None, - }], - ); - - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(&res.payload[0], ResponseData::ProcSpawned { .. }), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Gather two additional responses: - // - // 1. An indirect response for stdout - // 2. An indirect response that is proc completing - // - // Note that order is not a guarantee, so we have to check that - // we get one of each type of response - let res1 = rx.recv().await.expect("Missing first response"); - let res2 = rx.recv().await.expect("Missing second response"); - - let mut got_stdout = false; - let mut got_done = false; - - let mut check_res = |res: &Response| { - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::ProcStdout { data, .. } => { - assert_eq!(data, b"some stdout", "Got wrong stdout"); - got_stdout = true; - } - ResponseData::ProcDone { success, .. } => { - assert!(success, "Process should have completed successfully"); - got_done = true; - } - x => panic!("Unexpected response: {:?}", x), - } - }; - - check_res(&res1); - check_res(&res2); - assert!(got_stdout, "Missing stdout response"); - assert!(got_done, "Missing done response"); - } - - // NOTE: Ignoring on windows because it's using WSL which wants a Linux path - // with / but thinks it's on windows and is providing \ - #[tokio::test] - #[cfg_attr(windows, ignore)] - async fn proc_spawn_should_send_back_stderr_periodically_when_available() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Run a program that echoes to stderr - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ - ECHO_ARGS_TO_STDERR_SH.to_str().unwrap().to_string(), - String::from("some stderr"), - ], - persist: false, - pty: None, - }], - ); - - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(&res.payload[0], ResponseData::ProcSpawned { .. }), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Gather two additional responses: - // - // 1. An indirect response for stderr - // 2. An indirect response that is proc completing - // - // Note that order is not a guarantee, so we have to check that - // we get one of each type of response - let res1 = rx.recv().await.expect("Missing first response"); - let res2 = rx.recv().await.expect("Missing second response"); - - let mut got_stderr = false; - let mut got_done = false; - - let mut check_res = |res: &Response| { - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::ProcStderr { data, .. } => { - assert_eq!(data, b"some stderr", "Got wrong stderr"); - got_stderr = true; - } - ResponseData::ProcDone { success, .. } => { - assert!(success, "Process should have completed successfully"); - got_done = true; - } - x => panic!("Unexpected response: {:?}", x), - } - }; - - check_res(&res1); - check_res(&res2); - assert!(got_stderr, "Missing stderr response"); - assert!(got_done, "Missing done response"); - } - - // NOTE: Ignoring on windows because it's using WSL which wants a Linux path - // with / but thinks it's on windows and is providing \ - #[tokio::test] - #[cfg_attr(windows, ignore)] - async fn proc_spawn_should_clear_process_from_state_when_done() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Run a program that ends after a little bit - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![SLEEP_SH.to_str().unwrap().to_string(), String::from("0.1")], - persist: false, - pty: None, - }], - ); - - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Verify that the state has the process - assert!( - state.lock().await.processes.contains_key(&id), - "Process {} not in state", - id - ); - - // Wait for process to finish - let _ = rx.recv().await.unwrap(); - - // Verify that the state was cleared - assert!( - !state.lock().await.processes.contains_key(&id), - "Process {} still in state", - id - ); - } - - #[tokio::test] - async fn proc_spawn_should_clear_process_from_state_when_killed() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Run a program that ends slowly - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![SLEEP_SH.to_str().unwrap().to_string(), String::from("1")], - persist: false, - pty: None, - }], - ); - - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Verify that the state has the process - assert!( - state.lock().await.processes.contains_key(&id), - "Process {} not in state", - id - ); - - // Send kill signal - let req = Request::new("test-tenant", vec![RequestData::ProcKill { id }]); - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - // Wait for two responses, a kill confirmation and the done - let _ = rx.recv().await.unwrap(); - let _ = rx.recv().await.unwrap(); - - // Verify that the state was cleared - assert!( - !state.lock().await.processes.contains_key(&id), - "Process {} still in state", - id - ); - } - - #[tokio::test] - async fn proc_kill_should_send_error_on_failure() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Send kill to a non-existent process - let req = Request::new( - "test-tenant", - vec![RequestData::ProcKill { id: 0xDEADBEEF }], - ); - - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - // Verify that we get an error - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - #[tokio::test] - async fn proc_kill_should_send_ok_and_done_responses_on_success() { - let (conn_id, state, tx, mut rx) = setup(1); - - // First, run a program that sits around (sleep for 1 second) - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![SLEEP_SH.to_str().unwrap().to_string(), String::from("1")], - persist: false, - pty: None, - }], - ); - - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - // Second, grab the id of the started process - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Third, send kill for process - let req = Request::new("test-tenant", vec![RequestData::ProcKill { id }]); - - // NOTE: We cannot let the state get dropped as it results in killing - // the child process automatically; so, we clone another reference here - process(conn_id, Arc::clone(&state), req, tx).await.unwrap(); - - // Fourth, gather two responses: - // - // 1. A direct response saying that received (ok) - // 2. An indirect response that is proc completing - // - // Note that order is not a guarantee, so we have to check that - // we get one of each type of response - let res1 = rx.recv().await.expect("Missing first response"); - let res2 = rx.recv().await.expect("Missing second response"); - - let mut got_ok = false; - let mut got_done = false; - - let mut check_res = |res: &Response| { - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Ok => got_ok = true, - ResponseData::ProcDone { success, .. } => { - assert!(!success, "Process should not have completed successfully"); - got_done = true; - } - x => panic!("Unexpected response: {:?}", x), - } - }; - - check_res(&res1); - check_res(&res2); - assert!(got_ok, "Missing ok response"); - assert!(got_done, "Missing done response"); - } - - #[tokio::test] - async fn proc_stdin_should_send_error_on_failure() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Send stdin to a non-existent process - let req = Request::new( - "test-tenant", - vec![RequestData::ProcStdin { - id: 0xDEADBEEF, - data: b"some input".to_vec(), - }], - ); - - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - // Verify that we get an error - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - // NOTE: Ignoring on windows because it's using WSL which wants a Linux path - // with / but thinks it's on windows and is providing \ - #[tokio::test] - #[cfg_attr(windows, ignore)] - async fn proc_stdin_should_send_ok_on_success_and_properly_send_stdin_to_process() { - let (conn_id, state, tx, mut rx) = setup(1); - - // First, run a program that listens for stdin - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ECHO_STDIN_TO_STDOUT_SH.to_str().unwrap().to_string()], - persist: false, - pty: None, - }], - ); - - process(conn_id, Arc::clone(&state), req, tx.clone()) - .await - .unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - // Second, grab the id of the started process - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Third, send stdin to the remote process - let req = Request::new( - "test-tenant", - vec![RequestData::ProcStdin { - id, - data: b"hello world\n".to_vec(), - }], - ); - - // NOTE: We cannot let the state get dropped as it results in killing - // the child process; so, we clone another reference here - process(conn_id, Arc::clone(&state), req, tx).await.unwrap(); - - // Fourth, gather two responses: - // - // 1. A direct response to processing the stdin - // 2. An indirect response that is stdout from echoing our stdin - // - // Note that order is not a guarantee, so we have to check that - // we get one of each type of response - let res1 = rx.recv().await.expect("Missing first response"); - let res2 = rx.recv().await.expect("Missing second response"); - - let mut got_ok = false; - let mut got_stdout = false; - - let mut check_res = |res: &Response| { - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Ok => got_ok = true, - ResponseData::ProcStdout { data, .. } => { - assert_eq!(data, b"hello world\n", "Mirrored data didn't match"); - got_stdout = true; - } - x => panic!("Unexpected response: {:?}", x), - } - }; - - check_res(&res1); - check_res(&res2); - assert!(got_ok, "Missing ok response"); - assert!(got_stdout, "Missing mirrored stdin response"); - } - - #[tokio::test] - async fn proc_list_should_send_proc_entry_list() { - let (conn_id, state, tx, mut rx) = setup(1); - - // Run a process and get the list that includes that process - // at the same time (using sleep of 1 second) - let req = Request::new( - "test-tenant", - vec![ - RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![SLEEP_SH.to_str().unwrap().to_string(), String::from("1")], - persist: false, - pty: None, - }, - RequestData::ProcList {}, - ], - ); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 2, "Wrong payload size"); - - // Grab the id of the started process - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Verify our process shows up in our entry list - assert_eq!( - res.payload[1], - ResponseData::ProcEntries { - entries: vec![RunningProcess { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![SLEEP_SH.to_str().unwrap().to_string(), String::from("1")], - persist: false, - pty: None, - id, - }], - }, - "Unexpected response: {:?}", - res.payload[0] - ); - } - - #[tokio::test] - async fn system_info_should_send_system_info_based_on_binary() { - let (conn_id, state, tx, mut rx) = setup(1); - - let req = Request::new("test-tenant", vec![RequestData::SystemInfo {}]); - - process(conn_id, state, req, tx).await.unwrap(); - - let res = rx.recv().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert_eq!( - res.payload[0], - ResponseData::SystemInfo(SystemInfo { - family: env::consts::FAMILY.to_string(), - os: env::consts::OS.to_string(), - arch: env::consts::ARCH.to_string(), - current_dir: env::current_dir().unwrap_or_default(), - main_separator: std::path::MAIN_SEPARATOR, - }), - "Unexpected response: {:?}", - res.payload[0] - ); - } -} diff --git a/distant-core/src/server/distant/mod.rs b/distant-core/src/server/distant/mod.rs deleted file mode 100644 index f855a07..0000000 --- a/distant-core/src/server/distant/mod.rs +++ /dev/null @@ -1,362 +0,0 @@ -mod handler; -mod process; -mod state; - -pub(crate) use process::{InputChannel, ProcessKiller, ProcessPty}; -use state::State; - -use crate::{ - constants::MAX_MSG_CAPACITY, - data::{Request, Response}, - net::{Codec, DataStream, Transport, TransportListener, TransportReadHalf, TransportWriteHalf}, - server::{ - utils::{ConnTracker, ShutdownTask}, - PortRange, - }, -}; -use futures::stream::{Stream, StreamExt}; -use log::*; -use std::{net::IpAddr, sync::Arc}; -use tokio::{ - io::{self, AsyncRead, AsyncWrite}, - net::TcpListener, - sync::{mpsc, Mutex}, - task::{JoinError, JoinHandle}, - time::Duration, -}; - -/// Represents a server that listens for requests, processes them, and sends responses -pub struct DistantServer { - conn_task: JoinHandle<()>, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct DistantServerOptions { - pub shutdown_after: Option, - pub max_msg_capacity: usize, -} - -impl Default for DistantServerOptions { - fn default() -> Self { - Self { - shutdown_after: None, - max_msg_capacity: MAX_MSG_CAPACITY, - } - } -} - -impl DistantServer { - /// Bind to an IP address and port from the given range, taking an optional shutdown duration - /// that will shutdown the server if there is no active connection after duration - pub async fn bind( - addr: IpAddr, - port: PortRange, - codec: U, - opts: DistantServerOptions, - ) -> io::Result<(Self, u16)> - where - U: Codec + Send + 'static, - { - debug!("Binding to {} in range {}", addr, port); - let listener = TcpListener::bind(port.make_socket_addrs(addr).as_slice()).await?; - - let port = listener.local_addr()?.port(); - debug!("Bound to port: {}", port); - - let stream = TransportListener::initialize(listener, move |stream| { - Transport::new(stream, codec.clone()) - }) - .into_stream(); - - Ok((Self::initialize(Box::pin(stream), opts), port)) - } - - /// Initialize a distant server using the provided listener - pub fn initialize(stream: S, opts: DistantServerOptions) -> Self - where - T: DataStream + Send + 'static, - U: Codec + Send + 'static, - S: Stream> + Send + Unpin + 'static, - { - // Build our state for the server - let state: Arc> = Arc::new(Mutex::new(State::default())); - let (shutdown, tracker) = ShutdownTask::maybe_initialize(opts.shutdown_after); - - // Spawn our connection task - let conn_task = tokio::spawn(async move { - connection_loop(stream, state, tracker, shutdown, opts.max_msg_capacity).await - }); - - Self { conn_task } - } - - /// Waits for the server to terminate - pub async fn wait(self) -> Result<(), JoinError> { - self.conn_task.await - } - - /// Aborts the server by aborting the internal task handling new connections - pub fn abort(&self) { - self.conn_task.abort(); - } -} - -async fn connection_loop( - mut stream: S, - state: Arc>, - tracker: Option>>, - shutdown: Option, - max_msg_capacity: usize, -) where - T: DataStream + Send + 'static, - U: Codec + Send + 'static, - S: Stream> + Send + Unpin + 'static, -{ - let inner = async move { - loop { - match stream.next().await { - Some(transport) => { - let conn_id = rand::random(); - debug!( - " Established against {}", - conn_id, - transport.to_connection_tag() - ); - if let Err(x) = on_new_conn( - transport, - conn_id, - Arc::clone(&state), - tracker.as_ref().map(Arc::clone), - max_msg_capacity, - ) - .await - { - error!(" Failed handshake: {}", conn_id, x); - } - } - None => { - info!("Listener shutting down"); - break; - } - }; - } - }; - - match shutdown { - Some(shutdown) => tokio::select! { - _ = inner => {} - _ = shutdown => { - warn!("Reached shutdown timeout, so terminating"); - } - }, - None => inner.await, - } -} - -/// Processes a new connection, performing a handshake, and then spawning two tasks to handle -/// input and output, returning join handles for the input and output tasks respectively -async fn on_new_conn( - transport: Transport, - conn_id: usize, - state: Arc>, - tracker: Option>>, - max_msg_capacity: usize, -) -> io::Result> -where - T: DataStream, - U: Codec + Send + 'static, -{ - // Update our tracker to reflect the new connection - if let Some(ct) = tracker.as_ref() { - ct.lock().await.increment(); - } - - // Split the transport into read and write halves so we can handle input - // and output concurrently - let (t_read, t_write) = transport.into_split(); - let (tx, rx) = mpsc::channel(max_msg_capacity); - - // Spawn a new task that loops to handle requests from the client - let state_2 = Arc::clone(&state); - let req_task = tokio::spawn(async move { - request_loop(conn_id, state_2, t_read, tx).await; - }); - - // Spawn a new task that loops to handle responses to the client - let res_task = tokio::spawn(async move { response_loop(conn_id, t_write, rx).await }); - - // Spawn cleanup task that waits on our req & res tasks to complete - let cleanup_task = tokio::spawn(async move { - // Wait for both receiving and sending tasks to complete before marking - // the connection as complete - let _ = tokio::join!(req_task, res_task); - - state.lock().await.cleanup_connection(conn_id).await; - if let Some(ct) = tracker.as_ref() { - ct.lock().await.decrement(); - } - }); - - Ok(cleanup_task) -} - -/// Repeatedly reads in new requests, processes them, and sends their responses to the -/// response loop -async fn request_loop( - conn_id: usize, - state: Arc>, - mut transport: TransportReadHalf, - tx: mpsc::Sender, -) where - T: AsyncRead + Send + Unpin + 'static, - U: Codec, -{ - loop { - match transport.receive::().await { - Ok(Some(req)) => { - debug!( - " Received request of type{} {}", - conn_id, - if req.payload.len() > 1 { "s" } else { "" }, - req.to_payload_type_string() - ); - - if let Err(x) = handler::process(conn_id, Arc::clone(&state), req, tx.clone()).await - { - error!(" {}", conn_id, x); - break; - } - } - Ok(None) => { - trace!(" Input from connection closed", conn_id); - break; - } - Err(x) => { - error!(" {}", conn_id, x); - break; - } - } - } - - // Properly close off any associated process' stdin given that we can't get new - // requests to send more stdin to them - state.lock().await.close_stdin_for_connection(conn_id); -} - -/// Repeatedly sends responses out over the wire -async fn response_loop( - conn_id: usize, - mut transport: TransportWriteHalf, - mut rx: mpsc::Receiver, -) where - T: AsyncWrite + Send + Unpin + 'static, - U: Codec, -{ - while let Some(res) = rx.recv().await { - debug!( - " Sending response of type{} {}", - conn_id, - if res.payload.len() > 1 { "s" } else { "" }, - res.to_payload_type_string() - ); - - if let Err(x) = transport.send(res).await { - error!(" {}", conn_id, x); - break; - } - } - - trace!(" Output to connection closed", conn_id); -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - data::{RequestData, ResponseData}, - net::{InmemoryStream, PlainCodec}, - }; - use std::pin::Pin; - - #[allow(clippy::type_complexity)] - fn make_transport_stream() -> ( - mpsc::Sender>, - Pin> + Send>>, - ) { - let (tx, rx) = mpsc::channel::>(1); - let stream = futures::stream::unfold(rx, |mut rx| async move { - rx.recv().await.map(move |transport| (transport, rx)) - }); - (tx, Box::pin(stream)) - } - - #[tokio::test] - async fn wait_should_return_ok_when_all_inner_tasks_complete() { - let (tx, stream) = make_transport_stream(); - - let server = DistantServer::initialize(stream, Default::default()); - - // Conclude all server tasks by closing out the listener - drop(tx); - - let result = server.wait().await; - assert!(result.is_ok(), "Unexpected result: {:?}", result); - } - - #[tokio::test] - async fn wait_should_return_error_when_server_aborted() { - let (_tx, stream) = make_transport_stream(); - - let server = DistantServer::initialize(stream, Default::default()); - server.abort(); - - match server.wait().await { - Err(x) if x.is_cancelled() => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn server_should_receive_requests_and_send_responses_to_appropriate_connections() { - let (tx, stream) = make_transport_stream(); - - let _server = DistantServer::initialize(stream, Default::default()); - - // Send over a "connection" - let (mut t1, t2) = Transport::make_pair(); - tx.send(t2).await.unwrap(); - - // Send a request - t1.send(Request::new( - "test-tenant", - vec![RequestData::SystemInfo {}], - )) - .await - .unwrap(); - - // Get a response - let res = t1.receive::().await.unwrap().unwrap(); - assert!(res.payload.len() == 1, "Unexpected payload size"); - assert!( - matches!(res.payload[0], ResponseData::SystemInfo { .. }), - "Unexpected response: {:?}", - res.payload[0] - ); - } - - #[tokio::test] - async fn server_should_shutdown_if_no_connections_after_shutdown_duration() { - let (_tx, stream) = make_transport_stream(); - - let server = DistantServer::initialize( - stream, - DistantServerOptions { - shutdown_after: Some(Duration::from_millis(50)), - max_msg_capacity: 1, - }, - ); - - let result = server.wait().await; - assert!(result.is_ok(), "Unexpected result: {:?}", result); - } -} diff --git a/distant-core/src/server/distant/state.rs b/distant-core/src/server/distant/state.rs deleted file mode 100644 index a0a45c8..0000000 --- a/distant-core/src/server/distant/state.rs +++ /dev/null @@ -1,232 +0,0 @@ -use super::{InputChannel, ProcessKiller, ProcessPty}; -use crate::data::{ChangeKindSet, ResponseData}; -use log::*; -use notify::RecommendedWatcher; -use std::{ - collections::HashMap, - future::Future, - hash::{Hash, Hasher}, - io, - ops::{Deref, DerefMut}, - path::{Path, PathBuf}, - pin::Pin, -}; - -pub type ReplyFn = Box) -> ReplyRet + Send + 'static>; -pub type ReplyRet = Pin + Send + 'static>>; - -/// Holds state related to multiple connections managed by a server -#[derive(Default)] -pub struct State { - /// Map of all processes running on the server - pub processes: HashMap, - - /// List of processes that will be killed when a connection drops - client_processes: HashMap>, - - /// Watcher used for filesystem events - pub watcher: Option, - - /// Mapping of Path -> (Reply Fn, recursive) for watcher notifications - pub watcher_paths: HashMap, -} - -#[derive(Clone, Debug)] -pub struct WatcherPath { - /// The raw path provided to the watcher, which is not canonicalized - raw_path: PathBuf, - - /// The canonicalized path at the time of providing to the watcher, - /// as all paths must exist for a watcher, we use this to get the - /// source of truth when watching - path: PathBuf, - - /// Whether or not the path was set to be recursive - recursive: bool, - - /// Specific filter for path - only: ChangeKindSet, -} - -impl PartialEq for WatcherPath { - fn eq(&self, other: &Self) -> bool { - self.path == other.path - } -} - -impl Eq for WatcherPath {} - -impl Hash for WatcherPath { - fn hash(&self, state: &mut H) { - self.path.hash(state); - } -} - -impl Deref for WatcherPath { - type Target = PathBuf; - - fn deref(&self) -> &Self::Target { - &self.path - } -} - -impl DerefMut for WatcherPath { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.path - } -} - -impl WatcherPath { - /// Create a new watcher path using the given path and canonicalizing it - pub fn new( - path: impl Into, - recursive: bool, - only: impl Into, - ) -> io::Result { - let raw_path = path.into(); - let path = raw_path.canonicalize()?; - let only = only.into(); - Ok(Self { - raw_path, - path, - recursive, - only, - }) - } - - pub fn raw_path(&self) -> &Path { - self.raw_path.as_path() - } - - pub fn path(&self) -> &Path { - self.path.as_path() - } - - /// Returns true if this watcher path applies to the given path. - /// This is accomplished by checking if the path is contained - /// within either the raw or canonicalized path of the watcher - /// and ensures that recursion rules are respected - pub fn applies_to_path(&self, path: &Path) -> bool { - let check_path = |path: &Path| -> bool { - let cnt = path.components().count(); - - // 0 means exact match from strip_prefix - // 1 means that it was within immediate directory (fine for non-recursive) - // 2+ means it needs to be recursive - cnt < 2 || self.recursive - }; - - match ( - path.strip_prefix(self.path()), - path.strip_prefix(self.raw_path()), - ) { - (Ok(p1), Ok(p2)) => check_path(p1) || check_path(p2), - (Ok(p), Err(_)) => check_path(p), - (Err(_), Ok(p)) => check_path(p), - (Err(_), Err(_)) => false, - } - } -} - -/// Holds information related to a spawned process on the server -pub struct ProcessState { - pub cmd: String, - pub args: Vec, - pub persist: bool, - - pub id: usize, - pub stdin: Option>, - pub killer: Box, - pub pty: Box, -} - -impl State { - pub fn map_paths_to_watcher_paths_and_replies<'a>( - &mut self, - paths: &'a [PathBuf], - ) -> Vec<(Vec<&'a PathBuf>, &ChangeKindSet, &mut ReplyFn)> { - let mut results = Vec::new(); - - for (wp, reply) in self.watcher_paths.iter_mut() { - let mut wp_paths = Vec::new(); - for path in paths { - if wp.applies_to_path(path) { - wp_paths.push(path); - } - } - if !wp_paths.is_empty() { - results.push((wp_paths, &wp.only, reply)); - } - } - - results - } - - /// Pushes a new process associated with a connection - pub fn push_process_state(&mut self, conn_id: usize, process_state: ProcessState) { - self.client_processes - .entry(conn_id) - .or_insert_with(Vec::new) - .push(process_state.id); - self.processes.insert(process_state.id, process_state); - } - - /// Removes a process associated with a connection - pub fn remove_process(&mut self, conn_id: usize, proc_id: usize) { - self.client_processes.entry(conn_id).and_modify(|v| { - if let Some(pos) = v.iter().position(|x| *x == proc_id) { - v.remove(pos); - } - }); - self.processes.remove(&proc_id); - } - - /// Closes stdin for all processes associated with the connection - pub fn close_stdin_for_connection(&mut self, conn_id: usize) { - debug!(" Closing stdin to all processes", conn_id); - if let Some(ids) = self.client_processes.get(&conn_id) { - for id in ids { - if let Some(process) = self.processes.get_mut(id) { - trace!( - " Closing stdin for proc {}", - conn_id, - process.id - ); - - let _ = process.stdin.take(); - } - } - } - } - - /// Cleans up state associated with a particular connection - pub async fn cleanup_connection(&mut self, conn_id: usize) { - debug!(" Cleaning up state", conn_id); - if let Some(ids) = self.client_processes.remove(&conn_id) { - for id in ids { - if let Some(mut process) = self.processes.remove(&id) { - if !process.persist { - trace!( - " Requesting proc {} be killed", - conn_id, - process.id - ); - let pid = process.id; - if let Err(x) = process.killer.kill().await { - error!( - "Conn {} failed to send process {} kill signal: {}", - id, pid, x - ); - } - } else { - trace!( - " Proc {} is persistent and will not be killed", - conn_id, - process.id - ); - } - } - } - } - } -} diff --git a/distant-core/src/server/mod.rs b/distant-core/src/server/mod.rs deleted file mode 100644 index 6cce802..0000000 --- a/distant-core/src/server/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod distant; -mod port; -mod relay; -mod utils; - -pub use self::distant::{DistantServer, DistantServerOptions}; -pub use port::PortRange; -pub use relay::RelayServer; diff --git a/distant-core/src/server/relay.rs b/distant-core/src/server/relay.rs deleted file mode 100644 index 770160f..0000000 --- a/distant-core/src/server/relay.rs +++ /dev/null @@ -1,382 +0,0 @@ -use crate::{ - client::{Session, SessionChannel}, - data::{Request, RequestData, ResponseData}, - net::{Codec, DataStream, Transport}, - server::utils::{ConnTracker, ShutdownTask}, -}; -use futures::stream::{Stream, StreamExt}; -use log::*; -use std::{collections::HashMap, marker::Unpin, sync::Arc}; -use tokio::{ - io, - sync::{oneshot, Mutex}, - task::{JoinError, JoinHandle}, - time::Duration, -}; - -/// Represents a server that relays requests & responses between connections and the -/// actual server -pub struct RelayServer { - accept_task: JoinHandle<()>, - conns: Arc>>, -} - -impl RelayServer { - pub fn initialize( - session: Session, - mut stream: S, - shutdown_after: Option, - ) -> io::Result - where - T: DataStream + Send + 'static, - U: Codec + Send + 'static, - S: Stream> + Send + Unpin + 'static, - { - let conns: Arc>> = Arc::new(Mutex::new(HashMap::new())); - - let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after); - let conns_2 = Arc::clone(&conns); - let accept_task = tokio::spawn(async move { - let inner = async move { - loop { - let channel = session.clone_channel(); - match stream.next().await { - Some(transport) => { - let result = Conn::initialize( - transport, - channel, - tracker.as_ref().map(Arc::clone), - ) - .await; - - match result { - Ok(conn) => { - conns_2.lock().await.insert(conn.id(), conn); - } - Err(x) => { - error!("Failed to initialize connection: {}", x); - } - }; - } - None => { - info!("Listener shutting down"); - break; - } - }; - } - }; - - match shutdown { - Some(shutdown) => tokio::select! { - _ = inner => {} - _ = shutdown => { - warn!("Reached shutdown timeout, so terminating"); - } - }, - None => inner.await, - } - }); - - Ok(Self { accept_task, conns }) - } - - /// Waits for the server to terminate - pub async fn wait(self) -> Result<(), JoinError> { - self.accept_task.await - } - - /// Aborts the server by aborting the internal tasks and current connections - pub async fn abort(&self) { - self.accept_task.abort(); - self.conns - .lock() - .await - .values() - .for_each(|conn| conn.abort()); - } -} - -struct Conn { - id: usize, - conn_task: JoinHandle<()>, -} - -impl Conn { - pub async fn initialize( - transport: Transport, - channel: SessionChannel, - ct: Option>>, - ) -> io::Result - where - T: DataStream + 'static, - U: Codec + Send + 'static, - { - // Create a unique id to associate with the connection since its address - // is not guaranteed to have an identifiable string - let id: usize = rand::random(); - - // Mark that we have a new connection - if let Some(ct) = ct.as_ref() { - ct.lock().await.increment(); - } - - let conn_task = spawn_conn_handler(id, transport, channel, ct).await; - - Ok(Self { id, conn_task }) - } - - /// Id associated with the connection - pub fn id(&self) -> usize { - self.id - } - - /// Aborts the connection from the server side - pub fn abort(&self) { - self.conn_task.abort(); - } -} - -async fn spawn_conn_handler( - conn_id: usize, - transport: Transport, - mut channel: SessionChannel, - ct: Option>>, -) -> JoinHandle<()> -where - T: DataStream, - U: Codec + Send + 'static, -{ - let (mut t_reader, t_writer) = transport.into_split(); - let processes = Arc::new(Mutex::new(Vec::new())); - let t_writer = Arc::new(Mutex::new(t_writer)); - - let (done_tx, done_rx) = oneshot::channel(); - let mut channel_2 = channel.clone(); - let processes_2 = Arc::clone(&processes); - let task = tokio::spawn(async move { - loop { - if channel_2.is_closed() { - break; - } - - // For each request, forward it through the session and monitor all responses - match t_reader.receive::().await { - Ok(Some(req)) => match channel_2.mail(req).await { - Ok(mut mailbox) => { - let processes = Arc::clone(&processes_2); - let t_writer = Arc::clone(&t_writer); - tokio::spawn(async move { - while let Some(res) = mailbox.next().await { - // Keep track of processes that are started so we can kill them - // when we're done - { - let mut p_lock = processes.lock().await; - for data in res.payload.iter() { - if let ResponseData::ProcSpawned { id } = *data { - p_lock.push(id); - } - } - } - - if let Err(x) = t_writer.lock().await.send(res).await { - error!( - " Failed to send response back: {}", - conn_id, x - ); - } - } - }); - } - Err(x) => error!( - " Failed to pass along request received on unix socket: {:?}", - conn_id, x - ), - }, - Ok(None) => break, - Err(x) => { - error!( - " Failed to receive request from unix stream: {:?}", - conn_id, x - ); - break; - } - } - } - - let _ = done_tx.send(()); - }); - - // Perform cleanup if done by sending a request to kill each running process - tokio::spawn(async move { - let _ = done_rx.await; - - let p_lock = processes.lock().await; - if !p_lock.is_empty() { - trace!( - "Cleaning conn {} :: killing {} process", - conn_id, - p_lock.len() - ); - if let Err(x) = channel - .fire(Request::new( - "relay", - p_lock - .iter() - .map(|id| RequestData::ProcKill { id: *id }) - .collect(), - )) - .await - { - error!(" Failed to send kill signals: {}", conn_id, x); - } - } - - if let Some(ct) = ct.as_ref() { - ct.lock().await.decrement(); - } - debug!(" Disconnected", conn_id); - }); - - task -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - data::Response, - net::{InmemoryStream, PlainCodec}, - }; - use std::{pin::Pin, time::Duration}; - use tokio::sync::mpsc; - - fn make_session() -> (Transport, Session) { - let (t1, t2) = Transport::make_pair(); - (t1, Session::initialize(t2).unwrap()) - } - - #[allow(clippy::type_complexity)] - fn make_transport_stream() -> ( - mpsc::Sender>, - Pin> + Send>>, - ) { - let (tx, rx) = mpsc::channel::>(1); - let stream = futures::stream::unfold(rx, |mut rx| async move { - rx.recv().await.map(move |transport| (transport, rx)) - }); - (tx, Box::pin(stream)) - } - - #[tokio::test] - async fn wait_should_return_ok_when_all_inner_tasks_complete() { - let (transport, session) = make_session(); - let (tx, stream) = make_transport_stream(); - let server = RelayServer::initialize(session, stream, None).unwrap(); - - // Conclude all server tasks by closing out the listener & session - drop(transport); - drop(tx); - - let result = server.wait().await; - assert!(result.is_ok(), "Unexpected result: {:?}", result); - } - - #[tokio::test] - async fn wait_should_return_error_when_server_aborted() { - let (_transport, session) = make_session(); - let (_tx, stream) = make_transport_stream(); - let server = RelayServer::initialize(session, stream, None).unwrap(); - server.abort().await; - - match server.wait().await { - Err(x) if x.is_cancelled() => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn server_should_forward_requests_using_session() { - let (mut transport, session) = make_session(); - let (tx, stream) = make_transport_stream(); - let _server = RelayServer::initialize(session, stream, None).unwrap(); - - // Send over a "connection" - let (mut t1, t2) = Transport::make_pair(); - tx.send(t2).await.unwrap(); - - // Send a request - let req = Request::new("test-tenant", vec![RequestData::SystemInfo {}]); - t1.send(req.clone()).await.unwrap(); - - // Verify the request is forwarded out via session - let outbound_req = transport.receive().await.unwrap().unwrap(); - assert_eq!(req, outbound_req); - } - - #[tokio::test] - async fn server_should_send_back_response_with_tenant_matching_connection() { - let (mut transport, session) = make_session(); - let (tx, stream) = make_transport_stream(); - let _server = RelayServer::initialize(session, stream, None).unwrap(); - - // Send over a "connection" - let (mut t1, t2) = Transport::make_pair(); - tx.send(t2).await.unwrap(); - - // Send over a second "connection" - let (mut t2, t3) = Transport::make_pair(); - tx.send(t3).await.unwrap(); - - // Send a request to mark the tenant of the first connection - t1.send(Request::new( - "test-tenant-1", - vec![RequestData::SystemInfo {}], - )) - .await - .unwrap(); - - // Send a request to mark the tenant of the second connection - t2.send(Request::new( - "test-tenant-2", - vec![RequestData::SystemInfo {}], - )) - .await - .unwrap(); - - // Clear out the transport channel (outbound of session) - // NOTE: Because our test stream uses a buffer size of 1, we have to clear out the - // outbound data from the earlier requests before we can send back a response - let req_1 = transport.receive::().await.unwrap().unwrap(); - let req_2 = transport.receive::().await.unwrap().unwrap(); - let origin_id = if req_1.tenant == "test-tenant-2" { - req_1.id - } else { - req_2.id - }; - - // Send a response back to a singular connection based on the tenant - let res = Response::new("test-tenant-2", origin_id, vec![ResponseData::Ok]); - transport.send(res.clone()).await.unwrap(); - - // Verify that response is only received by a singular connection - let inbound_res = t2.receive().await.unwrap().unwrap(); - assert_eq!(res, inbound_res); - - let no_inbound = tokio::select! { - _ = t1.receive::() => {false} - _ = tokio::time::sleep(Duration::from_millis(50)) => {true} - }; - assert!(no_inbound, "Unexpectedly got response for wrong connection"); - } - - #[tokio::test] - async fn server_should_shutdown_if_no_connections_after_shutdown_duration() { - let (_transport, session) = make_session(); - let (_tx, stream) = make_transport_stream(); - let server = - RelayServer::initialize(session, stream, Some(Duration::from_millis(50))).unwrap(); - - let result = server.wait().await; - assert!(result.is_ok(), "Unexpected result: {:?}", result); - } -} diff --git a/distant-core/src/server/utils.rs b/distant-core/src/server/utils.rs deleted file mode 100644 index 76fcc19..0000000 --- a/distant-core/src/server/utils.rs +++ /dev/null @@ -1,289 +0,0 @@ -use log::*; -use std::{ - future::Future, - pin::Pin, - sync::Arc, - task::{Context, Poll}, - time::Duration, -}; -use tokio::{ - sync::Mutex, - task::{JoinError, JoinHandle}, - time::{self, Instant}, -}; - -/// Task to keep track of a possible server shutdown based on connections -pub struct ShutdownTask { - task: JoinHandle<()>, - tracker: Arc>, -} - -impl Future for ShutdownTask { - type Output = Result<(), JoinError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.task).poll(cx) - } -} - -impl ShutdownTask { - /// Given an optional timeout, will either create the shutdown task or not, - /// returning an optional shutdown task alongside an optional connection tracker - pub fn maybe_initialize( - duration: Option, - ) -> (Option, Option>>) { - match duration { - Some(duration) => { - let task = Self::initialize(duration); - let tracker = task.tracker(); - (Some(task), Some(tracker)) - } - None => (None, None), - } - } - - /// Spawns a new task that continues to monitor the time since a - /// connection on the server existed, reporting a shutdown to all listeners - /// once the timeout is exceeded - pub fn initialize(duration: Duration) -> Self { - let tracker = Arc::new(Mutex::new(ConnTracker::new())); - - let tracker_2 = Arc::clone(&tracker); - let task = tokio::spawn(async move { - loop { - // Get the time since the last connection joined/left - let (base_time, cnt) = tracker_2.lock().await.time_and_cnt(); - - // If we have no connections left, we want to wait - // until the remaining period has passed and then - // verify that we still have no connections - if cnt == 0 { - // Get the time we should wait based on when the last connection - // was dropped; this closes the gap in the case where we start - // sometime later than exactly duration since the last check - let next_time = base_time + duration; - let wait_duration = next_time - .checked_duration_since(Instant::now()) - .unwrap_or_default() - + Duration::from_millis(1); - - // Wait until we've reached our desired duration since the - // last connection was dropped - time::sleep(wait_duration).await; - - // If we do have a connection at this point, don't exit - if !tracker_2.lock().await.has_reached_timeout(duration) { - continue; - } - - // Otherwise, we now should exit, which we do by reporting - debug!( - "Shutdown time of {}s has been reached!", - duration.as_secs_f32() - ); - break; - } - - // Otherwise, we just wait the full duration as worst case - // we'll have waited just about the time desired if right - // after waiting starts the last connection is closed - time::sleep(duration).await; - } - }); - - Self { task, tracker } - } - - /// Produces a new copy of the connection tracker associated with the shutdown manager - pub fn tracker(&self) -> Arc> { - Arc::clone(&self.tracker) - } -} - -pub struct ConnTracker { - time: Instant, - cnt: usize, -} - -impl ConnTracker { - pub fn new() -> Self { - Self { - time: Instant::now(), - cnt: 0, - } - } - - pub fn increment(&mut self) { - self.time = Instant::now(); - self.cnt += 1; - } - - pub fn decrement(&mut self) { - if self.cnt > 0 { - self.time = Instant::now(); - self.cnt -= 1; - } - } - - fn time_and_cnt(&self) -> (Instant, usize) { - (self.time, self.cnt) - } - - fn has_reached_timeout(&self, duration: Duration) -> bool { - self.cnt == 0 && self.time.elapsed() >= duration - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::thread; - - #[tokio::test] - async fn shutdown_task_should_not_resolve_if_has_connection_regardless_of_time() { - let mut task = ShutdownTask::initialize(Duration::from_millis(10)); - task.tracker().lock().await.increment(); - assert!( - futures::poll!(&mut task).is_pending(), - "Shutdown task unexpectedly completed" - ); - - time::sleep(Duration::from_millis(50)).await; - - assert!( - futures::poll!(task).is_pending(), - "Shutdown task unexpectedly completed" - ); - } - - #[tokio::test] - async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration() { - let mut task = ShutdownTask::initialize(Duration::from_millis(10)); - assert!( - futures::poll!(&mut task).is_pending(), - "Shutdown task unexpectedly completed" - ); - - tokio::select! { - _ = task => {} - _ = time::sleep(Duration::from_secs(1)) => { - panic!("Shutdown task unexpectedly pending"); - } - } - } - - #[tokio::test] - async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration_after_connection_removed( - ) { - let mut task = ShutdownTask::initialize(Duration::from_millis(10)); - task.tracker().lock().await.increment(); - assert!( - futures::poll!(&mut task).is_pending(), - "Shutdown task unexpectedly completed" - ); - - time::sleep(Duration::from_millis(50)).await; - assert!( - futures::poll!(&mut task).is_pending(), - "Shutdown task unexpectedly completed" - ); - - task.tracker().lock().await.decrement(); - - tokio::select! { - _ = task => {} - _ = time::sleep(Duration::from_secs(1)) => { - panic!("Shutdown task unexpectedly pending"); - } - } - } - - #[tokio::test] - async fn shutdown_task_should_not_resolve_before_minimum_duration() { - let mut task = ShutdownTask::initialize(Duration::from_millis(50)); - assert!( - futures::poll!(&mut task).is_pending(), - "Shutdown task unexpectedly completed" - ); - - time::sleep(Duration::from_millis(5)).await; - - assert!( - futures::poll!(task).is_pending(), - "Shutdown task unexpectedly completed" - ); - } - - #[test] - fn conn_tracker_should_update_time_when_incremented() { - let mut tracker = ConnTracker::new(); - let (old_time, cnt) = tracker.time_and_cnt(); - assert_eq!(cnt, 0); - - // Wait to ensure that the new time will be different - thread::sleep(Duration::from_millis(1)); - - tracker.increment(); - let (new_time, cnt) = tracker.time_and_cnt(); - assert_eq!(cnt, 1); - assert!(new_time > old_time); - } - - #[test] - fn conn_tracker_should_update_time_when_decremented() { - let mut tracker = ConnTracker::new(); - tracker.increment(); - - let (old_time, cnt) = tracker.time_and_cnt(); - assert_eq!(cnt, 1); - - // Wait to ensure that the new time will be different - thread::sleep(Duration::from_millis(1)); - - tracker.decrement(); - let (new_time, cnt) = tracker.time_and_cnt(); - assert_eq!(cnt, 0); - assert!(new_time > old_time); - } - - #[test] - fn conn_tracker_should_not_update_time_when_decremented_if_at_zero_already() { - let mut tracker = ConnTracker::new(); - let (old_time, cnt) = tracker.time_and_cnt(); - assert_eq!(cnt, 0); - - // Wait to ensure that the new time would be different if updated - thread::sleep(Duration::from_millis(1)); - - tracker.decrement(); - let (new_time, cnt) = tracker.time_and_cnt(); - assert_eq!(cnt, 0); - assert!(new_time == old_time); - } - - #[test] - fn conn_tracker_should_report_timeout_reached_when_time_has_elapsed_and_no_connections() { - let tracker = ConnTracker::new(); - let (_, cnt) = tracker.time_and_cnt(); - assert_eq!(cnt, 0); - - // Wait to ensure that the new time would be different if updated - thread::sleep(Duration::from_millis(1)); - - assert!(tracker.has_reached_timeout(Duration::from_millis(1))); - } - - #[test] - fn conn_tracker_should_not_report_timeout_reached_when_time_has_elapsed_but_has_connections() { - let mut tracker = ConnTracker::new(); - tracker.increment(); - - let (_, cnt) = tracker.time_and_cnt(); - assert_eq!(cnt, 1); - - // Wait to ensure that the new time would be different if updated - thread::sleep(Duration::from_millis(1)); - - assert!(!tracker.has_reached_timeout(Duration::from_millis(1))); - } -} diff --git a/distant-core/tests/manager_tests.rs b/distant-core/tests/manager_tests.rs new file mode 100644 index 0000000..ee8abc0 --- /dev/null +++ b/distant-core/tests/manager_tests.rs @@ -0,0 +1,96 @@ +use distant_core::{ + net::{FramedTransport, InmemoryTransport, IntoSplit, OneshotListener, PlainCodec}, + BoxedDistantReader, BoxedDistantWriter, Destination, DistantApiServer, DistantChannelExt, + DistantManager, DistantManagerClient, DistantManagerClientConfig, DistantManagerConfig, Extra, +}; +use std::io; + +/// Creates a client transport and server listener for our tests +/// that are connected together +async fn setup() -> ( + FramedTransport, + OneshotListener>, +) { + let (t1, t2) = InmemoryTransport::pair(100); + + let listener = OneshotListener::from_value(FramedTransport::new(t2, PlainCodec)); + let transport = FramedTransport::new(t1, PlainCodec); + (transport, listener) +} + +#[tokio::test] +async fn should_be_able_to_establish_a_single_connection_and_communicate() { + let (transport, listener) = setup().await; + + let config = DistantManagerConfig::default(); + let manager_ref = DistantManager::start(config, listener).expect("Failed to start manager"); + + // NOTE: To pass in a raw function, we HAVE to specify the types of the parameters manually, + // otherwise we get a compilation error about lifetime mismatches + manager_ref + .register_connect_handler("scheme", |_: &_, _: &_, _: &mut _| async { + use distant_core::net::ServerExt; + let (t1, t2) = FramedTransport::pair(100); + + // Spawn a server on one end + let _ = DistantApiServer::local() + .unwrap() + .start(OneshotListener::from_value(t2.into_split()))?; + + // Create a reader/writer pair on the other end + let (writer, reader) = t1.into_split(); + let writer: BoxedDistantWriter = Box::new(writer); + let reader: BoxedDistantReader = Box::new(reader); + Ok((writer, reader)) + }) + .await + .expect("Failed to register handler"); + + let config = DistantManagerClientConfig::with_empty_prompts(); + let mut client = + DistantManagerClient::new(config, transport).expect("Failed to connect to manager"); + + // Test establishing a connection to some remote server + let id = client + .connect( + "scheme://host".parse::().unwrap(), + "key=value".parse::().unwrap(), + ) + .await + .expect("Failed to connect to a remote server"); + + // Test retrieving list of connections + let list = client + .list() + .await + .expect("Failed to get list of connections"); + assert_eq!(list.len(), 1); + assert_eq!(list.get(&id).unwrap().to_string(), "scheme://host/"); + + // Test retrieving information + let info = client + .info(id) + .await + .expect("Failed to get info about connection"); + assert_eq!(info.id, id); + assert_eq!(info.destination.to_string(), "scheme://host/"); + assert_eq!(info.extra, "key=value".parse::().unwrap()); + + // Create a new channel and request some data + let mut channel = client + .open_channel(id) + .await + .expect("Failed to open channel"); + let _ = channel + .system_info() + .await + .expect("Failed to get system information"); + + // Test killing a connection + client.kill(id).await.expect("Failed to kill connection"); + + // Test getting an error to ensure that serialization of that data works, + // which we do by trying to access a connection that no longer exists + let err = client.info(id).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::NotConnected); +} diff --git a/distant-core/tests/stress/distant/watch.rs b/distant-core/tests/stress/distant/watch.rs index f33a97a..a107204 100644 --- a/distant-core/tests/stress/distant/watch.rs +++ b/distant-core/tests/stress/distant/watch.rs @@ -1,6 +1,6 @@ use crate::stress::fixtures::*; use assert_fs::prelude::*; -use distant_core::{ChangeKindSet, SessionChannelExt}; +use distant_core::{data::ChangeKindSet, DistantChannelExt}; use rstest::*; const MAX_FILES: usize = 500; @@ -8,11 +8,10 @@ const MAX_FILES: usize = 500; #[rstest] #[tokio::test] #[ignore] -async fn should_handle_large_volume_of_file_watching(#[future] ctx: DistantSessionCtx) { +async fn should_handle_large_volume_of_file_watching(#[future] ctx: DistantClientCtx) { let ctx = ctx.await; - let mut channel = ctx.session.clone_channel(); + let mut channel = ctx.client.clone_channel(); - let tenant = "watch-stress-test"; let root = assert_fs::TempDir::new().unwrap(); let mut files_and_watchers = Vec::new(); @@ -25,7 +24,6 @@ async fn should_handle_large_volume_of_file_watching(#[future] ctx: DistantSessi eprintln!("Watching {:?}", file.path()); let watcher = channel .watch( - tenant, file.path(), false, ChangeKindSet::modify_set(), diff --git a/distant-core/tests/stress/fixtures.rs b/distant-core/tests/stress/fixtures.rs index e64c415..5e8884a 100644 --- a/distant-core/tests/stress/fixtures.rs +++ b/distant-core/tests/stress/fixtures.rs @@ -1,17 +1,20 @@ use crate::stress::utils; -use distant_core::{DistantServer, SecretKey, SecretKey32, Session, XChaCha20Poly1305Codec}; +use distant_core::{DistantApiServer, DistantClient, LocalDistantApi}; +use distant_net::{ + PortRange, SecretKey, SecretKey32, TcpClientExt, TcpServerExt, XChaCha20Poly1305Codec, +}; use rstest::*; use std::time::Duration; use tokio::sync::mpsc; const LOG_PATH: &str = "/tmp/test.distant.server.log"; -pub struct DistantSessionCtx { - pub session: Session, +pub struct DistantClientCtx { + pub client: DistantClient, _done_tx: mpsc::Sender<()>, } -impl DistantSessionCtx { +impl DistantClientCtx { pub async fn initialize() -> Self { let ip_addr = "127.0.0.1".parse().unwrap(); let (done_tx, mut done_rx) = mpsc::channel::<()>(1); @@ -21,14 +24,21 @@ impl DistantSessionCtx { let logger = utils::init_logging(LOG_PATH); let key = SecretKey::default(); let codec = XChaCha20Poly1305Codec::from(key.clone()); - let (_server, port) = - DistantServer::bind(ip_addr, "0".parse().unwrap(), codec, Default::default()) - .await - .unwrap(); - started_tx.send((port, key)).await.unwrap(); + if let Ok(api) = LocalDistantApi::initialize() { + let port: PortRange = "0".parse().unwrap(); + let port = { + let server_ref = DistantApiServer::new(api) + .start(ip_addr, port, codec) + .await + .unwrap(); + server_ref.port() + }; + + started_tx.send((port, key)).await.unwrap(); + let _ = done_rx.recv().await; + } - let _ = done_rx.recv().await; logger.flush(); logger.shutdown(); }); @@ -36,8 +46,8 @@ impl DistantSessionCtx { // Extract our server startup data if we succeeded let (port, key) = started_rx.recv().await.unwrap(); - // Now initialize our session - let session = Session::tcp_connect_timeout( + // Now initialize our client + let client = DistantClient::connect_timeout( format!("{}:{}", ip_addr, port).parse().unwrap(), XChaCha20Poly1305Codec::from(key), Duration::from_secs(1), @@ -45,14 +55,14 @@ impl DistantSessionCtx { .await .unwrap(); - DistantSessionCtx { - session, + DistantClientCtx { + client, _done_tx: done_tx, } } } #[fixture] -pub async fn ctx() -> DistantSessionCtx { - DistantSessionCtx::initialize().await +pub async fn ctx() -> DistantClientCtx { + DistantClientCtx::initialize().await } diff --git a/distant-net/Cargo.toml b/distant-net/Cargo.toml new file mode 100644 index 0000000..3abefb7 --- /dev/null +++ b/distant-net/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "distant-net" +description = "Network library for distant, providing implementations to support client/server architecture" +categories = ["network-programming"] +keywords = ["api", "async"] +version = "0.17.0" +authors = ["Chip Senkbeil "] +edition = "2021" +homepage = "https://github.com/chipsenkbeil/distant" +repository = "https://github.com/chipsenkbeil/distant" +readme = "README.md" +license = "MIT OR Apache-2.0" + +[dependencies] +async-trait = "0.1.56" +bytes = "1.1.0" +chacha20poly1305 = "=0.10.0-pre" +derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] } +futures = "0.3.21" +hex = "0.4.3" +hkdf = "0.12.3" +log = "0.4.17" +paste = "1.0.7" +p256 = { version = "0.11.1", features = ["ecdh", "pem"] } +rand = { version = "0.8.4", features = ["getrandom"] } +rmp-serde = "1.1.0" +sha2 = "0.10.2" +serde = { version = "1.0.126", features = ["derive"] } +serde_bytes = "0.11.6" +tokio = { version = "1.12.0", features = ["full"] } +tokio-util = { version = "0.6.7", features = ["codec"] } + +# Optional dependencies based on features +schemars = { version = "0.8.10", optional = true } + +[dev-dependencies] +tempfile = "3" diff --git a/distant-net/README.md b/distant-net/README.md new file mode 100644 index 0000000..a25b73b --- /dev/null +++ b/distant-net/README.md @@ -0,0 +1,49 @@ +# distant net + +[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.61.0][distant_rustc_img]][distant_rustc_lnk] + +[distant_crates_img]: https://img.shields.io/crates/v/distant-net.svg +[distant_crates_lnk]: https://crates.io/crates/distant-net +[distant_doc_img]: https://docs.rs/distant-net/badge.svg +[distant_doc_lnk]: https://docs.rs/distant-net +[distant_rustc_img]: https://img.shields.io/badge/distant_net-rustc_1.61+-lightgray.svg +[distant_rustc_lnk]: https://blog.rust-lang.org/2022/05/19/Rust-1.61.0.html + +Library that powers the [`distant`](https://github.com/chipsenkbeil/distant) +binary. + +🚧 **(Alpha stage software) This library is in rapid development and may break or change frequently!** 🚧 + +## Details + +The `distant-net` library supplies the foundational networking functionality +for the distant interfaces and distant cli. + +## Installation + +You can import the dependency by adding the following to your `Cargo.toml`: + +```toml +[dependencies] +distant-net = "0.17" +``` + +## Features + +Currently, the library supports the following features: + +- `schemars`: derives the `schemars::JsonSchema` interface on `Request` + and `Response` data types + +By default, no features are enabled on the library. + +## License + +This project is licensed under either of + +Apache License, Version 2.0, (LICENSE-APACHE or +[apache-license][apache-license]) MIT license (LICENSE-MIT or +[mit-license][mit-license]) at your option. + +[apache-license]: http://www.apache.org/licenses/LICENSE-2.0 +[mit-license]: http://opensource.org/licenses/MIT diff --git a/distant-net/src/any.rs b/distant-net/src/any.rs new file mode 100644 index 0000000..a03bc61 --- /dev/null +++ b/distant-net/src/any.rs @@ -0,0 +1,29 @@ +use std::any::Any; + +/// Trait used for casting support into the [`Any`] trait object +pub trait AsAny: Any { + /// Converts reference to [`Any`] + fn as_any(&self) -> &dyn Any; + + /// Converts mutable reference to [`Any`] + fn as_mut_any(&mut self) -> &mut dyn Any; + + /// Consumes and produces `Box` + fn into_any(self: Box) -> Box; +} + +/// Blanket implementation that enables any `'static` reference to convert +/// to the [`Any`] type +impl AsAny for T { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn into_any(self: Box) -> Box { + self + } +} diff --git a/distant-net/src/auth.rs b/distant-net/src/auth.rs new file mode 100644 index 0000000..89f9b40 --- /dev/null +++ b/distant-net/src/auth.rs @@ -0,0 +1,122 @@ +use derive_more::Display; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +mod client; +pub use client::*; + +mod handshake; +pub use handshake::*; + +mod server; +pub use server::*; + +/// Represents authentication messages that can be sent over the wire +/// +/// NOTE: Must use serde's content attribute with the tag attribute. Just the tag attribute will +/// cause deserialization to fail +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type", content = "data")] +pub enum Auth { + /// Represents a request to perform an authentication handshake, + /// providing the public key and salt from one side in order to + /// derive the shared key + #[serde(rename = "auth_handshake")] + Handshake { + /// Bytes of the public key + #[serde(with = "serde_bytes")] + public_key: PublicKeyBytes, + + /// Randomly generated salt + #[serde(with = "serde_bytes")] + salt: Salt, + }, + + /// Represents the bytes of an encrypted message + /// + /// Underneath, will be one of either [`AuthRequest`] or [`AuthResponse`] + #[serde(rename = "auth_msg")] + Msg { + #[serde(with = "serde_bytes")] + encrypted_payload: Vec, + }, +} + +/// Represents authentication messages that act as initiators such as providing +/// a challenge, verifying information, presenting information, or highlighting an error +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum AuthRequest { + /// Represents a challenge comprising a series of questions to be presented + Challenge { + questions: Vec, + extra: HashMap, + }, + + /// Represents an ask to verify some information + Verify { kind: AuthVerifyKind, text: String }, + + /// Represents some information to be presented + Info { text: String }, + + /// Represents some error that occurred + Error { kind: AuthErrorKind, text: String }, +} + +/// Represents authentication messages that are responses to auth requests such +/// as answers to challenges or verifying information +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum AuthResponse { + /// Represents the answers to a previously-asked challenge + Challenge { answers: Vec }, + + /// Represents the answer to a previously-asked verify + Verify { valid: bool }, +} + +/// Represents the type of verification being requested +#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +#[non_exhaustive] +pub enum AuthVerifyKind { + /// An ask to verify the host such as with SSH + #[display(fmt = "host")] + Host, +} + +/// Represents a single question in a challenge +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct AuthQuestion { + /// The text of the question + pub text: String, + + /// Any extra information specific to a particular auth domain + /// such as including a username and instructions for SSH authentication + pub extra: HashMap, +} + +impl AuthQuestion { + /// Creates a new question without any extra data + pub fn new(text: impl Into) -> Self { + Self { + text: text.into(), + extra: HashMap::new(), + } + } +} + +/// Represents the type of error encountered during authentication +#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AuthErrorKind { + /// When the answer(s) to a challenge do not pass authentication + FailedChallenge, + + /// When verification during authentication fails + /// (e.g. a host is not allowed or blocked) + FailedVerification, + + /// When the error is unknown + Unknown, +} diff --git a/distant-net/src/auth/client.rs b/distant-net/src/auth/client.rs new file mode 100644 index 0000000..3761d7c --- /dev/null +++ b/distant-net/src/auth/client.rs @@ -0,0 +1,817 @@ +use crate::{ + utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Client, + Codec, Handshake, XChaCha20Poly1305Codec, +}; +use bytes::BytesMut; +use log::*; +use std::{collections::HashMap, io}; + +pub struct AuthClient { + inner: Client, + codec: Option, + jit_handshake: bool, +} + +impl From> for AuthClient { + fn from(client: Client) -> Self { + Self { + inner: client, + codec: None, + jit_handshake: false, + } + } +} + +impl AuthClient { + /// Sends a request to the server to establish an encrypted connection + pub async fn handshake(&mut self) -> io::Result<()> { + let handshake = Handshake::default(); + + let response = self + .inner + .send(Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + }) + .await?; + + match response.payload { + Auth::Handshake { public_key, salt } => { + let key = handshake.handshake(public_key, salt)?; + self.codec.replace(XChaCha20Poly1305Codec::new(&key)); + Ok(()) + } + Auth::Msg { .. } => Err(io::Error::new( + io::ErrorKind::Other, + "Got unexpected encrypted message during handshake", + )), + } + } + + /// Perform a handshake only if jit is enabled and no handshake has succeeded yet + async fn jit_handshake(&mut self) -> io::Result<()> { + if self.will_jit_handshake() && !self.is_ready() { + self.handshake().await + } else { + Ok(()) + } + } + + /// Returns true if client has successfully performed a handshake + /// and is ready to communicate with the server + pub fn is_ready(&self) -> bool { + self.codec.is_some() + } + + /// Returns true if this client will perform a handshake just-in-time (JIT) prior to making a + /// request in the scenario where the client has not already performed a handshake + #[inline] + pub fn will_jit_handshake(&self) -> bool { + self.jit_handshake + } + + /// Sets the jit flag on this client with `true` indicating that this client will perform a + /// handshake just-in-time (JIT) prior to making a request in the scenario where the client has + /// not already performed a handshake + #[inline] + pub fn set_jit_handshake(&mut self, flag: bool) { + self.jit_handshake = flag; + } + + /// Provides a challenge to the server and returns the answers to the questions + /// asked by the client + pub async fn challenge( + &mut self, + questions: Vec, + extra: HashMap, + ) -> io::Result> { + trace!( + "AuthClient::challenge(questions = {:?}, extra = {:?})", + questions, + extra + ); + + // Perform JIT handshake if enabled + self.jit_handshake().await?; + + let payload = AuthRequest::Challenge { questions, extra }; + let encrypted_payload = self.serialize_and_encrypt(&payload)?; + let response = self.inner.send(Auth::Msg { encrypted_payload }).await?; + + match response.payload { + Auth::Msg { encrypted_payload } => { + match self.decrypt_and_deserialize(&encrypted_payload)? { + AuthResponse::Challenge { answers } => Ok(answers), + AuthResponse::Verify { .. } => Err(io::Error::new( + io::ErrorKind::Other, + "Got unexpected verify response during challenge", + )), + } + } + Auth::Handshake { .. } => Err(io::Error::new( + io::ErrorKind::Other, + "Got unexpected handshake during challenge", + )), + } + } + + /// Provides a verification request to the server and returns whether or not + /// the server approved + pub async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result { + trace!("AuthClient::verify(kind = {:?}, text = {:?})", kind, text); + + // Perform JIT handshake if enabled + self.jit_handshake().await?; + + let payload = AuthRequest::Verify { kind, text }; + let encrypted_payload = self.serialize_and_encrypt(&payload)?; + let response = self.inner.send(Auth::Msg { encrypted_payload }).await?; + + match response.payload { + Auth::Msg { encrypted_payload } => { + match self.decrypt_and_deserialize(&encrypted_payload)? { + AuthResponse::Verify { valid } => Ok(valid), + AuthResponse::Challenge { .. } => Err(io::Error::new( + io::ErrorKind::Other, + "Got unexpected challenge response during verify", + )), + } + } + Auth::Handshake { .. } => Err(io::Error::new( + io::ErrorKind::Other, + "Got unexpected handshake during verify", + )), + } + } + + /// Provides information to the server to use as it pleases with no response expected + pub async fn info(&mut self, text: String) -> io::Result<()> { + trace!("AuthClient::info(text = {:?})", text); + + // Perform JIT handshake if enabled + self.jit_handshake().await?; + + let payload = AuthRequest::Info { text }; + let encrypted_payload = self.serialize_and_encrypt(&payload)?; + self.inner.fire(Auth::Msg { encrypted_payload }).await + } + + /// Provides an error to the server to use as it pleases with no response expected + pub async fn error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> { + trace!("AuthClient::error(kind = {:?}, text = {:?})", kind, text); + + // Perform JIT handshake if enabled + self.jit_handshake().await?; + + let payload = AuthRequest::Error { kind, text }; + let encrypted_payload = self.serialize_and_encrypt(&payload)?; + self.inner.fire(Auth::Msg { encrypted_payload }).await + } + + fn serialize_and_encrypt(&mut self, payload: &AuthRequest) -> io::Result> { + let codec = self.codec.as_mut().ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "Handshake must be performed first (client encrypt message)", + ) + })?; + + let mut encryped_payload = BytesMut::new(); + let payload = utils::serialize_to_vec(payload)?; + codec.encode(&payload, &mut encryped_payload)?; + Ok(encryped_payload.freeze().to_vec()) + } + + fn decrypt_and_deserialize(&mut self, payload: &[u8]) -> io::Result { + let codec = self.codec.as_mut().ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "Handshake must be performed first (client decrypt message)", + ) + })?; + + let mut payload = BytesMut::from(payload); + match codec.decode(&mut payload)? { + Some(payload) => utils::deserialize_from_slice::(&payload), + None => Err(io::Error::new( + io::ErrorKind::InvalidData, + "Incomplete message received", + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Client, FramedTransport, Request, Response, TypedAsyncRead, TypedAsyncWrite}; + use serde::{de::DeserializeOwned, Serialize}; + + const TIMEOUT_MILLIS: u64 = 100; + + #[tokio::test] + async fn handshake_should_fail_if_get_unexpected_response_from_server() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { client.handshake().await }); + + // Get the request, but send a bad response + let request: Request = server.read().await.unwrap().unwrap(); + match request.payload { + Auth::Handshake { .. } => server + .write(Response::new( + request.id, + Auth::Msg { + encrypted_payload: Vec::new(), + }, + )) + .await + .unwrap(), + _ => panic!("Server received unexpected payload"), + } + + let result = task.await.unwrap(); + assert!(result.is_err(), "Handshake succeeded unexpectedly") + } + + #[tokio::test] + async fn challenge_should_fail_if_handshake_not_finished() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { client.challenge(Vec::new(), HashMap::new()).await }); + + // Wait for a request, failing if we get one as the failure + // should have prevented sending anything, but we should + tokio::select! { + x = TypedAsyncRead::>::read(&mut server) => { + match x { + Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x), + Ok(None) => {}, + Err(x) => panic!("Unexpectedly failed on server side: {}", x), + } + }, + _ = wait_ms(TIMEOUT_MILLIS) => { + panic!("Should have gotten server closure as part of client exit"); + } + } + + // Verify that we got an error with the method + let result = task.await.unwrap(); + assert!(result.is_err(), "Challenge succeeded unexpectedly") + } + + #[tokio::test] + async fn challenge_should_fail_if_receive_wrong_response() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { + client.handshake().await.unwrap(); + client + .challenge( + vec![ + AuthQuestion::new("question1".to_string()), + AuthQuestion { + text: "question2".to_string(), + extra: vec![("key2".to_string(), "value2".to_string())] + .into_iter() + .collect(), + }, + ], + vec![("key".to_string(), "value".to_string())] + .into_iter() + .collect(), + ) + .await + }); + + // Wait for a handshake request and set up our encryption codec + let request: Request = server.read().await.unwrap().unwrap(); + let mut codec = match request.payload { + Auth::Handshake { public_key, salt } => { + let handshake = Handshake::default(); + let key = handshake.handshake(public_key, salt).unwrap(); + server + .write(Response::new( + request.id, + Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + }, + )) + .await + .unwrap(); + XChaCha20Poly1305Codec::new(&key) + } + _ => panic!("Server received unexpected payload"), + }; + + // Wait for a challenge request and send back wrong response + let request: Request = server.read().await.unwrap().unwrap(); + match request.payload { + Auth::Msg { encrypted_payload } => { + match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { + AuthRequest::Challenge { .. } => { + server + .write(Response::new( + request.id, + Auth::Msg { + encrypted_payload: serialize_and_encrypt( + &mut codec, + &AuthResponse::Verify { valid: true }, + ) + .unwrap(), + }, + )) + .await + .unwrap(); + } + _ => panic!("Server received wrong request type"), + } + } + _ => panic!("Server received unexpected payload"), + }; + + // Verify that we got an error with the method + let result = task.await.unwrap(); + assert!(result.is_err(), "Challenge succeeded unexpectedly") + } + + #[tokio::test] + async fn challenge_should_return_answers_received_from_server() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { + client.handshake().await.unwrap(); + client + .challenge( + vec![ + AuthQuestion::new("question1".to_string()), + AuthQuestion { + text: "question2".to_string(), + extra: vec![("key2".to_string(), "value2".to_string())] + .into_iter() + .collect(), + }, + ], + vec![("key".to_string(), "value".to_string())] + .into_iter() + .collect(), + ) + .await + }); + + // Wait for a handshake request and set up our encryption codec + let request: Request = server.read().await.unwrap().unwrap(); + let mut codec = match request.payload { + Auth::Handshake { public_key, salt } => { + let handshake = Handshake::default(); + let key = handshake.handshake(public_key, salt).unwrap(); + server + .write(Response::new( + request.id, + Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + }, + )) + .await + .unwrap(); + XChaCha20Poly1305Codec::new(&key) + } + _ => panic!("Server received unexpected payload"), + }; + + // Wait for a challenge request and send back wrong response + let request: Request = server.read().await.unwrap().unwrap(); + match request.payload { + Auth::Msg { encrypted_payload } => { + match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { + AuthRequest::Challenge { questions, extra } => { + assert_eq!( + questions, + vec![ + AuthQuestion::new("question1".to_string()), + AuthQuestion { + text: "question2".to_string(), + extra: vec![("key2".to_string(), "value2".to_string())] + .into_iter() + .collect(), + }, + ], + ); + + assert_eq!( + extra, + vec![("key".to_string(), "value".to_string())] + .into_iter() + .collect(), + ); + + server + .write(Response::new( + request.id, + Auth::Msg { + encrypted_payload: serialize_and_encrypt( + &mut codec, + &AuthResponse::Challenge { + answers: vec![ + "answer1".to_string(), + "answer2".to_string(), + ], + }, + ) + .unwrap(), + }, + )) + .await + .unwrap(); + } + _ => panic!("Server received wrong request type"), + } + } + _ => panic!("Server received unexpected payload"), + }; + + // Verify that we got the right results + let answers = task.await.unwrap().unwrap(); + assert_eq!(answers, vec!["answer1".to_string(), "answer2".to_string()]); + } + + #[tokio::test] + async fn verify_should_fail_if_handshake_not_finished() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { + client + .verify(AuthVerifyKind::Host, "some text".to_string()) + .await + }); + + // Wait for a request, failing if we get one as the failure + // should have prevented sending anything, but we should + tokio::select! { + x = TypedAsyncRead::>::read(&mut server) => { + match x { + Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x), + Ok(None) => {}, + Err(x) => panic!("Unexpectedly failed on server side: {}", x), + } + }, + _ = wait_ms(TIMEOUT_MILLIS) => { + panic!("Should have gotten server closure as part of client exit"); + } + } + + // Verify that we got an error with the method + let result = task.await.unwrap(); + assert!(result.is_err(), "Verify succeeded unexpectedly") + } + + #[tokio::test] + async fn verify_should_fail_if_receive_wrong_response() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { + client.handshake().await.unwrap(); + client + .verify(AuthVerifyKind::Host, "some text".to_string()) + .await + }); + + // Wait for a handshake request and set up our encryption codec + let request: Request = server.read().await.unwrap().unwrap(); + let mut codec = match request.payload { + Auth::Handshake { public_key, salt } => { + let handshake = Handshake::default(); + let key = handshake.handshake(public_key, salt).unwrap(); + server + .write(Response::new( + request.id, + Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + }, + )) + .await + .unwrap(); + XChaCha20Poly1305Codec::new(&key) + } + _ => panic!("Server received unexpected payload"), + }; + + // Wait for a verify request and send back wrong response + let request: Request = server.read().await.unwrap().unwrap(); + match request.payload { + Auth::Msg { encrypted_payload } => { + match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { + AuthRequest::Verify { .. } => { + server + .write(Response::new( + request.id, + Auth::Msg { + encrypted_payload: serialize_and_encrypt( + &mut codec, + &AuthResponse::Challenge { + answers: Vec::new(), + }, + ) + .unwrap(), + }, + )) + .await + .unwrap(); + } + _ => panic!("Server received wrong request type"), + } + } + _ => panic!("Server received unexpected payload"), + }; + + // Verify that we got an error with the method + let result = task.await.unwrap(); + assert!(result.is_err(), "Verify succeeded unexpectedly") + } + + #[tokio::test] + async fn verify_should_return_valid_bool_received_from_server() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { + client.handshake().await.unwrap(); + client + .verify(AuthVerifyKind::Host, "some text".to_string()) + .await + }); + + // Wait for a handshake request and set up our encryption codec + let request: Request = server.read().await.unwrap().unwrap(); + let mut codec = match request.payload { + Auth::Handshake { public_key, salt } => { + let handshake = Handshake::default(); + let key = handshake.handshake(public_key, salt).unwrap(); + server + .write(Response::new( + request.id, + Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + }, + )) + .await + .unwrap(); + XChaCha20Poly1305Codec::new(&key) + } + _ => panic!("Server received unexpected payload"), + }; + + // Wait for a challenge request and send back wrong response + let request: Request = server.read().await.unwrap().unwrap(); + match request.payload { + Auth::Msg { encrypted_payload } => { + match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { + AuthRequest::Verify { kind, text } => { + assert_eq!(kind, AuthVerifyKind::Host); + assert_eq!(text, "some text"); + + server + .write(Response::new( + request.id, + Auth::Msg { + encrypted_payload: serialize_and_encrypt( + &mut codec, + &AuthResponse::Verify { valid: true }, + ) + .unwrap(), + }, + )) + .await + .unwrap(); + } + _ => panic!("Server received wrong request type"), + } + } + _ => panic!("Server received unexpected payload"), + }; + + // Verify that we got the right results + let valid = task.await.unwrap().unwrap(); + assert!(valid, "Got verify response, but valid was set incorrectly"); + } + + #[tokio::test] + async fn info_should_fail_if_handshake_not_finished() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { client.info("some text".to_string()).await }); + + // Wait for a request, failing if we get one as the failure + // should have prevented sending anything, but we should + tokio::select! { + x = TypedAsyncRead::>::read(&mut server) => { + match x { + Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x), + Ok(None) => {}, + Err(x) => panic!("Unexpectedly failed on server side: {}", x), + } + }, + _ = wait_ms(TIMEOUT_MILLIS) => { + panic!("Should have gotten server closure as part of client exit"); + } + } + + // Verify that we got an error with the method + let result = task.await.unwrap(); + assert!(result.is_err(), "Info succeeded unexpectedly") + } + + #[tokio::test] + async fn info_should_send_the_server_a_request_but_not_wait_for_a_response() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { + client.handshake().await.unwrap(); + client.info("some text".to_string()).await + }); + + // Wait for a handshake request and set up our encryption codec + let request: Request = server.read().await.unwrap().unwrap(); + let mut codec = match request.payload { + Auth::Handshake { public_key, salt } => { + let handshake = Handshake::default(); + let key = handshake.handshake(public_key, salt).unwrap(); + server + .write(Response::new( + request.id, + Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + }, + )) + .await + .unwrap(); + XChaCha20Poly1305Codec::new(&key) + } + _ => panic!("Server received unexpected payload"), + }; + + // Wait for a request + let request: Request = server.read().await.unwrap().unwrap(); + match request.payload { + Auth::Msg { encrypted_payload } => { + match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { + AuthRequest::Info { text } => { + assert_eq!(text, "some text"); + } + _ => panic!("Server received wrong request type"), + } + } + _ => panic!("Server received unexpected payload"), + }; + + // Verify that we got the right results + task.await.unwrap().unwrap(); + } + + #[tokio::test] + async fn error_should_fail_if_handshake_not_finished() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { + client + .error(AuthErrorKind::FailedChallenge, "some text".to_string()) + .await + }); + + // Wait for a request, failing if we get one as the failure + // should have prevented sending anything, but we should + tokio::select! { + x = TypedAsyncRead::>::read(&mut server) => { + match x { + Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x), + Ok(None) => {}, + Err(x) => panic!("Unexpectedly failed on server side: {}", x), + } + }, + _ = wait_ms(TIMEOUT_MILLIS) => { + panic!("Should have gotten server closure as part of client exit"); + } + } + + // Verify that we got an error with the method + let result = task.await.unwrap(); + assert!(result.is_err(), "Error succeeded unexpectedly") + } + + #[tokio::test] + async fn error_should_send_the_server_a_request_but_not_wait_for_a_response() { + let (t, mut server) = FramedTransport::make_test_pair(); + let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); + + // We start a separate task for the client to avoid blocking since + // we also need to receive the client's request and respond + let task = tokio::spawn(async move { + client.handshake().await.unwrap(); + client + .error(AuthErrorKind::FailedChallenge, "some text".to_string()) + .await + }); + + // Wait for a handshake request and set up our encryption codec + let request: Request = server.read().await.unwrap().unwrap(); + let mut codec = match request.payload { + Auth::Handshake { public_key, salt } => { + let handshake = Handshake::default(); + let key = handshake.handshake(public_key, salt).unwrap(); + server + .write(Response::new( + request.id, + Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + }, + )) + .await + .unwrap(); + XChaCha20Poly1305Codec::new(&key) + } + _ => panic!("Server received unexpected payload"), + }; + + // Wait for a request + let request: Request = server.read().await.unwrap().unwrap(); + match request.payload { + Auth::Msg { encrypted_payload } => { + match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { + AuthRequest::Error { kind, text } => { + assert_eq!(kind, AuthErrorKind::FailedChallenge); + assert_eq!(text, "some text"); + } + _ => panic!("Server received wrong request type"), + } + } + _ => panic!("Server received unexpected payload"), + }; + + // Verify that we got the right results + task.await.unwrap().unwrap(); + } + + async fn wait_ms(ms: u64) { + use std::time::Duration; + tokio::time::sleep(Duration::from_millis(ms)).await; + } + + fn serialize_and_encrypt( + codec: &mut XChaCha20Poly1305Codec, + payload: &T, + ) -> io::Result> { + let mut encryped_payload = BytesMut::new(); + let payload = utils::serialize_to_vec(payload)?; + codec.encode(&payload, &mut encryped_payload)?; + Ok(encryped_payload.freeze().to_vec()) + } + + fn decrypt_and_deserialize( + codec: &mut XChaCha20Poly1305Codec, + payload: &[u8], + ) -> io::Result { + let mut payload = BytesMut::from(payload); + match codec.decode(&mut payload)? { + Some(payload) => utils::deserialize_from_slice::(&payload), + None => Err(io::Error::new( + io::ErrorKind::InvalidData, + "Incomplete message received", + )), + } + } +} diff --git a/distant-net/src/auth/handshake.rs b/distant-net/src/auth/handshake.rs new file mode 100644 index 0000000..b342720 --- /dev/null +++ b/distant-net/src/auth/handshake.rs @@ -0,0 +1,62 @@ +use p256::{ecdh::EphemeralSecret, PublicKey}; +use rand::rngs::OsRng; +use sha2::Sha256; +use std::{convert::TryFrom, io}; + +mod pkb; +pub use pkb::PublicKeyBytes; + +mod salt; +pub use salt::Salt; + +/// 32-byte key shared by handshake +pub type SharedKey = [u8; 32]; + +/// Utility to perform a handshake +pub struct Handshake { + secret: EphemeralSecret, + salt: Salt, +} + +impl Default for Handshake { + // Create a new handshake instance with a secret and salt + fn default() -> Self { + let secret = EphemeralSecret::random(&mut OsRng); + let salt = Salt::random(); + + Self { secret, salt } + } +} + +impl Handshake { + // Return encoded bytes of public key + pub fn pk_bytes(&self) -> PublicKeyBytes { + PublicKeyBytes::from(self.secret.public_key()) + } + + // Return the salt contained by this handshake + pub fn salt(&self) -> &Salt { + &self.salt + } + + pub fn handshake(&self, public_key: PublicKeyBytes, salt: Salt) -> io::Result { + // Decode the public key of the client + let decoded_public_key = PublicKey::try_from(public_key)?; + + // Produce a salt that is consistent with what the other side will do + let shared_salt = self.salt ^ salt; + + // Acquire the shared secret + let shared_secret = self.secret.diffie_hellman(&decoded_public_key); + + // Extract entropy from the shared secret for use in producing a key + let hkdf = shared_secret.extract::(Some(shared_salt.as_ref())); + + // Derive a shared key (32 bytes) + let mut shared_key = [0u8; 32]; + match hkdf.expand(&[], &mut shared_key) { + Ok(_) => Ok(shared_key), + Err(x) => Err(io::Error::new(io::ErrorKind::InvalidData, x.to_string())), + } + } +} diff --git a/distant-net/src/auth/handshake/pkb.rs b/distant-net/src/auth/handshake/pkb.rs new file mode 100644 index 0000000..f896dc3 --- /dev/null +++ b/distant-net/src/auth/handshake/pkb.rs @@ -0,0 +1,60 @@ +use p256::{EncodedPoint, PublicKey}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::{convert::TryFrom, io}; + +/// Represents a wrapper around [`EncodedPoint`], and exists to +/// fix an issue with [`serde`] deserialization failing when +/// directly serializing the [`EncodedPoint`] type +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(into = "Vec", try_from = "Vec")] +pub struct PublicKeyBytes(EncodedPoint); + +impl From for PublicKeyBytes { + fn from(pk: PublicKey) -> Self { + Self(EncodedPoint::from(pk)) + } +} + +impl TryFrom for PublicKey { + type Error = io::Error; + + fn try_from(pkb: PublicKeyBytes) -> Result { + PublicKey::from_sec1_bytes(pkb.0.as_ref()) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)) + } +} + +impl From for Vec { + fn from(pkb: PublicKeyBytes) -> Self { + pkb.0.as_bytes().to_vec() + } +} + +impl TryFrom> for PublicKeyBytes { + type Error = io::Error; + + fn try_from(bytes: Vec) -> Result { + Ok(Self(EncodedPoint::from_bytes(bytes).map_err(|x| { + io::Error::new(io::ErrorKind::InvalidData, x.to_string()) + })?)) + } +} + +impl serde_bytes::Serialize for PublicKeyBytes { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(self.0.as_ref()) + } +} + +impl<'de> serde_bytes::Deserialize<'de> for PublicKeyBytes { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let bytes = Deserialize::deserialize(deserializer).map(serde_bytes::ByteBuf::into_vec)?; + PublicKeyBytes::try_from(bytes).map_err(serde::de::Error::custom) + } +} diff --git a/distant-net/src/auth/handshake/salt.rs b/distant-net/src/auth/handshake/salt.rs new file mode 100644 index 0000000..11f9826 --- /dev/null +++ b/distant-net/src/auth/handshake/salt.rs @@ -0,0 +1,111 @@ +use rand::{rngs::OsRng, RngCore}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::{ + convert::{TryFrom, TryInto}, + fmt, io, + ops::BitXor, + str::FromStr, +}; + +/// Friendly wrapper around a 32-byte array representing a salt +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(into = "Vec", try_from = "Vec")] +pub struct Salt([u8; 32]); + +impl Salt { + /// Generates a salt via a uniform random + pub fn random() -> Self { + let mut salt = [0u8; 32]; + OsRng.fill_bytes(&mut salt); + Self(salt) + } +} + +impl fmt::Display for Salt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(self.0)) + } +} + +impl serde_bytes::Serialize for Salt { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(self.as_ref()) + } +} + +impl<'de> serde_bytes::Deserialize<'de> for Salt { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let bytes = Deserialize::deserialize(deserializer).map(serde_bytes::ByteBuf::into_vec)?; + let bytes_len = bytes.len(); + Salt::try_from(bytes) + .map_err(|_| serde::de::Error::invalid_length(bytes_len, &"expected 32-byte length")) + } +} + +impl From for String { + fn from(salt: Salt) -> Self { + salt.to_string() + } +} + +impl FromStr for Salt { + type Err = io::Error; + + fn from_str(s: &str) -> Result { + let bytes = hex::decode(s).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?; + Self::try_from(bytes) + } +} + +impl TryFrom for Salt { + type Error = io::Error; + + fn try_from(s: String) -> Result { + s.parse() + } +} + +impl TryFrom> for Salt { + type Error = io::Error; + + fn try_from(bytes: Vec) -> Result { + Ok(Self(bytes.try_into().map_err(|x: Vec| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Vec len of {} != 32", x.len()), + ) + })?)) + } +} + +impl From for Vec { + fn from(salt: Salt) -> Self { + salt.0.to_vec() + } +} + +impl AsRef<[u8]> for Salt { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl BitXor for Salt { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + let shared_salt = self + .0 + .iter() + .zip(rhs.0.iter()) + .map(|(x, y)| x ^ y) + .collect::>(); + Self::try_from(shared_salt).unwrap() + } +} diff --git a/distant-net/src/auth/server.rs b/distant-net/src/auth/server.rs new file mode 100644 index 0000000..7e01e99 --- /dev/null +++ b/distant-net/src/auth/server.rs @@ -0,0 +1,653 @@ +use crate::{ + utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Codec, + Handshake, Server, ServerCtx, XChaCha20Poly1305Codec, +}; +use async_trait::async_trait; +use bytes::BytesMut; +use log::*; +use std::{collections::HashMap, io}; +use tokio::sync::RwLock; + +/// Type signature for a dynamic on_challenge function +pub type AuthChallengeFn = + dyn Fn(Vec, HashMap) -> Vec + Send + Sync; + +/// Type signature for a dynamic on_verify function +pub type AuthVerifyFn = dyn Fn(AuthVerifyKind, String) -> bool + Send + Sync; + +/// Type signature for a dynamic on_info function +pub type AuthInfoFn = dyn Fn(String) + Send + Sync; + +/// Type signature for a dynamic on_error function +pub type AuthErrorFn = dyn Fn(AuthErrorKind, String) + Send + Sync; + +/// Represents an [`AuthServer`] where all handlers are stored on the heap +pub type HeapAuthServer = + AuthServer, Box, Box, Box>; + +/// Server that handles authentication +pub struct AuthServer +where + ChallengeFn: Fn(Vec, HashMap) -> Vec + Send + Sync, + VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync, + InfoFn: Fn(String) + Send + Sync, + ErrorFn: Fn(AuthErrorKind, String) + Send + Sync, +{ + pub on_challenge: ChallengeFn, + pub on_verify: VerifyFn, + pub on_info: InfoFn, + pub on_error: ErrorFn, +} + +#[async_trait] +impl Server + for AuthServer +where + ChallengeFn: Fn(Vec, HashMap) -> Vec + Send + Sync, + VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync, + InfoFn: Fn(String) + Send + Sync, + ErrorFn: Fn(AuthErrorKind, String) + Send + Sync, +{ + type Request = Auth; + type Response = Auth; + type LocalData = RwLock>; + + async fn on_request(&self, ctx: ServerCtx) { + let reply = ctx.reply.clone(); + + match ctx.request.payload { + Auth::Handshake { public_key, salt } => { + trace!( + "Received handshake request from client, request id = {}", + ctx.request.id + ); + let handshake = Handshake::default(); + match handshake.handshake(public_key, salt) { + Ok(key) => { + ctx.local_data + .write() + .await + .replace(XChaCha20Poly1305Codec::new(&key)); + + trace!( + "Sending reciprocal handshake to client, response origin id = {}", + ctx.request.id + ); + if let Err(x) = reply + .send(Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + }) + .await + { + error!("[Conn {}] {}", ctx.connection_id, x); + } + } + Err(x) => { + error!("[Conn {}] {}", ctx.connection_id, x); + return; + } + } + } + Auth::Msg { + ref encrypted_payload, + } => { + trace!( + "Received auth msg, encrypted payload size = {}", + encrypted_payload.len() + ); + + // Attempt to decrypt the message so we can understand what to do + let request = match ctx.local_data.write().await.as_mut() { + Some(codec) => { + let mut payload = BytesMut::from(encrypted_payload.as_slice()); + match codec.decode(&mut payload) { + Ok(Some(payload)) => { + utils::deserialize_from_slice::(&payload) + } + Ok(None) => Err(io::Error::new( + io::ErrorKind::InvalidData, + "Incomplete message received", + )), + Err(x) => Err(x), + } + } + None => Err(io::Error::new( + io::ErrorKind::Other, + "Handshake must be performed first (server decrypt message)", + )), + }; + + let response = match request { + Ok(request) => match request { + AuthRequest::Challenge { questions, extra } => { + trace!("Received challenge request"); + trace!("questions = {:?}", questions); + trace!("extra = {:?}", extra); + + let answers = (self.on_challenge)(questions, extra); + AuthResponse::Challenge { answers } + } + AuthRequest::Verify { kind, text } => { + trace!("Received verify request"); + trace!("kind = {:?}", kind); + trace!("text = {:?}", text); + + let valid = (self.on_verify)(kind, text); + AuthResponse::Verify { valid } + } + AuthRequest::Info { text } => { + trace!("Received info request"); + trace!("text = {:?}", text); + + (self.on_info)(text); + return; + } + AuthRequest::Error { kind, text } => { + trace!("Received error request"); + trace!("kind = {:?}", kind); + trace!("text = {:?}", text); + + (self.on_error)(kind, text); + return; + } + }, + Err(x) => { + error!("[Conn {}] {}", ctx.connection_id, x); + return; + } + }; + + // Serialize and encrypt the message before sending it back + let encrypted_payload = match ctx.local_data.write().await.as_mut() { + Some(codec) => { + let mut encrypted_payload = BytesMut::new(); + + // Convert the response into bytes for us to send back + match utils::serialize_to_vec(&response) { + Ok(bytes) => match codec.encode(&bytes, &mut encrypted_payload) { + Ok(_) => Ok(encrypted_payload.freeze().to_vec()), + Err(x) => Err(x), + }, + Err(x) => Err(x), + } + } + None => Err(io::Error::new( + io::ErrorKind::Other, + "Handshake must be performed first (server encrypt messaage)", + )), + }; + + match encrypted_payload { + Ok(encrypted_payload) => { + if let Err(x) = reply.send(Auth::Msg { encrypted_payload }).await { + error!("[Conn {}] {}", ctx.connection_id, x); + return; + } + } + Err(x) => { + error!("[Conn {}] {}", ctx.connection_id, x); + return; + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + IntoSplit, MpscListener, MpscTransport, Request, Response, ServerExt, ServerRef, + TypedAsyncRead, TypedAsyncWrite, + }; + use tokio::sync::mpsc; + + const TIMEOUT_MILLIS: u64 = 100; + + #[tokio::test] + async fn should_not_reply_if_receive_encrypted_msg_without_handshake_first() { + let (mut t, _) = spawn_auth_server( + /* on_challenge */ |_, _| Vec::new(), + /* on_verify */ |_, _| false, + /* on_info */ |_| {}, + /* on_error */ |_, _| {}, + ) + .await + .expect("Failed to spawn server"); + + // Send an encrypted message before establishing a handshake + t.write(Request::new(Auth::Msg { + encrypted_payload: Vec::new(), + })) + .await + .expect("Failed to send request to server"); + + // Wait for a response, failing if we get one + tokio::select! { + x = t.read() => panic!("Unexpectedly resolved: {:?}", x), + _ = wait_ms(TIMEOUT_MILLIS) => {} + } + } + + #[tokio::test] + async fn should_reply_to_handshake_request_with_new_public_key_and_salt() { + let (mut t, _) = spawn_auth_server( + /* on_challenge */ |_, _| Vec::new(), + /* on_verify */ |_, _| false, + /* on_info */ |_| {}, + /* on_error */ |_, _| {}, + ) + .await + .expect("Failed to spawn server"); + + // Send a handshake + let handshake = Handshake::default(); + t.write(Request::new(Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + })) + .await + .expect("Failed to send request to server"); + + // Wait for a handshake response + tokio::select! { + x = t.read() => { + let response = x.expect("Request failed").expect("Response missing"); + match response.payload { + Auth::Handshake { .. } => {}, + Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), + } + } + _ = wait_ms(TIMEOUT_MILLIS) => panic!("Ran out of time waiting on response"), + } + } + + #[tokio::test] + async fn should_not_reply_if_receive_invalid_encrypted_msg() { + let (mut t, _) = spawn_auth_server( + /* on_challenge */ |_, _| Vec::new(), + /* on_verify */ |_, _| false, + /* on_info */ |_| {}, + /* on_error */ |_, _| {}, + ) + .await + .expect("Failed to spawn server"); + + // Send a handshake + let handshake = Handshake::default(); + t.write(Request::new(Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + })) + .await + .expect("Failed to send request to server"); + + // Complete handshake + let key = match t.read().await.unwrap().unwrap().payload { + Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(), + Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), + }; + + // Send a bad chunk of data + let _codec = XChaCha20Poly1305Codec::new(&key); + t.write(Request::new(Auth::Msg { + encrypted_payload: vec![1, 2, 3, 4], + })) + .await + .unwrap(); + + // Wait for a response, failing if we get one + tokio::select! { + x = t.read() => panic!("Unexpectedly resolved: {:?}", x), + _ = wait_ms(TIMEOUT_MILLIS) => {} + } + } + + #[tokio::test] + async fn should_invoke_appropriate_function_when_receive_challenge_request_and_reply() { + let (tx, mut rx) = mpsc::channel(1); + let (mut t, _) = spawn_auth_server( + /* on_challenge */ + move |questions, extra| { + tx.try_send((questions, extra)).unwrap(); + vec!["answer1".to_string(), "answer2".to_string()] + }, + /* on_verify */ |_, _| false, + /* on_info */ |_| {}, + /* on_error */ |_, _| {}, + ) + .await + .expect("Failed to spawn server"); + + // Send a handshake + let handshake = Handshake::default(); + t.write(Request::new(Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + })) + .await + .expect("Failed to send request to server"); + + // Complete handshake + let key = match t.read().await.unwrap().unwrap().payload { + Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(), + Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), + }; + + // Send an error request + let mut codec = XChaCha20Poly1305Codec::new(&key); + t.write(Request::new(Auth::Msg { + encrypted_payload: serialize_and_encrypt( + &mut codec, + &AuthRequest::Challenge { + questions: vec![ + AuthQuestion::new("question1".to_string()), + AuthQuestion { + text: "question2".to_string(), + extra: vec![("key".to_string(), "value".to_string())] + .into_iter() + .collect(), + }, + ], + extra: vec![("hello".to_string(), "world".to_string())] + .into_iter() + .collect(), + }, + ) + .unwrap(), + })) + .await + .unwrap(); + + // Verify that the handler was triggered + let (questions, extra) = rx.recv().await.expect("Channel closed unexpectedly"); + assert_eq!( + questions, + vec![ + AuthQuestion::new("question1".to_string()), + AuthQuestion { + text: "question2".to_string(), + extra: vec![("key".to_string(), "value".to_string())] + .into_iter() + .collect(), + } + ] + ); + assert_eq!( + extra, + vec![("hello".to_string(), "world".to_string())] + .into_iter() + .collect() + ); + + // Wait for a response and verify that it matches what we expect + tokio::select! { + x = t.read() => { + let response = x.expect("Request failed").expect("Response missing"); + match response.payload { + Auth::Handshake { .. } => panic!("Received unexpected handshake"), + Auth::Msg { encrypted_payload } => { + match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { + AuthResponse::Challenge { answers } => + assert_eq!( + answers, + vec!["answer1".to_string(), "answer2".to_string()] + ), + _ => panic!("Got wrong response for verify"), + } + }, + } + } + _ = wait_ms(TIMEOUT_MILLIS) => {} + } + } + + #[tokio::test] + async fn should_invoke_appropriate_function_when_receive_verify_request_and_reply() { + let (tx, mut rx) = mpsc::channel(1); + let (mut t, _) = spawn_auth_server( + /* on_challenge */ |_, _| Vec::new(), + /* on_verify */ + move |kind, text| { + tx.try_send((kind, text)).unwrap(); + true + }, + /* on_info */ |_| {}, + /* on_error */ |_, _| {}, + ) + .await + .expect("Failed to spawn server"); + + // Send a handshake + let handshake = Handshake::default(); + t.write(Request::new(Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + })) + .await + .expect("Failed to send request to server"); + + // Complete handshake + let key = match t.read().await.unwrap().unwrap().payload { + Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(), + Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), + }; + + // Send an error request + let mut codec = XChaCha20Poly1305Codec::new(&key); + t.write(Request::new(Auth::Msg { + encrypted_payload: serialize_and_encrypt( + &mut codec, + &AuthRequest::Verify { + kind: AuthVerifyKind::Host, + text: "some text".to_string(), + }, + ) + .unwrap(), + })) + .await + .unwrap(); + + // Verify that the handler was triggered + let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly"); + assert_eq!(kind, AuthVerifyKind::Host); + assert_eq!(text, "some text"); + + // Wait for a response and verify that it matches what we expect + tokio::select! { + x = t.read() => { + let response = x.expect("Request failed").expect("Response missing"); + match response.payload { + Auth::Handshake { .. } => panic!("Received unexpected handshake"), + Auth::Msg { encrypted_payload } => { + match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { + AuthResponse::Verify { valid } => + assert!(valid, "Got verify, but valid was wrong"), + _ => panic!("Got wrong response for verify"), + } + }, + } + } + _ = wait_ms(TIMEOUT_MILLIS) => {} + } + } + + #[tokio::test] + async fn should_invoke_appropriate_function_when_receive_info_request() { + let (tx, mut rx) = mpsc::channel(1); + let (mut t, _) = spawn_auth_server( + /* on_challenge */ |_, _| Vec::new(), + /* on_verify */ |_, _| false, + /* on_info */ + move |text| { + tx.try_send(text).unwrap(); + }, + /* on_error */ |_, _| {}, + ) + .await + .expect("Failed to spawn server"); + + // Send a handshake + let handshake = Handshake::default(); + t.write(Request::new(Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + })) + .await + .expect("Failed to send request to server"); + + // Complete handshake + let key = match t.read().await.unwrap().unwrap().payload { + Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(), + Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), + }; + + // Send an error request + let mut codec = XChaCha20Poly1305Codec::new(&key); + t.write(Request::new(Auth::Msg { + encrypted_payload: serialize_and_encrypt( + &mut codec, + &AuthRequest::Info { + text: "some text".to_string(), + }, + ) + .unwrap(), + })) + .await + .unwrap(); + + // Verify that the handler was triggered + let text = rx.recv().await.expect("Channel closed unexpectedly"); + assert_eq!(text, "some text"); + + // Wait for a response, failing if we get one + tokio::select! { + x = t.read() => panic!("Unexpectedly resolved: {:?}", x), + _ = wait_ms(TIMEOUT_MILLIS) => {} + } + } + + #[tokio::test] + async fn should_invoke_appropriate_function_when_receive_error_request() { + let (tx, mut rx) = mpsc::channel(1); + let (mut t, _) = spawn_auth_server( + /* on_challenge */ |_, _| Vec::new(), + /* on_verify */ |_, _| false, + /* on_info */ |_| {}, + /* on_error */ + move |kind, text| { + tx.try_send((kind, text)).unwrap(); + }, + ) + .await + .expect("Failed to spawn server"); + + // Send a handshake + let handshake = Handshake::default(); + t.write(Request::new(Auth::Handshake { + public_key: handshake.pk_bytes(), + salt: *handshake.salt(), + })) + .await + .expect("Failed to send request to server"); + + // Complete handshake + let key = match t.read().await.unwrap().unwrap().payload { + Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(), + Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), + }; + + // Send an error request + let mut codec = XChaCha20Poly1305Codec::new(&key); + t.write(Request::new(Auth::Msg { + encrypted_payload: serialize_and_encrypt( + &mut codec, + &AuthRequest::Error { + kind: AuthErrorKind::FailedChallenge, + text: "some text".to_string(), + }, + ) + .unwrap(), + })) + .await + .unwrap(); + + // Verify that the handler was triggered + let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly"); + assert_eq!(kind, AuthErrorKind::FailedChallenge); + assert_eq!(text, "some text"); + + // Wait for a response, failing if we get one + tokio::select! { + x = t.read() => panic!("Unexpectedly resolved: {:?}", x), + _ = wait_ms(TIMEOUT_MILLIS) => {} + } + } + + async fn wait_ms(ms: u64) { + use std::time::Duration; + tokio::time::sleep(Duration::from_millis(ms)).await; + } + + fn serialize_and_encrypt( + codec: &mut XChaCha20Poly1305Codec, + payload: &AuthRequest, + ) -> io::Result> { + let mut encryped_payload = BytesMut::new(); + let payload = utils::serialize_to_vec(payload)?; + codec.encode(&payload, &mut encryped_payload)?; + Ok(encryped_payload.freeze().to_vec()) + } + + fn decrypt_and_deserialize( + codec: &mut XChaCha20Poly1305Codec, + payload: &[u8], + ) -> io::Result { + let mut payload = BytesMut::from(payload); + match codec.decode(&mut payload)? { + Some(payload) => utils::deserialize_from_slice::(&payload), + None => Err(io::Error::new( + io::ErrorKind::InvalidData, + "Incomplete message received", + )), + } + } + + async fn spawn_auth_server( + on_challenge: ChallengeFn, + on_verify: VerifyFn, + on_info: InfoFn, + on_error: ErrorFn, + ) -> io::Result<( + MpscTransport, Response>, + Box, + )> + where + ChallengeFn: + Fn(Vec, HashMap) -> Vec + Send + Sync + 'static, + VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync + 'static, + InfoFn: Fn(String) + Send + Sync + 'static, + ErrorFn: Fn(AuthErrorKind, String) + Send + Sync + 'static, + { + let server = AuthServer { + on_challenge, + on_verify, + on_info, + on_error, + }; + + // Create a test listener where we will forward a connection + let (tx, listener) = MpscListener::channel(100); + + // Make bounded transport pair and send off one of them to act as our connection + let (transport, connection) = MpscTransport::, Response>::pair(100); + tx.send(connection.into_split()) + .await + .expect("Failed to feed listener a connection"); + + let server = server.start(listener)?; + Ok((transport, server)) + } +} diff --git a/distant-net/src/client.rs b/distant-net/src/client.rs new file mode 100644 index 0000000..8d1babb --- /dev/null +++ b/distant-net/src/client.rs @@ -0,0 +1,163 @@ +use crate::{ + Codec, FramedTransport, IntoSplit, RawTransport, RawTransportRead, RawTransportWrite, Request, + Response, TypedAsyncRead, TypedAsyncWrite, +}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + ops::{Deref, DerefMut}, + sync::Arc, +}; +use tokio::{ + io, + sync::mpsc, + task::{JoinError, JoinHandle}, +}; + +mod channel; +pub use channel::*; + +mod ext; +pub use ext::*; + +/// Represents a client that can be used to send requests & receive responses from a server +pub struct Client +where + T: Send + Sync + Serialize + 'static, + U: Send + Sync + DeserializeOwned + 'static, +{ + /// Used to send requests to a server + channel: Channel, + + /// Contains the task that is running to send requests to a server + request_task: JoinHandle<()>, + + /// Contains the task that is running to receive responses from a server + response_task: JoinHandle<()>, +} + +impl Client +where + T: Send + Sync + Serialize, + U: Send + Sync + DeserializeOwned, +{ + /// Initializes a client using the provided reader and writer + pub fn new(mut writer: W, mut reader: R) -> io::Result + where + R: TypedAsyncRead> + Send + 'static, + W: TypedAsyncWrite> + Send + 'static, + { + let post_office = Arc::new(PostOffice::default()); + let weak_post_office = Arc::downgrade(&post_office); + + // Start a task that continually checks for responses and delivers them using the + // post office + let response_task = tokio::spawn(async move { + loop { + match reader.read().await { + Ok(Some(res)) => { + // Try to send response to appropriate mailbox + // TODO: How should we handle false response? Did logging in past + post_office.deliver_response(res).await; + } + Ok(None) => { + break; + } + Err(_) => { + break; + } + } + } + }); + + let (tx, mut rx) = mpsc::channel::>(1); + let request_task = tokio::spawn(async move { + while let Some(req) = rx.recv().await { + if writer.write(req).await.is_err() { + break; + } + } + }); + + let channel = Channel { + tx, + post_office: weak_post_office, + }; + + Ok(Self { + channel, + request_task, + response_task, + }) + } + + /// Initializes a client using the provided framed transport + pub fn from_framed_transport(transport: FramedTransport) -> io::Result + where + TR: RawTransport + IntoSplit + 'static, + ::Read: RawTransportRead, + ::Write: RawTransportWrite, + C: Codec + Send + 'static, + { + let (writer, reader) = transport.into_split(); + Self::new(writer, reader) + } + + /// Convert into underlying channel + pub fn into_channel(self) -> Channel { + self.channel + } + + /// Clones the underlying channel for requests and returns the cloned instance + pub fn clone_channel(&self) -> Channel { + self.channel.clone() + } + + /// Waits for the client to terminate, which results when the receiving end of the network + /// connection is closed (or the client is shutdown) + pub async fn wait(self) -> Result<(), JoinError> { + tokio::try_join!(self.request_task, self.response_task).map(|_| ()) + } + + /// Abort the client's current connection by forcing its tasks to abort + pub fn abort(&self) { + self.request_task.abort(); + self.response_task.abort(); + } + + /// Returns true if client's underlying event processing has finished/terminated + pub fn is_finished(&self) -> bool { + self.request_task.is_finished() && self.response_task.is_finished() + } +} + +impl Deref for Client +where + T: Send + Sync + Serialize + 'static, + U: Send + Sync + DeserializeOwned + 'static, +{ + type Target = Channel; + + fn deref(&self) -> &Self::Target { + &self.channel + } +} + +impl DerefMut for Client +where + T: Send + Sync + Serialize + 'static, + U: Send + Sync + DeserializeOwned + 'static, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.channel + } +} + +impl From> for Channel +where + T: Send + Sync + Serialize + 'static, + U: Send + Sync + DeserializeOwned + 'static, +{ + fn from(client: Client) -> Self { + client.channel + } +} diff --git a/distant-net/src/client/channel.rs b/distant-net/src/client/channel.rs new file mode 100644 index 0000000..02827e2 --- /dev/null +++ b/distant-net/src/client/channel.rs @@ -0,0 +1,236 @@ +use crate::{Request, Response}; +use std::{convert, io, sync::Weak}; +use tokio::{sync::mpsc, time::Duration}; + +mod mailbox; +pub use mailbox::*; + +/// Capacity associated with a channel's mailboxes for receiving multiple responses to a request +const CHANNEL_MAILBOX_CAPACITY: usize = 10000; + +/// Represents a sender of requests tied to a session, holding onto a weak reference of +/// mailboxes to relay responses, meaning that once the [`Session`] is closed or dropped, +/// any sent request will no longer be able to receive responses +pub struct Channel +where + T: Send + Sync, + U: Send + Sync, +{ + /// Used to send requests to a server + pub(crate) tx: mpsc::Sender>, + + /// Collection of mailboxes for receiving responses to requests + pub(crate) post_office: Weak>>, +} + +// NOTE: Implemented manually to avoid needing clone to be defined on generic types +impl Clone for Channel +where + T: Send + Sync, + U: Send + Sync, +{ + fn clone(&self) -> Self { + Self { + tx: self.tx.clone(), + post_office: Weak::clone(&self.post_office), + } + } +} + +impl Channel +where + T: Send + Sync, + U: Send + Sync + 'static, +{ + /// Returns true if no more requests can be transferred + pub fn is_closed(&self) -> bool { + self.tx.is_closed() + } + + /// Sends a request and returns a mailbox that can receive one or more responses, failing if + /// unable to send a request or if the session's receiving line to the remote server has + /// already been severed + pub async fn mail(&mut self, req: impl Into>) -> io::Result>> { + let req = req.into(); + + // First, create a mailbox using the request's id + let mailbox = Weak::upgrade(&self.post_office) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotConnected, + "Session's post office is no longer available", + ) + })? + .make_mailbox(req.id.clone(), CHANNEL_MAILBOX_CAPACITY) + .await; + + // Second, send the request + self.fire(req).await?; + + // Third, return mailbox + Ok(mailbox) + } + + /// Sends a request and returns a mailbox, timing out after duration has passed + pub async fn mail_timeout( + &mut self, + req: impl Into>, + duration: impl Into>, + ) -> io::Result>> { + match duration.into() { + Some(duration) => tokio::time::timeout(duration, self.mail(req)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity), + None => self.mail(req).await, + } + } + + /// Sends a request and waits for a response, failing if unable to send a request or if + /// the session's receiving line to the remote server has already been severed + pub async fn send(&mut self, req: impl Into>) -> io::Result> { + // Send mail and get back a mailbox + let mut mailbox = self.mail(req).await?; + + // Wait for first response, and then drop the mailbox + mailbox + .next() + .await + .ok_or_else(|| io::Error::from(io::ErrorKind::ConnectionAborted)) + } + + /// Sends a request and waits for a response, timing out after duration has passed + pub async fn send_timeout( + &mut self, + req: impl Into>, + duration: impl Into>, + ) -> io::Result> { + match duration.into() { + Some(duration) => tokio::time::timeout(duration, self.send(req)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity), + None => self.send(req).await, + } + } + + /// Sends a request without waiting for a response; this method is able to be used even + /// if the session's receiving line to the remote server has been severed + pub async fn fire(&mut self, req: impl Into>) -> io::Result<()> { + self.tx + .send(req.into()) + .await + .map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x.to_string())) + } + + /// Sends a request without waiting for a response, timing out after duration has passed + pub async fn fire_timeout( + &mut self, + req: impl Into>, + duration: impl Into>, + ) -> io::Result<()> { + match duration.into() { + Some(duration) => tokio::time::timeout(duration, self.fire(req)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity), + None => self.fire(req).await, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Client, FramedTransport, TypedAsyncRead, TypedAsyncWrite}; + use std::time::Duration; + + type TestClient = Client; + + #[tokio::test] + async fn mail_should_return_mailbox_that_receives_responses_until_transport_closes() { + let (t1, mut t2) = FramedTransport::make_test_pair(); + let session: TestClient = Client::from_framed_transport(t1).unwrap(); + let mut channel = session.clone_channel(); + + let req = Request::new(0); + let res = Response::new(req.id.clone(), 1); + + let mut mailbox = channel.mail(req).await.unwrap(); + + // Get first response + match tokio::join!(mailbox.next(), t2.write(res.clone())) { + (Some(actual), _) => assert_eq!(actual, res), + x => panic!("Unexpected response: {:?}", x), + } + + // Get second response + match tokio::join!(mailbox.next(), t2.write(res.clone())) { + (Some(actual), _) => assert_eq!(actual, res), + x => panic!("Unexpected response: {:?}", x), + } + + // Trigger the mailbox to wait BEFORE closing our transport to ensure that + // we don't get stuck if the mailbox was already waiting + let next_task = tokio::spawn(async move { mailbox.next().await }); + tokio::task::yield_now().await; + + drop(t2); + match next_task.await { + Ok(None) => {} + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + async fn send_should_wait_until_response_received() { + let (t1, mut t2) = FramedTransport::make_test_pair(); + let session: TestClient = Client::from_framed_transport(t1).unwrap(); + let mut channel = session.clone_channel(); + + let req = Request::new(0); + let res = Response::new(req.id.clone(), 1); + + let (actual, _) = tokio::join!(channel.send(req), t2.write(res.clone())); + match actual { + Ok(actual) => assert_eq!(actual, res), + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + async fn send_timeout_should_fail_if_response_not_received_in_time() { + let (t1, mut t2) = FramedTransport::make_test_pair(); + let session: TestClient = Client::from_framed_transport(t1).unwrap(); + let mut channel = session.clone_channel(); + + let req = Request::new(0); + match channel.send_timeout(req, Duration::from_millis(30)).await { + Err(x) => assert_eq!(x.kind(), io::ErrorKind::TimedOut), + x => panic!("Unexpected response: {:?}", x), + } + + let _req = TypedAsyncRead::>::read(&mut t2) + .await + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn fire_should_send_request_and_not_wait_for_response() { + let (t1, mut t2) = FramedTransport::make_test_pair(); + let session: TestClient = Client::from_framed_transport(t1).unwrap(); + let mut channel = session.clone_channel(); + + let req = Request::new(0); + match channel.fire(req).await { + Ok(_) => {} + x => panic!("Unexpected response: {:?}", x), + } + + let _req = TypedAsyncRead::>::read(&mut t2) + .await + .unwrap() + .unwrap(); + } +} diff --git a/distant-net/src/client/channel/mailbox.rs b/distant-net/src/client/channel/mailbox.rs new file mode 100644 index 0000000..ca3a7c4 --- /dev/null +++ b/distant-net/src/client/channel/mailbox.rs @@ -0,0 +1,128 @@ +use crate::{Id, Response}; +use std::{ + collections::HashMap, + sync::{Arc, Weak}, + time::Duration, +}; +use tokio::{ + io, + sync::{mpsc, Mutex}, + time, +}; + +#[derive(Clone)] +pub struct PostOffice { + mailboxes: Arc>>>, +} + +impl Default for PostOffice +where + T: Send + 'static, +{ + /// Creates a new postoffice with a cleanup interval of 60s + fn default() -> Self { + Self::new(Duration::from_secs(60)) + } +} + +impl PostOffice +where + T: Send + 'static, +{ + /// Creates a new post office that delivers to mailboxes, cleaning up orphaned mailboxes + /// waiting `cleanup` time inbetween attempts + pub fn new(cleanup: Duration) -> Self { + let mailboxes = Arc::new(Mutex::new(HashMap::new())); + let mref = Arc::downgrade(&mailboxes); + + // Spawn a task that will clean up orphaned mailboxes every minute + tokio::spawn(async move { + while let Some(m) = Weak::upgrade(&mref) { + m.lock() + .await + .retain(|_id, tx: &mut mpsc::Sender| !tx.is_closed()); + + // NOTE: Must drop the reference before sleeping, otherwise we block + // access to the mailbox map elsewhere and deadlock! + drop(m); + + // Wait a minute before trying again + time::sleep(cleanup).await; + } + }); + + Self { mailboxes } + } + + /// Creates a new mailbox using the given id and buffer size for maximum values that + /// can be queued in the mailbox + pub async fn make_mailbox(&self, id: Id, buffer: usize) -> Mailbox { + let (tx, rx) = mpsc::channel(buffer); + self.mailboxes.lock().await.insert(id.clone(), tx); + + Mailbox { id, rx } + } + + /// Delivers some value to appropriate mailbox, returning false if no mailbox is found + /// for the specified id or if the mailbox is no longer receiving values + pub async fn deliver(&self, id: &Id, value: T) -> bool { + if let Some(tx) = self.mailboxes.lock().await.get_mut(id) { + let success = tx.send(value).await.is_ok(); + + // If failed, we want to remove the mailbox sender as it is no longer valid + if !success { + self.mailboxes.lock().await.remove(id); + } + + success + } else { + false + } + } +} + +impl PostOffice> +where + T: Send + 'static, +{ + /// Delivers some response to appropriate mailbox, returning false if no mailbox is found + /// for the response's origin or if the mailbox is no longer receiving values + pub async fn deliver_response(&self, res: Response) -> bool { + self.deliver(&res.origin_id.clone(), res).await + } +} + +/// Represents a destination for responses +pub struct Mailbox { + /// Represents id associated with the mailbox + id: Id, + + /// Underlying mailbox storage + rx: mpsc::Receiver, +} + +impl Mailbox { + /// Represents id associated with the mailbox + pub fn id(&self) -> &Id { + &self.id + } + + /// Receives next value in mailbox + pub async fn next(&mut self) -> Option { + self.rx.recv().await + } + + /// Receives next value in mailbox, waiting up to duration before timing out + pub async fn next_timeout(&mut self, duration: Duration) -> io::Result> { + time::timeout(duration, self.next()) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + } + + /// Closes the mailbox such that it will not receive any more values + /// + /// Any values already in the mailbox will still be returned via `next` + pub fn close(&mut self) { + self.rx.close() + } +} diff --git a/distant-net/src/client/ext.rs b/distant-net/src/client/ext.rs new file mode 100644 index 0000000..d23a3d2 --- /dev/null +++ b/distant-net/src/client/ext.rs @@ -0,0 +1,14 @@ +mod tcp; +pub use tcp::*; + +#[cfg(unix)] +mod unix; + +#[cfg(unix)] +pub use unix::*; + +#[cfg(windows)] +mod windows; + +#[cfg(windows)] +pub use windows::*; diff --git a/distant-net/src/client/ext/tcp.rs b/distant-net/src/client/ext/tcp.rs new file mode 100644 index 0000000..e58a345 --- /dev/null +++ b/distant-net/src/client/ext/tcp.rs @@ -0,0 +1,49 @@ +use crate::{Client, Codec, FramedTransport, TcpTransport}; +use async_trait::async_trait; +use serde::{de::DeserializeOwned, Serialize}; +use std::{convert, net::SocketAddr}; +use tokio::{io, time::Duration}; + +#[async_trait] +pub trait TcpClientExt +where + T: Serialize + Send + Sync, + U: DeserializeOwned + Send + Sync, +{ + /// Connect to a remote TCP server using the provided information + async fn connect(addr: SocketAddr, codec: C) -> io::Result> + where + C: Codec + Send + 'static; + + /// Connect to a remote TCP server, timing out after duration has passed + async fn connect_timeout( + addr: SocketAddr, + codec: C, + duration: Duration, + ) -> io::Result> + where + C: Codec + Send + 'static, + { + tokio::time::timeout(duration, Self::connect(addr, codec)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity) + } +} + +#[async_trait] +impl TcpClientExt for Client +where + T: Send + Sync + Serialize + 'static, + U: Send + Sync + DeserializeOwned + 'static, +{ + /// Connect to a remote TCP server using the provided information + async fn connect(addr: SocketAddr, codec: C) -> io::Result> + where + C: Codec + Send + 'static, + { + let transport = TcpTransport::connect(addr).await?; + let transport = FramedTransport::new(transport, codec); + Self::from_framed_transport(transport) + } +} diff --git a/distant-net/src/client/ext/unix.rs b/distant-net/src/client/ext/unix.rs new file mode 100644 index 0000000..9188f53 --- /dev/null +++ b/distant-net/src/client/ext/unix.rs @@ -0,0 +1,54 @@ +use crate::{Client, Codec, FramedTransport, IntoSplit, UnixSocketTransport}; +use async_trait::async_trait; +use serde::{de::DeserializeOwned, Serialize}; +use std::{convert, path::Path}; +use tokio::{io, time::Duration}; + +#[async_trait] +pub trait UnixSocketClientExt +where + T: Serialize + Send + Sync, + U: DeserializeOwned + Send + Sync, +{ + /// Connect to a proxy unix socket + async fn connect(path: P, codec: C) -> io::Result> + where + P: AsRef + Send, + C: Codec + Send + 'static; + + /// Connect to a proxy unix socket, timing out after duration has passed + async fn connect_timeout( + path: P, + codec: C, + duration: Duration, + ) -> io::Result> + where + P: AsRef + Send, + C: Codec + Send + 'static, + { + tokio::time::timeout(duration, Self::connect(path, codec)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity) + } +} + +#[async_trait] +impl UnixSocketClientExt for Client +where + T: Send + Sync + Serialize + 'static, + U: Send + Sync + DeserializeOwned + 'static, +{ + /// Connect to a proxy unix socket + async fn connect(path: P, codec: C) -> io::Result> + where + P: AsRef + Send, + C: Codec + Send + 'static, + { + let p = path.as_ref(); + let transport = UnixSocketTransport::connect(p).await?; + let transport = FramedTransport::new(transport, codec); + let (writer, reader) = transport.into_split(); + Ok(Client::new(writer, reader)?) + } +} diff --git a/distant-net/src/client/ext/windows.rs b/distant-net/src/client/ext/windows.rs new file mode 100644 index 0000000..5186caa --- /dev/null +++ b/distant-net/src/client/ext/windows.rs @@ -0,0 +1,86 @@ +use crate::{Client, Codec, FramedTransport, IntoSplit, WindowsPipeTransport}; +use async_trait::async_trait; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + convert, + ffi::{OsStr, OsString}, +}; +use tokio::{io, time::Duration}; + +#[async_trait] +pub trait WindowsPipeClientExt +where + T: Serialize + Send + Sync, + U: DeserializeOwned + Send + Sync, +{ + /// Connect to a server listening on a Windows pipe at the specified address + /// using the given codec + async fn connect(addr: A, codec: C) -> io::Result> + where + A: AsRef + Send, + C: Codec + Send + 'static; + + /// Connect to a server listening on a Windows pipe at the specified address + /// via `\\.\pipe\{name}` using the given codec + async fn connect_local(name: N, codec: C) -> io::Result> + where + N: AsRef + Send, + C: Codec + Send + 'static, + { + let mut addr = OsString::from(r"\\.\pipe\"); + addr.push(name.as_ref()); + Self::connect(addr, codec).await + } + + /// Connect to a server listening on a Windows pipe at the specified address + /// using the given codec, timing out after duration has passed + async fn connect_timeout( + addr: A, + codec: C, + duration: Duration, + ) -> io::Result> + where + A: AsRef + Send, + C: Codec + Send + 'static, + { + tokio::time::timeout(duration, Self::connect(addr, codec)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity) + } + + /// Connect to a server listening on a Windows pipe at the specified address + /// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed + async fn connect_local_timeout( + name: N, + codec: C, + duration: Duration, + ) -> io::Result> + where + N: AsRef + Send, + C: Codec + Send + 'static, + { + let mut addr = OsString::from(r"\\.\pipe\"); + addr.push(name.as_ref()); + Self::connect_timeout(addr, codec, duration).await + } +} + +#[async_trait] +impl WindowsPipeClientExt for Client +where + T: Send + Sync + Serialize + 'static, + U: Send + Sync + DeserializeOwned + 'static, +{ + async fn connect(addr: A, codec: C) -> io::Result> + where + A: AsRef + Send, + C: Codec + Send + 'static, + { + let a = addr.as_ref(); + let transport = WindowsPipeTransport::connect(a).await?; + let transport = FramedTransport::new(transport, codec); + let (writer, reader) = transport.into_split(); + Ok(Client::new(writer, reader)?) + } +} diff --git a/distant-core/src/net/transport/codec/mod.rs b/distant-net/src/codec.rs similarity index 100% rename from distant-core/src/net/transport/codec/mod.rs rename to distant-net/src/codec.rs diff --git a/distant-core/src/net/transport/codec/plain.rs b/distant-net/src/codec/plain.rs similarity index 99% rename from distant-core/src/net/transport/codec/plain.rs rename to distant-net/src/codec/plain.rs index 526d0dd..d0e6697 100644 --- a/distant-core/src/net/transport/codec/plain.rs +++ b/distant-net/src/codec/plain.rs @@ -1,4 +1,4 @@ -use super::Codec; +use crate::Codec; use bytes::{Buf, BufMut, BytesMut}; use std::convert::TryInto; use tokio::io; diff --git a/distant-core/src/net/transport/codec/xchacha20poly1305.rs b/distant-net/src/codec/xchacha20poly1305.rs similarity index 97% rename from distant-core/src/net/transport/codec/xchacha20poly1305.rs rename to distant-net/src/codec/xchacha20poly1305.rs index d899347..5d897e9 100644 --- a/distant-core/src/net/transport/codec/xchacha20poly1305.rs +++ b/distant-net/src/codec/xchacha20poly1305.rs @@ -1,5 +1,4 @@ -use super::Codec; -use crate::net::{SecretKey, SecretKey32}; +use crate::{Codec, SecretKey, SecretKey32}; use bytes::{Buf, BufMut, BytesMut}; use chacha20poly1305::{ aead::{Aead, NewAead}, @@ -23,12 +22,18 @@ pub struct XChaCha20Poly1305Codec { } impl_traits_for_codec!(XChaCha20Poly1305Codec); +impl XChaCha20Poly1305Codec { + pub fn new(key: &[u8]) -> Self { + let key = Key::from_slice(key); + let cipher = XChaCha20Poly1305::new(key); + Self { cipher } + } +} + impl From for XChaCha20Poly1305Codec { /// Create a new XChaCha20Poly1305 codec with a 32-byte key fn from(secret_key: SecretKey32) -> Self { - let key = Key::from_slice(secret_key.unprotected_as_bytes()); - let cipher = XChaCha20Poly1305::new(key); - Self { cipher } + Self::new(secret_key.unprotected_as_bytes()) } } diff --git a/distant-net/src/id.rs b/distant-net/src/id.rs new file mode 100644 index 0000000..b2ccda8 --- /dev/null +++ b/distant-net/src/id.rs @@ -0,0 +1,2 @@ +/// Id associated with an active connection +pub type ConnectionId = u64; diff --git a/distant-core/src/net/mod.rs b/distant-net/src/key.rs similarity index 87% rename from distant-core/src/net/mod.rs rename to distant-net/src/key.rs index 47d0556..4217a92 100644 --- a/distant-core/src/net/mod.rs +++ b/distant-net/src/key.rs @@ -1,11 +1,6 @@ -mod listener; -mod transport; - use derive_more::{Display, Error}; -pub use listener::{AcceptFuture, Listener, TransportListener}; use rand::{rngs::OsRng, RngCore}; use std::{fmt, str::FromStr}; -pub use transport::*; #[derive(Debug, Display, Error)] pub struct SecretKeyError; @@ -90,18 +85,16 @@ impl From<[u8; N]> for SecretKey { impl FromStr for SecretKey { type Err = SecretKeyError; + /// Parse a str of hex as an N-byte secret key fn from_str(s: &str) -> Result { let bytes = hex::decode(s).map_err(|_| SecretKeyError)?; Self::from_slice(&bytes) } } -pub trait UnprotectedToHexKey { - fn unprotected_to_hex_key(&self) -> String; -} - -impl UnprotectedToHexKey for SecretKey { - fn unprotected_to_hex_key(&self) -> String { - hex::encode(self.unprotected_as_bytes()) +impl fmt::Display for SecretKey { + /// Display an N-byte secret key as a hex string + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(self.unprotected_as_bytes())) } } diff --git a/distant-net/src/lib.rs b/distant-net/src/lib.rs new file mode 100644 index 0000000..b85de6a --- /dev/null +++ b/distant-net/src/lib.rs @@ -0,0 +1,27 @@ +mod any; +mod auth; +mod client; +mod codec; +mod id; +mod key; +mod listener; +mod packet; +mod port; +mod server; +mod transport; +mod utils; + +pub use any::*; +pub use auth::*; +pub use client::*; +pub use codec::*; +pub use id::*; +pub use key::*; +pub use listener::*; +pub use packet::*; +pub use port::*; +pub use server::*; +pub use transport::*; + +pub use log; +pub use paste; diff --git a/distant-net/src/listener.rs b/distant-net/src/listener.rs new file mode 100644 index 0000000..dd2ed66 --- /dev/null +++ b/distant-net/src/listener.rs @@ -0,0 +1,34 @@ +use async_trait::async_trait; +use std::io; + +mod mapped; +pub use mapped::*; + +mod mpsc; +pub use mpsc::*; + +mod oneshot; +pub use oneshot::*; + +mod tcp; +pub use tcp::*; + +#[cfg(unix)] +mod unix; + +#[cfg(unix)] +pub use unix::*; + +#[cfg(windows)] +mod windows; + +#[cfg(windows)] +pub use windows::*; + +/// Represents a type that has a listen interface for receiving raw streams +#[async_trait] +pub trait Listener: Send + Sync { + type Output; + + async fn accept(&mut self) -> io::Result; +} diff --git a/distant-net/src/listener/mapped.rs b/distant-net/src/listener/mapped.rs new file mode 100644 index 0000000..82e0cfc --- /dev/null +++ b/distant-net/src/listener/mapped.rs @@ -0,0 +1,40 @@ +use crate::Listener; +use async_trait::async_trait; +use std::io; + +/// Represents a [`Listener`] that wraps a different [`Listener`], +/// mapping the received connection to something else using the map function +pub struct MappedListener +where + L: Listener, + F: FnMut(T) -> U + Send + Sync, +{ + listener: L, + f: F, +} + +impl MappedListener +where + L: Listener, + F: FnMut(T) -> U + Send + Sync, +{ + pub fn new(listener: L, f: F) -> Self { + Self { listener, f } + } +} + +#[async_trait] +impl Listener for MappedListener +where + L: Listener, + F: FnMut(T) -> U + Send + Sync, +{ + type Output = U; + + /// Waits for the next fully-initialized transport for an incoming stream to be available, + /// returning an error if no longer accepting new connections + async fn accept(&mut self) -> io::Result { + let output = self.listener.accept().await?; + Ok((self.f)(output)) + } +} diff --git a/distant-net/src/listener/mpsc.rs b/distant-net/src/listener/mpsc.rs new file mode 100644 index 0000000..fe70779 --- /dev/null +++ b/distant-net/src/listener/mpsc.rs @@ -0,0 +1,31 @@ +use crate::Listener; +use async_trait::async_trait; +use derive_more::From; +use std::io; +use tokio::sync::mpsc; + +/// Represents a [`Listener`] that uses an [`mpsc::Receiver`] to +/// accept new connections +#[derive(From)] +pub struct MpscListener { + inner: mpsc::Receiver, +} + +impl MpscListener { + pub fn channel(buffer: usize) -> (mpsc::Sender, Self) { + let (tx, rx) = mpsc::channel(buffer); + (tx, Self { inner: rx }) + } +} + +#[async_trait] +impl Listener for MpscListener { + type Output = T; + + async fn accept(&mut self) -> io::Result { + self.inner + .recv() + .await + .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe)) + } +} diff --git a/distant-net/src/listener/oneshot.rs b/distant-net/src/listener/oneshot.rs new file mode 100644 index 0000000..98d8e01 --- /dev/null +++ b/distant-net/src/listener/oneshot.rs @@ -0,0 +1,84 @@ +use crate::Listener; +use async_trait::async_trait; +use derive_more::From; +use std::io; +use tokio::sync::oneshot; + +/// Represents a [`Listener`] that only has a single connection +#[derive(From)] +pub struct OneshotListener { + inner: Option>, +} + +impl OneshotListener { + pub fn from_value(value: T) -> Self { + let (tx, listener) = Self::channel(); + + // NOTE: Impossible to fail as the receiver has not been dropped at this point + let _ = tx.send(value); + + listener + } + + pub fn channel() -> (oneshot::Sender, Self) { + let (tx, rx) = oneshot::channel(); + (tx, Self { inner: Some(rx) }) + } +} + +#[async_trait] +impl Listener for OneshotListener { + type Output = T; + + /// First call to accept will return listener tied to [`OneshotListener`] while future + /// calls will yield an error of `io::ErrorKind::ConnectionAborted` + async fn accept(&mut self) -> io::Result { + match self.inner.take() { + Some(rx) => rx + .await + .map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x)), + None => Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "Oneshot listener has concluded", + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::task::JoinHandle; + + #[tokio::test] + async fn from_value_should_return_value_on_first_call_to_accept() { + let mut listener = OneshotListener::from_value("hello world"); + assert_eq!(listener.accept().await.unwrap(), "hello world"); + assert_eq!( + listener.accept().await.unwrap_err().kind(), + io::ErrorKind::ConnectionAborted + ); + } + + #[tokio::test] + async fn channel_should_return_a_oneshot_sender_to_feed_first_call_to_accept() { + let (tx, mut listener) = OneshotListener::channel(); + let accept_task: JoinHandle<(io::Result<&str>, io::Result<&str>)> = + tokio::spawn(async move { + let result_1 = listener.accept().await; + let result_2 = listener.accept().await; + (result_1, result_2) + }); + tokio::spawn(async move { + tx.send("hello world").unwrap(); + }); + + let (result_1, result_2) = accept_task.await.unwrap(); + + assert_eq!(result_1.unwrap(), "hello world"); + assert_eq!( + result_2.unwrap_err().kind(), + io::ErrorKind::ConnectionAborted + ); + } +} diff --git a/distant-net/src/listener/tcp.rs b/distant-net/src/listener/tcp.rs new file mode 100644 index 0000000..dc681f4 --- /dev/null +++ b/distant-net/src/listener/tcp.rs @@ -0,0 +1,167 @@ +use crate::{Listener, PortRange, TcpTransport}; +use async_trait::async_trait; +use std::{fmt, io, net::IpAddr}; +use tokio::net::TcpListener as TokioTcpListener; + +/// Represents a [`Listener`] for incoming connections over TCP +pub struct TcpListener { + addr: IpAddr, + port: u16, + inner: TokioTcpListener, +} + +impl TcpListener { + /// Creates a new listener by binding to the specified IP address and port + /// in the given port range + pub async fn bind(addr: IpAddr, port: impl Into) -> io::Result { + let listener = + TokioTcpListener::bind(port.into().make_socket_addrs(addr).as_slice()).await?; + + // Get the port that we bound to + let port = listener.local_addr()?.port(); + + Ok(Self { + addr, + port, + inner: listener, + }) + } + + /// Returns the IP address that the listener is bound to + pub fn ip_addr(&self) -> IpAddr { + self.addr + } + + /// Returns the port that the listener is bound to + pub fn port(&self) -> u16 { + self.port + } +} + +impl fmt::Debug for TcpListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TcpListener") + .field("addr", &self.addr) + .field("port", &self.port) + .finish() + } +} + +#[async_trait] +impl Listener for TcpListener { + type Output = TcpTransport; + + async fn accept(&mut self) -> io::Result { + let (stream, peer_addr) = TokioTcpListener::accept(&self.inner).await?; + Ok(TcpTransport { + addr: peer_addr.ip(), + port: peer_addr.port(), + inner: stream, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv6Addr, SocketAddr}; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::oneshot, + task::JoinHandle, + }; + + #[tokio::test] + async fn should_fail_to_bind_if_port_already_bound() { + let addr = IpAddr::V6(Ipv6Addr::LOCALHOST); + let port = 0; // Ephemeral port + + // Listen at some port + let listener = TcpListener::bind(addr, port) + .await + .expect("Unexpectedly failed to bind first time"); + + // Get the actual port we bound to + let port = listener.port(); + + // Now this should fail as we're already bound to the address and port + TcpListener::bind(addr, port).await.expect_err(&format!( + "Unexpectedly succeeded in binding a second time to {}:{}", + addr, port, + )); + } + + #[tokio::test] + async fn should_be_able_to_receive_connections_and_send_and_receive_data_with_them() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for two connections and then + // return the success or failure + let task: JoinHandle> = tokio::spawn(async move { + let addr = IpAddr::V6(Ipv6Addr::LOCALHOST); + let port = 0; // Ephemeral port + + // Listen at the address and port + let mut listener = TcpListener::bind(addr, port).await?; + + // Send the name back to our main test thread + tx.send(SocketAddr::from((addr, listener.port()))) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string()))?; + + // Get first connection + let mut conn_1 = listener.accept().await?; + + // Send some data to the first connection (12 bytes) + conn_1.write_all(b"hello conn 1").await?; + + // Get some data from the first connection (14 bytes) + let mut buf: [u8; 14] = [0; 14]; + let _ = conn_1.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server 1"); + + // Get second connection + let mut conn_2 = listener.accept().await?; + + // Send some data on to second connection (12 bytes) + conn_2.write_all(b"hello conn 2").await?; + + // Get some data from the second connection (14 bytes) + let mut buf: [u8; 14] = [0; 14]; + let _ = conn_2.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server 2"); + + Ok(()) + }); + + // Wait for the server to be ready + let address = rx.await.expect("Failed to get server address"); + + // Connect to the listener twice, sending some bytes and receiving some bytes from each + let mut buf: [u8; 12] = [0; 12]; + + let mut conn = TcpTransport::connect(&address) + .await + .expect("Conn 1 failed to connect"); + conn.write_all(b"hello server 1") + .await + .expect("Conn 1 failed to write"); + conn.read_exact(&mut buf) + .await + .expect("Conn 1 failed to read"); + assert_eq!(&buf, b"hello conn 1"); + + let mut conn = TcpTransport::connect(&address) + .await + .expect("Conn 2 failed to connect"); + conn.write_all(b"hello server 2") + .await + .expect("Conn 2 failed to write"); + conn.read_exact(&mut buf) + .await + .expect("Conn 2 failed to read"); + assert_eq!(&buf, b"hello conn 2"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Listener task failed unexpectedly"); + } +} diff --git a/distant-net/src/listener/unix.rs b/distant-net/src/listener/unix.rs new file mode 100644 index 0000000..21f8bac --- /dev/null +++ b/distant-net/src/listener/unix.rs @@ -0,0 +1,212 @@ +use crate::{Listener, UnixSocketTransport}; +use async_trait::async_trait; +use std::{ + fmt, io, + os::unix::fs::PermissionsExt, + path::{Path, PathBuf}, +}; +use tokio::net::{UnixListener, UnixStream}; + +/// Represents a [`Listener`] for incoming connections over a Unix socket +pub struct UnixSocketListener { + path: PathBuf, + inner: tokio::net::UnixListener, +} + +impl UnixSocketListener { + /// Creates a new listener by binding to the specified path, failing if the path already + /// exists. Sets permission of unix socket to `0o600` where only the owner can read from and + /// write to the socket. + pub async fn bind(path: impl AsRef) -> io::Result { + Self::bind_with_permissions(path, Self::default_unix_socket_file_permissions()).await + } + + /// Creates a new listener by binding to the specified path, failing if the path already + /// exists. Sets the unix socket file permissions to `mode`. + pub async fn bind_with_permissions(path: impl AsRef, mode: u32) -> io::Result { + // Attempt to bind to the path, and if we fail, we see if we can connect + // to the path -- if not, we can try to delete the path and start again + let listener = match UnixListener::bind(path.as_ref()) { + Ok(listener) => listener, + Err(_) => { + // If we can connect to the path, then it's already in use + if UnixStream::connect(path.as_ref()).await.is_ok() { + return Err(io::Error::from(io::ErrorKind::AddrInUse)); + } + + // Otherwise, remove the file and try again + tokio::fs::remove_file(path.as_ref()).await?; + + UnixListener::bind(path.as_ref())? + } + }; + + // TODO: We should be setting this permission during bind, but neither std library nor + // tokio have support for this. We would need to create our own raw socket and + // use libc to change the permissions via the raw file descriptor + // + // See https://github.com/chipsenkbeil/distant/issues/111 + let mut permissions = tokio::fs::metadata(path.as_ref()).await?.permissions(); + permissions.set_mode(mode); + tokio::fs::set_permissions(path.as_ref(), permissions).await?; + + Ok(Self { + path: path.as_ref().to_path_buf(), + inner: listener, + }) + } + + /// Returns the path to the socket + pub fn path(&self) -> &Path { + &self.path + } + + /// Returns the default unix socket file permissions as an octal (e.g. `0o600`) + pub const fn default_unix_socket_file_permissions() -> u32 { + 0o600 + } +} + +impl fmt::Debug for UnixSocketListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UnixSocketListener") + .field("path", &self.path) + .finish() + } +} + +#[async_trait] +impl Listener for UnixSocketListener { + type Output = UnixSocketTransport; + + async fn accept(&mut self) -> io::Result { + // NOTE: Address provided is unnamed, or at least the `as_pathname()` method is + // returning none, so we use our listener's path, which is the same as + // what is being connected, anyway + let (stream, _) = tokio::net::UnixListener::accept(&self.inner).await?; + Ok(UnixSocketTransport { + path: self.path.to_path_buf(), + inner: stream, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::oneshot, + task::JoinHandle, + }; + + #[tokio::test] + async fn should_succeed_to_bind_if_file_exists_at_path_but_nothing_listening() { + // Generate a socket path + let path = NamedTempFile::new() + .expect("Failed to create file") + .into_temp_path(); + + // This should fail as we're already got a file at the path + UnixSocketListener::bind(&path) + .await + .expect("Unexpectedly failed to bind to existing file"); + } + + #[tokio::test] + async fn should_fail_to_bind_if_socket_already_bound() { + // Generate a socket path and delete the file after + let path = NamedTempFile::new() + .expect("Failed to create socket file") + .path() + .to_path_buf(); + + // Listen at the socket + let _listener = UnixSocketListener::bind(&path) + .await + .expect("Unexpectedly failed to bind first time"); + + // Now this should fail as we're already bound to the path + UnixSocketListener::bind(&path) + .await + .expect_err("Unexpectedly succeeded in binding to same socket"); + } + + #[tokio::test] + async fn should_be_able_to_receive_connections_and_send_and_receive_data_with_them() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for two connections and then + // return the success or failure + let task: JoinHandle> = tokio::spawn(async move { + // Generate a socket path and delete the file after + let path = NamedTempFile::new() + .expect("Failed to create socket file") + .path() + .to_path_buf(); + + // Listen at the socket + let mut listener = UnixSocketListener::bind(&path).await?; + + // Send the name path to our main test thread + tx.send(path) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?; + + // Get first connection + let mut conn_1 = listener.accept().await?; + + // Send some data to the first connection (12 bytes) + conn_1.write_all(b"hello conn 1").await?; + + // Get some data from the first connection (14 bytes) + let mut buf: [u8; 14] = [0; 14]; + let _ = conn_1.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server 1"); + + // Get second connection + let mut conn_2 = listener.accept().await?; + + // Send some data on to second connection (12 bytes) + conn_2.write_all(b"hello conn 2").await?; + + // Get some data from the second connection (14 bytes) + let mut buf: [u8; 14] = [0; 14]; + let _ = conn_2.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server 2"); + + Ok(()) + }); + + // Wait for the server to be ready + let path = rx.await.expect("Failed to get server socket path"); + + // Connect to the listener twice, sending some bytes and receiving some bytes from each + let mut buf: [u8; 12] = [0; 12]; + + let mut conn = UnixSocketTransport::connect(&path) + .await + .expect("Conn 1 failed to connect"); + conn.write_all(b"hello server 1") + .await + .expect("Conn 1 failed to write"); + conn.read_exact(&mut buf) + .await + .expect("Conn 1 failed to read"); + assert_eq!(&buf, b"hello conn 1"); + + let mut conn = UnixSocketTransport::connect(&path) + .await + .expect("Conn 2 failed to connect"); + conn.write_all(b"hello server 2") + .await + .expect("Conn 2 failed to write"); + conn.read_exact(&mut buf) + .await + .expect("Conn 2 failed to read"); + assert_eq!(&buf, b"hello conn 2"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Listener task failed unexpectedly"); + } +} diff --git a/distant-net/src/listener/windows.rs b/distant-net/src/listener/windows.rs new file mode 100644 index 0000000..ef30f4e --- /dev/null +++ b/distant-net/src/listener/windows.rs @@ -0,0 +1,162 @@ +use crate::{Listener, NamedPipe, WindowsPipeTransport}; +use async_trait::async_trait; +use std::{ + ffi::{OsStr, OsString}, + fmt, io, mem, +}; +use tokio::net::windows::named_pipe::{NamedPipeServer, ServerOptions}; + +/// Represents a [`Listener`] for incoming connections over a named windows pipe +pub struct WindowsPipeListener { + addr: OsString, + inner: NamedPipeServer, +} + +impl WindowsPipeListener { + /// Creates a new listener by binding to the specified local address + /// using the given name, which translates to `\\.\pipe\{name}` + pub fn bind_local(name: impl AsRef) -> io::Result { + let mut addr = OsString::from(r"\\.\pipe\"); + addr.push(name.as_ref()); + Self::bind(addr) + } + + /// Creates a new listener by binding to the specified address + pub fn bind(addr: impl Into) -> io::Result { + let addr = addr.into(); + let pipe = ServerOptions::new() + .first_pipe_instance(true) + .create(addr.as_os_str())?; + Ok(Self { addr, inner: pipe }) + } + + /// Returns the addr that the listener is bound to + pub fn addr(&self) -> &OsStr { + &self.addr + } +} + +impl fmt::Debug for WindowsPipeListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WindowsPipeListener") + .field("addr", &self.addr) + .finish() + } +} + +#[async_trait] +impl Listener for WindowsPipeListener { + type Output = WindowsPipeTransport; + + async fn accept(&mut self) -> io::Result { + // Wait for a new connection on the current server pipe + self.inner.connect().await?; + + // Create a new server pipe to use for the next connection + // as the current pipe is now taken with our existing connection + let pipe = mem::replace(&mut self.inner, ServerOptions::new().create(&self.addr)?); + + Ok(WindowsPipeTransport { + addr: self.addr.clone(), + inner: NamedPipe::from(pipe), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::oneshot, + task::JoinHandle, + }; + + #[tokio::test] + async fn should_fail_to_bind_if_pipe_already_bound() { + // Generate a pipe name + let name = format!("test_pipe_{}", rand::random::()); + + // Listen at the pipe + let _listener = + WindowsPipeListener::bind_local(&name).expect("Unexpectedly failed to bind first time"); + + // Now this should fail as we're already bound to the name + WindowsPipeListener::bind_local(&name) + .expect_err("Unexpectedly succeeded in binding to same pipe"); + } + + #[tokio::test] + async fn should_be_able_to_receive_connections_and_send_and_receive_data_with_them() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for two connections and then + // return the success or failure + let task: JoinHandle> = tokio::spawn(async move { + // Generate a pipe name + let name = format!("test_pipe_{}", rand::random::()); + + // Listen at the pipe + let mut listener = WindowsPipeListener::bind_local(&name)?; + + // Send the name back to our main test thread + tx.send(name) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; + + // Get first connection + let mut conn_1 = listener.accept().await?; + + // Send some data to the first connection (12 bytes) + conn_1.write_all(b"hello conn 1").await?; + + // Get some data from the first connection (14 bytes) + let mut buf: [u8; 14] = [0; 14]; + let _ = conn_1.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server 1"); + + // Get second connection + let mut conn_2 = listener.accept().await?; + + // Send some data on to second connection (12 bytes) + conn_2.write_all(b"hello conn 2").await?; + + // Get some data from the second connection (14 bytes) + let mut buf: [u8; 14] = [0; 14]; + let _ = conn_2.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server 2"); + + Ok(()) + }); + + // Wait for the server to be ready + let name = rx.await.expect("Failed to get server name"); + + // Connect to the listener twice, sending some bytes and receiving some bytes from each + let mut buf: [u8; 12] = [0; 12]; + + let mut conn = WindowsPipeTransport::connect_local(&name) + .await + .expect("Conn 1 failed to connect"); + conn.write_all(b"hello server 1") + .await + .expect("Conn 1 failed to write"); + conn.read_exact(&mut buf) + .await + .expect("Conn 1 failed to read"); + assert_eq!(&buf, b"hello conn 1"); + + let mut conn = WindowsPipeTransport::connect_local(&name) + .await + .expect("Conn 2 failed to connect"); + conn.write_all(b"hello server 2") + .await + .expect("Conn 2 failed to write"); + conn.read_exact(&mut buf) + .await + .expect("Conn 2 failed to read"); + assert_eq!(&buf, b"hello conn 2"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Listener task failed unexpectedly"); + } +} diff --git a/distant-net/src/packet.rs b/distant-net/src/packet.rs new file mode 100644 index 0000000..1cccf7c --- /dev/null +++ b/distant-net/src/packet.rs @@ -0,0 +1,68 @@ +/// Represents a generic id type +pub type Id = String; + +/// Represents a request to send +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Request { + /// Unique id associated with the request + pub id: Id, + + /// Payload associated with the request + pub payload: T, +} + +impl Request { + /// Creates a new request with a random, unique id + pub fn new(payload: T) -> Self { + Self { + id: rand::random::().to_string(), + payload, + } + } +} + +#[cfg(feature = "schemars")] +impl Request { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(Request) + } +} + +impl From for Request { + fn from(payload: T) -> Self { + Self::new(payload) + } +} + +/// Represents a response received related to some request +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Response { + /// Unique id associated with the response + pub id: Id, + + /// Unique id associated with the request that triggered the response + pub origin_id: Id, + + /// Payload associated with the response + pub payload: T, +} + +impl Response { + /// Creates a new response with a random, unique id + pub fn new(origin_id: Id, payload: T) -> Self { + Self { + id: rand::random::().to_string(), + origin_id, + payload, + } + } +} + +#[cfg(feature = "schemars")] +impl Response { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(Response) + } +} diff --git a/distant-core/src/server/port.rs b/distant-net/src/port.rs similarity index 94% rename from distant-core/src/server/port.rs rename to distant-net/src/port.rs index 524a911..ad73972 100644 --- a/distant-core/src/server/port.rs +++ b/distant-net/src/port.rs @@ -1,4 +1,5 @@ use derive_more::Display; +use serde::{Deserialize, Serialize}; use std::{ net::{IpAddr, SocketAddr}, ops::RangeInclusive, @@ -6,7 +7,7 @@ use std::{ }; /// Represents some range of ports -#[derive(Clone, Debug, Display, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] #[display( fmt = "{}{}", start, @@ -31,6 +32,15 @@ impl PortRange { } } +impl From for PortRange { + fn from(port: u16) -> Self { + Self { + start: port, + end: None, + } + } +} + impl From> for PortRange { fn from(r: RangeInclusive) -> Self { let (start, end) = r.into_inner(); diff --git a/distant-net/src/server.rs b/distant-net/src/server.rs new file mode 100644 index 0000000..448d184 --- /dev/null +++ b/distant-net/src/server.rs @@ -0,0 +1,49 @@ +use async_trait::async_trait; +use serde::{de::DeserializeOwned, Serialize}; + +mod connection; +pub use connection::*; + +mod context; +pub use context::*; + +mod ext; +pub use ext::*; + +mod r#ref; +pub use r#ref::*; + +mod reply; +pub use reply::*; + +mod state; +pub use state::*; + +/// Interface for a general-purpose server that receives requests to handle +#[async_trait] +pub trait Server: Send { + /// Type of data received by the server + type Request: DeserializeOwned + Send + Sync; + + /// Type of data sent back by the server + type Response: Serialize + Send; + + /// Type of data to store locally tied to the specific connection + type LocalData: Send + Sync; + + /// Invoked immediately on server start, being provided the raw listener to use (untyped + /// transport), and returning the listener when ready to start (enabling servers that need to + /// tweak a listener to do so) + /* async fn on_start(&mut self, listener: L) -> Box> { + } */ + + /// Invoked upon a new connection becoming established, which provides a mutable reference to + /// the data created for the connection. This can be useful in performing some additional + /// initialization on the data prior to it being used anywhere else. + #[allow(unused_variables)] + async fn on_accept(&self, local_data: &mut Self::LocalData) {} + + /// Invoked upon receiving a request from a client. The server should process this + /// request, which can be found in `ctx`, and send one or more replies in response. + async fn on_request(&self, ctx: ServerCtx); +} diff --git a/distant-net/src/server/connection.rs b/distant-net/src/server/connection.rs new file mode 100644 index 0000000..f4e17cf --- /dev/null +++ b/distant-net/src/server/connection.rs @@ -0,0 +1,51 @@ +use crate::ConnectionId; +use tokio::task::JoinHandle; + +/// Represents an individual connection on the server +pub struct ServerConnection { + /// Unique identifier tied to the connection + pub id: ConnectionId, + + /// Task that is processing incoming requests from the connection + pub(crate) reader_task: Option>, + + /// Task that is processing outgoing responses to the connection + pub(crate) writer_task: Option>, +} + +impl Default for ServerConnection { + fn default() -> Self { + Self::new() + } +} + +impl ServerConnection { + /// Creates a new connection, generating a unique id to represent the connection + pub fn new() -> Self { + Self { + id: rand::random(), + reader_task: None, + writer_task: None, + } + } + + /// Returns true if connection is still processing incoming or outgoing messages + pub fn is_active(&self) -> bool { + let reader_active = + self.reader_task.is_some() && !self.reader_task.as_ref().unwrap().is_finished(); + let writer_active = + self.writer_task.is_some() && !self.writer_task.as_ref().unwrap().is_finished(); + reader_active || writer_active + } + + /// Aborts the connection + pub fn abort(&self) { + if let Some(task) = self.reader_task.as_ref() { + task.abort(); + } + + if let Some(task) = self.writer_task.as_ref() { + task.abort(); + } + } +} diff --git a/distant-net/src/server/context.rs b/distant-net/src/server/context.rs new file mode 100644 index 0000000..3f3363d --- /dev/null +++ b/distant-net/src/server/context.rs @@ -0,0 +1,17 @@ +use crate::{ConnectionId, Request, ServerReply}; +use std::sync::Arc; + +/// Represents contextual information for working with an inbound request +pub struct ServerCtx { + /// Unique identifer associated with the connection that sent the request + pub connection_id: ConnectionId, + + /// The request being handled + pub request: Request, + + /// Used to send replies back to be sent out by the server + pub reply: ServerReply, + + /// Reference to the connection's local data + pub local_data: Arc, +} diff --git a/distant-net/src/server/ext.rs b/distant-net/src/server/ext.rs new file mode 100644 index 0000000..39a240c --- /dev/null +++ b/distant-net/src/server/ext.rs @@ -0,0 +1,195 @@ +use crate::{ + GenericServerRef, Listener, Request, Response, Server, ServerConnection, ServerCtx, ServerRef, + ServerReply, ServerState, TypedAsyncRead, TypedAsyncWrite, +}; +use log::*; +use serde::{de::DeserializeOwned, Serialize}; +use std::{io, sync::Arc}; +use tokio::sync::mpsc; + +mod tcp; +pub use tcp::*; + +#[cfg(unix)] +mod unix; + +#[cfg(unix)] +pub use unix::*; + +#[cfg(windows)] +mod windows; + +#[cfg(windows)] +pub use windows::*; + +/// Extension trait to provide a reference implementation of starting a server +/// that will listen for new connections (exposed as [`TypedAsyncWrite`] and [`TypedAsyncRead`]) +/// and process them using the [`Server`] implementation +pub trait ServerExt { + type Request; + type Response; + + /// Start a new server using the provided listener + fn start(self, listener: L) -> io::Result> + where + L: Listener + 'static, + R: TypedAsyncRead> + Send + 'static, + W: TypedAsyncWrite> + Send + 'static; +} + +impl ServerExt for S +where + S: Server + Sync + 'static, + Req: DeserializeOwned + Send + Sync + 'static, + Res: Serialize + Send + 'static, + Data: Default + Send + Sync + 'static, +{ + type Request = Req; + type Response = Res; + + fn start(self, listener: L) -> io::Result> + where + L: Listener + 'static, + R: TypedAsyncRead> + Send + 'static, + W: TypedAsyncWrite> + Send + 'static, + { + let server = Arc::new(self); + let state = Arc::new(ServerState::new()); + + let task = tokio::spawn(task(server, Arc::clone(&state), listener)); + + Ok(Box::new(GenericServerRef { state, task })) + } +} + +async fn task(server: Arc, state: Arc, mut listener: L) +where + S: Server + Sync + 'static, + Req: DeserializeOwned + Send + Sync + 'static, + Res: Serialize + Send + 'static, + Data: Default + Send + Sync + 'static, + L: Listener + 'static, + R: TypedAsyncRead> + Send + 'static, + W: TypedAsyncWrite> + Send + 'static, +{ + loop { + let server = Arc::clone(&server); + match listener.accept().await { + Ok((mut writer, mut reader)) => { + let mut connection = ServerConnection::new(); + let connection_id = connection.id; + let state = Arc::clone(&state); + + // Create some default data for the new connection and pass it + // to the callback prior to processing new requests + let local_data = { + let mut data = Data::default(); + server.on_accept(&mut data).await; + Arc::new(data) + }; + + // Start a writer task that reads from a channel and forwards all + // data through the writer + let (tx, mut rx) = mpsc::channel::>(1); + connection.writer_task = Some(tokio::spawn(async move { + while let Some(data) = rx.recv().await { + // trace!("[Conn {}] Sending {:?}", connection_id, data.payload); + if let Err(x) = writer.write(data).await { + error!("[Conn {}] Failed to send {:?}", connection_id, x); + break; + } + } + })); + + // Start a reader task that reads requests and processes them + // using the provided handler + connection.reader_task = Some(tokio::spawn(async move { + loop { + match reader.read().await { + Ok(Some(request)) => { + let reply = ServerReply { + origin_id: request.id.clone(), + tx: tx.clone(), + }; + + let ctx = ServerCtx { + connection_id, + request, + reply: reply.clone(), + local_data: Arc::clone(&local_data), + }; + + server.on_request(ctx).await; + } + Ok(None) => { + debug!("[Conn {}] Connection closed", connection_id); + break; + } + Err(x) => { + // NOTE: We do NOT break out of the loop, as this could happen + // if someone sends bad data at any point, but does not + // mean that the reader itself has failed. This can + // happen from getting non-compliant typed data + error!("[Conn {}] {}", connection_id, x); + } + } + } + })); + + state + .connections + .write() + .await + .insert(connection_id, connection); + } + Err(x) => { + error!("Server no longer accepting connections: {}", x); + break; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{IntoSplit, MpscListener, MpscTransport}; + use async_trait::async_trait; + + pub struct TestServer; + + #[async_trait] + impl Server for TestServer { + type Request = u16; + type Response = String; + type LocalData = (); + + async fn on_request(&self, ctx: ServerCtx) { + // Always send back "hello" + ctx.reply.send("hello".to_string()).await.unwrap(); + } + } + + #[tokio::test] + async fn should_invoke_handler_upon_receiving_a_request() { + // Create a test listener where we will forward a connection + let (tx, listener) = MpscListener::channel(100); + + // Make bounded transport pair and send off one of them to act as our connection + let (mut transport, connection) = + MpscTransport::, Response>::pair(100); + tx.send(connection.into_split()) + .await + .expect("Failed to feed listener a connection"); + + let _server = ServerExt::start(TestServer, listener).expect("Failed to start server"); + + transport + .write(Request::new(123)) + .await + .expect("Failed to send request"); + + let response: Response = transport.read().await.unwrap().unwrap(); + assert_eq!(response.payload, "hello"); + } +} diff --git a/distant-net/src/server/ext/tcp.rs b/distant-net/src/server/ext/tcp.rs new file mode 100644 index 0000000..ff764cd --- /dev/null +++ b/distant-net/src/server/ext/tcp.rs @@ -0,0 +1,94 @@ +use crate::{ + Codec, FramedTransport, IntoSplit, MappedListener, PortRange, Server, ServerExt, TcpListener, + TcpServerRef, +}; +use async_trait::async_trait; +use serde::{de::DeserializeOwned, Serialize}; +use std::{io, net::IpAddr}; + +/// Extension trait to provide a reference implementation of starting a TCP server +/// that will listen for new connections and process them using the [`Server`] implementation +#[async_trait] +pub trait TcpServerExt { + type Request; + type Response; + + /// Start a new server using the provided listener + async fn start(self, addr: IpAddr, port: P, codec: C) -> io::Result + where + P: Into + Send, + C: Codec + Send + Sync + 'static; +} + +#[async_trait] +impl TcpServerExt for S +where + S: Server + Sync + 'static, + Req: DeserializeOwned + Send + Sync + 'static, + Res: Serialize + Send + 'static, + Data: Default + Send + Sync + 'static, +{ + type Request = Req; + type Response = Res; + + async fn start(self, addr: IpAddr, port: P, codec: C) -> io::Result + where + P: Into + Send, + C: Codec + Send + Sync + 'static, + { + let listener = TcpListener::bind(addr, port).await?; + let port = listener.port(); + + let listener = MappedListener::new(listener, move |transport| { + let transport = FramedTransport::new(transport, codec.clone()); + transport.into_split() + }); + let inner = ServerExt::start(self, listener)?; + Ok(TcpServerRef { addr, port, inner }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Client, PlainCodec, Request, ServerCtx, TcpClientExt}; + use std::net::{Ipv6Addr, SocketAddr}; + + pub struct TestServer; + + #[async_trait] + impl Server for TestServer { + type Request = String; + type Response = String; + type LocalData = (); + + async fn on_request(&self, ctx: ServerCtx) { + // Echo back what we received + ctx.reply + .send(ctx.request.payload.to_string()) + .await + .unwrap(); + } + } + + #[tokio::test] + async fn should_invoke_handler_upon_receiving_a_request() { + let server = + TcpServerExt::start(TestServer, IpAddr::V6(Ipv6Addr::LOCALHOST), 0, PlainCodec) + .await + .expect("Failed to start TCP server"); + + let mut client: Client = Client::connect( + SocketAddr::from((server.ip_addr(), server.port())), + PlainCodec, + ) + .await + .expect("Client failed to connect"); + + let response = client + .send(Request::new("hello".to_string())) + .await + .expect("Failed to send message"); + assert_eq!(response.payload, "hello"); + } +} diff --git a/distant-net/src/server/ext/unix.rs b/distant-net/src/server/ext/unix.rs new file mode 100644 index 0000000..1a2838f --- /dev/null +++ b/distant-net/src/server/ext/unix.rs @@ -0,0 +1,97 @@ +use crate::{ + Codec, FramedTransport, IntoSplit, MappedListener, Server, ServerExt, UnixSocketListener, + UnixSocketServerRef, +}; +use async_trait::async_trait; +use serde::{de::DeserializeOwned, Serialize}; +use std::{io, path::Path}; + +/// Extension trait to provide a reference implementation of starting a Unix socket server +/// that will listen for new connections and process them using the [`Server`] implementation +#[async_trait] +pub trait UnixSocketServerExt { + type Request; + type Response; + + /// Start a new server using the provided listener + async fn start(self, path: P, codec: C) -> io::Result + where + P: AsRef + Send, + C: Codec + Send + Sync + 'static; +} + +#[async_trait] +impl UnixSocketServerExt for S +where + S: Server + Sync + 'static, + Req: DeserializeOwned + Send + Sync + 'static, + Res: Serialize + Send + 'static, + Data: Default + Send + Sync + 'static, +{ + type Request = Req; + type Response = Res; + + async fn start(self, path: P, codec: C) -> io::Result + where + P: AsRef + Send, + C: Codec + Send + Sync + 'static, + { + let path = path.as_ref(); + let listener = UnixSocketListener::bind(path).await?; + let path = listener.path().to_path_buf(); + + let listener = MappedListener::new(listener, move |transport| { + let transport = FramedTransport::new(transport, codec.clone()); + transport.into_split() + }); + let inner = ServerExt::start(self, listener)?; + Ok(UnixSocketServerRef { path, inner }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Client, PlainCodec, Request, ServerCtx, UnixSocketClientExt}; + use tempfile::NamedTempFile; + + pub struct TestServer; + + #[async_trait] + impl Server for TestServer { + type Request = String; + type Response = String; + type LocalData = (); + + async fn on_request(&self, ctx: ServerCtx) { + // Echo back what we received + ctx.reply + .send(ctx.request.payload.to_string()) + .await + .unwrap(); + } + } + + #[tokio::test] + async fn should_invoke_handler_upon_receiving_a_request() { + // Generate a socket path and delete the file after so there is nothing there + let path = NamedTempFile::new() + .expect("Failed to create socket file") + .path() + .to_path_buf(); + + let server = UnixSocketServerExt::start(TestServer, path, PlainCodec) + .await + .expect("Failed to start Unix socket server"); + + let mut client: Client = Client::connect(server.path(), PlainCodec) + .await + .expect("Client failed to connect"); + + let response = client + .send(Request::new("hello".to_string())) + .await + .expect("Failed to send message"); + assert_eq!(response.payload, "hello"); + } +} diff --git a/distant-net/src/server/ext/windows.rs b/distant-net/src/server/ext/windows.rs new file mode 100644 index 0000000..d2e8715 --- /dev/null +++ b/distant-net/src/server/ext/windows.rs @@ -0,0 +1,109 @@ +use crate::{ + Codec, FramedTransport, IntoSplit, MappedListener, Server, ServerExt, WindowsPipeListener, + WindowsPipeServerRef, +}; +use async_trait::async_trait; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + ffi::{OsStr, OsString}, + io, +}; + +/// Extension trait to provide a reference implementation of starting a Windows pipe server +/// that will listen for new connections and process them using the [`Server`] implementation +#[async_trait] +pub trait WindowsPipeServerExt { + type Request; + type Response; + + /// Start a new server at the specified address using the given codec + async fn start(self, addr: A, codec: C) -> io::Result + where + A: AsRef + Send, + C: Codec + Send + Sync + 'static; + + /// Start a new server at the specified address via `\\.\pipe\{name}` using the given codec + async fn start_local(self, name: N, codec: C) -> io::Result + where + Self: Sized, + N: AsRef + Send, + C: Codec + Send + Sync + 'static, + { + let mut addr = OsString::from(r"\\.\pipe\"); + addr.push(name.as_ref()); + self.start(addr, codec).await + } +} + +#[async_trait] +impl WindowsPipeServerExt for S +where + S: Server + Sync + 'static, + Req: DeserializeOwned + Send + Sync + 'static, + Res: Serialize + Send + 'static, + Data: Default + Send + Sync + 'static, +{ + type Request = Req; + type Response = Res; + + async fn start(self, addr: A, codec: C) -> io::Result + where + A: AsRef + Send, + C: Codec + Send + Sync + 'static, + { + let a = addr.as_ref(); + let listener = WindowsPipeListener::bind(a)?; + let addr = listener.addr().to_os_string(); + + let listener = MappedListener::new(listener, move |transport| { + let transport = FramedTransport::new(transport, codec.clone()); + transport.into_split() + }); + let inner = ServerExt::start(self, listener)?; + Ok(WindowsPipeServerRef { addr, inner }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Client, PlainCodec, Request, ServerCtx, WindowsPipeClientExt}; + + pub struct TestServer; + + #[async_trait] + impl Server for TestServer { + type Request = String; + type Response = String; + type LocalData = (); + + async fn on_request(&self, ctx: ServerCtx) { + // Echo back what we received + ctx.reply + .send(ctx.request.payload.to_string()) + .await + .unwrap(); + } + } + + #[tokio::test] + async fn should_invoke_handler_upon_receiving_a_request() { + let server = WindowsPipeServerExt::start_local( + TestServer, + format!("test_pip_{}", rand::random::()), + PlainCodec, + ) + .await + .expect("Failed to start Windows pipe server"); + + let mut client: Client = Client::connect(server.addr(), PlainCodec) + .await + .expect("Client failed to connect"); + + let response = client + .send(Request::new("hello".to_string())) + .await + .expect("Failed to send message"); + assert_eq!(response.payload, "hello"); + } +} diff --git a/distant-net/src/server/ref.rs b/distant-net/src/server/ref.rs new file mode 100644 index 0000000..5359ece --- /dev/null +++ b/distant-net/src/server/ref.rs @@ -0,0 +1,120 @@ +use crate::{AsAny, ServerState}; +use log::*; +use std::{ + future::Future, + io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; +use tokio::task::{JoinError, JoinHandle}; + +/// Interface to engage with a server instance +pub trait ServerRef: AsAny + Send { + /// Returns a reference to the state of the server + fn state(&self) -> &ServerState; + + /// Returns true if the server is no longer running + fn is_finished(&self) -> bool; + + /// Kills the internal task processing new inbound requests + fn abort(&self); + + fn wait(self) -> Pin>>> + where + Self: Sized + 'static, + { + Box::pin(async { + let task = tokio::spawn(async move { + while !self.is_finished() { + tokio::time::sleep(Duration::from_millis(100)).await; + } + }); + task.await + .map_err(|x| io::Error::new(io::ErrorKind::Other, x)) + }) + } +} + +impl dyn ServerRef { + /// Attempts to convert this ref into a concrete ref by downcasting + pub fn as_server_ref(&self) -> Option<&R> { + self.as_any().downcast_ref::() + } + + /// Attempts to convert this mutable ref into a concrete mutable ref by downcasting + pub fn as_mut_server_ref(&mut self) -> Option<&mut R> { + self.as_mut_any().downcast_mut::() + } + + /// Attempts to convert this into a concrete, boxed ref by downcasting + pub fn into_boxed_server_ref( + self: Box, + ) -> Result, Box> { + self.into_any().downcast::() + } +} + +/// Represents a generic reference to a server +pub struct GenericServerRef { + pub(crate) state: Arc, + pub(crate) task: JoinHandle<()>, +} + +/// Runtime-specific implementation of [`ServerRef`] for a [`tokio::task::JoinHandle`] +impl ServerRef for GenericServerRef { + fn state(&self) -> &ServerState { + &self.state + } + + fn is_finished(&self) -> bool { + self.task.is_finished() + } + + fn abort(&self) { + self.task.abort(); + + let state = Arc::clone(&self.state); + tokio::spawn(async move { + for (id, connection) in state.connections.read().await.iter() { + debug!("Aborting connection {}", id); + connection.abort(); + } + }); + } + + fn wait(self) -> Pin>>> + where + Self: Sized + 'static, + { + Box::pin(async { + self.task + .await + .map_err(|x| io::Error::new(io::ErrorKind::Other, x)) + }) + } +} + +impl Future for GenericServerRef { + type Output = Result<(), JoinError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.task).poll(cx) + } +} + +mod tcp; +pub use tcp::*; + +#[cfg(unix)] +mod unix; + +#[cfg(unix)] +pub use unix::*; + +#[cfg(windows)] +mod windows; + +#[cfg(windows)] +pub use windows::*; diff --git a/distant-net/src/server/ref/tcp.rs b/distant-net/src/server/ref/tcp.rs new file mode 100644 index 0000000..f042d92 --- /dev/null +++ b/distant-net/src/server/ref/tcp.rs @@ -0,0 +1,39 @@ +use crate::{ServerRef, ServerState}; +use std::net::IpAddr; + +/// Reference to a TCP server instance +pub struct TcpServerRef { + pub(crate) addr: IpAddr, + pub(crate) port: u16, + pub(crate) inner: Box, +} + +impl TcpServerRef { + pub fn new(addr: IpAddr, port: u16, inner: Box) -> Self { + Self { addr, port, inner } + } + + /// Returns the IP address that the listener is bound to + pub fn ip_addr(&self) -> IpAddr { + self.addr + } + + /// Returns the port that the listener is bound to + pub fn port(&self) -> u16 { + self.port + } +} + +impl ServerRef for TcpServerRef { + fn state(&self) -> &ServerState { + self.inner.state() + } + + fn is_finished(&self) -> bool { + self.inner.is_finished() + } + + fn abort(&self) { + self.inner.abort(); + } +} diff --git a/distant-net/src/server/ref/unix.rs b/distant-net/src/server/ref/unix.rs new file mode 100644 index 0000000..3e762a6 --- /dev/null +++ b/distant-net/src/server/ref/unix.rs @@ -0,0 +1,38 @@ +use crate::{ServerRef, ServerState}; +use std::path::{Path, PathBuf}; + +/// Reference to a unix socket server instance +pub struct UnixSocketServerRef { + pub(crate) path: PathBuf, + pub(crate) inner: Box, +} + +impl UnixSocketServerRef { + pub fn new(path: PathBuf, inner: Box) -> Self { + Self { path, inner } + } + + /// Returns the path to the socket + pub fn path(&self) -> &Path { + &self.path + } + + /// Consumes ref, returning inner ref + pub fn into_inner(self) -> Box { + self.inner + } +} + +impl ServerRef for UnixSocketServerRef { + fn state(&self) -> &ServerState { + self.inner.state() + } + + fn is_finished(&self) -> bool { + self.inner.is_finished() + } + + fn abort(&self) { + self.inner.abort(); + } +} diff --git a/distant-net/src/server/ref/windows.rs b/distant-net/src/server/ref/windows.rs new file mode 100644 index 0000000..6d0ee77 --- /dev/null +++ b/distant-net/src/server/ref/windows.rs @@ -0,0 +1,38 @@ +use crate::{ServerRef, ServerState}; +use std::ffi::{OsStr, OsString}; + +/// Reference to a unix socket server instance +pub struct WindowsPipeServerRef { + pub(crate) addr: OsString, + pub(crate) inner: Box, +} + +impl WindowsPipeServerRef { + pub fn new(addr: OsString, inner: Box) -> Self { + Self { addr, inner } + } + + /// Returns the addr that the listener is bound to + pub fn addr(&self) -> &OsStr { + &self.addr + } + + /// Consumes ref, returning inner ref + pub fn into_inner(self) -> Box { + self.inner + } +} + +impl ServerRef for WindowsPipeServerRef { + fn state(&self) -> &ServerState { + self.inner.state() + } + + fn is_finished(&self) -> bool { + self.inner.is_finished() + } + + fn abort(&self) { + self.inner.abort(); + } +} diff --git a/distant-net/src/server/reply.rs b/distant-net/src/server/reply.rs new file mode 100644 index 0000000..0756ab0 --- /dev/null +++ b/distant-net/src/server/reply.rs @@ -0,0 +1,198 @@ +use crate::{Id, Response}; +use std::{future::Future, io, pin::Pin, sync::Arc}; +use tokio::sync::{mpsc, Mutex}; + +/// Interface to send a reply to some request +pub trait Reply: Send + Sync { + type Data; + + /// Sends a reply out from the server + fn send(&self, data: Self::Data) -> Pin> + Send + '_>>; + + /// Blocking version of sending a reply out from the server + fn blocking_send(&self, data: Self::Data) -> io::Result<()>; + + /// Clones this reply + fn clone_reply(&self) -> Box>; +} + +impl Reply for mpsc::Sender { + type Data = T; + + fn send(&self, data: Self::Data) -> Pin> + Send + '_>> { + Box::pin(async move { + mpsc::Sender::send(self, data) + .await + .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string())) + }) + } + + fn blocking_send(&self, data: Self::Data) -> io::Result<()> { + mpsc::Sender::blocking_send(self, data) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string())) + } + + fn clone_reply(&self) -> Box> { + Box::new(self.clone()) + } +} + +/// Utility to send ad-hoc replies from the server back through the connection +pub struct ServerReply { + pub(crate) origin_id: Id, + pub(crate) tx: mpsc::Sender>, +} + +impl Clone for ServerReply { + fn clone(&self) -> Self { + Self { + origin_id: self.origin_id.clone(), + tx: self.tx.clone(), + } + } +} + +impl ServerReply { + pub async fn send(&self, data: T) -> io::Result<()> { + self.tx + .send(Response::new(self.origin_id.clone(), data)) + .await + .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "Connection reply closed")) + } + + pub fn blocking_send(&self, data: T) -> io::Result<()> { + self.tx + .blocking_send(Response::new(self.origin_id.clone(), data)) + .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "Connection reply closed")) + } + + pub fn is_closed(&self) -> bool { + self.tx.is_closed() + } + + pub fn queue(self) -> QueuedServerReply { + QueuedServerReply { + inner: self, + queue: Arc::new(Mutex::new(Vec::new())), + hold: Arc::new(Mutex::new(true)), + } + } +} + +impl Reply for ServerReply { + type Data = T; + + fn send(&self, data: Self::Data) -> Pin> + Send + '_>> { + Box::pin(ServerReply::send(self, data)) + } + + fn blocking_send(&self, data: Self::Data) -> io::Result<()> { + ServerReply::blocking_send(self, data) + } + + fn clone_reply(&self) -> Box> { + Box::new(self.clone()) + } +} + +/// Represents a reply where all sends are queued up but not sent until +/// after the flush method is called. This reply supports injecting +/// at the front of the queue in order to support sending messages +/// but ensuring that some specific message is sent out first +pub struct QueuedServerReply { + inner: ServerReply, + queue: Arc>>, + hold: Arc>, +} + +impl Clone for QueuedServerReply { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + queue: Arc::clone(&self.queue), + hold: Arc::clone(&self.hold), + } + } +} + +impl QueuedServerReply { + /// Updates the hold status for the queue + /// + /// * If true, all messages are held until the queue is flushed + /// * If false, messages are sent directly as they come in + pub async fn hold(&self, hold: bool) { + *self.hold.lock().await = hold; + } + + /// Send this message, adding it to a queue if holding messages + pub async fn send(&self, data: T) -> io::Result<()> { + if *self.hold.lock().await { + self.queue.lock().await.push(data); + Ok(()) + } else { + self.inner.send(data).await + } + } + + /// Send this message, adding it to a queue if holding messages, blocking + /// for access to locks and other internals + pub fn blocking_send(&self, data: T) -> io::Result<()> { + if *self.hold.blocking_lock() { + self.queue.blocking_lock().push(data); + Ok(()) + } else { + self.inner.blocking_send(data) + } + } + + /// Send this message before anything else in the queue + pub async fn send_before(&self, data: T) -> io::Result<()> { + if *self.hold.lock().await { + self.queue.lock().await.insert(0, data); + Ok(()) + } else { + self.inner.send(data).await + } + } + + /// Sends all pending msgs queued up and clears the queue + /// + /// Additionally, takes `hold` to indicate whether or not new msgs + /// after the flush should continue to be held within the queue + /// or if all future msgs will be sent immediately + pub async fn flush(&self, hold: bool) -> io::Result<()> { + // Lock hold so we can ensure that nothing gets sent + // to the queue after we clear it + let mut hold_lock = self.hold.lock().await; + + // Clear the queue by sending everything + for data in self.queue.lock().await.drain(..) { + self.inner.send(data).await?; + } + + // Update hold to + *hold_lock = hold; + + Ok(()) + } + + pub fn is_closed(&self) -> bool { + self.inner.is_closed() + } +} + +impl Reply for QueuedServerReply { + type Data = T; + + fn send(&self, data: Self::Data) -> Pin> + Send + '_>> { + Box::pin(QueuedServerReply::send(self, data)) + } + + fn blocking_send(&self, data: Self::Data) -> io::Result<()> { + QueuedServerReply::blocking_send(self, data) + } + + fn clone_reply(&self) -> Box> { + Box::new(self.clone()) + } +} diff --git a/distant-net/src/server/state.rs b/distant-net/src/server/state.rs new file mode 100644 index 0000000..54553ea --- /dev/null +++ b/distant-net/src/server/state.rs @@ -0,0 +1,23 @@ +use crate::{ConnectionId, ServerConnection}; +use std::collections::HashMap; +use tokio::sync::RwLock; + +/// Contains all top-level state for the server +pub struct ServerState { + /// Mapping of connection ids to their transports + pub connections: RwLock>, +} + +impl ServerState { + pub fn new() -> Self { + Self { + connections: RwLock::new(HashMap::new()), + } + } +} + +impl Default for ServerState { + fn default() -> Self { + Self::new() + } +} diff --git a/distant-net/src/transport.rs b/distant-net/src/transport.rs new file mode 100644 index 0000000..acd7109 --- /dev/null +++ b/distant-net/src/transport.rs @@ -0,0 +1,112 @@ +use async_trait::async_trait; +use std::{io, marker::Unpin}; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// Interface to split something into writing and reading halves +pub trait IntoSplit { + type Write; + type Read; + + fn into_split(self) -> (Self::Write, Self::Read); +} + +impl IntoSplit for (W, R) { + type Write = W; + type Read = R; + + fn into_split(self) -> (Self::Write, Self::Read) { + (self.0, self.1) + } +} + +/// Interface representing a transport of raw bytes into and out of the system +pub trait RawTransport: RawTransportRead + RawTransportWrite {} + +/// Interface representing a transport of raw bytes into the system +pub trait RawTransportRead: AsyncRead + Send + Unpin {} + +/// Interface representing a transport of raw bytes out of the system +pub trait RawTransportWrite: AsyncWrite + Send + Unpin {} + +/// Interface representing a transport of typed data into and out of the system +pub trait TypedTransport: TypedAsyncRead + TypedAsyncWrite {} + +/// Interface to read some structured data asynchronously +#[async_trait] +pub trait TypedAsyncRead { + /// Reads some data, returning `Some(T)` if available or `None` if the reader + /// has closed and no longer is providing data + async fn read(&mut self) -> io::Result>; +} + +#[async_trait] +impl TypedAsyncRead for (W, R) +where + W: Send, + R: TypedAsyncRead + Send, +{ + async fn read(&mut self) -> io::Result> { + self.1.read().await + } +} + +#[async_trait] +impl TypedAsyncRead for Box + Send> { + async fn read(&mut self) -> io::Result> { + (**self).read().await + } +} + +/// Interface to write some structured data asynchronously +#[async_trait] +pub trait TypedAsyncWrite { + async fn write(&mut self, data: T) -> io::Result<()>; +} + +#[async_trait] +impl TypedAsyncWrite for (W, R) +where + W: TypedAsyncWrite + Send, + R: Send, + T: Send + 'static, +{ + async fn write(&mut self, data: T) -> io::Result<()> { + self.0.write(data).await + } +} + +#[async_trait] +impl TypedAsyncWrite for Box + Send> { + async fn write(&mut self, data: T) -> io::Result<()> { + (**self).write(data).await + } +} + +mod router; + +mod framed; +pub use framed::*; + +mod inmemory; +pub use inmemory::*; + +mod mpsc; +pub use mpsc::*; + +mod tcp; +pub use tcp::*; + +#[cfg(unix)] +mod unix; + +#[cfg(unix)] +pub use unix::*; + +mod untyped; +pub use untyped::*; + +#[cfg(windows)] +mod windows; + +#[cfg(windows)] +pub use windows::*; diff --git a/distant-net/src/transport/framed.rs b/distant-net/src/transport/framed.rs new file mode 100644 index 0000000..ec259c4 --- /dev/null +++ b/distant-net/src/transport/framed.rs @@ -0,0 +1,209 @@ +use crate::{ + utils, Codec, IntoSplit, RawTransport, RawTransportRead, RawTransportWrite, UntypedTransport, + UntypedTransportRead, UntypedTransportWrite, +}; +use async_trait::async_trait; +use futures::{SinkExt, StreamExt}; +use serde::{de::DeserializeOwned, Serialize}; +use std::io; +use tokio_util::codec::{Framed, FramedRead, FramedWrite}; + +#[cfg(test)] +mod test; + +#[cfg(test)] +pub use test::*; + +mod read; +pub use read::*; + +mod write; +pub use write::*; + +/// Represents [`TypedTransport`] of data across the network using frames in order to support +/// typed messages instead of arbitrary bytes being sent across the wire. +/// +/// Note that this type does **not** implement [`RawTransport`] and instead acts as a wrapper +/// around a transport to provide a higher-level interface +#[derive(Debug)] +pub struct FramedTransport(Framed) +where + T: RawTransport, + C: Codec; + +impl FramedTransport +where + T: RawTransport, + C: Codec, +{ + /// Creates a new instance of the transport, wrapping the stream in a `Framed` + pub fn new(transport: T, codec: C) -> Self { + Self(Framed::new(transport, codec)) + } +} + +impl UntypedTransport for FramedTransport +where + T: RawTransport, + C: Codec + Send, +{ +} + +impl IntoSplit for FramedTransport +where + T: RawTransport + IntoSplit, + ::Read: RawTransportRead, + ::Write: RawTransportWrite, + C: Codec + Send, +{ + type Read = FramedTransportReadHalf<::Read, C>; + type Write = FramedTransportWriteHalf<::Write, C>; + + fn into_split(self) -> (Self::Write, Self::Read) { + let parts = self.0.into_parts(); + let (write_half, read_half) = parts.io.into_split(); + + // Create our split read half and populate its buffer with existing contents + let mut f_read = FramedRead::new(read_half, parts.codec.clone()); + *f_read.read_buffer_mut() = parts.read_buf; + + // Create our split write half and populate its buffer with existing contents + let mut f_write = FramedWrite::new(write_half, parts.codec); + *f_write.write_buffer_mut() = parts.write_buf; + + let read_half = FramedTransportReadHalf(f_read); + let write_half = FramedTransportWriteHalf(f_write); + + (write_half, read_half) + } +} + +#[async_trait] +impl UntypedTransportWrite for FramedTransport +where + T: RawTransport + Send, + C: Codec + Send, +{ + async fn write(&mut self, data: D) -> io::Result<()> + where + D: Serialize + Send + 'static, + { + // Serialize data into a byte stream + // NOTE: Cannot used packed implementation for now due to issues with deserialization + let data = utils::serialize_to_vec(&data)?; + + // Use underlying codec to send data (may encrypt, sign, etc.) + self.0.send(&data).await + } +} + +#[async_trait] +impl UntypedTransportRead for FramedTransport +where + T: RawTransport + Send, + C: Codec + Send, +{ + async fn read(&mut self) -> io::Result> + where + D: DeserializeOwned, + { + // Use underlying codec to receive data (may decrypt, validate, etc.) + if let Some(data) = self.0.next().await { + let data = data?; + + // Deserialize byte stream into our expected type + let data = utils::deserialize_from_slice(&data)?; + Ok(Some(data)) + } else { + Ok(None) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{InmemoryTransport, PlainCodec}; + use serde::{Deserialize, Serialize}; + + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] + pub struct TestData { + name: String, + value: usize, + } + + #[tokio::test] + async fn send_should_convert_data_into_byte_stream_and_send_through_stream() { + let (_tx, mut rx, stream) = InmemoryTransport::make(1); + let mut transport = FramedTransport::new(stream, PlainCodec::new()); + + let data = TestData { + name: String::from("test"), + value: 123, + }; + + let bytes = utils::serialize_to_vec(&data).unwrap(); + let len = (bytes.len() as u64).to_be_bytes(); + let mut frame = Vec::new(); + frame.extend(len.iter().copied()); + frame.extend(bytes); + + transport.write(data).await.unwrap(); + + let outgoing = rx.recv().await.unwrap(); + assert_eq!(outgoing, frame); + } + + #[tokio::test] + async fn receive_should_return_none_if_stream_is_closed() { + let (_, _, stream) = InmemoryTransport::make(1); + let mut transport = FramedTransport::new(stream, PlainCodec::new()); + + let result = transport.read::().await; + match result { + Ok(None) => {} + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn receive_should_fail_if_unable_to_convert_to_type() { + let (tx, _rx, stream) = InmemoryTransport::make(1); + let mut transport = FramedTransport::new(stream, PlainCodec::new()); + + #[derive(Serialize, Deserialize)] + struct OtherTestData(usize); + + let data = OtherTestData(123); + let bytes = utils::serialize_to_vec(&data).unwrap(); + let len = (bytes.len() as u64).to_be_bytes(); + let mut frame = Vec::new(); + frame.extend(len.iter().copied()); + frame.extend(bytes); + + tx.send(frame).await.unwrap(); + let result = transport.read::().await; + assert!(result.is_err(), "Unexpectedly succeeded") + } + + #[tokio::test] + async fn receive_should_return_some_instance_of_type_when_coming_into_stream() { + let (tx, _rx, stream) = InmemoryTransport::make(1); + let mut transport = FramedTransport::new(stream, PlainCodec::new()); + + let data = TestData { + name: String::from("test"), + value: 123, + }; + + let bytes = utils::serialize_to_vec(&data).unwrap(); + let len = (bytes.len() as u64).to_be_bytes(); + let mut frame = Vec::new(); + frame.extend(len.iter().copied()); + frame.extend(bytes); + + tx.send(frame).await.unwrap(); + let received_data = transport.read::().await.unwrap().unwrap(); + assert_eq!(received_data, data); + } +} diff --git a/distant-net/src/transport/framed/read.rs b/distant-net/src/transport/framed/read.rs new file mode 100644 index 0000000..ff1bb41 --- /dev/null +++ b/distant-net/src/transport/framed/read.rs @@ -0,0 +1,109 @@ +use crate::{transport::framed::utils, Codec, UntypedTransportRead}; +use async_trait::async_trait; +use futures::StreamExt; +use serde::de::DeserializeOwned; +use std::io; +use tokio::io::AsyncRead; +use tokio_util::codec::FramedRead; + +/// Represents a transport of inbound data from the network using frames in order to support +/// typed messages instead of arbitrary bytes being sent across the wire. +/// +/// Note that this type does **not** implement [`AsyncRead`] and instead acts as a +/// wrapper to provide a higher-level interface +pub struct FramedTransportReadHalf(pub(super) FramedRead) +where + R: AsyncRead, + C: Codec; + +#[async_trait] +impl UntypedTransportRead for FramedTransportReadHalf +where + R: AsyncRead + Send + Unpin, + C: Codec + Send, +{ + async fn read(&mut self) -> io::Result> + where + D: DeserializeOwned, + { + // Use underlying codec to receive data (may decrypt, validate, etc.) + if let Some(data) = self.0.next().await { + let data = data?; + + // Deserialize byte stream into our expected type + let data = utils::deserialize_from_slice(&data)?; + Ok(Some(data)) + } else { + Ok(None) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{FramedTransport, InmemoryTransport, IntoSplit, PlainCodec}; + use serde::{Deserialize, Serialize}; + + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] + pub struct TestData { + name: String, + value: usize, + } + + #[tokio::test] + async fn receive_should_return_none_if_stream_is_closed() { + let (_, _, stream) = InmemoryTransport::make(1); + let transport = FramedTransport::new(stream, PlainCodec::new()); + let (_, mut reader) = transport.into_split(); + + let result = reader.read::().await; + match result { + Ok(None) => {} + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn receive_should_fail_if_unable_to_convert_to_type() { + let (tx, _rx, stream) = InmemoryTransport::make(1); + let transport = FramedTransport::new(stream, PlainCodec::new()); + let (_, mut reader) = transport.into_split(); + + #[derive(Serialize, Deserialize)] + struct OtherTestData(usize); + + let data = OtherTestData(123); + let bytes = utils::serialize_to_vec(&data).unwrap(); + let len = (bytes.len() as u64).to_be_bytes(); + let mut frame = Vec::new(); + frame.extend(len.iter().copied()); + frame.extend(bytes); + + tx.send(frame).await.unwrap(); + let result = reader.read::().await; + assert!(result.is_err(), "Unexpectedly succeeded"); + } + + #[tokio::test] + async fn receive_should_return_some_instance_of_type_when_coming_into_stream() { + let (tx, _rx, stream) = InmemoryTransport::make(1); + let transport = FramedTransport::new(stream, PlainCodec::new()); + let (_, mut reader) = transport.into_split(); + + let data = TestData { + name: String::from("test"), + value: 123, + }; + + let bytes = utils::serialize_to_vec(&data).unwrap(); + let len = (bytes.len() as u64).to_be_bytes(); + let mut frame = Vec::new(); + frame.extend(len.iter().copied()); + frame.extend(bytes); + + tx.send(frame).await.unwrap(); + let received_data = reader.read::().await.unwrap().unwrap(); + assert_eq!(received_data, data); + } +} diff --git a/distant-net/src/transport/framed/test.rs b/distant-net/src/transport/framed/test.rs new file mode 100644 index 0000000..cfae093 --- /dev/null +++ b/distant-net/src/transport/framed/test.rs @@ -0,0 +1,12 @@ +use crate::{FramedTransport, InmemoryTransport, PlainCodec}; + +#[cfg(test)] +impl FramedTransport { + /// Makes a connected pair of framed inmemory transports with plain codec for testing purposes + pub fn make_test_pair() -> ( + FramedTransport, + FramedTransport, + ) { + Self::pair(100) + } +} diff --git a/distant-net/src/transport/framed/write.rs b/distant-net/src/transport/framed/write.rs new file mode 100644 index 0000000..56b3ab3 --- /dev/null +++ b/distant-net/src/transport/framed/write.rs @@ -0,0 +1,72 @@ +use crate::{transport::framed::utils, Codec, UntypedTransportWrite}; +use async_trait::async_trait; +use futures::SinkExt; +use serde::Serialize; +use std::io; +use tokio::io::AsyncWrite; +use tokio_util::codec::FramedWrite; + +/// Represents a transport of outbound data to the network using frames in order to support +/// typed messages instead of arbitrary bytes being sent across the wire. +/// +/// Note that this type does **not** implement [`AsyncWrite`] and instead acts as a +/// wrapper to provide a higher-level interface +pub struct FramedTransportWriteHalf(pub(super) FramedWrite) +where + W: AsyncWrite, + C: Codec; + +#[async_trait] +impl UntypedTransportWrite for FramedTransportWriteHalf +where + W: AsyncWrite + Send + Unpin, + C: Codec + Send, +{ + async fn write(&mut self, data: D) -> io::Result<()> + where + D: Serialize + Send + 'static, + { + // Serialize data into a byte stream + // NOTE: Cannot used packed implementation for now due to issues with deserialization + let data = utils::serialize_to_vec(&data)?; + + // Use underlying codec to send data (may encrypt, sign, etc.) + self.0.send(&data).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{FramedTransport, InmemoryTransport, IntoSplit, PlainCodec}; + use serde::{Deserialize, Serialize}; + + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] + pub struct TestData { + name: String, + value: usize, + } + + #[tokio::test] + async fn send_should_convert_data_into_byte_stream_and_send_through_stream() { + let (_tx, mut rx, stream) = InmemoryTransport::make(1); + let transport = FramedTransport::new(stream, PlainCodec::new()); + let (mut wh, _) = transport.into_split(); + + let data = TestData { + name: String::from("test"), + value: 123, + }; + + let bytes = utils::serialize_to_vec(&data).unwrap(); + let len = (bytes.len() as u64).to_be_bytes(); + let mut frame = Vec::new(); + frame.extend(len.iter().copied()); + frame.extend(bytes); + + wh.write(data).await.unwrap(); + + let outgoing = rx.recv().await.unwrap(); + assert_eq!(outgoing, frame); + } +} diff --git a/distant-net/src/transport/inmemory.rs b/distant-net/src/transport/inmemory.rs new file mode 100644 index 0000000..81169b6 --- /dev/null +++ b/distant-net/src/transport/inmemory.rs @@ -0,0 +1,225 @@ +use crate::{ + FramedTransport, IntoSplit, PlainCodec, RawTransport, RawTransportRead, RawTransportWrite, +}; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + sync::mpsc, +}; + +mod read; +pub use read::*; + +mod write; +pub use write::*; + +/// Represents a [`RawTransport`] comprised of two inmemory channels +#[derive(Debug)] +pub struct InmemoryTransport { + incoming: InmemoryTransportReadHalf, + outgoing: InmemoryTransportWriteHalf, +} + +impl InmemoryTransport { + pub fn new(incoming: mpsc::Receiver>, outgoing: mpsc::Sender>) -> Self { + Self { + incoming: InmemoryTransportReadHalf::new(incoming), + outgoing: InmemoryTransportWriteHalf::new(outgoing), + } + } + + /// Returns (incoming_tx, outgoing_rx, transport) + pub fn make(buffer: usize) -> (mpsc::Sender>, mpsc::Receiver>, Self) { + let (incoming_tx, incoming_rx) = mpsc::channel(buffer); + let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer); + + ( + incoming_tx, + outgoing_rx, + Self::new(incoming_rx, outgoing_tx), + ) + } + + /// Returns pair of transports that are connected such that one sends to the other and + /// vice versa + pub fn pair(buffer: usize) -> (Self, Self) { + let (tx, rx, transport) = Self::make(buffer); + (transport, Self::new(rx, tx)) + } +} + +impl RawTransport for InmemoryTransport {} +impl RawTransportRead for InmemoryTransport {} +impl RawTransportWrite for InmemoryTransport {} +impl IntoSplit for InmemoryTransport { + type Read = InmemoryTransportReadHalf; + type Write = InmemoryTransportWriteHalf; + + fn into_split(self) -> (Self::Write, Self::Read) { + (self.outgoing, self.incoming) + } +} + +impl AsyncRead for InmemoryTransport { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.incoming).poll_read(cx, buf) + } +} + +impl AsyncWrite for InmemoryTransport { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.outgoing).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.outgoing).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.outgoing).poll_shutdown(cx) + } +} + +impl FramedTransport { + /// Produces a pair of inmemory transports that are connected to each other using + /// a standard codec + /// + /// Sets the buffer for message passing for each underlying transport to the given buffer size + pub fn pair( + buffer: usize, + ) -> ( + FramedTransport, + FramedTransport, + ) { + let (a, b) = InmemoryTransport::pair(buffer); + let a = FramedTransport::new(a, PlainCodec::new()); + let b = FramedTransport::new(b, PlainCodec::new()); + (a, b) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[tokio::test] + async fn make_should_return_sender_that_sends_data_to_transport() { + let (tx, _, mut transport) = InmemoryTransport::make(3); + + tx.send(b"test msg 1".to_vec()).await.unwrap(); + tx.send(b"test msg 2".to_vec()).await.unwrap(); + tx.send(b"test msg 3".to_vec()).await.unwrap(); + + // Should get data matching a singular message + let mut buf = [0; 256]; + let len = transport.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 1"); + + // Next call would get the second message + let len = transport.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 2"); + + // When the last of the senders is dropped, we should still get + // the rest of the data that was sent first before getting + // an indicator that there is no more data + drop(tx); + + let len = transport.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 3"); + + let len = transport.read(&mut buf).await.unwrap(); + assert_eq!(len, 0, "Unexpectedly got more data"); + } + + #[tokio::test] + async fn make_should_return_receiver_that_receives_data_from_transport() { + let (_, mut rx, mut transport) = InmemoryTransport::make(3); + + transport.write_all(b"test msg 1").await.unwrap(); + transport.write_all(b"test msg 2").await.unwrap(); + transport.write_all(b"test msg 3").await.unwrap(); + + // Should get data matching a singular message + assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec())); + + // Next call would get the second message + assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec())); + + // When the transport is dropped, we should still get + // the rest of the data that was sent first before getting + // an indicator that there is no more data + drop(transport); + + assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec())); + + assert_eq!(rx.recv().await, None, "Unexpectedly got more data"); + } + + #[tokio::test] + async fn into_split_should_provide_a_read_half_that_receives_from_sender() { + let (tx, _, transport) = InmemoryTransport::make(3); + let (_, mut read_half) = transport.into_split(); + + tx.send(b"test msg 1".to_vec()).await.unwrap(); + tx.send(b"test msg 2".to_vec()).await.unwrap(); + tx.send(b"test msg 3".to_vec()).await.unwrap(); + + // Should get data matching a singular message + let mut buf = [0; 256]; + let len = read_half.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 1"); + + // Next call would get the second message + let len = read_half.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 2"); + + // When the last of the senders is dropped, we should still get + // the rest of the data that was sent first before getting + // an indicator that there is no more data + drop(tx); + + let len = read_half.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 3"); + + let len = read_half.read(&mut buf).await.unwrap(); + assert_eq!(len, 0, "Unexpectedly got more data"); + } + + #[tokio::test] + async fn into_split_should_provide_a_write_half_that_sends_to_receiver() { + let (_, mut rx, transport) = InmemoryTransport::make(3); + let (mut write_half, _) = transport.into_split(); + + write_half.write_all(b"test msg 1").await.unwrap(); + write_half.write_all(b"test msg 2").await.unwrap(); + write_half.write_all(b"test msg 3").await.unwrap(); + + // Should get data matching a singular message + assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec())); + + // Next call would get the second message + assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec())); + + // When the transport is dropped, we should still get + // the rest of the data that was sent first before getting + // an indicator that there is no more data + drop(write_half); + + assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec())); + + assert_eq!(rx.recv().await, None, "Unexpectedly got more data"); + } +} diff --git a/distant-net/src/transport/inmemory/read.rs b/distant-net/src/transport/inmemory/read.rs new file mode 100644 index 0000000..a05e0f9 --- /dev/null +++ b/distant-net/src/transport/inmemory/read.rs @@ -0,0 +1,249 @@ +use crate::RawTransportRead; +use futures::ready; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, ReadBuf}, + sync::mpsc, +}; + +/// Read portion of an inmemory channel +#[derive(Debug)] +pub struct InmemoryTransportReadHalf { + rx: mpsc::Receiver>, + overflow: Vec, +} + +impl InmemoryTransportReadHalf { + pub fn new(rx: mpsc::Receiver>) -> Self { + Self { + rx, + overflow: Vec::new(), + } + } +} + +impl RawTransportRead for InmemoryTransportReadHalf {} + +impl AsyncRead for InmemoryTransportReadHalf { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // If we cannot fit any more into the buffer at the moment, we wait + if buf.remaining() == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "Cannot poll as buf.remaining() == 0", + ))); + } + + // If we have overflow from the last poll, put that in the buffer + if !self.overflow.is_empty() { + if self.overflow.len() > buf.remaining() { + let extra = self.overflow.split_off(buf.remaining()); + buf.put_slice(&self.overflow); + self.overflow = extra; + } else { + buf.put_slice(&self.overflow); + self.overflow.clear(); + } + + return Poll::Ready(Ok(())); + } + + // Otherwise, we poll for the next batch to read in + match ready!(self.rx.poll_recv(cx)) { + Some(mut x) => { + if x.len() > buf.remaining() { + self.overflow = x.split_off(buf.remaining()); + } + buf.put_slice(&x); + Poll::Ready(Ok(())) + } + None => Poll::Ready(Ok(())), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{InmemoryTransport, IntoSplit}; + use tokio::io::AsyncReadExt; + + #[tokio::test] + async fn read_half_should_fail_if_buf_has_no_space_remaining() { + let (_tx, _rx, transport) = InmemoryTransport::make(1); + let (_t_write, mut t_read) = transport.into_split(); + + let mut buf = [0u8; 0]; + match t_read.read(&mut buf).await { + Err(x) if x.kind() == io::ErrorKind::Other => {} + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn read_half_should_update_buf_with_all_overflow_from_last_read_if_it_all_fits() { + let (tx, _rx, transport) = InmemoryTransport::make(1); + let (_t_write, mut t_read) = transport.into_split(); + + tx.send(vec![1, 2, 3]).await.expect("Failed to send"); + + let mut buf = [0u8; 2]; + + // First, read part of the data (first two bytes) + match t_read.read(&mut buf).await { + Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]), + x => panic!("Unexpected result: {:?}", x), + } + + // Second, we send more data because the last message was placed in overflow + tx.send(vec![4, 5, 6]).await.expect("Failed to send"); + + // Third, read remainder of the overflow from first message (third byte) + match t_read.read(&mut buf).await { + Ok(n) if n == 1 => assert_eq!(&buf[..n], &[3]), + x => panic!("Unexpected result: {:?}", x), + } + + // Fourth, verify that we start to receive the next overflow + match t_read.read(&mut buf).await { + Ok(n) if n == 2 => assert_eq!(&buf[..n], &[4, 5]), + x => panic!("Unexpected result: {:?}", x), + } + + // Fifth, verify that we get the last bit of overflow + match t_read.read(&mut buf).await { + Ok(n) if n == 1 => assert_eq!(&buf[..n], &[6]), + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn read_half_should_update_buf_with_some_of_overflow_that_can_fit() { + let (tx, _rx, transport) = InmemoryTransport::make(1); + let (_t_write, mut t_read) = transport.into_split(); + + tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); + + let mut buf = [0u8; 2]; + + // First, read part of the data (first two bytes) + match t_read.read(&mut buf).await { + Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]), + x => panic!("Unexpected result: {:?}", x), + } + + // Second, we send more data because the last message was placed in overflow + tx.send(vec![6]).await.expect("Failed to send"); + + // Third, read next chunk of the overflow from first message (next two byte) + match t_read.read(&mut buf).await { + Ok(n) if n == 2 => assert_eq!(&buf[..n], &[3, 4]), + x => panic!("Unexpected result: {:?}", x), + } + + // Fourth, read last chunk of the overflow from first message (fifth byte) + match t_read.read(&mut buf).await { + Ok(n) if n == 1 => assert_eq!(&buf[..n], &[5]), + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn read_half_should_update_buf_with_all_of_inner_channel_when_it_fits() { + let (tx, _rx, transport) = InmemoryTransport::make(1); + let (_t_write, mut t_read) = transport.into_split(); + + let mut buf = [0u8; 5]; + + tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); + + // First, read all of data that fits exactly + match t_read.read(&mut buf).await { + Ok(n) if n == 5 => assert_eq!(&buf[..n], &[1, 2, 3, 4, 5]), + x => panic!("Unexpected result: {:?}", x), + } + + tx.send(vec![6, 7, 8]).await.expect("Failed to send"); + + // Second, read data that fits within buf + match t_read.read(&mut buf).await { + Ok(n) if n == 3 => assert_eq!(&buf[..n], &[6, 7, 8]), + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn read_half_should_update_buf_with_some_of_inner_channel_that_can_fit_and_add_rest_to_overflow( + ) { + let (tx, _rx, transport) = InmemoryTransport::make(1); + let (_t_write, mut t_read) = transport.into_split(); + + let mut buf = [0u8; 1]; + + tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); + + // Attempt a read that places more in overflow + match t_read.read(&mut buf).await { + Ok(n) if n == 1 => assert_eq!(&buf[..n], &[1]), + x => panic!("Unexpected result: {:?}", x), + } + + // Verify overflow contains the rest + assert_eq!(&t_read.overflow, &[2, 3, 4, 5]); + + // Queue up extra data that will not be read until overflow is finished + tx.send(vec![6, 7, 8]).await.expect("Failed to send"); + + // Read next data point + match t_read.read(&mut buf).await { + Ok(n) if n == 1 => assert_eq!(&buf[..n], &[2]), + x => panic!("Unexpected result: {:?}", x), + } + + // Verify overflow contains the rest without having added extra data + assert_eq!(&t_read.overflow, &[3, 4, 5]); + } + + #[tokio::test] + async fn read_half_should_yield_pending_if_no_data_available_on_inner_channel() { + let (_tx, _rx, transport) = InmemoryTransport::make(1); + let (_t_write, mut t_read) = transport.into_split(); + + let mut buf = [0u8; 1]; + + // Attempt a read that should yield ok with no change, which is what should + // happen when nothing is read into buf + let f = t_read.read(&mut buf); + tokio::pin!(f); + match futures::poll!(f) { + Poll::Pending => {} + x => panic!("Unexpected poll result: {:?}", x), + } + } + + #[tokio::test] + async fn read_half_should_not_update_buf_if_inner_channel_closed() { + let (tx, _rx, transport) = InmemoryTransport::make(1); + let (_t_write, mut t_read) = transport.into_split(); + + let mut buf = [0u8; 1]; + + // Drop the channel that would be sending data to the transport + drop(tx); + + // Attempt a read that should yield ok with no change, which is what should + // happen when nothing is read into buf + match t_read.read(&mut buf).await { + Ok(n) if n == 0 => assert_eq!(&buf, &[0]), + x => panic!("Unexpected result: {:?}", x), + } + } +} diff --git a/distant-net/src/transport/inmemory/write.rs b/distant-net/src/transport/inmemory/write.rs new file mode 100644 index 0000000..ae96e74 --- /dev/null +++ b/distant-net/src/transport/inmemory/write.rs @@ -0,0 +1,147 @@ +use crate::RawTransportWrite; +use futures::ready; +use std::{ + fmt, + future::Future, + io, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{io::AsyncWrite, sync::mpsc}; + +/// Write portion of an inmemory channel +pub struct InmemoryTransportWriteHalf { + tx: Option>>, + task: Option> + Send + Sync + 'static>>>, +} + +impl InmemoryTransportWriteHalf { + pub fn new(tx: mpsc::Sender>) -> Self { + Self { + tx: Some(tx), + task: None, + } + } +} + +impl fmt::Debug for InmemoryTransportWriteHalf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("InmemoryTransportWrite") + .field("tx", &self.tx) + .field( + "task", + &if self.tx.is_some() { + "Some(...)" + } else { + "None" + }, + ) + .finish() + } +} + +impl RawTransportWrite for InmemoryTransportWriteHalf {} + +impl AsyncWrite for InmemoryTransportWriteHalf { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match self.task.as_mut() { + Some(task) => { + let res = ready!(task.as_mut().poll(cx)); + self.task.take(); + return Poll::Ready(res); + } + None => match self.tx.as_mut() { + Some(tx) => { + let n = buf.len(); + let tx_2 = tx.clone(); + let data = buf.to_vec(); + let task = + Box::pin(async move { tx_2.send(data).await.map(|_| n).or(Ok(0)) }); + self.task.replace(task); + } + None => return Poll::Ready(Ok(0)), + }, + } + } + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + self.tx.take(); + self.task.take(); + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{InmemoryTransport, IntoSplit}; + use tokio::io::AsyncWriteExt; + + #[tokio::test] + async fn write_half_should_return_buf_len_if_can_send_immediately() { + let (_tx, mut rx, transport) = InmemoryTransport::make(1); + let (mut t_write, _t_read) = transport.into_split(); + + // Write that is not waiting should always succeed with full contents + let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); + assert_eq!(n, 3, "Unexpected byte count returned"); + + // Verify we actually had the data sent + let data = rx.try_recv().expect("Failed to recv data"); + assert_eq!(data, &[1, 2, 3]); + } + + #[tokio::test] + async fn write_half_should_return_support_eventually_sending_by_retrying_when_not_ready() { + let (_tx, mut rx, transport) = InmemoryTransport::make(1); + let (mut t_write, _t_read) = transport.into_split(); + + // Queue a write already so that we block on the next one + let _ = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); + + // Verify that the next write is pending + let f = t_write.write(&[4, 5]); + tokio::pin!(f); + match futures::poll!(&mut f) { + Poll::Pending => {} + x => panic!("Unexpected poll result: {:?}", x), + } + + // Consume first batch of data so future of second can continue + let data = rx.try_recv().expect("Failed to recv data"); + assert_eq!(data, &[1, 2, 3]); + + // Verify that poll now returns success + match futures::poll!(f) { + Poll::Ready(Ok(n)) if n == 2 => {} + x => panic!("Unexpected poll result: {:?}", x), + } + + // Consume second batch of data + let data = rx.try_recv().expect("Failed to recv data"); + assert_eq!(data, &[4, 5]); + } + + #[tokio::test] + async fn write_half_should_zero_if_inner_channel_closed() { + let (_tx, rx, transport) = InmemoryTransport::make(1); + let (mut t_write, _t_read) = transport.into_split(); + + // Drop receiving end that transport would talk to + drop(rx); + + // Channel is dropped, so return 0 to indicate no bytes sent + let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); + assert_eq!(n, 0, "Unexpected byte count returned"); + } +} diff --git a/distant-net/src/transport/mpsc.rs b/distant-net/src/transport/mpsc.rs new file mode 100644 index 0000000..bd50473 --- /dev/null +++ b/distant-net/src/transport/mpsc.rs @@ -0,0 +1,66 @@ +use crate::{IntoSplit, TypedAsyncRead, TypedAsyncWrite, TypedTransport}; +use async_trait::async_trait; +use std::io; +use tokio::sync::mpsc; + +mod read; +pub use read::*; + +mod write; +pub use write::*; + +/// Represents a [`TypedTransport`] of data across the network that uses [`mpsc::Sender`] and +/// [`mpsc::Receiver`] underneath. +#[derive(Debug)] +pub struct MpscTransport { + outbound: MpscTransportWriteHalf, + inbound: MpscTransportReadHalf, +} + +impl MpscTransport { + pub fn new(outbound: mpsc::Sender, inbound: mpsc::Receiver) -> Self { + Self { + outbound: MpscTransportWriteHalf::new(outbound), + inbound: MpscTransportReadHalf::new(inbound), + } + } + + /// Creates a pair of connected transports using `buffer` as maximum + /// channel capacity for each + pub fn pair(buffer: usize) -> (MpscTransport, MpscTransport) { + let (t_tx, t_rx) = mpsc::channel(buffer); + let (u_tx, u_rx) = mpsc::channel(buffer); + ( + MpscTransport::new(t_tx, u_rx), + MpscTransport::new(u_tx, t_rx), + ) + } +} + +impl TypedTransport for MpscTransport {} + +#[async_trait] +impl TypedAsyncWrite for MpscTransport { + async fn write(&mut self, data: T) -> io::Result<()> { + self.outbound + .write(data) + .await + .map_err(|x| io::Error::new(io::ErrorKind::Other, x)) + } +} + +#[async_trait] +impl TypedAsyncRead for MpscTransport { + async fn read(&mut self) -> io::Result> { + self.inbound.read().await + } +} + +impl IntoSplit for MpscTransport { + type Read = MpscTransportReadHalf; + type Write = MpscTransportWriteHalf; + + fn into_split(self) -> (Self::Write, Self::Read) { + (self.outbound, self.inbound) + } +} diff --git a/distant-net/src/transport/mpsc/read.rs b/distant-net/src/transport/mpsc/read.rs new file mode 100644 index 0000000..da7e41f --- /dev/null +++ b/distant-net/src/transport/mpsc/read.rs @@ -0,0 +1,22 @@ +use crate::TypedAsyncRead; +use async_trait::async_trait; +use std::io; +use tokio::sync::mpsc; + +#[derive(Debug)] +pub struct MpscTransportReadHalf { + rx: mpsc::Receiver, +} + +impl MpscTransportReadHalf { + pub fn new(rx: mpsc::Receiver) -> Self { + Self { rx } + } +} + +#[async_trait] +impl TypedAsyncRead for MpscTransportReadHalf { + async fn read(&mut self) -> io::Result> { + Ok(self.rx.recv().await) + } +} diff --git a/distant-net/src/transport/mpsc/write.rs b/distant-net/src/transport/mpsc/write.rs new file mode 100644 index 0000000..7801268 --- /dev/null +++ b/distant-net/src/transport/mpsc/write.rs @@ -0,0 +1,25 @@ +use crate::TypedAsyncWrite; +use async_trait::async_trait; +use std::io; +use tokio::sync::mpsc; + +#[derive(Debug)] +pub struct MpscTransportWriteHalf { + tx: mpsc::Sender, +} + +impl MpscTransportWriteHalf { + pub fn new(tx: mpsc::Sender) -> Self { + Self { tx } + } +} + +#[async_trait] +impl TypedAsyncWrite for MpscTransportWriteHalf { + async fn write(&mut self, data: T) -> io::Result<()> { + self.tx + .send(data) + .await + .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string())) + } +} diff --git a/distant-net/src/transport/router.rs b/distant-net/src/transport/router.rs new file mode 100644 index 0000000..4f1e4c4 --- /dev/null +++ b/distant-net/src/transport/router.rs @@ -0,0 +1,370 @@ +/// Creates a new struct around a [`UntypedTransport`](crate::UntypedTransport) that routes incoming +/// and outgoing messages to different transports, enabling the ability to transform a singular +/// transport into multiple typed transports that can be combined with [`Client`](crate::Client) +/// and [`Server`](crate::Server) to mix having a variety of clients and servers available on the +/// same underlying [`UntypedTransport`](crate::UntypedTransport). +/// +/// ```no_run +/// use distant_net::router; +/// +/// # // To send, the data needs to be serializable +/// # // To receive, the data needs to be deserializable +/// # #[derive(serde::Serialize, serde::Deserialize)] +/// # struct CustomData(u8, u8); +/// +/// // Create a router that produces three transports from one: +/// // 1. `Transport` - receives `String` and sends `u8` +/// // 2. `Transport` - receives `CustomData` and sends `bool` +/// // 3. `Transport, u8>` - receives `u8` and sends `Option` +/// router!(TestRouter { +/// one: String => u8, +/// two: CustomData => bool, +/// three: u8 => Option, +/// }); +/// +/// router!( +/// #[router(inbound = 10, outbound = 20)] +/// TestRouterWithCustomBounds { +/// one: String => u8, +/// two: CustomData => bool, +/// three: u8 => Option, +/// } +/// ); +/// +/// # let (transport, _) = distant_net::FramedTransport::pair(1); +/// +/// let router = TestRouter::new(transport); +/// +/// let one = router.one; // MpscTransport +/// let two = router.two; // MpscTransport +/// let three = router.three; // MpscTransport, u8> +/// ``` +#[macro_export] +macro_rules! router { + ( + $(#[router($($mname:ident = $mvalue:literal),*)])? + $vis:vis $name:ident { + $($transport:ident : $res_ty:ty => $req_ty:ty),+ $(,)? + } + ) => { + $crate::paste::paste! { + #[doc = "Implements a message router for splitting out transport messages"] + #[allow(dead_code)] + $vis struct $name { + reader_task: tokio::task::JoinHandle<()>, + writer_task: tokio::task::JoinHandle<()>, + $( + pub $transport: $crate::MpscTransport<$req_ty, $res_ty>, + )+ + } + + #[allow(dead_code)] + impl $name { + /// Returns the size of the inbound buffer used by this router + pub const fn inbound_buffer_size() -> usize { + Self::buffer_sizes().0 + } + + /// Returns the size of the outbound buffer used by this router + pub const fn outbound_buffer_size() -> usize { + Self::buffer_sizes().1 + } + + /// Returns the size of the inbound and outbound buffers used by this router + /// in the form of `(inbound, outbound)` + pub const fn buffer_sizes() -> (usize, usize) { + // Set defaults for inbound and outbound buffer sizes + let _inbound = 10000; + let _outbound = 10000; + + $($( + let [<_ $mname:snake>] = $mvalue; + )*)? + + (_inbound, _outbound) + } + + #[doc = "Creates a new instance of [`" $name "`]"] + pub fn new(split: T) -> Self + where + T: $crate::IntoSplit, + W: $crate::UntypedTransportWrite + 'static, + R: $crate::UntypedTransportRead + 'static, + { + let (writer, reader) = split.into_split(); + Self::from_writer_and_reader(writer, reader) + } + + #[doc = "Creates a new instance of [`" $name "`] from the given writer and reader"] + pub fn from_writer_and_reader(mut writer: W, mut reader: R) -> Self + where + W: $crate::UntypedTransportWrite + 'static, + R: $crate::UntypedTransportRead + 'static, + { + + $( + let ( + [<$transport:snake _inbound_tx>], + [<$transport:snake _inbound_rx>] + ) = tokio::sync::mpsc::channel(Self::inbound_buffer_size()); + let ( + [<$transport:snake _outbound_tx>], + mut [<$transport:snake _outbound_rx>] + ) = tokio::sync::mpsc::channel(Self::outbound_buffer_size()); + let [<$transport:snake>]: $crate::MpscTransport<$req_ty, $res_ty> = + $crate::MpscTransport::new( + [<$transport:snake _outbound_tx>], + [<$transport:snake _inbound_rx>] + ); + )+ + + #[derive(serde::Deserialize)] + #[serde(untagged)] + enum [<$name:camel In>] { + $([<$transport:camel>]($res_ty)),+ + } + + let reader_task = tokio::spawn(async move { + loop { + match $crate::UntypedTransportRead::read(&mut reader).await { + $( + Ok(Some([<$name:camel In>]::[<$transport:camel>](x))) => { + if let Err(x) = [<$transport:snake _inbound_tx>].send(x).await { + $crate::log::error!( + "Failed to forward received data from {} of {}: {}", + std::stringify!($transport), + std::stringify!($name), + x + ); + } + } + )+ + + // Quit if the reader no longer has data + // NOTE: Compiler says this is unreachable, but it is? + #[allow(unreachable_patterns)] + Ok(None) => { + $crate::log::trace!( + "Router {} has closed", + std::stringify!($name), + ); + break; + } + + // Drop any received data that does not map to something + // NOTE: Compiler says this is unreachable, but it is? + #[allow(unreachable_patterns)] + Err(x) => { + $crate::log::error!( + "Failed to read from any transport of {}: {}", + std::stringify!($name), + x + ); + continue; + } + } + } + }); + + let writer_task = tokio::spawn(async move { + loop { + tokio::select! { + $( + Some(x) = [<$transport:snake _outbound_rx>].recv() => { + if let Err(x) = $crate::UntypedTransportWrite::write( + &mut writer, + x, + ).await { + $crate::log::error!( + "Failed to write to {} of {}: {}", + std::stringify!($transport), + std::stringify!($name), + x + ); + } + } + )+ + else => break, + } + } + }); + + Self { + reader_task, + writer_task, + $([<$transport:snake>]),+ + } + } + + pub fn abort(&self) { + self.reader_task.abort(); + self.writer_task.abort(); + } + + pub fn is_finished(&self) -> bool { + self.reader_task.is_finished() && self.writer_task.is_finished() + } + } + } + }; +} + +#[cfg(test)] +mod tests { + use crate::{FramedTransport, TypedAsyncRead, TypedAsyncWrite}; + use serde::{Deserialize, Serialize}; + + // NOTE: Must implement deserialize for our router, + // but we also need serialize to send for our test + #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] + struct CustomData(u8, String); + + // Creates a private `TestRouter` implementation + // + // 1. Transport receiving `CustomData` and sending `String` + // 2. Transport receiving `String` and sending `u8` + // 3. Transport receiving `bool` and sending `bool` + // 4. Transport receiving `Result` and sending `Option` + router!(TestRouter { + one: CustomData => String, + two: String => u8, + three: bool => bool, + should_compile: Result => Option, + }); + + #[test] + fn router_buffer_sizes_should_support_being_overridden() { + router!(DefaultSizes { data: u8 => u8 }); + router!(#[router(inbound = 5)] CustomInboundSize { data: u8 => u8 }); + router!(#[router(outbound = 5)] CustomOutboundSize { data: u8 => u8 }); + router!(#[router(inbound = 5, outbound = 6)] CustomSizes { data: u8 => u8 }); + + assert_eq!(DefaultSizes::buffer_sizes(), (10000, 10000)); + assert_eq!(DefaultSizes::inbound_buffer_size(), 10000); + assert_eq!(DefaultSizes::outbound_buffer_size(), 10000); + + assert_eq!(CustomInboundSize::buffer_sizes(), (5, 10000)); + assert_eq!(CustomInboundSize::inbound_buffer_size(), 5); + assert_eq!(CustomInboundSize::outbound_buffer_size(), 10000); + + assert_eq!(CustomOutboundSize::buffer_sizes(), (10000, 5)); + assert_eq!(CustomOutboundSize::inbound_buffer_size(), 10000); + assert_eq!(CustomOutboundSize::outbound_buffer_size(), 5); + + assert_eq!(CustomSizes::buffer_sizes(), (5, 6)); + assert_eq!(CustomSizes::inbound_buffer_size(), 5); + assert_eq!(CustomSizes::outbound_buffer_size(), 6); + } + + #[tokio::test] + async fn router_should_wire_transports_to_distinguish_incoming_data() { + let (t1, mut t2) = FramedTransport::make_test_pair(); + let TestRouter { + mut one, + mut two, + mut three, + .. + } = TestRouter::new(t1); + + // Send some data of different types that these transports expect + t2.write(false).await.unwrap(); + t2.write("hello world".to_string()).await.unwrap(); + t2.write(CustomData(123, "goodbye world".to_string())) + .await + .unwrap(); + + // Get that data through the appropriate transport + let data = one.read().await.unwrap().unwrap(); + assert_eq!( + data, + CustomData(123, "goodbye world".to_string()), + "string_custom_data_transport got unexpected result" + ); + + let data = two.read().await.unwrap().unwrap(); + assert_eq!( + data, "hello world", + "u8_string_transport got unexpected result" + ); + + let data = three.read().await.unwrap().unwrap(); + assert!(!data, "bool_bool_transport got unexpected result"); + } + + #[tokio::test] + async fn router_should_wire_transports_to_ignore_unknown_incoming_data() { + let (t1, mut t2) = FramedTransport::make_test_pair(); + let TestRouter { + mut one, mut two, .. + } = TestRouter::new(t1); + + #[derive(Serialize, Deserialize)] + struct UnknownData(char, u8); + + // Send some known and unknown data + t2.write("hello world".to_string()).await.unwrap(); + t2.write(UnknownData('a', 99)).await.unwrap(); + t2.write(CustomData(123, "goodbye world".to_string())) + .await + .unwrap(); + + // Get that data through the appropriate transport + let data = one.read().await.unwrap().unwrap(); + assert_eq!( + data, + CustomData(123, "goodbye world".to_string()), + "string_custom_data_transport got unexpected result" + ); + + let data = two.read().await.unwrap().unwrap(); + assert_eq!( + data, "hello world", + "u8_string_transport got unexpected result" + ); + } + + #[tokio::test] + async fn router_should_wire_transports_to_relay_outgoing_data() { + let (t1, mut t2) = FramedTransport::make_test_pair(); + let TestRouter { + mut one, + mut two, + mut three, + .. + } = TestRouter::new(t1); + + // NOTE: Introduce a sleep between each send, otherwise we are + // resolving futures in a way where the ordering may + // get mixed up on the way out + async fn wait() { + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + + // Send some data of different types that these transports expect + three.write(true).await.unwrap(); + wait().await; + two.write(123).await.unwrap(); + wait().await; + one.write("hello world".to_string()).await.unwrap(); + + // All of that data should funnel through our primary transport, + // but the order is NOT guaranteed! So we need to store + let data: bool = t2.read().await.unwrap().unwrap(); + assert!( + data, + "Unexpected data received from bool_bool_transport output" + ); + + let data: u8 = t2.read().await.unwrap().unwrap(); + assert_eq!( + data, 123, + "Unexpected data received from u8_string_transport output" + ); + + let data: String = t2.read().await.unwrap().unwrap(); + assert_eq!( + data, "hello world", + "Unexpected data received from string_custom_data_transport output" + ); + } +} diff --git a/distant-net/src/transport/tcp.rs b/distant-net/src/transport/tcp.rs new file mode 100644 index 0000000..909e1e2 --- /dev/null +++ b/distant-net/src/transport/tcp.rs @@ -0,0 +1,196 @@ +use crate::{IntoSplit, RawTransport, RawTransportRead, RawTransportWrite}; +use std::{ + fmt, io, + net::IpAddr, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpStream, ToSocketAddrs, + }, +}; + +/// Represents a [`RawTransport`] that leverages a TCP stream +pub struct TcpTransport { + pub(crate) addr: IpAddr, + pub(crate) port: u16, + pub(crate) inner: TcpStream, +} + +impl TcpTransport { + /// Creates a new stream by connecting to a remote machine at the specified + /// IP address and port + pub async fn connect(addrs: impl ToSocketAddrs) -> io::Result { + let stream = TcpStream::connect(addrs).await?; + let addr = stream.peer_addr()?; + Ok(Self { + addr: addr.ip(), + port: addr.port(), + inner: stream, + }) + } + + /// Returns the IP address that the stream is connected to + pub fn ip_addr(&self) -> IpAddr { + self.addr + } + + /// Returns the port that the stream is connected to + pub fn port(&self) -> u16 { + self.port + } +} + +impl fmt::Debug for TcpTransport { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TcpTransport") + .field("addr", &self.addr) + .field("port", &self.port) + .finish() + } +} + +impl RawTransport for TcpTransport {} +impl RawTransportRead for TcpTransport {} +impl RawTransportWrite for TcpTransport {} + +impl RawTransportRead for OwnedReadHalf {} +impl RawTransportWrite for OwnedWriteHalf {} + +impl IntoSplit for TcpTransport { + type Read = OwnedReadHalf; + type Write = OwnedWriteHalf; + + fn into_split(self) -> (Self::Write, Self::Read) { + let (r, w) = TcpStream::into_split(self.inner); + (w, r) + } +} + +impl AsyncRead for TcpTransport { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for TcpTransport { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv6Addr, SocketAddr}; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpListener, + sync::oneshot, + task::JoinHandle, + }; + + async fn find_ephemeral_addr() -> SocketAddr { + // Start a listener on a distinct port, get its port, and kill it + // NOTE: This is a race condition as something else could bind to + // this port inbetween us killing it and us attempting to + // connect to it. We're willing to take that chance + let addr = IpAddr::V6(Ipv6Addr::LOCALHOST); + + let listener = TcpListener::bind((addr, 0)) + .await + .expect("Failed to bind on an ephemeral port"); + + let port = listener + .local_addr() + .expect("Failed to look up ephemeral port") + .port(); + + SocketAddr::from((addr, port)) + } + + #[tokio::test] + async fn should_fail_to_connect_if_nothing_listening() { + let addr = find_ephemeral_addr().await; + + // Now this should fail as we've stopped what was listening + TcpTransport::connect(addr).await.expect_err(&format!( + "Unexpectedly succeeded in connecting to ghost address: {}", + addr + )); + } + + #[tokio::test] + async fn should_be_able_to_send_and_receive_data() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for a connection, send data, + // and receive data that it will return in the task + let task: JoinHandle> = tokio::spawn(async move { + let addr = find_ephemeral_addr().await; + + // Start listening at the distinct address + let listener = TcpListener::bind(addr).await?; + + // Send the address back to our main test thread + tx.send(addr) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string()))?; + + // Get the connection + let (mut conn, _) = listener.accept().await?; + + // Send some data to the connection (10 bytes) + conn.write_all(b"hello conn").await?; + + // Receive some data from the connection (12 bytes) + let mut buf: [u8; 12] = [0; 12]; + let _ = conn.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server"); + + Ok(()) + }); + + // Wait for the server to be ready + let addr = rx.await.expect("Failed to get server server address"); + + // Connect to the socket, send some bytes, and get some bytes + let mut buf: [u8; 10] = [0; 10]; + + let mut conn = TcpTransport::connect(&addr) + .await + .expect("Conn failed to connect"); + conn.read_exact(&mut buf) + .await + .expect("Conn failed to read"); + assert_eq!(&buf, b"hello conn"); + + conn.write_all(b"hello server") + .await + .expect("Conn failed to write"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Server task failed unexpectedly"); + } +} diff --git a/distant-net/src/transport/unix.rs b/distant-net/src/transport/unix.rs new file mode 100644 index 0000000..a20142d --- /dev/null +++ b/distant-net/src/transport/unix.rs @@ -0,0 +1,187 @@ +use crate::{IntoSplit, RawTransport, RawTransportRead, RawTransportWrite}; +use std::{ + fmt, io, + path::{Path, PathBuf}, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::{ + unix::{OwnedReadHalf, OwnedWriteHalf}, + UnixStream, + }, +}; + +/// Represents a [`RawTransport`] that leverages a Unix socket +pub struct UnixSocketTransport { + pub(crate) path: PathBuf, + pub(crate) inner: UnixStream, +} + +impl UnixSocketTransport { + /// Creates a new stream by connecting to the specified path + pub async fn connect(path: impl AsRef) -> io::Result { + let stream = UnixStream::connect(path.as_ref()).await?; + Ok(Self { + path: path.as_ref().to_path_buf(), + inner: stream, + }) + } + + /// Returns the path to the socket + pub fn path(&self) -> &Path { + &self.path + } +} + +impl fmt::Debug for UnixSocketTransport { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UnixSocketTransport") + .field("path", &self.path) + .finish() + } +} + +impl RawTransport for UnixSocketTransport {} +impl RawTransportRead for UnixSocketTransport {} +impl RawTransportWrite for UnixSocketTransport {} + +impl RawTransportRead for OwnedReadHalf {} +impl RawTransportWrite for OwnedWriteHalf {} + +impl IntoSplit for UnixSocketTransport { + type Read = OwnedReadHalf; + type Write = OwnedWriteHalf; + + fn into_split(self) -> (Self::Write, Self::Read) { + let (r, w) = UnixStream::into_split(self.inner); + (w, r) + } +} + +impl AsyncRead for UnixSocketTransport { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for UnixSocketTransport { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::UnixListener, + sync::oneshot, + task::JoinHandle, + }; + + #[tokio::test] + async fn should_fail_to_connect_if_socket_does_not_exist() { + // Generate a socket path and delete the file after so there is nothing there + let path = NamedTempFile::new() + .expect("Failed to create socket file") + .path() + .to_path_buf(); + + // Now this should fail as we're already bound to the name + UnixSocketTransport::connect(&path) + .await + .expect_err("Unexpectedly succeeded in connecting to missing socket"); + } + + #[tokio::test] + async fn should_fail_to_connect_if_path_is_not_a_socket() { + // Generate a regular file + let path = NamedTempFile::new() + .expect("Failed to create socket file") + .into_temp_path(); + + // Now this should fail as this file is not a socket + UnixSocketTransport::connect(&path) + .await + .expect_err("Unexpectedly succeeded in connecting to regular file"); + } + + #[tokio::test] + async fn should_be_able_to_send_and_receive_data() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for a connection, send data, + // and receive data that it will return in the task + let task: JoinHandle> = tokio::spawn(async move { + // Generate a socket path and delete the file after so there is nothing there + let path = NamedTempFile::new() + .expect("Failed to create socket file") + .path() + .to_path_buf(); + + // Start listening at the socket path + let socket = UnixListener::bind(&path)?; + + // Send the path back to our main test thread + tx.send(path) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?; + + // Get the connection + let (mut conn, _) = socket.accept().await?; + + // Send some data to the connection (10 bytes) + conn.write_all(b"hello conn").await?; + + // Receive some data from the connection (12 bytes) + let mut buf: [u8; 12] = [0; 12]; + let _ = conn.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server"); + + Ok(()) + }); + + // Wait for the server to be ready + let path = rx.await.expect("Failed to get server socket path"); + + // Connect to the socket, send some bytes, and get some bytes + let mut buf: [u8; 10] = [0; 10]; + + let mut conn = UnixSocketTransport::connect(&path) + .await + .expect("Conn failed to connect"); + conn.read_exact(&mut buf) + .await + .expect("Conn failed to read"); + assert_eq!(&buf, b"hello conn"); + + conn.write_all(b"hello server") + .await + .expect("Conn failed to write"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Server task failed unexpectedly"); + } +} diff --git a/distant-net/src/transport/untyped.rs b/distant-net/src/transport/untyped.rs new file mode 100644 index 0000000..dfe871a --- /dev/null +++ b/distant-net/src/transport/untyped.rs @@ -0,0 +1,61 @@ +use crate::{TypedAsyncRead, TypedAsyncWrite, TypedTransport}; +use async_trait::async_trait; +use serde::{de::DeserializeOwned, Serialize}; +use std::io; + +/// Interface representing a transport that uses [`serde`] to serialize and deserialize data +/// as it is sent and received +pub trait UntypedTransport: UntypedTransportRead + UntypedTransportWrite {} + +/// Interface representing a transport's read half that uses [`serde`] to deserialize data as it is +/// received +#[async_trait] +pub trait UntypedTransportRead: Send + Unpin { + /// Attempts to read some data as `T`, returning [`io::Error`] if unable to deserialize + /// or some other error occurs. `Some(T)` is returned if successful. `None` is + /// returned if no more data is available. + async fn read(&mut self) -> io::Result> + where + T: DeserializeOwned; +} + +/// Interface representing a transport's write half that uses [`serde`] to serialize data as it is +/// sent +#[async_trait] +pub trait UntypedTransportWrite: Send + Unpin { + /// Attempts to write some data of type `T`, returning [`io::Error`] if unable to serialize + /// or some other error occurs. + async fn write(&mut self, data: T) -> io::Result<()> + where + T: Serialize + Send + 'static; +} + +impl TypedTransport for T +where + T: UntypedTransport + Send, + W: Serialize + Send + 'static, + R: DeserializeOwned, +{ +} + +#[async_trait] +impl TypedAsyncWrite for W +where + W: UntypedTransportWrite + Send, + T: Serialize + Send + 'static, +{ + async fn write(&mut self, data: T) -> io::Result<()> { + W::write(self, data).await + } +} + +#[async_trait] +impl TypedAsyncRead for R +where + R: UntypedTransportRead + Send, + T: DeserializeOwned, +{ + async fn read(&mut self) -> io::Result> { + R::read(self).await + } +} diff --git a/distant-net/src/transport/windows.rs b/distant-net/src/transport/windows.rs new file mode 100644 index 0000000..e4d87f8 --- /dev/null +++ b/distant-net/src/transport/windows.rs @@ -0,0 +1,202 @@ +use crate::{IntoSplit, RawTransport, RawTransportRead, RawTransportWrite}; +use std::{ + ffi::{OsStr, OsString}, + fmt, io, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf}, + net::windows::named_pipe::ClientOptions, +}; + +// Equivalent to winapi::shared::winerror::ERROR_PIPE_BUSY +// DWORD -> c_uLong -> u32 +const ERROR_PIPE_BUSY: u32 = 231; + +// Time between attempts to connect to a busy pipe +const BUSY_PIPE_SLEEP_MILLIS: u64 = 50; + +mod pipe; +pub use pipe::NamedPipe; + +/// Represents a [`RawTransport`] that leverages a named Windows pipe (client or server) +pub struct WindowsPipeTransport { + pub(crate) addr: OsString, + pub(crate) inner: NamedPipe, +} + +impl WindowsPipeTransport { + /// Establishes a connection to the pipe with the specified name, using the + /// name for a local pipe address in the form of `\\.\pipe\my_pipe_name` where + /// `my_pipe_name` is provided to this function + pub async fn connect_local(name: impl AsRef) -> io::Result { + let mut addr = OsString::from(r"\\.\pipe\"); + addr.push(name.as_ref()); + Self::connect(addr).await + } + + /// Establishes a connection to the pipe at the specified address + /// + /// Address may be something like `\.\pipe\my_pipe_name` + pub async fn connect(addr: impl Into) -> io::Result { + let addr = addr.into(); + + let pipe = loop { + match ClientOptions::new().open(&addr) { + Ok(client) => break client, + Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), + Err(e) => return Err(e), + } + + tokio::time::sleep(Duration::from_millis(BUSY_PIPE_SLEEP_MILLIS)).await; + }; + + Ok(Self { + addr, + inner: NamedPipe::from(pipe), + }) + } + + /// Returns the addr that the listener is bound to + pub fn addr(&self) -> &OsStr { + &self.addr + } +} + +impl fmt::Debug for WindowsPipeTransport { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WindowsPipeTransport") + .field("addr", &self.addr) + .finish() + } +} + +impl RawTransport for WindowsPipeTransport {} +impl RawTransportRead for WindowsPipeTransport {} +impl RawTransportWrite for WindowsPipeTransport {} + +impl RawTransportRead for ReadHalf {} +impl RawTransportWrite for WriteHalf {} + +impl IntoSplit for WindowsPipeTransport { + type Read = ReadHalf; + type Write = WriteHalf; + + fn into_split(self) -> (Self::Write, Self::Read) { + let (reader, writer) = tokio::io::split(self); + (writer, reader) + } +} + +impl AsyncRead for WindowsPipeTransport { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for WindowsPipeTransport { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::windows::named_pipe::ServerOptions, + sync::oneshot, + task::JoinHandle, + }; + + #[tokio::test] + async fn should_fail_to_connect_if_pipe_does_not_exist() { + // Generate a pipe name + let name = format!("test_pipe_{}", rand::random::()); + + // Now this should fail as we're already bound to the name + WindowsPipeTransport::connect_local(&name) + .await + .expect_err("Unexpectedly succeeded in connecting to missing pipe"); + } + + #[tokio::test] + async fn should_be_able_to_send_and_receive_data() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for a connection, send data, + // and receive data that it will return in the task + let task: JoinHandle> = tokio::spawn(async move { + // Generate a pipe address (not just a name) + let addr = format!(r"\\.\pipe\test_pipe_{}", rand::random::()); + + // Listen at the pipe + let pipe = ServerOptions::new() + .first_pipe_instance(true) + .create(&addr)?; + + // Send the address back to our main test thread + tx.send(addr) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; + + // Get the connection + let mut conn = { + pipe.connect().await?; + pipe + }; + + // Send some data to the connection (10 bytes) + conn.write_all(b"hello conn").await?; + + // Receive some data from the connection (12 bytes) + let mut buf: [u8; 12] = [0; 12]; + let _ = conn.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server"); + + Ok(()) + }); + + // Wait for the server to be ready + let address = rx.await.expect("Failed to get server address"); + + // Connect to the pipe, send some bytes, and get some bytes + let mut buf: [u8; 10] = [0; 10]; + + let mut conn = WindowsPipeTransport::connect(&address) + .await + .expect("Conn failed to connect"); + conn.read_exact(&mut buf) + .await + .expect("Conn failed to read"); + assert_eq!(&buf, b"hello conn"); + + conn.write_all(b"hello server") + .await + .expect("Conn failed to write"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Server task failed unexpectedly"); + } +} diff --git a/distant-net/src/transport/windows/pipe.rs b/distant-net/src/transport/windows/pipe.rs new file mode 100644 index 0000000..532d0ed --- /dev/null +++ b/distant-net/src/transport/windows/pipe.rs @@ -0,0 +1,101 @@ +use derive_more::{From, TryInto}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{ + io::{self, AsyncRead, AsyncWrite, ReadBuf}, + net::windows::named_pipe::{NamedPipeClient, NamedPipeServer}, +}; + +#[derive(From, TryInto)] +pub enum NamedPipe { + Client(NamedPipeClient), + Server(NamedPipeServer), +} + +impl NamedPipe { + pub fn as_client(&self) -> Option<&NamedPipeClient> { + match self { + Self::Client(x) => Some(x), + _ => None, + } + } + + pub fn as_mut_client(&mut self) -> Option<&mut NamedPipeClient> { + match self { + Self::Client(x) => Some(x), + _ => None, + } + } + + pub fn into_client(self) -> Option { + match self { + Self::Client(x) => Some(x), + _ => None, + } + } + + pub fn as_server(&self) -> Option<&NamedPipeServer> { + match self { + Self::Server(x) => Some(x), + _ => None, + } + } + + pub fn as_mut_server(&mut self) -> Option<&mut NamedPipeServer> { + match self { + Self::Server(x) => Some(x), + _ => None, + } + } + + pub fn into_server(self) -> Option { + match self { + Self::Server(x) => Some(x), + _ => None, + } + } +} +impl AsyncRead for NamedPipe { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match Pin::get_mut(self) { + Self::Client(x) => Pin::new(x).poll_read(cx, buf), + Self::Server(x) => Pin::new(x).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for NamedPipe { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match Pin::get_mut(self) { + Self::Client(x) => Pin::new(x).poll_write(cx, buf), + Self::Server(x) => Pin::new(x).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match Pin::get_mut(self) { + Self::Client(x) => Pin::new(x).poll_flush(cx), + Self::Server(x) => Pin::new(x).poll_flush(cx), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match Pin::get_mut(self) { + Self::Client(x) => Pin::new(x).poll_shutdown(cx), + Self::Server(x) => Pin::new(x).poll_shutdown(cx), + } + } +} diff --git a/distant-net/src/utils.rs b/distant-net/src/utils.rs new file mode 100644 index 0000000..34b9074 --- /dev/null +++ b/distant-net/src/utils.rs @@ -0,0 +1,20 @@ +use serde::{de::DeserializeOwned, Serialize}; +use std::io; + +pub fn serialize_to_vec(value: &T) -> io::Result> { + rmp_serde::encode::to_vec_named(value).map_err(|x| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Serialize failed: {}", x), + ) + }) +} + +pub fn deserialize_from_slice(slice: &[u8]) -> io::Result { + rmp_serde::decode::from_slice(slice).map_err(|x| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Deserialize failed: {}", x), + ) + }) +} diff --git a/distant-net/tests/auth.rs b/distant-net/tests/auth.rs new file mode 100644 index 0000000..b666a66 --- /dev/null +++ b/distant-net/tests/auth.rs @@ -0,0 +1,169 @@ +use distant_net::{ + AuthClient, AuthErrorKind, AuthQuestion, AuthRequest, AuthServer, AuthVerifyKind, Client, + IntoSplit, MpscListener, MpscTransport, ServerExt, +}; +use std::collections::HashMap; +use tokio::sync::mpsc; + +/// Spawns a server and client connected together, returning the client +fn setup() -> (AuthClient, mpsc::Receiver) { + // Make a pair of inmemory transports that we can use to test client and server connected + let (t1, t2) = MpscTransport::pair(100); + + // Create the client + let (writer, reader) = t1.into_split(); + let client = AuthClient::from(Client::new(writer, reader).unwrap()); + + // Prepare a channel where we can pass back out whatever request we get + let (tx, rx) = mpsc::channel(100); + + let tx_2 = tx.clone(); + let tx_3 = tx.clone(); + let tx_4 = tx.clone(); + + // Make a server that echos questions back as answers and only verifies the text "yes" + let server = AuthServer { + on_challenge: move |questions, extra| { + let questions_2 = questions.clone(); + tx.try_send(AuthRequest::Challenge { questions, extra }) + .unwrap(); + questions_2.into_iter().map(|x| x.text).collect() + }, + on_verify: move |kind, text| { + let valid = text == "yes"; + tx_2.try_send(AuthRequest::Verify { kind, text }).unwrap(); + valid + }, + on_info: move |text| { + tx_3.try_send(AuthRequest::Info { text }).unwrap(); + }, + on_error: move |kind, text| { + tx_4.try_send(AuthRequest::Error { kind, text }).unwrap(); + }, + }; + + // Spawn the server to listen for our client to connect + tokio::spawn(async move { + let (writer, reader) = t2.into_split(); + let (tx, listener) = MpscListener::channel(1); + tx.send((writer, reader)).await.unwrap(); + let _server = server.start(listener).unwrap(); + }); + + (client, rx) +} + +#[tokio::test] +async fn client_should_be_able_to_challenge_against_server() { + let (mut client, mut rx) = setup(); + + // Gotta start with the handshake first + client.handshake().await.unwrap(); + + // Now do the challenge + assert_eq!( + client + .challenge( + vec![AuthQuestion::new("hello".to_string())], + Default::default() + ) + .await + .unwrap(), + vec!["hello".to_string()] + ); + + // Verify that the server received the request + let request = rx.recv().await.unwrap(); + match request { + AuthRequest::Challenge { questions, extra } => { + assert_eq!(questions.len(), 1); + assert_eq!(questions[0].text, "hello"); + assert_eq!(questions[0].extra, HashMap::new()); + + assert_eq!(extra, HashMap::new()); + } + x => panic!("Unexpected request received by server: {:?}", x), + } +} + +#[tokio::test] +async fn client_should_be_able_to_verify_against_server() { + let (mut client, mut rx) = setup(); + + // Gotta start with the handshake first + client.handshake().await.unwrap(); + + // "no" will yield false + assert!(!client + .verify(AuthVerifyKind::Host, "no".to_string()) + .await + .unwrap()); + + // Verify that the server received the request + let request = rx.recv().await.unwrap(); + match request { + AuthRequest::Verify { kind, text } => { + assert_eq!(kind, AuthVerifyKind::Host); + assert_eq!(text, "no"); + } + x => panic!("Unexpected request received by server: {:?}", x), + } + + // "yes" will yield true + assert!(client + .verify(AuthVerifyKind::Host, "yes".to_string()) + .await + .unwrap()); + + // Verify that the server received the request + let request = rx.recv().await.unwrap(); + match request { + AuthRequest::Verify { kind, text } => { + assert_eq!(kind, AuthVerifyKind::Host); + assert_eq!(text, "yes"); + } + x => panic!("Unexpected request received by server: {:?}", x), + } +} + +#[tokio::test] +async fn client_should_be_able_to_send_info_to_server() { + let (mut client, mut rx) = setup(); + + // Gotta start with the handshake first + client.handshake().await.unwrap(); + + // Send some information + client.info(String::from("hello, world")).await.unwrap(); + + // Verify that the server received the request + let request = rx.recv().await.unwrap(); + match request { + AuthRequest::Info { text } => assert_eq!(text, "hello, world"), + x => panic!("Unexpected request received by server: {:?}", x), + } +} + +#[tokio::test] +async fn client_should_be_able_to_send_error_to_server() { + let (mut client, mut rx) = setup(); + + // Gotta start with the handshake first + client.handshake().await.unwrap(); + + // Send some error + client + .error(AuthErrorKind::Unknown, String::from("hello, world")) + .await + .unwrap(); + + // Verify that the server received the request + let request = rx.recv().await.unwrap(); + match request { + AuthRequest::Error { kind, text } => { + assert_eq!(kind, AuthErrorKind::Unknown); + assert_eq!(text, "hello, world"); + } + x => panic!("Unexpected request received by server: {:?}", x), + } +} diff --git a/distant-net/tests/lib.rs b/distant-net/tests/lib.rs new file mode 100644 index 0000000..12bc9de --- /dev/null +++ b/distant-net/tests/lib.rs @@ -0,0 +1 @@ +mod auth; diff --git a/distant-ssh2/Cargo.toml b/distant-ssh2/Cargo.toml index c49c87a..7d1db0d 100644 --- a/distant-ssh2/Cargo.toml +++ b/distant-ssh2/Cargo.toml @@ -2,9 +2,9 @@ name = "distant-ssh2" description = "Library to enable native ssh-2 protocol for use with distant sessions" categories = ["network-programming"] -version = "0.16.4" +version = "0.17.0" authors = ["Chip Senkbeil "] -edition = "2018" +edition = "2021" homepage = "https://github.com/chipsenkbeil/distant" repository = "https://github.com/chipsenkbeil/distant" readme = "README.md" @@ -17,25 +17,34 @@ ssh2 = ["wezterm-ssh/ssh2", "wezterm-ssh/vendored-openssl-ssh2"] [dependencies] async-compat = "0.2.1" -distant-core = { version = "=0.16.4", path = "../distant-core" } +async-once-cell = "0.4.2" +async-trait = "0.1.56" +derive_more = { version = "0.99.17", default-features = false, features = ["display", "error"] } +distant-core = { version = "=0.17.0", path = "../distant-core" } futures = "0.3.16" +hex = "0.4.3" log = "0.4.14" +openssl-src = "=300.0.7" rand = { version = "0.8.4", features = ["getrandom"] } -rpassword = "5.0.1" +rpassword = "6.0.1" shell-words = "1.0" smol = "1.2" tokio = { version = "1.12.0", features = ["full"] } -wezterm-ssh = { version = "0.4.0", default-features = false } +wezterm-ssh = { version = "=0.4.0", default-features = false } +winsplit = "0.1" # Optional serde support for data structures serde = { version = "1.0.126", features = ["derive"], optional = true } [dev-dependencies] +anyhow = "1.0.58" assert_cmd = "2.0.0" assert_fs = "1.0.4" +dunce = "1.0.2" flexi_logger = "0.19.4" indoc = "1.0.3" once_cell = "1.8.0" predicates = "2.0.2" rstest = "0.11.0" +which = "4.2.5" whoami = "1.1.4" diff --git a/distant-ssh2/README.md b/distant-ssh2/README.md index 249bc9f..d2ab8f8 100644 --- a/distant-ssh2/README.md +++ b/distant-ssh2/README.md @@ -1,13 +1,13 @@ # distant ssh2 -[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.51.0][distant_rustc_img]][distant_rustc_lnk] +[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.61.0][distant_rustc_img]][distant_rustc_lnk] [distant_crates_img]: https://img.shields.io/crates/v/distant-ssh2.svg [distant_crates_lnk]: https://crates.io/crates/distant-ssh2 [distant_doc_img]: https://docs.rs/distant-ssh2/badge.svg [distant_doc_lnk]: https://docs.rs/distant-ssh2 -[distant_rustc_img]: https://img.shields.io/badge/distant_ssh2-rustc_1.51+-lightgray.svg -[distant_rustc_lnk]: https://blog.rust-lang.org/2021/03/25/Rust-1.51.0.html +[distant_rustc_img]: https://img.shields.io/badge/distant_ssh2-rustc_1.61+-lightgray.svg +[distant_rustc_lnk]: https://blog.rust-lang.org/2022/05/19/Rust-1.61.0.html Library provides native ssh integration into the [`distant`](https://github.com/chipsenkbeil/distant) binary. @@ -16,14 +16,9 @@ Library provides native ssh integration into the ## Details -The `distant-ssh2` library supplies functionality to - -- Asynchronous in nature, powered by [`tokio`](https://tokio.rs/) -- Data is serialized to send across the wire via [`CBOR`](https://cbor.io/) -- Encryption & authentication are handled via - [XChaCha20Poly1305](https://tools.ietf.org/html/rfc8439) for an authenticated - encryption scheme via - [RustCrypto/ChaCha20Poly1305](https://github.com/RustCrypto/AEADs/tree/master/chacha20poly1305) +The `distant-ssh2` library supplies functionality to speak over the `ssh` +protocol using `distant` and spawn `distant` servers on remote machines using +`ssh`. ## Installation @@ -31,52 +26,58 @@ You can import the dependency by adding the following to your `Cargo.toml`: ```toml [dependencies] -distant-ssh2 = "0.16" +distant-ssh2 = "0.17" ``` ## Examples -Below is an example of connecting to an ssh server and producing a distant -session that uses ssh without a distant server binary: +Below is an example of connecting to an ssh server and translating between ssh +protocol and distant protocol: ```rust -use distant_ssh2::Ssh2Session; +use distant_ssh2::{LocalSshAuthHandler, Ssh, SshOpts}; -// Using default ssh session arguments to establish a connection -let mut ssh_session = Ssh2Session::connect("example.com", Default::default()).expect("Failed to connect"); +// Using default ssh arguments to establish a connection +let mut ssh = Ssh::connect("example.com", SshOpts::default()) + .expect("Failed to connect"); // Authenticating with the server is a separate step -// 1. You can pass defaults and authentication and host verification will -// be done over stderr +// 1. You can pass the local handler and authentication and host verification +// will be done over stderr // 2. You can provide your own handlers for programmatic engagement -ssh_session.authenticate(Default::default()).await.expect("Failed to authenticate"); +ssh.authenticate(LocalSshAuthHandler).await + .expect("Failed to authenticate"); // Convert into an ssh client session (no distant server required) -let session = ssh_session.into_ssh_client_session().await.expect("Failed to convert session"); +let client = ssh.into_distant_client().await + .expect("Failed to convert into distant client"); ``` -Below is an example of connecting to an ssh server and producing a distant -session that spawns a distant server binary and then connects to it: +Below is an example of connecting to an ssh server, spawning a distant server +on the remote machine, and connecting to the distant server: ```rust -use distant_ssh2::Ssh2Session; +use distant_ssh2::{DistantLaunchOpts, LocalSshAuthHandler, Ssh, SshOpts}; -// Using default ssh session arguments to establish a connection -let mut ssh_session = Ssh2Session::connect("example.com", Default::default()).expect("Failed to connect"); +// Using default ssh arguments to establish a connection +let mut ssh = Ssh::connect("example.com", SshOpts::default()) + .expect("Failed to connect"); // Authenticating with the server is a separate step -// 1. You can pass defaults and authentication and host verification will -// be done over stderr +// 1. You can pass the local handler and authentication and host verification +// will be done over stderr // 2. You can provide your own handlers for programmatic engagement -ssh_session.authenticate(Default::default()).await.expect("Failed to authenticate"); +ssh.authenticate(LocalSshAuthHandler).await + .expect("Failed to authenticate"); // Convert into a distant session, which involves spawning a distant server // using the current ssh connection and then establishing a new connection // to the distant server // -// This takes in `IntoDistantSessionOpts` to specify the server's bin path, +// This takes in `DistantLaunchOpts` to specify the server's bin path, // arguments, timeout, and whether or not to spawn using a login shell -let session = ssh_session.into_distant_session(Default::default()).await.expect("Failed to convert session"); +let client = ssh.launch_and_connect(DistantLaunchOpts::default()).await + .expect("Failed to spawn server or connect to it"); ``` ## License diff --git a/distant-ssh2/src/api.rs b/distant-ssh2/src/api.rs new file mode 100644 index 0000000..0b5942c --- /dev/null +++ b/distant-ssh2/src/api.rs @@ -0,0 +1,842 @@ +use crate::{ + process::{spawn_pty, spawn_simple, SpawnResult}, + utils::{self, to_other_error}, +}; +use async_compat::CompatExt; +use async_trait::async_trait; +use distant_core::{ + data::{ + DirEntry, Environment, FileType, Metadata, ProcessId, PtySize, SystemInfo, UnixMetadata, + }, + DistantApi, DistantCtx, +}; +use log::*; +use std::{ + collections::{HashMap, HashSet}, + io, + path::{Component, PathBuf}, + sync::{Arc, Weak}, +}; +use tokio::sync::{mpsc, RwLock}; +use wezterm_ssh::{FilePermissions, OpenFileType, OpenOptions, Session as WezSession, WriteMode}; + +#[derive(Default)] +pub struct ConnectionState { + /// List of process ids that will be killed when the connection terminates + processes: Arc>>, + + /// Internal reference to global process list for removals + /// NOTE: Initialized during `on_accept` of [`DistantApi`] + global_processes: Weak>>, +} + +struct Process { + stdin_tx: mpsc::Sender>, + kill_tx: mpsc::Sender<()>, + resize_tx: mpsc::Sender, +} + +/// Represents implementation of [`DistantApi`] for SSH +pub struct SshDistantApi { + /// Internal ssh session + session: WezSession, + + /// Global tracking of running processes by id + processes: Arc>>, +} + +impl SshDistantApi { + pub fn new(session: WezSession) -> Self { + Self { + session, + processes: Arc::new(RwLock::new(HashMap::new())), + } + } +} + +#[async_trait] +impl DistantApi for SshDistantApi { + type LocalData = ConnectionState; + + async fn on_accept(&self, local_data: &mut Self::LocalData) { + local_data.global_processes = Arc::downgrade(&self.processes); + } + + async fn read_file( + &self, + ctx: DistantCtx, + path: PathBuf, + ) -> io::Result> { + debug!( + "[Conn {}] Reading bytes from file {:?}", + ctx.connection_id, path + ); + + use smol::io::AsyncReadExt; + let mut file = self + .session + .sftp() + .open(path) + .compat() + .await + .map_err(to_other_error)?; + + let mut contents = String::new(); + file.read_to_string(&mut contents).compat().await?; + Ok(contents.into_bytes()) + } + + async fn read_file_text( + &self, + ctx: DistantCtx, + path: PathBuf, + ) -> io::Result { + debug!( + "[Conn {}] Reading text from file {:?}", + ctx.connection_id, path + ); + + use smol::io::AsyncReadExt; + let mut file = self + .session + .sftp() + .open(path) + .compat() + .await + .map_err(to_other_error)?; + + let mut contents = String::new(); + file.read_to_string(&mut contents).compat().await?; + Ok(contents) + } + + async fn write_file( + &self, + ctx: DistantCtx, + path: PathBuf, + data: Vec, + ) -> io::Result<()> { + debug!( + "[Conn {}] Writing bytes to file {:?}", + ctx.connection_id, path + ); + + use smol::io::AsyncWriteExt; + let mut file = self + .session + .sftp() + .create(path) + .compat() + .await + .map_err(to_other_error)?; + + file.write_all(data.as_ref()).compat().await?; + + Ok(()) + } + + async fn write_file_text( + &self, + ctx: DistantCtx, + path: PathBuf, + data: String, + ) -> io::Result<()> { + debug!( + "[Conn {}] Writing text to file {:?}", + ctx.connection_id, path + ); + + use smol::io::AsyncWriteExt; + let mut file = self + .session + .sftp() + .create(path) + .compat() + .await + .map_err(to_other_error)?; + + file.write_all(data.as_ref()).compat().await?; + + Ok(()) + } + + async fn append_file( + &self, + ctx: DistantCtx, + path: PathBuf, + data: Vec, + ) -> io::Result<()> { + debug!( + "[Conn {}] Appending bytes to file {:?}", + ctx.connection_id, path + ); + + use smol::io::AsyncWriteExt; + let mut file = self + .session + .sftp() + .open_with_mode( + path, + OpenOptions { + read: false, + write: Some(WriteMode::Append), + // Using 644 as this mirrors "ssh touch ..." + // 644: rw-r--r-- + mode: 0o644, + ty: OpenFileType::File, + }, + ) + .compat() + .await + .map_err(to_other_error)?; + + file.write_all(data.as_ref()).compat().await?; + Ok(()) + } + + async fn append_file_text( + &self, + ctx: DistantCtx, + path: PathBuf, + data: String, + ) -> io::Result<()> { + debug!( + "[Conn {}] Appending text to file {:?}", + ctx.connection_id, path + ); + + use smol::io::AsyncWriteExt; + let mut file = self + .session + .sftp() + .open_with_mode( + path, + OpenOptions { + read: false, + write: Some(WriteMode::Append), + // Using 644 as this mirrors "ssh touch ..." + // 644: rw-r--r-- + mode: 0o644, + ty: OpenFileType::File, + }, + ) + .compat() + .await + .map_err(to_other_error)?; + + file.write_all(data.as_ref()).compat().await?; + Ok(()) + } + + async fn read_dir( + &self, + ctx: DistantCtx, + path: PathBuf, + depth: usize, + absolute: bool, + canonicalize: bool, + include_root: bool, + ) -> io::Result<(Vec, Vec)> { + debug!( + "[Conn {}] Reading directory {:?} {{depth: {}, absolute: {}, canonicalize: {}, include_root: {}}}", + ctx.connection_id, path, depth, absolute, canonicalize, include_root + ); + + let sftp = self.session.sftp(); + + // Canonicalize our provided path to ensure that it is exists, not a loop, and absolute + let root_path = utils::canonicalize(&sftp, path).await?; + + // Build up our entry list + let mut entries = Vec::new(); + let mut errors: Vec = Vec::new(); + + let mut to_traverse = vec![DirEntry { + path: root_path.to_path_buf(), + file_type: FileType::Dir, + depth: 0, + }]; + + while let Some(entry) = to_traverse.pop() { + let is_root = entry.depth == 0; + let next_depth = entry.depth + 1; + let ft = entry.file_type; + let path = if entry.path.is_relative() { + root_path.join(&entry.path) + } else { + entry.path.to_path_buf() + }; + + // Always include any non-root in our traverse list, but only include the + // root directory if flagged to do so + if !is_root || include_root { + entries.push(entry); + } + + let is_dir = match ft { + FileType::Dir => true, + FileType::File => false, + FileType::Symlink => match sftp.metadata(path.to_path_buf()).await { + Ok(metadata) => metadata.is_dir(), + Err(x) => { + errors.push(to_other_error(x)); + continue; + } + }, + }; + + // Determine if we continue traversing or stop + if is_dir && (depth == 0 || next_depth <= depth) { + match sftp + .read_dir(path.to_path_buf()) + .compat() + .await + .map_err(to_other_error) + { + Ok(entries) => { + for (path, metadata) in entries { + // Canonicalize the path if specified, otherwise just return + // the path as is + let mut path = if canonicalize { + match utils::canonicalize(&sftp, path.as_std_path()).await { + Ok(path) => path, + Err(x) => { + errors.push(to_other_error(x)); + continue; + } + } + } else { + path.into_std_path_buf() + }; + + // Strip the path of its prefix based if not flagged as absolute + if !absolute { + // NOTE: In the situation where we canonicalized the path earlier, + // there is no guarantee that our root path is still the parent of + // the symlink's destination; so, in that case we MUST just return + // the path if the strip_prefix fails + path = path + .strip_prefix(root_path.as_path()) + .map(|p| p.to_path_buf()) + .unwrap_or(path); + }; + + // If we canonicalized the path, we also want to refresh our metadata + // on windows since it doesn't reflect the real file type from read_dir + let metadata = if canonicalize { + sftp.metadata(path.to_path_buf()) + .compat() + .await + .unwrap_or(metadata) + } else { + metadata + }; + + let ft = metadata.ty; + to_traverse.push(DirEntry { + path, + file_type: if ft.is_dir() { + FileType::Dir + } else if ft.is_file() { + FileType::File + } else { + FileType::Symlink + }, + depth: next_depth, + }); + } + } + Err(x) if is_root => return Err(io::Error::new(io::ErrorKind::Other, x)), + Err(x) => errors.push(x), + } + } + } + + // Sort entries by filename + entries.sort_unstable_by_key(|e| e.path.to_path_buf()); + + Ok((entries, errors)) + } + + async fn create_dir( + &self, + ctx: DistantCtx, + path: PathBuf, + all: bool, + ) -> io::Result<()> { + debug!( + "[Conn {}] Creating directory {:?} {{all: {}}}", + ctx.connection_id, path, all + ); + + let sftp = self.session.sftp(); + + // Makes the immediate directory, failing if given a path with missing components + async fn mkdir(sftp: &wezterm_ssh::Sftp, path: PathBuf) -> io::Result<()> { + // Using 755 as this mirrors "ssh mkdir ..." + // 755: rwxr-xr-x + sftp.create_dir(path, 0o755) + .compat() + .await + .map_err(to_other_error) + } + + if all { + // Keep trying to create a directory, moving up to parent each time a failure happens + let mut failed_paths = Vec::new(); + let mut cur_path = path.as_path(); + let mut first_err = None; + loop { + match mkdir(&sftp, cur_path.to_path_buf()).await { + Ok(_) => break, + Err(x) => { + failed_paths.push(cur_path); + if let Some(path) = cur_path.parent() { + cur_path = path; + + if first_err.is_none() { + first_err = Some(x); + } + } else { + return Err(io::Error::new( + io::ErrorKind::PermissionDenied, + first_err.unwrap_or(x), + )); + } + } + } + } + + // Now that we've successfully created a parent component (or the directory), proceed + // to attempt to create each failed directory + while let Some(path) = failed_paths.pop() { + mkdir(&sftp, path.to_path_buf()).await?; + } + } else { + mkdir(&sftp, path).await?; + } + + Ok(()) + } + + async fn remove( + &self, + ctx: DistantCtx, + path: PathBuf, + force: bool, + ) -> io::Result<()> { + debug!( + "[Conn {}] Removing {:?} {{force: {}}}", + ctx.connection_id, path, force + ); + + let sftp = self.session.sftp(); + + // Determine if we are dealing with a file or directory + let stat = sftp + .metadata(path.to_path_buf()) + .compat() + .await + .map_err(to_other_error)?; + + // If a file or symlink, we just unlink (easy) + if stat.is_file() || stat.is_symlink() { + sftp.remove_file(path) + .compat() + .await + .map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?; + // If directory and not forcing, we just rmdir (easy) + } else if !force { + sftp.remove_dir(path) + .compat() + .await + .map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?; + // Otherwise, we need to find all files and directories, keep track of their depth, and + // then attempt to remove them all + } else { + let mut entries = Vec::new(); + let mut to_traverse = vec![DirEntry { + path, + file_type: FileType::Dir, + depth: 0, + }]; + + // Collect all entries within directory + while let Some(entry) = to_traverse.pop() { + if entry.file_type == FileType::Dir { + let path = entry.path.to_path_buf(); + let depth = entry.depth; + + entries.push(entry); + + for (path, stat) in sftp.read_dir(path).await.map_err(to_other_error)? { + to_traverse.push(DirEntry { + path: path.into_std_path_buf(), + file_type: if stat.is_dir() { + FileType::Dir + } else if stat.is_file() { + FileType::File + } else { + FileType::Symlink + }, + depth: depth + 1, + }); + } + } else { + entries.push(entry); + } + } + + // Sort by depth such that deepest are last as we will be popping + // off entries from end to remove first + entries.sort_unstable_by_key(|e| e.depth); + + while let Some(entry) = entries.pop() { + if entry.file_type == FileType::Dir { + sftp.remove_dir(entry.path) + .compat() + .await + .map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?; + } else { + sftp.remove_file(entry.path) + .compat() + .await + .map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?; + } + } + } + + Ok(()) + } + + async fn copy( + &self, + ctx: DistantCtx, + src: PathBuf, + dst: PathBuf, + ) -> io::Result<()> { + debug!( + "[Conn {}] Copying {:?} to {:?}", + ctx.connection_id, src, dst + ); + + // NOTE: SFTP does not provide a remote-to-remote copy method, so we instead execute + // a program and hope that it applies, starting with the Unix/BSD/GNU cp method + // and switch to Window's xcopy if the former fails + + // Unix cp -R + let unix_result = self + .session + .exec(&format!("cp -R {:?} {:?}", src, dst), None) + .compat() + .await; + + let failed = unix_result.is_err() || { + let exit_status = unix_result.unwrap().child.async_wait().compat().await; + exit_status.is_err() || !exit_status.unwrap().success() + }; + + // Windows xcopy /s /e + if failed { + let exit_status = self + .session + .exec(&format!("xcopy {:?} {:?} /s /e", src, dst), None) + .compat() + .await + .map_err(to_other_error)? + .child + .async_wait() + .compat() + .await + .map_err(to_other_error)?; + + if !exit_status.success() { + return Err(io::Error::new( + io::ErrorKind::Other, + "Unix and windows copy commands failed", + )); + } + } + + Ok(()) + } + + async fn rename( + &self, + ctx: DistantCtx, + src: PathBuf, + dst: PathBuf, + ) -> io::Result<()> { + debug!( + "[Conn {}] Renaming {:?} to {:?}", + ctx.connection_id, src, dst + ); + + self.session + .sftp() + .rename(src, dst, Default::default()) + .compat() + .await + .map_err(to_other_error)?; + + Ok(()) + } + + async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result { + debug!("[Conn {}] Checking if {:?} exists", ctx.connection_id, path); + + // NOTE: SFTP does not provide a means to check if a path exists that can be performed + // separately from getting permission errors; so, we just assume any error means that the path + // does not exist + let exists = self + .session + .sftp() + .symlink_metadata(path) + .compat() + .await + .is_ok(); + Ok(exists) + } + + async fn metadata( + &self, + ctx: DistantCtx, + path: PathBuf, + canonicalize: bool, + resolve_file_type: bool, + ) -> io::Result { + debug!( + "[Conn {}] Reading metadata for {:?} {{canonicalize: {}, resolve_file_type: {}}}", + ctx.connection_id, path, canonicalize, resolve_file_type + ); + + let sftp = self.session.sftp(); + let canonicalized_path = if canonicalize { + Some(utils::canonicalize(&sftp, path.as_path()).await?) + } else { + None + }; + + let metadata = if resolve_file_type { + sftp.metadata(path).compat().await.map_err(to_other_error)? + } else { + sftp.symlink_metadata(path) + .compat() + .await + .map_err(to_other_error)? + }; + + let file_type = if metadata.is_dir() { + FileType::Dir + } else if metadata.is_file() { + FileType::File + } else { + FileType::Symlink + }; + + Ok(Metadata { + canonicalized_path, + file_type, + len: metadata.size.unwrap_or(0), + // Check that owner, group, or other has write permission (if not, then readonly) + readonly: metadata + .permissions + .map(FilePermissions::is_readonly) + .unwrap_or(true), + accessed: metadata.accessed.map(u128::from), + modified: metadata.modified.map(u128::from), + created: None, + unix: metadata.permissions.as_ref().map(|p| UnixMetadata { + owner_read: p.owner_read, + owner_write: p.owner_write, + owner_exec: p.owner_exec, + group_read: p.group_read, + group_write: p.group_write, + group_exec: p.group_exec, + other_read: p.other_read, + other_write: p.other_write, + other_exec: p.other_exec, + }), + windows: None, + }) + } + + async fn proc_spawn( + &self, + ctx: DistantCtx, + cmd: String, + environment: Environment, + current_dir: Option, + persist: bool, + pty: Option, + ) -> io::Result { + debug!( + "[Conn {}] Spawning {} {{environment: {:?}, current_dir: {:?}, persist: {}, pty: {:?}}}", + ctx.connection_id, cmd, environment, current_dir, persist, pty + ); + + let global_processes = Arc::downgrade(&self.processes); + let local_processes = Arc::downgrade(&ctx.local_data.processes); + let cleanup = |id: ProcessId| async move { + if let Some(processes) = Weak::upgrade(&global_processes) { + processes.write().await.remove(&id); + } + if let Some(processes) = Weak::upgrade(&local_processes) { + processes.write().await.remove(&id); + } + }; + + let SpawnResult { + id, + stdin, + killer, + resizer, + } = match pty { + None => { + spawn_simple( + &self.session, + &cmd, + environment, + current_dir, + ctx.reply.clone_reply(), + cleanup, + ) + .await? + } + Some(size) => { + spawn_pty( + &self.session, + &cmd, + environment, + current_dir, + size, + ctx.reply.clone_reply(), + cleanup, + ) + .await? + } + }; + + // If the process will be killed when the connection ends, we want to add it + // to our local data + if !persist { + ctx.local_data.processes.write().await.insert(id); + } + + self.processes.write().await.insert( + id, + Process { + stdin_tx: stdin, + kill_tx: killer, + resize_tx: resizer, + }, + ); + + debug!( + "[Conn {}] Spawned process {} successfully!", + ctx.connection_id, id + ); + Ok(id) + } + + async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> { + debug!("[Conn {}] Killing process {}", ctx.connection_id, id); + + if let Some(process) = self.processes.read().await.get(&id) { + if process.kill_tx.send(()).await.is_ok() { + return Ok(()); + } + } + + Err(io::Error::new( + io::ErrorKind::BrokenPipe, + format!( + "[Conn {}] Unable to send kill signal to process {}", + ctx.connection_id, id + ), + )) + } + + async fn proc_stdin( + &self, + ctx: DistantCtx, + id: ProcessId, + data: Vec, + ) -> io::Result<()> { + debug!( + "[Conn {}] Sending stdin to process {}", + ctx.connection_id, id + ); + + if let Some(process) = self.processes.read().await.get(&id) { + if process.stdin_tx.send(data).await.is_ok() { + return Ok(()); + } + } + + Err(io::Error::new( + io::ErrorKind::BrokenPipe, + format!( + "[Conn {}] Unable to send stdin to process {}", + ctx.connection_id, id + ), + )) + } + + async fn proc_resize_pty( + &self, + ctx: DistantCtx, + id: ProcessId, + size: PtySize, + ) -> io::Result<()> { + debug!( + "[Conn {}] Resizing pty of process {} to {}", + ctx.connection_id, id, size + ); + + if let Some(process) = self.processes.read().await.get(&id) { + if process.resize_tx.send(size).await.is_ok() { + return Ok(()); + } + } + + Err(io::Error::new( + io::ErrorKind::BrokenPipe, + format!( + "[Conn {}] Unable to resize process {}", + ctx.connection_id, id + ), + )) + } + + async fn system_info(&self, ctx: DistantCtx) -> io::Result { + debug!("[Conn {}] Reading system information", ctx.connection_id); + + // Look up the current directory + let current_dir = utils::canonicalize(&self.session.sftp(), ".").await?; + + // TODO: Ideally, we would determine the family using something like the following: + // + // cmd.exe /C echo %OS% + // + // Determine OS by printing OS variable (works with Windows 2000+) + // If it matches Windows_NT, then we are on windows + // + // However, the above is not working for whatever reason (always has success == false); so, + // we're purely using a check if we have a drive letter on the canonicalized path to + // determine if on windows for now. + let is_windows = current_dir + .components() + .any(|c| matches!(c, Component::Prefix(_))); + + let family = if is_windows { "windows" } else { "unix" }.to_string(); + + Ok(SystemInfo { + family, + os: "".to_string(), + arch: "".to_string(), + current_dir, + main_separator: if is_windows { '\\' } else { '/' }, + }) + } +} diff --git a/distant-ssh2/src/handler.rs b/distant-ssh2/src/handler.rs deleted file mode 100644 index 2d7fc6c..0000000 --- a/distant-ssh2/src/handler.rs +++ /dev/null @@ -1,787 +0,0 @@ -use crate::process::{self, SpawnResult}; -use async_compat::CompatExt; -use distant_core::{ - data::{ - DirEntry, Error as DistantError, FileType, Metadata, PtySize, RunningProcess, SystemInfo, - }, - Request, RequestData, Response, ResponseData, UnixMetadata, -}; -use futures::future; -use log::*; -use std::{ - collections::HashMap, - future::Future, - io, - path::{Component, PathBuf}, - pin::Pin, - sync::Arc, -}; -use tokio::sync::{mpsc, Mutex}; -use wezterm_ssh::{FilePermissions, OpenFileType, OpenOptions, Session as WezSession, WriteMode}; - -fn to_other_error(err: E) -> io::Error -where - E: Into>, -{ - io::Error::new(io::ErrorKind::Other, err) -} - -#[derive(Default)] -pub(crate) struct State { - processes: HashMap, -} - -struct Process { - id: usize, - cmd: String, - args: Vec, - persist: bool, - stdin_tx: mpsc::Sender>, - kill_tx: mpsc::Sender<()>, - resize_tx: mpsc::Sender, -} - -type ReplyRet = Pin + Send + 'static>>; - -type PostHook = Box>) + Send>; -struct Outgoing { - data: ResponseData, - post_hook: Option, -} - -impl Outgoing { - pub fn unsupported() -> Self { - Self::from(ResponseData::from(io::Error::new( - io::ErrorKind::Other, - "Unsupported", - ))) - } -} - -impl From for Outgoing { - fn from(data: ResponseData) -> Self { - Self { - data, - post_hook: None, - } - } -} - -/// Processes the provided request, sending replies using the given sender -pub(super) async fn process( - session: WezSession, - state: Arc>, - req: Request, - tx: mpsc::Sender, -) -> Result<(), mpsc::error::SendError> { - async fn inner( - session: WezSession, - state: Arc>, - data: RequestData, - ) -> io::Result { - match data { - RequestData::FileRead { path } => file_read(session, path).await, - RequestData::FileReadText { path } => file_read_text(session, path).await, - RequestData::FileWrite { path, data } => file_write(session, path, data).await, - RequestData::FileWriteText { path, text } => file_write(session, path, text).await, - RequestData::FileAppend { path, data } => file_append(session, path, data).await, - RequestData::FileAppendText { path, text } => file_append(session, path, text).await, - RequestData::DirRead { - path, - depth, - absolute, - canonicalize, - include_root, - } => dir_read(session, path, depth, absolute, canonicalize, include_root).await, - RequestData::DirCreate { path, all } => dir_create(session, path, all).await, - RequestData::Remove { path, force } => remove(session, path, force).await, - RequestData::Copy { src, dst } => copy(session, src, dst).await, - RequestData::Rename { src, dst } => rename(session, src, dst).await, - RequestData::Watch { .. } => Ok(Outgoing::unsupported()), - RequestData::Unwatch { .. } => Ok(Outgoing::unsupported()), - RequestData::Exists { path } => exists(session, path).await, - RequestData::Metadata { - path, - canonicalize, - resolve_file_type, - } => metadata(session, path, canonicalize, resolve_file_type).await, - RequestData::ProcSpawn { - cmd, - args, - persist, - pty, - } => proc_spawn(session, state, cmd, args, persist, pty).await, - RequestData::ProcResizePty { id, size } => { - proc_resize_pty(session, state, id, size).await - } - RequestData::ProcKill { id } => proc_kill(session, state, id).await, - RequestData::ProcStdin { id, data } => proc_stdin(session, state, id, data).await, - RequestData::ProcList {} => proc_list(session, state).await, - RequestData::SystemInfo {} => system_info(session).await, - } - } - - let reply = { - let origin_id = req.id; - let tenant = req.tenant.clone(); - let tx_2 = tx.clone(); - move |payload: Vec| -> ReplyRet { - let tx = tx_2.clone(); - let res = Response::new(tenant.to_string(), origin_id, payload); - Box::pin(async move { tx.send(res).await.is_ok() }) - } - }; - - // Build up a collection of tasks to run independently - let mut payload_tasks = Vec::new(); - for data in req.payload { - let state_2 = Arc::clone(&state); - let session = session.clone(); - payload_tasks.push(tokio::spawn(async move { - match inner(session, state_2, data).await { - Ok(outgoing) => outgoing, - Err(x) => Outgoing::from(ResponseData::from(x)), - } - })); - } - - // Collect the results of our tasks into the payload entries - let mut outgoing: Vec = future::join_all(payload_tasks) - .await - .into_iter() - .map(|x| match x { - Ok(outgoing) => outgoing, - Err(x) => Outgoing::from(ResponseData::from(x)), - }) - .collect(); - - let post_hooks: Vec = outgoing - .iter_mut() - .filter_map(|x| x.post_hook.take()) - .collect(); - - let payload = outgoing.into_iter().map(|x| x.data).collect(); - let res = Response::new(req.tenant, req.id, payload); - // Send out our primary response from processing the request - let result = tx.send(res).await; - - let (tx, mut rx) = mpsc::channel(1); - tokio::spawn(async move { - while let Some(payload) = rx.recv().await { - if !reply(payload).await { - break; - } - } - }); - - // Invoke all post hooks - for hook in post_hooks { - hook(tx.clone()); - } - - result -} - -async fn file_read(session: WezSession, path: PathBuf) -> io::Result { - use smol::io::AsyncReadExt; - let mut file = session - .sftp() - .open(path) - .compat() - .await - .map_err(to_other_error)?; - - let mut contents = String::new(); - file.read_to_string(&mut contents).compat().await?; - - Ok(Outgoing::from(ResponseData::Blob { - data: contents.into_bytes(), - })) -} - -async fn file_read_text(session: WezSession, path: PathBuf) -> io::Result { - use smol::io::AsyncReadExt; - let mut file = session - .sftp() - .open(path) - .compat() - .await - .map_err(to_other_error)?; - - let mut contents = String::new(); - file.read_to_string(&mut contents).compat().await?; - - Ok(Outgoing::from(ResponseData::Text { data: contents })) -} - -async fn file_write( - session: WezSession, - path: PathBuf, - data: impl AsRef<[u8]>, -) -> io::Result { - use smol::io::AsyncWriteExt; - let mut file = session - .sftp() - .create(path) - .compat() - .await - .map_err(to_other_error)?; - - file.write_all(data.as_ref()).compat().await?; - - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn file_append( - session: WezSession, - path: PathBuf, - data: impl AsRef<[u8]>, -) -> io::Result { - use smol::io::AsyncWriteExt; - let mut file = session - .sftp() - .open_with_mode( - path, - OpenOptions { - read: false, - write: Some(WriteMode::Append), - // Using 644 as this mirrors "ssh touch ..." - // 644: rw-r--r-- - mode: 0o644, - ty: OpenFileType::File, - }, - ) - .compat() - .await - .map_err(to_other_error)?; - - file.write_all(data.as_ref()).compat().await?; - - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn dir_read( - session: WezSession, - path: PathBuf, - depth: usize, - absolute: bool, - canonicalize: bool, - include_root: bool, -) -> io::Result { - let sftp = session.sftp(); - - // Canonicalize our provided path to ensure that it is exists, not a loop, and absolute - let root_path = sftp - .canonicalize(path) - .compat() - .await - .map_err(to_other_error)? - .into_std_path_buf(); - - // Build up our entry list - let mut entries = Vec::new(); - let mut errors = Vec::new(); - - let mut to_traverse = vec![DirEntry { - path: root_path.to_path_buf(), - file_type: FileType::Dir, - depth: 0, - }]; - - while let Some(entry) = to_traverse.pop() { - let is_root = entry.depth == 0; - let next_depth = entry.depth + 1; - let ft = entry.file_type; - let path = if entry.path.is_relative() { - root_path.join(&entry.path) - } else { - entry.path.to_path_buf() - }; - - // Always include any non-root in our traverse list, but only include the - // root directory if flagged to do so - if !is_root || include_root { - entries.push(entry); - } - - let is_dir = match ft { - FileType::Dir => true, - FileType::File => false, - FileType::Symlink => match sftp.metadata(path.to_path_buf()).await { - Ok(metadata) => metadata.is_dir(), - Err(x) => { - errors.push(DistantError::from(to_other_error(x))); - continue; - } - }, - }; - - // Determine if we continue traversing or stop - if is_dir && (depth == 0 || next_depth <= depth) { - match sftp - .read_dir(path.to_path_buf()) - .compat() - .await - .map_err(to_other_error) - { - Ok(entries) => { - for (mut path, metadata) in entries { - // Canonicalize the path if specified, otherwise just return - // the path as is - path = if canonicalize { - match sftp.canonicalize(path).compat().await { - Ok(path) => path, - Err(x) => { - errors.push(DistantError::from(to_other_error(x))); - continue; - } - } - } else { - path - }; - - // Strip the path of its prefix based if not flagged as absolute - if !absolute { - // NOTE: In the situation where we canonicalized the path earlier, - // there is no guarantee that our root path is still the parent of - // the symlink's destination; so, in that case we MUST just return - // the path if the strip_prefix fails - path = path - .strip_prefix(root_path.as_path()) - .map(|p| p.to_path_buf()) - .unwrap_or(path); - }; - - let ft = metadata.ty; - to_traverse.push(DirEntry { - path: path.into_std_path_buf(), - file_type: if ft.is_dir() { - FileType::Dir - } else if ft.is_file() { - FileType::File - } else { - FileType::Symlink - }, - depth: next_depth, - }); - } - } - Err(x) if is_root => return Err(io::Error::new(io::ErrorKind::Other, x)), - Err(x) => errors.push(DistantError::from(x)), - } - } - } - - // Sort entries by filename - entries.sort_unstable_by_key(|e| e.path.to_path_buf()); - - Ok(Outgoing::from(ResponseData::DirEntries { entries, errors })) -} - -async fn dir_create(session: WezSession, path: PathBuf, all: bool) -> io::Result { - let sftp = session.sftp(); - - // Makes the immediate directory, failing if given a path with missing components - async fn mkdir(sftp: &wezterm_ssh::Sftp, path: PathBuf) -> io::Result<()> { - // Using 755 as this mirrors "ssh mkdir ..." - // 755: rwxr-xr-x - sftp.create_dir(path, 0o755) - .compat() - .await - .map_err(to_other_error) - } - - if all { - // Keep trying to create a directory, moving up to parent each time a failure happens - let mut failed_paths = Vec::new(); - let mut cur_path = path.as_path(); - let mut first_err = None; - loop { - match mkdir(&sftp, cur_path.to_path_buf()).await { - Ok(_) => break, - Err(x) => { - failed_paths.push(cur_path); - if let Some(path) = cur_path.parent() { - cur_path = path; - - if first_err.is_none() { - first_err = Some(x); - } - } else { - return Err(io::Error::new( - io::ErrorKind::PermissionDenied, - first_err.unwrap_or(x), - )); - } - } - } - } - - // Now that we've successfully created a parent component (or the directory), proceed - // to attempt to create each failed directory - while let Some(path) = failed_paths.pop() { - mkdir(&sftp, path.to_path_buf()).await?; - } - } else { - mkdir(&sftp, path).await?; - } - - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn remove(session: WezSession, path: PathBuf, force: bool) -> io::Result { - let sftp = session.sftp(); - - // Determine if we are dealing with a file or directory - let stat = sftp - .metadata(path.to_path_buf()) - .compat() - .await - .map_err(to_other_error)?; - - // If a file or symlink, we just unlink (easy) - if stat.is_file() || stat.is_symlink() { - sftp.remove_file(path) - .compat() - .await - .map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?; - // If directory and not forcing, we just rmdir (easy) - } else if !force { - sftp.remove_dir(path) - .compat() - .await - .map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?; - // Otherwise, we need to find all files and directories, keep track of their depth, and - // then attempt to remove them all - } else { - let mut entries = Vec::new(); - let mut to_traverse = vec![DirEntry { - path, - file_type: FileType::Dir, - depth: 0, - }]; - - // Collect all entries within directory - while let Some(entry) = to_traverse.pop() { - if entry.file_type == FileType::Dir { - let path = entry.path.to_path_buf(); - let depth = entry.depth; - - entries.push(entry); - - for (path, stat) in sftp.read_dir(path).await.map_err(to_other_error)? { - to_traverse.push(DirEntry { - path: path.into_std_path_buf(), - file_type: if stat.is_dir() { - FileType::Dir - } else if stat.is_file() { - FileType::File - } else { - FileType::Symlink - }, - depth: depth + 1, - }); - } - } else { - entries.push(entry); - } - } - - // Sort by depth such that deepest are last as we will be popping - // off entries from end to remove first - entries.sort_unstable_by_key(|e| e.depth); - - while let Some(entry) = entries.pop() { - if entry.file_type == FileType::Dir { - sftp.remove_dir(entry.path) - .compat() - .await - .map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?; - } else { - sftp.remove_file(entry.path) - .compat() - .await - .map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?; - } - } - } - - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn copy(session: WezSession, src: PathBuf, dst: PathBuf) -> io::Result { - // NOTE: SFTP does not provide a remote-to-remote copy method, so we instead execute - // a program and hope that it applies, starting with the Unix/BSD/GNU cp method - // and switch to Window's xcopy if the former fails - - // Unix cp -R - let unix_result = session - .exec(&format!("cp -R {:?} {:?}", src, dst), None) - .compat() - .await; - - let failed = unix_result.is_err() || { - let exit_status = unix_result.unwrap().child.async_wait().compat().await; - exit_status.is_err() || !exit_status.unwrap().success() - }; - - // Windows xcopy /s /e - if failed { - let exit_status = session - .exec(&format!("xcopy {:?} {:?} /s /e", src, dst), None) - .compat() - .await - .map_err(to_other_error)? - .child - .async_wait() - .compat() - .await - .map_err(to_other_error)?; - - if !exit_status.success() { - return Err(io::Error::new( - io::ErrorKind::Other, - "Unix and windows copy commands failed", - )); - } - } - - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn rename(session: WezSession, src: PathBuf, dst: PathBuf) -> io::Result { - session - .sftp() - .rename(src, dst, Default::default()) - .compat() - .await - .map_err(to_other_error)?; - - Ok(Outgoing::from(ResponseData::Ok)) -} - -async fn exists(session: WezSession, path: PathBuf) -> io::Result { - // NOTE: SFTP does not provide a means to check if a path exists that can be performed - // separately from getting permission errors; so, we just assume any error means that the path - // does not exist - let exists = session.sftp().symlink_metadata(path).compat().await.is_ok(); - - Ok(Outgoing::from(ResponseData::Exists { value: exists })) -} - -async fn metadata( - session: WezSession, - path: PathBuf, - canonicalize: bool, - resolve_file_type: bool, -) -> io::Result { - let sftp = session.sftp(); - let canonicalized_path = if canonicalize { - Some( - sftp.canonicalize(path.to_path_buf()) - .compat() - .await - .map_err(to_other_error)? - .into_std_path_buf(), - ) - } else { - None - }; - - let metadata = if resolve_file_type { - sftp.metadata(path).compat().await.map_err(to_other_error)? - } else { - sftp.symlink_metadata(path) - .compat() - .await - .map_err(to_other_error)? - }; - - let file_type = if metadata.is_dir() { - FileType::Dir - } else if metadata.is_file() { - FileType::File - } else { - FileType::Symlink - }; - - Ok(Outgoing::from(ResponseData::Metadata(Metadata { - canonicalized_path, - file_type, - len: metadata.size.unwrap_or(0), - // Check that owner, group, or other has write permission (if not, then readonly) - readonly: metadata - .permissions - .map(FilePermissions::is_readonly) - .unwrap_or(true), - accessed: metadata.accessed.map(u128::from), - modified: metadata.modified.map(u128::from), - created: None, - unix: metadata.permissions.as_ref().map(|p| UnixMetadata { - owner_read: p.owner_read, - owner_write: p.owner_write, - owner_exec: p.owner_exec, - group_read: p.group_read, - group_write: p.group_write, - group_exec: p.group_exec, - other_read: p.other_read, - other_write: p.other_write, - other_exec: p.other_exec, - }), - windows: None, - }))) -} - -async fn proc_spawn( - session: WezSession, - state: Arc>, - cmd: String, - args: Vec, - persist: bool, - pty: Option, -) -> io::Result { - let cmd_string = format!("{} {}", cmd, args.join(" ")); - debug!(" Spawning {} (pty: {:?})", cmd_string, pty); - - let state_2 = Arc::clone(&state); - let cleanup = |id: usize| async move { - state_2.lock().await.processes.remove(&id); - }; - - let SpawnResult { - id, - stdin, - killer, - resizer, - initialize, - } = match pty { - None => process::spawn_simple(&session, &cmd_string, cleanup).await?, - Some(size) => process::spawn_pty(&session, &cmd_string, size, cleanup).await?, - }; - - state.lock().await.processes.insert( - id, - Process { - id, - cmd, - args, - persist, - stdin_tx: stdin, - kill_tx: killer, - resize_tx: resizer, - }, - ); - - debug!( - " Spawned successfully! Will enter post hook later", - id - ); - Ok(Outgoing { - data: ResponseData::ProcSpawned { id }, - post_hook: Some(initialize), - }) -} - -async fn proc_resize_pty( - _session: WezSession, - state: Arc>, - id: usize, - size: PtySize, -) -> io::Result { - if let Some(process) = state.lock().await.processes.get(&id) { - if process.resize_tx.send(size).await.is_ok() { - return Ok(Outgoing::from(ResponseData::Ok)); - } - } - - Err(io::Error::new( - io::ErrorKind::BrokenPipe, - format!(" Unable to resize process", id), - )) -} - -async fn proc_kill( - _session: WezSession, - state: Arc>, - id: usize, -) -> io::Result { - if let Some(process) = state.lock().await.processes.remove(&id) { - if process.kill_tx.send(()).await.is_ok() { - return Ok(Outgoing::from(ResponseData::Ok)); - } - } - - Err(io::Error::new( - io::ErrorKind::BrokenPipe, - format!(" Unable to send kill signal to process", id), - )) -} - -async fn proc_stdin( - _session: WezSession, - state: Arc>, - id: usize, - data: Vec, -) -> io::Result { - if let Some(process) = state.lock().await.processes.get_mut(&id) { - if process.stdin_tx.send(data).await.is_ok() { - return Ok(Outgoing::from(ResponseData::Ok)); - } - } - - Err(io::Error::new( - io::ErrorKind::BrokenPipe, - format!(" Unable to send stdin to process", id), - )) -} - -async fn proc_list(_session: WezSession, state: Arc>) -> io::Result { - Ok(Outgoing::from(ResponseData::ProcEntries { - entries: state - .lock() - .await - .processes - .values() - .map(|p| RunningProcess { - cmd: p.cmd.to_string(), - args: p.args.clone(), - persist: p.persist, - // TODO: Support pty size from ssh - pty: None, - id: p.id, - }) - .collect(), - })) -} - -async fn system_info(session: WezSession) -> io::Result { - let current_dir = session - .sftp() - .canonicalize(".") - .compat() - .await - .map_err(to_other_error)? - .into_std_path_buf(); - - let first_component = current_dir.components().next(); - let is_windows = - first_component.is_some() && matches!(first_component.unwrap(), Component::Prefix(_)); - let is_unix = current_dir.as_os_str().to_string_lossy().starts_with('/'); - - let family = if is_windows { - "windows" - } else if is_unix { - "unix" - } else { - "" - } - .to_string(); - - Ok(Outgoing::from(ResponseData::SystemInfo(SystemInfo { - family, - os: "".to_string(), - arch: "".to_string(), - current_dir, - main_separator: if is_windows { '\\' } else { '/' }, - }))) -} diff --git a/distant-ssh2/src/lib.rs b/distant-ssh2/src/lib.rs index 629a386..24b1dab 100644 --- a/distant-ssh2/src/lib.rs +++ b/distant-ssh2/src/lib.rs @@ -2,9 +2,16 @@ compile_error!("Either feature \"libssh\" or \"ssh2\" must be enabled for this crate."); use async_compat::CompatExt; +use async_once_cell::OnceCell; +use async_trait::async_trait; use distant_core::{ - Request, Session, SessionChannelExt, SessionDetails, SessionInfo, Transport, - XChaCha20Poly1305Codec, + data::Environment, + net::{ + FramedTransport, IntoSplit, OneshotListener, ServerExt, ServerRef, TcpClientExt, + XChaCha20Poly1305Codec, + }, + BoxedDistantReader, BoxedDistantWriter, BoxedDistantWriterReader, DistantApiServer, + DistantChannelExt, DistantClient, DistantSingleKeyCredentials, }; use log::*; use smol::channel::Receiver as SmolReceiver; @@ -15,17 +22,30 @@ use std::{ net::{IpAddr, SocketAddr}, path::PathBuf, str::FromStr, - sync::Arc, time::Duration, }; -use tokio::sync::{mpsc, Mutex}; use wezterm_ssh::{Config as WezConfig, Session as WezSession, SessionEvent as WezSessionEvent}; -mod handler; +mod api; mod process; +mod utils; + +use api::SshDistantApi; + +/// Represents the family of the remote machine connected over SSH +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))] +pub enum SshFamily { + /// Operating system belongs to unix family + Unix, + + /// Operating system belongs to windows family + Windows, +} /// Represents the backend to use for ssh operations -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))] pub enum SshBackend { @@ -96,10 +116,10 @@ impl fmt::Display for SshBackend { } } -/// Represents a singular authentication prompt for a new ssh session +/// Represents a singular authentication prompt for a new ssh client #[derive(Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Ssh2AuthPrompt { +pub struct SshAuthPrompt { /// The label to show when prompting the user pub prompt: String, @@ -108,11 +128,11 @@ pub struct Ssh2AuthPrompt { pub echo: bool, } -/// Represents an authentication request that needs to be handled before an ssh session can be +/// Represents an authentication request that needs to be handled before an ssh client can be /// established #[derive(Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Ssh2AuthEvent { +pub struct SshAuthEvent { /// Represents the name of the user to be authenticated. This may be empty! pub username: String, @@ -120,14 +140,14 @@ pub struct Ssh2AuthEvent { pub instructions: String, /// Prompts to be conveyed to the user, each representing a single answer needed - pub prompts: Vec, + pub prompts: Vec, } -/// Represents options to be provided when establishing an ssh session +/// Represents options to be provided when establishing an ssh client #[derive(Clone, Debug, Default)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", serde(default))] -pub struct Ssh2SessionOpts { +pub struct SshOpts { /// Represents the backend to use for ssh operations pub backend: SshBackend, @@ -168,9 +188,9 @@ pub struct Ssh2SessionOpts { pub other: BTreeMap, } -/// Represents options to be provided when converting an ssh session into a distant session +/// Represents options to be provided when converting an ssh client into a distant client #[derive(Clone, Debug)] -pub struct IntoDistantSessionOpts { +pub struct DistantLaunchOpts { /// Binary to use for distant server pub binary: String, @@ -185,7 +205,7 @@ pub struct IntoDistantSessionOpts { pub timeout: Duration, } -impl Default for IntoDistantSessionOpts { +impl Default for DistantLaunchOpts { fn default() -> Self { Self { binary: String::from("distant"), @@ -196,81 +216,109 @@ impl Default for IntoDistantSessionOpts { } } -/// Represents callback functions to be invoked during authentication of an ssh session -pub struct Ssh2AuthHandler<'a> { +/// Interface to handle various events during ssh authentication +#[async_trait] +pub trait SshAuthHandler { /// Invoked whenever a series of authentication prompts need to be displayed and responded to, /// receiving one event at a time and returning a collection of answers matching the total /// prompts provided in the event - pub on_authenticate: Box io::Result> + 'a>, - - /// Invoked when receiving a banner from the ssh server, receiving the banner as a str, useful - /// to display to the user - pub on_banner: Box, + async fn on_authenticate(&self, event: SshAuthEvent) -> io::Result>; /// Invoked when the host is unknown for a new ssh connection, receiving the host as a str and - /// returning true if the host is acceptable or false if the host (and thereby ssh session) + /// returning true if the host is acceptable or false if the host (and thereby ssh client) /// should be declined - pub on_host_verify: Box io::Result + 'a>, + async fn on_verify_host(&self, host: &str) -> io::Result; + + /// Invoked when receiving a banner from the ssh server, receiving the banner as a str, useful + /// to display to the user + async fn on_banner(&self, text: &str); /// Invoked when an error is encountered, receiving the error as a str - pub on_error: Box, + async fn on_error(&self, text: &str); } -impl Default for Ssh2AuthHandler<'static> { - fn default() -> Self { - Self { - on_authenticate: Box::new(|ev| { - if !ev.username.is_empty() { - eprintln!("Authentication for {}", ev.username); - } +/// Implementation of [`SshAuthHandler`] that prompts locally for authentication and verification +/// events +pub struct LocalSshAuthHandler; + +#[async_trait] +impl SshAuthHandler for LocalSshAuthHandler { + async fn on_authenticate(&self, event: SshAuthEvent) -> io::Result> { + trace!("[local] on_authenticate({event:?})"); + let task = tokio::task::spawn_blocking(move || { + if !event.username.is_empty() { + eprintln!("Authentication for {}", event.username); + } + + if !event.instructions.is_empty() { + eprintln!("{}", event.instructions); + } + + let mut answers = Vec::new(); + for prompt in &event.prompts { + // Contains all prompt lines including same line + let mut prompt_lines = prompt.prompt.split('\n').collect::>(); + + // Line that is prompt on same line as answer + let prompt_line = prompt_lines.pop().unwrap(); - if !ev.instructions.is_empty() { - eprintln!("{}", ev.instructions); + // Go ahead and display all other lines + for line in prompt_lines.into_iter() { + eprintln!("{}", line); } - let mut answers = Vec::new(); - for prompt in &ev.prompts { - // Contains all prompt lines including same line - let mut prompt_lines = prompt.prompt.split('\n').collect::>(); + let answer = if prompt.echo { + eprint!("{}", prompt_line); + std::io::stderr().lock().flush()?; - // Line that is prompt on same line as answer - let prompt_line = prompt_lines.pop().unwrap(); + let mut answer = String::new(); + std::io::stdin().read_line(&mut answer)?; + answer + } else { + rpassword::prompt_password(prompt_line)? + }; - // Go ahead and display all other lines - for line in prompt_lines.into_iter() { - eprintln!("{}", line); - } + answers.push(answer); + } + Ok(answers) + }); - let answer = if prompt.echo { - eprint!("{}", prompt_line); - std::io::stderr().lock().flush()?; + task.await + .map_err(|x| io::Error::new(io::ErrorKind::Other, x))? + } - let mut answer = String::new(); - std::io::stdin().read_line(&mut answer)?; - answer - } else { - rpassword::prompt_password_stderr(prompt_line)? - }; + async fn on_verify_host(&self, host: &str) -> io::Result { + trace!("[local] on_verify_host({host})"); + eprintln!("{}", host); + let task = tokio::task::spawn_blocking(|| { + eprint!("Enter [y/N]> "); + std::io::stderr().lock().flush()?; - answers.push(answer); - } - Ok(answers) - }), - on_banner: Box::new(|_| {}), - on_host_verify: Box::new(|message| { - eprintln!("{}", message); - match rpassword::prompt_password_stderr("Enter [y/N]> ")?.as_str() { - "y" | "Y" | "yes" | "YES" => Ok(true), - _ => Ok(false), - } - }), - on_error: Box::new(|_| {}), - } + let mut answer = String::new(); + std::io::stdin().read_line(&mut answer)?; + + trace!("Verify? Answer = '{answer}'"); + match answer.as_str().trim() { + "y" | "Y" | "yes" | "YES" => Ok(true), + _ => Ok(false), + } + }); + + task.await + .map_err(|x| io::Error::new(io::ErrorKind::Other, x))? + } + + async fn on_banner(&self, _text: &str) { + trace!("[local] on_banner({_text})"); + } + + async fn on_error(&self, _text: &str) { + trace!("[local] on_error({_text})"); } } -/// Represents an ssh2 session -pub struct Ssh2Session { +/// Represents an ssh2 client +pub struct Ssh { session: WezSession, events: SmolReceiver, host: String, @@ -278,9 +326,9 @@ pub struct Ssh2Session { authenticated: bool, } -impl Ssh2Session { +impl Ssh { /// Connect to a remote TCP server using SSH - pub fn connect(host: impl AsRef, opts: Ssh2SessionOpts) -> io::Result { + pub fn connect(host: impl AsRef, opts: SshOpts) -> io::Result { debug!( "Establishing ssh connection to {} using {:?}", host.as_ref(), @@ -292,7 +340,7 @@ impl Ssh2Session { // Grab the config for the specific host let mut config = config.for_host(host.as_ref()); - // Override config with any settings provided by session opts + // Override config with any settings provided by client opts if let Some(port) = opts.port.as_ref() { config.insert("port".to_string(), port.to_string()); } @@ -363,12 +411,12 @@ impl Ssh2Session { }) } - /// Host this session is connected to + /// Host this client is connected to pub fn host(&self) -> &str { &self.host } - /// Port this session is connected to on remote host + /// Port this client is connected to on remote host pub fn port(&self) -> u16 { self.port } @@ -378,8 +426,8 @@ impl Ssh2Session { self.authenticated } - /// Authenticates the [`Ssh2Session`] if not already authenticated - pub async fn authenticate(&mut self, mut handler: Ssh2AuthHandler<'_>) -> io::Result<()> { + /// Authenticates the [`Ssh`] if not already authenticated + pub async fn authenticate(&mut self, handler: impl SshAuthHandler) -> io::Result<()> { // If already authenticated, exit if self.authenticated { return Ok(()); @@ -391,11 +439,11 @@ impl Ssh2Session { match event { WezSessionEvent::Banner(banner) => { if let Some(banner) = banner { - (handler.on_banner)(banner.as_ref()); + handler.on_banner(banner.as_ref()).await; } } WezSessionEvent::HostVerify(verify) => { - let verified = (handler.on_host_verify)(verify.message.as_str())?; + let verified = handler.on_verify_host(verify.message.as_str()).await?; verify .answer(verified) .compat() @@ -403,27 +451,27 @@ impl Ssh2Session { .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; } WezSessionEvent::Authenticate(mut auth) => { - let ev = Ssh2AuthEvent { + let ev = SshAuthEvent { username: auth.username.clone(), instructions: auth.instructions.clone(), prompts: auth .prompts .drain(..) - .map(|p| Ssh2AuthPrompt { + .map(|p| SshAuthPrompt { prompt: p.prompt, echo: p.echo, }) .collect(), }; - let answers = (handler.on_authenticate)(ev)?; + let answers = handler.on_authenticate(ev).await?; auth.answer(answers) .compat() .await .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; } WezSessionEvent::Error(err) => { - (handler.on_error)(&err); + handler.on_error(&err).await; return Err(io::Error::new(io::ErrorKind::PermissionDenied, err)); } WezSessionEvent::Authenticated => break, @@ -436,9 +484,37 @@ impl Ssh2Session { Ok(()) } - /// Consume [`Ssh2Session`] and produce a distant [`Session`] that is connected to a remote - /// distant server that is spawned using the ssh session - pub async fn into_distant_session(self, opts: IntoDistantSessionOpts) -> io::Result { + /// Detects the family of operating system on the remote machine + pub async fn detect_family(&self) -> io::Result { + static INSTANCE: OnceCell = OnceCell::new(); + + // Exit early if not authenticated as this is a requirement + if !self.authenticated { + return Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "Not authenticated", + )); + } + + INSTANCE + .get_or_try_init(async move { + let is_windows = utils::is_windows(&self.session.sftp()).await?; + + Ok(if is_windows { + SshFamily::Windows + } else { + SshFamily::Unix + }) + }) + .await + .copied() + } + + /// Consume [`Ssh`] and produce a [`DistantClient`] that is connected to a remote + /// distant server that is spawned using the ssh client + pub async fn launch_and_connect(self, opts: DistantLaunchOpts) -> io::Result { + trace!("ssh::launch_and_connect({:?})", opts); + // Exit early if not authenticated as this is a requirement if !self.authenticated { return Err(io::Error::new( @@ -477,17 +553,17 @@ impl Ssh2Session { )); } - let info = self.into_distant_session_info(opts).await?; - let key = info.key; + let credentials = self.launch(opts).await?; + let key = credentials.key; let codec = XChaCha20Poly1305Codec::from(key); // Try each IP address with the same port to see if one works let mut err = None; for ip in candidate_ips { - let addr = SocketAddr::new(ip, info.port); + let addr = SocketAddr::new(ip, credentials.port); debug!("Attempting to connect to distant server @ {}", addr); - match Session::tcp_connect_timeout(addr, codec.clone(), timeout).await { - Ok(session) => return Ok(session), + match DistantClient::connect_timeout(addr, codec.clone(), timeout).await { + Ok(client) => return Ok(client), Err(x) => err = Some(x), } } @@ -496,12 +572,11 @@ impl Ssh2Session { Err(err.expect("Err set above")) } - /// Consume [`Ssh2Session`] and produce a distant [`SessionInfo`] representing a remote - /// distant server that is spawned using the ssh session - pub async fn into_distant_session_info( - self, - opts: IntoDistantSessionOpts, - ) -> io::Result { + /// Consume [`Ssh`] and launch a distant server, returning a [`DistantSingleKeyCredentials`] + /// tied to the launched server that includes credentials + pub async fn launch(self, opts: DistantLaunchOpts) -> io::Result { + trace!("ssh::launch({:?})", opts); + // Exit early if not authenticated as this is a requirement if !self.authenticated { return Err(io::Error::new( @@ -510,73 +585,66 @@ impl Ssh2Session { )); } + let family = self.detect_family().await?; + let host = self.host().to_string(); - // Turn our ssh connection into a client session so we can use it to spawn our server - let (mut session, cleanup_session) = self.into_ssh_client_session_impl().await?; + // Turn our ssh connection into a client/server pair so we can use it to spawn our server + let (mut client, server) = self.into_distant_pair().await?; // Build arguments for distant to execute listen subcommand let mut args = vec![ + String::from("server"), String::from("listen"), + String::from("--daemon"), String::from("--host"), String::from("ssh"), ]; - args.extend( - shell_words::split(&opts.args) + args.extend(match family { + SshFamily::Windows => winsplit::split(&opts.args), + SshFamily::Unix => shell_words::split(&opts.args) .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?, - ); + }); // If we are using a login shell, we need to make the binary be sh // so we can appropriately pipe into the login shell - let (bin, args) = if opts.use_login_shell { - ( - String::from("sh"), - vec![ - String::from("-c"), - shell_words::quote(&format!( - "echo {} {} | $SHELL -l", - opts.binary, - args.join(" ") - )) - .to_string(), - ], + let cmd = if opts.use_login_shell { + format!( + "sh -c {}", + shell_words::quote(&format!( + "echo {} {} | $SHELL -l", + opts.binary, + args.join(" ") + )) ) } else { - (opts.binary, args) + format!("{} {}", opts.binary, args.join(" ")) }; // Spawn distant server and detach it so that we don't kill it when the - // ssh session is closed - debug!("Executing {} {}", bin, args.join(" ")); - let mut proc = session - .spawn("", bin, args, true, None) - .await - .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; - let mut stdout = proc.stdout.take().unwrap(); - let mut stderr = proc.stderr.take().unwrap(); - let (success, code) = proc - .wait() - .await - .map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x))?; + // ssh client is closed + debug!("Executing {}", cmd); + let output = client.output(cmd, Environment::new(), None, None).await?; + debug!( + "Completed with success = {}, code = {:?}", + output.success, output.code + ); - // Close out ssh session - cleanup_session(); - session.abort(); - let _ = session.wait().await; - let mut output = Vec::new(); + // Close out ssh client by killing the internal server and client + server.abort(); + client.abort(); + let _ = client.wait().await; - // If successful, grab the session information and establish a connection + // If successful, grab the client information and establish a connection // with the distant server - if success { - while let Ok(data) = stdout.read().await { - output.extend(&data); - } - - // Iterate over output as individual lines, looking for session info + if output.success { + // Iterate over output as individual lines, looking for client info + trace!("Searching for credentials"); let maybe_info = output + .stdout .split(|&b| b == b'\n') .map(String::from_utf8_lossy) - .find_map(|line| line.parse::().ok()); + .find_map(|line| line.parse::().ok()); match maybe_info { Some(mut info) => { info.host = host; @@ -584,21 +652,19 @@ impl Ssh2Session { } None => Err(io::Error::new( io::ErrorKind::InvalidData, - "Missing session data", + "Missing launch information", )), } } else { - while let Ok(data) = stderr.read().await { - output.extend(&data); - } - Err(io::Error::new( io::ErrorKind::Other, format!( "Spawning distant failed [{}]: {}", - code.map(|x| x.to_string()) + output + .code + .map(|x| x.to_string()) .unwrap_or_else(|| String::from("???")), - match String::from_utf8(output) { + match String::from_utf8(output.stderr) { Ok(output) => output, Err(x) => x.to_string(), } @@ -607,13 +673,29 @@ impl Ssh2Session { } } - /// Consume [`Ssh2Session`] and produce a distant [`Session`] that is powered by an ssh client + /// Consume [`Ssh`] and produce a [`DistantClient`] that is powered by an ssh client + /// underneath + pub async fn into_distant_client(self) -> io::Result { + Ok(self.into_distant_pair().await?.0) + } + + /// Consume [`Ssh`] and produce a [`BoxedDistantWriterReader`] that is powered by an ssh client /// underneath - pub async fn into_ssh_client_session(self) -> io::Result { - self.into_ssh_client_session_impl().await.map(|x| x.0) + pub async fn into_distant_writer_reader(self) -> io::Result { + Ok(self.into_writer_reader_and_server().await?.0) } - async fn into_ssh_client_session_impl(self) -> io::Result<(Session, Box)> { + /// Consumes [`Ssh`] and produces a [`DistantClient`] and [`DistantApiServer`] pair + pub async fn into_distant_pair(self) -> io::Result<(DistantClient, Box)> { + let ((writer, reader), server) = self.into_writer_reader_and_server().await?; + let client = DistantClient::new(writer, reader)?; + Ok((client, server)) + } + + /// Consumes [`Ssh`] and produces a [`DistantClient`] and [`DistantApiServer`] pair + async fn into_writer_reader_and_server( + self, + ) -> io::Result<(BoxedDistantWriterReader, Box)> { // Exit early if not authenticated as this is a requirement if !self.authenticated { return Err(io::Error::new( @@ -622,47 +704,24 @@ impl Ssh2Session { )); } - let (t1, t2) = Transport::pair(1); - let tag = format!("ssh {}:{}", self.host, self.port); - let session = Session::initialize_with_details(t1, Some(SessionDetails::Custom { tag }))?; - - // Spawn tasks that forward requests to the ssh session - // and send back responses from the ssh session - let (mut t_read, mut t_write) = t2.into_split(); - let Self { - session: wez_session, - .. - } = self; - - let (tx, mut rx) = mpsc::channel(1); - let request_task = tokio::spawn(async move { - let state = Arc::new(Mutex::new(handler::State::default())); - while let Ok(Some(req)) = t_read.receive::().await { - if let Err(x) = - handler::process(wez_session.clone(), Arc::clone(&state), req, tx.clone()).await - { - error!("Ssh session receiver handler failed: {}", x); - } - } - debug!("Ssh receiver task is now closed"); - }); - - let send_task = tokio::spawn(async move { - while let Some(res) = rx.recv().await { - if let Err(x) = t_write.send(res).await { - error!("Ssh session sender failed: {}", x); - break; - } - } - debug!("Ssh sender task is now closed"); - }); + let (t1, t2) = FramedTransport::pair(1); + + // Spawn a bridge client that is directly connected to our server + let (writer, reader) = t1.into_split(); + let writer: BoxedDistantWriter = Box::new(writer); + let reader: BoxedDistantReader = Box::new(reader); + + // Spawn a bridge server that is directly connected to our client + let server = { + let Self { + session: wez_session, + .. + } = self; + let (writer, reader) = t2.into_split(); + DistantApiServer::new(SshDistantApi::new(wez_session)) + .start(OneshotListener::from_value((writer, reader)))? + }; - Ok(( - session, - Box::new(move || { - send_task.abort(); - request_task.abort(); - }), - )) + Ok(((writer, reader), server)) } } diff --git a/distant-ssh2/src/process.rs b/distant-ssh2/src/process.rs index 05ccd49..29882f1 100644 --- a/distant-ssh2/src/process.rs +++ b/distant-ssh2/src/process.rs @@ -1,9 +1,13 @@ use async_compat::CompatExt; -use distant_core::{PtySize, ResponseData}; +use distant_core::{ + data::{DistantResponseData, Environment, ProcessId, PtySize}, + net::Reply, +}; use log::*; use std::{ future::Future, io::{self, Read, Write}, + path::PathBuf, time::Duration, }; use tokio::{sync::mpsc, task::JoinHandle}; @@ -17,27 +21,47 @@ const THREAD_PAUSE_MILLIS: u64 = 50; /// Result of spawning a process, containing means to send stdin, means to kill the process, /// and the initialization function to use to start processing stdin, stdout, and stderr pub struct SpawnResult { - pub id: usize, + pub id: ProcessId, pub stdin: mpsc::Sender>, pub killer: mpsc::Sender<()>, pub resizer: mpsc::Sender, - pub initialize: Box>) + Send>, } /// Spawns a non-pty process, returning a function that initializes processing /// stdin, stdout, and stderr once called (for lazy processing) -pub async fn spawn_simple(session: &Session, cmd: &str, cleanup: F) -> io::Result +pub async fn spawn_simple( + session: &Session, + cmd: &str, + environment: Environment, + current_dir: Option, + reply: Box>, + cleanup: F, +) -> io::Result where - F: FnOnce(usize) -> R + Send + 'static, + F: FnOnce(ProcessId) -> R + Send + 'static, R: Future + Send + 'static, { + if current_dir.is_some() { + return Err(io::Error::new( + io::ErrorKind::Unsupported, + "current_dir is not supported", + )); + } + let ExecResult { mut stdin, mut stdout, mut stderr, mut child, } = session - .exec(cmd, None) + .exec( + cmd, + if environment.is_empty() { + None + } else { + Some(environment.into_map()) + }, + ) .compat() .await .map_err(to_other_error)?; @@ -61,22 +85,20 @@ where let id = rand::random(); let session = session.clone(); - let initialize = Box::new(move |reply: mpsc::Sender>| { - let stdout_task = spawn_nonblocking_stdout_task(id, stdout, reply.clone()); - let stderr_task = spawn_nonblocking_stderr_task(id, stderr, reply.clone()); - let stdin_task = spawn_nonblocking_stdin_task(id, stdin, stdin_rx); - let _ = spawn_cleanup_task( - session, - id, - child, - kill_rx, - stdin_task, - stdout_task, - Some(stderr_task), - reply, - cleanup, - ); - }); + let stdout_task = spawn_nonblocking_stdout_task(id, stdout, reply.clone_reply()); + let stderr_task = spawn_nonblocking_stderr_task(id, stderr, reply.clone_reply()); + let stdin_task = spawn_nonblocking_stdin_task(id, stdin, stdin_rx); + let _ = spawn_cleanup_task( + session, + id, + child, + kill_rx, + stdin_task, + stdout_task, + Some(stderr_task), + reply, + cleanup, + ); // Create a resizer that is already closed since a simple process does not resize let resizer = mpsc::channel(1).0; @@ -86,7 +108,6 @@ where stdin: stdin_tx, killer: kill_tx, resizer, - initialize, }) } @@ -95,16 +116,38 @@ where pub async fn spawn_pty( session: &Session, cmd: &str, + environment: Environment, + current_dir: Option, size: PtySize, + reply: Box>, cleanup: F, ) -> io::Result where - F: FnOnce(usize) -> R + Send + 'static, + F: FnOnce(ProcessId) -> R + Send + 'static, R: Future + Send + 'static, { - // TODO: Do we need to support other terminal types for TERM? + if current_dir.is_some() { + return Err(io::Error::new( + io::ErrorKind::Unsupported, + "current_dir is not supported", + )); + } + + let term = environment + .get("TERM") + .map(ToString::to_string) + .unwrap_or_else(|| String::from("xterm-256color")); let (pty, mut child) = session - .request_pty("xterm-256color", to_portable_size(size), Some(cmd), None) + .request_pty( + &term, + to_portable_size(size), + Some(cmd), + if environment.is_empty() { + None + } else { + Some(environment.into_map()) + }, + ) .compat() .await .map_err(to_other_error)?; @@ -130,21 +173,19 @@ where let id = rand::random(); let session = session.clone(); - let initialize = Box::new(move |reply: mpsc::Sender>| { - let stdout_task = spawn_blocking_stdout_task(id, reader, reply.clone()); - let stdin_task = spawn_blocking_stdin_task(id, writer, stdin_rx); - let _ = spawn_cleanup_task( - session, - id, - child, - kill_rx, - stdin_task, - stdout_task, - None, - reply, - cleanup, - ); - }); + let stdout_task = spawn_blocking_stdout_task(id, reader, reply.clone_reply()); + let stdin_task = spawn_blocking_stdin_task(id, writer, stdin_rx); + let _ = spawn_cleanup_task( + session, + id, + child, + kill_rx, + stdin_task, + stdout_task, + None, + reply, + cleanup, + ); let (resize_tx, mut resize_rx) = mpsc::channel::(1); tokio::spawn(async move { @@ -160,26 +201,25 @@ where stdin: stdin_tx, killer: kill_tx, resizer: resize_tx, - initialize, }) } fn spawn_blocking_stdout_task( - id: usize, + id: ProcessId, mut reader: impl Read + Send + 'static, - tx: mpsc::Sender>, + reply: Box>, ) -> JoinHandle<()> { tokio::task::spawn_blocking(move || { let mut buf: [u8; MAX_PIPE_CHUNK_SIZE] = [0; MAX_PIPE_CHUNK_SIZE]; loop { match reader.read(&mut buf) { Ok(n) if n > 0 => { - let payload = vec![ResponseData::ProcStdout { + let payload = DistantResponseData::ProcStdout { id, data: buf[..n].to_vec(), - }]; - if tx.blocking_send(payload).is_err() { - error!(" Stdout channel closed", id); + }; + if reply.blocking_send(payload).is_err() { + error!("[Ssh | Proc {}] Stdout channel closed", id); break; } @@ -187,7 +227,7 @@ fn spawn_blocking_stdout_task( } Ok(_) => break, Err(x) => { - error!(" Stdout unexpectedly closed: {}", id, x); + error!("[Ssh | Proc {}] Stdout unexpectedly closed: {}", id, x); break; } } @@ -196,21 +236,21 @@ fn spawn_blocking_stdout_task( } fn spawn_nonblocking_stdout_task( - id: usize, + id: ProcessId, mut reader: impl Read + Send + 'static, - tx: mpsc::Sender>, + reply: Box>, ) -> JoinHandle<()> { tokio::spawn(async move { let mut buf: [u8; MAX_PIPE_CHUNK_SIZE] = [0; MAX_PIPE_CHUNK_SIZE]; loop { match reader.read(&mut buf) { Ok(n) if n > 0 => { - let payload = vec![ResponseData::ProcStdout { + let payload = DistantResponseData::ProcStdout { id, data: buf[..n].to_vec(), - }]; - if tx.send(payload).await.is_err() { - error!(" Stdout channel closed", id); + }; + if reply.send(payload).await.is_err() { + error!("[Ssh | Proc {}] Stdout channel closed", id); break; } @@ -221,7 +261,7 @@ fn spawn_nonblocking_stdout_task( tokio::time::sleep(Duration::from_millis(THREAD_PAUSE_MILLIS)).await; } Err(x) => { - error!(" Stdout unexpectedly closed: {}", id, x); + error!("[Ssh | Proc {}] Stdout unexpectedly closed: {}", id, x); break; } } @@ -230,21 +270,21 @@ fn spawn_nonblocking_stdout_task( } fn spawn_nonblocking_stderr_task( - id: usize, + id: ProcessId, mut reader: impl Read + Send + 'static, - tx: mpsc::Sender>, + reply: Box>, ) -> JoinHandle<()> { tokio::spawn(async move { let mut buf: [u8; MAX_PIPE_CHUNK_SIZE] = [0; MAX_PIPE_CHUNK_SIZE]; loop { match reader.read(&mut buf) { Ok(n) if n > 0 => { - let payload = vec![ResponseData::ProcStderr { + let payload = DistantResponseData::ProcStderr { id, data: buf[..n].to_vec(), - }]; - if tx.send(payload).await.is_err() { - error!(" Stderr channel closed", id); + }; + if reply.send(payload).await.is_err() { + error!("[Ssh | Proc {}] Stderr channel closed", id); break; } @@ -255,7 +295,7 @@ fn spawn_nonblocking_stderr_task( tokio::time::sleep(Duration::from_millis(THREAD_PAUSE_MILLIS)).await; } Err(x) => { - error!(" Stderr unexpectedly closed: {}", id, x); + error!("[Ssh | Proc {}] Stderr unexpectedly closed: {}", id, x); break; } } @@ -264,14 +304,14 @@ fn spawn_nonblocking_stderr_task( } fn spawn_blocking_stdin_task( - id: usize, + id: ProcessId, mut writer: impl Write + Send + 'static, mut rx: mpsc::Receiver>, ) -> JoinHandle<()> { tokio::task::spawn_blocking(move || { while let Some(data) = rx.blocking_recv() { if let Err(x) = writer.write_all(&data) { - error!(" Failed to send stdin: {}", id, x); + error!("[Ssh | Proc {}] Failed to send stdin: {}", id, x); break; } @@ -281,7 +321,7 @@ fn spawn_blocking_stdin_task( } fn spawn_nonblocking_stdin_task( - id: usize, + id: ProcessId, mut writer: impl Write + Send + 'static, mut rx: mpsc::Receiver>, ) -> JoinHandle<()> { @@ -291,7 +331,7 @@ fn spawn_nonblocking_stdin_task( // In non-blocking mode, we'll just pause and try again if // the IO would block here; otherwise, stop the task if x.kind() != io::ErrorKind::WouldBlock { - error!(" Failed to send stdin: {}", id, x); + error!("[Ssh | Proc {}] Failed to send stdin: {}", id, x); break; } } @@ -304,17 +344,17 @@ fn spawn_nonblocking_stdin_task( #[allow(clippy::too_many_arguments)] fn spawn_cleanup_task( session: Session, - id: usize, + id: ProcessId, mut child: SshChildProcess, mut kill_rx: mpsc::Receiver<()>, stdin_task: JoinHandle<()>, stdout_task: JoinHandle<()>, stderr_task: Option>, - tx: mpsc::Sender>, + reply: Box>, cleanup: F, ) -> JoinHandle<()> where - F: FnOnce(usize) -> R + Send + 'static, + F: FnOnce(ProcessId) -> R + Send + 'static, R: Future + Send + 'static, { tokio::spawn(async move { @@ -330,7 +370,7 @@ where success = status.success(); } Err(x) => { - error!(" Waiting on process failed: {}", id, x); + error!("[Ssh | Proc {}] Waiting on process failed: {}", id, x); } } } @@ -341,10 +381,10 @@ where stdin_task.abort(); if should_kill { - debug!(" Killing", id); + debug!("[Ssh | Proc {}] Killing", id); if let Err(x) = child.kill() { - error!(" Unable to kill process: {}", id, x); + error!("[Ssh | Proc {}] Unable to kill process: {}", id, x); } // NOTE: At the moment, child.kill does nothing for wezterm_ssh::SshChildProcess; @@ -362,7 +402,7 @@ where } } else { debug!( - " Completed and waiting on stdout & stderr tasks", + "[Ssh | Proc {}] Completed and waiting on stdout & stderr tasks", id ); } @@ -372,24 +412,24 @@ where if let Some(task) = stderr_task { if let Err(x) = task.await { - error!(" Join on stderr task failed: {}", id, x); + error!("[Ssh | Proc {}] Join on stderr task failed: {}", id, x); } } if let Err(x) = stdout_task.await { - error!(" Join on stdout task failed: {}", id, x); + error!("[Ssh | Proc {}] Join on stdout task failed: {}", id, x); } cleanup(id).await; - let payload = vec![ResponseData::ProcDone { + let payload = DistantResponseData::ProcDone { id, success: !should_kill && success, code: if success { Some(0) } else { None }, - }]; + }; - if tx.send(payload).await.is_err() { - error!(" Failed to send done", id,); + if reply.send(payload).await.is_err() { + error!("[Ssh | Proc {}] Failed to send done", id,); } }) } diff --git a/distant-ssh2/src/utils.rs b/distant-ssh2/src/utils.rs new file mode 100644 index 0000000..b6ea836 --- /dev/null +++ b/distant-ssh2/src/utils.rs @@ -0,0 +1,293 @@ +use async_compat::CompatExt; +use std::{ + fmt, io, + path::{Component, Path, PathBuf, Prefix}, + time::Duration, +}; +use wezterm_ssh::{ExecResult, Session, Sftp}; + +#[allow(dead_code)] +const READER_PAUSE_MILLIS: u64 = 100; + +#[derive(Clone, PartialEq, Eq)] +pub struct ExecOutput { + pub success: bool, + pub stdout: Vec, + pub stderr: Vec, +} + +impl fmt::Debug for ExecOutput { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let is_alternate = f.alternate(); + + let mut s = f.debug_struct("ExecOutput"); + s.field("success", &self.success); + + if is_alternate { + s.field("stdout", &String::from_utf8_lossy(&self.stdout)) + .field("stderr", &String::from_utf8_lossy(&self.stderr)); + } else { + s.field("stdout", &self.stdout) + .field("stderr", &self.stderr); + } + + s.finish() + } +} + +#[allow(dead_code)] +pub async fn execute_output(session: &Session, cmd: &str) -> io::Result { + let ExecResult { + mut child, + mut stdout, + mut stderr, + .. + } = session + .exec(cmd, None) + .compat() + .await + .map_err(to_other_error)?; + + macro_rules! spawn_reader { + ($reader:ident) => {{ + $reader.set_non_blocking(true).map_err(to_other_error)?; + tokio::spawn(async move { + use std::io::Read; + let mut bytes = Vec::new(); + let mut buf = [0u8; 1024]; + loop { + match $reader.read(&mut buf) { + Ok(n) if n > 0 => bytes.extend(&buf[..n]), + Ok(_) => break Ok(bytes), + Err(x) if x.kind() == io::ErrorKind::WouldBlock => { + tokio::time::sleep(Duration::from_millis(READER_PAUSE_MILLIS)).await; + } + Err(x) => break Err(x), + } + } + }) + }}; + } + + // Spawn async readers for stdout and stderr from process + let stdout_handle = spawn_reader!(stdout); + let stderr_handle = spawn_reader!(stderr); + + // Wait for our handles to conclude + let stdout = stdout_handle.await.map_err(to_other_error)??; + let stderr = stderr_handle.await.map_err(to_other_error)??; + + // Wait for process to conclude + let status = child.async_wait().compat().await.map_err(to_other_error)?; + + Ok(ExecOutput { + success: status.success(), + stdout, + stderr, + }) +} + +/// Performs canonicalization of the given path using SFTP with various handling of Windows paths +pub async fn canonicalize(sftp: &Sftp, path: impl AsRef) -> io::Result { + // Determine if we are supplying a Windows path + let mut is_windows_path = path + .as_ref() + .components() + .any(|c| matches!(c, Component::Prefix(_))); + + // Try to canonicalize original path first + let result = sftp + .canonicalize(path.as_ref().to_path_buf()) + .compat() + .await; + + // If we don't see the path initially as a Windows path, but we can find a drive letter after + // canonicalization, still treat it as a windows path + // + // NOTE: This is for situations where we are given a relative path like '.' where we cannot + // infer the path is for Windows out of the box + if !is_windows_path { + if let Ok(path) = result.as_ref() { + is_windows_path = drive_letter(path.as_std_path()).is_some(); + } + } + + // If result is a failure, we want to try again with a unix path in case we were using + // a windows path and sshd had a problem with canonicalizing it + let unix_path = if result.is_err() && is_windows_path { + Some(to_unix_path(path.as_ref())) + } else { + None + }; + + // 1. If we succeeded on first try, return that path + // a. If the canonicalized path was for a Windows path, sftp may return something odd + // like C:\Users\example -> /c:/Users/example and we need to transform it back + // b. Otherwise, if the input path was a unix path, we return canonicalized as is + // 2. If we failed on first try and have a clear Windows path, try the unix version + // and then convert result back to windows version, return our original error if we fail + // 3. If we failed and there is no valid unix path for a Windows path, return the + // original error + match (result, unix_path) { + (Ok(path), _) if is_windows_path => Ok(to_windows_path(path.as_std_path())), + (Ok(path), _) => Ok(path.into_std_path_buf()), + (Err(x), Some(path)) => Ok(to_windows_path( + &sftp + .canonicalize(path.to_path_buf()) + .compat() + .await + .map_err(|_| to_other_error(x))? + .into_std_path_buf(), + )), + (Err(x), None) => Err(to_other_error(x)), + } +} + +/// Convert a path into unix-oriented path +/// +/// E.g. C:\Users\example\Documents\file.txt -> /c/Users/example/Documents/file.txt +pub fn to_unix_path(path: &Path) -> PathBuf { + let is_windows_path = path.components().any(|c| matches!(c, Component::Prefix(_))); + + if !is_windows_path { + return path.to_path_buf(); + } + + let mut p = PathBuf::new(); + for component in path.components() { + match component { + Component::Prefix(x) => match x.kind() { + Prefix::Verbatim(path) => p.push(path), + Prefix::VerbatimUNC(hostname, share) => { + p.push(hostname); + p.push(share); + } + Prefix::VerbatimDisk(letter) => { + p.push(format!("/{}", letter as char)); + } + Prefix::DeviceNS(device_name) => p.push(device_name), + Prefix::UNC(hostname, share) => { + p.push(hostname); + p.push(share); + } + Prefix::Disk(letter) => { + p.push(format!("/{}", letter as char)); + } + }, + + // If we have a prefix, then we are dropping it and converting into + // a root and normal component, so we will now skip this root + Component::RootDir => continue, + + x => p.push(x), + } + } + + p +} + +/// Convert a path into windows-oriented path +/// +/// E.g. /c/Users/example/Documents/file.txt -> C:\Users\example\Documents\file.txt +pub fn to_windows_path(path: &Path) -> PathBuf { + let is_windows_path = path.components().any(|c| matches!(c, Component::Prefix(_))); + + if is_windows_path { + return path.to_path_buf(); + } + + // See if we have a drive letter at the beginning, otherwise default to C:\ + let drive_letter = drive_letter(path); + + let mut p = PathBuf::new(); + + // Start with a drive prefix + p.push(format!("{}:", drive_letter.unwrap_or('C'))); + + let mut components = path.components(); + + // If we start with a root portion of the regular path, we want to drop + // it and the drive letter since we've added that separately + if path.has_root() { + p.push(Component::RootDir); + components.next(); + + if drive_letter.is_some() { + components.next(); + } + } + + for component in components { + p.push(component); + } + + p +} + +/// Looks for a drive letter in the given path +pub fn drive_letter(path: &Path) -> Option { + // See if we are a windows path, and if so grab the letter from the components + let maybe_letter = path.components().find_map(|c| match c { + Component::Prefix(x) => match x.kind() { + Prefix::Disk(letter) | Prefix::VerbatimDisk(letter) => Some(letter as char), + _ => None, + }, + _ => None, + }); + + if let Some(letter) = maybe_letter { + return Some(letter); + } + + // If there was no drive letter and we are not a root, there is nothing left to find + if !path.has_root() { + return None; + } + + // Otherwise, scan just after root for a drive letter + path.components().nth(1).and_then(|c| match c { + Component::Normal(s) => s.to_str().and_then(|s| { + let mut chars = s.chars(); + let first = chars.next(); + let second = chars.next(); + let has_more = chars.next().is_some(); + + if has_more { + return None; + } + + match (first, second) { + (letter, Some(':') | None) => letter, + _ => None, + } + }), + _ => None, + }) +} + +/// Determines if using windows by checking the canonicalized path of '.' +pub async fn is_windows(sftp: &Sftp) -> io::Result { + // Look up the current directory + let current_dir = canonicalize(sftp, ".").await?; + + // TODO: Ideally, we would determine the family using something like the following: + // + // cmd.exe /C echo %OS% + // + // Determine OS by printing OS variable (works with Windows 2000+) + // If it matches Windows_NT, then we are on windows + // + // However, the above is not working for whatever reason (always has success == false); so, + // we're purely using a check if we have a drive letter on the canonicalized path to + // determine if on windows for now. Some sort of failure with SIGPIPE + Ok(current_dir + .components() + .any(|c| matches!(c, Component::Prefix(_)))) +} + +pub fn to_other_error(err: E) -> io::Error +where + E: Into>, +{ + io::Error::new(io::ErrorKind::Other, err) +} diff --git a/distant-ssh2/tests/lib.rs b/distant-ssh2/tests/lib.rs index 0456b85..1599a36 100644 --- a/distant-ssh2/tests/lib.rs +++ b/distant-ssh2/tests/lib.rs @@ -1,2 +1,3 @@ mod ssh2; mod sshd; +mod utils; diff --git a/distant-ssh2/tests/ssh2/client.rs b/distant-ssh2/tests/ssh2/client.rs new file mode 100644 index 0000000..1f109fd --- /dev/null +++ b/distant-ssh2/tests/ssh2/client.rs @@ -0,0 +1,1447 @@ +use crate::sshd::*; +use assert_fs::{prelude::*, TempDir}; +use distant_core::{ + data::{ChangeKindSet, Environment, FileType, Metadata}, + DistantChannelExt, DistantClient, +}; +use once_cell::sync::Lazy; +use predicates::prelude::*; +use rstest::*; +use std::{io, path::Path, time::Duration}; + +static TEMP_SCRIPT_DIR: Lazy = Lazy::new(|| TempDir::new().unwrap()); +static SCRIPT_RUNNER: Lazy = Lazy::new(|| String::from("bash")); + +static ECHO_ARGS_TO_STDOUT_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_args_to_stdout.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + printf "%s" "$*" + "# + )) + .unwrap(); + script +}); + +static ECHO_ARGS_TO_STDERR_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_args_to_stderr.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + printf "%s" "$*" 1>&2 + "# + )) + .unwrap(); + script +}); + +static ECHO_STDIN_TO_STDOUT_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_stdin_to_stdout.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + while IFS= read; do echo "$REPLY"; done + "# + )) + .unwrap(); + script +}); + +static SLEEP_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("sleep.sh"); + script + .write_str(indoc::indoc!( + r#" + #!/usr/bin/env bash + sleep "$1" + "# + )) + .unwrap(); + script +}); + +static DOES_NOT_EXIST_BIN: Lazy = + Lazy::new(|| TEMP_SCRIPT_DIR.child("does_not_exist_bin")); + +#[rstest] +#[tokio::test] +async fn read_file_should_fail_if_file_missing(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let path = temp.child("missing-file").path().to_path_buf(); + + let _ = client.read_file(path).await.unwrap_err(); +} + +#[rstest] +#[tokio::test] +async fn read_file_should_send_blob_with_file_contents(#[future] client: Ctx) { + let mut client = client.await; + + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + let bytes = client.read_file(file.path().to_path_buf()).await.unwrap(); + assert_eq!(bytes, b"some file contents"); +} + +#[rstest] +#[tokio::test] +async fn read_file_text_should_send_error_if_fails_to_read_file( + #[future] client: Ctx, +) { + let mut client = client.await; + + let temp = assert_fs::TempDir::new().unwrap(); + let path = temp.child("missing-file").path().to_path_buf(); + + let _ = client.read_file_text(path).await.unwrap_err(); +} + +#[rstest] +#[tokio::test] +async fn read_file_text_should_send_text_with_file_contents(#[future] client: Ctx) { + let mut client = client.await; + + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + let text = client + .read_file_text(file.path().to_path_buf()) + .await + .unwrap(); + assert_eq!(text, "some file contents"); +} + +#[rstest] +#[tokio::test] +async fn write_file_should_send_error_if_fails_to_write_file(#[future] client: Ctx) { + let mut client = client.await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + let _ = client + .write_file(file.path().to_path_buf(), b"some text".to_vec()) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn write_file_should_send_ok_when_successful(#[future] client: Ctx) { + let mut client = client.await; + + // Path should point to a file that does not exist, but all + // other components leading up to it do + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + client + .write_file(file.path().to_path_buf(), b"some text".to_vec()) + .await + .unwrap(); + + // Also verify that we actually did create the file + // with the associated contents + file.assert("some text"); +} + +#[rstest] +#[tokio::test] +async fn write_file_text_should_send_error_if_fails_to_write_file( + #[future] client: Ctx, +) { + let mut client = client.await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + let _ = client + .write_file_text(file.path().to_path_buf(), "some text".to_string()) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn write_file_text_should_send_ok_when_successful(#[future] client: Ctx) { + let mut client = client.await; + + // Path should point to a file that does not exist, but all + // other components leading up to it do + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + client + .write_file_text(file.path().to_path_buf(), "some text".to_string()) + .await + .unwrap(); + + // Also verify that we actually did create the file + // with the associated contents + file.assert("some text"); +} + +#[rstest] +#[tokio::test] +async fn append_file_should_send_error_if_fails_to_create_file( + #[future] client: Ctx, +) { + let mut client = client.await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + let _ = client + .append_file(file.path().to_path_buf(), b"some extra contents".to_vec()) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn append_file_should_create_file_if_missing(#[future] client: Ctx) { + let mut client = client.await; + + // Don't create the file directly, but define path + // where the file should be + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + client + .append_file(file.path().to_path_buf(), b"some extra contents".to_vec()) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did create to the file + file.assert("some extra contents"); +} + +#[rstest] +#[tokio::test] +async fn append_file_should_send_ok_when_successful(#[future] client: Ctx) { + let mut client = client.await; + + // Create a temporary file and fill it with some contents + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + client + .append_file(file.path().to_path_buf(), b"some extra contents".to_vec()) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did append to the file + file.assert("some file contentssome extra contents"); +} + +#[rstest] +#[tokio::test] +async fn append_file_text_should_send_error_if_fails_to_create_file( + #[future] client: Ctx, +) { + let mut client = client.await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + let _ = client + .append_file_text(file.path().to_path_buf(), "some extra contents".to_string()) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn append_file_text_should_create_file_if_missing(#[future] client: Ctx) { + let mut client = client.await; + + // Don't create the file directly, but define path + // where the file should be + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + client + .append_file_text(file.path().to_path_buf(), "some extra contents".to_string()) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did create to the file + file.assert("some extra contents"); +} + +#[rstest] +#[tokio::test] +async fn append_file_text_should_send_ok_when_successful(#[future] client: Ctx) { + let mut client = client.await; + + // Create a temporary file and fill it with some contents + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + client + .append_file_text(file.path().to_path_buf(), "some extra contents".to_string()) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did append to the file + file.assert("some file contentssome extra contents"); +} + +#[rstest] +#[tokio::test] +async fn dir_read_should_send_error_if_directory_does_not_exist( + #[future] client: Ctx, +) { + let mut client = client.await; + + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("test-dir"); + + let _ = client + .read_dir( + dir.path().to_path_buf(), + /* depth */ 0, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap_err(); +} + +// /root/ +// /root/file1 +// /root/link1 -> /root/sub1/file2 +// /root/sub1/ +// /root/sub1/file2 +async fn setup_dir() -> assert_fs::TempDir { + let root_dir = assert_fs::TempDir::new().unwrap(); + root_dir.child("file1").touch().unwrap(); + + let sub1 = root_dir.child("sub1"); + sub1.create_dir_all().unwrap(); + + let file2 = sub1.child("file2"); + file2.touch().unwrap(); + + let link1 = root_dir.child("link1"); + link1.symlink_to_file(file2.path()).unwrap(); + + root_dir +} + +#[rstest] +#[tokio::test] +async fn dir_read_should_support_depth_limits(#[future] client: Ctx) { + let mut client = client.await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = client + .read_dir( + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, Path::new("link1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, Path::new("sub1")); + assert_eq!(entries[2].depth, 1); +} + +#[rstest] +#[tokio::test] +async fn dir_read_should_support_unlimited_depth_using_zero(#[future] client: Ctx) { + let mut client = client.await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = client + .read_dir( + root_dir.path().to_path_buf(), + /* depth */ 0, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 4, "Wrong number of entries found"); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, Path::new("link1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, Path::new("sub1")); + assert_eq!(entries[2].depth, 1); + + assert_eq!(entries[3].file_type, FileType::File); + assert_eq!(entries[3].path, Path::new("sub1").join("file2")); + assert_eq!(entries[3].depth, 2); +} + +// NOTE: This is failing on windows as canonicalization of root path is not correct! +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn dir_read_should_support_including_directory_in_returned_entries( + #[future] client: Ctx, +) { + let mut client = client.await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = client + .read_dir( + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ true, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 4, "Wrong number of entries found"); + + // NOTE: Root entry is always absolute, resolved path + assert_eq!(entries[0].file_type, FileType::Dir); + assert_eq!( + entries[0].path, + dunce::canonicalize(root_dir.path()).unwrap() + ); + assert_eq!(entries[0].depth, 0); + + assert_eq!(entries[1].file_type, FileType::File); + assert_eq!(entries[1].path, Path::new("file1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Symlink); + assert_eq!(entries[2].path, Path::new("link1")); + assert_eq!(entries[2].depth, 1); + + assert_eq!(entries[3].file_type, FileType::Dir); + assert_eq!(entries[3].path, Path::new("sub1")); + assert_eq!(entries[3].depth, 1); +} + +// NOTE: This is failing on windows as canonicalization of root path is not correct! +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn dir_read_should_support_returning_absolute_paths(#[future] client: Ctx) { + let mut client = client.await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = client + .read_dir( + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ true, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + let root_path = dunce::canonicalize(root_dir.path()).unwrap(); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, root_path.join("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, root_path.join("link1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, root_path.join("sub1")); + assert_eq!(entries[2].depth, 1); +} + +// NOTE: This is failing on windows as the symlink does not get resolved! +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn dir_read_should_support_returning_canonicalized_paths( + #[future] client: Ctx, +) { + let mut client = client.await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = client + .read_dir( + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ false, + /* canonicalize */ true, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + println!("{:?}", entries); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Dir); + assert_eq!(entries[1].path, Path::new("sub1")); + assert_eq!(entries[1].depth, 1); + + // Symlink should be resolved from $ROOT/link1 -> $ROOT/sub1/file2 + assert_eq!(entries[2].file_type, FileType::Symlink); + assert_eq!(entries[2].path, Path::new("sub1").join("file2")); + assert_eq!(entries[2].depth, 1); +} + +#[rstest] +#[tokio::test] +async fn create_dir_should_send_error_if_fails(#[future] client: Ctx) { + let mut client = client.await; + + // Make a path that has multiple non-existent components + // so the creation will fail + let root_dir = setup_dir().await; + let path = root_dir.path().join("nested").join("new-dir"); + + let _ = client + .create_dir(path.to_path_buf(), /* all */ false) + .await + .unwrap_err(); + + // Also verify that the directory was not actually created + assert!(!path.exists(), "Path unexpectedly exists"); +} + +#[rstest] +#[tokio::test] +async fn create_dir_should_send_ok_when_successful(#[future] client: Ctx) { + let mut client = client.await; + let root_dir = setup_dir().await; + let path = root_dir.path().join("new-dir"); + + client + .create_dir(path.to_path_buf(), /* all */ false) + .await + .unwrap(); + + // Also verify that the directory was actually created + assert!(path.exists(), "Directory not created"); +} + +#[rstest] +#[tokio::test] +async fn create_dir_should_support_creating_multiple_dir_components( + #[future] client: Ctx, +) { + let mut client = client.await; + let root_dir = setup_dir().await; + let path = root_dir.path().join("nested").join("new-dir"); + + client + .create_dir(path.to_path_buf(), /* all */ true) + .await + .unwrap(); + + // Also verify that the directory was actually created + assert!(path.exists(), "Directory not created"); +} + +#[rstest] +#[tokio::test] +async fn remove_should_send_error_on_failure(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("missing-file"); + + let _ = client + .remove(file.path().to_path_buf(), /* false */ false) + .await + .unwrap_err(); + + // Also, verify that path does not exist + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn remove_should_support_deleting_a_directory(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + client + .remove(dir.path().to_path_buf(), /* false */ false) + .await + .unwrap(); + + // Also, verify that path does not exist + dir.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn remove_should_delete_nonempty_directory_if_force_is_true( + #[future] client: Ctx, +) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + dir.child("file").touch().unwrap(); + + client + .remove(dir.path().to_path_buf(), /* false */ true) + .await + .unwrap(); + + // Also, verify that path does not exist + dir.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn remove_should_support_deleting_a_single_file(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("some-file"); + file.touch().unwrap(); + + client + .remove(file.path().to_path_buf(), /* false */ false) + .await + .unwrap(); + + // Also, verify that path does not exist + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn copy_should_send_error_on_failure(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + let dst = temp.child("dst"); + + let _ = client + .copy(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap_err(); + + // Also, verify that destination does not exist + dst.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn copy_should_support_copying_an_entire_directory(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let src_file = src.child("file"); + src_file.write_str("some contents").unwrap(); + + let dst = temp.child("dst"); + let dst_file = dst.child("file"); + + client + .copy(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we have source and destination directories and associated contents + src.assert(predicate::path::is_dir()); + src_file.assert(predicate::path::is_file()); + dst.assert(predicate::path::is_dir()); + dst_file.assert(predicate::path::eq_file(src_file.path())); +} + +#[rstest] +#[tokio::test] +async fn copy_should_support_copying_an_empty_directory(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let dst = temp.child("dst"); + + client + .copy(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we still have source and destination directories + src.assert(predicate::path::is_dir()); + dst.assert(predicate::path::is_dir()); +} + +#[rstest] +#[tokio::test] +async fn copy_should_support_copying_a_directory_that_only_contains_directories( + #[future] client: Ctx, +) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let src_dir = src.child("dir"); + src_dir.create_dir_all().unwrap(); + + let dst = temp.child("dst"); + let dst_dir = dst.child("dir"); + + client + .copy(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we have source and destination directories and associated contents + src.assert(predicate::path::is_dir().name("src")); + src_dir.assert(predicate::path::is_dir().name("src/dir")); + dst.assert(predicate::path::is_dir().name("dst")); + dst_dir.assert(predicate::path::is_dir().name("dst/dir")); +} + +#[rstest] +#[tokio::test] +async fn copy_should_support_copying_a_single_file(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + src.write_str("some text").unwrap(); + let dst = temp.child("dst"); + + client + .copy(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we still have source and that destination has source's contents + src.assert(predicate::path::is_file()); + dst.assert(predicate::path::eq_file(src.path())); +} + +#[rstest] +#[tokio::test] +async fn rename_should_fail_if_path_missing(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + let dst = temp.child("dst"); + + let _ = client + .rename(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap_err(); + + // Also, verify that destination does not exist + dst.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn rename_should_support_renaming_an_entire_directory(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let src_file = src.child("file"); + src_file.write_str("some contents").unwrap(); + + let dst = temp.child("dst"); + let dst_file = dst.child("file"); + + client + .rename(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we moved the contents + src.assert(predicate::path::missing()); + src_file.assert(predicate::path::missing()); + dst.assert(predicate::path::is_dir()); + dst_file.assert("some contents"); +} + +#[rstest] +#[tokio::test] +async fn rename_should_support_renaming_a_single_file(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + src.write_str("some text").unwrap(); + let dst = temp.child("dst"); + + client + .rename(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we moved the file + src.assert(predicate::path::missing()); + dst.assert("some text"); +} + +#[rstest] +#[tokio::test] +async fn watch_should_fail_as_unsupported(#[future] client: Ctx) { + // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + + let file = temp.child("file"); + file.touch().unwrap(); + + let err = client + .watch( + file.path().to_path_buf(), + /* recursive */ false, + /* only */ ChangeKindSet::default(), + /* except */ ChangeKindSet::default(), + ) + .await + .unwrap_err(); + + assert_eq!(err.kind(), io::ErrorKind::Unsupported, "{:?}", err); +} + +#[rstest] +#[tokio::test] +async fn exists_should_send_true_if_path_exists(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.touch().unwrap(); + + let exists = client.exists(file.path().to_path_buf()).await.unwrap(); + assert!(exists, "Expected exists to be true, but was false"); +} + +#[rstest] +#[tokio::test] +async fn exists_should_send_false_if_path_does_not_exist(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + + let exists = client.exists(file.path().to_path_buf()).await.unwrap(); + assert!(!exists, "Expected exists to be false, but was true"); +} + +#[rstest] +#[tokio::test] +async fn metadata_should_send_error_on_failure(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + + let _ = client + .metadata( + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap_err(); +} + +#[rstest] +#[tokio::test] +async fn metadata_should_send_back_metadata_on_file_if_exists( + #[future] client: Ctx, +) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let metadata = client + .metadata( + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + canonicalized_path: None, + file_type: FileType::File, + len: 9, + readonly: false, + .. + } + ), + "{:?}", + metadata + ); +} + +#[cfg(unix)] +#[rstest] +#[tokio::test] +async fn metadata_should_include_unix_specific_metadata_on_unix_platform( + #[future] client: Ctx, +) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let metadata = client + .metadata( + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + #[allow(clippy::match_single_binding)] + match metadata { + Metadata { unix, windows, .. } => { + assert!(unix.is_some(), "Unexpectedly missing unix metadata on unix"); + assert!( + windows.is_none(), + "Unexpectedly got windows metadata on unix" + ); + } + } +} + +#[cfg(windows)] +#[rstest] +#[tokio::test] +async fn metadata_should_not_include_windows_as_ssh_cannot_retrieve_that_information( + #[future] client: Ctx, +) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let metadata = client + .metadata( + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + #[allow(clippy::match_single_binding)] + match metadata { + Metadata { unix, windows, .. } => { + assert!( + windows.is_none(), + "Unexpectedly got windows metadata on windows (support added?)" + ); + + // NOTE: Still includes unix metadata + assert!( + unix.is_some(), + "Unexpectedly missing unix metadata from sshd (even on windows)" + ); + } + } +} + +#[rstest] +#[tokio::test] +async fn metadata_should_send_back_metadata_on_dir_if_exists(#[future] client: Ctx) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + let metadata = client + .metadata( + dir.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + canonicalized_path: None, + file_type: FileType::Dir, + readonly: false, + .. + } + ), + "{:?}", + metadata + ); +} + +#[rstest] +#[tokio::test] +async fn metadata_should_send_back_metadata_on_symlink_if_exists( + #[future] client: Ctx, +) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let symlink = temp.child("link"); + symlink.symlink_to_file(file.path()).unwrap(); + + let metadata = client + .metadata( + symlink.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + canonicalized_path: None, + file_type: FileType::Symlink, + readonly: false, + .. + } + ), + "{:?}", + metadata + ); +} + +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn metadata_should_include_canonicalized_path_if_flag_specified( + #[future] client: Ctx, +) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let symlink = temp.child("link"); + symlink.symlink_to_file(file.path()).unwrap(); + + let metadata = client + .metadata( + symlink.path().to_path_buf(), + /* canonicalize */ true, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + // NOTE: This is failing on windows as the symlink does not get resolved! + match metadata { + Metadata { + canonicalized_path: Some(path), + file_type: FileType::Symlink, + readonly: false, + .. + } => assert_eq!( + path, + dunce::canonicalize(file.path()).unwrap(), + "Symlink canonicalized path does not match referenced file" + ), + x => panic!("Unexpected response: {:?}", x), + } +} + +#[rstest] +#[tokio::test] +async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified( + #[future] client: Ctx, +) { + let mut client = client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let symlink = temp.child("link"); + symlink.symlink_to_file(file.path()).unwrap(); + + let metadata = client + .metadata( + symlink.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ true, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + file_type: FileType::File, + .. + } + ), + "{:?}", + metadata + ); +} + +#[rstest] +#[tokio::test] +async fn proc_spawn_should_not_fail_even_if_process_not_found( + #[future] client: Ctx, +) { + let mut client = client.await; + + // NOTE: This is a distinction from standard distant and ssh distant + let _ = client + .spawn( + /* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); +} + +#[rstest] +#[tokio::test] +async fn proc_spawn_should_return_id_of_spawned_process(#[future] client: Ctx) { + let mut client = client.await; + + let proc = client + .spawn( + /* cmd */ + format!( + "{} {}", + *SCRIPT_RUNNER, + ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + assert!(proc.id() > 0); +} + +// NOTE: Ignoring on windows because it's using WSL which wants a Linux path +// with / but thinks it's on windows and is providing \ +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn proc_spawn_should_send_back_stdout_periodically_when_available( + #[future] client: Ctx, +) { + let mut client = client.await; + + let mut proc = client + .spawn( + /* cmd */ + format!( + "{} {} some stdout", + *SCRIPT_RUNNER, + ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + assert_eq!( + proc.stdout.as_mut().unwrap().read().await.unwrap(), + b"some stdout" + ); + assert!( + proc.wait().await.unwrap().success, + "Process should have completed successfully" + ); +} + +// NOTE: Ignoring on windows because it's using WSL which wants a Linux path +// with / but thinks it's on windows and is providing \ +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn proc_spawn_should_send_back_stderr_periodically_when_available( + #[future] client: Ctx, +) { + let mut client = client.await; + + let mut proc = client + .spawn( + /* cmd */ + format!( + "{} {} some stderr", + *SCRIPT_RUNNER, + ECHO_ARGS_TO_STDERR_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + assert_eq!( + proc.stderr.as_mut().unwrap().read().await.unwrap(), + b"some stderr" + ); + assert!( + proc.wait().await.unwrap().success, + "Process should have completed successfully" + ); +} + +// NOTE: Ignoring on windows because it's using WSL which wants a Linux path +// with / but thinks it's on windows and is providing \ +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn proc_spawn_should_send_done_signal_when_completed(#[future] client: Ctx) { + let mut client = client.await; + + let proc = client + .spawn( + /* cmd */ + format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + let _ = proc.wait().await.unwrap(); +} + +#[rstest] +#[tokio::test] +async fn proc_spawn_should_clear_process_from_state_when_killed( + #[future] client: Ctx, +) { + let mut client = client.await; + + let mut proc = client + .spawn( + /* cmd */ + format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Send kill signal + proc.kill().await.unwrap(); + + // Verify killed, which should be success false + let status = proc.wait().await.unwrap(); + assert!(!status.success, "Process succeeded when killed") +} + +#[rstest] +#[tokio::test] +async fn proc_kill_should_fail_if_process_not_running(#[future] client: Ctx) { + let mut client = client.await; + + let mut proc = client + .spawn( + /* cmd */ + format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Send kill signal + proc.kill().await.unwrap(); + + // Wait for process to be dead + let mut killer = proc.clone_killer(); + let _ = proc.wait().await.unwrap(); + + // Now send it again, which should fail + let _ = killer.kill().await.unwrap_err(); +} + +#[rstest] +#[tokio::test] +async fn proc_stdin_should_fail_if_process_not_running(#[future] client: Ctx) { + let mut client = client.await; + + let mut proc = client + .spawn( + /* cmd */ + format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Send kill signal + proc.kill().await.unwrap(); + + // Wait for process to be dead + let mut stdin = proc.stdin.take().unwrap(); + let _ = proc.wait().await.unwrap(); + + // Now send stdin, which should fail + let _ = stdin.write_str("some data").await.unwrap_err(); +} + +// NOTE: Ignoring on windows because it's using WSL which wants a Linux path +// with / but thinks it's on windows and is providing \ +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn proc_stdin_should_send_stdin_to_process(#[future] client: Ctx) { + let mut client = client.await; + + // First, run a program that listens for stdin + let mut proc = client + .spawn( + /* cmd */ + format!( + "{} {}", + *SCRIPT_RUNNER, + ECHO_STDIN_TO_STDOUT_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Second, send stdin to the remote process + proc.stdin + .as_mut() + .unwrap() + .write_str("hello world\n") + .await + .unwrap(); + + // Third, check the async response of stdout to verify we got stdin + assert_eq!( + proc.stdout.as_mut().unwrap().read_string().await.unwrap(), + "hello world\n" + ); +} + +#[rstest] +#[tokio::test] +async fn system_info_should_return_system_info_based_on_binary( + #[future] client: Ctx, +) { + let mut client = client.await; + + let system_info = client.system_info().await.unwrap(); + assert_eq!(system_info.family, std::env::consts::FAMILY.to_string()); + assert_eq!(system_info.os, ""); + assert_eq!(system_info.arch, ""); + assert_eq!(system_info.main_separator, std::path::MAIN_SEPARATOR); +} diff --git a/distant-ssh2/tests/ssh2/launched.rs b/distant-ssh2/tests/ssh2/launched.rs new file mode 100644 index 0000000..0cb7a89 --- /dev/null +++ b/distant-ssh2/tests/ssh2/launched.rs @@ -0,0 +1,1469 @@ +use crate::sshd::*; +use assert_fs::{prelude::*, TempDir}; +use distant_core::{ + data::{ChangeKindSet, Environment, FileType, Metadata}, + DistantChannelExt, DistantClient, +}; +use once_cell::sync::Lazy; +use predicates::prelude::*; +use rstest::*; +use std::{path::Path, time::Duration}; + +static TEMP_SCRIPT_DIR: Lazy = Lazy::new(|| TempDir::new().unwrap()); +static SCRIPT_RUNNER: Lazy = Lazy::new(|| String::from("bash")); + +static ECHO_ARGS_TO_STDOUT_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_args_to_stdout.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + printf "%s" "$*" + "# + )) + .unwrap(); + script +}); + +static ECHO_ARGS_TO_STDERR_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_args_to_stderr.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + printf "%s" "$*" 1>&2 + "# + )) + .unwrap(); + script +}); + +static ECHO_STDIN_TO_STDOUT_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_stdin_to_stdout.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + while IFS= read; do echo "$REPLY"; done + "# + )) + .unwrap(); + script +}); + +static SLEEP_SH: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("sleep.sh"); + script + .write_str(indoc::indoc!( + r#" + #!/usr/bin/env bash + sleep "$1" + "# + )) + .unwrap(); + script +}); + +static DOES_NOT_EXIST_BIN: Lazy = + Lazy::new(|| TEMP_SCRIPT_DIR.child("does_not_exist_bin")); + +#[rstest] +#[tokio::test] +async fn read_file_should_fail_if_file_missing(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let path = temp.child("missing-file").path().to_path_buf(); + + let _ = client.read_file(path).await.unwrap_err(); +} + +#[rstest] +#[tokio::test] +async fn read_file_should_send_blob_with_file_contents( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + let bytes = client.read_file(file.path().to_path_buf()).await.unwrap(); + assert_eq!(bytes, b"some file contents"); +} + +#[rstest] +#[tokio::test] +async fn read_file_text_should_send_error_if_fails_to_read_file( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let temp = assert_fs::TempDir::new().unwrap(); + let path = temp.child("missing-file").path().to_path_buf(); + + let _ = client.read_file_text(path).await.unwrap_err(); +} + +#[rstest] +#[tokio::test] +async fn read_file_text_should_send_text_with_file_contents( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + let text = client + .read_file_text(file.path().to_path_buf()) + .await + .unwrap(); + assert_eq!(text, "some file contents"); +} + +#[rstest] +#[tokio::test] +async fn write_file_should_send_error_if_fails_to_write_file( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + let _ = client + .write_file(file.path().to_path_buf(), b"some text".to_vec()) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn write_file_should_send_ok_when_successful(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + + // Path should point to a file that does not exist, but all + // other components leading up to it do + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + client + .write_file(file.path().to_path_buf(), b"some text".to_vec()) + .await + .unwrap(); + + // Also verify that we actually did create the file + // with the associated contents + file.assert("some text"); +} + +#[rstest] +#[tokio::test] +async fn write_file_text_should_send_error_if_fails_to_write_file( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + let _ = client + .write_file_text(file.path().to_path_buf(), "some text".to_string()) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn write_file_text_should_send_ok_when_successful( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Path should point to a file that does not exist, but all + // other components leading up to it do + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + client + .write_file_text(file.path().to_path_buf(), "some text".to_string()) + .await + .unwrap(); + + // Also verify that we actually did create the file + // with the associated contents + file.assert("some text"); +} + +#[rstest] +#[tokio::test] +async fn append_file_should_send_error_if_fails_to_create_file( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + let _ = client + .append_file(file.path().to_path_buf(), b"some extra contents".to_vec()) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn append_file_should_create_file_if_missing(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + + // Don't create the file directly, but define path + // where the file should be + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + client + .append_file(file.path().to_path_buf(), b"some extra contents".to_vec()) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did create to the file + file.assert("some extra contents"); +} + +#[rstest] +#[tokio::test] +async fn append_file_should_send_ok_when_successful(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + + // Create a temporary file and fill it with some contents + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + client + .append_file(file.path().to_path_buf(), b"some extra contents".to_vec()) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did append to the file + file.assert("some file contentssome extra contents"); +} + +#[rstest] +#[tokio::test] +async fn append_file_text_should_send_error_if_fails_to_create_file( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("dir").child("test-file"); + + let _ = client + .append_file_text(file.path().to_path_buf(), "some extra contents".to_string()) + .await + .unwrap_err(); + + // Also verify that we didn't actually create the file + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn append_file_text_should_create_file_if_missing( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Don't create the file directly, but define path + // where the file should be + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + client + .append_file_text(file.path().to_path_buf(), "some extra contents".to_string()) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did create to the file + file.assert("some extra contents"); +} + +#[rstest] +#[tokio::test] +async fn append_file_text_should_send_ok_when_successful( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Create a temporary file and fill it with some contents + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str("some file contents").unwrap(); + + client + .append_file_text(file.path().to_path_buf(), "some extra contents".to_string()) + .await + .unwrap(); + + // Yield to allow chance to finish appending to file + tokio::time::sleep(Duration::from_millis(50)).await; + + // Also verify that we actually did append to the file + file.assert("some file contentssome extra contents"); +} + +#[rstest] +#[tokio::test] +async fn dir_read_should_send_error_if_directory_does_not_exist( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("test-dir"); + + let _ = client + .read_dir( + dir.path().to_path_buf(), + /* depth */ 0, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap_err(); +} + +// /root/ +// /root/file1 +// /root/link1 -> /root/sub1/file2 +// /root/sub1/ +// /root/sub1/file2 +async fn setup_dir() -> assert_fs::TempDir { + let root_dir = assert_fs::TempDir::new().unwrap(); + root_dir.child("file1").touch().unwrap(); + + let sub1 = root_dir.child("sub1"); + sub1.create_dir_all().unwrap(); + + let file2 = sub1.child("file2"); + file2.touch().unwrap(); + + let link1 = root_dir.child("link1"); + link1.symlink_to_file(file2.path()).unwrap(); + + root_dir +} + +#[rstest] +#[tokio::test] +async fn dir_read_should_support_depth_limits(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = client + .read_dir( + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, Path::new("link1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, Path::new("sub1")); + assert_eq!(entries[2].depth, 1); +} + +#[rstest] +#[tokio::test] +async fn dir_read_should_support_unlimited_depth_using_zero( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = client + .read_dir( + root_dir.path().to_path_buf(), + /* depth */ 0, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 4, "Wrong number of entries found"); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, Path::new("link1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, Path::new("sub1")); + assert_eq!(entries[2].depth, 1); + + assert_eq!(entries[3].file_type, FileType::File); + assert_eq!(entries[3].path, Path::new("sub1").join("file2")); + assert_eq!(entries[3].depth, 2); +} + +#[rstest] +#[tokio::test] +async fn dir_read_should_support_including_directory_in_returned_entries( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = client + .read_dir( + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ false, + /* canonicalize */ false, + /* include_root */ true, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 4, "Wrong number of entries found"); + + // NOTE: Root entry is always absolute, resolved path + assert_eq!(entries[0].file_type, FileType::Dir); + assert_eq!( + entries[0].path, + dunce::canonicalize(root_dir.path()).unwrap() + ); + assert_eq!(entries[0].depth, 0); + + assert_eq!(entries[1].file_type, FileType::File); + assert_eq!(entries[1].path, Path::new("file1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Symlink); + assert_eq!(entries[2].path, Path::new("link1")); + assert_eq!(entries[2].depth, 1); + + assert_eq!(entries[3].file_type, FileType::Dir); + assert_eq!(entries[3].path, Path::new("sub1")); + assert_eq!(entries[3].depth, 1); +} + +#[rstest] +#[tokio::test] +async fn dir_read_should_support_returning_absolute_paths( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = client + .read_dir( + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ true, + /* canonicalize */ false, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + let root_path = dunce::canonicalize(root_dir.path()).unwrap(); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, root_path.join("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, root_path.join("link1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, root_path.join("sub1")); + assert_eq!(entries[2].depth, 1); +} + +#[rstest] +#[tokio::test] +async fn dir_read_should_support_returning_canonicalized_paths( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let (entries, _) = client + .read_dir( + root_dir.path().to_path_buf(), + /* depth */ 1, + /* absolute */ false, + /* canonicalize */ true, + /* include_root */ false, + ) + .await + .unwrap(); + + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + println!("{:?}", entries); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + // Symlink should be resolved from $ROOT/link1 -> $ROOT/sub1/file2 + assert_eq!(entries[1].file_type, FileType::Symlink); + assert_eq!(entries[1].path, Path::new("sub1").join("file2")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, Path::new("sub1")); + assert_eq!(entries[2].depth, 1); +} + +#[rstest] +#[tokio::test] +async fn create_dir_should_send_error_if_fails(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + + // Make a path that has multiple non-existent components + // so the creation will fail + let root_dir = setup_dir().await; + let path = root_dir.path().join("nested").join("new-dir"); + + let _ = client + .create_dir(path.to_path_buf(), /* all */ false) + .await + .unwrap_err(); + + // Also verify that the directory was not actually created + assert!(!path.exists(), "Path unexpectedly exists"); +} + +#[rstest] +#[tokio::test] +async fn create_dir_should_send_ok_when_successful(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + let root_dir = setup_dir().await; + let path = root_dir.path().join("new-dir"); + + client + .create_dir(path.to_path_buf(), /* all */ false) + .await + .unwrap(); + + // Also verify that the directory was actually created + assert!(path.exists(), "Directory not created"); +} + +#[rstest] +#[tokio::test] +async fn create_dir_should_support_creating_multiple_dir_components( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let root_dir = setup_dir().await; + let path = root_dir.path().join("nested").join("new-dir"); + + client + .create_dir(path.to_path_buf(), /* all */ true) + .await + .unwrap(); + + // Also verify that the directory was actually created + assert!(path.exists(), "Directory not created"); +} + +#[rstest] +#[tokio::test] +async fn remove_should_send_error_on_failure(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("missing-file"); + + let _ = client + .remove(file.path().to_path_buf(), /* false */ false) + .await + .unwrap_err(); + + // Also, verify that path does not exist + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn remove_should_support_deleting_a_directory(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + client + .remove(dir.path().to_path_buf(), /* false */ false) + .await + .unwrap(); + + // Also, verify that path does not exist + dir.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn remove_should_delete_nonempty_directory_if_force_is_true( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + dir.child("file").touch().unwrap(); + + client + .remove(dir.path().to_path_buf(), /* false */ true) + .await + .unwrap(); + + // Also, verify that path does not exist + dir.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn remove_should_support_deleting_a_single_file( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("some-file"); + file.touch().unwrap(); + + client + .remove(file.path().to_path_buf(), /* false */ false) + .await + .unwrap(); + + // Also, verify that path does not exist + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn copy_should_send_error_on_failure(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + let dst = temp.child("dst"); + + let _ = client + .copy(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap_err(); + + // Also, verify that destination does not exist + dst.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn copy_should_support_copying_an_entire_directory( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let src_file = src.child("file"); + src_file.write_str("some contents").unwrap(); + + let dst = temp.child("dst"); + let dst_file = dst.child("file"); + + client + .copy(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we have source and destination directories and associated contents + src.assert(predicate::path::is_dir()); + src_file.assert(predicate::path::is_file()); + dst.assert(predicate::path::is_dir()); + dst_file.assert(predicate::path::eq_file(src_file.path())); +} + +#[rstest] +#[tokio::test] +async fn copy_should_support_copying_an_empty_directory( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let dst = temp.child("dst"); + + client + .copy(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we still have source and destination directories + src.assert(predicate::path::is_dir()); + dst.assert(predicate::path::is_dir()); +} + +#[rstest] +#[tokio::test] +async fn copy_should_support_copying_a_directory_that_only_contains_directories( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let src_dir = src.child("dir"); + src_dir.create_dir_all().unwrap(); + + let dst = temp.child("dst"); + let dst_dir = dst.child("dir"); + + client + .copy(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we have source and destination directories and associated contents + src.assert(predicate::path::is_dir().name("src")); + src_dir.assert(predicate::path::is_dir().name("src/dir")); + dst.assert(predicate::path::is_dir().name("dst")); + dst_dir.assert(predicate::path::is_dir().name("dst/dir")); +} + +#[rstest] +#[tokio::test] +async fn copy_should_support_copying_a_single_file(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + src.write_str("some text").unwrap(); + let dst = temp.child("dst"); + + client + .copy(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we still have source and that destination has source's contents + src.assert(predicate::path::is_file()); + dst.assert(predicate::path::eq_file(src.path())); +} + +#[rstest] +#[tokio::test] +async fn rename_should_fail_if_path_missing(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + let dst = temp.child("dst"); + + let _ = client + .rename(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap_err(); + + // Also, verify that destination does not exist + dst.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn rename_should_support_renaming_an_entire_directory( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("src"); + src.create_dir_all().unwrap(); + let src_file = src.child("file"); + src_file.write_str("some contents").unwrap(); + + let dst = temp.child("dst"); + let dst_file = dst.child("file"); + + client + .rename(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we moved the contents + src.assert(predicate::path::missing()); + src_file.assert(predicate::path::missing()); + dst.assert(predicate::path::is_dir()); + dst_file.assert("some contents"); +} + +#[rstest] +#[tokio::test] +async fn rename_should_support_renaming_a_single_file( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let src = temp.child("src"); + src.write_str("some text").unwrap(); + let dst = temp.child("dst"); + + client + .rename(src.path().to_path_buf(), dst.path().to_path_buf()) + .await + .unwrap(); + + // Verify that we moved the file + src.assert(predicate::path::missing()); + dst.assert("some text"); +} + +#[rstest] +#[tokio::test] +async fn watch_should_succeed(#[future] launched_client: Ctx) { + // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + + let file = temp.child("file"); + file.touch().unwrap(); + + let _ = client + .watch( + file.path().to_path_buf(), + /* recursive */ false, + /* only */ ChangeKindSet::default(), + /* except */ ChangeKindSet::default(), + ) + .await + .unwrap(); +} + +#[rstest] +#[tokio::test] +async fn exists_should_send_true_if_path_exists(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.touch().unwrap(); + + let exists = client.exists(file.path().to_path_buf()).await.unwrap(); + assert!(exists, "Expected exists to be true, but was false"); +} + +#[rstest] +#[tokio::test] +async fn exists_should_send_false_if_path_does_not_exist( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + + let exists = client.exists(file.path().to_path_buf()).await.unwrap(); + assert!(!exists, "Expected exists to be false, but was true"); +} + +#[rstest] +#[tokio::test] +async fn metadata_should_send_error_on_failure(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + + let _ = client + .metadata( + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap_err(); +} + +#[rstest] +#[tokio::test] +async fn metadata_should_send_back_metadata_on_file_if_exists( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let metadata = client + .metadata( + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + canonicalized_path: None, + file_type: FileType::File, + len: 9, + readonly: false, + .. + } + ), + "{:?}", + metadata + ); +} + +#[cfg(unix)] +#[rstest] +#[tokio::test] +async fn metadata_should_include_unix_specific_metadata_on_unix_platform( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let metadata = client + .metadata( + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + #[allow(clippy::match_single_binding)] + match metadata { + Metadata { unix, windows, .. } => { + assert!(unix.is_some(), "Unexpectedly missing unix metadata on unix"); + assert!( + windows.is_none(), + "Unexpectedly got windows metadata on unix" + ); + } + } +} + +#[cfg(windows)] +#[rstest] +#[tokio::test] +async fn metadata_should_include_windows_specific_metadata_on_windows_platform( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let metadata = client + .metadata( + file.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + #[allow(clippy::match_single_binding)] + match metadata { + Metadata { unix, windows, .. } => { + assert!( + windows.is_some(), + "Unexpectedly missing windows metadata on windows" + ); + assert!(unix.is_none(), "Unexpectedly got unix metadata on windows"); + } + } +} + +#[rstest] +#[tokio::test] +async fn metadata_should_send_back_metadata_on_dir_if_exists( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + let metadata = client + .metadata( + dir.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + canonicalized_path: None, + file_type: FileType::Dir, + readonly: false, + .. + } + ), + "{:?}", + metadata + ); +} + +#[rstest] +#[tokio::test] +async fn metadata_should_send_back_metadata_on_symlink_if_exists( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let symlink = temp.child("link"); + symlink.symlink_to_file(file.path()).unwrap(); + + let metadata = client + .metadata( + symlink.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + canonicalized_path: None, + file_type: FileType::Symlink, + readonly: false, + .. + } + ), + "{:?}", + metadata + ); +} + +#[rstest] +#[tokio::test] +async fn metadata_should_include_canonicalized_path_if_flag_specified( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let symlink = temp.child("link"); + symlink.symlink_to_file(file.path()).unwrap(); + + let metadata = client + .metadata( + symlink.path().to_path_buf(), + /* canonicalize */ true, + /* resolve_file_type */ false, + ) + .await + .unwrap(); + + match metadata { + Metadata { + canonicalized_path: Some(path), + file_type: FileType::Symlink, + readonly: false, + .. + } => assert_eq!( + path, + dunce::canonicalize(file.path()).unwrap(), + "Symlink canonicalized path does not match referenced file" + ), + x => panic!("Unexpected response: {:?}", x), + } +} + +#[rstest] +#[tokio::test] +async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("file"); + file.write_str("some text").unwrap(); + + let symlink = temp.child("link"); + symlink.symlink_to_file(file.path()).unwrap(); + + let metadata = client + .metadata( + symlink.path().to_path_buf(), + /* canonicalize */ false, + /* resolve_file_type */ true, + ) + .await + .unwrap(); + + assert!( + matches!( + metadata, + Metadata { + file_type: FileType::File, + .. + } + ), + "{:?}", + metadata + ); +} + +#[rstest] +#[tokio::test] +async fn proc_spawn_should_fail_if_process_not_found( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let _ = client + .spawn( + /* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap_err(); +} + +#[rstest] +#[tokio::test] +async fn proc_spawn_should_return_id_of_spawned_process( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let proc = client + .spawn( + /* cmd */ + format!( + "{} {}", + *SCRIPT_RUNNER, + ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + assert!(proc.id() > 0); +} + +// NOTE: Ignoring on windows because it's using WSL which wants a Linux path +// with / but thinks it's on windows and is providing \ +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn proc_spawn_should_send_back_stdout_periodically_when_available( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let mut proc = client + .spawn( + /* cmd */ + format!( + "{} {} some stdout", + *SCRIPT_RUNNER, + ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + assert_eq!( + proc.stdout.as_mut().unwrap().read().await.unwrap(), + b"some stdout" + ); + assert!( + proc.wait().await.unwrap().success, + "Process should have completed successfully" + ); +} + +// NOTE: Ignoring on windows because it's using WSL which wants a Linux path +// with / but thinks it's on windows and is providing \ +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn proc_spawn_should_send_back_stderr_periodically_when_available( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let mut proc = client + .spawn( + /* cmd */ + format!( + "{} {} some stderr", + *SCRIPT_RUNNER, + ECHO_ARGS_TO_STDERR_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + assert_eq!( + proc.stderr.as_mut().unwrap().read().await.unwrap(), + b"some stderr" + ); + assert!( + proc.wait().await.unwrap().success, + "Process should have completed successfully" + ); +} + +// NOTE: Ignoring on windows because it's using WSL which wants a Linux path +// with / but thinks it's on windows and is providing \ +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn proc_spawn_should_send_done_signal_when_completed( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let proc = client + .spawn( + /* cmd */ + format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + let _ = proc.wait().await.unwrap(); +} + +#[rstest] +#[tokio::test] +async fn proc_spawn_should_clear_process_from_state_when_killed( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let mut proc = client + .spawn( + /* cmd */ + format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Send kill signal + proc.kill().await.unwrap(); + + // Verify killed, which should be success false + let status = proc.wait().await.unwrap(); + assert!(!status.success, "Process succeeded when killed") +} + +#[rstest] +#[tokio::test] +async fn proc_kill_should_fail_if_process_not_running( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let mut proc = client + .spawn( + /* cmd */ + format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Send kill signal + proc.kill().await.unwrap(); + + // Wait for process to be dead + let mut killer = proc.clone_killer(); + let _ = proc.wait().await.unwrap(); + + // Now send it again, which should fail + let _ = killer.kill().await.unwrap_err(); +} + +#[rstest] +#[tokio::test] +async fn proc_stdin_should_fail_if_process_not_running( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let mut proc = client + .spawn( + /* cmd */ + format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Send kill signal + proc.kill().await.unwrap(); + + // Wait for process to be dead + let mut stdin = proc.stdin.take().unwrap(); + let _ = proc.wait().await.unwrap(); + + // Now send stdin, which should fail + let _ = stdin.write_str("some data").await.unwrap_err(); +} + +// NOTE: Ignoring on windows because it's using WSL which wants a Linux path +// with / but thinks it's on windows and is providing \ +#[rstest] +#[tokio::test] +#[cfg_attr(windows, ignore)] +async fn proc_stdin_should_send_stdin_to_process(#[future] launched_client: Ctx) { + let mut client = launched_client.await; + + // First, run a program that listens for stdin + let mut proc = client + .spawn( + /* cmd */ + format!( + "{} {}", + *SCRIPT_RUNNER, + ECHO_STDIN_TO_STDOUT_SH.to_str().unwrap() + ), + /* environment */ Environment::new(), + /* current_dir */ None, + /* persist */ false, + /* pty */ None, + ) + .await + .unwrap(); + + // Second, send stdin to the remote process + proc.stdin + .as_mut() + .unwrap() + .write_str("hello world\n") + .await + .unwrap(); + + // Third, check the async response of stdout to verify we got stdin + assert_eq!( + proc.stdout.as_mut().unwrap().read_string().await.unwrap(), + "hello world\n" + ); +} + +#[rstest] +#[tokio::test] +async fn system_info_should_return_system_info_based_on_binary( + #[future] launched_client: Ctx, +) { + let mut client = launched_client.await; + + let system_info = client.system_info().await.unwrap(); + assert_eq!(system_info.family, std::env::consts::FAMILY.to_string()); + assert_eq!(system_info.os, std::env::consts::OS.to_string()); + assert_eq!(system_info.arch, std::env::consts::ARCH.to_string()); + assert_eq!(system_info.main_separator, std::path::MAIN_SEPARATOR); +} diff --git a/distant-ssh2/tests/ssh2/mod.rs b/distant-ssh2/tests/ssh2/mod.rs index 7f33ec2..85a10d8 100644 --- a/distant-ssh2/tests/ssh2/mod.rs +++ b/distant-ssh2/tests/ssh2/mod.rs @@ -1 +1,3 @@ -mod session; +mod client; +mod launched; +mod ssh; diff --git a/distant-ssh2/tests/ssh2/session.rs b/distant-ssh2/tests/ssh2/session.rs deleted file mode 100644 index a4f2873..0000000 --- a/distant-ssh2/tests/ssh2/session.rs +++ /dev/null @@ -1,1943 +0,0 @@ -use crate::sshd::*; -use assert_fs::{prelude::*, TempDir}; -use distant_core::{ - FileType, Metadata, Request, RequestData, Response, ResponseData, RunningProcess, Session, - SystemInfo, -}; -use once_cell::sync::Lazy; -use predicates::prelude::*; -use rstest::*; -use std::{ - env, - path::{Path, PathBuf}, - time::Duration, -}; - -static TEMP_SCRIPT_DIR: Lazy = Lazy::new(|| TempDir::new().unwrap()); -static SCRIPT_RUNNER: Lazy = Lazy::new(|| String::from("bash")); - -static ECHO_ARGS_TO_STDOUT_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("echo_args_to_stdout.sh"); - script - .write_str(indoc::indoc!( - r#" - #/usr/bin/env bash - printf "%s" "$*" - "# - )) - .unwrap(); - script -}); - -static ECHO_ARGS_TO_STDERR_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("echo_args_to_stderr.sh"); - script - .write_str(indoc::indoc!( - r#" - #/usr/bin/env bash - printf "%s" "$*" 1>&2 - "# - )) - .unwrap(); - script -}); - -static ECHO_STDIN_TO_STDOUT_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("echo_stdin_to_stdout.sh"); - script - .write_str(indoc::indoc!( - r#" - #/usr/bin/env bash - while IFS= read; do echo "$REPLY"; done - "# - )) - .unwrap(); - script -}); - -static SLEEP_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("sleep.sh"); - script - .write_str(indoc::indoc!( - r#" - #!/usr/bin/env bash - sleep "$1" - "# - )) - .unwrap(); - script -}); - -static DOES_NOT_EXIST_BIN: Lazy = - Lazy::new(|| TEMP_SCRIPT_DIR.child("does_not_exist_bin")); - -#[rstest] -#[tokio::test] -async fn file_read_should_send_error_if_fails_to_read_file(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let path = temp.child("missing-file").path().to_path_buf(); - let req = Request::new("test-tenant", vec![RequestData::FileRead { path }]); - let res = session.send(req).await.unwrap(); - - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -#[rstest] -#[tokio::test] -async fn file_read_should_send_blob_with_file_contents(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str("some file contents").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileRead { - path: file.path().to_path_buf(), - }], - ); - let res = session.send(req).await.unwrap(); - - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Blob { data } => assert_eq!(data, b"some file contents"), - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn file_read_text_should_send_error_if_fails_to_read_file(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let path = temp.child("missing-file").path().to_path_buf(); - let req = Request::new("test-tenant", vec![RequestData::FileReadText { path }]); - let res = session.send(req).await.unwrap(); - - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -#[rstest] -#[tokio::test] -async fn file_read_text_should_send_text_with_file_contents(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str("some file contents").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileReadText { - path: file.path().to_path_buf(), - }], - ); - let res = session.send(req).await.unwrap(); - - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Text { data } => assert_eq!(data, "some file contents"), - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn file_write_should_send_error_if_fails_to_write_file(#[future] session: Session) { - let mut session = session.await; - // Create a temporary path and add to it to ensure that there are - // extra components that don't exist to cause writing to fail - let temp = TempDir::new().unwrap(); - let file = temp.child("dir").child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileWrite { - path: file.path().to_path_buf(), - data: b"some text".to_vec(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we didn't actually create the file - file.assert(predicate::path::missing()); -} - -#[rstest] -#[tokio::test] -async fn file_write_should_send_ok_when_successful(#[future] session: Session) { - let mut session = session.await; - // Path should point to a file that does not exist, but all - // other components leading up to it do - let temp = TempDir::new().unwrap(); - let file = temp.child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileWrite { - path: file.path().to_path_buf(), - data: b"some text".to_vec(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we actually did create the file - // with the associated contents - file.assert("some text"); -} - -#[rstest] -#[tokio::test] -async fn file_write_text_should_send_error_if_fails_to_write_file(#[future] session: Session) { - let mut session = session.await; - // Create a temporary path and add to it to ensure that there are - // extra components that don't exist to cause writing to fail - let temp = TempDir::new().unwrap(); - let file = temp.child("dir").child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileWriteText { - path: file.path().to_path_buf(), - text: String::from("some text"), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we didn't actually create the file - file.assert(predicate::path::missing()); -} - -#[rstest] -#[tokio::test] -async fn file_write_text_should_send_ok_when_successful(#[future] session: Session) { - let mut session = session.await; - // Path should point to a file that does not exist, but all - // other components leading up to it do - let temp = TempDir::new().unwrap(); - let file = temp.child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileWriteText { - path: file.path().to_path_buf(), - text: String::from("some text"), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we actually did create the file - // with the associated contents - file.assert("some text"); -} - -#[rstest] -#[tokio::test] -async fn file_append_should_send_error_if_fails_to_create_file(#[future] session: Session) { - let mut session = session.await; - // Create a temporary path and add to it to ensure that there are - // extra components that don't exist to cause writing to fail - let temp = TempDir::new().unwrap(); - let file = temp.child("dir").child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileAppend { - path: file.path().to_path_buf(), - data: b"some extra contents".to_vec(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we didn't actually create the file - file.assert(predicate::path::missing()); -} - -#[rstest] -#[tokio::test] -async fn file_append_should_send_ok_when_successful(#[future] session: Session) { - let mut session = session.await; - // Create a temporary file and fill it with some contents - let temp = TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str("some file contents").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileAppend { - path: file.path().to_path_buf(), - data: b"some extra contents".to_vec(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Yield to allow chance to finish appending to file - tokio::time::sleep(Duration::from_millis(50)).await; - - // Also verify that we actually did append to the file - file.assert("some file contentssome extra contents"); -} - -#[rstest] -#[tokio::test] -async fn file_append_text_should_send_error_if_fails_to_create_file(#[future] session: Session) { - let mut session = session.await; - // Create a temporary path and add to it to ensure that there are - // extra components that don't exist to cause writing to fail - let temp = TempDir::new().unwrap(); - let file = temp.child("dir").child("test-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileAppendText { - path: file.path().to_path_buf(), - text: String::from("some extra contents"), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that we didn't actually create the file - file.assert(predicate::path::missing()); -} - -#[rstest] -#[tokio::test] -async fn file_append_text_should_send_ok_when_successful(#[future] session: Session) { - let mut session = session.await; - // Create a temporary file and fill it with some contents - let temp = TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str("some file contents").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::FileAppendText { - path: file.path().to_path_buf(), - text: String::from("some extra contents"), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Yield to allow chance to finish appending to file - tokio::time::sleep(Duration::from_millis(50)).await; - - // Also verify that we actually did append to the file - file.assert("some file contentssome extra contents"); -} - -#[rstest] -#[tokio::test] -async fn dir_read_should_send_error_if_directory_does_not_exist(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let dir = temp.child("test-dir"); - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: dir.path().to_path_buf(), - depth: 0, - absolute: false, - canonicalize: false, - include_root: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -// /root/ -// /root/file1 -// /root/link1 -> /root/sub1/file2 -// /root/sub1/ -// /root/sub1/file2 -async fn setup_dir() -> TempDir { - let root_dir = TempDir::new().unwrap(); - root_dir.child("file1").touch().unwrap(); - - let sub1 = root_dir.child("sub1"); - sub1.create_dir_all().unwrap(); - - let file2 = sub1.child("file2"); - file2.touch().unwrap(); - - let link1 = root_dir.child("link1"); - link1.symlink_to_file(file2.path()).unwrap(); - - root_dir -} - -#[rstest] -#[tokio::test] -async fn dir_read_should_support_depth_limits(#[future] session: Session) { - let mut session = session.await; - // Create directory with some nested items - let root_dir = setup_dir().await; - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: root_dir.path().to_path_buf(), - depth: 1, - absolute: false, - canonicalize: false, - include_root: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::DirEntries { entries, .. } => { - assert_eq!(entries.len(), 3, "Wrong number of entries found"); - - assert_eq!(entries[0].file_type, FileType::File); - assert_eq!(entries[0].path, Path::new("file1")); - assert_eq!(entries[0].depth, 1); - - assert_eq!(entries[1].file_type, FileType::Symlink); - assert_eq!(entries[1].path, Path::new("link1")); - assert_eq!(entries[1].depth, 1); - - assert_eq!(entries[2].file_type, FileType::Dir); - assert_eq!(entries[2].path, Path::new("sub1")); - assert_eq!(entries[2].depth, 1); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn dir_read_should_support_unlimited_depth_using_zero(#[future] session: Session) { - let mut session = session.await; - // Create directory with some nested items - let root_dir = setup_dir().await; - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: root_dir.path().to_path_buf(), - depth: 0, - absolute: false, - canonicalize: false, - include_root: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::DirEntries { entries, .. } => { - assert_eq!(entries.len(), 4, "Wrong number of entries found"); - - assert_eq!(entries[0].file_type, FileType::File); - assert_eq!(entries[0].path, Path::new("file1")); - assert_eq!(entries[0].depth, 1); - - assert_eq!(entries[1].file_type, FileType::Symlink); - assert_eq!(entries[1].path, Path::new("link1")); - assert_eq!(entries[1].depth, 1); - - assert_eq!(entries[2].file_type, FileType::Dir); - assert_eq!(entries[2].path, Path::new("sub1")); - assert_eq!(entries[2].depth, 1); - - assert_eq!(entries[3].file_type, FileType::File); - assert_eq!(entries[3].path, Path::new("sub1").join("file2")); - assert_eq!(entries[3].depth, 2); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn dir_read_should_support_including_directory_in_returned_entries( - #[future] session: Session, -) { - let mut session = session.await; - // Create directory with some nested items - let root_dir = setup_dir().await; - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: root_dir.path().to_path_buf(), - depth: 1, - absolute: false, - canonicalize: false, - include_root: true, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::DirEntries { entries, .. } => { - assert_eq!(entries.len(), 4, "Wrong number of entries found"); - - // NOTE: Root entry is always absolute, resolved path - assert_eq!(entries[0].file_type, FileType::Dir); - assert_eq!(entries[0].path, root_dir.path().canonicalize().unwrap()); - assert_eq!(entries[0].depth, 0); - - assert_eq!(entries[1].file_type, FileType::File); - assert_eq!(entries[1].path, Path::new("file1")); - assert_eq!(entries[1].depth, 1); - - assert_eq!(entries[2].file_type, FileType::Symlink); - assert_eq!(entries[2].path, Path::new("link1")); - assert_eq!(entries[2].depth, 1); - - assert_eq!(entries[3].file_type, FileType::Dir); - assert_eq!(entries[3].path, Path::new("sub1")); - assert_eq!(entries[3].depth, 1); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn dir_read_should_support_returning_absolute_paths(#[future] session: Session) { - let mut session = session.await; - // Create directory with some nested items - let root_dir = setup_dir().await; - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: root_dir.path().to_path_buf(), - depth: 1, - absolute: true, - canonicalize: false, - include_root: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::DirEntries { entries, .. } => { - assert_eq!(entries.len(), 3, "Wrong number of entries found"); - let root_path = root_dir.path().canonicalize().unwrap(); - - assert_eq!(entries[0].file_type, FileType::File); - assert_eq!(entries[0].path, root_path.join("file1")); - assert_eq!(entries[0].depth, 1); - - assert_eq!(entries[1].file_type, FileType::Symlink); - assert_eq!(entries[1].path, root_path.join("link1")); - assert_eq!(entries[1].depth, 1); - - assert_eq!(entries[2].file_type, FileType::Dir); - assert_eq!(entries[2].path, root_path.join("sub1")); - assert_eq!(entries[2].depth, 1); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn dir_read_should_support_returning_canonicalized_paths(#[future] session: Session) { - let mut session = session.await; - // Create directory with some nested items - let root_dir = setup_dir().await; - - let req = Request::new( - "test-tenant", - vec![RequestData::DirRead { - path: root_dir.path().to_path_buf(), - depth: 1, - absolute: false, - canonicalize: true, - include_root: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::DirEntries { entries, .. } => { - assert_eq!(entries.len(), 3, "Wrong number of entries found"); - - assert_eq!(entries[0].file_type, FileType::File); - assert_eq!(entries[0].path, Path::new("file1")); - assert_eq!(entries[0].depth, 1); - - assert_eq!(entries[1].file_type, FileType::Dir); - assert_eq!(entries[1].path, Path::new("sub1")); - assert_eq!(entries[1].depth, 1); - - // Symlink should be resolved from $ROOT/link1 -> $ROOT/sub1/file2 - assert_eq!(entries[2].file_type, FileType::Symlink); - assert_eq!(entries[2].path, Path::new("sub1").join("file2")); - assert_eq!(entries[2].depth, 1); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn dir_create_should_send_error_if_fails(#[future] session: Session) { - let mut session = session.await; - // Make a path that has multiple non-existent components - // so the creation will fail - let root_dir = setup_dir().await; - let path = root_dir.path().join("nested").join("new-dir"); - - let req = Request::new( - "test-tenant", - vec![RequestData::DirCreate { - path: path.to_path_buf(), - all: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that the directory was not actually created - assert!(!path.exists(), "Path unexpectedly exists"); -} - -#[rstest] -#[tokio::test] -async fn dir_create_should_send_ok_when_successful(#[future] session: Session) { - let mut session = session.await; - let root_dir = setup_dir().await; - let path = root_dir.path().join("new-dir"); - - let req = Request::new( - "test-tenant", - vec![RequestData::DirCreate { - path: path.to_path_buf(), - all: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that the directory was actually created - assert!(path.exists(), "Directory not created"); -} - -#[rstest] -#[tokio::test] -async fn dir_create_should_support_creating_multiple_dir_components(#[future] session: Session) { - let mut session = session.await; - let root_dir = setup_dir().await; - let path = root_dir.path().join("nested").join("new").join("dir"); - - let req = Request::new( - "test-tenant", - vec![RequestData::DirCreate { - path: path.to_path_buf(), - all: true, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also verify that the directory was actually created - assert!(path.exists(), "Directory not created"); -} - -#[rstest] -#[tokio::test] -async fn remove_should_send_error_on_failure(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("missing-file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Remove { - path: file.path().to_path_buf(), - force: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that path does not exist - file.assert(predicate::path::missing()); -} - -#[rstest] -#[tokio::test] -async fn remove_should_support_deleting_a_directory(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Remove { - path: dir.path().to_path_buf(), - force: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that path does not exist - dir.assert(predicate::path::missing()); -} - -#[rstest] -#[tokio::test] -async fn remove_should_delete_nonempty_directory_if_force_is_true(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - dir.child("file").touch().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Remove { - path: dir.path().to_path_buf(), - force: true, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that path does not exist - dir.assert(predicate::path::missing()); -} - -#[rstest] -#[tokio::test] -async fn remove_should_support_deleting_a_single_file(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("some-file"); - file.touch().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Remove { - path: file.path().to_path_buf(), - force: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that path does not exist - file.assert(predicate::path::missing()); -} - -#[rstest] -#[tokio::test] -async fn copy_should_send_error_on_failure(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let src = temp.child("src"); - let dst = temp.child("dst"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Copy { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that destination does not exist - dst.assert(predicate::path::missing()); -} - -#[rstest] -#[tokio::test] -async fn copy_should_support_copying_an_entire_directory(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - - let src = temp.child("src"); - src.create_dir_all().unwrap(); - let src_file = src.child("file"); - src_file.write_str("some contents").unwrap(); - - let dst = temp.child("dst"); - let dst_file = dst.child("file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Copy { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we have source and destination directories and associated contents - src.assert(predicate::path::is_dir()); - src_file.assert(predicate::path::is_file()); - dst.assert(predicate::path::is_dir()); - dst_file.assert(predicate::path::eq_file(src_file.path())); -} - -#[rstest] -#[tokio::test] -async fn copy_should_support_copying_an_empty_directory(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let src = temp.child("src"); - src.create_dir_all().unwrap(); - let dst = temp.child("dst"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Copy { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we still have source and destination directories - src.assert(predicate::path::is_dir()); - dst.assert(predicate::path::is_dir()); -} - -#[rstest] -#[tokio::test] -async fn copy_should_support_copying_a_directory_that_only_contains_directories( - #[future] session: Session, -) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - - let src = temp.child("src"); - src.create_dir_all().unwrap(); - let src_dir = src.child("dir"); - src_dir.create_dir_all().unwrap(); - - let dst = temp.child("dst"); - let dst_dir = dst.child("dir"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Copy { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we have source and destination directories and associated contents - src.assert(predicate::path::is_dir().name("src")); - src_dir.assert(predicate::path::is_dir().name("src/dir")); - dst.assert(predicate::path::is_dir().name("dst")); - dst_dir.assert(predicate::path::is_dir().name("dst/dir")); -} - -#[rstest] -#[tokio::test] -async fn copy_should_support_copying_a_single_file(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let src = temp.child("src"); - src.write_str("some text").unwrap(); - let dst = temp.child("dst"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Copy { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we still have source and that destination has source's contents - src.assert(predicate::path::is_file()); - dst.assert(predicate::path::eq_file(src.path())); -} - -#[rstest] -#[tokio::test] -async fn rename_should_send_error_on_failure(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let src = temp.child("src"); - let dst = temp.child("dst"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Rename { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Also, verify that destination does not exist - dst.assert(predicate::path::missing()); -} - -#[rstest] -#[tokio::test] -async fn rename_should_support_renaming_an_entire_directory(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - - let src = temp.child("src"); - src.create_dir_all().unwrap(); - let src_file = src.child("file"); - src_file.write_str("some contents").unwrap(); - - let dst = temp.child("dst"); - let dst_file = dst.child("file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Rename { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we moved the contents - src.assert(predicate::path::missing()); - src_file.assert(predicate::path::missing()); - dst.assert(predicate::path::is_dir()); - dst_file.assert("some contents"); -} - -#[rstest] -#[tokio::test] -async fn rename_should_support_renaming_a_single_file(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let src = temp.child("src"); - src.write_str("some text").unwrap(); - let dst = temp.child("dst"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Rename { - src: src.path().to_path_buf(), - dst: dst.path().to_path_buf(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Verify that we moved the file - src.assert(predicate::path::missing()); - dst.assert("some text"); -} - -#[rstest] -#[tokio::test] -async fn exists_should_send_true_if_path_exists(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("file"); - file.touch().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Exists { - path: file.path().to_path_buf(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert_eq!(res.payload[0], ResponseData::Exists { value: true }); -} - -#[rstest] -#[tokio::test] -async fn exists_should_send_false_if_path_does_not_exist(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Exists { - path: file.path().to_path_buf(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert_eq!(res.payload[0], ResponseData::Exists { value: false }); -} - -#[rstest] -#[tokio::test] -async fn metadata_should_send_error_on_failure(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("file"); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: file.path().to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -#[rstest] -#[tokio::test] -async fn metadata_should_send_back_metadata_on_file_if_exists(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("file"); - file.write_str("some text").unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: file.path().to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!( - res.payload[0], - ResponseData::Metadata(Metadata { - canonicalized_path: None, - file_type: FileType::File, - len: 9, - readonly: false, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -#[rstest] -#[tokio::test] -async fn metadata_should_send_back_metadata_on_dir_if_exists(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: dir.path().to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!( - res.payload[0], - ResponseData::Metadata(Metadata { - canonicalized_path: None, - file_type: FileType::Dir, - readonly: false, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -#[rstest] -#[tokio::test] -async fn metadata_should_send_back_metadata_on_symlink_if_exists(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("file"); - file.write_str("some text").unwrap(); - - let symlink = temp.child("link"); - symlink.symlink_to_file(file.path()).unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: symlink.path().to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!( - res.payload[0], - ResponseData::Metadata(Metadata { - canonicalized_path: None, - file_type: FileType::Symlink, - readonly: false, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -#[rstest] -#[tokio::test] -async fn metadata_should_include_canonicalized_path_if_flag_specified(#[future] session: Session) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("file"); - file.write_str("some text").unwrap(); - - let symlink = temp.child("link"); - symlink.symlink_to_file(file.path()).unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: symlink.path().to_path_buf(), - canonicalize: true, - resolve_file_type: false, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Metadata(Metadata { - canonicalized_path: Some(path), - file_type: FileType::Symlink, - readonly: false, - .. - }) => assert_eq!( - path, - &file.path().canonicalize().unwrap(), - "Symlink canonicalized path does not match referenced file" - ), - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified( - #[future] session: Session, -) { - let mut session = session.await; - let temp = TempDir::new().unwrap(); - let file = temp.child("file"); - file.write_str("some text").unwrap(); - - let symlink = temp.child("link"); - symlink.symlink_to_file(file.path()).unwrap(); - - let req = Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: symlink.path().to_path_buf(), - canonicalize: false, - resolve_file_type: true, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Metadata(Metadata { - file_type: FileType::File, - .. - }) => {} - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn proc_spawn_should_send_error_over_stderr_on_failure(#[future] session: Session) { - let mut session = session.await; - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), - args: Vec::new(), - persist: false, - pty: None, - }], - ); - - // NOTE: This diverges from distant in that we don't get an error message and instead - // will always get stderr as ssh runs every command in some kind of shell - let mut mailbox = session.mail(req).await.unwrap(); - - // Get proc start message - let res = mailbox.next().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - let proc_id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Get proc stderr message - let res = mailbox.next().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::ProcStderr { id, .. } => { - assert_eq!(proc_id, *id, "Wrong process stderr received"); - } - x => panic!("Unexpected response: {:?}", x), - } - - // Get proc done message - let res = mailbox.next().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::ProcDone { id, .. } => { - assert_eq!(proc_id, *id, "Wrong process done received"); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn proc_spawn_should_send_back_proc_start_on_success(#[future] session: Session) { - let mut session = session.await; - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap().to_string()], - persist: false, - pty: None, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(&res.payload[0], ResponseData::ProcSpawned { .. }), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -// NOTE: Ignoring on windows because it's using WSL which wants a Linux path -// with / but thinks it's on windows and is providing \ -#[rstest] -#[tokio::test] -#[cfg_attr(windows, ignore)] -async fn proc_spawn_should_send_back_stdout_periodically_when_available( - #[future] session: Session, -) { - let mut session = session.await; - // Run a program that echoes to stdout - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ - ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap().to_string(), - String::from("'some stdout'"), - ], - persist: false, - pty: None, - }], - ); - - let mut mailbox = session.mail(req).await.unwrap(); - - let res = mailbox.next().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(&res.payload[0], ResponseData::ProcSpawned { .. }), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Gather two additional responses: - // - // 1. An indirect response for stdout - // 2. An indirect response that is proc completing - // - // Note that order is not a guarantee, so we have to check that - // we get one of each type of response - let res1 = mailbox.next().await.expect("Missing first response"); - let res2 = mailbox.next().await.expect("Missing second response"); - - let mut got_stdout = false; - let mut got_done = false; - - let mut check_res = |res: &Response| { - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::ProcStdout { data, .. } => { - assert_eq!(data, b"some stdout", "Got wrong stdout"); - got_stdout = true; - } - ResponseData::ProcDone { success, .. } => { - assert!(success, "Process should have completed successfully"); - got_done = true; - } - x => panic!("Unexpected response: {:?}", x), - } - }; - - check_res(&res1); - check_res(&res2); - assert!(got_stdout, "Missing stdout response"); - assert!(got_done, "Missing done response"); -} - -// NOTE: Ignoring on windows because it's using WSL which wants a Linux path -// with / but thinks it's on windows and is providing \ -#[rstest] -#[tokio::test] -#[cfg_attr(windows, ignore)] -async fn proc_spawn_should_send_back_stderr_periodically_when_available( - #[future] session: Session, -) { - let mut session = session.await; - // Run a program that echoes to stderr - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ - ECHO_ARGS_TO_STDERR_SH.to_str().unwrap().to_string(), - String::from("'some stderr'"), - ], - persist: false, - pty: None, - }], - ); - - let mut mailbox = session.mail(req).await.unwrap(); - - let res = mailbox.next().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert!( - matches!(&res.payload[0], ResponseData::ProcSpawned { .. }), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Gather two additional responses: - // - // 1. An indirect response for stderr - // 2. An indirect response that is proc completing - // - // Note that order is not a guarantee, so we have to check that - // we get one of each type of response - let res1 = mailbox.next().await.expect("Missing first response"); - let res2 = mailbox.next().await.expect("Missing second response"); - - let mut got_stderr = false; - let mut got_done = false; - - let mut check_res = |res: &Response| { - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::ProcStderr { data, .. } => { - assert_eq!(data, b"some stderr", "Got wrong stderr"); - got_stderr = true; - } - ResponseData::ProcDone { success, .. } => { - assert!(success, "Process should have completed successfully"); - got_done = true; - } - x => panic!("Unexpected response: {:?}", x), - } - }; - - check_res(&res1); - check_res(&res2); - assert!(got_stderr, "Missing stderr response"); - assert!(got_done, "Missing done response"); -} - -// NOTE: Ignoring on windows because it's using WSL which wants a Linux path -// with / but thinks it's on windows and is providing \ -#[rstest] -#[tokio::test] -#[cfg_attr(windows, ignore)] -async fn proc_spawn_should_clear_process_from_state_when_done(#[future] session: Session) { - let mut session = session.await; - // Run a program that ends after a little bit - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![SLEEP_SH.to_str().unwrap().to_string(), String::from("0.1")], - persist: false, - pty: None, - }], - ); - let mut mailbox = session.mail(req).await.unwrap(); - - let res = mailbox.next().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Verify that the state has the process - let res = session - .send(Request::new("test-tenant", vec![RequestData::ProcList {}])) - .await - .unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::ProcEntries { entries } => assert_eq!(entries[0].id, id), - x => panic!("Unexpected response: {:?}", x), - } - - // Wait for process to finish - let _ = mailbox.next().await.unwrap(); - - // Verify that the state was cleared - let res = session - .send(Request::new("test-tenant", vec![RequestData::ProcList {}])) - .await - .unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::ProcEntries { entries } => assert!(entries.is_empty(), "Proc not cleared"), - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn proc_spawn_should_clear_process_from_state_when_killed(#[future] session: Session) { - let mut session = session.await; - // Run a program that ends slowly - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![SLEEP_SH.to_str().unwrap().to_string(), String::from("1")], - persist: false, - pty: None, - }], - ); - - let mut mailbox = session.mail(req).await.unwrap(); - - let res = mailbox.next().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Verify that the state has the process - let res = session - .send(Request::new("test-tenant", vec![RequestData::ProcList {}])) - .await - .unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::ProcEntries { entries } => assert_eq!(entries[0].id, id), - x => panic!("Unexpected response: {:?}", x), - } - - // Send kill signal - let req = Request::new("test-tenant", vec![RequestData::ProcKill { id }]); - let _ = session.send(req).await.unwrap(); - - // Wait for the proc done - let _ = mailbox.next().await.unwrap(); - - // Verify that the state was cleared - let res = session - .send(Request::new("test-tenant", vec![RequestData::ProcList {}])) - .await - .unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::ProcEntries { entries } => assert!(entries.is_empty(), "Proc not cleared"), - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn proc_kill_should_send_error_on_failure(#[future] session: Session) { - let mut session = session.await; - // Send kill to a non-existent process - let req = Request::new( - "test-tenant", - vec![RequestData::ProcKill { id: 0xDEADBEEF }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - // Verify that we get an error - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -#[rstest] -#[tokio::test] -async fn proc_kill_should_send_ok_and_done_responses_on_success(#[future] session: Session) { - let mut session = session.await; - // First, run a program that sits around (sleep for 1 second) - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![SLEEP_SH.to_str().unwrap().to_string(), String::from("1")], - persist: false, - pty: None, - }], - ); - - let mut mailbox = session.mail(req).await.unwrap(); - - let res = mailbox.next().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - // Second, grab the id of the started process - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Third, send kill for process - // NOTE: We cannot let the state get dropped as it results in killing - // the child process automatically; so, we clone another reference here - let req = Request::new("test-tenant", vec![RequestData::ProcKill { id }]); - let res = session.send(req).await.unwrap(); - match &res.payload[0] { - ResponseData::Ok => {} - x => panic!("Unexpected response: {:?}", x), - } - - // Fourth, verify that the process completes - let res = mailbox.next().await.unwrap(); - match &res.payload[0] { - ResponseData::ProcDone { success, .. } => { - assert!(!success, "Process should not have completed successfully"); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn proc_stdin_should_send_error_on_failure(#[future] session: Session) { - let mut session = session.await; - // Send stdin to a non-existent process - let req = Request::new( - "test-tenant", - vec![RequestData::ProcStdin { - id: 0xDEADBEEF, - data: b"some input".to_vec(), - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - // Verify that we get an error - assert!( - matches!(res.payload[0], ResponseData::Error(_)), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -// NOTE: Ignoring on windows because it's using WSL which wants a Linux path -// with / but thinks it's on windows and is providing \ -#[rstest] -#[tokio::test] -#[cfg_attr(windows, ignore)] -async fn proc_stdin_should_send_ok_on_success_and_properly_send_stdin_to_process( - #[future] session: Session, -) { - let mut session = session.await; - - // First, run a program that listens for stdin - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ECHO_STDIN_TO_STDOUT_SH.to_str().unwrap().to_string()], - persist: false, - pty: None, - }], - ); - let mut mailbox = session.mail(req).await.unwrap(); - - let res = mailbox.next().await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - // Second, grab the id of the started process - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Third, send stdin to the remote process - // NOTE: We cannot let the state get dropped as it results in killing - // the child process; so, we clone another reference here - let req = Request::new( - "test-tenant", - vec![RequestData::ProcStdin { - id, - data: b"hello world\n".to_vec(), - }], - ); - let res = session.send(req).await.unwrap(); - match &res.payload[0] { - ResponseData::Ok => {} - x => panic!("Unexpected response: {:?}", x), - } - - // Fourth, gather an indirect response that is stdout from echoing our stdin - let res = mailbox.next().await.unwrap(); - match &res.payload[0] { - ResponseData::ProcStdout { data, .. } => { - assert_eq!(data, b"hello world\n", "Mirrored data didn't match"); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn proc_list_should_send_proc_entry_list(#[future] session: Session) { - let mut session = session.await; - let req = Request::new( - "test-tenant", - vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![SLEEP_SH.to_str().unwrap().to_string(), String::from("10")], - persist: false, - pty: None, - }], - ); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - // Grab the id of the started process - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - let req = Request::new("test-tenant", vec![RequestData::ProcList {}]); - - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - - // Verify our process shows up in our entry list - assert_eq!( - res.payload[0], - ResponseData::ProcEntries { - entries: vec![RunningProcess { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![SLEEP_SH.to_str().unwrap().to_string(), String::from("10")], - persist: false, - pty: None, - id, - }], - }, - "Unexpected response: {:?}", - res.payload[0] - ); -} - -#[rstest] -#[tokio::test] -async fn system_info_should_send_system_info_based_on_binary(#[future] session: Session) { - let mut session = session.await; - - // Figure out what SFTP's realpath(.) would resolve to - let res = session - .send(Request::new( - "test-tenant", - vec![RequestData::Metadata { - path: PathBuf::from("."), - canonicalize: true, - resolve_file_type: false, - }], - )) - .await - .unwrap(); - let current_dir = if let ResponseData::Metadata(Metadata { - canonicalized_path, .. - }) = &res.payload[0] - { - canonicalized_path - .as_deref() - .expect("Missing canonicalized path") - .to_path_buf() - } else { - panic!("Failed to get metadata for '.'") - }; - - let req = Request::new("test-tenant", vec![RequestData::SystemInfo {}]); - let res = session.send(req).await.unwrap(); - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - assert_eq!( - res.payload[0], - ResponseData::SystemInfo(SystemInfo { - family: env::consts::FAMILY.to_string(), - os: "".to_string(), - arch: "".to_string(), - current_dir, - main_separator: std::path::MAIN_SEPARATOR, - }), - "Unexpected response: {:?}", - res.payload[0] - ); -} - -#[rstest] -#[tokio::test] -async fn watch_should_fail_as_unsupported(#[future] session: Session) { - let mut session = session.await; - - let req = Request::new( - "test-tenant", - vec![RequestData::Watch { - path: PathBuf::from("/some/path"), - recursive: true, - only: Default::default(), - except: Default::default(), - }], - ); - let res = session.send(req).await.unwrap(); - - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Error(x) => { - assert_eq!(x.to_string(), "Other: Unsupported"); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -#[tokio::test] -async fn unwatch_should_fail_as_unsupported(#[future] session: Session) { - let mut session = session.await; - - let req = Request::new( - "test-tenant", - vec![RequestData::Unwatch { - path: PathBuf::from("/some/path"), - }], - ); - let res = session.send(req).await.unwrap(); - - assert_eq!(res.payload.len(), 1, "Wrong payload size"); - match &res.payload[0] { - ResponseData::Error(x) => { - assert_eq!(x.to_string(), "Other: Unsupported"); - } - x => panic!("Unexpected response: {:?}", x), - } -} diff --git a/distant-ssh2/tests/ssh2/ssh.rs b/distant-ssh2/tests/ssh2/ssh.rs new file mode 100644 index 0000000..463b678 --- /dev/null +++ b/distant-ssh2/tests/ssh2/ssh.rs @@ -0,0 +1,19 @@ +use crate::sshd::*; +use distant_ssh2::{Ssh, SshFamily}; +use rstest::*; + +#[rstest] +#[tokio::test] +async fn detect_family_should_return_windows_if_sshd_on_windows(#[future] ssh: Ctx) { + let ssh = ssh.await; + let family = ssh.detect_family().await.expect("Failed to detect family"); + assert_eq!( + family, + if cfg!(windows) { + SshFamily::Windows + } else { + SshFamily::Unix + }, + "Got wrong family" + ); +} diff --git a/distant-ssh2/tests/sshd.rs b/distant-ssh2/tests/sshd.rs deleted file mode 100644 index d1f3ea0..0000000 --- a/distant-ssh2/tests/sshd.rs +++ /dev/null @@ -1,436 +0,0 @@ -use assert_fs::{prelude::*, TempDir}; -use distant_core::Session; -use distant_ssh2::{Ssh2AuthHandler, Ssh2Session, Ssh2SessionOpts}; -use once_cell::sync::{Lazy, OnceCell}; -use rstest::*; -use std::{ - collections::HashMap, - fmt, io, - path::Path, - process::{Child, Command}, - sync::atomic::{AtomicU16, Ordering}, - thread, - time::Duration, -}; - -#[cfg(unix)] -use std::os::unix::fs::PermissionsExt; - -/// NOTE: OpenSSH's sshd requires absolute path -const BIN_PATH_STR: &str = "/usr/sbin/sshd"; - -/// Port range to use when finding a port to bind to (using IANA guidance) -const PORT_RANGE: (u16, u16) = (49152, 65535); - -static USERNAME: Lazy = Lazy::new(whoami::username); - -pub struct SshKeygen; - -impl SshKeygen { - // ssh-keygen -t rsa -f $ROOT/id_rsa -N "" -q - pub fn generate_rsa(path: impl AsRef, passphrase: impl AsRef) -> io::Result { - let res = Command::new("ssh-keygen") - .args(&["-m", "PEM"]) - .args(&["-t", "rsa"]) - .arg("-f") - .arg(path.as_ref()) - .arg("-N") - .arg(passphrase.as_ref()) - .arg("-q") - .status() - .map(|status| status.success())?; - - #[cfg(unix)] - if res { - // chmod 600 id_rsa* -> ida_rsa + ida_rsa.pub - std::fs::metadata(path.as_ref().with_extension("pub"))? - .permissions() - .set_mode(0o600); - std::fs::metadata(path)?.permissions().set_mode(0o600); - } - - Ok(res) - } -} - -pub struct SshAgent; - -impl SshAgent { - pub fn generate_shell_env() -> io::Result> { - let output = Command::new("ssh-agent").arg("-s").output()?; - let stdout = String::from_utf8(output.stdout) - .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?; - Ok(stdout - .split(';') - .map(str::trim) - .filter(|s| s.contains('=')) - .map(|s| { - let mut tokens = s.split('='); - let key = tokens.next().unwrap().trim().to_string(); - let rest = tokens - .map(str::trim) - .map(ToString::to_string) - .collect::>() - .join("="); - (key, rest) - }) - .collect::>()) - } - - pub fn update_tests_with_shell_env() -> io::Result<()> { - let env_map = Self::generate_shell_env()?; - for (key, value) in env_map { - std::env::set_var(key, value); - } - - Ok(()) - } -} - -#[derive(Debug)] -pub struct SshdConfig(HashMap>); - -impl Default for SshdConfig { - fn default() -> Self { - let mut config = Self::new(); - - config.set_authentication_methods(vec!["publickey".to_string()]); - config.set_use_privilege_separation(false); - config.set_subsystem(true, true); - config.set_use_pam(false); - config.set_x11_forwarding(true); - config.set_print_motd(true); - config.set_permit_tunnel(true); - config.set_kbd_interactive_authentication(true); - config.set_allow_tcp_forwarding(true); - config.set_max_startups(500, None); - config.set_strict_modes(false); - - config - } -} - -impl SshdConfig { - pub fn new() -> Self { - Self(HashMap::new()) - } - - pub fn set_authentication_methods(&mut self, methods: Vec) { - self.0.insert("AuthenticationMethods".to_string(), methods); - } - - pub fn set_authorized_keys_file(&mut self, path: impl AsRef) { - self.0.insert( - "AuthorizedKeysFile".to_string(), - vec![path.as_ref().to_string_lossy().to_string()], - ); - } - - pub fn set_host_key(&mut self, path: impl AsRef) { - self.0.insert( - "HostKey".to_string(), - vec![path.as_ref().to_string_lossy().to_string()], - ); - } - - pub fn set_pid_file(&mut self, path: impl AsRef) { - self.0.insert( - "PidFile".to_string(), - vec![path.as_ref().to_string_lossy().to_string()], - ); - } - - pub fn set_subsystem(&mut self, sftp: bool, internal_sftp: bool) { - let mut values = Vec::new(); - if sftp { - values.push("sftp".to_string()); - } - if internal_sftp { - values.push("internal-sftp".to_string()); - } - - self.0.insert("Subsystem".to_string(), values); - } - - pub fn set_use_pam(&mut self, yes: bool) { - self.0.insert("UsePAM".to_string(), Self::yes_value(yes)); - } - - pub fn set_x11_forwarding(&mut self, yes: bool) { - self.0 - .insert("X11Forwarding".to_string(), Self::yes_value(yes)); - } - - pub fn set_use_privilege_separation(&mut self, yes: bool) { - self.0 - .insert("UsePrivilegeSeparation".to_string(), Self::yes_value(yes)); - } - - pub fn set_print_motd(&mut self, yes: bool) { - self.0.insert("PrintMotd".to_string(), Self::yes_value(yes)); - } - - pub fn set_permit_tunnel(&mut self, yes: bool) { - self.0 - .insert("PermitTunnel".to_string(), Self::yes_value(yes)); - } - - pub fn set_kbd_interactive_authentication(&mut self, yes: bool) { - self.0.insert( - "KbdInteractiveAuthentication".to_string(), - Self::yes_value(yes), - ); - } - - pub fn set_allow_tcp_forwarding(&mut self, yes: bool) { - self.0 - .insert("AllowTcpForwarding".to_string(), Self::yes_value(yes)); - } - - pub fn set_max_startups(&mut self, start: u16, rate_full: Option<(u16, u16)>) { - let value = format!( - "{}{}", - start, - rate_full - .map(|(r, f)| format!(":{}:{}", r, f)) - .unwrap_or_default(), - ); - - self.0.insert("MaxStartups".to_string(), vec![value]); - } - - pub fn set_strict_modes(&mut self, yes: bool) { - self.0 - .insert("StrictModes".to_string(), Self::yes_value(yes)); - } - - fn yes_value(yes: bool) -> Vec { - vec![Self::yes_string(yes)] - } - - fn yes_string(yes: bool) -> String { - Self::yes_str(yes).to_string() - } - - const fn yes_str(yes: bool) -> &'static str { - if yes { - "yes" - } else { - "no" - } - } -} - -impl fmt::Display for SshdConfig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for (keyword, values) in self.0.iter() { - writeln!( - f, - "{} {}", - keyword, - values - .iter() - .map(|v| { - let v = v.trim(); - if v.contains(|c: char| c.is_whitespace()) { - format!("\"{}\"", v) - } else { - v.to_string() - } - }) - .collect::>() - .join(" ") - )?; - } - Ok(()) - } -} - -/// Context for some sshd instance -pub struct Sshd { - child: Child, - - /// Port that sshd is listening on - pub port: u16, - - /// Temporary directory used to hold resources for sshd such as its config, keys, and log - pub tmp: TempDir, -} - -impl Sshd { - pub fn spawn(mut config: SshdConfig) -> Result> { - let tmp = TempDir::new()?; - - // Ensure that everything needed for interacting with ssh-agent is set - SshAgent::update_tests_with_shell_env()?; - - // ssh-keygen -t rsa -f $ROOT/id_rsa -N "" -q - let id_rsa_file = tmp.child("id_rsa"); - assert!( - SshKeygen::generate_rsa(id_rsa_file.path(), "")?, - "Failed to ssh-keygen id_rsa" - ); - - // cp $ROOT/id_rsa.pub $ROOT/authorized_keys - let authorized_keys_file = tmp.child("authorized_keys"); - std::fs::copy( - id_rsa_file.path().with_extension("pub"), - authorized_keys_file.path(), - )?; - - // ssh-keygen -t rsa -f $ROOT/ssh_host_rsa_key -N "" -q - let ssh_host_rsa_key_file = tmp.child("ssh_host_rsa_key"); - assert!( - SshKeygen::generate_rsa(ssh_host_rsa_key_file.path(), "")?, - "Failed to ssh-keygen ssh_host_rsa_key" - ); - - config.set_authorized_keys_file(id_rsa_file.path().with_extension("pub")); - config.set_host_key(ssh_host_rsa_key_file.path()); - - let sshd_pid_file = tmp.child("sshd.pid"); - config.set_pid_file(sshd_pid_file.path()); - - // Generate $ROOT/sshd_config based on config - let sshd_config_file = tmp.child("sshd_config"); - sshd_config_file.write_str(&config.to_string())?; - - let sshd_log_file = tmp.child("sshd.log"); - - let (child, port) = Self::try_spawn_next(sshd_config_file.path(), sshd_log_file.path()) - .expect("No open port available for sshd"); - - Ok(Self { child, port, tmp }) - } - - fn try_spawn_next( - config_path: impl AsRef, - log_path: impl AsRef, - ) -> io::Result<(Child, u16)> { - static PORT: AtomicU16 = AtomicU16::new(PORT_RANGE.0); - - loop { - let port = PORT.fetch_add(1, Ordering::Relaxed); - - match Self::try_spawn(port, config_path.as_ref(), log_path.as_ref()) { - // If successful, return our spawned server child process - Ok(Ok(child)) => break Ok((child, port)), - - // If the server died when spawned and we reached the final port, we want to exit - Ok(Err((code, msg))) if port == PORT_RANGE.1 => { - break Err(io::Error::new( - io::ErrorKind::Other, - format!( - "{} failed [{}]: {}", - BIN_PATH_STR, - code.map(|x| x.to_string()) - .unwrap_or_else(|| String::from("???")), - msg - ), - )) - } - - // If we've reached the final port in our range to try, we want to exit - Err(x) if port == PORT_RANGE.1 => break Err(x), - - // Otherwise, try next port - Err(_) | Ok(Err(_)) => continue, - } - } - } - - fn try_spawn( - port: u16, - config_path: impl AsRef, - log_path: impl AsRef, - ) -> io::Result, String)>> { - let mut child = Command::new(BIN_PATH_STR) - .arg("-D") - .arg("-p") - .arg(port.to_string()) - .arg("-f") - .arg(config_path.as_ref()) - .arg("-E") - .arg(log_path.as_ref()) - .spawn()?; - - // Pause for a little bit to make sure that the server didn't die due to an error - thread::sleep(Duration::from_millis(100)); - - if let Some(exit_status) = child.try_wait()? { - let output = child.wait_with_output()?; - Ok(Err(( - exit_status.code(), - format!( - "{}\n{}", - String::from_utf8(output.stdout).unwrap(), - String::from_utf8(output.stderr).unwrap(), - ), - ))) - } else { - Ok(Ok(child)) - } - } -} - -impl Drop for Sshd { - /// Kills server upon drop - fn drop(&mut self) { - let _ = self.child.kill(); - } -} - -#[fixture] -pub fn logger() -> &'static flexi_logger::LoggerHandle { - static LOGGER: OnceCell = OnceCell::new(); - - LOGGER.get_or_init(|| { - // flexi_logger::Logger::try_with_str("off, distant_core=trace, distant_ssh2=trace") - flexi_logger::Logger::try_with_str("off, distant_core=warn, distant_ssh2=warn") - .expect("Failed to load env") - .start() - .expect("Failed to start logger") - }) -} - -#[fixture] -pub fn sshd() -> &'static Sshd { - static SSHD: OnceCell = OnceCell::new(); - - SSHD.get_or_init(|| Sshd::spawn(Default::default()).unwrap()) -} - -#[fixture] -pub async fn session(sshd: &'_ Sshd, _logger: &'_ flexi_logger::LoggerHandle) -> Session { - let port = sshd.port; - - let mut ssh2_session = Ssh2Session::connect( - "127.0.0.1", - Ssh2SessionOpts { - port: Some(port), - identity_files: vec![sshd.tmp.child("id_rsa").path().to_path_buf()], - identities_only: Some(true), - user: Some(USERNAME.to_string()), - user_known_hosts_files: vec![sshd.tmp.child("known_hosts").path().to_path_buf()], - ..Default::default() - }, - ) - .unwrap(); - - ssh2_session - .authenticate(Ssh2AuthHandler { - on_authenticate: Box::new(|ev| { - println!("on_authenticate: {:?}", ev); - Ok(vec![String::new(); ev.prompts.len()]) - }), - on_host_verify: Box::new(|host| { - println!("on_host_verify: {}", host); - Ok(true) - }), - ..Default::default() - }) - .await - .unwrap(); - - ssh2_session.into_ssh_client_session().await.unwrap() -} diff --git a/distant-ssh2/tests/sshd/mod.rs b/distant-ssh2/tests/sshd/mod.rs new file mode 100644 index 0000000..ee5ba39 --- /dev/null +++ b/distant-ssh2/tests/sshd/mod.rs @@ -0,0 +1,677 @@ +use crate::utils::ci_path_to_string; +use anyhow::Context; +use assert_fs::{prelude::*, TempDir}; +use async_trait::async_trait; +use derive_more::Display; +use derive_more::{Deref, DerefMut}; +use distant_core::DistantClient; +use distant_ssh2::{DistantLaunchOpts, Ssh, SshAuthEvent, SshAuthHandler, SshOpts}; +use once_cell::sync::{Lazy, OnceCell}; +use rstest::*; +use std::{ + collections::HashMap, + fmt, io, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + path::{Path, PathBuf}, + process::{Child, Command}, + sync::{ + atomic::{AtomicU16, Ordering}, + Mutex, + }, + thread, + time::Duration, +}; + +#[cfg(unix)] +use std::os::unix::fs::PermissionsExt; + +#[derive(Deref, DerefMut)] +pub struct Ctx { + pub sshd: Sshd, + + #[deref] + #[deref_mut] + pub value: T, +} + +// NOTE: Should find path +// +// Unix should be something like /usr/sbin/sshd +// Windows should be something like C:\Windows\System32\OpenSSH\sshd.exe +static BIN_PATH: Lazy = + Lazy::new(|| which::which(if cfg!(windows) { "sshd.exe" } else { "sshd" }).unwrap()); + +/// Port range to use when finding a port to bind to (using IANA guidance) +const PORT_RANGE: (u16, u16) = (49152, 65535); + +static USERNAME: Lazy = Lazy::new(whoami::username); + +pub struct SshKeygen; + +impl SshKeygen { + // ssh-keygen -t ed25519 -f $ROOT/id_ed25519 -N "" -q + pub fn generate_ed25519( + path: impl AsRef, + passphrase: impl AsRef, + ) -> anyhow::Result { + let res = Command::new("ssh-keygen") + .args(&["-m", "PEM"]) + .args(&["-t", "ed25519"]) + .arg("-f") + .arg(path.as_ref()) + .arg("-N") + .arg(passphrase.as_ref()) + .arg("-q") + .status() + .map(|status| status.success()) + .context("Failed to generate ed25519 key")?; + + #[cfg(unix)] + if res { + // chmod 600 id_ed25519* -> ida_ed25519 + ida_ed25519.pub + std::fs::metadata(path.as_ref().with_extension("pub")) + .context("Failed to load metadata of ed25519 pub key")? + .permissions() + .set_mode(0o600); + std::fs::metadata(path) + .context("Failed to load metadata of ed25519 key")? + .permissions() + .set_mode(0o600); + } + + Ok(res) + } +} + +pub struct SshAgent; + +impl SshAgent { + pub fn generate_shell_env() -> anyhow::Result> { + let output = Command::new("ssh-agent") + .arg("-s") + .output() + .context("Failed to generate Bourne shell commands from ssh-agent")?; + let stdout = + String::from_utf8(output.stdout).context("Failed to parse stdout as utf8 string")?; + Ok(stdout + .split(';') + .map(str::trim) + .filter(|s| s.contains('=')) + .map(|s| { + let mut tokens = s.split('='); + let key = tokens.next().unwrap().trim().to_string(); + let rest = tokens + .map(str::trim) + .map(ToString::to_string) + .collect::>() + .join("="); + (key, rest) + }) + .collect::>()) + } + + pub fn update_tests_with_shell_env() -> anyhow::Result<()> { + let env_map = + Self::generate_shell_env().context("Failed to generate ssh agent shell env")?; + for (key, value) in env_map { + std::env::set_var(key, value); + } + + Ok(()) + } +} + +/// Log level for sshd config +#[allow(dead_code)] +#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Hash)] +pub enum SshdLogLevel { + #[display(fmt = "QUIET")] + Quiet, + #[display(fmt = "FATAL")] + Fatal, + #[display(fmt = "ERROR")] + Error, + #[display(fmt = "INFO")] + Info, + #[display(fmt = "VERBOSE")] + Verbose, + #[display(fmt = "DEBUG")] + Debug, + #[display(fmt = "DEBUG1")] + Debug1, + #[display(fmt = "DEBUG2")] + Debug2, + #[display(fmt = "DEBUG3")] + Debug3, +} + +#[derive(Debug)] +pub struct SshdConfig(HashMap>); + +impl Default for SshdConfig { + fn default() -> Self { + let mut config = Self::new(); + + config.set_authentication_methods(vec!["publickey".to_string()]); + config.set_use_privilege_separation(false); + config.set_subsystem(true, true); + config.set_use_pam(false); + config.set_x11_forwarding(true); + config.set_print_motd(true); + config.set_permit_tunnel(true); + config.set_kbd_interactive_authentication(true); + config.set_allow_tcp_forwarding(true); + config.set_max_startups(500, None); + config.set_strict_modes(false); + config.set_log_level(SshdLogLevel::Debug3); + + config + } +} + +impl SshdConfig { + pub fn new() -> Self { + Self(HashMap::new()) + } + + pub fn set_authentication_methods(&mut self, methods: Vec) { + self.0.insert("AuthenticationMethods".to_string(), methods); + } + + pub fn set_authorized_keys_file(&mut self, path: impl AsRef) { + let path = ci_path_to_string(path.as_ref()); + + self.0.insert("AuthorizedKeysFile".to_string(), vec![path]); + } + + pub fn set_host_key(&mut self, path: impl AsRef) { + let path = ci_path_to_string(path.as_ref()); + + self.0.insert("HostKey".to_string(), vec![path]); + } + + pub fn set_pid_file(&mut self, path: impl AsRef) { + let path = ci_path_to_string(path.as_ref()); + + self.0.insert("PidFile".to_string(), vec![path]); + } + + pub fn set_subsystem(&mut self, sftp: bool, internal_sftp: bool) { + let mut values = Vec::new(); + if sftp { + values.push("sftp".to_string()); + } + if internal_sftp { + values.push("internal-sftp".to_string()); + } + + self.0.insert("Subsystem".to_string(), values); + } + + pub fn set_use_pam(&mut self, yes: bool) { + self.0.insert("UsePAM".to_string(), Self::yes_value(yes)); + } + + pub fn set_x11_forwarding(&mut self, yes: bool) { + self.0 + .insert("X11Forwarding".to_string(), Self::yes_value(yes)); + } + + pub fn set_use_privilege_separation(&mut self, yes: bool) { + self.0 + .insert("UsePrivilegeSeparation".to_string(), Self::yes_value(yes)); + } + + pub fn set_print_motd(&mut self, yes: bool) { + self.0.insert("PrintMotd".to_string(), Self::yes_value(yes)); + } + + pub fn set_permit_tunnel(&mut self, yes: bool) { + self.0 + .insert("PermitTunnel".to_string(), Self::yes_value(yes)); + } + + pub fn set_kbd_interactive_authentication(&mut self, yes: bool) { + self.0.insert( + "KbdInteractiveAuthentication".to_string(), + Self::yes_value(yes), + ); + } + + pub fn set_allow_tcp_forwarding(&mut self, yes: bool) { + self.0 + .insert("AllowTcpForwarding".to_string(), Self::yes_value(yes)); + } + + pub fn set_max_startups(&mut self, start: u16, rate_full: Option<(u16, u16)>) { + let value = format!( + "{}{}", + start, + rate_full + .map(|(r, f)| format!(":{}:{}", r, f)) + .unwrap_or_default(), + ); + + self.0.insert("MaxStartups".to_string(), vec![value]); + } + + pub fn set_strict_modes(&mut self, yes: bool) { + self.0 + .insert("StrictModes".to_string(), Self::yes_value(yes)); + } + + pub fn set_log_level(&mut self, log_level: SshdLogLevel) { + self.0 + .insert("LogLevel".to_string(), vec![log_level.to_string()]); + } + + fn yes_value(yes: bool) -> Vec { + vec![Self::yes_string(yes)] + } + + fn yes_string(yes: bool) -> String { + Self::yes_str(yes).to_string() + } + + const fn yes_str(yes: bool) -> &'static str { + if yes { + "yes" + } else { + "no" + } + } +} + +impl fmt::Display for SshdConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (keyword, values) in self.0.iter() { + writeln!( + f, + "{} {}", + keyword, + values + .iter() + .map(|v| { + let v = v.trim(); + if v.contains(|c: char| c.is_whitespace()) { + format!("\"{}\"", v) + } else { + v.to_string() + } + }) + .collect::>() + .join(" ") + )?; + } + Ok(()) + } +} + +/// Context for some sshd instance +pub struct Sshd { + child: Mutex>, + + /// Port that sshd is listening on + pub port: u16, + + /// Temporary directory used to hold resources for sshd such as its config, keys, and log + pub tmp: TempDir, + + /// Path to config file to print out when failures happen + pub config_file: PathBuf, + + /// Path to log file to print out when failures happen + pub log_file: PathBuf, +} + +impl Sshd { + /// Cached check if dead, does not actually do the check itself + pub fn is_dead(&self) -> bool { + self.child.lock().unwrap().is_none() + } + + pub fn spawn(mut config: SshdConfig) -> anyhow::Result { + let tmp = TempDir::new().context("Failed to create temporary directory")?; + + // Ensure that everything needed for interacting with ssh-agent is set + SshAgent::update_tests_with_shell_env() + .context("Failed to update tests with ssh agent shell env")?; + + // ssh-keygen -t ed25519 -f $ROOT/id_ed25519 -N "" -q + let id_ed25519_file = tmp.child("id_ed25519"); + assert!( + SshKeygen::generate_ed25519(id_ed25519_file.path(), "") + .context("Failed to generate ed25519 key for self")?, + "Failed to ssh-keygen id_ed25519" + ); + + // cp $ROOT/id_ed25519.pub $ROOT/authorized_keys + let authorized_keys_file = tmp.child("authorized_keys"); + std::fs::copy( + id_ed25519_file.path().with_extension("pub"), + authorized_keys_file.path(), + ) + .context("Failed to copy ed25519 pub key to authorized keys file")?; + + // ssh-keygen -t ed25519 -f $ROOT/ssh_host_ed25519_key -N "" -q + let ssh_host_ed25519_key_file = tmp.child("ssh_host_ed25519_key"); + assert!( + SshKeygen::generate_ed25519(ssh_host_ed25519_key_file.path(), "") + .context("Failed to generate ed25519 key for host")?, + "Failed to ssh-keygen ssh_host_ed25519_key" + ); + + config.set_authorized_keys_file(id_ed25519_file.path().with_extension("pub")); + config.set_host_key(ssh_host_ed25519_key_file.path()); + + let sshd_pid_file = tmp.child("sshd.pid"); + config.set_pid_file(sshd_pid_file.path()); + + // Generate $ROOT/sshd_config based on config + let sshd_config_file = tmp.child("sshd_config"); + sshd_config_file + .write_str(&config.to_string()) + .context("Failed to write sshd config to file")?; + + let sshd_log_file = tmp.child("sshd.log"); + + let (child, port) = Self::try_spawn_next(sshd_config_file.path(), sshd_log_file.path()) + .context("Failed to find open port for sshd")?; + + Ok(Self { + child: Mutex::new(Some(child)), + port, + tmp, + config_file: sshd_config_file.to_path_buf(), + log_file: sshd_log_file.to_path_buf(), + }) + } + + fn try_spawn_next( + config_path: impl AsRef, + log_path: impl AsRef, + ) -> anyhow::Result<(Child, u16)> { + static PORT: AtomicU16 = AtomicU16::new(PORT_RANGE.0); + + loop { + let port = PORT.fetch_add(1, Ordering::Relaxed); + + match Self::try_spawn(port, config_path.as_ref(), log_path.as_ref()) { + // If successful, return our spawned server child process + Ok(Ok(child)) => return Ok((child, port)), + + // If the server died when spawned and we reached the final port, we want to exit + Ok(Err((code, msg))) if port == PORT_RANGE.1 => { + anyhow::bail!( + "{BIN_PATH:?} failed [{}]: {}", + code.map(|x| x.to_string()) + .unwrap_or_else(|| String::from("???")), + msg + ) + } + + // If we've reached the final port in our range to try, we want to exit + Err(x) if port == PORT_RANGE.1 => anyhow::bail!(x), + + // Otherwise, try next port + Err(_) | Ok(Err(_)) => continue, + } + } + } + + fn try_spawn( + port: u16, + config_path: impl AsRef, + log_path: impl AsRef, + ) -> anyhow::Result, String)>> { + let child = Command::new(BIN_PATH.as_path()) + .arg("-D") + .arg("-p") + .arg(port.to_string()) + .arg("-f") + .arg(config_path.as_ref()) + .arg("-E") + .arg(log_path.as_ref()) + .spawn() + .with_context(|| format!("Failed to spawn {:?}", BIN_PATH.as_path()))?; + + // Pause for a little bit to make sure that the server didn't die due to an error + thread::sleep(Duration::from_millis(100)); + + let child = match check(child).context("Sshd encountered problems (after 100ms)")? { + Ok(child) => child, + Err(x) => return Ok(Err(x)), + }; + + // Pause for a little bit to make sure that the server didn't die due to an error + thread::sleep(Duration::from_millis(100)); + + let result = check(child).context("Sshd encountered problems (after 200ms)")?; + Ok(result) + } +} + +impl Drop for Sshd { + /// Kills server upon drop + fn drop(&mut self) { + if let Some(mut child) = self.child.lock().unwrap().take() { + let _ = child.kill(); + let _ = child.wait(); + } + } +} + +#[fixture] +pub fn logger() -> &'static flexi_logger::LoggerHandle { + static LOGGER: OnceCell = OnceCell::new(); + + LOGGER.get_or_init(|| { + // flexi_logger::Logger::try_with_str("off, distant_core=trace, distant_ssh2=trace") + flexi_logger::Logger::try_with_str("off, distant_core=warn, distant_ssh2=warn") + .expect("Failed to load env") + .start() + .expect("Failed to start logger") + }) +} + +/// Mocked version of [`SshAuthHandler`] +pub struct MockSshAuthHandler; + +#[async_trait] +impl SshAuthHandler for MockSshAuthHandler { + async fn on_authenticate(&self, event: SshAuthEvent) -> io::Result> { + println!("on_authenticate: {:?}", event); + Ok(vec![String::new(); event.prompts.len()]) + } + + async fn on_verify_host(&self, host: &str) -> io::Result { + println!("on_host_verify: {}", host); + Ok(true) + } + + async fn on_banner(&self, _text: &str) {} + + async fn on_error(&self, _text: &str) {} +} + +#[fixture] +pub fn sshd() -> Sshd { + Sshd::spawn(Default::default()).expect("Failed to spawn sshd") +} + +/// Fixture to establish a client to an SSH server +#[fixture] +pub async fn client(sshd: Sshd, _logger: &'_ flexi_logger::LoggerHandle) -> Ctx { + let ssh_client = load_ssh_client(&sshd).await; + let client = ssh_client + .into_distant_client() + .await + .context("Failed to convert into distant client") + .unwrap(); + Ctx { + sshd, + value: client, + } +} + +/// Fixture to establish a client to a launched server +#[fixture] +pub async fn launched_client( + sshd: Sshd, + _logger: &'_ flexi_logger::LoggerHandle, +) -> Ctx { + let binary = std::env::var("DISTANT_PATH").unwrap_or_else(|_| String::from("distant")); + eprintln!("Setting path to distant binary as {binary}"); + + // Attempt to launch the server and connect to it, using $DISTANT_PATH as the path to the + // binary if provided, defaulting to assuming the binary is on our ssh path otherwise + let ssh_client = load_ssh_client(&sshd).await; + let client = ssh_client + .launch_and_connect(DistantLaunchOpts { + binary, + ..Default::default() + }) + .await + .context("Failed to launch and connect to distant server") + .unwrap(); + + // TODO: Wrapping in ctx does not fully clean up the test as the launched distant server + // is not cleaned up during drop. We don't know what the server's pid is, so our + // only option would be to look up all running distant servers and kill them on drop, + // but that would cause other tests to fail. + // + // Setting an expiration of 1s would clean up running servers and possibly be good enough + Ctx { + sshd, + value: client, + } +} + +/// Access to raw [`Ssh`] client +#[fixture] +pub async fn ssh(sshd: Sshd) -> Ctx { + let ssh = load_ssh_client(&sshd).await; + Ctx { sshd, value: ssh } +} + +async fn load_ssh_client(sshd: &Sshd) -> Ssh { + if sshd.is_dead() { + panic!("sshd is dead!"); + } + + let port = sshd.port; + let opts = SshOpts { + port: Some(port), + identity_files: vec![sshd.tmp.child("id_ed25519").path().to_path_buf()], + identities_only: Some(true), + user: Some(USERNAME.to_string()), + user_known_hosts_files: vec![sshd.tmp.child("known_hosts").path().to_path_buf()], + // verbose: true, + ..Default::default() + }; + + let addrs = vec![ + IpAddr::V4(Ipv4Addr::LOCALHOST), + IpAddr::V6(Ipv6Addr::LOCALHOST), + ]; + let mut errors = Vec::new(); + let msg = format!("Failed to connect to any of these hosts: {addrs:?}"); + + for addr in addrs { + let addr_string = addr.to_string(); + match Ssh::connect(&addr_string, opts.clone()) { + Ok(mut ssh_client) => { + let res = ssh_client.authenticate(MockSshAuthHandler).await; + + match res { + Ok(_) => return ssh_client, + Err(x) => { + errors.push( + anyhow::Error::new(x).context(format!( + "Failed to authenticate with sshd @ {addr_string}" + )), + ); + } + } + } + Err(x) => { + errors.push( + anyhow::Error::new(x) + .context(format!("Failed to connect to sshd @ {addr_string}")), + ); + } + } + } + + // We want to print out the log file from sshd in case it sheds clues on problem + if let Ok(log) = std::fs::read_to_string(&sshd.log_file) { + eprintln!(); + eprintln!("===================="); + eprintln!("= SSHD LOG FILE "); + eprintln!("===================="); + eprintln!(); + eprintln!("{log}"); + eprintln!(); + eprintln!("===================="); + eprintln!(); + } + + // We want to print out the config file from sshd in case it sheds clues on problem + if let Ok(contents) = std::fs::read_to_string(&sshd.config_file) { + eprintln!(); + eprintln!("===================="); + eprintln!("= SSHD CONFIG FILE "); + eprintln!("===================="); + eprintln!(); + eprintln!("{contents}"); + eprintln!(); + eprintln!("===================="); + eprintln!(); + } + + // Check if our sshd process is still running, or if it died and we can report about it + let mut child_lock = sshd.child.lock().unwrap(); + if let Some(child) = child_lock.take() { + match check(child) { + Ok(Ok(child)) => { + eprintln!("sshd is still alive, so something else is going on"); + child_lock.replace(child); + } + Ok(Err((code, msg))) => eprintln!( + "sshd died w/ exit code {}: {msg}", + if let Some(code) = code { + code.to_string() + } else { + "[missing]".to_string() + } + ), + Err(x) => eprintln!("Failed to check status of sshd: {x}"), + } + } else { + eprintln!("sshd is dead!"); + } + drop(child_lock); + + let error = match errors.into_iter().reduce(|x, y| x.context(y)) { + Some(x) => x.context(msg), + None => anyhow::anyhow!(msg), + }; + + panic!("{error:?}"); +} + +fn check(mut child: Child) -> anyhow::Result, String)>> { + if let Some(exit_status) = child.try_wait().context("Failed to check status of sshd")? { + let output = child.wait_with_output().context("Failed to wait on sshd")?; + Ok(Err(( + exit_status.code(), + format!( + "{}\n{}", + String::from_utf8(output.stdout).unwrap(), + String::from_utf8(output.stderr).unwrap(), + ), + ))) + } else { + Ok(Ok(child)) + } +} diff --git a/distant-ssh2/tests/utils/mod.rs b/distant-ssh2/tests/utils/mod.rs new file mode 100644 index 0000000..24515a3 --- /dev/null +++ b/distant-ssh2/tests/utils/mod.rs @@ -0,0 +1,36 @@ +use once_cell::sync::Lazy; +use std::path::{Component, Path, Prefix}; + +// Returns true if running test in Github CI +pub static IS_CI: Lazy = Lazy::new(|| std::env::var("CI").as_deref() == Ok("true")); + +pub fn ci_path_to_string(path: &Path) -> String { + if cfg!(windows) && *IS_CI { + convert_path_to_unix_string(path) + } else { + path.to_string_lossy().to_string() + } +} + +pub fn convert_path_to_unix_string(path: &Path) -> String { + let mut s = String::new(); + for component in path.components() { + s.push('/'); + + match component { + Component::Prefix(x) => match x.kind() { + Prefix::Verbatim(x) => s.push_str(&x.to_string_lossy()), + Prefix::VerbatimUNC(_, _) => unimplemented!(), + Prefix::VerbatimDisk(x) => s.push(x as char), + Prefix::DeviceNS(_) => unimplemented!(), + Prefix::UNC(_, _) => unimplemented!(), + Prefix::Disk(x) => s.push(x as char), + }, + Component::RootDir => continue, + Component::CurDir => s.push('.'), + Component::ParentDir => s.push_str(".."), + Component::Normal(x) => s.push_str(&x.to_string_lossy()), + } + } + s +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..2b493f7 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,11 @@ +max_width = 100 +newline_style = "unix" +indent_style = "Block" +use_field_init_shorthand = true + +# Unstable features +# unstable_features = true +# imports_granularity = "Crate" +# group_imports = "StdExternalCrate" +# reorder_impl_items = true +# normalize_doc_attributes = true diff --git a/src/cli.rs b/src/cli.rs new file mode 100644 index 0000000..bc54480 --- /dev/null +++ b/src/cli.rs @@ -0,0 +1,157 @@ +use crate::{ + config::{CommonConfig, Config}, + paths, CliResult, +}; +use clap::Parser; +use std::{ffi::OsString, path::PathBuf}; + +mod cache; +mod client; +mod commands; +mod manager; +mod spawner; + +pub(crate) use cache::Cache; +pub(crate) use client::Client; +use commands::DistantSubcommand; +pub(crate) use manager::Manager; + +#[cfg_attr(unix, allow(unused_imports))] +pub(crate) use spawner::Spawner; + +/// Represents the primary CLI entrypoint +pub struct Cli { + common: CommonConfig, + command: DistantSubcommand, + config: Config, +} + +#[derive(Debug, Parser)] +#[clap(author, version, about)] +#[clap(name = "distant")] +struct Opt { + #[clap(flatten)] + common: CommonConfig, + + /// Configuration file to load instead of the default paths + #[clap(short = 'c', long = "config", global = true, value_parser)] + config_path: Option, + + #[clap(subcommand)] + command: DistantSubcommand, +} + +impl Cli { + /// Creates a new CLI instance by parsing command-line arguments + pub fn initialize() -> anyhow::Result { + Self::initialize_from(std::env::args_os()) + } + + /// Creates a new CLI instance by parsing providing arguments + pub fn initialize_from(args: I) -> anyhow::Result + where + I: IntoIterator, + T: Into + Clone, + { + // NOTE: We should NOT provide context here as printing help and version are both + // reported this way and providing context puts them under the "caused by" section + let Opt { + mut common, + config_path, + command, + } = Opt::try_parse_from(args)?; + + // Try to load a configuration file, defaulting if no config file is found + let config = Config::load_multi(config_path)?; + + // Extract the common config from our config file + let config_common = match &command { + DistantSubcommand::Client(_) => config.client.common.clone(), + DistantSubcommand::Generate(_) => config.generate.common.clone(), + DistantSubcommand::Manager(_) => config.manager.common.clone(), + DistantSubcommand::Server(_) => config.server.common.clone(), + }; + + // Blend common configs together + common.log_file = common.log_file.or(config_common.log_file); + common.log_level = common.log_level.or(config_common.log_level); + + // Assign the appropriate log file based on client/manager/server + if common.log_file.is_none() { + // NOTE: We assume that any of these commands will log to the user-specific path + // and that services that run manager will explicitly override the + // log file path + common.log_file = Some(match &command { + DistantSubcommand::Client(_) => paths::user::CLIENT_LOG_FILE_PATH.to_path_buf(), + DistantSubcommand::Server(_) => paths::user::SERVER_LOG_FILE_PATH.to_path_buf(), + DistantSubcommand::Generate(_) => paths::user::GENERATE_LOG_FILE_PATH.to_path_buf(), + + // If we are listening as a manager, then we want to log to a manager-specific file + DistantSubcommand::Manager(cmd) if cmd.is_listen() => { + paths::user::MANAGER_LOG_FILE_PATH.to_path_buf() + } + + // Otherwise, if we are performing some operation as a client talking to the + // manager, then we want to log to the client file + DistantSubcommand::Manager(_) => paths::user::CLIENT_LOG_FILE_PATH.to_path_buf(), + }); + } + + Ok(Cli { + common, + command, + config, + }) + } + + /// Initializes a logger for the CLI, returning a handle to the logger + pub fn init_logger(&self) -> flexi_logger::LoggerHandle { + use flexi_logger::{FileSpec, LevelFilter, LogSpecification, Logger}; + let modules = &["distant", "distant_core", "distant_net", "distant_ssh2"]; + + // Disable logging for everything but our binary, which is based on verbosity + let mut builder = LogSpecification::builder(); + builder.default(LevelFilter::Off); + + // For each module, configure logging + for module in modules { + builder.module( + module, + self.common + .log_level + .unwrap_or_default() + .to_log_level_filter(), + ); + } + + // Create our logger, but don't initialize yet + let logger = Logger::with(builder.build()).format_for_files(flexi_logger::opt_format); + + // Assign our log output to a file + // NOTE: We can unwrap here as we assign the log file earlier + let logger = logger.log_to_file( + FileSpec::try_from(self.common.log_file.as_ref().unwrap()) + .expect("Failed to create log file spec"), + ); + + logger.start().expect("Failed to initialize logger") + } + + #[cfg(windows)] + pub fn is_manager_listen_command(&self) -> bool { + match &self.command { + DistantSubcommand::Manager(cmd) => cmd.is_listen(), + _ => false, + } + } + + /// Runs the CLI + pub fn run(self) -> CliResult { + match self.command { + DistantSubcommand::Client(cmd) => cmd.run(self.config.client), + DistantSubcommand::Generate(cmd) => cmd.run(self.config.generate), + DistantSubcommand::Manager(cmd) => cmd.run(self.config.manager), + DistantSubcommand::Server(cmd) => cmd.run(self.config.server), + } + } +} diff --git a/src/cli/cache.rs b/src/cli/cache.rs new file mode 100644 index 0000000..7ff1259 --- /dev/null +++ b/src/cli/cache.rs @@ -0,0 +1,102 @@ +use crate::paths::user::CACHE_FILE_PATH; +use anyhow::Context; +use distant_core::ConnectionId; +use serde::{Deserialize, Serialize}; +use std::{ + io, + path::{Path, PathBuf}, +}; + +mod id; +pub use id::CacheId; + +/// Represents a disk-backed cache of data +#[derive(Clone, Debug)] +pub struct Cache { + file: CacheFile, + pub data: CacheData, +} + +impl Cache { + /// Loads the cache from the specified file path, or default user-local cache path, + /// constructing data from the default cache if not found + pub async fn read_from_disk_or_default( + custom_path: impl Into>, + ) -> anyhow::Result { + let file = CacheFile::new(custom_path); + let data = file.read_or_default().await?; + Ok(Self { file, data }) + } + + /// Writes the cache back to disk + pub async fn write_to_disk(&self) -> anyhow::Result<()> { + self.file.write(&self.data).await + } +} + +/// Points to a cache file to support reading, writing, and editing the data +#[derive(Clone, Debug)] +pub struct CacheFile { + path: PathBuf, +} + +impl CacheFile { + /// Creates a new [`CacheFile`] from the given path, defaulting to a user-local cache path + /// if none is provided + pub fn new(custom_path: impl Into>) -> Self { + Self { + path: custom_path + .into() + .unwrap_or_else(|| CACHE_FILE_PATH.to_path_buf()), + } + } + + async fn read_or_default(&self) -> anyhow::Result { + CacheData::read_or_default(self.path.as_path()) + .await + .with_context(|| format!("Failed to read cache from {:?}", self.path.as_path())) + } + + async fn write(&self, data: &CacheData) -> anyhow::Result<()> { + data.write(self.path.as_path()) + .await + .with_context(|| format!("Failed to write cache to {:?}", self.path.as_path())) + } +} + +/// Provides quick access to cli-specific cache for a user +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct CacheData { + /// Connection id of selected connection (or 0 if nothing selected) + pub selected: CacheId, +} + +impl CacheData { + /// Reads the cache data from disk + async fn read(path: impl AsRef) -> io::Result { + let bytes = tokio::fs::read(path).await?; + toml_edit::de::from_slice(&bytes).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)) + } + + /// Reads the cache data if the file exists, otherwise returning a default cache instance + async fn read_or_default(path: impl AsRef) -> io::Result { + match Self::read(path).await { + Ok(cache) => Ok(cache), + Err(x) if x.kind() == io::ErrorKind::NotFound => Ok(Self::default()), + Err(x) => Err(x), + } + } + + /// Writes the cache data to disk + async fn write(&self, path: impl AsRef) -> io::Result<()> { + let bytes = toml_edit::ser::to_vec(self) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?; + + // Ensure the parent directory of the cache exists + if let Some(parent) = path.as_ref().parent() { + tokio::fs::create_dir_all(parent).await?; + } + + tokio::fs::write(path, bytes).await + } +} diff --git a/src/cli/cache/id.rs b/src/cli/cache/id.rs new file mode 100644 index 0000000..b0a4d6f --- /dev/null +++ b/src/cli/cache/id.rs @@ -0,0 +1,105 @@ +use serde::{Deserialize, Serialize}; +use std::{ + convert::TryFrom, + fmt, + ops::{Deref, DerefMut}, + str::FromStr, +}; + +/// NOTE: This type only exists due to a bug with toml-rs where a u64 cannot be stored if its +/// value is greater than i64's max as it gets written as a negative number and then +/// fails to get read back out. To avoid this, we have a wrapper type that serializes +/// and deserializes using a string +/// +/// https://github.com/alexcrichton/toml-rs/issues/256 +#[derive(Copy, Clone, Debug, Default, Hash, Serialize, Deserialize)] +#[serde(into = "String", try_from = "String")] +pub struct CacheId(T) +where + T: fmt::Display + FromStr + Clone, + T::Err: fmt::Display; + +impl CacheId +where + T: fmt::Display + FromStr + Clone, + T::Err: fmt::Display, +{ + /// Returns the value of this storage id container + pub fn value(self) -> T { + self.0 + } +} + +impl AsRef for CacheId +where + T: fmt::Display + FromStr + Clone, + T::Err: fmt::Display, +{ + fn as_ref(&self) -> &T { + &self.0 + } +} + +impl AsMut for CacheId +where + T: fmt::Display + FromStr + Clone, + T::Err: fmt::Display, +{ + fn as_mut(&mut self) -> &mut T { + &mut self.0 + } +} + +impl Deref for CacheId +where + T: fmt::Display + FromStr + Clone, + T::Err: fmt::Display, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for CacheId +where + T: fmt::Display + FromStr + Clone, + T::Err: fmt::Display, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl fmt::Display for CacheId +where + T: fmt::Display + FromStr + Clone, + T::Err: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From> for String +where + T: fmt::Display + FromStr + Clone, + T::Err: fmt::Display, +{ + fn from(id: CacheId) -> Self { + id.to_string() + } +} + +impl TryFrom for CacheId +where + T: fmt::Display + FromStr + Clone, + T::Err: fmt::Display, +{ + type Error = T::Err; + + fn try_from(s: String) -> Result { + Ok(CacheId(s.parse()?)) + } +} diff --git a/src/cli/client.rs b/src/cli/client.rs new file mode 100644 index 0000000..7a93abf --- /dev/null +++ b/src/cli/client.rs @@ -0,0 +1,176 @@ +use crate::config::NetworkConfig; +use anyhow::Context; +use distant_core::{ + net::{AuthRequest, AuthResponse, FramedTransport, PlainCodec}, + DistantManagerClient, DistantManagerClientConfig, +}; +use log::*; + +mod msg; +pub use msg::*; + +pub struct Client { + config: DistantManagerClientConfig, + network: NetworkConfig, +} + +impl Client { + pub fn new(network: NetworkConfig) -> Self { + let config = DistantManagerClientConfig::with_prompts( + |prompt| rpassword::prompt_password(prompt), + |prompt| { + use std::io::Write; + eprint!("{}", prompt); + std::io::stderr().lock().flush()?; + + let mut answer = String::new(); + std::io::stdin().read_line(&mut answer)?; + Ok(answer) + }, + ); + Self { config, network } + } + + /// Configure client to talk over stdin and stdout using messages + pub fn using_msg_stdin_stdout(self) -> Self { + self.using_msg(MsgSender::from_stdout(), MsgReceiver::from_stdin()) + } + + /// Configure client to use a pair of msg sender and receiver + pub fn using_msg(mut self, tx: MsgSender, rx: MsgReceiver) -> Self { + self.config = DistantManagerClientConfig { + on_challenge: { + let tx = tx.clone(); + let rx = rx.clone(); + Box::new(move |questions, extra| { + let question_cnt = questions.len(); + + if let Err(x) = tx.send_blocking(&AuthRequest::Challenge { questions, extra }) { + error!("{}", x); + return (0..question_cnt) + .into_iter() + .map(|_| "".to_string()) + .collect(); + } + + match rx.recv_blocking() { + Ok(AuthResponse::Challenge { answers }) => answers, + Ok(x) => { + error!("Invalid response received: {:?}", x); + (0..question_cnt) + .into_iter() + .map(|_| "".to_string()) + .collect() + } + Err(x) => { + error!("{}", x); + (0..question_cnt) + .into_iter() + .map(|_| "".to_string()) + .collect() + } + } + }) + }, + on_info: { + let tx = tx.clone(); + Box::new(move |text| { + let _ = tx.send_blocking(&AuthRequest::Info { text }); + }) + }, + on_verify: { + let tx = tx.clone(); + Box::new(move |kind, text| { + if let Err(x) = tx.send_blocking(&AuthRequest::Verify { kind, text }) { + error!("{}", x); + return false; + } + + match rx.recv_blocking() { + Ok(AuthResponse::Verify { valid }) => valid, + Ok(x) => { + error!("Invalid response received: {:?}", x); + false + } + Err(x) => { + error!("{}", x); + false + } + } + }) + }, + on_error: { + Box::new(move |kind, text| { + let _ = tx.send_blocking(&AuthRequest::Error { kind, text }); + }) + }, + }; + self + } + + /// Connect to the manager listening on the socket or windows pipe based on + /// the [`NetworkConfig`] provided to the client earlier. Will return a new instance + /// of the [`DistantManagerClient`] upon successful connection + pub async fn connect(self) -> anyhow::Result { + #[cfg(unix)] + let transport = { + use distant_core::net::UnixSocketTransport; + let mut maybe_transport = None; + let mut error: Option = None; + for path in self.network.to_unix_socket_path_candidates() { + match UnixSocketTransport::connect(path).await { + Ok(transport) => { + info!("Connected to unix socket @ {:?}", path); + maybe_transport = Some(FramedTransport::new(transport, PlainCodec)); + break; + } + Err(x) => { + let err = anyhow::Error::new(x) + .context(format!("Failed to connect to unix socket {:?}", path)); + if let Some(x) = error { + error = Some(x.context(err)); + } else { + error = Some(err); + } + } + } + } + + maybe_transport.ok_or_else(|| { + error.unwrap_or_else(|| anyhow::anyhow!("No unix socket candidate available")) + })? + }; + + #[cfg(windows)] + let transport = { + use distant_core::net::WindowsPipeTransport; + let mut maybe_transport = None; + let mut error: Option = None; + for name in self.network.to_windows_pipe_name_candidates() { + match WindowsPipeTransport::connect_local(name).await { + Ok(transport) => { + info!("Connected to named windows socket @ {:?}", name); + maybe_transport = Some(FramedTransport::new(transport, PlainCodec)); + break; + } + Err(x) => { + let err = anyhow::Error::new(x) + .context(format!("Failed to connect to windows pipe {:?}", name)); + if let Some(x) = error { + error = Some(x.context(err)); + } else { + error = Some(err); + } + } + } + } + + maybe_transport.ok_or_else(|| { + error.unwrap_or_else(|| anyhow::anyhow!("No windows pipe candidate available")) + })? + }; + + DistantManagerClient::new(self.config, transport) + .context("Failed to create client for manager") + } +} diff --git a/src/msg.rs b/src/cli/client/msg.rs similarity index 81% rename from src/msg.rs rename to src/cli/client/msg.rs index 942e39a..bb3776f 100644 --- a/src/msg.rs +++ b/src/cli/client/msg.rs @@ -1,3 +1,4 @@ +use log::*; use serde::{de::DeserializeOwned, Serialize}; use std::{ io::{self, Write}, @@ -6,7 +7,7 @@ use std::{ }; use tokio::sync::mpsc; -type SendFn = Arc io::Result<()>>>>; +type SendFn = Arc io::Result<()> + Send>>>; type RecvFn = Arc io::Result<()> + Send>>>; /// Sends JSON messages over stdout @@ -17,7 +18,7 @@ pub struct MsgSender { impl From for MsgSender where - F: FnMut(&[u8]) -> io::Result<()> + 'static, + F: FnMut(&[u8]) -> io::Result<()> + Send + 'static, { fn from(f: F) -> Self { Self { @@ -30,8 +31,8 @@ impl MsgSender { pub fn from_stdout() -> Self { let mut writer = std::io::stdout(); Self::from(Box::new(move |output: &'_ [u8]| { - let _ = writer.write_all(output)?; - let _ = writer.flush()?; + writer.write_all(output)?; + writer.flush()?; Ok(()) })) } @@ -72,7 +73,6 @@ impl MsgReceiver { } /// Spawns a thread to continually poll receiver for new input of the given type - #[allow(dead_code)] pub fn into_rx(self) -> mpsc::Receiver> where T: DeserializeOwned + Send + 'static, @@ -112,14 +112,25 @@ impl MsgReceiver { // is a partial match let data: T = loop { // Read in another line of input - let _ = self.recv.lock().unwrap()(&mut input)?; + self.recv.lock().unwrap()(&mut input)?; // Attempt to parse current input as type, yielding it on success, continuing to read // more input if error is unexpected EOF (meaning we are partially reading json), and // failing if we get any other error + trace!( + "Parsing into {} for {:?}", + std::any::type_name::(), + input, + ); match serde_json::from_str(&input) { Ok(data) => break data, - Err(x) if x.is_eof() => continue, + Err(x) if x.is_eof() => { + trace!( + "Not ready to parse as {}, so trying again with next update", + std::any::type_name::(), + ); + continue; + } Err(x) => return Err(x.into()), } }; diff --git a/src/cli/commands.rs b/src/cli/commands.rs new file mode 100644 index 0000000..c34e049 --- /dev/null +++ b/src/cli/commands.rs @@ -0,0 +1,25 @@ +use clap::Subcommand; + +mod client; +mod generate; +mod manager; +mod server; + +#[derive(Debug, Subcommand)] +pub enum DistantSubcommand { + /// Perform client commands + #[clap(subcommand)] + Client(client::ClientSubcommand), + + /// Perform manager commands + #[clap(subcommand)] + Manager(manager::ManagerSubcommand), + + /// Perform server commands + #[clap(subcommand)] + Server(server::ServerSubcommand), + + /// Perform generation commands + #[clap(subcommand)] + Generate(generate::GenerateSubcommand), +} diff --git a/src/cli/commands/client.rs b/src/cli/commands/client.rs new file mode 100644 index 0000000..5eed150 --- /dev/null +++ b/src/cli/commands/client.rs @@ -0,0 +1,738 @@ +use crate::{ + cli::{ + client::{MsgReceiver, MsgSender}, + Cache, Client, + }, + config::{ClientConfig, ClientLaunchConfig, NetworkConfig}, + paths::user::CACHE_FILE_PATH_STR, + CliError, CliResult, +}; +use anyhow::Context; +use clap::{Subcommand, ValueHint}; +use dialoguer::{console::Term, theme::ColorfulTheme, Select}; +use distant_core::{ + data::{ChangeKindSet, Environment}, + net::{IntoSplit, Request, Response, TypedAsyncRead, TypedAsyncWrite}, + ConnectionId, Destination, DistantManagerClient, DistantMsg, DistantRequestData, + DistantResponseData, Extra, RemoteCommand, Watcher, +}; +use log::*; +use std::{ + io, + path::{Path, PathBuf}, + time::Duration, +}; + +mod buf; +mod format; +mod link; +mod lsp; +mod shell; +mod stdin; + +pub use format::Format; +use format::Formatter; +use link::RemoteProcessLink; +use lsp::Lsp; +use shell::Shell; + +#[derive(Debug, Subcommand)] +pub enum ClientSubcommand { + /// Performs some action on a remote machine + Action { + /// Location to store cached data + #[clap( + long, + value_hint = ValueHint::FilePath, + value_parser, + default_value = CACHE_FILE_PATH_STR.as_str() + )] + cache: PathBuf, + + /// Specify a connection being managed + #[clap(long)] + connection: Option, + + #[clap(flatten)] + network: NetworkConfig, + + /// Represents the maximum time (in seconds) to wait for a network request before timing out + #[clap(short, long)] + timeout: Option, + + #[clap(subcommand)] + request: DistantRequestData, + }, + + /// Requests that active manager connects to the server at the specified destination + Connect { + /// Location to store cached data + #[clap( + long, + value_hint = ValueHint::FilePath, + value_parser, + default_value = CACHE_FILE_PATH_STR.as_str() + )] + cache: PathBuf, + + #[clap(flatten)] + network: NetworkConfig, + + #[clap(short, long, default_value_t, value_enum)] + format: Format, + + destination: Box, + }, + + /// Launches the server-portion of the binary on a remote machine + Launch { + /// Location to store cached data + #[clap( + long, + value_hint = ValueHint::FilePath, + value_parser, + default_value = CACHE_FILE_PATH_STR.as_str() + )] + cache: PathBuf, + + #[clap(flatten)] + config: ClientLaunchConfig, + + #[clap(flatten)] + network: NetworkConfig, + + #[clap(short, long, default_value_t, value_enum)] + format: Format, + + destination: Box, + }, + + /// Specialized treatment of running a remote LSP process + Lsp { + /// Location to store cached data + #[clap( + long, + value_hint = ValueHint::FilePath, + value_parser, + default_value = CACHE_FILE_PATH_STR.as_str() + )] + cache: PathBuf, + + /// Specify a connection being managed + #[clap(long)] + connection: Option, + + #[clap(flatten)] + network: NetworkConfig, + + /// If provided, will run in persist mode, meaning that the process will not be killed if the + /// client disconnects from the server + #[clap(long)] + persist: bool, + + /// If provided, will run LSP in a pty + #[clap(long)] + pty: bool, + + cmd: String, + }, + + /// Runs actions in a read-eval-print loop + Repl { + /// Location to store cached data + #[clap( + long, + value_hint = ValueHint::FilePath, + value_parser, + default_value = CACHE_FILE_PATH_STR.as_str() + )] + cache: PathBuf, + + /// Specify a connection being managed + #[clap(long)] + connection: Option, + + #[clap(flatten)] + network: NetworkConfig, + + /// Format used for input into and output from the repl + #[clap(short, long, default_value_t, value_enum)] + format: Format, + + /// Represents the maximum time (in seconds) to wait for a network request before timing out + #[clap(short, long)] + timeout: Option, + }, + + /// Select the active connection + Select { + /// Location to store cached data + #[clap( + long, + value_hint = ValueHint::FilePath, + value_parser, + default_value = CACHE_FILE_PATH_STR.as_str() + )] + cache: PathBuf, + + /// Connection to use, otherwise will prompt to select + connection: Option, + + #[clap(flatten)] + network: NetworkConfig, + }, + + /// Specialized treatment of running a remote shell process + Shell { + /// Location to store cached data + #[clap( + long, + value_hint = ValueHint::FilePath, + value_parser, + default_value = CACHE_FILE_PATH_STR.as_str() + )] + cache: PathBuf, + + /// Specify a connection being managed + #[clap(long)] + connection: Option, + + #[clap(flatten)] + network: NetworkConfig, + + /// Environment variables to provide to the shell + #[clap(long, default_value_t)] + environment: Environment, + + /// If provided, will run in persist mode, meaning that the process will not be killed if the + /// client disconnects from the server + #[clap(long)] + persist: bool, + + /// Optional command to run instead of $SHELL + cmd: Option, + }, +} + +impl ClientSubcommand { + pub fn run(self, config: ClientConfig) -> CliResult { + let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?; + rt.block_on(Self::async_run(self, config)) + } + + fn cache_path(&self) -> &Path { + match self { + Self::Action { cache, .. } => cache.as_path(), + Self::Connect { cache, .. } => cache.as_path(), + Self::Launch { cache, .. } => cache.as_path(), + Self::Lsp { cache, .. } => cache.as_path(), + Self::Repl { cache, .. } => cache.as_path(), + Self::Select { cache, .. } => cache.as_path(), + Self::Shell { cache, .. } => cache.as_path(), + } + } + + async fn async_run(self, config: ClientConfig) -> CliResult { + let mut cache = Cache::read_from_disk_or_default(self.cache_path().to_path_buf()).await?; + + match self { + Self::Action { + connection, + network, + request, + timeout, + .. + } => { + let network = network.merge(config.network); + debug!("Connecting to manager"); + let mut client = Client::new(network) + .connect() + .await + .context("Failed to connect to manager")?; + + let connection_id = + use_or_lookup_connection_id(&mut cache, connection, &mut client).await?; + + debug!("Opening channel to connection {}", connection_id); + let mut channel = client.open_channel(connection_id).await.with_context(|| { + format!("Failed to open channel to connection {connection_id}") + })?; + + debug!( + "Timeout configured to be {}", + match timeout { + Some(secs) => format!("{}s", secs), + None => "none".to_string(), + } + ); + + let formatter = Formatter::shell(); + + debug!("Sending request {:?}", request); + match request { + DistantRequestData::ProcSpawn { + cmd, + environment, + current_dir, + persist, + pty, + } => { + debug!("Special request spawning {:?}", cmd); + let mut proc = RemoteCommand::new() + .environment(environment) + .current_dir(current_dir) + .persist(persist) + .pty(pty) + .spawn(channel, cmd.as_str()) + .await + .with_context(|| format!("Failed to spawn {cmd}"))?; + + // Now, map the remote process' stdin/stdout/stderr to our own process + let link = RemoteProcessLink::from_remote_pipes( + proc.stdin.take(), + proc.stdout.take().unwrap(), + proc.stderr.take().unwrap(), + ); + + let status = proc.wait().await.context("Failed to wait for process")?; + + // Shut down our link + link.shutdown().await; + + if !status.success { + if let Some(code) = status.code { + return Err(CliError::Exit(code as u8)); + } else { + return Err(CliError::FAILURE); + } + } + } + DistantRequestData::Watch { + path, + recursive, + only, + except, + } => { + debug!("Special request creating watcher for {:?}", path); + let mut watcher = Watcher::watch( + channel, + path.as_path(), + recursive, + only.into_iter().collect::(), + except.into_iter().collect::(), + ) + .await + .with_context(|| format!("Failed to watch {path:?}"))?; + + // Continue to receive and process changes + while let Some(change) = watcher.next().await { + // TODO: Provide a cleaner way to print just a change + let res = Response::new( + "".to_string(), + DistantMsg::Single(DistantResponseData::Changed(change)), + ); + + formatter.print(res).context("Failed to print change")?; + } + } + request => { + let response = channel + .send_timeout( + DistantMsg::Single(request), + timeout + .or(config.action.timeout) + .map(Duration::from_secs_f32), + ) + .await + .context("Failed to send request")?; + + debug!("Got response {:?}", response); + + // NOTE: We expect a single response, and if that is an error then + // we want to pass that error up the stack + let id = response.id; + let origin_id = response.origin_id; + match response.payload { + DistantMsg::Single(DistantResponseData::Error(x)) => { + return Err(CliError::Error(anyhow::anyhow!(x))); + } + payload => formatter + .print(Response { + id, + origin_id, + payload, + }) + .context("Failed to print response")?, + } + } + } + } + Self::Connect { + network, + format, + destination, + .. + } => { + let network = network.merge(config.network); + debug!("Connecting to manager"); + let mut client = { + let client = match format { + Format::Shell => Client::new(network), + Format::Json => Client::new(network).using_msg_stdin_stdout(), + }; + client + .connect() + .await + .context("Failed to connect to manager")? + }; + + // Trigger our manager to connect to the launched server + debug!("Connecting to server at {}", destination); + let id = client + .connect(*destination, Extra::new()) + .await + .context("Failed to connect to server")?; + + // Mark the server's id as the new default + debug!("Updating selected connection id in cache to {}", id); + *cache.data.selected = id; + cache.write_to_disk().await?; + + println!("{}", id); + } + Self::Launch { + config: launcher_config, + network, + format, + destination, + .. + } => { + let network = network.merge(config.network); + debug!("Connecting to manager"); + let mut client = { + let client = match format { + Format::Shell => Client::new(network), + Format::Json => Client::new(network).using_msg_stdin_stdout(), + }; + client + .connect() + .await + .context("Failed to connect to manager")? + }; + + // Merge our launch configs, overwriting anything in the config file + // with our cli arguments + let mut extra = Extra::from(config.launch); + extra.extend(Extra::from(launcher_config).into_map()); + + // Grab the host we are connecting to for later use + let host = destination.to_host_string(); + + // Start the server using our manager + debug!("Launching server at {} with {}", destination, extra); + let mut new_destination = client + .launch(*destination, extra) + .await + .context("Failed to launch server")?; + + // Update the new destination with our previously-used host if the + // new host is not globally-accessible + if !new_destination.is_host_global() { + trace!( + "Updating host to {:?} from non-global {:?}", + host, + new_destination.to_host_string() + ); + new_destination + .replace_host(host.as_str()) + .context("Failed to replace host")?; + } else { + trace!("Host {:?} is global", new_destination.to_host_string()); + } + + // Trigger our manager to connect to the launched server + debug!("Connecting to server at {}", new_destination); + let id = client + .connect(new_destination, Extra::new()) + .await + .context("Failed to connect to server")?; + + // Mark the server's id as the new default + debug!("Updating selected connection id in cache to {}", id); + *cache.data.selected = id; + cache.write_to_disk().await?; + + println!("{}", id); + } + Self::Lsp { + connection, + network, + persist, + pty, + cmd, + .. + } => { + let network = network.merge(config.network); + debug!("Connecting to manager"); + let mut client = Client::new(network) + .connect() + .await + .context("Failed to connect to manager")?; + + let connection_id = + use_or_lookup_connection_id(&mut cache, connection, &mut client).await?; + + debug!("Opening channel to connection {}", connection_id); + let channel = client.open_channel(connection_id).await.with_context(|| { + format!("Failed to open channel to connection {connection_id}") + })?; + + debug!( + "Spawning LSP server (persist = {}, pty = {}): {}", + persist, pty, cmd + ); + Lsp::new(channel).spawn(cmd, persist, pty).await?; + } + Self::Repl { + connection, + network, + format, + timeout, + .. + } => { + let network = network.merge(config.network); + debug!("Connecting to manager"); + let mut client = Client::new(network) + .using_msg_stdin_stdout() + .connect() + .await + .context("Failed to connect to manager")?; + + let connection_id = + use_or_lookup_connection_id(&mut cache, connection, &mut client).await?; + + debug!("Opening raw channel to connection {}", connection_id); + let channel = client + .open_raw_channel(connection_id) + .await + .with_context(|| { + format!("Failed to open raw channel to connection {connection_id}") + })?; + + debug!( + "Timeout configured to be {}", + match timeout { + Some(secs) => format!("{}s", secs), + None => "none".to_string(), + } + ); + + // TODO: Support shell format? + if !format.is_json() { + return Err(CliError::Error(anyhow::anyhow!( + "Only JSON format is supported" + ))); + } + + debug!("Starting repl using format {:?}", format); + let (mut writer, mut reader) = channel.transport.into_split(); + let response_task = tokio::task::spawn(async move { + let tx = MsgSender::from_stdout(); + while let Some(response) = reader.read().await? { + debug!("Received response {:?}", response); + tx.send_blocking(&response)?; + } + io::Result::Ok(()) + }); + + let request_task = tokio::spawn(async move { + let mut rx = MsgReceiver::from_stdin() + .into_rx::>>(); + loop { + match rx.recv().await { + Some(Ok(request)) => { + debug!("Sending request {:?}", request); + writer.write(request).await?; + } + Some(Err(x)) => error!("{}", x), + None => { + debug!("Shutting down repl"); + break; + } + } + } + io::Result::Ok(()) + }); + + let (r1, r2) = tokio::join!(request_task, response_task); + match r1 { + Err(x) => error!("{}", x), + Ok(Err(x)) => error!("{}", x), + _ => (), + } + match r2 { + Err(x) => error!("{}", x), + Ok(Err(x)) => error!("{}", x), + _ => (), + } + + debug!("Shutting down repl"); + } + Self::Select { + connection, + network, + .. + } => match connection { + Some(id) => { + *cache.data.selected = id; + cache.write_to_disk().await?; + } + None => { + let network = network.merge(config.network); + debug!("Connecting to manager"); + let mut client = Client::new(network) + .connect() + .await + .context("Failed to connect to manager")?; + let list = client + .list() + .await + .context("Failed to get a list of managed connections")?; + + if list.is_empty() { + return Err(CliError::Error(anyhow::anyhow!( + "No connection available in manager" + ))); + } + + trace!("Building selection prompt of {} choices", list.len()); + let selected = list + .iter() + .enumerate() + .find_map(|(i, (id, _))| { + if *cache.data.selected == *id { + Some(i) + } else { + None + } + }) + .unwrap_or_default(); + + let items: Vec = list + .iter() + .map(|(_, destination)| { + format!( + "{}{}{}", + destination + .scheme() + .map(|x| format!(r"{}://", x)) + .unwrap_or_default(), + destination.to_host_string(), + destination + .port() + .map(|x| format!(":{}", x)) + .unwrap_or_default() + ) + }) + .collect(); + + trace!("Rendering prompt"); + let selected = Select::with_theme(&ColorfulTheme::default()) + .items(&items) + .default(selected) + .interact_on_opt(&Term::stderr()) + .context("Failed to render prompt")?; + + match selected { + Some(index) => { + trace!("Selected choice {}", index); + if let Some((id, _)) = list.iter().nth(index) { + debug!("Updating selected connection id in cache to {}", id); + *cache.data.selected = *id; + cache.write_to_disk().await?; + } + } + None => { + debug!("No change in selection of default connection id"); + } + } + } + }, + Self::Shell { + connection, + network, + environment, + persist, + cmd, + .. + } => { + let network = network.merge(config.network); + debug!("Connecting to manager"); + let mut client = Client::new(network) + .connect() + .await + .context("Failed to connect to manager")?; + + let connection_id = + use_or_lookup_connection_id(&mut cache, connection, &mut client).await?; + + debug!("Opening channel to connection {}", connection_id); + let channel = client.open_channel(connection_id).await.with_context(|| { + format!("Failed to open channel to connection {connection_id}") + })?; + + debug!( + "Spawning shell (environment = {:?}, persist = {}): {}", + environment, + persist, + cmd.as_deref().unwrap_or(r"$SHELL") + ); + Shell::new(channel).spawn(cmd, environment, persist).await?; + } + } + + Ok(()) + } +} + +async fn use_or_lookup_connection_id( + cache: &mut Cache, + connection: Option, + client: &mut DistantManagerClient, +) -> anyhow::Result { + match connection { + Some(id) => { + trace!("Using specified connection id: {}", id); + Ok(id) + } + None => { + trace!("Looking up connection id"); + let list = client + .list() + .await + .context("Failed to retrieve list of available connections")?; + + if list.contains_key(&cache.data.selected) { + trace!("Using cached connection id: {}", cache.data.selected); + Ok(*cache.data.selected) + } else if list.is_empty() { + trace!("Cached connection id is invalid as there are no connections"); + anyhow::bail!("There are no connections being managed! You need to start one!"); + } else if list.len() > 1 { + trace!("Cached connection id is invalid and there are multiple connections"); + anyhow::bail!( + "There are multiple connections being managed! You need to pick one!" + ); + } else { + trace!("Cached connection id is invalid"); + *cache.data.selected = *list.keys().next().unwrap(); + trace!( + "Detected singular connection id, so updating cache: {}", + cache.data.selected + ); + cache.write_to_disk().await?; + Ok(*cache.data.selected) + } + } + } +} diff --git a/src/buf.rs b/src/cli/commands/client/buf.rs similarity index 100% rename from src/buf.rs rename to src/cli/commands/client/buf.rs diff --git a/src/output.rs b/src/cli/commands/client/format.rs similarity index 67% rename from src/output.rs rename to src/cli/commands/client/format.rs index c6ca24c..e596619 100644 --- a/src/output.rs +++ b/src/cli/commands/client/format.rs @@ -1,50 +1,71 @@ -use crate::opt::Format; +use clap::ValueEnum; use distant_core::{ - data::{ChangeKind, Error, Metadata, SystemInfo}, - Response, ResponseData, + data::{ChangeKind, DistantMsg, DistantResponseData, Error, FileType, Metadata, SystemInfo}, + net::Response, }; use log::*; -use std::io; -use std::io::Write; +use std::io::{self, Write}; +use tabled::{object::Rows, style::Style, Alignment, Disable, Modify, Table, Tabled}; -/// Represents the output content and destination -pub enum ResponseOut { - Stdout(Vec), - StdoutLine(Vec), - Stderr(Vec), - StderrLine(Vec), - None, +#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)] +#[clap(rename_all = "snake_case")] +pub enum Format { + /// Sends and receives data in JSON format + Json, + + /// Commands are traditional shell commands and output responses are + /// inline with what is expected of a program's output in a shell + Shell, +} + +impl Format { + /// Returns true if json format + pub fn is_json(self) -> bool { + matches!(self, Self::Json) + } } -impl ResponseOut { +impl Default for Format { + fn default() -> Self { + Self::Shell + } +} + +pub struct Formatter { + format: Format, +} + +impl Formatter { /// Create a new output message for the given response based on the specified format - pub fn new(format: Format, res: Response) -> io::Result { - let payload_cnt = res.payload.len(); + pub fn new(format: Format) -> Self { + Self { format } + } + + /// Creates a new [`Formatter`] using [`Format`] of `Format::Shell` + pub fn shell() -> Self { + Self::new(Format::Shell) + } - Ok(match format { - Format::Json => ResponseOut::StdoutLine( + /// Consumes the output message, printing it based on its configuration + pub fn print(&self, res: Response>) -> io::Result<()> { + let output = match self.format { + Format::Json => Output::StdoutLine( serde_json::to_vec(&res) .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?, ), // NOTE: For shell, we assume a singular entry in the response's payload - Format::Shell if payload_cnt != 1 => { + Format::Shell if res.payload.is_batch() => { return Err(io::Error::new( io::ErrorKind::InvalidData, - format!( - "Got {} entries in payload data, but shell expects exactly 1", - payload_cnt - ), + "Shell does not support batch responses", )) } - Format::Shell => format_shell(res.payload.into_iter().next().unwrap()), - }) - } + Format::Shell => format_shell(res.payload.into_single().unwrap()), + }; - /// Consumes the output message, printing it based on its configuration - pub fn print(self) { - match self { - Self::Stdout(x) => { + match output { + Output::Stdout(x) => { // NOTE: Because we are not including a newline in the output, // it is not guaranteed to be written out. In the case of // LSP protocol, the JSON content is not followed by a @@ -58,7 +79,7 @@ impl ResponseOut { error!("Failed to flush stdout: {}", x); } } - Self::StdoutLine(x) => { + Output::StdoutLine(x) => { if let Err(x) = io::stdout().lock().write_all(&x) { error!("Failed to write stdout: {}", x); } @@ -67,7 +88,7 @@ impl ResponseOut { error!("Failed to write stdout newline: {}", x); } } - Self::Stderr(x) => { + Output::Stderr(x) => { // NOTE: Because we are not including a newline in the output, // it is not guaranteed to be written out. In the case of // LSP protocol, the JSON content is not followed by a @@ -81,7 +102,7 @@ impl ResponseOut { error!("Failed to flush stderr: {}", x); } } - Self::StderrLine(x) => { + Output::StderrLine(x) => { if let Err(x) = io::stderr().lock().write_all(&x) { error!("Failed to write stderr: {}", x); } @@ -90,44 +111,52 @@ impl ResponseOut { error!("Failed to write stderr newline: {}", x); } } - Self::None => {} + Output::None => {} } + + Ok(()) } } -fn format_shell(data: ResponseData) -> ResponseOut { +/// Represents the output content and destination +enum Output { + Stdout(Vec), + StdoutLine(Vec), + Stderr(Vec), + StderrLine(Vec), + None, +} + +fn format_shell(data: DistantResponseData) -> Output { match data { - ResponseData::Ok => ResponseOut::None, - ResponseData::Error(Error { kind, description }) => { - ResponseOut::StderrLine(format!("Failed ({}): '{}'.", kind, description).into_bytes()) + DistantResponseData::Ok => Output::None, + DistantResponseData::Error(Error { description, .. }) => { + Output::StderrLine(description.into_bytes()) } - ResponseData::Blob { data } => ResponseOut::StdoutLine(data), - ResponseData::Text { data } => ResponseOut::StdoutLine(data.into_bytes()), - ResponseData::DirEntries { entries, .. } => ResponseOut::StdoutLine( - entries - .into_iter() - .map(|entry| { - format!( - "{}{}", - entry.path.as_os_str().to_string_lossy(), - if entry.file_type.is_dir() { - // NOTE: This can be different from the server if - // the server OS is unix and the client is - // not or vice versa; for now, this doesn't - // matter as we only support unix-based - // operating systems, but something to keep - // in mind - std::path::MAIN_SEPARATOR.to_string() - } else { - String::new() - }, - ) - }) - .collect::>() - .join("\n") - .into_bytes(), - ), - ResponseData::Changed(change) => ResponseOut::StdoutLine( + DistantResponseData::Blob { data } => Output::StdoutLine(data), + DistantResponseData::Text { data } => Output::StdoutLine(data.into_bytes()), + DistantResponseData::DirEntries { entries, .. } => { + #[derive(Tabled)] + struct EntryRow { + ty: String, + path: String, + } + + let table = Table::new(entries.into_iter().map(|entry| EntryRow { + ty: String::from(match entry.file_type { + FileType::Dir => "", + FileType::File => "", + FileType::Symlink => "", + }), + path: entry.path.to_string_lossy().to_string(), + })) + .with(Style::blank()) + .with(Disable::Row(..1)) + .with(Modify::new(Rows::new(..)).with(Alignment::left())); + + Output::Stdout(table.to_string().into_bytes()) + } + DistantResponseData::Changed(change) => Output::StdoutLine( format!( "{}{}", match change.kind { @@ -147,14 +176,14 @@ fn format_shell(data: ResponseData) -> ResponseOut { ) .into_bytes(), ), - ResponseData::Exists { value: exists } => { + DistantResponseData::Exists { value: exists } => { if exists { - ResponseOut::StdoutLine(b"true".to_vec()) + Output::StdoutLine(b"true".to_vec()) } else { - ResponseOut::StdoutLine(b"false".to_vec()) + Output::StdoutLine(b"false".to_vec()) } } - ResponseData::Metadata(Metadata { + DistantResponseData::Metadata(Metadata { canonicalized_path, file_type, len, @@ -164,7 +193,7 @@ fn format_shell(data: ResponseData) -> ResponseOut { modified, unix, windows, - }) => ResponseOut::StdoutLine( + }) => Output::StdoutLine( format!( concat!( "{}", @@ -254,35 +283,25 @@ fn format_shell(data: ResponseData) -> ResponseOut { ) .into_bytes(), ), - ResponseData::ProcEntries { entries } => ResponseOut::StdoutLine( - entries - .into_iter() - .map(|entry| format!("{}: {} {}", entry.id, entry.cmd, entry.args.join(" "))) - .collect::>() - .join("\n") - .into_bytes(), - ), - ResponseData::ProcSpawned { .. } => ResponseOut::None, - ResponseData::ProcStdout { data, .. } => ResponseOut::Stdout(data), - ResponseData::ProcStderr { data, .. } => ResponseOut::Stderr(data), - ResponseData::ProcDone { id, success, code } => { + DistantResponseData::ProcSpawned { .. } => Output::None, + DistantResponseData::ProcStdout { data, .. } => Output::Stdout(data), + DistantResponseData::ProcStderr { data, .. } => Output::Stderr(data), + DistantResponseData::ProcDone { id, success, code } => { if success { - ResponseOut::None + Output::None } else if let Some(code) = code { - ResponseOut::StderrLine( - format!("Proc {} failed with code {}", id, code).into_bytes(), - ) + Output::StderrLine(format!("Proc {} failed with code {}", id, code).into_bytes()) } else { - ResponseOut::StderrLine(format!("Proc {} failed", id).into_bytes()) + Output::StderrLine(format!("Proc {} failed", id).into_bytes()) } } - ResponseData::SystemInfo(SystemInfo { + DistantResponseData::SystemInfo(SystemInfo { family, os, arch, current_dir, main_separator, - }) => ResponseOut::StdoutLine( + }) => Output::StdoutLine( format!( concat!( "Family: {:?}\n", diff --git a/src/link.rs b/src/cli/commands/client/link.rs similarity index 95% rename from src/link.rs rename to src/cli/commands/client/link.rs index 6b20eb7..d1e3dd8 100644 --- a/src/link.rs +++ b/src/cli/commands/client/link.rs @@ -1,7 +1,9 @@ -use crate::{constants::MAX_PIPE_CHUNK_SIZE, stdin}; +use super::stdin; +use crate::constants::MAX_PIPE_CHUNK_SIZE; use distant_core::{ RemoteLspStderr, RemoteLspStdin, RemoteLspStdout, RemoteStderr, RemoteStdin, RemoteStdout, }; +use log::*; use std::{ io::{self, Write}, thread, @@ -26,6 +28,7 @@ macro_rules! from_pipes { let task = tokio::spawn(async move { loop { if let Some(input) = rx.recv().await { + trace!("Forwarding stdin: {:?}", String::from_utf8_lossy(&input)); if let Err(x) = stdin_handle.write(&*input).await { break Err(x); } diff --git a/src/cli/commands/client/lsp.rs b/src/cli/commands/client/lsp.rs new file mode 100644 index 0000000..9dea374 --- /dev/null +++ b/src/cli/commands/client/lsp.rs @@ -0,0 +1,51 @@ +use super::{link::RemoteProcessLink, CliError, CliResult}; +use anyhow::Context; +use distant_core::{data::PtySize, DistantChannel, RemoteLspCommand}; +use terminal_size::{terminal_size, Height, Width}; + +#[derive(Clone)] +pub struct Lsp(DistantChannel); + +impl Lsp { + pub fn new(channel: DistantChannel) -> Self { + Self(channel) + } + + pub async fn spawn(self, cmd: impl Into, persist: bool, pty: bool) -> CliResult { + let cmd = cmd.into(); + let mut proc = RemoteLspCommand::new() + .persist(persist) + .pty(if pty { + terminal_size().map(|(Width(width), Height(height))| { + PtySize::from_rows_and_cols(height, width) + }) + } else { + None + }) + .spawn(self.0, &cmd) + .await + .with_context(|| format!("Failed to spawn {cmd}"))?; + + // Now, map the remote LSP server's stdin/stdout/stderr to our own process + let link = RemoteProcessLink::from_remote_lsp_pipes( + proc.stdin.take(), + proc.stdout.take().unwrap(), + proc.stderr.take().unwrap(), + ); + + let status = proc.wait().await.context("Failed to wait for process")?; + + // Shut down our link + link.shutdown().await; + + if !status.success { + if let Some(code) = status.code { + return Err(CliError::Exit(code as u8)); + } else { + return Err(CliError::FAILURE); + } + } + + Ok(()) + } +} diff --git a/src/cli/commands/client/shell.rs b/src/cli/commands/client/shell.rs new file mode 100644 index 0000000..f2ddb48 --- /dev/null +++ b/src/cli/commands/client/shell.rs @@ -0,0 +1,113 @@ +use super::{link::RemoteProcessLink, CliError, CliResult}; +use anyhow::Context; +use distant_core::{ + data::{Environment, PtySize}, + DistantChannel, RemoteCommand, +}; +use log::*; +use std::time::Duration; +use terminal_size::{terminal_size, Height, Width}; +use termwiz::{ + caps::Capabilities, + input::{InputEvent, KeyCodeEncodeModes}, + terminal::{new_terminal, Terminal}, +}; + +#[derive(Clone)] +pub struct Shell(DistantChannel); + +impl Shell { + pub fn new(channel: DistantChannel) -> Self { + Self(channel) + } + + pub async fn spawn( + self, + cmd: impl Into>, + mut environment: Environment, + persist: bool, + ) -> CliResult { + // Automatically add TERM=xterm-256color if not specified + if !environment.contains_key("TERM") { + environment.insert("TERM".to_string(), "xterm-256color".to_string()); + } + + let cmd = cmd.into().unwrap_or_else(|| "/bin/sh".to_string()); + let mut proc = RemoteCommand::new() + .persist(persist) + .environment(environment) + .pty( + terminal_size() + .map(|(Width(cols), Height(rows))| PtySize::from_rows_and_cols(rows, cols)), + ) + .spawn(self.0, &cmd) + .await + .with_context(|| format!("Failed to spawn {cmd}"))?; + + // Create a new terminal in raw mode + let mut terminal = new_terminal( + Capabilities::new_from_env().context("Failed to load terminal capabilities")?, + ) + .context("Failed to create terminal")?; + terminal.set_raw_mode().context("Failed to set raw mode")?; + + let mut stdin = proc.stdin.take().unwrap(); + let resizer = proc.clone_resizer(); + tokio::spawn(async move { + while let Ok(input) = terminal.poll_input(Some(Duration::new(0, 0))) { + match input { + Some(InputEvent::Key(ev)) => { + if let Ok(input) = ev.key.encode( + ev.modifiers, + KeyCodeEncodeModes { + enable_csi_u_key_encoding: false, + application_cursor_keys: false, + newline_mode: false, + }, + ) { + if let Err(x) = stdin.write_str(input).await { + error!("Failed to write to stdin of remote process: {}", x); + break; + } + } + } + Some(InputEvent::Resized { cols, rows }) => { + if let Err(x) = resizer + .resize(PtySize::from_rows_and_cols(rows as u16, cols as u16)) + .await + { + error!("Failed to resize remote process: {}", x); + break; + } + } + Some(_) => continue, + None => tokio::time::sleep(Duration::from_millis(1)).await, + } + } + }); + + // Now, map the remote shell's stdout/stderr to our own process, + // while stdin is handled by the task above + let link = RemoteProcessLink::from_remote_pipes( + None, + proc.stdout.take().unwrap(), + proc.stderr.take().unwrap(), + ); + + // Continually loop to check for terminal resize changes while the process is still running + let status = proc.wait().await.context("Failed to wait for process")?; + + // Shut down our link + link.shutdown().await; + + if !status.success { + if let Some(code) = status.code { + return Err(CliError::Exit(code as u8)); + } else { + return Err(CliError::FAILURE); + } + } + + Ok(()) + } +} diff --git a/src/stdin.rs b/src/cli/commands/client/stdin.rs similarity index 100% rename from src/stdin.rs rename to src/cli/commands/client/stdin.rs diff --git a/src/cli/commands/generate.rs b/src/cli/commands/generate.rs new file mode 100644 index 0000000..5dbda26 --- /dev/null +++ b/src/cli/commands/generate.rs @@ -0,0 +1,93 @@ +use crate::{cli::Opt, config::GenerateConfig, CliResult}; +use anyhow::Context; +use clap::{CommandFactory, Subcommand}; +use clap_complete::{generate as clap_generate, Shell}; +use distant_core::{ + net::{Request, Response}, + DistantMsg, DistantRequestData, DistantResponseData, +}; +use std::{fs, io, path::PathBuf}; + +#[derive(Debug, Subcommand)] +pub enum GenerateSubcommand { + /// Generate JSON schema for server request/response + Schema { + /// If specified, will output to the file at the given path instead of stdout + #[clap(long)] + file: Option, + }, + + // Generate completion info for CLI + Completion { + /// If specified, will output to the file at the given path instead of stdout + #[clap(long)] + file: Option, + + /// Specific shell to target for the generated output + #[clap(arg_enum, value_parser)] + shell: Shell, + }, +} + +impl GenerateSubcommand { + pub fn run(self, _config: GenerateConfig) -> CliResult { + let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?; + rt.block_on(Self::async_run(self)) + } + + async fn async_run(self) -> CliResult { + match self { + Self::Schema { file } => { + let request_schema = + serde_json::to_value(&Request::>::root_schema()) + .context("Failed to serialize request schema")?; + let response_schema = serde_json::to_value(&Response::< + DistantMsg, + >::root_schema()) + .context("Failed to serialize response schema")?; + + let schema = serde_json::json!({ + "request": request_schema, + "response": response_schema, + }); + + if let Some(path) = file { + serde_json::to_writer_pretty( + &mut fs::OpenOptions::new() + .create(true) + .write(true) + .open(&path) + .with_context(|| format!("Failed to open {path:?}"))?, + &schema, + ) + .context("Failed to write to {path:?}")?; + } else { + serde_json::to_writer_pretty(&mut io::stdout(), &schema) + .context("Failed to print to stdout")?; + } + } + + Self::Completion { file, shell } => { + let name = "distant"; + let mut cmd = Opt::command(); + + if let Some(path) = file { + clap_generate( + shell, + &mut cmd, + name, + &mut fs::OpenOptions::new() + .create(true) + .write(true) + .open(&path) + .with_context(|| format!("Failed to open {path:?}"))?, + ) + } else { + clap_generate(shell, &mut cmd, name, &mut io::stdout()) + } + } + } + + Ok(()) + } +} diff --git a/src/cli/commands/manager.rs b/src/cli/commands/manager.rs new file mode 100644 index 0000000..f474f57 --- /dev/null +++ b/src/cli/commands/manager.rs @@ -0,0 +1,431 @@ +use crate::{ + cli::{Cache, Client, Manager}, + config::{ManagerConfig, NetworkConfig}, + paths::user::CACHE_FILE_PATH_STR, + CliResult, +}; +use anyhow::Context; +use clap::{Subcommand, ValueHint}; +use distant_core::{net::ServerRef, ConnectionId, DistantManagerConfig}; +use log::*; +use once_cell::sync::Lazy; +use service_manager::{ + ServiceInstallCtx, ServiceLabel, ServiceLevel, ServiceManager, ServiceManagerKind, + ServiceStartCtx, ServiceStopCtx, ServiceUninstallCtx, +}; +use std::{ffi::OsString, path::PathBuf}; +use tabled::{Table, Tabled}; + +/// [`ServiceLabel`] for our manager in the form `rocks.distant.manager` +static SERVICE_LABEL: Lazy = Lazy::new(|| ServiceLabel { + qualifier: String::from("rocks"), + organization: String::from("distant"), + application: String::from("manager"), +}); + +mod handlers; + +#[derive(Debug, Subcommand)] +pub enum ManagerSubcommand { + /// Interact with a manager being run by a service management platform + #[clap(subcommand)] + Service(ManagerServiceSubcommand), + + /// Listen for incoming requests as a manager + Listen { + /// If specified, will fork the process to run as a standalone daemon + #[clap(long)] + daemon: bool, + + /// If specified, will listen on a user-local unix socket or local windows named pipe + #[clap(long)] + user: bool, + + #[clap(flatten)] + network: NetworkConfig, + }, + + /// Retrieve information about a specific connection + Info { + id: ConnectionId, + #[clap(flatten)] + network: NetworkConfig, + }, + + /// List information about all connections + List { + #[clap(flatten)] + network: NetworkConfig, + + /// Location to store cached data + #[clap( + long, + value_hint = ValueHint::FilePath, + value_parser, + default_value = CACHE_FILE_PATH_STR.as_str() + )] + cache: PathBuf, + }, + + /// Kill a specific connection + Kill { + #[clap(flatten)] + network: NetworkConfig, + id: ConnectionId, + }, + + /// Send a shutdown request to the manager + Shutdown { + #[clap(flatten)] + network: NetworkConfig, + }, +} + +#[derive(Debug, Subcommand)] +pub enum ManagerServiceSubcommand { + /// Start the manager as a service + Start { + /// Type of service manager used to run this service, defaulting to platform native + #[clap(long, value_enum)] + kind: Option, + + /// If specified, starts as a user-level service + #[clap(long)] + user: bool, + }, + + /// Stop the manager as a service + Stop { + #[clap(long, value_enum)] + kind: Option, + + /// If specified, stops a user-level service + #[clap(long)] + user: bool, + }, + + /// Install the manager as a service + Install { + #[clap(long, value_enum)] + kind: Option, + + /// If specified, installs as a user-level service + #[clap(long)] + user: bool, + }, + + /// Uninstall the manager as a service + Uninstall { + #[clap(long, value_enum)] + kind: Option, + + /// If specified, uninstalls a user-level service + #[clap(long)] + user: bool, + }, +} + +impl ManagerSubcommand { + /// Returns true if the manager subcommand is listen + pub fn is_listen(&self) -> bool { + matches!(self, Self::Listen { .. }) + } + + pub fn run(self, config: ManagerConfig) -> CliResult { + match &self { + Self::Listen { daemon, .. } if *daemon => Self::run_daemon(self, config), + _ => { + let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?; + rt.block_on(Self::async_run(self, config)) + } + } + } + + #[cfg(windows)] + fn run_daemon(self, _config: ManagerConfig) -> CliResult { + use crate::cli::Spawner; + let pid = Spawner::spawn_running_background(Vec::new()) + .context("Failed to spawn background process")?; + println!("[distant manager detached, pid = {}]", pid); + Ok(()) + } + + #[cfg(unix)] + fn run_daemon(self, config: ManagerConfig) -> CliResult { + use crate::CliError; + use fork::{daemon, Fork}; + + debug!("Forking process"); + match daemon(true, true) { + Ok(Fork::Child) => { + let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?; + rt.block_on(async { Self::async_run(self, config).await })?; + Ok(()) + } + Ok(Fork::Parent(pid)) => { + println!("[distant manager detached, pid = {}]", pid); + if fork::close_fd().is_err() { + Err(CliError::Error(anyhow::anyhow!("Fork failed to close fd"))) + } else { + Ok(()) + } + } + Err(_) => Err(CliError::Error(anyhow::anyhow!("Fork failed"))), + } + } + + async fn async_run(self, config: ManagerConfig) -> CliResult { + match self { + Self::Service(ManagerServiceSubcommand::Start { kind, user }) => { + debug!("Starting manager service via {:?}", kind); + let mut manager = ::target_or_native(kind) + .context("Failed to detect native service manager")?; + + if user { + manager + .set_level(ServiceLevel::User) + .context("Failed to set service manager to user level")?; + } + + manager + .start(ServiceStartCtx { + label: SERVICE_LABEL.clone(), + }) + .context("Failed to start service")?; + Ok(()) + } + Self::Service(ManagerServiceSubcommand::Stop { kind, user }) => { + debug!("Stopping manager service via {:?}", kind); + let mut manager = ::target_or_native(kind) + .context("Failed to detect native service manager")?; + + if user { + manager + .set_level(ServiceLevel::User) + .context("Failed to set service manager to user level")?; + } + + manager + .stop(ServiceStopCtx { + label: SERVICE_LABEL.clone(), + }) + .context("Failed to stop service")?; + Ok(()) + } + Self::Service(ManagerServiceSubcommand::Install { kind, user }) => { + debug!("Installing manager service via {:?}", kind); + let mut manager = ::target_or_native(kind) + .context("Failed to detect native service manager")?; + let mut args = vec![OsString::from("manager"), OsString::from("listen")]; + + if user { + args.push(OsString::from("--user")); + manager + .set_level(ServiceLevel::User) + .context("Failed to set service manager to user level")?; + } + + manager + .install(ServiceInstallCtx { + label: SERVICE_LABEL.clone(), + + // distant manager listen + program: std::env::current_exe() + .ok() + .unwrap_or_else(|| PathBuf::from("distant")), + args, + }) + .context("Failed to install service")?; + + Ok(()) + } + Self::Service(ManagerServiceSubcommand::Uninstall { kind, user }) => { + debug!("Uninstalling manager service via {:?}", kind); + let mut manager = ::target_or_native(kind) + .context("Failed to detect native service manager")?; + if user { + manager + .set_level(ServiceLevel::User) + .context("Failed to set service manager to user level")?; + } + manager + .uninstall(ServiceUninstallCtx { + label: SERVICE_LABEL.clone(), + }) + .context("Failed to uninstall service")?; + + Ok(()) + } + Self::Listen { network, user, .. } => { + let network = network.merge(config.network); + + info!( + "Starting manager (network = {})", + if network.as_opt().is_some() { + "custom" + } else if user { + "user" + } else { + "global" + } + ); + let manager_ref = Manager::new( + DistantManagerConfig { + user, + ..Default::default() + }, + network, + ) + .listen() + .await + .context("Failed to start manager")?; + + // Register our handlers for different schemes + debug!("Registering handlers with manager"); + manager_ref + .register_launch_handler("manager", handlers::ManagerLaunchHandler::new()) + .await + .context("Failed to register launch handler for \"manager://\"")?; + manager_ref + .register_connect_handler("distant", handlers::DistantConnectHandler) + .await + .context("Failed to register connect handler for \"distant://\"")?; + + #[cfg(any(feature = "libssh", feature = "ssh2"))] + // Register ssh-specific handlers if either feature flag is enabled + { + manager_ref + .register_launch_handler("ssh", handlers::SshLaunchHandler) + .await + .context("Failed to register launch handler for \"ssh://\"")?; + manager_ref + .register_connect_handler("ssh", handlers::SshConnectHandler) + .await + .context("Failed to register connect handler for \"ssh://\"")?; + } + + // Let our server run to completion + manager_ref + .wait() + .await + .context("Failed to wait on manager")?; + info!("Manager is shutting down"); + + Ok(()) + } + Self::Info { network, id } => { + let network = network.merge(config.network); + debug!("Getting info about connection {}", id); + let info = Client::new(network) + .connect() + .await + .context("Failed to connect to manager")? + .info(id) + .await + .context("Failed to get info about connection")?; + + #[derive(Tabled)] + struct InfoRow { + id: ConnectionId, + scheme: String, + host: String, + port: String, + extra: String, + } + + println!( + "{}", + Table::new(vec![InfoRow { + id: info.id, + scheme: info + .destination + .scheme() + .map(ToString::to_string) + .unwrap_or_default(), + host: info.destination.to_host_string(), + port: info + .destination + .port() + .map(|x| x.to_string()) + .unwrap_or_default(), + extra: info.extra.to_string() + }]) + ); + + Ok(()) + } + Self::List { network, cache } => { + let network = network.merge(config.network); + debug!("Getting list of connections"); + let list = Client::new(network) + .connect() + .await + .context("Failed to connect to manager")? + .list() + .await + .context("Failed to get list of connections")?; + + debug!("Looking up selected connection"); + let selected = Cache::read_from_disk_or_default(cache) + .await + .context("Failed to look up selected connection")? + .data + .selected; + + #[derive(Tabled)] + struct ListRow { + selected: bool, + id: ConnectionId, + scheme: String, + host: String, + port: String, + } + + println!( + "{}", + Table::new(list.into_iter().map(|(id, destination)| { + ListRow { + selected: *selected == id, + id, + scheme: destination + .scheme() + .map(ToString::to_string) + .unwrap_or_default(), + host: destination.to_host_string(), + port: destination + .port() + .map(|x| x.to_string()) + .unwrap_or_default(), + } + })) + ); + + Ok(()) + } + Self::Kill { network, id } => { + let network = network.merge(config.network); + debug!("Killing connection {}", id); + Client::new(network) + .connect() + .await + .context("Failed to connect to manager")? + .kill(id) + .await + .with_context(|| format!("Failed to kill connection to server {id}"))?; + Ok(()) + } + Self::Shutdown { network } => { + let network = network.merge(config.network); + debug!("Shutting down manager"); + Client::new(network) + .connect() + .await + .context("Failed to connect to manager")? + .shutdown() + .await + .context("Failed to shutdown manager")?; + Ok(()) + } + } + } +} diff --git a/src/cli/commands/manager/handlers.rs b/src/cli/commands/manager/handlers.rs new file mode 100644 index 0000000..a4849c6 --- /dev/null +++ b/src/cli/commands/manager/handlers.rs @@ -0,0 +1,404 @@ +use crate::config::ClientLaunchConfig; +use async_trait::async_trait; +use distant_core::{ + net::{ + AuthClient, AuthQuestion, FramedTransport, IntoSplit, SecretKey32, TcpTransport, + XChaCha20Poly1305Codec, + }, + BoxedDistantReader, BoxedDistantWriter, BoxedDistantWriterReader, ConnectHandler, Destination, + Extra, LaunchHandler, +}; +use log::*; +use std::{ + io, + net::{IpAddr, SocketAddr}, + path::PathBuf, + process::Stdio, +}; +use tokio::{ + io::{AsyncBufReadExt, BufReader}, + process::{Child, Command}, + sync::Mutex, +}; + +#[inline] +fn missing(label: &str) -> io::Error { + io::Error::new(io::ErrorKind::InvalidInput, format!("Missing {}", label)) +} + +#[inline] +fn invalid(label: &str) -> io::Error { + io::Error::new(io::ErrorKind::InvalidInput, format!("Invalid {}", label)) +} + +/// Supports launching locally through the manager as defined by `manager://...` +pub struct ManagerLaunchHandler { + servers: Mutex>, +} + +impl ManagerLaunchHandler { + pub fn new() -> Self { + Self { + servers: Mutex::new(Vec::new()), + } + } +} + +#[async_trait] +impl LaunchHandler for ManagerLaunchHandler { + async fn launch( + &self, + destination: &Destination, + extra: &Extra, + _auth_client: &mut AuthClient, + ) -> io::Result { + trace!("Handling launch of {destination} with {extra}"); + let config = ClientLaunchConfig::from(extra.clone()); + + // Get the path to the distant binary, ensuring it exists and is executable + let program = which::which(match config.distant.bin { + Some(bin) => PathBuf::from(bin), + None => std::env::current_exe().unwrap_or_else(|_| { + PathBuf::from(if cfg!(windows) { + "distant.exe" + } else { + "distant" + }) + }), + }) + .map_err(|x| io::Error::new(io::ErrorKind::NotFound, x))?; + + // Build our command to run + let mut args = vec![ + String::from("server"), + String::from("listen"), + String::from("--host"), + config + .distant + .bind_server + .as_ref() + .map(ToString::to_string) + .unwrap_or_else(|| String::from("any")), + ]; + + if let Some(port) = destination.port() { + args.push("--port".to_string()); + args.push(port.to_string()); + } + + // Add any extra arguments to the command + if let Some(extra_args) = config.distant.args { + // NOTE: Split arguments based on whether we are running on windows or unix + args.extend(if cfg!(windows) { + winsplit::split(&extra_args) + } else { + shell_words::split(&extra_args) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))? + }); + } + + // Spawn it and wait to get the communicated destination + // NOTE: Server will persist until this handler is dropped + let mut command = Command::new(program); + command + .kill_on_drop(true) + .args(args) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + debug!("Launching local to manager by spawning command: {command:?}"); + let mut child = command.spawn()?; + + let mut stdout = BufReader::new(child.stdout.take().unwrap()); + + let mut line = String::new(); + loop { + match stdout.read_line(&mut line).await { + Ok(n) if n > 0 => { + if let Ok(destination) = line[..n].trim().parse::() { + // Store a reference to the server so we can terminate them + // when this handler is dropped + self.servers.lock().await.push(child); + + break Ok(destination); + } else { + line.clear(); + } + } + + // If we reach the point of no more data, then fail with EOF + Ok(_) => { + // Ensure that the server is terminated + child.kill().await?; + + break Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "Missing output destination", + )); + } + + // If we fail to read a line, we assume that the child has completed + // and we missed it, so capture the stderr to report issues + Err(x) => { + let output = child.wait_with_output().await?; + break Err(io::Error::new( + io::ErrorKind::Other, + String::from_utf8(output.stderr).unwrap_or_else(|_| x.to_string()), + )); + } + } + } + } +} + +/// Supports launching remotely via SSH as defined by `ssh://...` +#[cfg(any(feature = "libssh", feature = "ssh2"))] +pub struct SshLaunchHandler; + +#[cfg(any(feature = "libssh", feature = "ssh2"))] +#[async_trait] +impl LaunchHandler for SshLaunchHandler { + async fn launch( + &self, + destination: &Destination, + extra: &Extra, + auth_client: &mut AuthClient, + ) -> io::Result { + trace!("Handling launch of {destination} with {extra}"); + let config = ClientLaunchConfig::from(extra.clone()); + + use distant_ssh2::DistantLaunchOpts; + let mut ssh = load_ssh(destination, extra)?; + let handler = AuthClientSshAuthHandler::new(auth_client); + let _ = ssh.authenticate(handler).await?; + let opts = { + let opts = DistantLaunchOpts::default(); + DistantLaunchOpts { + binary: config.distant.bin.unwrap_or(opts.binary), + args: config.distant.args.unwrap_or(opts.args), + use_login_shell: !config.distant.no_shell, + timeout: match extra.get("timeout") { + Some(s) => std::time::Duration::from_millis( + s.parse::().map_err(|_| invalid("timeout"))?, + ), + None => opts.timeout, + }, + } + }; + + debug!("Launching via ssh: {opts:?}"); + ssh.launch(opts).await?.try_to_destination() + } +} + +/// Supports connecting to a remote distant TCP server as defined by `distant://...` +pub struct DistantConnectHandler; + +impl DistantConnectHandler { + pub async fn try_connect(ips: Vec, port: u16) -> io::Result { + // Try each IP address with the same port to see if one works + let mut err = None; + for ip in ips { + let addr = SocketAddr::new(ip, port); + debug!("Attempting to connect to distant server @ {}", addr); + match TcpTransport::connect(addr).await { + Ok(transport) => return Ok(transport), + Err(x) => err = Some(x), + } + } + + // If all failed, return the last error we got + Err(err.expect("Err set above")) + } +} + +#[async_trait] +impl ConnectHandler for DistantConnectHandler { + async fn connect( + &self, + destination: &Destination, + extra: &Extra, + auth_client: &mut AuthClient, + ) -> io::Result { + trace!("Handling connect of {destination} with {extra}"); + let host = destination.to_host_string(); + let port = destination.port().ok_or_else(|| missing("port"))?; + let mut candidate_ips = tokio::net::lookup_host(format!("{}:{}", host, port)) + .await + .map_err(|x| { + io::Error::new( + x.kind(), + format!("{} needs to be resolvable outside of ssh: {}", host, x), + ) + })? + .into_iter() + .map(|addr| addr.ip()) + .collect::>(); + candidate_ips.sort_unstable(); + candidate_ips.dedup(); + if candidate_ips.is_empty() { + return Err(io::Error::new( + io::ErrorKind::AddrNotAvailable, + format!("Unable to resolve {}:{}", host, port), + )); + } + + // Use provided password or extra key if available, otherwise ask for it, and produce a + // codec using the key + let codec = { + let key = destination + .password() + .or_else(|| extra.get("key").map(|s| s.as_str())); + + let key = match key { + Some(key) => key.parse::().map_err(|_| invalid("key"))?, + None => { + let answers = auth_client + .challenge(vec![AuthQuestion::new("key")], Default::default()) + .await?; + answers + .first() + .ok_or_else(|| missing("key"))? + .parse::() + .map_err(|_| invalid("key"))? + } + }; + XChaCha20Poly1305Codec::from(key) + }; + + // Establish a TCP connection, wrap it, and split it out into a writer and reader + let transport = Self::try_connect(candidate_ips, port).await?; + let transport = FramedTransport::new(transport, codec); + let (writer, reader) = transport.into_split(); + let writer: BoxedDistantWriter = Box::new(writer); + let reader: BoxedDistantReader = Box::new(reader); + Ok((writer, reader)) + } +} + +/// Supports connecting to a remote SSH server as defined by `ssh://...` +#[cfg(any(feature = "libssh", feature = "ssh2"))] +pub struct SshConnectHandler; + +#[cfg(any(feature = "libssh", feature = "ssh2"))] +#[async_trait] +impl ConnectHandler for SshConnectHandler { + async fn connect( + &self, + destination: &Destination, + extra: &Extra, + auth_client: &mut AuthClient, + ) -> io::Result { + trace!("Handling connect of {destination} with {extra}"); + let mut ssh = load_ssh(destination, extra)?; + let handler = AuthClientSshAuthHandler::new(auth_client); + let _ = ssh.authenticate(handler).await?; + ssh.into_distant_writer_reader().await + } +} + +#[cfg(any(feature = "libssh", feature = "ssh2"))] +struct AuthClientSshAuthHandler<'a>(Mutex<&'a mut AuthClient>); + +#[cfg(any(feature = "libssh", feature = "ssh2"))] +impl<'a> AuthClientSshAuthHandler<'a> { + pub fn new(auth_client: &'a mut AuthClient) -> Self { + Self(Mutex::new(auth_client)) + } +} + +#[cfg(any(feature = "libssh", feature = "ssh2"))] +#[async_trait] +impl<'a> distant_ssh2::SshAuthHandler for AuthClientSshAuthHandler<'a> { + async fn on_authenticate(&self, event: distant_ssh2::SshAuthEvent) -> io::Result> { + use std::collections::HashMap; + let mut extra = HashMap::new(); + let mut questions = Vec::new(); + + for prompt in event.prompts { + let mut extra = HashMap::new(); + extra.insert("echo".to_string(), prompt.echo.to_string()); + questions.push(AuthQuestion { + text: prompt.prompt, + extra, + }); + } + + extra.insert("instructions".to_string(), event.instructions); + extra.insert("username".to_string(), event.username); + + self.0.lock().await.challenge(questions, extra).await + } + + async fn on_verify_host(&self, host: &str) -> io::Result { + use distant_core::net::AuthVerifyKind; + self.0 + .lock() + .await + .verify(AuthVerifyKind::Host, host.to_string()) + .await + } + + async fn on_banner(&self, text: &str) { + if let Err(x) = self.0.lock().await.info(text.to_string()).await { + error!("ssh on_banner failed: {}", x); + } + } + + async fn on_error(&self, text: &str) { + use distant_core::net::AuthErrorKind; + if let Err(x) = self + .0 + .lock() + .await + .error(AuthErrorKind::Unknown, text.to_string()) + .await + { + error!("ssh on_error failed: {}", x); + } + } +} + +#[cfg(any(feature = "libssh", feature = "ssh2"))] +fn load_ssh(destination: &Destination, extra: &Extra) -> io::Result { + use distant_ssh2::{Ssh, SshOpts}; + + let host = destination.to_host_string(); + + let opts = SshOpts { + backend: match extra.get("backend") { + Some(s) => s.parse().map_err(|_| invalid("backend"))?, + None => Default::default(), + }, + + identity_files: extra + .get("identity_files") + .map(|s| s.split(',').map(|s| PathBuf::from(s.trim())).collect()) + .unwrap_or_default(), + + identities_only: match extra.get("identities_only") { + Some(s) => Some(s.parse().map_err(|_| invalid("identities_only"))?), + None => None, + }, + + port: destination.port(), + + proxy_command: extra.get("proxy_command").cloned(), + + user: destination.username().map(ToString::to_string), + + user_known_hosts_files: extra + .get("user_known_hosts_files") + .map(|s| s.split(',').map(|s| PathBuf::from(s.trim())).collect()) + .unwrap_or_default(), + + verbose: match extra.get("verbose") { + Some(s) => s.parse().map_err(|_| invalid("verbose"))?, + None => false, + }, + + ..Default::default() + }; + Ssh::connect(host, opts) +} diff --git a/src/cli/commands/server.rs b/src/cli/commands/server.rs new file mode 100644 index 0000000..781d161 --- /dev/null +++ b/src/cli/commands/server.rs @@ -0,0 +1,238 @@ +use crate::{ + config::{BindAddress, ServerConfig, ServerListenConfig}, + CliError, CliResult, +}; +use anyhow::Context; +use clap::Subcommand; +use distant_core::{ + net::{SecretKey32, ServerRef, TcpServerExt, XChaCha20Poly1305Codec}, + DistantApiServer, DistantSingleKeyCredentials, +}; +use log::*; +use std::io::{self, Read, Write}; + +#[derive(Debug, Subcommand)] +pub enum ServerSubcommand { + /// Listen for incoming requests as a server + Listen { + #[clap(flatten)] + config: ServerListenConfig, + + /// If specified, will fork the process to run as a standalone daemon + #[clap(long)] + daemon: bool, + + /// If specified, the server will not generate a key but instead listen on stdin for the next + /// 32 bytes that it will use as the key instead. Receiving less than 32 bytes before stdin + /// is closed is considered an error and any bytes after the first 32 are not used for the key + #[clap(long)] + key_from_stdin: bool, + + /// If specified, will send output to the specified named pipe (internal usage) + #[cfg(windows)] + #[clap(long, help = None, long_help = None)] + output_to_local_pipe: Option, + }, +} + +impl ServerSubcommand { + pub fn run(self, _config: ServerConfig) -> CliResult { + match &self { + Self::Listen { daemon, .. } if *daemon => Self::run_daemon(self), + Self::Listen { .. } => { + let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?; + rt.block_on(Self::async_run(self, false)) + } + } + } + + #[cfg(windows)] + fn run_daemon(self) -> CliResult { + use crate::cli::Spawner; + use distant_core::net::{Listener, WindowsPipeListener}; + use std::ffi::OsString; + use tokio::io::AsyncReadExt; + let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?; + rt.block_on(async { + let name = format!("distant_{}_{}", std::process::id(), rand::random::()); + let mut listener = WindowsPipeListener::bind_local(name.as_str()) + .with_context(|| "Failed to bind to local named pipe {name:?}")?; + + let pid = Spawner::spawn_running_background(vec![ + OsString::from("--output-to-local-pipe"), + OsString::from(name), + ]) + .context("Failed to spawn background process")?; + println!("[distant server detached, pid = {}]", pid); + + // Wait to receive a connection from the above process + let mut transport = listener.accept().await.context( + "Failed to receive connection from background process to send credentials", + )?; + + // Get the credentials and print them + let mut s = String::new(); + let n = transport + .read_to_string(&mut s) + .await + .context("Failed to receive credentials")?; + if n == 0 { + anyhow::bail!("No credentials received from spawned server"); + } + let credentials = s[..n] + .trim() + .parse::() + .context("Failed to parse server credentials")?; + + println!("\r"); + println!("{}", credentials); + println!("\r"); + io::stdout() + .flush() + .context("Failed to print server credentials")?; + Ok(()) + }) + .map_err(CliError::Error) + } + + #[cfg(unix)] + fn run_daemon(self) -> CliResult { + use fork::{daemon, Fork}; + + // NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent + debug!("Forking process"); + match daemon(true, true) { + Ok(Fork::Child) => { + let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?; + rt.block_on(async { Self::async_run(self, true).await })?; + Ok(()) + } + Ok(Fork::Parent(pid)) => { + println!("[distant server detached, pid = {}]", pid); + if fork::close_fd().is_err() { + Err(CliError::Error(anyhow::anyhow!("Fork failed to close fd"))) + } else { + Ok(()) + } + } + Err(_) => Err(CliError::Error(anyhow::anyhow!("Fork failed"))), + } + } + + async fn async_run(self, _is_forked: bool) -> CliResult { + match self { + Self::Listen { + config, + key_from_stdin, + #[cfg(windows)] + output_to_local_pipe, + .. + } => { + let host = config.host.unwrap_or(BindAddress::Any); + trace!("Starting server using unresolved host '{}'", host); + let addr = host.resolve(config.use_ipv6)?; + + // If specified, change the current working directory of this program + if let Some(path) = config.current_dir.as_ref() { + debug!("Setting current directory to {:?}", path); + std::env::set_current_dir(path) + .context("Failed to set new current directory")?; + } + + // Bind & start our server + let key = if key_from_stdin { + debug!("Reading secret key from stdin"); + let mut buf = [0u8; 32]; + io::stdin() + .read_exact(&mut buf) + .context("Failed to read secret key from stdin")?; + SecretKey32::from(buf) + } else { + SecretKey32::default() + }; + + let codec = XChaCha20Poly1305Codec::new(key.unprotected_as_bytes()); + + debug!( + "Starting local API server, binding to {} {}", + addr, + match config.port { + Some(range) => format!("with port in range {}", range), + None => "using an ephemeral port".to_string(), + } + ); + let server = DistantApiServer::local() + .context("Failed to create local distant api")? + .start(addr, config.port.unwrap_or_else(|| 0.into()), codec) + .await + .with_context(|| { + format!( + "Failed to start server @ {} with {}", + addr, + config + .port + .map(|p| format!("port in range {p}")) + .unwrap_or_else(|| String::from("ephemeral port")) + ) + })?; + + let credentials = DistantSingleKeyCredentials { + host: addr.to_string(), + port: server.port(), + key, + username: None, + }; + info!( + "Server listening at {}:{}", + credentials.host, credentials.port + ); + + // Print information about port, key, etc. + // NOTE: Following mosh approach of printing to make sure there's no garbage floating around + #[cfg(not(windows))] + { + println!("\r"); + println!("{}", credentials); + println!("\r"); + io::stdout() + .flush() + .context("Failed to print credentials")?; + } + + #[cfg(windows)] + if let Some(name) = output_to_local_pipe { + use distant_core::net::WindowsPipeTransport; + use tokio::io::AsyncWriteExt; + let mut transport = WindowsPipeTransport::connect_local(&name) + .await + .with_context(|| { + format!("Failed to connect to local pipe named {name:?}") + })?; + transport + .write_all(credentials.to_string().as_bytes()) + .await + .context("Failed to send credentials through pipe")?; + } else { + println!("\r"); + println!("{}", credentials); + println!("\r"); + io::stdout() + .flush() + .context("Failed to print credentials")?; + } + + // For the child, we want to fully disconnect it from pipes, which we do now + #[cfg(unix)] + if _is_forked && fork::close_fd().is_err() { + return Err(CliError::Error(anyhow::anyhow!("Fork failed to close fd"))); + } + + // Let our server run to completion + server.wait().await.context("Failed to wait on server")?; + info!("Server is shutting down"); + } + } + + Ok(()) + } +} diff --git a/src/cli/manager.rs b/src/cli/manager.rs new file mode 100644 index 0000000..24a8fd2 --- /dev/null +++ b/src/cli/manager.rs @@ -0,0 +1,77 @@ +use crate::{ + config::NetworkConfig, + paths::{global as global_paths, user as user_paths}, +}; +use anyhow::Context; +use distant_core::{net::PlainCodec, DistantManager, DistantManagerConfig, DistantManagerRef}; +use log::*; + +pub struct Manager { + config: DistantManagerConfig, + network: NetworkConfig, +} + +impl Manager { + pub fn new(config: DistantManagerConfig, network: NetworkConfig) -> Self { + Self { config, network } + } + + /// Begin listening on the network interface specified within [`NetworkConfig`] + pub async fn listen(self) -> anyhow::Result { + let user = self.config.user; + + #[cfg(unix)] + { + let socket_path = self.network.unix_socket.as_deref().unwrap_or({ + if user { + user_paths::UNIX_SOCKET_PATH.as_path() + } else { + global_paths::UNIX_SOCKET_PATH.as_path() + } + }); + + // Ensure that the path to the socket exists + if let Some(parent) = socket_path.parent() { + tokio::fs::create_dir_all(parent) + .await + .with_context(|| format!("Failed to create socket directory {parent:?}"))?; + } + + let boxed_ref = DistantManager::start_unix_socket_with_permissions( + self.config, + socket_path, + PlainCodec, + self.network.access.unwrap_or_default().into_mode(), + ) + .await + .with_context(|| format!("Failed to start manager at socket {socket_path:?}"))? + .into_inner() + .into_boxed_server_ref() + .map_err(|_| anyhow::anyhow!("Got wrong server ref"))?; + + info!("Manager listening using unix socket @ {:?}", socket_path); + Ok(*boxed_ref) + } + + #[cfg(windows)] + { + let pipe_name = self.network.windows_pipe.as_deref().unwrap_or(if user { + user_paths::WINDOWS_PIPE_NAME.as_str() + } else { + global_paths::WINDOWS_PIPE_NAME.as_str() + }); + let boxed_ref = + DistantManager::start_local_named_pipe(self.config, pipe_name, PlainCodec) + .await + .with_context(|| { + format!("Failed to start manager with pipe named '{pipe_name}'") + })? + .into_inner() + .into_boxed_server_ref() + .map_err(|_| anyhow::anyhow!("Got wrong server ref"))?; + + info!("Manager listening using local named pipe @ {:?}", pipe_name); + Ok(*boxed_ref) + } + } +} diff --git a/src/cli/spawner.rs b/src/cli/spawner.rs new file mode 100644 index 0000000..5816746 --- /dev/null +++ b/src/cli/spawner.rs @@ -0,0 +1,211 @@ +use anyhow::Context; +use log::*; +use std::{ + ffi::{OsStr, OsString}, + path::PathBuf, + process::{Command, Stdio}, +}; + +/// Utility functions to spawn a process in the background +#[allow(dead_code)] +pub struct Spawner; + +#[allow(dead_code)] +impl Spawner { + /// Spawns a new instance of this running process without a `--daemon` flag, + /// returning the id of the spawned process + pub fn spawn_running_background(extra_args: Vec) -> anyhow::Result { + let cmd = Self::make_current_cmd(extra_args, "--daemon")?; + + #[cfg(windows)] + let cmd = { + let mut s = OsString::new(); + s.push("'"); + s.push(&cmd); + s.push("'"); + s + }; + + Self::spawn_background(cmd) + } + + #[inline] + fn make_current_cmd(extra_args: Vec, exclude: &str) -> anyhow::Result { + // Get absolute path to our binary + let program = which::which(std::env::current_exe().unwrap_or_else(|_| { + PathBuf::from(if cfg!(windows) { + "distant.exe" + } else { + "distant" + }) + })) + .context("Failed to locate distant binary")?; + + // Remove --daemon argument to to ensure runs in foreground, + // otherwise we would fork bomb ourselves + // + // Also, remove first argument (program) since we determined it above + let mut cmd = OsString::new(); + cmd.push(program.as_os_str()); + + let it = std::env::args_os() + .skip(1) + .filter(|arg| { + !arg.to_str() + .map(|s| s.trim().eq_ignore_ascii_case(exclude)) + .unwrap_or_default() + }) + .chain(extra_args.into_iter()); + for arg in it { + cmd.push(" "); + cmd.push(&arg); + } + + Ok(cmd) + } +} + +#[cfg(unix)] +#[allow(dead_code)] +impl Spawner { + /// Spawns a process on Unix that runs in the background and won't be terminated when the + /// parent process exits + pub fn spawn_background(cmd: impl AsRef) -> anyhow::Result { + let cmd = cmd + .as_ref() + .to_str() + .ok_or_else(|| anyhow::anyhow!("cmd is not a UTF-8 str"))?; + + // Build out the command and args from our string + let (cmd, args) = match cmd.split_once(' ') { + Some((cmd_str, args_str)) => ( + cmd_str, + shell_words::split(args_str).context("Failed to split process arguments")?, + ), + None => (cmd, Vec::new()), + }; + + debug!("Spawning background process: {}", cmd); + let child = Command::new(cmd) + .args(args) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn() + .context("Failed to spawn background process")?; + Ok(child.id()) + } +} + +#[cfg(windows)] +impl Spawner { + /// Spawns a process on Windows that runs in the background without a console and does not get + /// terminated when the parent or other ancestors terminate (such as openssh session) + pub fn spawn_background(cmd: impl AsRef) -> anyhow::Result { + use std::{ + io::{BufRead, Cursor}, + os::windows::process::CommandExt, + }; + + // Get absolute path to powershell + let powershell = which::which("powershell.exe").context("Failed to find powershell.exe")?; + + // Pass along our environment variables + let env = { + let mut s = OsString::new(); + s.push(r#"$startup.Properties['EnvironmentVariables'].value=@("#); + let mut first = true; + for (key, value) in std::env::vars_os() { + if !first { + s.push(","); + } else { + first = false; + } + + s.push("'"); + s.push(key); + s.push("="); + s.push(value); + s.push("'"); + } + s.push(")"); + s + }; + + let args = vec![ + OsString::from(r#"$startup=[wmiclass]"Win32_ProcessStartup""#), + OsString::from(";"), + OsString::from(r#"$startup.Properties['ShowWindow'].value=$False"#), + OsString::from(";"), + env, + OsString::from(";"), + OsString::from("Invoke-WmiMethod"), + OsString::from("-Class"), + OsString::from("Win32_Process"), + OsString::from("-Name"), + OsString::from("Create"), + OsString::from("-ArgumentList"), + { + let mut arg_list = OsString::new(); + arg_list.push(cmd.as_ref()); + arg_list.push(",$null,$startup"); + arg_list + }, + ]; + + // const DETACHED_PROCESS: u32 = 0x00000008; + const CREATE_NEW_PROCESS_GROUP: u32 = 0x00000200; + const CREATE_NO_WINDOW: u32 = 0x08000000; + let flags = CREATE_NEW_PROCESS_GROUP | CREATE_NO_WINDOW; + + debug!( + "Spawning background process: {} {:?}", + powershell.to_string_lossy(), + args + ); + let output = Command::new(powershell.into_os_string()) + .creation_flags(flags) + .args(args) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .context("Failed to spawn background process")?; + + if !output.status.success() { + anyhow::bail!( + "Program failed [{}]: {}", + output.status.code().unwrap_or(1), + String::from_utf8_lossy(&output.stderr) + ); + } + + let stdout = Cursor::new(output.stdout); + + let mut process_id = None; + let mut return_value = None; + for line in stdout.lines().filter_map(|l| l.ok()) { + let line = line.trim(); + if line.starts_with("ProcessId") { + if let Some((_, id)) = line.split_once(':') { + process_id = id.trim().parse::().ok(); + } + } else if line.starts_with("ReturnValue") { + if let Some((_, value)) = line.split_once(':') { + return_value = value.trim().parse::().ok(); + } + } + } + + match (return_value, process_id) { + (Some(0), Some(pid)) => Ok(pid), + (Some(0), None) => anyhow::bail!("Program succeeded, but missing process pid"), + (Some(code), _) => anyhow::bail!( + "Program failed [{}]: {}", + code, + String::from_utf8_lossy(&output.stderr) + ), + (None, _) => anyhow::bail!("Missing return value"), + } + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..e0835ef --- /dev/null +++ b/src/config.rs @@ -0,0 +1,126 @@ +use crate::paths; +use anyhow::Context; +use serde::{Deserialize, Serialize}; +use std::{ + io, + path::{Path, PathBuf}, +}; +use toml_edit::Document; + +mod client; +mod common; +mod generate; +mod manager; +mod network; +mod server; + +pub use client::*; +pub use common::*; +pub use generate::*; +pub use manager::*; +pub use network::*; +pub use server::*; + +/// Represents configuration settings for all of distant +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct Config { + pub client: ClientConfig, + pub generate: GenerateConfig, + pub manager: ManagerConfig, + pub server: ServerConfig, +} + +impl Config { + /// Loads the configuration from multiple sources in a blocking fashion + /// + /// 1. If `custom` is provided, it is used by itself as the source for configuration + /// 2. Otherwise, if `custom` is not provided, will attempt to load from global and user + /// config files, merging together if they both exist + /// 3. Otherwise if no `custom` path and none of the standard configuration paths exist, + /// then the default configuration is returned instead + pub fn load_multi(custom: Option) -> anyhow::Result { + match custom { + Some(path) => { + toml_edit::de::from_slice(&std::fs::read(path)?).context("Failed to parse config") + } + None => { + let paths = vec![ + paths::global::CONFIG_FILE_PATH.as_path(), + paths::user::CONFIG_FILE_PATH.as_path(), + ]; + + match (paths[0].exists(), paths[1].exists()) { + // At least one standard path exists, so load it + (exists_1, exists_2) if exists_1 || exists_2 => { + use config::{Config, File}; + let config = Config::builder() + .add_source(File::from(paths[0]).required(exists_1)) + .add_source(File::from(paths[1]).required(exists_2)) + .build() + .context("Failed to build config from paths")?; + config.try_deserialize().context("Failed to parse config") + } + + // None of our standard paths exist, so use the default value instead + _ => Ok(Self::default()), + } + } + } + } + + /// Loads the specified `path` as a [`Config`] + pub async fn load(path: impl AsRef) -> anyhow::Result { + let bytes = tokio::fs::read(path.as_ref()) + .await + .with_context(|| format!("Failed to read config file {:?}", path.as_ref()))?; + toml_edit::de::from_slice(&bytes).context("Failed to parse config") + } + + /// Like `edit` but will succeed without invoking `f` if the path is not found + pub async fn edit_if_exists( + path: impl AsRef, + f: impl FnOnce(&mut Document) -> io::Result<()>, + ) -> io::Result<()> { + Self::edit(path, f).await.or_else(|x| { + if x.kind() == io::ErrorKind::NotFound { + Ok(()) + } else { + Err(x) + } + }) + } + + /// Loads the specified `path` as a [`Document`], performs changes to the document using `f`, + /// and overwrites the `path` with the updated [`Document`] + pub async fn edit( + path: impl AsRef, + f: impl FnOnce(&mut Document) -> io::Result<()>, + ) -> io::Result<()> { + let mut document = tokio::fs::read_to_string(path.as_ref()) + .await? + .parse::() + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?; + f(&mut document)?; + tokio::fs::write(path, document.to_string()).await + } + + /// Saves the [`Config`] to the specified `path` only if the path points to no file + pub async fn save_if_not_found(&self, path: impl AsRef) -> io::Result<()> { + use tokio::io::AsyncWriteExt; + let text = toml_edit::ser::to_string_pretty(self) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?; + tokio::fs::OpenOptions::new() + .create_new(true) + .open(path) + .await? + .write_all(text.as_bytes()) + .await + } + + /// Saves the [`Config`] to the specified `path`, overwriting the file if it exists + pub async fn save(&self, path: impl AsRef) -> io::Result<()> { + let text = toml_edit::ser::to_string_pretty(self) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?; + tokio::fs::write(path, text).await + } +} diff --git a/src/config/client.rs b/src/config/client.rs new file mode 100644 index 0000000..3b568b2 --- /dev/null +++ b/src/config/client.rs @@ -0,0 +1,24 @@ +use super::{CommonConfig, NetworkConfig}; +use serde::{Deserialize, Serialize}; + +mod action; +mod launch; +mod repl; + +pub use action::*; +pub use launch::*; +pub use repl::*; + +/// Represents configuration settings for the distant client +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct ClientConfig { + #[serde(flatten)] + pub common: CommonConfig, + + pub action: ClientActionConfig, + pub launch: ClientLaunchConfig, + pub repl: ClientReplConfig, + + #[serde(flatten)] + pub network: NetworkConfig, +} diff --git a/src/config/client/action.rs b/src/config/client/action.rs new file mode 100644 index 0000000..f4f7831 --- /dev/null +++ b/src/config/client/action.rs @@ -0,0 +1,7 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct ClientActionConfig { + /// Represents the maximum time (in seconds) to wait for a network request before timing out + pub timeout: Option, +} diff --git a/src/config/client/launch.rs b/src/config/client/launch.rs new file mode 100644 index 0000000..7c3cb83 --- /dev/null +++ b/src/config/client/launch.rs @@ -0,0 +1,162 @@ +use crate::config::BindAddress; +use clap::Args; +use distant_core::Map; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +#[derive(Args, Debug, Default, Serialize, Deserialize)] +pub struct ClientLaunchConfig { + #[clap(flatten)] + #[serde(flatten)] + pub distant: ClientLaunchDistantConfig, + + #[clap(flatten)] + #[serde(flatten)] + pub ssh: ClientLaunchSshConfig, +} + +impl From for ClientLaunchConfig { + fn from(mut map: Map) -> Self { + Self { + distant: ClientLaunchDistantConfig { + bin: map.remove("distant.bin"), + bind_server: map + .remove("distant.bind_server") + .and_then(|x| x.parse::().ok()), + args: map.remove("distant.args"), + no_shell: map + .remove("distant.no_shell") + .and_then(|x| x.parse::().ok()) + .unwrap_or_default(), + }, + ssh: ClientLaunchSshConfig { + bin: map.remove("ssh.bind"), + #[cfg(any(feature = "libssh", feature = "ssh2"))] + backend: map + .remove("ssh.backend") + .and_then(|x| x.parse::().ok()), + external: map + .remove("ssh.external") + .and_then(|x| x.parse::().ok()) + .unwrap_or_default(), + username: map.remove("ssh.username"), + identity_file: map + .remove("ssh.identity_file") + .and_then(|x| x.parse::().ok()), + port: map.remove("ssh.port").and_then(|x| x.parse::().ok()), + }, + } + } +} + +impl From for Map { + fn from(config: ClientLaunchConfig) -> Self { + let mut this = Self::new(); + + if let Some(x) = config.distant.bin { + this.insert("distant.bin".to_string(), x); + } + + if let Some(x) = config.distant.bind_server { + this.insert("distant.bind_server".to_string(), x.to_string()); + } + + if let Some(x) = config.distant.args { + this.insert("distant.args".to_string(), x); + } + + this.insert( + "distant.no_shell".to_string(), + config.distant.no_shell.to_string(), + ); + + if let Some(x) = config.ssh.bin { + this.insert("ssh.bin".to_string(), x); + } + + #[cfg(any(feature = "libssh", feature = "ssh2"))] + if let Some(x) = config.ssh.backend { + this.insert("ssh.backend".to_string(), x.to_string()); + } + + this.insert("ssh.external".to_string(), config.ssh.external.to_string()); + + if let Some(x) = config.ssh.username { + this.insert("ssh.username".to_string(), x); + } + + if let Some(x) = config.ssh.identity_file { + this.insert( + "ssh.identity_file".to_string(), + x.to_string_lossy().to_string(), + ); + } + + if let Some(x) = config.ssh.port { + this.insert("ssh.port".to_string(), x.to_string()); + } + + this + } +} + +#[derive(Args, Debug, Default, Serialize, Deserialize)] +pub struct ClientLaunchDistantConfig { + /// Path to distant program on remote machine to execute via ssh; + /// by default, this program needs to be available within PATH as + /// specified when compiling ssh (not your login shell) + #[clap(name = "distant", long)] + pub bin: Option, + + /// Control the IP address that the server binds to. + /// + /// The default is `ssh', in which case the server will reply from the IP address that the SSH + /// connection came from (as found in the SSH_CONNECTION environment variable). This is + /// useful for multihomed servers. + /// + /// With --bind-server=any, the server will reply on the default interface and will not bind to + /// a particular IP address. This can be useful if the connection is made through sslh or + /// another tool that makes the SSH connection appear to come from localhost. + /// + /// With --bind-server=IP, the server will attempt to bind to the specified IP address. + #[clap(name = "distant-bind-server", long, value_name = "ssh|any|IP")] + pub bind_server: Option, + + /// Additional arguments to provide to the server + #[clap(name = "distant-args", long, allow_hyphen_values(true))] + pub args: Option, + + /// If specified, will not launch distant using a login shell but instead execute it directly + #[clap(long)] + pub no_shell: bool, +} + +#[derive(Args, Debug, Default, Serialize, Deserialize)] +pub struct ClientLaunchSshConfig { + /// Path to ssh program on local machine to execute when using external ssh + #[clap(name = "ssh", long)] + pub bin: Option, + + /// If using native ssh integration, represents the backend + #[cfg(any(feature = "libssh", feature = "ssh2"))] + #[clap(name = "ssh-backend", long)] + pub backend: Option, + + /// If specified, will use the external ssh program to launch the server + /// instead of the native integration; does nothing if the ssh2 feature is + /// not enabled as there is no other option than external ssh + #[clap(name = "ssh-external", long)] + pub external: bool, + + /// Username to use when sshing into remote machine + #[clap(name = "ssh-username", short = 'u', long)] + pub username: Option, + + /// Explicit identity file to use with ssh + #[clap(name = "ssh-identity-file", short = 'i', long)] + pub identity_file: Option, + + /// Port to use for sshing into the remote machine + #[clap(name = "ssh-port", short = 'p', long)] + pub port: Option, +} diff --git a/src/config/client/repl.rs b/src/config/client/repl.rs new file mode 100644 index 0000000..f2f0482 --- /dev/null +++ b/src/config/client/repl.rs @@ -0,0 +1,7 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct ClientReplConfig { + /// Represents the maximum time (in seconds) to wait for a network request before timing out + pub timeout: Option, +} diff --git a/src/config/common.rs b/src/config/common.rs new file mode 100644 index 0000000..786eccb --- /dev/null +++ b/src/config/common.rs @@ -0,0 +1,52 @@ +use clap::{Args, ValueEnum}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, ValueEnum)] +#[clap(rename_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub enum LogLevel { + Off, + Error, + Warn, + Info, + Debug, + Trace, +} + +impl LogLevel { + pub fn to_log_level_filter(self) -> log::LevelFilter { + match self { + Self::Off => log::LevelFilter::Off, + Self::Error => log::LevelFilter::Error, + Self::Warn => log::LevelFilter::Warn, + Self::Info => log::LevelFilter::Info, + Self::Debug => log::LevelFilter::Debug, + Self::Trace => log::LevelFilter::Trace, + } + } +} + +impl Default for LogLevel { + fn default() -> Self { + Self::Info + } +} + +/// Contains options that are common across subcommands +#[derive(Args, Clone, Debug, Default, Serialize, Deserialize)] +pub struct CommonConfig { + /// Log level to use throughout the application + #[clap(long, global = true, case_insensitive = true, value_enum)] + pub log_level: Option, + + /// Path to file to use for logging + #[clap(long, global = true)] + pub log_file: Option, +} + +impl CommonConfig { + pub fn log_level_or_default(&self) -> LogLevel { + self.log_level.as_ref().copied().unwrap_or_default() + } +} diff --git a/src/config/generate.rs b/src/config/generate.rs new file mode 100644 index 0000000..a5f4863 --- /dev/null +++ b/src/config/generate.rs @@ -0,0 +1,9 @@ +use super::CommonConfig; +use serde::{Deserialize, Serialize}; + +/// Represents configuration settings for the distant generate +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct GenerateConfig { + #[serde(flatten)] + pub common: CommonConfig, +} diff --git a/src/config/manager.rs b/src/config/manager.rs new file mode 100644 index 0000000..7d87060 --- /dev/null +++ b/src/config/manager.rs @@ -0,0 +1,45 @@ +use super::{CommonConfig, NetworkConfig}; +use clap::Args; +use distant_core::Destination; +use serde::{Deserialize, Serialize}; +use service_manager::ServiceManagerKind; + +/// Represents configuration settings for the distant manager +#[derive(Args, Debug, Default, Serialize, Deserialize)] +pub struct ManagerConfig { + #[clap(flatten)] + #[serde(flatten)] + pub common: CommonConfig, + + #[clap(skip)] + pub connections: Vec, + + #[clap(flatten)] + #[serde(flatten)] + pub network: NetworkConfig, + + #[clap(value_enum)] + pub service: Option, +} + +/// Represents configuration for some managed connection +#[derive(Debug, Serialize, Deserialize)] +pub enum ManagerConnectionConfig { + Distant(ManagerDistantConnectionConfig), + Ssh(ManagerSshConnectionConfig), +} + +/// Represents configuration for a distant connection +#[derive(Debug, Serialize, Deserialize)] +pub struct ManagerDistantConnectionConfig { + pub name: String, + pub destination: Destination, + pub key_cmd: Option, +} + +/// Represents configuration for an SSH connection +#[derive(Debug, Serialize, Deserialize)] +pub struct ManagerSshConnectionConfig { + pub name: String, + pub destination: Destination, +} diff --git a/src/config/network.rs b/src/config/network.rs new file mode 100644 index 0000000..2dda804 --- /dev/null +++ b/src/config/network.rs @@ -0,0 +1,105 @@ +use clap::Args; +use serde::{Deserialize, Serialize}; + +/// Level of access control to the unix socket or windows pipe +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)] +#[clap(rename_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub enum AccessControl { + /// Equates to `0o600` on Unix (read & write for owner) + Owner, + + /// Equates to `0o660` on Unix (read & write for owner and group) + Group, + + /// Equates to `0o666` on Unix (read & write for owner, group, and other) + Anyone, +} + +impl AccessControl { + /// Converts into a Unix file permission octal + pub fn into_mode(self) -> u32 { + match self { + Self::Owner => 0o600, + Self::Group => 0o660, + Self::Anyone => 0o666, + } + } +} + +impl Default for AccessControl { + /// Defaults to owner-only permissions + fn default() -> Self { + Self::Owner + } +} + +/// Represents common networking configuration +#[derive(Args, Clone, Debug, Default, Serialize, Deserialize)] +pub struct NetworkConfig { + /// Type of access to apply to created unix socket or windows pipe + #[clap(long, value_enum)] + pub access: Option, + + /// Override the path to the Unix socket used by the manager + #[cfg(unix)] + #[clap(long)] + pub unix_socket: Option, + + /// Override the name of the local named Windows pipe used by the manager + #[cfg(windows)] + #[clap(long)] + pub windows_pipe: Option, +} + +impl NetworkConfig { + pub fn merge(self, other: Self) -> Self { + Self { + access: self.access.or(other.access), + + #[cfg(unix)] + unix_socket: self.unix_socket.or(other.unix_socket), + + #[cfg(windows)] + windows_pipe: self.windows_pipe.or(other.windows_pipe), + } + } + + /// Returns option containing reference to unix path if configured + #[cfg(unix)] + pub fn as_opt(&self) -> Option<&std::path::Path> { + self.unix_socket.as_deref() + } + + /// Returns option containing reference to windows pipe name if configured + #[cfg(windows)] + pub fn as_opt(&self) -> Option<&str> { + self.windows_pipe.as_deref() + } + + /// Returns a collection of candidate unix socket paths, which will either be + /// the config-provided unix socket path or the default user and global socket paths + #[cfg(unix)] + pub fn to_unix_socket_path_candidates(&self) -> Vec<&std::path::Path> { + match self.unix_socket.as_deref() { + Some(path) => vec![path], + None => vec![ + crate::paths::user::UNIX_SOCKET_PATH.as_path(), + crate::paths::global::UNIX_SOCKET_PATH.as_path(), + ], + } + } + + /// Returns a collection of candidate windows pipe names, which will either be + /// the config-provided windows pipe name or the default user and global pipe names + #[cfg(windows)] + pub fn to_windows_pipe_name_candidates(&self) -> Vec<&str> { + match self.windows_pipe.as_deref() { + Some(name) => vec![name], + None => vec![ + crate::paths::user::WINDOWS_PIPE_NAME.as_str(), + crate::paths::global::WINDOWS_PIPE_NAME.as_str(), + ], + } + } +} diff --git a/src/config/server.rs b/src/config/server.rs new file mode 100644 index 0000000..fc2515f --- /dev/null +++ b/src/config/server.rs @@ -0,0 +1,14 @@ +use super::CommonConfig; +use serde::{Deserialize, Serialize}; + +mod listen; +pub use listen::*; + +/// Represents configuration settings for the distant server +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct ServerConfig { + #[serde(flatten)] + pub common: CommonConfig, + + pub listen: ServerListenConfig, +} diff --git a/src/config/server/listen.rs b/src/config/server/listen.rs new file mode 100644 index 0000000..023a64f --- /dev/null +++ b/src/config/server/listen.rs @@ -0,0 +1,149 @@ +use anyhow::Context; +use clap::Args; +use derive_more::Display; +use distant_core::{net::PortRange, Map}; +use serde::{Deserialize, Serialize}; +use std::{ + env, + net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr}, + path::PathBuf, + str::FromStr, +}; + +#[derive(Args, Debug, Default, Serialize, Deserialize)] +pub struct ServerListenConfig { + /// Control the IP address that the distant binds to + /// + /// There are three options here: + /// + /// 1. `ssh`: the server will reply from the IP address that the SSH + /// connection came from (as found in the SSH_CONNECTION environment variable). This is + /// useful for multihomed servers. + /// + /// 2. `any`: the server will reply on the default interface and will not bind to + /// a particular IP address. This can be useful if the connection is made through ssh or + /// another tool that makes the SSH connection appear to come from localhost. + /// + /// 3. `IP`: the server will attempt to bind to the specified IP address. + #[clap(long, value_name = "ssh|any|IP")] + pub host: Option, + + /// Set the port(s) that the server will attempt to bind to + /// + /// This can be in the form of PORT1 or PORT1:PORTN to provide a range of ports. + /// With `--port 0`, the server will let the operating system pick an available TCP port. + /// + /// Please note that this option does not affect the server-side port used by SSH + #[clap(long, value_name = "PORT[:PORT2]")] + pub port: Option, + + /// If specified, will bind to the ipv6 interface if host is "any" instead of ipv4 + #[clap(short = '6', long)] + pub use_ipv6: bool, + + /// The time in seconds before shutting down the server if there are no active + /// connections. The countdown begins once all connections have closed and + /// stops when a new connection is made. In not specified, the server will not + /// shutdown at any point when there are no active connections. + #[clap(long)] + pub shutdown_after: Option, + + /// Changes the current working directory (cwd) to the specified directory + #[clap(long)] + pub current_dir: Option, +} + +impl From for ServerListenConfig { + fn from(mut map: Map) -> Self { + Self { + host: map + .remove("host") + .and_then(|x| x.parse::().ok()), + port: map.remove("port").and_then(|x| x.parse::().ok()), + use_ipv6: map + .remove("use_ipv6") + .and_then(|x| x.parse::().ok()) + .unwrap_or_default(), + shutdown_after: map + .remove("shutdown_after") + .and_then(|x| x.parse::().ok()), + current_dir: map + .remove("current_dir") + .and_then(|x| x.parse::().ok()), + } + } +} + +impl From for Map { + fn from(config: ServerListenConfig) -> Self { + let mut this = Self::new(); + + if let Some(x) = config.host { + this.insert("host".to_string(), x.to_string()); + } + + if let Some(x) = config.port { + this.insert("port".to_string(), x.to_string()); + } + + this.insert("use_ipv6".to_string(), config.use_ipv6.to_string()); + + if let Some(x) = config.shutdown_after { + this.insert("shutdown_after".to_string(), x.to_string()); + } + + if let Some(x) = config.current_dir { + this.insert("current_dir".to_string(), x.to_string_lossy().to_string()); + } + + this + } +} + +/// Represents options for binding a server to an IP address +#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +pub enum BindAddress { + #[display = "ssh"] + Ssh, + #[display = "any"] + Any, + Ip(IpAddr), +} + +impl FromStr for BindAddress { + type Err = AddrParseError; + + fn from_str(s: &str) -> Result { + let s = s.trim(); + Ok(if s.eq_ignore_ascii_case("ssh") { + Self::Ssh + } else if s.eq_ignore_ascii_case("any") { + Self::Any + } else { + s.parse()? + }) + } +} + +impl BindAddress { + /// Resolves address into valid IP; in the case of "any", will leverage the + /// `use_ipv6` flag to determine if binding should use ipv4 or ipv6 + pub fn resolve(self, use_ipv6: bool) -> anyhow::Result { + match self { + Self::Ssh => { + let ssh_connection = + env::var("SSH_CONNECTION").context("Failed to read SSH_CONNECTION")?; + let ip_str = ssh_connection.split(' ').nth(2).ok_or_else(|| { + anyhow::anyhow!("SSH_CONNECTION missing 3rd argument (host ip)") + })?; + let ip = ip_str + .parse::() + .context("Failed to parse IP address")?; + Ok(ip) + } + Self::Any if use_ipv6 => Ok(IpAddr::V6(Ipv6Addr::UNSPECIFIED)), + Self::Any => Ok(IpAddr::V4(Ipv4Addr::UNSPECIFIED)), + Self::Ip(addr) => Ok(addr), + } + } +} diff --git a/src/constants.rs b/src/constants.rs index 1c690e8..efa51af 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,30 +1,5 @@ -use once_cell::sync::Lazy; -use std::{env, path::PathBuf}; - /// Represents the maximum size (in bytes) that data will be read from pipes /// per individual `read` call /// /// Current setting is 16k size pub const MAX_PIPE_CHUNK_SIZE: usize = 16384; - -/// Capacity associated with a server receiving messages from a connection -/// with a client -pub const SERVER_CONN_MSG_CAPACITY: usize = 10000; - -/// Represents maximum time (in milliseconds) to wait on a network request -/// before failing (0 meaning indefinitely) -pub const TIMEOUT: usize = 15000; - -pub static TIMEOUT_STR: Lazy = Lazy::new(|| TIMEOUT.to_string()); -pub static SERVER_CONN_MSG_CAPACITY_STR: Lazy = - Lazy::new(|| SERVER_CONN_MSG_CAPACITY.to_string()); - -/// Represents the path to the global session file -pub static SESSION_FILE_PATH: Lazy = Lazy::new(|| env::temp_dir().join("distant.session")); -pub static SESSION_FILE_PATH_STR: Lazy = - Lazy::new(|| SESSION_FILE_PATH.to_string_lossy().to_string()); - -/// Represents the path to a socket to communicate instead of a session file -pub static SESSION_SOCKET_PATH: Lazy = Lazy::new(|| env::temp_dir().join("distant.sock")); -pub static SESSION_SOCKET_PATH_STR: Lazy = - Lazy::new(|| SESSION_SOCKET_PATH.to_string_lossy().to_string()); diff --git a/src/environment.rs b/src/environment.rs deleted file mode 100644 index e59404a..0000000 --- a/src/environment.rs +++ /dev/null @@ -1,142 +0,0 @@ -use distant_core::SessionInfo; -use std::{ffi::OsStr, path::Path}; - -/// Prints out shell-specific environment information -pub fn print_environment(info: &SessionInfo) { - inner_print_environment(&info.host, info.port, &info.key_to_unprotected_string()) -} - -/// Prints out shell-specific environment information -#[cfg(unix)] -fn inner_print_environment(host: &str, port: u16, key: &str) { - match parent_exe_name() { - // If shell is csh or tcsh, we want to print differently - Some(s) if s.eq_ignore_ascii_case("csh") || s.eq_ignore_ascii_case("tcsh") => { - formatter::print_csh_string(host, port, key) - } - - // If shell is fish, we want to print differently - Some(s) if s.eq_ignore_ascii_case("fish") => formatter::print_fish_string(host, port, key), - - // Otherwise, we assume that the shell is compatible with sh (e.g. bash, dash, zsh) - _ => formatter::print_sh_string(host, port, key), - } -} - -/// Prints out shell-specific environment information -#[cfg(windows)] -fn inner_print_environment(host: &str, port: u16, key: &str) { - match parent_exe_name() { - // If shell is powershell, we want to print differently - Some(s) if s.eq_ignore_ascii_case("powershell") => { - formatter::print_powershell_string(host, port, key) - } - - // Otherwise, we assume that the shell was cmd.exe - _ => formatter::print_cmd_exe_string(host, port, key), - } -} - -/// Retrieve the name of the parent process that spawned us -fn parent_exe_name() -> Option { - use sysinfo::{Pid, PidExt, Process, ProcessExt, System, SystemExt}; - - let mut system = System::new(); - - // Get our own process pid - let pid = Pid::from_u32(std::process::id()); - - // Update our system's knowledge about our process - system.refresh_process(pid); - - // Get our parent process' pid and update sustem's knowledge about parent process - let maybe_parent_pid = system.process(pid).and_then(Process::parent); - if let Some(pid) = maybe_parent_pid { - system.refresh_process(pid); - } - - maybe_parent_pid - .and_then(|pid| system.process(pid)) - .map(Process::exe) - .and_then(Path::file_name) - .map(OsStr::to_string_lossy) - .map(|s| s.to_string()) -} - -mod formatter { - use indoc::printdoc; - - /// Prints out a {csh,tcsh}-specific example of setting environment variables - #[cfg(unix)] - pub fn print_csh_string(host: &str, port: u16, key: &str) { - printdoc! {r#" - setenv DISTANT_HOST "{host}" - setenv DISTANT_PORT "{port}" - setenv DISTANT_KEY "{key}" - "#, - host = host, - port = port, - key = key, - } - } - - /// Prints out a fish-specific example of setting environment variables - #[cfg(unix)] - pub fn print_fish_string(host: &str, port: u16, key: &str) { - printdoc! {r#" - # Please export the following variables to use with actions - set -gx DISTANT_HOST {host} - set -gx DISTANT_PORT {port} - set -gx DISTANT_KEY {key} - "#, - host = host, - port = port, - key = key, - } - } - - /// Prints out an sh-compliant example of setting environment variables - #[cfg(unix)] - pub fn print_sh_string(host: &str, port: u16, key: &str) { - printdoc! {r#" - # Please export the following variables to use with actions - export DISTANT_HOST="{host}" - export DISTANT_PORT="{port}" - export DISTANT_KEY="{key}" - "#, - host = host, - port = port, - key = key, - } - } - - /// Prints out a powershell example of setting environment variables - #[cfg(windows)] - pub fn print_powershell_string(host: &str, port: u16, key: &str) { - printdoc! {r#" - # Please export the following variables to use with actions - $Env:DISTANT_HOST = "{host}" - $Env:DISTANT_PORT = "{port}" - $Env:DISTANT_KEY = "{key}" - "#, - host = host, - port = port, - key = key, - } - } - - /// Prints out a command prompt example of setting environment variables - #[cfg(windows)] - pub fn print_cmd_exe_string(host: &str, port: u16, key: &str) { - printdoc! {r#" - REM Please export the following variables to use with actions - SET DISTANT_HOST="{host}" - SET DISTANT_PORT="{port}" - SET DISTANT_KEY="{key}" - "#, - host = host, - port = port, - key = key, - } - } -} diff --git a/src/exit.rs b/src/exit.rs deleted file mode 100644 index 05f4373..0000000 --- a/src/exit.rs +++ /dev/null @@ -1,148 +0,0 @@ -use distant_core::{RemoteProcessError, SessionChannelExtError, TransportError, WatchError}; - -/// Exit codes following https://www.freebsd.org/cgi/man.cgi?query=sysexits&sektion=3 -#[derive(Copy, Clone, PartialEq, Eq, Hash)] -pub enum ExitCode { - /// EX_USAGE (64) - being used when arguments missing or bad arguments provided to CLI - Usage, - - /// EX_DATAERR (65) - being used when bad data received not in UTF-8 format or transport data - /// is bad - DataErr, - - /// EX_NOINPUT (66) - being used when not getting expected data from launch - NoInput, - - /// EX_NOHOST (68) - being used when failed to resolve a host - NoHost, - - /// EX_UNAVAILABLE (69) - being used when IO error encountered where connection is problem - Unavailable, - - /// EX_SOFTWARE (70) - being used for when an action fails as well as for internal errors that - /// can occur like joining a task - Software, - - /// EX_OSERR (71) - being used when fork failed - OsErr, - - /// EX_IOERR (74) - being used as catchall for IO errors - IoError, - - /// EX_TEMPFAIL (75) - being used when we get a timeout - TempFail, - - /// EX_PROTOCOL (76) - being used as catchall for transport errors - Protocol, - - /// Custom exit code to pass back verbatim - Custom(i32), -} - -impl ExitCode { - /// Convert into numeric exit code - pub fn to_i32(self) -> i32 { - match self { - Self::Usage => 64, - Self::DataErr => 65, - Self::NoInput => 66, - Self::NoHost => 68, - Self::Unavailable => 69, - Self::Software => 70, - Self::OsErr => 71, - Self::IoError => 74, - Self::TempFail => 75, - Self::Protocol => 76, - Self::Custom(x) => x, - } - } -} - -impl From for i32 { - fn from(code: ExitCode) -> Self { - code.to_i32() - } -} - -/// Represents an error that can be converted into an exit code -pub trait ExitCodeError: std::error::Error { - fn to_exit_code(&self) -> ExitCode; - - /// Indicates if the error message associated with this exit code error - /// should be printed, or if this is just used to reflect the exit code - /// when the process exits - fn is_silent(&self) -> bool { - false - } - - fn to_i32(&self) -> i32 { - self.to_exit_code().to_i32() - } -} - -impl ExitCodeError for std::io::Error { - fn to_exit_code(&self) -> ExitCode { - use std::io::ErrorKind; - match self.kind() { - ErrorKind::ConnectionAborted - | ErrorKind::ConnectionRefused - | ErrorKind::ConnectionReset - | ErrorKind::NotConnected => ExitCode::Unavailable, - ErrorKind::InvalidData => ExitCode::DataErr, - ErrorKind::TimedOut => ExitCode::TempFail, - _ => ExitCode::IoError, - } - } -} - -impl ExitCodeError for TransportError { - fn to_exit_code(&self) -> ExitCode { - match self { - TransportError::IoError(x) => x.to_exit_code(), - _ => ExitCode::Protocol, - } - } -} - -impl ExitCodeError for RemoteProcessError { - fn is_silent(&self) -> bool { - true - } - - fn to_exit_code(&self) -> ExitCode { - match self { - Self::ChannelDead => ExitCode::Unavailable, - Self::TransportError(x) => x.to_exit_code(), - Self::UnexpectedEof => ExitCode::IoError, - Self::WaitFailed(_) => ExitCode::Software, - } - } -} - -impl ExitCodeError for SessionChannelExtError { - fn to_exit_code(&self) -> ExitCode { - match self { - Self::Failure(_) => ExitCode::Software, - Self::TransportError(x) => x.to_exit_code(), - Self::MismatchedResponse => ExitCode::Protocol, - } - } -} - -impl ExitCodeError for WatchError { - fn to_exit_code(&self) -> ExitCode { - match self { - Self::MissingConfirmation => ExitCode::Protocol, - Self::ServerError(_) => ExitCode::Software, - Self::TransportError(x) => x.to_exit_code(), - Self::QueuedChangeDropped => ExitCode::Software, - Self::UnexpectedResponse(_) => ExitCode::Protocol, - } - } -} - -impl From for Box { - fn from(x: T) -> Self { - Box::new(x) - } -} diff --git a/src/lib.rs b/src/lib.rs index a43b096..e19ce92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,69 +1,72 @@ -mod buf; +use derive_more::{Display, Error, From}; +use std::process::{ExitCode, Termination}; + +mod cli; +pub mod config; mod constants; -mod environment; -mod exit; -mod link; -mod msg; -mod opt; -mod output; -mod session; -mod stdin; -mod subcommand; -mod utils; +mod paths; -use log::error; +#[cfg(windows)] +pub mod win_service; -pub use exit::{ExitCode, ExitCodeError}; +pub use self::config::Config; +pub use cli::Cli; -/// Main entrypoint into the program -pub fn run() { - let opt = opt::Opt::load(); - let logger = init_logging(&opt.common, opt.subcommand.is_remote_process()); - if let Err(x) = opt.subcommand.run(opt.common) { - if !x.is_silent() { - error!("Exiting due to error: {}", x); - } - logger.flush(); - logger.shutdown(); +/// Wrapper around a [`CliResult`] that provides [`Termination`] support +pub struct MainResult(CliResult); + +impl MainResult { + pub const OK: MainResult = MainResult(Ok(())); +} - std::process::exit(x.to_i32()); +impl From for MainResult { + fn from(res: CliResult) -> Self { + Self(res) } } -fn init_logging(opt: &opt::CommonOpt, is_remote_process: bool) -> flexi_logger::LoggerHandle { - use flexi_logger::{FileSpec, LevelFilter, LogSpecification, Logger}; - let modules = &["distant", "distant_core"]; +impl From for MainResult { + fn from(x: anyhow::Error) -> Self { + Self(Err(CliError::Error(x))) + } +} - // Disable logging for everything but our binary, which is based on verbosity - let mut builder = LogSpecification::builder(); - builder.default(LevelFilter::Off); +impl From> for MainResult { + fn from(res: anyhow::Result<()>) -> Self { + Self(res.map_err(CliError::Error)) + } +} - // For each module, configure logging - for module in modules { - builder.module(module, opt.log_level.to_log_level_filter()); +pub type CliResult = Result<(), CliError>; - // If quiet, we suppress all logging output - // - // NOTE: For a process request, unless logging to a file, we also suppress logging output - // to avoid unexpected results when being treated like a process - // - // Without this, CI tests can sporadically fail when getting the exit code of a - // process because an error log is provided about failing to broadcast a response - // on the client side - if opt.quiet || (is_remote_process && opt.log_file.is_none()) { - builder.module(module, LevelFilter::Off); - } - } +/// Represents an error associated with the CLI +#[derive(Debug, Display, Error, From)] +pub enum CliError { + /// CLI should return a specific error code + Exit(#[error(not(source))] u8), - // Create our logger, but don't initialize yet - let logger = Logger::with(builder.build()).format_for_files(flexi_logger::opt_format); + /// CLI encountered some unexpected error + Error(#[error(not(source))] anyhow::Error), +} - // If provided, log to file instead of stderr - let logger = if let Some(path) = opt.log_file.as_ref() { - logger.log_to_file(FileSpec::try_from(path).expect("Failed to create log file spec")) - } else { - logger - }; +impl CliError { + /// Represents a generic failure with exit code = 1 + pub const FAILURE: CliError = CliError::Exit(1); +} - logger.start().expect("Failed to initialize logger") +impl Termination for MainResult { + fn report(self) -> ExitCode { + match self.0 { + Ok(_) => ExitCode::SUCCESS, + Err(x) => match x { + CliError::Exit(code) => ExitCode::from(code), + CliError::Error(x) => { + eprintln!("{x:?}"); + ::log::error!("{x:?}"); + ::log::logger().flush(); + ExitCode::FAILURE + } + }, + } + } } diff --git a/src/main.rs b/src/main.rs index 7b33e34..9176a18 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,39 @@ -//! # distant -//! -//! ### Exit codes -//! -//! * EX_USAGE (64) - being used when arguments missing or bad arguments provided to CLI -//! * EX_DATAERR (65) - being used when bad data received not in UTF-8 format or transport data is bad -//! * EX_NOINPUT (66) - being used when not getting expected data from launch -//! * EX_NOHOST (68) - being used when failed to resolve a host -//! * EX_UNAVAILABLE (69) - being used when IO error encountered where connection is problem -//! * EX_OSERR (71) - being used when fork failed -//! * EX_IOERR (74) - being used as catchall for IO errors -//! * EX_TEMPFAIL (75) - being used when we get a timeout -//! * EX_PROTOCOL (76) - being used as catchall for transport errors +use distant::{Cli, MainResult}; -fn main() { - distant::run(); +#[cfg(unix)] +fn main() -> MainResult { + let cli = match Cli::initialize() { + Ok(cli) => cli, + Err(x) => return MainResult::from(x), + }; + let _logger = cli.init_logger(); + MainResult::from(cli.run()) +} + +#[cfg(windows)] +fn main() -> MainResult { + let cli = match Cli::initialize() { + Ok(cli) => cli, + Err(x) => return MainResult::from(x), + }; + let _logger = cli.init_logger(); + + // If we are trying to listen as a manager, try as a service first + if cli.is_manager_listen_command() { + match distant::win_service::run() { + // Success! So we don't need to run again + Ok(_) => return MainResult::OK, + + // In this case, we know there was a service error, and we're assuming it + // means that we were trying to dispatch a service when we were not started + // as a service, so we will move forward as a console application + Err(distant::win_service::ServiceError::Service(_)) => (), + + // Otherwise, we got a raw error that we want to return + Err(distant::win_service::ServiceError::Anyhow(x)) => return MainResult::from(x), + } + } + + // Otherwise, execute as a non-service CLI + MainResult::from(cli.run()) } diff --git a/src/opt.rs b/src/opt.rs deleted file mode 100644 index 9537287..0000000 --- a/src/opt.rs +++ /dev/null @@ -1,755 +0,0 @@ -use crate::{ - constants::{ - SERVER_CONN_MSG_CAPACITY_STR, SESSION_FILE_PATH_STR, SESSION_SOCKET_PATH_STR, TIMEOUT_STR, - }, - exit::ExitCodeError, - subcommand, -}; -use derive_more::{Display, Error, From, IsVariant}; -use distant_core::{PortRange, RequestData}; -use once_cell::sync::Lazy; -use std::{ - env, - net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr}, - path::PathBuf, - str::FromStr, - time::Duration, -}; -use structopt::StructOpt; -use strum::{EnumString, EnumVariantNames, IntoStaticStr, VariantNames}; - -static USERNAME: Lazy = Lazy::new(whoami::username); - -/// Options and commands to apply to binary -#[derive(Clone, Debug, StructOpt)] -#[structopt(name = "distant")] -pub struct Opt { - #[structopt(flatten)] - pub common: CommonOpt, - - #[structopt(subcommand)] - pub subcommand: Subcommand, -} - -impl Opt { - /// Loads options from CLI arguments - pub fn load() -> Self { - Self::from_args() - } -} - -#[derive( - Copy, - Clone, - Debug, - Display, - PartialEq, - Eq, - IsVariant, - IntoStaticStr, - EnumString, - EnumVariantNames, -)] -#[strum(serialize_all = "snake_case")] -pub enum LogLevel { - Off, - Error, - Warn, - Info, - Debug, - Trace, -} - -impl LogLevel { - pub fn to_log_level_filter(self) -> log::LevelFilter { - match self { - Self::Off => log::LevelFilter::Off, - Self::Error => log::LevelFilter::Error, - Self::Warn => log::LevelFilter::Warn, - Self::Info => log::LevelFilter::Info, - Self::Debug => log::LevelFilter::Debug, - Self::Trace => log::LevelFilter::Trace, - } - } -} - -/// Contains options that are common across subcommands -#[derive(Clone, Debug, StructOpt)] -pub struct CommonOpt { - /// Quiet mode, suppresses all logging (shortcut for log level off) - #[structopt(short, long, global = true)] - pub quiet: bool, - - /// Log level to use throughout the application - #[structopt( - long, - global = true, - case_insensitive = true, - default_value = LogLevel::Info.into(), - possible_values = LogLevel::VARIANTS - )] - pub log_level: LogLevel, - - /// Log output to disk instead of stderr - #[structopt(long, global = true)] - pub log_file: Option, - - /// Represents the maximum time (in seconds) to wait for a network - /// request before timing out; a timeout of 0 implies waiting indefinitely - #[structopt(short, long, global = true, default_value = &TIMEOUT_STR)] - pub timeout: f32, -} - -impl CommonOpt { - /// Creates a new duration representing the timeout in seconds - pub fn to_timeout_duration(&self) -> Duration { - Duration::from_secs_f32(self.timeout) - } -} - -/// Contains options related sessions -#[derive(Clone, Debug, StructOpt)] -pub struct SessionOpt { - /// Represents the location of the file containing session information, - /// only useful when the session is set to "file" - #[structopt(long, default_value = &SESSION_FILE_PATH_STR)] - pub session_file: PathBuf, - - /// Represents the location of the session's socket to communicate across, - /// only useful when the session is set to "socket" - #[structopt(long, default_value = &SESSION_SOCKET_PATH_STR)] - pub session_socket: PathBuf, -} - -/// Contains options related ssh -#[derive(Clone, Debug, StructOpt)] -pub struct SshConnectionOpts { - /// Host to use for connection to when using SSH method - #[structopt(name = "ssh-host", long, default_value = "localhost")] - pub host: String, - - /// Port to use for connection when using SSH method - #[structopt(name = "ssh-port", long, default_value = "22")] - pub port: u16, - - /// Alternative user for connection when using SSH method - #[structopt(name = "ssh-user", long)] - pub user: Option, -} - -#[derive(Clone, Debug, StructOpt)] -pub enum Subcommand { - /// Performs some action on a remote machine - Action(ActionSubcommand), - - /// Launches the server-portion of the binary on a remote machine - Launch(LaunchSubcommand), - - /// Begins listening for incoming requests - Listen(ListenSubcommand), - - /// Specialized treatment of running a remote LSP process - Lsp(LspSubcommand), - - /// Specialized treatment of running a remote shell process - Shell(ShellSubcommand), -} - -impl Subcommand { - /// Runs the subcommand, returning the result - pub fn run(self, opt: CommonOpt) -> Result<(), Box> { - match self { - Self::Action(cmd) => subcommand::action::run(cmd, opt)?, - Self::Launch(cmd) => subcommand::launch::run(cmd, opt)?, - Self::Listen(cmd) => subcommand::listen::run(cmd, opt)?, - Self::Lsp(cmd) => subcommand::lsp::run(cmd, opt)?, - Self::Shell(cmd) => subcommand::shell::run(cmd, opt)?, - } - - Ok(()) - } - - /// Returns true if subcommand simplifies to acting as a proxy for a remote process - pub fn is_remote_process(&self) -> bool { - match self { - Self::Action(cmd) => cmd - .operation - .as_ref() - .map(|req| req.is_proc_spawn()) - .unwrap_or_default(), - Self::Lsp(_) => true, - _ => false, - } - } -} - -/// Represents the method to use in communicating with a remote machine -#[derive( - Copy, - Clone, - Debug, - Display, - PartialEq, - Eq, - IsVariant, - IntoStaticStr, - EnumString, - EnumVariantNames, -)] -#[strum(serialize_all = "snake_case")] -pub enum Method { - /// Launch/connect to a distant server running on a remote machine - Distant, - - /// Connect to an SSH server running on a remote machine - #[cfg(any(feature = "libssh", feature = "ssh2"))] - Ssh, -} - -impl Default for Method { - fn default() -> Self { - Self::Distant - } -} - -/// Represents the format for data communicated to & from the client -#[derive( - Copy, - Clone, - Debug, - Display, - PartialEq, - Eq, - IsVariant, - IntoStaticStr, - EnumString, - EnumVariantNames, -)] -#[strum(serialize_all = "snake_case")] -pub enum Format { - /// Sends and receives data in JSON format - Json, - - /// Commands are traditional shell commands and output responses are - /// inline with what is expected of a program's output in a shell - Shell, -} - -/// Represents subcommand to execute some operation remotely -#[derive(Clone, Debug, StructOpt)] -#[structopt(verbatim_doc_comment)] -pub struct ActionSubcommand { - /// Represents the format that results should be returned - /// - /// Currently, there are two possible formats: - /// - /// 1. "json": printing out JSON for external program usage - /// - /// 2. "shell": printing out human-readable results for interactive shell usage - #[structopt( - short, - long, - case_insensitive = true, - default_value = Format::Shell.into(), - possible_values = Format::VARIANTS - )] - pub format: Format, - - /// Method to communicate with a remote machine - #[structopt( - short, - long, - case_insensitive = true, - default_value = Method::default().into(), - possible_values = Method::VARIANTS - )] - pub method: Method, - - /// Represents the medium for retrieving a session for use in performing the action - #[structopt( - long, - case_insensitive = true, - default_value = SessionInput::default().into(), - possible_values = SessionInput::VARIANTS - )] - pub session: SessionInput, - - /// Contains additional information related to sessions - #[structopt(flatten)] - pub session_data: SessionOpt, - - /// SSH connection settings when method is ssh - #[structopt(flatten)] - pub ssh_connection: SshConnectionOpts, - - /// If specified, commands to send are sent over stdin and responses are received - /// over stdout (and stderr if mode is shell) - #[structopt(short, long)] - pub interactive: bool, - - /// Operation to send over the wire if not in interactive mode - #[structopt(subcommand)] - pub operation: Option, -} - -/// Represents options for binding a server to an IP address -#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, IsVariant)] -pub enum BindAddress { - #[display(fmt = "ssh")] - Ssh, - #[display(fmt = "any")] - Any, - Ip(IpAddr), -} - -#[derive(Clone, Debug, Display, From, Error, PartialEq, Eq)] -pub enum ConvertToIpAddrError { - AddrParseError(AddrParseError), - #[display(fmt = "SSH_CONNECTION missing 3rd argument (host ip)")] - MissingSshAddr, - VarError(env::VarError), -} - -impl BindAddress { - /// Converts address into valid IP; in the case of "any", will leverage the - /// `use_ipv6` flag to determine if binding should use ipv4 or ipv6 - pub fn to_ip_addr(self, use_ipv6: bool) -> Result { - match self { - Self::Ssh => { - let ssh_connection = env::var("SSH_CONNECTION")?; - let ip_str = ssh_connection - .split(' ') - .nth(2) - .ok_or(ConvertToIpAddrError::MissingSshAddr)?; - let ip = ip_str.parse::()?; - Ok(ip) - } - Self::Any if use_ipv6 => Ok(IpAddr::V6(Ipv6Addr::UNSPECIFIED)), - Self::Any => Ok(IpAddr::V4(Ipv4Addr::UNSPECIFIED)), - Self::Ip(addr) => Ok(addr), - } - } -} - -impl FromStr for BindAddress { - type Err = AddrParseError; - - fn from_str(s: &str) -> Result { - match s.trim() { - "ssh" => Ok(Self::Ssh), - "any" => Ok(Self::Any), - "localhost" => Ok(Self::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST))), - x => Ok(Self::Ip(x.parse::()?)), - } - } -} - -/// Represents the means by which to share the session from launching on a remote machine -#[derive( - Copy, - Clone, - Debug, - Display, - PartialEq, - Eq, - IntoStaticStr, - IsVariant, - EnumString, - EnumVariantNames, -)] -#[strum(serialize_all = "snake_case")] -pub enum SessionOutput { - /// Session will be exposed as a series of environment variables - /// - /// * `DISTANT_HOST=` - /// * `DISTANT_PORT=` - /// * `DISTANT_KEY=` - /// - /// Note that this does not actually create the environment variables, - /// but instead prints out a message detailing how to set the environment - /// variables, which can be evaluated to set them - Environment, - - /// Session is in a file in the form of `DISTANT CONNECT ` - File, - - /// Special scenario where the session is not shared but is instead kept within the - /// launch program, causing the program itself to listen on stdin for input rather - /// than terminating - Keep, - - /// Session is stored and retrieved over anonymous pipes (stdout/stdin) - /// in form of `DISTANT CONNECT ` - Pipe, - - /// Special scenario where the session is not shared but is instead kept within the - /// launch program, where the program now listens on a unix socket for input - #[cfg(unix)] - Socket, -} - -impl Default for SessionOutput { - /// Default to environment output - fn default() -> Self { - Self::Environment - } -} - -/// Represents the means by which to consume a session when performing an action -#[derive( - Copy, - Clone, - Debug, - Display, - PartialEq, - Eq, - IntoStaticStr, - IsVariant, - EnumString, - EnumVariantNames, -)] -#[strum(serialize_all = "snake_case")] -pub enum SessionInput { - /// Session is in a environment variables - /// - /// * `DISTANT_HOST=` - /// * `DISTANT_PORT=` - /// * `DISTANT_KEY=` - Environment, - - /// Session is in a file in the form of `DISTANT CONNECT ` - File, - - /// Session is stored and retrieved over anonymous pipes (stdout/stdin) - /// in form of `DISTANT CONNECT ` - Pipe, - - /// Session is stored and retrieved from the initializeOptions of the initialize request - /// that is first sent for an LSP through - Lsp, - - /// Session isn't directly available but instead there is a process listening - /// on a unix socket that will forward requests and responses - #[cfg(unix)] - Socket, -} - -impl Default for SessionInput { - /// Default to environment output - fn default() -> Self { - Self::Environment - } -} - -/// Represents subcommand to launch a remote server -#[derive(Clone, Debug, StructOpt)] -pub struct LaunchSubcommand { - /// Represents the medium for sharing the session upon launching on a remote machine - #[structopt( - long, - default_value = SessionOutput::default().into(), - possible_values = SessionOutput::VARIANTS - )] - pub session: SessionOutput, - - /// Contains additional information related to sessions - #[structopt(flatten)] - pub session_data: SessionOpt, - - /// If specified, launch will fail when attempting to bind to a unix socket that - /// already exists, rather than removing the old socket - #[structopt(long)] - pub fail_if_socket_exists: bool, - - /// The time in seconds before shutting down the server if there are no active - /// connections. The countdown begins once all connections have closed and - /// stops when a new connection is made. In not specified, the server will not - /// shutdown at any point when there are no active connections. - /// - /// In the case of launch, this is only applicable when it is set to socket session - /// as this controls when the unix socket listener would shutdown, not when the - /// remote server it is connected to will shutdown. - /// - /// To configure the remote server's shutdown time, provide it as an argument - /// via `--extra-server-args` - #[structopt(long)] - pub shutdown_after: Option, - - /// When session is socket, runs in foreground instead of spawning a background process - #[structopt(long)] - pub foreground: bool, - - /// Represents the format that results should be returned when session is "keep", - /// causing the launcher to enter an interactive loop to handle input and output - /// itself rather than enabling other clients to connect - /// - /// Additionally, for all session types, dictates how authentication questions - /// and answers should be communicated (over shell, or using json if ssh2 feature enabled) - #[structopt( - short, - long, - case_insensitive = true, - default_value = Format::Shell.into(), - possible_values = Format::VARIANTS - )] - pub format: Format, - - /// Path to distant program on remote machine to execute via ssh; - /// by default, this program needs to be available within PATH as - /// specified when compiling ssh (not your login shell) - #[structopt(long, default_value = "distant")] - pub distant: String, - - /// Path to ssh program on local machine to execute when using external ssh - #[structopt(long, default_value = "ssh")] - pub ssh: String, - - /// If using native ssh integration, represents the backend - #[cfg(any(feature = "libssh", feature = "ssh2"))] - #[structopt(long, default_value = distant_ssh2::SshBackend::default().as_static_str())] - pub ssh_backend: distant_ssh2::SshBackend, - - /// If specified, will use the external ssh program to launch the server - /// instead of the native integration; does nothing if the ssh2 feature is - /// not enabled as there is no other option than external ssh - #[structopt(long)] - pub external_ssh: bool, - - /// Control the IP address that the server binds to. - /// - /// The default is `ssh', in which case the server will reply from the IP address that the SSH - /// connection came from (as found in the SSH_CONNECTION environment variable). This is - /// useful for multihomed servers. - /// - /// With --bind-server=any, the server will reply on the default interface and will not bind to - /// a particular IP address. This can be useful if the connection is made through sslh or - /// another tool that makes the SSH connection appear to come from localhost. - /// - /// With --bind-server=IP, the server will attempt to bind to the specified IP address. - #[structopt(long, value_name = "ssh|any|IP", default_value = "ssh")] - pub bind_server: BindAddress, - - /// Additional arguments to provide to the server - #[structopt(long, allow_hyphen_values(true))] - pub extra_server_args: Option, - - /// Username to use when sshing into remote machine - #[structopt(short, long, default_value = &USERNAME)] - pub username: String, - - /// Explicit identity file to use with ssh - #[structopt(short, long)] - pub identity_file: Option, - - /// If specified, will not launch distant using a login shell but instead execute it directly - #[structopt(long)] - pub no_shell: bool, - - /// Port to use for sshing into the remote machine - #[structopt(short, long, default_value = "22")] - pub port: u16, - - /// Host to use for sshing into the remote machine - #[structopt(name = "HOST")] - pub host: String, -} - -impl LaunchSubcommand { - /// Creates a new duration representing the shutdown period in seconds - pub fn to_shutdown_after_duration(&self) -> Option { - self.shutdown_after - .as_ref() - .copied() - .map(Duration::from_secs_f32) - } -} - -/// Represents subcommand to operate in listen mode for incoming requests -#[derive(Clone, Debug, StructOpt)] -pub struct ListenSubcommand { - /// Runs in foreground instead of spawning a background process - #[structopt(long)] - pub foreground: bool, - - /// Control the IP address that the distant binds to - /// - /// There are three options here: - /// - /// 1. `ssh`: the server will reply from the IP address that the SSH - /// connection came from (as found in the SSH_CONNECTION environment variable). This is - /// useful for multihomed servers. - /// - /// 2. `any`: the server will reply on the default interface and will not bind to - /// a particular IP address. This can be useful if the connection is made through sslh or - /// another tool that makes the SSH connection appear to come from localhost. - /// - /// 3. `IP`: the server will attempt to bind to the specified IP address. - #[structopt(short, long, value_name = "ssh|any|IP", default_value = "localhost")] - pub host: BindAddress, - - /// If specified, will bind to the ipv6 interface if host is "any" instead of ipv4 - #[structopt(short = "6", long)] - pub use_ipv6: bool, - - /// Maximum capacity for concurrent message handled by the server - #[structopt(long, default_value = &SERVER_CONN_MSG_CAPACITY_STR)] - pub max_msg_capacity: u16, - - /// If specified, the server will not generate a key but instead listen on stdin for the next - /// 32 bytes that it will use as the key instead. Receiving less than 32 bytes before stdin - /// is closed is considered an error and any bytes after the first 32 are not used for the key - #[structopt(long)] - pub key_from_stdin: bool, - - /// The time in seconds before shutting down the server if there are no active - /// connections. The countdown begins once all connections have closed and - /// stops when a new connection is made. In not specified, the server will not - /// shutdown at any point when there are no active connections. - #[structopt(long)] - pub shutdown_after: Option, - - /// Changes the current working directory (cwd) to the specified directory - #[structopt(long)] - pub current_dir: Option, - - /// Set the port(s) that the server will attempt to bind to - /// - /// This can be in the form of PORT1 or PORT1:PORTN to provide a range of ports. - /// With -p 0, the server will let the operating system pick an available TCP port. - /// - /// Please note that this option does not affect the server-side port used by SSH - #[structopt(short, long, value_name = "PORT[:PORT2]", default_value = "8080:8099")] - pub port: PortRange, -} - -impl ListenSubcommand { - /// Creates a new duration representing the shutdown period in seconds - pub fn to_shutdown_after_duration(&self) -> Option { - self.shutdown_after - .as_ref() - .copied() - .map(Duration::from_secs_f32) - } -} - -/// Represents subcommand to execute some LSP server on a remote machine -#[derive(Clone, Debug, StructOpt)] -#[structopt(verbatim_doc_comment)] -pub struct LspSubcommand { - /// Represents the format that results should be returned - /// - /// Currently, there are two possible formats: - /// - /// 1. "json": printing out JSON for external program usage - /// - /// 2. "shell": printing out human-readable results for interactive shell usage - #[structopt( - short, - long, - case_insensitive = true, - default_value = Format::Shell.into(), - possible_values = Format::VARIANTS - )] - pub format: Format, - - /// Method to communicate with a remote machine - #[structopt( - short, - long, - case_insensitive = true, - default_value = Method::default().into(), - possible_values = Method::VARIANTS - )] - pub method: Method, - - /// Represents the medium for retrieving a session to use when running a remote LSP server - #[structopt( - long, - case_insensitive = true, - default_value = SessionInput::default().into(), - possible_values = SessionInput::VARIANTS - )] - pub session: SessionInput, - - /// Contains additional information related to sessions - #[structopt(flatten)] - pub session_data: SessionOpt, - - /// SSH connection settings when method is ssh - #[structopt(flatten)] - pub ssh_connection: SshConnectionOpts, - - /// If provided, will run in persist mode, meaning that the process will not be killed if the - /// client disconnects from the server - #[structopt(long)] - pub persist: bool, - - /// If provided, will run LSP in a pty - #[structopt(long)] - pub pty: bool, - - /// Command to run on the remote machine that represents an LSP server - pub cmd: String, - - /// Additional arguments to supply to the remote machine - pub args: Vec, -} - -/// Represents subcommand to execute some shell on a remote machine -#[derive(Clone, Debug, StructOpt)] -#[structopt(verbatim_doc_comment)] -pub struct ShellSubcommand { - /// Represents the format that results should be returned - /// - /// Currently, there are two possible formats: - /// - /// 1. "json": printing out JSON for external program usage - /// - /// 2. "shell": printing out human-readable results for interactive shell usage - #[structopt( - short, - long, - case_insensitive = true, - default_value = Format::Shell.into(), - possible_values = Format::VARIANTS - )] - pub format: Format, - - /// Method to communicate with a remote machine - #[structopt( - short, - long, - case_insensitive = true, - default_value = Method::default().into(), - possible_values = Method::VARIANTS - )] - pub method: Method, - - /// Represents the medium for retrieving a session to use when running a remote LSP server - #[structopt( - long, - case_insensitive = true, - default_value = SessionInput::default().into(), - possible_values = SessionInput::VARIANTS - )] - pub session: SessionInput, - - /// Contains additional information related to sessions - #[structopt(flatten)] - pub session_data: SessionOpt, - - /// SSH connection settings when method is ssh - #[structopt(flatten)] - pub ssh_connection: SshConnectionOpts, - - /// If provided, will run in persist mode, meaning that the process will not be killed if the - /// client disconnects from the server - #[structopt(long)] - pub persist: bool, - - /// Command to run on the remote machine as the shell (defaults to $TERM) - pub cmd: Option, - - /// Additional arguments to supply to the shell (defaults to nothing) - pub args: Vec, -} diff --git a/src/paths.rs b/src/paths.rs new file mode 100644 index 0000000..7a4e0b5 --- /dev/null +++ b/src/paths.rs @@ -0,0 +1,112 @@ +use directories::ProjectDirs; +use once_cell::sync::Lazy; +use std::path::PathBuf; + +#[cfg(unix)] +const SOCKET_FILE_STR: &str = "distant.sock"; + +/// User-oriented paths +pub mod user { + use super::*; + + /// Root project directory used to calculate other paths + static PROJECT_DIR: Lazy = Lazy::new(|| { + ProjectDirs::from("", "", "distant").expect("Could not determine valid $HOME path") + }); + + /// Path to configuration settings for distant client/manager/server + pub static CONFIG_FILE_PATH: Lazy = + Lazy::new(|| PROJECT_DIR.config_dir().join("config.toml")); + + /// Path to cache file used for arbitrary CLI data + pub static CACHE_FILE_PATH: Lazy = + Lazy::new(|| PROJECT_DIR.cache_dir().join("cache.toml")); + + pub static CACHE_FILE_PATH_STR: Lazy = + Lazy::new(|| CACHE_FILE_PATH.to_string_lossy().to_string()); + + /// Path to log file for distant client + pub static CLIENT_LOG_FILE_PATH: Lazy = + Lazy::new(|| PROJECT_DIR.cache_dir().join("client.log")); + + /// Path to log file for distant manager + pub static MANAGER_LOG_FILE_PATH: Lazy = + Lazy::new(|| PROJECT_DIR.cache_dir().join("manager.log")); + + /// Path to log file for distant server + pub static SERVER_LOG_FILE_PATH: Lazy = + Lazy::new(|| PROJECT_DIR.cache_dir().join("server.log")); + + /// Path to log file for distant generate + pub static GENERATE_LOG_FILE_PATH: Lazy = + Lazy::new(|| PROJECT_DIR.cache_dir().join("generate.log")); + + /// For Linux & BSD, this uses the runtime path. For Mac, this uses the tmp path + /// + /// * `/run/user/1001/distant/{user}.distant.sock` on Linux + /// * `/var/run/{user}.distant.sock` on BSD + /// * `/tmp/{user}.distant.dock` on MacOS + #[cfg(unix)] + pub static UNIX_SOCKET_PATH: Lazy = Lazy::new(|| { + // Form of {user}.distant.sock + let mut file_name = whoami::username_os(); + file_name.push("."); + file_name.push(SOCKET_FILE_STR); + + PROJECT_DIR + .runtime_dir() + .map(std::path::Path::to_path_buf) + .unwrap_or_else(std::env::temp_dir) + .join(file_name) + }); + + /// Name of the pipe used by Windows in the form of `{user}.distant` + #[cfg(windows)] + pub static WINDOWS_PIPE_NAME: Lazy = + Lazy::new(|| format!("{}.distant", whoami::username())); +} + +/// Global paths +pub mod global { + use super::*; + + /// Windows ProgramData directory from from the %ProgramData% environment variable + #[cfg(windows)] + static PROGRAM_DATA_DIR: Lazy = Lazy::new(|| { + PathBuf::from(std::env::var("ProgramData").expect("Could not determine %ProgramData%")) + }); + + #[cfg(windows)] + static CONFIG_DIR: Lazy = Lazy::new(|| PROGRAM_DATA_DIR.join("distant")); + + #[cfg(unix)] + static CONFIG_DIR: Lazy = Lazy::new(|| PathBuf::from("/etc").join("distant")); + + /// Path to configuration settings for distant client/manager/server + pub static CONFIG_FILE_PATH: Lazy = Lazy::new(|| CONFIG_DIR.join("config.toml")); + + /// For Linux & BSD, this uses the runtime path. For Mac, this uses the tmp path + /// + /// * `/run/distant.sock` on Linux + /// * `/var/run/distant.sock` on BSD + /// * `/tmp/distant.dock` on MacOS + #[cfg(unix)] + pub static UNIX_SOCKET_PATH: Lazy = Lazy::new(|| { + if cfg!(target_os = "macos") { + std::env::temp_dir().join(SOCKET_FILE_STR) + } else if cfg!(any( + target_os = "freebsd", + target_os = "dragonfly", + target_os = "openbsd", + target_os = "netbsd" + )) { + PathBuf::from("/var").join("run").join(SOCKET_FILE_STR) + } else { + PathBuf::from("/run").join(SOCKET_FILE_STR) + } + }); + + /// Name of the pipe used by Windows + #[cfg(windows)] + pub static WINDOWS_PIPE_NAME: Lazy = Lazy::new(|| "distant".to_string()); +} diff --git a/src/session.rs b/src/session.rs deleted file mode 100644 index 3d6de63..0000000 --- a/src/session.rs +++ /dev/null @@ -1,163 +0,0 @@ -use crate::{ - buf::StringBuf, constants::MAX_PIPE_CHUNK_SIZE, opt::Format, output::ResponseOut, stdin, -}; -use distant_core::{Mailbox, Request, RequestData, Session}; -use log::*; -use std::io; -use structopt::StructOpt; -use tokio::{ - sync::{mpsc, oneshot}, - task::JoinHandle, -}; - -/// Represents a wrapper around a session that provides CLI functionality such as reading from -/// stdin and piping results back out to stdout -pub struct CliSession { - req_task: JoinHandle<()>, -} - -impl CliSession { - /// Creates a new instance of a session for use in CLI interactions being fed input using - /// the program's stdin - pub fn new_for_stdin(tenant: String, session: Session, format: Format) -> Self { - let (_stdin_thread, stdin_rx) = stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE); - - Self::new(tenant, session, format, stdin_rx) - } - - /// Creates a new instance of a session for use in CLI interactions being fed input using - /// the provided receiver - pub fn new( - tenant: String, - session: Session, - format: Format, - stdin_rx: mpsc::Receiver>, - ) -> Self { - let map_line = move |line: &str| match format { - Format::Json => serde_json::from_str(line) - .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x)), - Format::Shell => { - let data = RequestData::from_iter_safe( - std::iter::once("distant") - .chain(line.trim().split(' ').filter(|s| !s.trim().is_empty())), - ) - .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x)); - - data.map(|x| Request::new(tenant.to_string(), vec![x])) - } - }; - let req_task = tokio::spawn(async move { - process_outgoing_requests(session, stdin_rx, format, map_line).await - }); - - Self { req_task } - } - - /// Wait for the cli session to terminate - pub async fn wait(self) -> io::Result<()> { - match self.req_task.await { - Ok(res) => Ok(res), - Err(x) => Err(io::Error::new(io::ErrorKind::BrokenPipe, x)), - } - } -} - -/// Helper function that loops, processing incoming responses to a mailbox -async fn process_mailbox(mut mailbox: Mailbox, format: Format, exit: oneshot::Receiver<()>) { - let inner = async move { - while let Some(res) = mailbox.next().await { - match ResponseOut::new(format, res) { - Ok(out) => out.print(), - Err(x) => { - error!("Repsonse out failed: {}", x); - break; - } - } - } - }; - - tokio::select! { - _ = inner => {} - _ = exit => {} - } -} - -/// Helper function that loops, processing outgoing requests created from stdin, and printing out -/// responses -async fn process_outgoing_requests( - mut session: Session, - mut stdin_rx: mpsc::Receiver>, - format: Format, - map_line: F, -) where - F: Fn(&str) -> io::Result, -{ - let mut buf = StringBuf::new(); - let mut mailbox_exits = Vec::new(); - - while let Some(data) = stdin_rx.recv().await { - // TODO: Should we support raw bytes? If so, we need to rewrite map_line to take Vec - let data = match String::from_utf8(data) { - Ok(data) => data, - Err(x) => { - error!("Bad stdin: {}", x); - continue; - } - }; - - // Update our buffer with the new data and split it into concrete lines and remainder - buf.push_str(&data); - let (lines, new_buf) = buf.into_full_lines(); - buf = new_buf; - - // For each complete line, parse into a request - if let Some(lines) = lines { - for line in lines.lines().map(|line| line.trim()) { - trace!("Processing line: {:?}", line); - if line.is_empty() { - continue; - } - - // TODO: We need to consolidate MsgReceiver and this logic as this only - // allows messages sent completely on a single line rather than - // MsgReceiver's ability to get a multi-line message - match map_line(line) { - Ok(req) => match session.mail(req).await { - Ok(mut mailbox) => { - // Wait to get our first response before moving on to the next line - // of input - if let Some(res) = mailbox.next().await { - // Convert to response to output, and when successful launch - // a handler for continued responses to the same request - // such as with processes - match ResponseOut::new(format, res) { - Ok(out) => { - out.print(); - - let (tx, rx) = oneshot::channel(); - mailbox_exits.push(tx); - tokio::spawn(process_mailbox(mailbox, format, rx)); - } - Err(x) => { - error!("Map line response out failed: {}", x); - } - } - } - } - Err(x) => { - error!("Failed to send request: {}", x) - } - }, - Err(x) => { - error!("Failed to parse line: {}", x); - } - } - } - } - } - - // Close out any dangling mailbox handlers - for tx in mailbox_exits { - let _ = tx.send(()); - } -} diff --git a/src/subcommand/action.rs b/src/subcommand/action.rs deleted file mode 100644 index 54b3e0b..0000000 --- a/src/subcommand/action.rs +++ /dev/null @@ -1,218 +0,0 @@ -use crate::{ - exit::{ExitCode, ExitCodeError}, - link::RemoteProcessLink, - opt::{ActionSubcommand, CommonOpt, Format}, - output::ResponseOut, - session::CliSession, - subcommand::CommandRunner, - utils, -}; -use derive_more::{Display, Error, From}; -use distant_core::{ - ChangeKindSet, LspData, RemoteProcess, RemoteProcessError, Request, RequestData, Response, - ResponseData, Session, TransportError, WatchError, Watcher, -}; -use tokio::{io, time::Duration}; - -#[derive(Debug, Display, Error, From)] -pub enum Error { - #[display(fmt = "Process failed with exit code: {}", _0)] - BadProcessExit(#[error(not(source))] i32), - Io(io::Error), - #[display(fmt = "Non-interactive but no operation supplied")] - MissingOperation, - OperationFailed, - RemoteProcess(RemoteProcessError), - Transport(TransportError), - Watch(WatchError), -} - -impl ExitCodeError for Error { - fn is_silent(&self) -> bool { - match self { - Self::BadProcessExit(_) | Self::OperationFailed => true, - Self::RemoteProcess(x) => x.is_silent(), - _ => false, - } - } - - fn to_exit_code(&self) -> ExitCode { - match self { - Self::BadProcessExit(x) => ExitCode::Custom(*x), - Self::Io(x) => x.to_exit_code(), - Self::MissingOperation => ExitCode::Usage, - Self::OperationFailed => ExitCode::Software, - Self::RemoteProcess(x) => x.to_exit_code(), - Self::Transport(x) => x.to_exit_code(), - Self::Watch(x) => x.to_exit_code(), - } - } -} - -pub fn run(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> { - let rt = tokio::runtime::Runtime::new()?; - - rt.block_on(async { run_async(cmd, opt).await }) -} - -async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> { - let method = cmd.method; - let ssh_connection = cmd.ssh_connection.clone(); - let session_input = cmd.session; - let timeout = opt.to_timeout_duration(); - let session_file = cmd.session_data.session_file.clone(); - let session_socket = cmd.session_data.session_socket.clone(); - - CommandRunner { - method, - ssh_connection, - session_input, - session_file, - session_socket, - timeout, - } - .run( - |session, timeout, lsp_data| Box::pin(start(cmd, session, timeout, lsp_data)), - Error::Io, - ) - .await -} - -async fn start( - cmd: ActionSubcommand, - mut session: Session, - timeout: Duration, - lsp_data: Option, -) -> Result<(), Error> { - let is_shell_format = matches!(cmd.format, Format::Shell); - - match (cmd.interactive, cmd.operation) { - // Watch request w/ shell format is specially handled and we ignore interactive as - // watch will run and wait - ( - _, - Some(RequestData::Watch { - path, - recursive, - only, - except, - }), - ) if is_shell_format => { - let mut watcher = Watcher::watch( - utils::new_tenant(), - session.into_channel(), - path, - recursive, - only.into_iter().collect::(), - except.into_iter().collect::(), - ) - .await?; - - // Continue to receive and process changes - while let Some(change) = watcher.next().await { - // TODO: Provide a cleaner way to print just a change - let res = Response::new("", 0, vec![ResponseData::Changed(change)]); - ResponseOut::new(cmd.format, res)?.print() - } - - Ok(()) - } - - // ProcSpawn request w/ shell format is specially handled and we ignore interactive as - // the stdin will be used for sending ProcStdin to remote process - ( - _, - Some(RequestData::ProcSpawn { - cmd, - args, - persist, - pty, - }), - ) if is_shell_format => { - let mut proc = RemoteProcess::spawn( - utils::new_tenant(), - session.clone_channel(), - cmd, - args, - persist, - pty, - ) - .await?; - - // If we also parsed an LSP's initialize request for its session, we want to forward - // it along in the case of a process call - if let Some(data) = lsp_data { - proc.stdin.as_mut().unwrap().write(data.to_string()).await?; - } - - // Now, map the remote process' stdin/stdout/stderr to our own process - let link = RemoteProcessLink::from_remote_pipes( - proc.stdin.take(), - proc.stdout.take().unwrap(), - proc.stderr.take().unwrap(), - ); - - // Drop main session as the singular remote process will now manage stdin/stdout/stderr - // NOTE: Without this, severing stdin when from this side would not occur as we would - // continue to maintain a second reference to the remote connection's input - // through the primary session - drop(session); - - let (success, exit_code) = proc.wait().await?; - - // Shut down our link - link.shutdown().await; - - if !success { - if let Some(code) = exit_code { - return Err(Error::BadProcessExit(code)); - } else { - return Err(Error::BadProcessExit(1)); - } - } - - Ok(()) - } - - // All other requests without interactive are oneoffs - (false, Some(data)) => { - let res = session - .send_timeout(Request::new(utils::new_tenant(), vec![data]), timeout) - .await?; - - // If we have an error as our payload, then we want to reflect that in our - // exit code - let is_err = res.payload.iter().any(|d| d.is_error()); - - ResponseOut::new(cmd.format, res)?.print(); - - if is_err { - Err(Error::OperationFailed) - } else { - Ok(()) - } - } - - // Interactive mode will send an optional first request and then continue - // to read stdin to send more - (true, maybe_req) => { - // Send our first request if provided - if let Some(data) = maybe_req { - let res = session - .send_timeout(Request::new(utils::new_tenant(), vec![data]), timeout) - .await?; - ResponseOut::new(cmd.format, res)?.print(); - } - - // Enter into CLI session where we receive requests on stdin and send out - // over stdout/stderr - let cli_session = CliSession::new_for_stdin(utils::new_tenant(), session, cmd.format); - cli_session.wait().await?; - - Ok(()) - } - - // Not interactive and no operation given - (false, None) => Err(Error::MissingOperation), - } -} diff --git a/src/subcommand/launch.rs b/src/subcommand/launch.rs deleted file mode 100644 index 0f1436d..0000000 --- a/src/subcommand/launch.rs +++ /dev/null @@ -1,374 +0,0 @@ -use crate::{ - environment, - exit::{ExitCode, ExitCodeError}, - msg::{MsgReceiver, MsgSender}, - opt::{CommonOpt, Format, LaunchSubcommand, SessionOutput}, - session::CliSession, - utils, -}; -use derive_more::{Display, Error, From}; -use distant_core::{ - PlainCodec, RelayServer, Session, SessionInfo, SessionInfoFile, Transport, TransportListener, - XChaCha20Poly1305Codec, -}; -use log::*; -use std::{path::Path, string::FromUtf8Error}; -use tokio::{io, process::Command, runtime::Runtime, time::Duration}; - -#[derive(Debug, Display, Error, From)] -pub enum Error { - #[display(fmt = "Missing data for session")] - MissingSessionData, - - Fork(#[error(not(source))] i32), - Io(io::Error), - Utf8(FromUtf8Error), -} - -impl ExitCodeError for Error { - fn to_exit_code(&self) -> ExitCode { - match self { - Self::MissingSessionData => ExitCode::NoInput, - Self::Fork(_) => ExitCode::OsErr, - Self::Io(x) => x.to_exit_code(), - Self::Utf8(_) => ExitCode::DataErr, - } - } -} - -pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> { - let rt = Runtime::new()?; - let session_output = cmd.session; - let format = cmd.format; - let is_daemon = !cmd.foreground; - - let session_file = cmd.session_data.session_file.clone(); - let session_socket = cmd.session_data.session_socket.clone(); - let fail_if_socket_exists = cmd.fail_if_socket_exists; - let timeout = opt.to_timeout_duration(); - let shutdown_after = cmd.to_shutdown_after_duration(); - - let session = rt.block_on(async { spawn_remote_server(cmd, opt).await })?; - - // Handle sharing resulting session in different ways - match session_output { - SessionOutput::Environment => { - debug!("Outputting session to environment"); - environment::print_environment(&session) - } - SessionOutput::File => { - debug!("Outputting session to {:?}", session_file); - rt.block_on(async { SessionInfoFile::new(session_file, session).save().await })? - } - SessionOutput::Keep => { - debug!("Entering interactive loop over stdin"); - rt.block_on(async { keep_loop(session, format, timeout).await })? - } - SessionOutput::Pipe => { - debug!("Piping session to stdout"); - println!("{}", session.to_unprotected_string()) - } - #[cfg(unix)] - SessionOutput::Socket if is_daemon => { - debug!( - "Forking and entering interactive loop over unix socket {:?}", - session_socket - ); - - // Force runtime shutdown by dropping it BEFORE forking as otherwise - // this produces a garbage process that won't die - drop(rt); - - run_daemon_socket( - session_socket, - session, - timeout, - fail_if_socket_exists, - shutdown_after, - )?; - } - #[cfg(unix)] - SessionOutput::Socket => { - debug!( - "Entering interactive loop over unix socket {:?}", - session_socket - ); - rt.block_on(async { - socket_loop( - session_socket, - session, - timeout, - fail_if_socket_exists, - shutdown_after, - ) - .await - })? - } - } - - Ok(()) -} - -#[cfg(unix)] -fn run_daemon_socket( - session_socket: impl AsRef, - session: SessionInfo, - timeout: Duration, - fail_if_socket_exists: bool, - shutdown_after: Option, -) -> Result<(), Error> { - use fork::{daemon, Fork}; - match daemon(false, false) { - Ok(Fork::Child) => { - // NOTE: We need to create a runtime within the forked process as - // tokio's runtime doesn't support being transferred from - // parent to child in a fork - let rt = Runtime::new()?; - rt.block_on(async { - socket_loop( - session_socket, - session, - timeout, - fail_if_socket_exists, - shutdown_after, - ) - .await - })? - } - Ok(_) => {} - Err(x) => return Err(Error::Fork(x)), - } - - Ok(()) -} - -async fn keep_loop(info: SessionInfo, format: Format, duration: Duration) -> io::Result<()> { - let addr = info.to_socket_addr().await?; - let codec = XChaCha20Poly1305Codec::from(info.key); - match Session::tcp_connect_timeout(addr, codec, duration).await { - Ok(session) => { - let cli_session = CliSession::new_for_stdin(utils::new_tenant(), session, format); - cli_session.wait().await - } - Err(x) => Err(x), - } -} - -#[cfg(unix)] -async fn socket_loop( - socket_path: impl AsRef, - info: SessionInfo, - duration: Duration, - fail_if_socket_exists: bool, - shutdown_after: Option, -) -> io::Result<()> { - // We need to form a connection with the actual server to forward requests - // and responses between connections - debug!("Connecting to {} {}", info.host, info.port); - let addr = info.to_socket_addr().await?; - let codec = XChaCha20Poly1305Codec::from(info.key); - let session = Session::tcp_connect_timeout(addr, codec, duration).await?; - - // Remove the socket file if it already exists - if !fail_if_socket_exists && socket_path.as_ref().exists() { - debug!("Removing old unix socket instance"); - tokio::fs::remove_file(socket_path.as_ref()).await?; - } - - // Continue to receive connections over the unix socket, store them in our - // connection mapping - debug!("Binding to unix socket: {:?}", socket_path.as_ref()); - let listener = tokio::net::UnixListener::bind(socket_path)?; - - let stream = - TransportListener::initialize(listener, |stream| Transport::new(stream, PlainCodec::new())) - .into_stream(); - - let server = RelayServer::initialize(session, Box::pin(stream), shutdown_after)?; - server - .wait() - .await - .map_err(|x| io::Error::new(io::ErrorKind::Other, x)) -} - -async fn spawn_remote_server(cmd: LaunchSubcommand, opt: CommonOpt) -> Result { - #[cfg(any(feature = "libssh", feature = "ssh2"))] - if cmd.external_ssh { - external_spawn_remote_server(cmd, opt).await - } else { - native_spawn_remote_server(cmd, opt).await - } - - #[cfg(not(any(feature = "libssh", feature = "ssh2")))] - external_spawn_remote_server(cmd, opt).await -} - -/// Spawns a remote server using native ssh library that listens for requests -/// -/// Returns the session associated with the server -#[cfg(any(feature = "libssh", feature = "ssh2"))] -async fn native_spawn_remote_server( - cmd: LaunchSubcommand, - _opt: CommonOpt, -) -> Result { - trace!("native_spawn_remote_server({:?})", cmd); - use distant_ssh2::{ - IntoDistantSessionOpts, Ssh2AuthEvent, Ssh2AuthHandler, Ssh2Session, Ssh2SessionOpts, - }; - - let host = cmd.host; - - // Build our options based on cli input - let mut opts = Ssh2SessionOpts::default(); - if let Some(path) = cmd.identity_file { - opts.identity_files.push(path); - } - opts.backend = cmd.ssh_backend; - opts.port = Some(cmd.port); - opts.user = Some(cmd.username); - - debug!("Connecting to {} {:#?}", host, opts); - let mut ssh_session = Ssh2Session::connect(host.as_str(), opts)?; - - #[derive(Debug, serde::Serialize, serde::Deserialize)] - #[serde(tag = "type")] - enum SshMsg { - #[serde(rename = "ssh_authenticate")] - Authenticate(Ssh2AuthEvent), - #[serde(rename = "ssh_authenticate_answer")] - AuthenticateAnswer { answers: Vec }, - #[serde(rename = "ssh_banner")] - Banner { text: String }, - #[serde(rename = "ssh_host_verify")] - HostVerify { host: String }, - #[serde(rename = "ssh_host_verify_answer")] - HostVerifyAnswer { answer: bool }, - #[serde(rename = "ssh_error")] - Error { msg: String }, - } - - debug!("Authenticating against {}", host); - ssh_session - .authenticate(match cmd.format { - Format::Shell => Ssh2AuthHandler::default(), - Format::Json => { - let tx = MsgSender::from_stdout(); - let tx_2 = tx.clone(); - let tx_3 = tx.clone(); - let tx_4 = tx.clone(); - let rx = MsgReceiver::from_stdin(); - let rx_2 = rx.clone(); - - Ssh2AuthHandler { - on_authenticate: Box::new(move |ev| { - let _ = tx.send_blocking(&SshMsg::Authenticate(ev)); - - let msg: SshMsg = rx.recv_blocking()?; - match msg { - SshMsg::AuthenticateAnswer { answers } => Ok(answers), - x => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("Invalid response received: {:?}", x), - )) - } - } - }), - on_banner: Box::new(move |banner| { - let _ = tx_2.send_blocking(&SshMsg::Banner { - text: banner.to_string(), - }); - }), - on_host_verify: Box::new(move |host| { - let _ = tx_3.send_blocking(&SshMsg::HostVerify { - host: host.to_string(), - })?; - - let msg: SshMsg = rx_2.recv_blocking()?; - match msg { - SshMsg::HostVerifyAnswer { answer } => Ok(answer), - x => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("Invalid response received: {:?}", x), - )) - } - } - }), - on_error: Box::new(move |err| { - let _ = tx_4.send_blocking(&SshMsg::Error { - msg: err.to_string(), - }); - }), - } - } - }) - .await?; - - debug!("Mapping session for {}", host); - let session_info = ssh_session - .into_distant_session_info(IntoDistantSessionOpts { - binary: cmd.distant, - args: cmd.extra_server_args.unwrap_or_default(), - ..Default::default() - }) - .await?; - - Ok(session_info) -} - -/// Spawns a remote server using external ssh command that listens for requests -/// -/// Returns the session associated with the server -async fn external_spawn_remote_server( - cmd: LaunchSubcommand, - _opt: CommonOpt, -) -> Result { - let distant_command = format!( - "{} listen --host {} {}", - cmd.distant, - cmd.bind_server, - cmd.extra_server_args.unwrap_or_default(), - ); - let ssh_command = format!( - "{} -o StrictHostKeyChecking=no ssh://{}@{}:{} {} '{}'", - cmd.ssh, - cmd.username, - cmd.host.as_str(), - cmd.port, - cmd.identity_file - .map(|f| format!("-i {}", f.as_path().display())) - .unwrap_or_default(), - if cmd.no_shell { - distant_command.trim().to_string() - } else { - // TODO: Do we need to try to escape single quotes here because of extra_server_args? - // TODO: Replace this with the ssh2 library shell exec once we integrate that - format!("echo {} | $SHELL -l", distant_command.trim()) - }, - ); - let out = Command::new("sh") - .arg("-c") - .arg(ssh_command) - .output() - .await?; - - // If our attempt to run the program via ssh failed, report it - if !out.status.success() { - return Err(Error::from(io::Error::new( - io::ErrorKind::Other, - String::from_utf8(out.stderr)?.trim().to_string(), - ))); - } - - // Parse our output for the specific session line - // NOTE: The host provided on this line isn't valid, so we fill it in with our actual host - let out = String::from_utf8(out.stdout)?.trim().to_string(); - let mut info = out - .lines() - .find_map(|line| line.parse::().ok()) - .ok_or(Error::MissingSessionData)?; - info.host = cmd.host; - - Ok(info) -} diff --git a/src/subcommand/listen.rs b/src/subcommand/listen.rs deleted file mode 100644 index 7d373b7..0000000 --- a/src/subcommand/listen.rs +++ /dev/null @@ -1,145 +0,0 @@ -use crate::{ - exit::{ExitCode, ExitCodeError}, - opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand}, -}; -use derive_more::{Display, Error, From}; -use distant_core::{ - DistantServer, DistantServerOptions, SecretKey32, UnprotectedToHexKey, XChaCha20Poly1305Codec, -}; -use log::*; -use tokio::{ - io::{self, AsyncReadExt, AsyncWriteExt}, - task::JoinError, -}; - -#[derive(Debug, Display, Error, From)] -pub enum Error { - BadKey, - ConverToIpAddr(ConvertToIpAddrError), - Fork, - Io(io::Error), - Join(JoinError), -} - -impl ExitCodeError for Error { - fn to_exit_code(&self) -> ExitCode { - match self { - Self::BadKey => ExitCode::Usage, - Self::ConverToIpAddr(_) => ExitCode::NoHost, - Self::Fork => ExitCode::OsErr, - Self::Io(x) => x.to_exit_code(), - Self::Join(_) => ExitCode::Software, - } - } -} - -pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> { - if cmd.foreground { - let rt = tokio::runtime::Runtime::new()?; - rt.block_on(async { run_async(cmd, opt, false).await })?; - } else { - run_daemon(cmd, opt)?; - } - - Ok(()) -} - -#[cfg(windows)] -fn run_daemon(_cmd: ListenSubcommand, _opt: CommonOpt) -> Result<(), Error> { - use std::{ - ffi::OsString, - iter, - process::{Command, Stdio}, - }; - let mut args = std::env::args_os(); - let program = args.next().ok_or(Error::Fork)?; - - // Ensure that forked server runs in foreground, otherwise we would fork bomb ourselves - let args = args.chain(iter::once(OsString::from("--foreground"))); - - let child = Command::new(program) - .args(args) - .stdin(Stdio::null()) - .stdout(Stdio::inherit()) - .stderr(Stdio::inherit()) - .spawn()?; - info!("[distant detached, pid = {}]", child.id()); - Ok(()) -} - -#[cfg(unix)] -fn run_daemon(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> { - use fork::{daemon, Fork}; - - // NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent - match daemon(false, true) { - Ok(Fork::Child) => { - let rt = tokio::runtime::Runtime::new()?; - rt.block_on(async { run_async(cmd, opt, true).await })?; - Ok(()) - } - Ok(Fork::Parent(pid)) => { - info!("[distant detached, pid = {}]", pid); - if fork::close_fd().is_err() { - Err(Error::Fork) - } else { - Ok(()) - } - } - Err(_) => Err(Error::Fork), - } -} - -async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> { - let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?; - let shutdown_after = cmd.to_shutdown_after_duration(); - - // If specified, change the current working directory of this program - if let Some(path) = cmd.current_dir.as_ref() { - debug!("Setting current directory to {:?}", path); - std::env::set_current_dir(path)?; - } - - // Bind & start our server - let key = if cmd.key_from_stdin { - let mut buf = [0u8; 32]; - let n = io::stdin().read_exact(&mut buf).await?; - if n < buf.len() { - return Err(Error::BadKey); - } - SecretKey32::from(buf) - } else { - SecretKey32::default() - }; - let key_hex_string = key.unprotected_to_hex_key(); - let codec = XChaCha20Poly1305Codec::from(key); - - let (server, port) = DistantServer::bind( - addr, - cmd.port, - codec, - DistantServerOptions { - shutdown_after, - max_msg_capacity: cmd.max_msg_capacity as usize, - }, - ) - .await?; - - // Print information about port, key, etc. - // NOTE: Following mosh approach of printing to make sure there's no garbage floating around - println!("\r"); - println!("DISTANT CONNECT -- {} {}", port, key_hex_string); - println!("\r"); - io::stdout().flush().await?; - - // For the child, we want to fully disconnect it from pipes, which we do now - #[cfg(unix)] - if is_forked && fork::close_fd().is_err() { - return Err(Error::Fork); - } - - // Let our server run to completion - server.wait().await?; - - Ok(()) -} diff --git a/src/subcommand/lsp.rs b/src/subcommand/lsp.rs deleted file mode 100644 index 2ea54a7..0000000 --- a/src/subcommand/lsp.rs +++ /dev/null @@ -1,118 +0,0 @@ -use crate::{ - exit::{ExitCode, ExitCodeError}, - link::RemoteProcessLink, - opt::{CommonOpt, LspSubcommand}, - subcommand::CommandRunner, - utils, -}; -use derive_more::{Display, Error, From}; -use distant_core::{LspData, PtySize, RemoteLspProcess, RemoteProcessError, Session}; -use terminal_size::{terminal_size, Height, Width}; -use tokio::io; - -#[derive(Debug, Display, Error, From)] -pub enum Error { - #[display(fmt = "Process failed with exit code: {}", _0)] - BadProcessExit(#[error(not(source))] i32), - Io(io::Error), - RemoteProcess(RemoteProcessError), -} - -impl ExitCodeError for Error { - fn is_silent(&self) -> bool { - match self { - Self::RemoteProcess(x) => x.is_silent(), - _ => false, - } - } - - fn to_exit_code(&self) -> ExitCode { - match self { - Self::BadProcessExit(x) => ExitCode::Custom(*x), - Self::Io(x) => x.to_exit_code(), - Self::RemoteProcess(x) => x.to_exit_code(), - } - } -} - -pub fn run(cmd: LspSubcommand, opt: CommonOpt) -> Result<(), Error> { - let rt = tokio::runtime::Runtime::new()?; - - rt.block_on(async { run_async(cmd, opt).await }) -} - -async fn run_async(cmd: LspSubcommand, opt: CommonOpt) -> Result<(), Error> { - let method = cmd.method; - let timeout = opt.to_timeout_duration(); - let ssh_connection = cmd.ssh_connection.clone(); - let session_input = cmd.session; - let session_file = cmd.session_data.session_file.clone(); - let session_socket = cmd.session_data.session_socket.clone(); - - CommandRunner { - method, - ssh_connection, - session_input, - session_file, - session_socket, - timeout, - } - .run( - |session, _, lsp_data| Box::pin(start(cmd, session, lsp_data)), - Error::Io, - ) - .await -} - -async fn start( - cmd: LspSubcommand, - session: Session, - lsp_data: Option, -) -> Result<(), Error> { - let mut proc = RemoteLspProcess::spawn( - utils::new_tenant(), - session.clone_channel(), - cmd.cmd, - cmd.args, - cmd.persist, - if cmd.pty { - terminal_size() - .map(|(Width(width), Height(height))| PtySize::from_rows_and_cols(height, width)) - } else { - None - }, - ) - .await?; - - // If we also parsed an LSP's initialize request for its session, we want to forward - // it along in the case of a process call - if let Some(data) = lsp_data { - proc.stdin - .as_mut() - .unwrap() - .write(data.to_string().as_bytes()) - .await?; - } - - // Now, map the remote LSP server's stdin/stdout/stderr to our own process - let link = RemoteProcessLink::from_remote_lsp_pipes( - proc.stdin.take(), - proc.stdout.take().unwrap(), - proc.stderr.take().unwrap(), - ); - - let (success, exit_code) = proc.wait().await?; - - // Shut down our link - link.shutdown().await; - - if !success { - if let Some(code) = exit_code { - return Err(Error::BadProcessExit(code)); - } else { - return Err(Error::BadProcessExit(1)); - } - } - - Ok(()) -} diff --git a/src/subcommand/mod.rs b/src/subcommand/mod.rs deleted file mode 100644 index c9c970a..0000000 --- a/src/subcommand/mod.rs +++ /dev/null @@ -1,171 +0,0 @@ -use crate::opt::{Method, SessionInput, SshConnectionOpts}; -use distant_core::{ - LspData, PlainCodec, Session, SessionInfo, SessionInfoFile, XChaCha20Poly1305Codec, -}; -use std::{ - future::Future, - io, - net::SocketAddr, - path::{Path, PathBuf}, - pin::Pin, - time::Duration, -}; - -pub mod action; -pub mod launch; -pub mod listen; -pub mod lsp; -pub mod shell; - -struct CommandRunner { - method: Method, - ssh_connection: SshConnectionOpts, - session_input: SessionInput, - session_file: PathBuf, - session_socket: PathBuf, - timeout: Duration, -} - -impl CommandRunner { - async fn run(self, start: F1, wrap_err: F2) -> Result<(), E> - where - F1: FnOnce( - Session, - Duration, - Option, - ) -> Pin>>>, - F2: Fn(io::Error) -> E + Copy, - E: std::error::Error, - { - let CommandRunner { - method, - ssh_connection, - session_input, - session_file, - session_socket, - timeout, - } = self; - - let (session, lsp_data) = match method { - #[cfg(any(feature = "libssh", feature = "ssh2"))] - Method::Ssh => { - use distant_ssh2::{Ssh2Session, Ssh2SessionOpts}; - let SshConnectionOpts { host, port, user } = ssh_connection; - - let mut session = Ssh2Session::connect( - host, - Ssh2SessionOpts { - port: Some(port), - user, - ..Default::default() - }, - ) - .map_err(wrap_err)?; - - session - .authenticate(Default::default()) - .await - .map_err(wrap_err)?; - - ( - session.into_ssh_client_session().await.map_err(wrap_err)?, - None, - ) - } - - Method::Distant => { - let params = retrieve_session_params(session_input, session_file, session_socket) - .await - .map_err(wrap_err)?; - match params { - SessionParams::Tcp { - addr, - codec, - lsp_data, - } => { - let session = Session::tcp_connect_timeout(addr, codec, timeout) - .await - .map_err(wrap_err)?; - (session, lsp_data) - } - #[cfg(unix)] - SessionParams::Socket { path, codec } => { - let session = Session::unix_connect_timeout(path, codec, timeout) - .await - .map_err(wrap_err)?; - (session, None) - } - } - } - }; - - start(session, timeout, lsp_data).await - } -} - -enum SessionParams { - Tcp { - addr: SocketAddr, - codec: XChaCha20Poly1305Codec, - lsp_data: Option, - }, - #[cfg(unix)] - Socket { path: PathBuf, codec: PlainCodec }, -} - -async fn retrieve_session_params( - session_input: SessionInput, - session_file: impl AsRef, - session_socket: impl AsRef, -) -> io::Result { - Ok(match session_input { - SessionInput::Environment => { - let info = SessionInfo::from_environment()?; - let addr = info.to_socket_addr().await?; - let codec = XChaCha20Poly1305Codec::from(info.key); - SessionParams::Tcp { - addr, - codec, - lsp_data: None, - } - } - SessionInput::File => { - let info: SessionInfo = SessionInfoFile::load_from(session_file).await?.into(); - let addr = info.to_socket_addr().await?; - let codec = XChaCha20Poly1305Codec::from(info.key); - SessionParams::Tcp { - addr, - codec, - lsp_data: None, - } - } - SessionInput::Pipe => { - let info = SessionInfo::from_stdin()?; - let addr = info.to_socket_addr().await?; - let codec = XChaCha20Poly1305Codec::from(info.key); - SessionParams::Tcp { - addr, - codec, - lsp_data: None, - } - } - SessionInput::Lsp => { - let mut data = - LspData::from_buf_reader(&mut io::stdin().lock()).map_err(io::Error::from)?; - let info = data.take_session_info().map_err(io::Error::from)?; - let addr = info.to_socket_addr().await?; - let codec = XChaCha20Poly1305Codec::from(info.key); - SessionParams::Tcp { - addr, - codec, - lsp_data: Some(data), - } - } - #[cfg(unix)] - SessionInput::Socket => { - let path = session_socket.as_ref().to_path_buf(); - let codec = PlainCodec::new(); - SessionParams::Socket { path, codec } - } - }) -} diff --git a/src/subcommand/shell.rs b/src/subcommand/shell.rs deleted file mode 100644 index e2aaa8e..0000000 --- a/src/subcommand/shell.rs +++ /dev/null @@ -1,165 +0,0 @@ -use crate::{ - exit::{ExitCode, ExitCodeError}, - link::RemoteProcessLink, - opt::{CommonOpt, ShellSubcommand}, - subcommand::CommandRunner, - utils, -}; -use derive_more::{Display, Error, From}; -use distant_core::{LspData, PtySize, RemoteProcess, RemoteProcessError, Session}; -use log::*; -use terminal_size::{terminal_size, Height, Width}; -use termwiz::{ - caps::Capabilities, - input::{InputEvent, KeyCodeEncodeModes}, - terminal::{new_terminal, Terminal}, -}; -use tokio::{io, time::Duration}; - -#[derive(Debug, Display, Error, From)] -pub enum Error { - #[display(fmt = "Process failed with exit code: {}", _0)] - BadProcessExit(#[error(not(source))] i32), - Io(io::Error), - RemoteProcess(RemoteProcessError), -} - -impl ExitCodeError for Error { - fn is_silent(&self) -> bool { - match self { - Self::RemoteProcess(x) => x.is_silent(), - _ => false, - } - } - - fn to_exit_code(&self) -> ExitCode { - match self { - Self::BadProcessExit(x) => ExitCode::Custom(*x), - Self::Io(x) => x.to_exit_code(), - Self::RemoteProcess(x) => x.to_exit_code(), - } - } -} - -pub fn run(cmd: ShellSubcommand, opt: CommonOpt) -> Result<(), Error> { - let rt = tokio::runtime::Runtime::new()?; - - rt.block_on(async { run_async(cmd, opt).await }) -} - -async fn run_async(cmd: ShellSubcommand, opt: CommonOpt) -> Result<(), Error> { - let method = cmd.method; - let timeout = opt.to_timeout_duration(); - let ssh_connection = cmd.ssh_connection.clone(); - let session_input = cmd.session; - let session_file = cmd.session_data.session_file.clone(); - let session_socket = cmd.session_data.session_socket.clone(); - - CommandRunner { - method, - ssh_connection, - session_input, - session_file, - session_socket, - timeout, - } - .run( - |session, _, lsp_data| Box::pin(start(cmd, session, lsp_data)), - Error::Io, - ) - .await -} - -async fn start( - cmd: ShellSubcommand, - session: Session, - lsp_data: Option, -) -> Result<(), Error> { - let mut proc = RemoteProcess::spawn( - utils::new_tenant(), - session.clone_channel(), - cmd.cmd.unwrap_or_else(|| "/bin/sh".to_string()), - cmd.args, - cmd.persist, - terminal_size().map(|(Width(cols), Height(rows))| PtySize::from_rows_and_cols(rows, cols)), - ) - .await?; - - // If we also parsed an LSP's initialize request for its session, we want to forward - // it along in the case of a process call - if let Some(data) = lsp_data { - proc.stdin - .as_mut() - .unwrap() - .write(data.to_string().as_bytes()) - .await?; - } - - // Create a new terminal in raw mode - let mut terminal = new_terminal( - Capabilities::new_from_env().map_err(|x| io::Error::new(io::ErrorKind::Other, x))?, - ) - .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; - terminal - .set_raw_mode() - .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; - - let mut stdin = proc.stdin.take().unwrap(); - let resizer = proc.clone_resizer(); - tokio::spawn(async move { - while let Ok(input) = terminal.poll_input(Some(Duration::new(0, 0))) { - match input { - Some(InputEvent::Key(ev)) => { - if let Ok(input) = ev.key.encode( - ev.modifiers, - KeyCodeEncodeModes { - enable_csi_u_key_encoding: false, - application_cursor_keys: false, - newline_mode: false, - }, - ) { - if let Err(x) = stdin.write_str(input).await { - error!("Failed to write to stdin of remote process: {}", x); - break; - } - } - } - Some(InputEvent::Resized { cols, rows }) => { - if let Err(x) = resizer - .resize(PtySize::from_rows_and_cols(rows as u16, cols as u16)) - .await - { - error!("Failed to resize remote process: {}", x); - break; - } - } - Some(_) => continue, - None => tokio::time::sleep(Duration::from_millis(1)).await, - } - } - }); - - // Now, map the remote shell's stdout/stderr to our own process, - // while stdin is handled by the task above - let link = RemoteProcessLink::from_remote_pipes( - None, - proc.stdout.take().unwrap(), - proc.stderr.take().unwrap(), - ); - - // Continually loop to check for terminal resize changes while the process is still running - let (success, exit_code) = proc.wait().await?; - - // Shut down our link - link.shutdown().await; - - if !success { - if let Some(code) = exit_code { - return Err(Error::BadProcessExit(code)); - } else { - return Err(Error::BadProcessExit(1)); - } - } - - Ok(()) -} diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index a6fb470..0000000 --- a/src/utils.rs +++ /dev/null @@ -1,4 +0,0 @@ -// Generates a new tenant name -pub fn new_tenant() -> String { - format!("tenant_{}{}", rand::random::(), rand::random::()) -} diff --git a/src/win_service.rs b/src/win_service.rs new file mode 100644 index 0000000..b131497 --- /dev/null +++ b/src/win_service.rs @@ -0,0 +1,223 @@ +use super::Cli; +use anyhow::Context; +use derive_more::From; +use log::*; +use std::{ + ffi::{OsStr, OsString}, + path::Path, + sync::mpsc, + thread, + time::Duration, +}; +use windows_service::{ + define_windows_service, + service::{ + ServiceControl, ServiceControlAccept, ServiceExitCode, ServiceState, ServiceStatus, + ServiceType, + }, + service_control_handler::{self, ServiceControlHandlerResult}, + service_dispatcher, +}; + +const SERVICE_NAME: &str = "distant_manager"; +const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; + +#[derive(serde::Serialize, serde::Deserialize)] +struct Config { + pub args: Vec, +} + +impl Config { + pub fn save(&self) -> anyhow::Result<()> { + let mut bytes = Vec::new(); + serde_json::to_writer(&mut bytes, self).context("Could not convert config into json")?; + std::fs::write(Self::config_file(), bytes).context("Could not write config to file") + } + + pub fn load() -> anyhow::Result { + let bytes = std::fs::read(Self::config_file()).context("Could not read config file")?; + serde_json::from_slice(&bytes).context("Could not convert json into config") + } + + pub fn delete() -> anyhow::Result<()> { + std::fs::remove_file(Self::config_file()).context("Could not delete config file") + } + + /// Stored next to the service exe + fn config_file() -> std::path::PathBuf { + let mut path = std::env::current_exe().unwrap(); + path.set_extension("exe.config"); + path + } +} + +#[derive(From)] +pub enum ServiceError { + /// Any other error type + Anyhow(anyhow::Error), + + /// Represents a service-specific error that we use to known that we are not running as a + /// service + Service(windows_service::Error), +} + +pub fn run() -> Result<(), ServiceError> { + // Save our CLI arguments to pass on to the service + let config = Config { + args: std::env::args_os().collect(), + }; + config.save()?; + + // Attempt to run as a service, deleting our config when completed + // regardless of success + let result = service_dispatcher::start(SERVICE_NAME, ffi_service_main); + let config_result = Config::delete(); + + // Swallow the config error if we have a service error, otherwise display + // the config error + match (result, config_result) { + (Ok(_), Ok(_)) => Ok(()), + (Err(x), _) => Err(ServiceError::Service(x)), + (_, Err(x)) => Err(ServiceError::Anyhow(x)), + } +} + +/// Returns true if running as a windows service +pub fn is_windows_service() -> bool { + use sysinfo::{Pid, PidExt, Process, ProcessExt, System, SystemExt}; + + let mut system = System::new(); + + // Get our own process pid + let pid = Pid::from_u32(std::process::id()); + + // Update our system's knowledge about our process + system.refresh_process(pid); + + // Get our parent process' pid and update sustem's knowledge about parent process + let maybe_parent_pid = system.process(pid).and_then(Process::parent); + if let Some(pid) = maybe_parent_pid { + system.refresh_process(pid); + } + + // Check modeled after https://github.com/dotnet/extensions/blob/9069ee83c6ff1e4471cfbc07215c715c5ce157e1/src/Hosting/WindowsServices/src/WindowsServiceHelpers.cs#L31 + maybe_parent_pid + .and_then(|pid| system.process(pid)) + .map(Process::exe) + .and_then(Path::file_name) + .map(OsStr::to_string_lossy) + .map(|s| s.eq_ignore_ascii_case("services")) + .unwrap_or_default() +} + +define_windows_service!(ffi_service_main, service_main); + +fn service_main(_arguments: Vec) { + if let Err(_e) = run_service() { + // Handle the error, by logging or something. + } +} + +fn run_service() -> windows_service::Result<()> { + debug!("Starting windows service for {SERVICE_NAME}"); + + // Create a channel to be able to poll a stop event from the service worker loop. + let (shutdown_tx, shutdown_rx) = std::sync::mpsc::channel(); + + // Define system service event handler that will be receiving service events. + let event_handler = { + move |control_event| -> ServiceControlHandlerResult { + match control_event { + // Notifies a service to report its current status information to the service + // control manager. Always return NoError even if not implemented. + ServiceControl::Interrogate => ServiceControlHandlerResult::NoError, + + // Handle stop + ServiceControl::Stop => { + shutdown_tx.send(true).unwrap(); + ServiceControlHandlerResult::NoError + } + + _ => ServiceControlHandlerResult::NotImplemented, + } + } + }; + + // Register system service event handler. + // The returned status handle should be used to report service status changes to the system. + debug!("Registering service control handler for {SERVICE_NAME}"); + let status_handle = service_control_handler::register(SERVICE_NAME, event_handler)?; + + // Tell the system that service is running + debug!("Setting service status as running for {SERVICE_NAME}"); + status_handle.set_service_status(ServiceStatus { + service_type: SERVICE_TYPE, + current_state: ServiceState::Running, + controls_accepted: ServiceControlAccept::STOP, + exit_code: ServiceExitCode::Win32(0), + checkpoint: 0, + wait_hint: Duration::default(), + process_id: None, + })?; + + // Kick off thread to run our cli + debug!("Spawning CLI thread for {SERVICE_NAME}"); + let handle = thread::spawn({ + move || { + debug!("Loading CLI using args from disk for {SERVICE_NAME}"); + let config = Config::load()?; + + debug!("Parsing CLI args from disk for {SERVICE_NAME}"); + let cli = Cli::initialize_from(config.args)?; + + debug!("Running CLI for {SERVICE_NAME}"); + cli.run() + } + }); + + // Continually check for a shutdown trigger, catching completion of the thread + // running our CLI as well and reporting errors if they occurred + let success = loop { + if handle.is_finished() { + match handle.join() { + Ok(result) => match result { + Ok(_) => break true, + Err(x) => { + error!("{x:?}"); + break false; + } + }, + Err(x) => { + error!("{x:?}"); + break false; + } + } + } + + match shutdown_rx.try_recv() { + // Break the loop either upon stop or channel disconnect as a success + Ok(_) | Err(mpsc::TryRecvError::Disconnected) => break true, + + // Continue work if no events were received within the timeout + Err(mpsc::TryRecvError::Empty) => thread::sleep(Duration::from_millis(100)), + } + }; + + // Tell the system that service has stopped. + debug!("Setting service status as stopped for {SERVICE_NAME}"); + status_handle.set_service_status(ServiceStatus { + service_type: SERVICE_TYPE, + current_state: ServiceState::Stopped, + controls_accepted: ServiceControlAccept::empty(), + exit_code: if success { + ServiceExitCode::NO_ERROR + } else { + ServiceExitCode::ServiceSpecific(1u32) + }, + checkpoint: 0, + wait_hint: Duration::default(), + process_id: None, + })?; + + Ok(()) +} diff --git a/tests/cli/action/copy.rs b/tests/cli/action/copy.rs index 6aff037..0aa02fc 100644 --- a/tests/cli/action/copy.rs +++ b/tests/cli/action/copy.rs @@ -1,14 +1,6 @@ -use crate::cli::{ - fixtures::*, - utils::{random_tenant, FAILURE_LINE}, -}; +use crate::cli::{fixtures::*, utils::FAILURE_LINE}; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; use predicates::prelude::*; use rstest::*; @@ -19,7 +11,7 @@ that is a file's contents "#; #[rstest] -fn should_support_copying_file(mut action_cmd: Command) { +fn should_support_copying_file(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let src = temp.child("file"); @@ -40,7 +32,7 @@ fn should_support_copying_file(mut action_cmd: Command) { } #[rstest] -fn should_support_copying_nonempty_directory(mut action_cmd: Command) { +fn should_support_copying_nonempty_directory(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); // Make a non-empty directory @@ -65,7 +57,7 @@ fn should_support_copying_nonempty_directory(mut action_cmd: Command) { } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let src = temp.child("dir"); @@ -75,124 +67,10 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { action_cmd .args(&["copy", src.to_str().unwrap(), dst.to_str().unwrap()]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); src.assert(predicate::path::missing()); dst.assert(predicate::path::missing()); } - -#[rstest] -fn should_support_json_copying_file(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let src = temp.child("file"); - src.write_str(FILE_CONTENTS).unwrap(); - - let dst = temp.child("file2"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Copy { - src: src.to_path_buf(), - dst: dst.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Ok); - - src.assert(predicate::path::exists()); - dst.assert(predicate::path::eq_file(src.path())); -} - -#[rstest] -fn should_support_json_copying_nonempty_directory(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - // Make a non-empty directory - let src = temp.child("dir"); - src.create_dir_all().unwrap(); - let src_file = src.child("file"); - src_file.write_str(FILE_CONTENTS).unwrap(); - - let dst = temp.child("dir2"); - let dst_file = dst.child("file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Copy { - src: src.to_path_buf(), - dst: dst.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Ok); - - src_file.assert(predicate::path::exists()); - dst_file.assert(predicate::path::eq_file(src_file.path())); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let src = temp.child("dir"); - let dst = temp.child("dir2"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Copy { - src: src.to_path_buf(), - dst: dst.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - - src.assert(predicate::path::missing()); - dst.assert(predicate::path::missing()); -} diff --git a/tests/cli/action/dir_create.rs b/tests/cli/action/dir_create.rs index 313e6a4..c6cf228 100644 --- a/tests/cli/action/dir_create.rs +++ b/tests/cli/action/dir_create.rs @@ -1,19 +1,11 @@ -use crate::cli::{ - fixtures::*, - utils::{random_tenant, FAILURE_LINE}, -}; +use crate::cli::{fixtures::*, utils::FAILURE_LINE}; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; use predicates::prelude::*; use rstest::*; #[rstest] -fn should_report_ok_when_done(mut action_cmd: Command) { +fn should_report_ok_when_done(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("dir"); @@ -30,7 +22,9 @@ fn should_report_ok_when_done(mut action_cmd: Command) { } #[rstest] -fn should_support_creating_missing_parent_directories_if_specified(mut action_cmd: Command) { +fn should_support_creating_missing_parent_directories_if_specified( + mut action_cmd: CtxCommand, +) { let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("dir1").child("dir2"); @@ -47,7 +41,7 @@ fn should_support_creating_missing_parent_directories_if_specified(mut action_cm } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("missing-dir").child("dir"); @@ -55,108 +49,9 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { action_cmd .args(&["dir-create", dir.to_str().unwrap()]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); dir.assert(predicate::path::missing()); } - -#[rstest] -fn should_support_json_output(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let dir = temp.child("dir"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::DirCreate { - path: dir.to_path_buf(), - all: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Ok); - - dir.assert(predicate::path::exists()); - dir.assert(predicate::path::is_dir()); -} - -#[rstest] -fn should_support_json_creating_missing_parent_directories_if_specified(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let dir = temp.child("dir1").child("dir2"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::DirCreate { - path: dir.to_path_buf(), - all: true, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Ok); - - dir.assert(predicate::path::exists()); - dir.assert(predicate::path::is_dir()); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let dir = temp.child("missing-dir").child("dir"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::DirCreate { - path: dir.to_path_buf(), - all: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - - dir.assert(predicate::path::missing()); -} diff --git a/tests/cli/action/dir_read.rs b/tests/cli/action/dir_read.rs index 0cb3488..acd5ffa 100644 --- a/tests/cli/action/dir_read.rs +++ b/tests/cli/action/dir_read.rs @@ -1,16 +1,11 @@ use crate::cli::{ fixtures::*, - utils::{random_tenant, FAILURE_LINE}, + utils::{regex_pred, FAILURE_LINE}, }; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{DirEntry, Error, ErrorKind, FileType}, - Request, RequestData, Response, ResponseData, -}; use rstest::*; -use std::path::PathBuf; +use std::path::Path; /// Creates a directory in the form /// @@ -76,96 +71,143 @@ fn make_directory() -> assert_fs::TempDir { temp } +fn regex_stdout<'a>(lines: impl IntoIterator) -> String { + let mut s = String::new(); + + s.push('^'); + for (ty, path) in lines { + s.push_str(®ex_line(ty, path)); + } + s.push('$'); + + s +} + +fn regex_line(ty: &str, path: &str) -> String { + format!(r"\s*{ty}\s+{path}\s*\n") +} + #[rstest] -fn should_print_immediate_files_and_directories_by_default(mut action_cmd: Command) { +fn should_print_immediate_files_and_directories_by_default(mut action_cmd: CtxCommand) { let temp = make_directory(); + let expected = regex_pred(®ex_stdout(vec![ + ("", "dir1"), + ("", "dir2"), + ("", "file1"), + ("", "file2"), + ])); + // distant action dir-read {path} action_cmd .args(&["dir-read", temp.to_str().unwrap()]) .assert() .success() - .stdout(concat!("dir1/\n", "dir2/\n", "file1\n", "file2\n")) + .stdout(expected) .stderr(""); } +// NOTE: Ignoring on windows because ssh2 doesn't properly canonicalize paths to resolve symlinks! #[rstest] -fn should_use_absolute_paths_if_specified(mut action_cmd: Command) { +#[cfg_attr(windows, ignore)] +fn should_use_absolute_paths_if_specified(mut action_cmd: CtxCommand) { let temp = make_directory(); // NOTE: Our root path is always canonicalized, so the absolute path // provided is our canonicalized root path prepended let root_path = temp.to_path_buf().canonicalize().unwrap(); + let expected = regex_pred(®ex_stdout(vec![ + ("", root_path.join("dir1").to_str().unwrap()), + ("", root_path.join("dir2").to_str().unwrap()), + ("", root_path.join("file1").to_str().unwrap()), + ("", root_path.join("file2").to_str().unwrap()), + ])); + // distant action dir-read --absolute {path} action_cmd .args(&["dir-read", "--absolute", temp.to_str().unwrap()]) .assert() .success() - .stdout(format!( - "{}\n", - vec![ - format!("{}/{}", root_path.to_str().unwrap(), "dir1/"), - format!("{}/{}", root_path.to_str().unwrap(), "dir2/"), - format!("{}/{}", root_path.to_str().unwrap(), "file1"), - format!("{}/{}", root_path.to_str().unwrap(), "file2"), - ] - .join("\n") - )) + .stdout(expected) .stderr(""); } +// NOTE: Ignoring on windows because ssh2 doesn't properly canonicalize paths to resolve symlinks! #[rstest] -fn should_print_all_files_and_directories_if_depth_is_0(mut action_cmd: Command) { +#[cfg_attr(windows, ignore)] +fn should_print_all_files_and_directories_if_depth_is_0(mut action_cmd: CtxCommand) { let temp = make_directory(); + let expected = regex_pred(®ex_stdout(vec![ + ("", Path::new("dir1").to_str().unwrap()), + ("", Path::new("dir1").join("dira").to_str().unwrap()), + ("", Path::new("dir1").join("dirb").to_str().unwrap()), + ( + "", + Path::new("dir1") + .join("dirb") + .join("file1") + .to_str() + .unwrap(), + ), + ("", Path::new("dir1").join("file1").to_str().unwrap()), + ("", Path::new("dir1").join("file2").to_str().unwrap()), + ("", Path::new("dir2").to_str().unwrap()), + ("", Path::new("dir2").join("dira").to_str().unwrap()), + ("", Path::new("dir2").join("dirb").to_str().unwrap()), + ( + "", + Path::new("dir2") + .join("dirb") + .join("file1") + .to_str() + .unwrap(), + ), + ("", Path::new("dir2").join("file1").to_str().unwrap()), + ("", Path::new("dir2").join("file2").to_str().unwrap()), + ("", Path::new("file1").to_str().unwrap()), + ("", Path::new("file2").to_str().unwrap()), + ])); + // distant action dir-read --depth 0 {path} action_cmd .args(&["dir-read", "--depth", "0", temp.to_str().unwrap()]) .assert() .success() - .stdout(concat!( - "dir1/\n", - "dir1/dira/\n", - "dir1/dirb/\n", - "dir1/dirb/file1\n", - "dir1/file1\n", - "dir1/file2\n", - "dir2/\n", - "dir2/dira/\n", - "dir2/dirb/\n", - "dir2/dirb/file1\n", - "dir2/file1\n", - "dir2/file2\n", - "file1\n", - "file2\n", - )) + .stdout(expected) .stderr(""); } +// NOTE: Ignoring on windows because ssh2 doesn't properly canonicalize paths to resolve symlinks! #[rstest] -fn should_include_root_directory_if_specified(mut action_cmd: Command) { +#[cfg_attr(windows, ignore)] +fn should_include_root_directory_if_specified(mut action_cmd: CtxCommand) { let temp = make_directory(); // NOTE: Our root path is always canonicalized, so yielded entry // is the canonicalized version let root_path = temp.to_path_buf().canonicalize().unwrap(); + let expected = regex_pred(®ex_stdout(vec![ + ("", root_path.to_str().unwrap()), + ("", "dir1"), + ("", "dir2"), + ("", "file1"), + ("", "file2"), + ])); + // distant action dir-read --include-root {path} action_cmd .args(&["dir-read", "--include-root", temp.to_str().unwrap()]) .assert() .success() - .stdout(format!( - "{}/\n{}", - root_path.to_str().unwrap(), - concat!("dir1/\n", "dir2/\n", "file1\n", "file2\n") - )) + .stdout(expected) .stderr(""); } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = make_directory(); let dir = temp.child("missing-dir"); @@ -173,348 +215,7 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { action_cmd .args(&["dir-read", dir.to_str().unwrap()]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); } - -#[rstest] -fn should_support_json_output(mut action_cmd: Command) { - let temp = make_directory(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::DirRead { - path: temp.to_path_buf(), - depth: 1, - absolute: false, - canonicalize: false, - include_root: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!( - res.payload[0], - ResponseData::DirEntries { - entries: vec![ - DirEntry { - path: PathBuf::from("dir1"), - file_type: FileType::Dir, - depth: 1 - }, - DirEntry { - path: PathBuf::from("dir2"), - file_type: FileType::Dir, - depth: 1 - }, - DirEntry { - path: PathBuf::from("file1"), - file_type: FileType::File, - depth: 1 - }, - DirEntry { - path: PathBuf::from("file2"), - file_type: FileType::File, - depth: 1 - }, - ], - errors: Vec::new(), - } - ); -} - -#[rstest] -fn should_support_json_returning_absolute_paths_if_specified(mut action_cmd: Command) { - let temp = make_directory(); - - // NOTE: Our root path is always canonicalized, so the absolute path - // provided is our canonicalized root path prepended - let root_path = temp.to_path_buf().canonicalize().unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::DirRead { - path: temp.to_path_buf(), - depth: 1, - absolute: true, - canonicalize: false, - include_root: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!( - res.payload[0], - ResponseData::DirEntries { - entries: vec![ - DirEntry { - path: root_path.join("dir1"), - file_type: FileType::Dir, - depth: 1 - }, - DirEntry { - path: root_path.join("dir2"), - file_type: FileType::Dir, - depth: 1 - }, - DirEntry { - path: root_path.join("file1"), - file_type: FileType::File, - depth: 1 - }, - DirEntry { - path: root_path.join("file2"), - file_type: FileType::File, - depth: 1 - }, - ], - errors: Vec::new(), - } - ); -} - -#[rstest] -fn should_support_json_returning_all_files_and_directories_if_depth_is_0(mut action_cmd: Command) { - let temp = make_directory(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::DirRead { - path: temp.to_path_buf(), - depth: 0, - absolute: false, - canonicalize: false, - include_root: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!( - res.payload[0], - ResponseData::DirEntries { - /* "dir1/\n", - "dir1/dira/\n", - "dir1/dirb/\n", - "dir1/dirb/file1\n", - "dir1/file1\n", - "dir1/file2\n", - "dir2/\n", - "dir2/dira/\n", - "dir2/dirb/\n", - "dir2/dirb/file1\n", - "dir2/file1\n", - "dir2/file2\n", - "file1\n", - "file2\n", */ - entries: vec![ - DirEntry { - path: PathBuf::from("dir1"), - file_type: FileType::Dir, - depth: 1 - }, - DirEntry { - path: PathBuf::from("dir1").join("dira"), - file_type: FileType::Dir, - depth: 2 - }, - DirEntry { - path: PathBuf::from("dir1").join("dirb"), - file_type: FileType::Dir, - depth: 2 - }, - DirEntry { - path: PathBuf::from("dir1").join("dirb").join("file1"), - file_type: FileType::File, - depth: 3 - }, - DirEntry { - path: PathBuf::from("dir1").join("file1"), - file_type: FileType::File, - depth: 2 - }, - DirEntry { - path: PathBuf::from("dir1").join("file2"), - file_type: FileType::File, - depth: 2 - }, - DirEntry { - path: PathBuf::from("dir2"), - file_type: FileType::Dir, - depth: 1 - }, - DirEntry { - path: PathBuf::from("dir2").join("dira"), - file_type: FileType::Dir, - depth: 2 - }, - DirEntry { - path: PathBuf::from("dir2").join("dirb"), - file_type: FileType::Dir, - depth: 2 - }, - DirEntry { - path: PathBuf::from("dir2").join("dirb").join("file1"), - file_type: FileType::File, - depth: 3 - }, - DirEntry { - path: PathBuf::from("dir2").join("file1"), - file_type: FileType::File, - depth: 2 - }, - DirEntry { - path: PathBuf::from("dir2").join("file2"), - file_type: FileType::File, - depth: 2 - }, - DirEntry { - path: PathBuf::from("file1"), - file_type: FileType::File, - depth: 1 - }, - DirEntry { - path: PathBuf::from("file2"), - file_type: FileType::File, - depth: 1 - }, - ], - errors: Vec::new(), - } - ); -} - -#[rstest] -fn should_support_json_including_root_directory_if_specified(mut action_cmd: Command) { - let temp = make_directory(); - - // NOTE: Our root path is always canonicalized, so yielded entry - // is the canonicalized version - let root_path = temp.to_path_buf().canonicalize().unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::DirRead { - path: temp.to_path_buf(), - depth: 1, - absolute: false, - canonicalize: false, - include_root: true, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!( - res.payload[0], - ResponseData::DirEntries { - entries: vec![ - DirEntry { - path: root_path, - file_type: FileType::Dir, - depth: 0 - }, - DirEntry { - path: PathBuf::from("dir1"), - file_type: FileType::Dir, - depth: 1 - }, - DirEntry { - path: PathBuf::from("dir2"), - file_type: FileType::Dir, - depth: 1 - }, - DirEntry { - path: PathBuf::from("file1"), - file_type: FileType::File, - depth: 1 - }, - DirEntry { - path: PathBuf::from("file2"), - file_type: FileType::File, - depth: 1 - }, - ], - errors: Vec::new(), - } - ); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = make_directory(); - let dir = temp.child("missing-dir"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::DirRead { - path: dir.to_path_buf(), - depth: 1, - absolute: false, - canonicalize: false, - include_root: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); -} diff --git a/tests/cli/action/exists.rs b/tests/cli/action/exists.rs index 5f95b5c..e5acaa3 100644 --- a/tests/cli/action/exists.rs +++ b/tests/cli/action/exists.rs @@ -1,11 +1,10 @@ -use crate::cli::{fixtures::*, utils::random_tenant}; +use crate::cli::fixtures::*; use assert_cmd::Command; use assert_fs::prelude::*; -use distant_core::{Request, RequestData, Response, ResponseData}; use rstest::*; #[rstest] -fn should_output_true_if_exists(mut action_cmd: Command) { +fn should_output_true_if_exists(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); // Create file @@ -22,7 +21,7 @@ fn should_output_true_if_exists(mut action_cmd: Command) { } #[rstest] -fn should_output_false_if_not_exists(mut action_cmd: Command) { +fn should_output_false_if_not_exists(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); // Don't create file @@ -36,60 +35,3 @@ fn should_output_false_if_not_exists(mut action_cmd: Command) { .stdout("false\n") .stderr(""); } - -#[rstest] -fn should_support_json_true_if_exists(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - // Create file - let file = temp.child("file"); - file.touch().unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Exists { - path: file.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Exists { value: true }); -} - -#[rstest] -fn should_support_json_false_if_not_exists(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - // Don't create file - let file = temp.child("file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Exists { - path: file.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Exists { value: false }); -} diff --git a/tests/cli/action/file_append.rs b/tests/cli/action/file_append.rs index 36783e6..3bb0a6c 100644 --- a/tests/cli/action/file_append.rs +++ b/tests/cli/action/file_append.rs @@ -1,14 +1,6 @@ -use crate::cli::{ - fixtures::*, - utils::{random_tenant, FAILURE_LINE}, -}; +use crate::cli::{fixtures::*, utils::FAILURE_LINE}; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; use rstest::*; const FILE_CONTENTS: &str = r#" @@ -23,7 +15,7 @@ file contents "#; #[rstest] -fn should_report_ok_when_done(mut action_cmd: Command) { +fn should_report_ok_when_done(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); file.write_str(FILE_CONTENTS).unwrap(); @@ -49,7 +41,7 @@ fn should_report_ok_when_done(mut action_cmd: Command) { } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); @@ -62,88 +54,10 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { APPENDED_FILE_CONTENTS, ]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); // Because we're talking to a local server, we can verify locally file.assert(predicates::path::missing()); } - -#[rstest] -fn should_support_json_output(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str(FILE_CONTENTS).unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileAppend { - path: file.to_path_buf(), - data: APPENDED_FILE_CONTENTS.as_bytes().to_vec(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // NOTE: We wait a little bit to give the OS time to fully write to file - std::thread::sleep(std::time::Duration::from_millis(100)); - - // Because we're talking to a local server, we can verify locally - file.assert(format!("{}{}", FILE_CONTENTS, APPENDED_FILE_CONTENTS)); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("missing-dir").child("missing-file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileAppend { - path: file.to_path_buf(), - data: APPENDED_FILE_CONTENTS.as_bytes().to_vec(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Because we're talking to a local server, we can verify locally - file.assert(predicates::path::missing()); -} diff --git a/tests/cli/action/file_append_text.rs b/tests/cli/action/file_append_text.rs index 2e7bc02..a603a97 100644 --- a/tests/cli/action/file_append_text.rs +++ b/tests/cli/action/file_append_text.rs @@ -1,14 +1,6 @@ -use crate::cli::{ - fixtures::*, - utils::{random_tenant, FAILURE_LINE}, -}; +use crate::cli::{fixtures::*, utils::FAILURE_LINE}; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; use rstest::*; const FILE_CONTENTS: &str = r#" @@ -23,7 +15,7 @@ file contents "#; #[rstest] -fn should_report_ok_when_done(mut action_cmd: Command) { +fn should_report_ok_when_done(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); file.write_str(FILE_CONTENTS).unwrap(); @@ -49,7 +41,7 @@ fn should_report_ok_when_done(mut action_cmd: Command) { } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); @@ -62,88 +54,10 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { APPENDED_FILE_CONTENTS, ]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); // Because we're talking to a local server, we can verify locally file.assert(predicates::path::missing()); } - -#[rstest] -fn should_support_json_output(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str(FILE_CONTENTS).unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileAppendText { - path: file.to_path_buf(), - text: APPENDED_FILE_CONTENTS.to_string(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // NOTE: We wait a little bit to give the OS time to fully write to file - std::thread::sleep(std::time::Duration::from_millis(100)); - - // Because we're talking to a local server, we can verify locally - file.assert(format!("{}{}", FILE_CONTENTS, APPENDED_FILE_CONTENTS)); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("missing-dir").child("missing-file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileAppendText { - path: file.to_path_buf(), - text: APPENDED_FILE_CONTENTS.to_string(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Because we're talking to a local server, we can verify locally - file.assert(predicates::path::missing()); -} diff --git a/tests/cli/action/file_read.rs b/tests/cli/action/file_read.rs index 3d4a5ff..0d48da9 100644 --- a/tests/cli/action/file_read.rs +++ b/tests/cli/action/file_read.rs @@ -1,14 +1,6 @@ -use crate::cli::{ - fixtures::*, - utils::{random_tenant, FAILURE_LINE}, -}; +use crate::cli::{fixtures::*, utils::FAILURE_LINE}; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; use rstest::*; const FILE_CONTENTS: &str = r#" @@ -18,7 +10,7 @@ that is a file's contents "#; #[rstest] -fn should_print_out_file_contents(mut action_cmd: Command) { +fn should_print_out_file_contents(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); file.write_str(FILE_CONTENTS).unwrap(); @@ -33,7 +25,7 @@ fn should_print_out_file_contents(mut action_cmd: Command) { } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-file"); @@ -41,75 +33,7 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { action_cmd .args(&["file-read", file.to_str().unwrap()]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); } - -#[rstest] -fn should_support_json_output(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str(FILE_CONTENTS).unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileRead { - path: file.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!( - res.payload[0], - ResponseData::Blob { - data: FILE_CONTENTS.as_bytes().to_vec() - } - ); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("missing-file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileRead { - path: file.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); -} diff --git a/tests/cli/action/file_read_text.rs b/tests/cli/action/file_read_text.rs index a4bc9cf..141f69b 100644 --- a/tests/cli/action/file_read_text.rs +++ b/tests/cli/action/file_read_text.rs @@ -1,14 +1,6 @@ -use crate::cli::{ - fixtures::*, - utils::{random_tenant, FAILURE_LINE}, -}; +use crate::cli::{fixtures::*, utils::FAILURE_LINE}; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; use rstest::*; const FILE_CONTENTS: &str = r#" @@ -18,7 +10,7 @@ that is a file's contents "#; #[rstest] -fn should_print_out_file_contents(mut action_cmd: Command) { +fn should_print_out_file_contents(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); file.write_str(FILE_CONTENTS).unwrap(); @@ -33,7 +25,7 @@ fn should_print_out_file_contents(mut action_cmd: Command) { } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-file"); @@ -41,75 +33,7 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { action_cmd .args(&["file-read-text", file.to_str().unwrap()]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); } - -#[rstest] -fn should_support_json_output(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - file.write_str(FILE_CONTENTS).unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileReadText { - path: file.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!( - res.payload[0], - ResponseData::Text { - data: FILE_CONTENTS.to_string() - } - ); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("missing-file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileReadText { - path: file.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); -} diff --git a/tests/cli/action/file_write.rs b/tests/cli/action/file_write.rs index add4fbe..6b005c3 100644 --- a/tests/cli/action/file_write.rs +++ b/tests/cli/action/file_write.rs @@ -1,14 +1,6 @@ -use crate::cli::{ - fixtures::*, - utils::{random_tenant, FAILURE_LINE}, -}; +use crate::cli::{fixtures::*, utils::FAILURE_LINE}; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; use rstest::*; const FILE_CONTENTS: &str = r#" @@ -18,7 +10,7 @@ that is a file's contents "#; #[rstest] -fn should_report_ok_when_done(mut action_cmd: Command) { +fn should_report_ok_when_done(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); @@ -38,7 +30,7 @@ fn should_report_ok_when_done(mut action_cmd: Command) { } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); @@ -46,87 +38,10 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { action_cmd .args(&["file-write", file.to_str().unwrap(), "--", FILE_CONTENTS]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); // Because we're talking to a local server, we can verify locally file.assert(predicates::path::missing()); } - -#[rstest] -fn should_support_json_output(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileWrite { - path: file.to_path_buf(), - data: FILE_CONTENTS.as_bytes().to_vec(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // NOTE: We wait a little bit to give the OS time to fully write to file - std::thread::sleep(std::time::Duration::from_millis(100)); - - // Because we're talking to a local server, we can verify locally - file.assert(FILE_CONTENTS); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("missing-dir").child("missing-file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileWrite { - path: file.to_path_buf(), - data: FILE_CONTENTS.as_bytes().to_vec(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Because we're talking to a local server, we can verify locally - file.assert(predicates::path::missing()); -} diff --git a/tests/cli/action/file_write_text.rs b/tests/cli/action/file_write_text.rs index b238ae8..0bc546f 100644 --- a/tests/cli/action/file_write_text.rs +++ b/tests/cli/action/file_write_text.rs @@ -1,14 +1,6 @@ -use crate::cli::{ - fixtures::*, - utils::{random_tenant, FAILURE_LINE}, -}; +use crate::cli::{fixtures::*, utils::FAILURE_LINE}; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; use rstest::*; const FILE_CONTENTS: &str = r#" @@ -18,7 +10,7 @@ that is a file's contents "#; #[rstest] -fn should_report_ok_when_done(mut action_cmd: Command) { +fn should_report_ok_when_done(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); @@ -43,7 +35,7 @@ fn should_report_ok_when_done(mut action_cmd: Command) { } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); @@ -56,87 +48,10 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { FILE_CONTENTS, ]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); // Because we're talking to a local server, we can verify locally file.assert(predicates::path::missing()); } - -#[rstest] -fn should_support_json_output(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("test-file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileWriteText { - path: file.to_path_buf(), - text: FILE_CONTENTS.to_string(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!(res.payload[0], ResponseData::Ok), - "Unexpected response: {:?}", - res.payload[0] - ); - - // NOTE: We wait a little bit to give the OS time to fully write to file - std::thread::sleep(std::time::Duration::from_millis(100)); - - // Because we're talking to a local server, we can verify locally - file.assert(FILE_CONTENTS); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let file = temp.child("missing-dir").child("missing-file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::FileWriteText { - path: file.to_path_buf(), - text: FILE_CONTENTS.to_string(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Because we're talking to a local server, we can verify locally - file.assert(predicates::path::missing()); -} diff --git a/tests/cli/action/metadata.rs b/tests/cli/action/metadata.rs index 945c7f9..77f7581 100644 --- a/tests/cli/action/metadata.rs +++ b/tests/cli/action/metadata.rs @@ -1,14 +1,9 @@ use crate::cli::{ fixtures::*, - utils::{random_tenant, regex_pred, FAILURE_LINE}, + utils::{regex_pred, FAILURE_LINE}, }; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind, FileType, Metadata}, - Request, RequestData, Response, ResponseData, -}; use rstest::*; const FILE_CONTENTS: &str = r#" @@ -18,7 +13,7 @@ that is a file's contents "#; #[rstest] -fn should_output_metadata_for_file(mut action_cmd: Command) { +fn should_output_metadata_for_file(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); @@ -41,7 +36,7 @@ fn should_output_metadata_for_file(mut action_cmd: Command) { } #[rstest] -fn should_output_metadata_for_directory(mut action_cmd: Command) { +fn should_output_metadata_for_directory(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("dir"); @@ -63,8 +58,10 @@ fn should_output_metadata_for_directory(mut action_cmd: Command) { .stderr(""); } +// NOTE: Ignoring on windows because ssh2 doesn't properly canonicalize paths to resolve symlinks! #[rstest] -fn should_support_including_a_canonicalized_path(mut action_cmd: Command) { +#[cfg_attr(windows, ignore)] +fn should_support_including_a_canonicalized_path(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); @@ -94,7 +91,7 @@ fn should_support_including_a_canonicalized_path(mut action_cmd: Command) { } #[rstest] -fn should_support_resolving_file_type_of_symlink(mut action_cmd: Command) { +fn should_support_resolving_file_type_of_symlink(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); @@ -120,7 +117,7 @@ fn should_support_resolving_file_type_of_symlink(mut action_cmd: Command) { } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); // Don't create file @@ -130,215 +127,7 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { action_cmd .args(&["metadata", file.to_str().unwrap()]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); } - -#[rstest] -fn should_support_json_metadata_for_file(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let file = temp.child("file"); - file.write_str(FILE_CONTENTS).unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Metadata { - path: file.to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Metadata(Metadata { - canonicalized_path: None, - file_type: FileType::File, - readonly: false, - .. - }), - ), - "Unexpected response: {:?}", - res.payload[0], - ); -} - -#[rstest] -fn should_support_json_metadata_for_directory(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Metadata { - path: dir.to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Metadata(Metadata { - canonicalized_path: None, - file_type: FileType::Dir, - readonly: false, - .. - }), - ), - "Unexpected response: {:?}", - res.payload[0], - ); -} - -#[rstest] -fn should_support_json_metadata_for_including_a_canonicalized_path(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let file = temp.child("file"); - file.touch().unwrap(); - - let link = temp.child("link"); - link.symlink_to_file(file.path()).unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Metadata { - path: link.to_path_buf(), - canonicalize: true, - resolve_file_type: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - match &res.payload[0] { - ResponseData::Metadata(Metadata { - canonicalized_path: Some(path), - file_type: FileType::Symlink, - readonly: false, - .. - }) => assert_eq!(path, &file.path().canonicalize().unwrap()), - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -fn should_support_json_metadata_for_resolving_file_type_of_symlink(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let file = temp.child("file"); - file.touch().unwrap(); - - let link = temp.child("link"); - link.symlink_to_file(file.path()).unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Metadata { - path: link.to_path_buf(), - canonicalize: true, - resolve_file_type: true, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Metadata(Metadata { - file_type: FileType::File, - .. - }), - ), - "Unexpected response: {:?}", - res.payload[0], - ); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - // Don't create file - let file = temp.child("file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Metadata { - path: file.to_path_buf(), - canonicalize: false, - resolve_file_type: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); -} diff --git a/tests/cli/action/proc_spawn.rs b/tests/cli/action/proc_spawn.rs index 2e6672b..b87e9d3 100644 --- a/tests/cli/action/proc_spawn.rs +++ b/tests/cli/action/proc_spawn.rs @@ -1,465 +1,139 @@ -use crate::cli::{ - fixtures::*, - utils::{distant_subcommand, friendly_recv_line, random_tenant, spawn_line_reader}, -}; +use crate::cli::{fixtures::*, scripts::*, utils::regex_pred}; use assert_cmd::Command; -use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; -use once_cell::sync::Lazy; use rstest::*; -use std::{io::Write, time::Duration}; - -static TEMP_SCRIPT_DIR: Lazy = Lazy::new(|| assert_fs::TempDir::new().unwrap()); -static SCRIPT_RUNNER: Lazy = Lazy::new(|| String::from("bash")); - -static ECHO_ARGS_TO_STDOUT_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("echo_args_to_stdout.sh"); - script - .write_str(indoc::indoc!( - r#" - #/usr/bin/env bash - printf "%s" "$*" - "# - )) - .unwrap(); - script -}); - -static ECHO_ARGS_TO_STDERR_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("echo_args_to_stderr.sh"); - script - .write_str(indoc::indoc!( - r#" - #/usr/bin/env bash - printf "%s" "$*" 1>&2 - "# - )) - .unwrap(); - script -}); - -static ECHO_STDIN_TO_STDOUT_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("echo_stdin_to_stdout.sh"); - script - .write_str(indoc::indoc!( - r#" - #/usr/bin/env bash - while IFS= read; do echo "$REPLY"; done - "# - )) - .unwrap(); - script -}); - -static EXIT_CODE_SH: Lazy = Lazy::new(|| { - let script = TEMP_SCRIPT_DIR.child("exit_code.sh"); - script - .write_str(indoc::indoc!( - r#" - #!/usr/bin/env bash - exit "$1" - "# - )) - .unwrap(); - script -}); - -static DOES_NOT_EXIST_BIN: Lazy = - Lazy::new(|| TEMP_SCRIPT_DIR.child("does_not_exist_bin")); - -macro_rules! next_two_msgs { - ($rx:expr) => {{ - let out = friendly_recv_line($rx, Duration::from_secs(1)).unwrap(); - let res1: Response = serde_json::from_str(&out).unwrap(); - let out = friendly_recv_line($rx, Duration::from_secs(1)).unwrap(); - let res2: Response = serde_json::from_str(&out).unwrap(); - (res1, res2) - }}; -} +use std::process::Command as StdCommand; #[rstest] -fn should_execute_program_and_return_exit_status(mut action_cmd: Command) { +fn should_execute_program_and_return_exit_status(mut action_cmd: CtxCommand) { + // Windows prints out a message whereas unix prints nothing + #[cfg(windows)] + let stdout = regex_pred(".+"); + #[cfg(unix)] + let stdout = ""; + // distant action proc-spawn -- {cmd} [args] action_cmd .args(&["proc-spawn", "--"]) .arg(SCRIPT_RUNNER.as_str()) - .arg(EXIT_CODE_SH.to_str().unwrap()) + .arg(SCRIPT_RUNNER_ARG.as_str()) + .arg(EXIT_CODE.to_str().unwrap()) .arg("0") .assert() .success() - .stdout("") + .stdout(stdout) .stderr(""); } #[rstest] -fn should_capture_and_print_stdout(mut action_cmd: Command) { +fn should_capture_and_print_stdout(mut action_cmd: CtxCommand) { // distant action proc-spawn {cmd} [args] action_cmd .args(&["proc-spawn", "--"]) .arg(SCRIPT_RUNNER.as_str()) - .arg(ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap()) + .arg(SCRIPT_RUNNER_ARG.as_str()) + .arg(ECHO_ARGS_TO_STDOUT.to_str().unwrap()) .arg("hello world") .assert() .success() - .stdout("hello world") + .stdout(if cfg!(windows) { + "hello world\r\n" + } else { + "hello world" + }) .stderr(""); } #[rstest] -fn should_capture_and_print_stderr(mut action_cmd: Command) { +fn should_capture_and_print_stderr(mut action_cmd: CtxCommand) { // distant action proc-spawn {cmd} [args] action_cmd .args(&["proc-spawn", "--"]) .arg(SCRIPT_RUNNER.as_str()) - .arg(ECHO_ARGS_TO_STDERR_SH.to_str().unwrap()) + .arg(SCRIPT_RUNNER_ARG.as_str()) + .arg(ECHO_ARGS_TO_STDERR.to_str().unwrap()) .arg("hello world") .assert() .success() .stdout("") - .stderr("hello world"); + .stderr(if cfg!(windows) { + "hello world \r\n" + } else { + "hello world" + }); } +// TODO: This used to work fine with the assert_cmd where stdin would close from our +// process, which would in turn lead to the remote process stdin being closed +// and then the process exiting. This may be a bug we've introduced with the +// refactor and should be revisited some day. #[rstest] -fn should_forward_stdin_to_remote_process(mut action_cmd: Command) { +fn should_forward_stdin_to_remote_process(mut action_std_cmd: CtxCommand) { + use std::io::{BufRead, BufReader, Write}; + // distant action proc-spawn {cmd} [args] - action_cmd + let mut child = action_std_cmd .args(&["proc-spawn", "--"]) .arg(SCRIPT_RUNNER.as_str()) - .arg(ECHO_STDIN_TO_STDOUT_SH.to_str().unwrap()) - .write_stdin("hello world\n") - .assert() - .success() - .stdout("hello world\n") - .stderr(""); + .arg(SCRIPT_RUNNER_ARG.as_str()) + .arg(ECHO_STDIN_TO_STDOUT.to_str().unwrap()) + .spawn() + .expect("Failed to spawn process"); + + child + .stdin + .as_mut() + .unwrap() + .write_all(if cfg!(windows) { + b"hello world\r\n" + } else { + b"hello world\n" + }) + .expect("Failed to write to stdin of process"); + + let mut stdout = BufReader::new(child.stdout.take().unwrap()); + let mut line = String::new(); + stdout.read_line(&mut line).expect("Failed to read line"); + assert_eq!( + line, + if cfg!(windows) { + "hello world\r\n" + } else { + "hello world\n" + } + ); + + child.kill().expect("Failed to kill spawned process"); } #[rstest] -fn reflect_the_exit_code_of_the_process(mut action_cmd: Command) { +fn reflect_the_exit_code_of_the_process(mut action_cmd: CtxCommand) { + // Windows prints out a message whereas unix prints nothing + #[cfg(windows)] + let stdout = regex_pred(".+"); + #[cfg(unix)] + let stdout = ""; + // distant action proc-spawn {cmd} [args] action_cmd .args(&["proc-spawn", "--"]) .arg(SCRIPT_RUNNER.as_str()) - .arg(EXIT_CODE_SH.to_str().unwrap()) + .arg(SCRIPT_RUNNER_ARG.as_str()) + .arg(EXIT_CODE.to_str().unwrap()) .arg("99") .assert() .code(99) - .stdout("") + .stdout(stdout) .stderr(""); } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { // distant action proc-spawn {cmd} [args] action_cmd .args(&["proc-spawn", "--"]) .arg(DOES_NOT_EXIST_BIN.to_str().unwrap()) .assert() - .code(ExitCode::IoError.to_i32()) + .code(1) .stdout("") - .stderr(""); -} - -#[rstest] -fn should_support_json_to_execute_program_and_return_exit_status(mut action_cmd: Command) { - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap().to_string()], - persist: false, - pty: None, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!(res.payload[0], ResponseData::ProcSpawned { .. }), - "Unexpected response: {:?}", - res.payload[0], - ); -} - -#[rstest] -fn should_support_json_to_capture_and_print_stdout(ctx: &'_ DistantServerCtx) { - let output = String::from("some output"); - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ - ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap().to_string(), - output.to_string(), - ], - persist: false, - pty: None, - }], - }; - - // distant action --format json --interactive - let mut child = distant_subcommand(ctx, "action") - .args(&["--format", "json"]) - .arg("--interactive") - .spawn() - .unwrap(); - - let mut stdin = child.stdin.take().unwrap(); - let stdout = spawn_line_reader(child.stdout.take().unwrap()); - let stderr = spawn_line_reader(child.stderr.take().unwrap()); - - // Send our request as json - let req_string = format!("{}\n", serde_json::to_string(&req).unwrap()); - stdin.write_all(req_string.as_bytes()).unwrap(); - stdin.flush().unwrap(); - - // Get the indicator of a process started (first line returned can take ~7 seconds due to the - // handshake cost) - let out = - friendly_recv_line(&stdout, Duration::from_secs(30)).expect("Failed to get proc start"); - let res: Response = serde_json::from_str(&out).unwrap(); - assert!( - matches!(res.payload[0], ResponseData::ProcSpawned { .. }), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Get stdout from process and verify it - let out = - friendly_recv_line(&stdout, Duration::from_secs(1)).expect("Failed to get proc stdout"); - let res: Response = serde_json::from_str(&out).unwrap(); - match &res.payload[0] { - ResponseData::ProcStdout { data, .. } => assert_eq!(data, output.as_bytes()), - x => panic!("Unexpected response: {:?}", x), - }; - - // Get the indicator of a process completion - let out = friendly_recv_line(&stdout, Duration::from_secs(1)).expect("Failed to get proc done"); - let res: Response = serde_json::from_str(&out).unwrap(); - match &res.payload[0] { - ResponseData::ProcDone { success, .. } => { - assert!(success, "Process failed unexpectedly"); - } - x => panic!("Unexpected response: {:?}", x), - }; - - // Verify that we received nothing on stderr channel - assert!( - stderr.try_recv().is_err(), - "Unexpectedly got result on stderr channel" - ); -} - -#[rstest] -fn should_support_json_to_capture_and_print_stderr(ctx: &'_ DistantServerCtx) { - let output = String::from("some output"); - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ - ECHO_ARGS_TO_STDERR_SH.to_str().unwrap().to_string(), - output.to_string(), - ], - persist: false, - pty: None, - }], - }; - - // distant action --format json --interactive - let mut child = distant_subcommand(ctx, "action") - .args(&["--format", "json"]) - .arg("--interactive") - .spawn() - .unwrap(); - - let mut stdin = child.stdin.take().unwrap(); - let stdout = spawn_line_reader(child.stdout.take().unwrap()); - let stderr = spawn_line_reader(child.stderr.take().unwrap()); - - // Send our request as json - let req_string = format!("{}\n", serde_json::to_string(&req).unwrap()); - stdin.write_all(req_string.as_bytes()).unwrap(); - stdin.flush().unwrap(); - - // Get the indicator of a process started (first line returned can take ~7 seconds due to the - // handshake cost) - let out = - friendly_recv_line(&stdout, Duration::from_secs(30)).expect("Failed to get proc start"); - let res: Response = serde_json::from_str(&out).unwrap(); - assert!( - matches!(res.payload[0], ResponseData::ProcSpawned { .. }), - "Unexpected response: {:?}", - res.payload[0] - ); - - // Get stderr from process and verify it - let out = - friendly_recv_line(&stdout, Duration::from_secs(1)).expect("Failed to get proc stderr"); - let res: Response = serde_json::from_str(&out).unwrap(); - match &res.payload[0] { - ResponseData::ProcStderr { data, .. } => assert_eq!(data, output.as_bytes()), - x => panic!("Unexpected response: {:?}", x), - }; - - // Get the indicator of a process completion - let out = friendly_recv_line(&stdout, Duration::from_secs(1)).expect("Failed to get proc done"); - let res: Response = serde_json::from_str(&out).unwrap(); - match &res.payload[0] { - ResponseData::ProcDone { success, .. } => { - assert!(success, "Process failed unexpectedly"); - } - x => panic!("Unexpected response: {:?}", x), - }; - - // Verify that we received nothing on stderr channel - assert!( - stderr.try_recv().is_err(), - "Unexpectedly got result on stderr channel" - ); -} - -#[rstest] -fn should_support_json_to_forward_stdin_to_remote_process(ctx: &'_ DistantServerCtx) { - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::ProcSpawn { - cmd: SCRIPT_RUNNER.to_string(), - args: vec![ECHO_STDIN_TO_STDOUT_SH.to_str().unwrap().to_string()], - persist: false, - pty: None, - }], - }; - - // distant action --format json --interactive - let mut child = distant_subcommand(ctx, "action") - .args(&["--format", "json"]) - .arg("--interactive") - .spawn() - .unwrap(); - - let mut stdin = child.stdin.take().unwrap(); - let stdout = spawn_line_reader(child.stdout.take().unwrap()); - let stderr = spawn_line_reader(child.stderr.take().unwrap()); - - // Send our request as json - let req_string = format!("{}\n", serde_json::to_string(&req).unwrap()); - stdin.write_all(req_string.as_bytes()).unwrap(); - stdin.flush().unwrap(); - - // Get the indicator of a process started (first line returned can take ~7 seconds due to the - // handshake cost) - let out = - friendly_recv_line(&stdout, Duration::from_secs(30)).expect("Failed to get proc start"); - let res: Response = serde_json::from_str(&out).unwrap(); - let id = match &res.payload[0] { - ResponseData::ProcSpawned { id } => *id, - x => panic!("Unexpected response: {:?}", x), - }; - - // Send stdin to remote process - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::ProcStdin { - id, - data: b"hello world\n".to_vec(), - }], - }; - let req_string = format!("{}\n", serde_json::to_string(&req).unwrap()); - stdin.write_all(req_string.as_bytes()).unwrap(); - stdin.flush().unwrap(); - - // Should receive ok message & stdout message, although these may be in different order - let (res1, res2) = next_two_msgs!(&stdout); - match (&res1.payload[0], &res2.payload[0]) { - (ResponseData::Ok, ResponseData::ProcStdout { data, .. }) => { - assert_eq!(data, b"hello world\n") - } - (ResponseData::ProcStdout { data, .. }, ResponseData::Ok) => { - assert_eq!(data, b"hello world\n") - } - x => panic!("Unexpected responses: {:?}", x), - }; - - // Kill the remote process since it only terminates when stdin closes, but we - // want to verify that we get a proc done is some manner, which won't happen - // if stdin closes as our interactive process will also close - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::ProcKill { id }], - }; - let req_string = format!("{}\n", serde_json::to_string(&req).unwrap()); - stdin.write_all(req_string.as_bytes()).unwrap(); - stdin.flush().unwrap(); - - // Should receive ok message & process completion - let (res1, res2) = next_two_msgs!(&stdout); - match (&res1.payload[0], &res2.payload[0]) { - (ResponseData::Ok, ResponseData::ProcDone { success, .. }) => { - assert!(!success, "Process succeeded unexpectedly"); - } - (ResponseData::ProcDone { success, .. }, ResponseData::Ok) => { - assert!(!success, "Process succeeded unexpectedly"); - } - x => panic!("Unexpected responses: {:?}", x), - }; - - // Verify that we received nothing on stderr channel - assert!( - stderr.try_recv().is_err(), - "Unexpectedly got result on stderr channel" - ); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::ProcSpawn { - cmd: DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), - args: Vec::new(), - persist: false, - pty: None, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); + .stderr(regex_pred(".+")); } diff --git a/tests/cli/action/remove.rs b/tests/cli/action/remove.rs index 8ef3a08..681498c 100644 --- a/tests/cli/action/remove.rs +++ b/tests/cli/action/remove.rs @@ -1,19 +1,11 @@ -use crate::cli::{ - fixtures::*, - utils::{random_tenant, FAILURE_LINE}, -}; +use crate::cli::{fixtures::*, utils::FAILURE_LINE}; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; use predicates::prelude::*; use rstest::*; #[rstest] -fn should_support_removing_file(mut action_cmd: Command) { +fn should_support_removing_file(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); file.touch().unwrap(); @@ -30,7 +22,7 @@ fn should_support_removing_file(mut action_cmd: Command) { } #[rstest] -fn should_support_removing_empty_directory(mut action_cmd: Command) { +fn should_support_removing_empty_directory(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); // Make an empty directory @@ -49,7 +41,9 @@ fn should_support_removing_empty_directory(mut action_cmd: Command) { } #[rstest] -fn should_support_removing_nonempty_directory_if_force_specified(mut action_cmd: Command) { +fn should_support_removing_nonempty_directory_if_force_specified( + mut action_cmd: CtxCommand, +) { let temp = assert_fs::TempDir::new().unwrap(); // Make a non-empty directory @@ -69,7 +63,7 @@ fn should_support_removing_nonempty_directory_if_force_specified(mut action_cmd: } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); // Make a non-empty directory @@ -81,157 +75,10 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { action_cmd .args(&["remove", dir.to_str().unwrap()]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); dir.assert(predicate::path::exists()); dir.assert(predicate::path::is_dir()); } - -#[rstest] -fn should_support_json_removing_file(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let file = temp.child("file"); - file.touch().unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Remove { - path: file.to_path_buf(), - force: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Ok); - - file.assert(predicate::path::missing()); -} - -#[rstest] -fn should_support_json_removing_empty_directory(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - // Make an empty directory - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Remove { - path: dir.to_path_buf(), - force: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Ok); - - dir.assert(predicate::path::missing()); -} - -#[rstest] -fn should_support_json_removing_nonempty_directory_if_force_specified(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - // Make an empty directory - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Remove { - path: dir.to_path_buf(), - force: true, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Ok); - - dir.assert(predicate::path::missing()); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - // Make a non-empty directory so we fail to remove it - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - dir.child("file").touch().unwrap(); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Remove { - path: dir.to_path_buf(), - force: false, - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - // NOTE: After some refactoring, unknown error type shows up in - // our CI but not on my local machine. I can't pin it down. - // The description matches what we'd expect regarding the - // directory not being empty, so for now going to support - // either of these error kinds. - ResponseData::Error(Error { - kind: ErrorKind::Other, - .. - }) | ResponseData::Error(Error { - kind: ErrorKind::Unknown, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - - dir.assert(predicate::path::exists()); - dir.assert(predicate::path::is_dir()); -} diff --git a/tests/cli/action/rename.rs b/tests/cli/action/rename.rs index 8103ff1..2f57abb 100644 --- a/tests/cli/action/rename.rs +++ b/tests/cli/action/rename.rs @@ -1,14 +1,6 @@ -use crate::cli::{ - fixtures::*, - utils::{random_tenant, FAILURE_LINE}, -}; +use crate::cli::{fixtures::*, utils::FAILURE_LINE}; use assert_cmd::Command; use assert_fs::prelude::*; -use distant::ExitCode; -use distant_core::{ - data::{Error, ErrorKind}, - Request, RequestData, Response, ResponseData, -}; use predicates::prelude::*; use rstest::*; @@ -19,7 +11,7 @@ that is a file's contents "#; #[rstest] -fn should_support_renaming_file(mut action_cmd: Command) { +fn should_support_renaming_file(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let src = temp.child("file"); @@ -40,7 +32,7 @@ fn should_support_renaming_file(mut action_cmd: Command) { } #[rstest] -fn should_support_renaming_nonempty_directory(mut action_cmd: Command) { +fn should_support_renaming_nonempty_directory(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); // Make a non-empty directory @@ -68,7 +60,7 @@ fn should_support_renaming_nonempty_directory(mut action_cmd: Command) { } #[rstest] -fn yield_an_error_when_fails(mut action_cmd: Command) { +fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let src = temp.child("dir"); @@ -78,127 +70,10 @@ fn yield_an_error_when_fails(mut action_cmd: Command) { action_cmd .args(&["rename", src.to_str().unwrap(), dst.to_str().unwrap()]) .assert() - .code(ExitCode::Software.to_i32()) + .code(1) .stdout("") .stderr(FAILURE_LINE.clone()); src.assert(predicate::path::missing()); dst.assert(predicate::path::missing()); } - -#[rstest] -fn should_support_json_renaming_file(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let src = temp.child("file"); - src.write_str(FILE_CONTENTS).unwrap(); - - let dst = temp.child("file2"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Rename { - src: src.to_path_buf(), - dst: dst.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Ok); - - src.assert(predicate::path::missing()); - dst.assert(FILE_CONTENTS); -} - -#[rstest] -fn should_support_json_renaming_nonempty_directory(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - // Make a non-empty directory - let src = temp.child("dir"); - src.create_dir_all().unwrap(); - let src_file = src.child("file"); - src_file.write_str(FILE_CONTENTS).unwrap(); - - let dst = temp.child("dir2"); - let dst_file = dst.child("file"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Rename { - src: src.to_path_buf(), - dst: dst.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert_eq!(res.payload[0], ResponseData::Ok); - - src.assert(predicate::path::missing()); - src_file.assert(predicate::path::missing()); - - dst.assert(predicate::path::is_dir()); - dst_file.assert(FILE_CONTENTS); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let src = temp.child("dir"); - let dst = temp.child("dir2"); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Rename { - src: src.to_path_buf(), - dst: dst.to_path_buf(), - }], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - assert!( - matches!( - res.payload[0], - ResponseData::Error(Error { - kind: ErrorKind::NotFound, - .. - }) - ), - "Unexpected response: {:?}", - res.payload[0] - ); - - src.assert(predicate::path::missing()); - dst.assert(predicate::path::missing()); -} diff --git a/tests/cli/action/system_info.rs b/tests/cli/action/system_info.rs index 2f26002..40847a0 100644 --- a/tests/cli/action/system_info.rs +++ b/tests/cli/action/system_info.rs @@ -1,11 +1,10 @@ -use crate::cli::{fixtures::*, utils::random_tenant}; +use crate::cli::fixtures::*; use assert_cmd::Command; -use distant_core::{data::SystemInfo, Request, RequestData, Response, ResponseData}; use rstest::*; use std::env; #[rstest] -fn should_output_system_info(mut action_cmd: Command) { +fn should_output_system_info(mut action_cmd: CtxCommand) { // distant action system-info action_cmd .arg("system-info") @@ -27,38 +26,3 @@ fn should_output_system_info(mut action_cmd: Command) { )) .stderr(""); } - -#[rstest] -fn should_support_json_system_info(mut action_cmd: Command) { - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::SystemInfo {}], - }; - - // distant action --format json --interactive - let cmd = action_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap())) - .assert() - .success() - .stderr(""); - - let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap(); - match &res.payload[0] { - ResponseData::SystemInfo(info) => { - assert_eq!( - info, - &SystemInfo { - family: env::consts::FAMILY.to_string(), - os: env::consts::OS.to_string(), - arch: env::consts::ARCH.to_string(), - current_dir: env::current_dir().unwrap_or_default(), - main_separator: std::path::MAIN_SEPARATOR, - } - ); - } - x => panic!("Unexpected response: {:?}", x), - } -} diff --git a/tests/cli/action/watch.rs b/tests/cli/action/watch.rs index e072dda..d4df696 100644 --- a/tests/cli/action/watch.rs +++ b/tests/cli/action/watch.rs @@ -1,16 +1,7 @@ -use crate::cli::{fixtures::*, utils::random_tenant}; +use crate::cli::{fixtures::*, utils::ThreadedReader}; use assert_fs::prelude::*; -use distant_core::{data::ErrorKind, Request, RequestData, Response, ResponseData}; use rstest::*; -use std::{ - io, - io::{BufRead, BufReader, Read, Write}, - path::PathBuf, - process::Command, - sync::mpsc, - thread, - time::{Duration, Instant}, -}; +use std::{process::Command, thread, time::Duration}; fn wait_a_bit() { wait_millis(250); @@ -24,165 +15,8 @@ fn wait_millis(millis: u64) { thread::sleep(Duration::from_millis(millis)); } -struct ThreadedReader { - #[allow(dead_code)] - handle: thread::JoinHandle>, - rx: mpsc::Receiver, -} - -impl ThreadedReader { - pub fn new(reader: R) -> Self - where - R: Read + Send + 'static, - { - let (tx, rx) = mpsc::channel(); - let handle = thread::spawn(move || { - let mut reader = BufReader::new(reader); - let mut line = String::new(); - loop { - match reader.read_line(&mut line) { - Ok(0) => break Ok(()), - Ok(_) => { - // Consume the line and create an empty line to - // be filled in next time - let line2 = line; - line = String::new(); - - if let Err(line) = tx.send(line2) { - return Err(io::Error::new( - io::ErrorKind::Other, - format!( - "Failed to pass along line because channel closed! Line: '{}'", - line.0 - ), - )); - } - } - Err(x) => return Err(x), - } - } - }); - Self { handle, rx } - } - - /// Tries to read the next line if available - pub fn try_read_line(&mut self) -> Option { - self.rx.try_recv().ok() - } - - /// Reads the next line, waiting for at minimum "timeout" - pub fn try_read_line_timeout(&mut self, timeout: Duration) -> Option { - let start_time = Instant::now(); - let mut checked_at_least_once = false; - - while !checked_at_least_once || start_time.elapsed() < timeout { - if let Some(line) = self.try_read_line() { - return Some(line); - } - - checked_at_least_once = true; - } - - None - } - - /// Reads the next line, waiting for at minimum "timeout" before panicking - pub fn read_line_timeout(&mut self, timeout: Duration) -> String { - let start_time = Instant::now(); - let mut checked_at_least_once = false; - - while !checked_at_least_once || start_time.elapsed() < timeout { - if let Some(line) = self.try_read_line() { - return line; - } - - checked_at_least_once = true; - } - - panic!("Reached timeout of {:?}", timeout); - } - - /// Reads the next line, waiting for at minimum default timeout before panicking - #[allow(dead_code)] - pub fn read_line_default_timeout(&mut self) -> String { - self.read_line_timeout(Self::default_timeout()) - } - - /// Tries to read the next response if available - /// - /// Will panic if next line is not a valid response - #[allow(dead_code)] - pub fn try_read_response(&mut self) -> Option { - self.try_read_line().map(|line| { - serde_json::from_str(&line) - .unwrap_or_else(|_| panic!("Invalid response format for {}", line)) - }) - } - - /// Reads the next response, waiting for at minimum "timeout" before panicking - pub fn read_response_timeout(&mut self, timeout: Duration) -> Response { - let line = self.read_line_timeout(timeout); - serde_json::from_str(&line) - .unwrap_or_else(|_| panic!("Invalid response format for {}", line)) - } - - /// Reads the next response, waiting for at minimum default timeout before panicking - pub fn read_response_default_timeout(&mut self) -> Response { - self.read_response_timeout(Self::default_timeout()) - } - - /// Creates a new duration representing a default timeout for the threaded reader - pub fn default_timeout() -> Duration { - Duration::from_millis(250) - } - - /// Waits for reader to shut down, returning the result - #[allow(dead_code)] - pub fn wait(self) -> io::Result<()> { - match self.handle.join() { - Ok(x) => x, - Err(x) => std::panic::resume_unwind(x), - } - } -} - -fn send_watch_request( - writer: &mut W, - reader: &mut ThreadedReader, - path: impl Into, - recursive: bool, -) -> Response -where - W: Write, -{ - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Watch { - path: path.into(), - recursive, - only: Vec::new(), - except: Vec::new(), - }], - }; - - // Send our request to the process - let msg = format!("{}\n", serde_json::to_string(&req).unwrap()); - writer - .write_all(msg.as_bytes()) - .expect("Failed to write to process"); - - // Pause a bit to ensure that the process started and processed our request - wait_a_bit(); - - // Ensure we got an acknowledgement of watching - let res = reader.read_response_default_timeout(); - assert_eq!(res.payload[0], ResponseData::Ok); - res -} - #[rstest] -fn should_support_watching_a_single_file(mut action_std_cmd: Command) { +fn should_support_watching_a_single_file(mut action_std_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); file.touch().unwrap(); @@ -209,7 +43,7 @@ fn should_support_watching_a_single_file(mut action_std_cmd: Command) { } // Close out the process and collect the output - let _ = child.kill().expect("Failed to terminate process"); + child.kill().expect("Failed to terminate process"); let output = child.wait_with_output().expect("Failed to wait for output"); let stderr_data = String::from_utf8_lossy(&output.stderr).to_string(); @@ -232,7 +66,7 @@ fn should_support_watching_a_single_file(mut action_std_cmd: Command) { } #[rstest] -fn should_support_watching_a_directory_recursively(mut action_std_cmd: Command) { +fn should_support_watching_a_directory_recursively(mut action_std_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("dir"); @@ -263,7 +97,7 @@ fn should_support_watching_a_directory_recursively(mut action_std_cmd: Command) } // Close out the process and collect the output - let _ = child.kill().expect("Failed to terminate process"); + child.kill().expect("Failed to terminate process"); let output = child.wait_with_output().expect("Failed to wait for output"); let stderr_data = String::from_utf8_lossy(&output.stderr).to_string(); @@ -286,7 +120,7 @@ fn should_support_watching_a_directory_recursively(mut action_std_cmd: Command) } #[rstest] -fn yield_an_error_when_fails(mut action_std_cmd: Command) { +fn yield_an_error_when_fails(mut action_std_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let invalid_path = temp.to_path_buf().join("missing"); @@ -308,213 +142,3 @@ fn yield_an_error_when_fails(mut action_std_cmd: Command) { assert!(output.stdout.is_empty(), "Unexpectedly got stdout"); assert!(!output.stderr.is_empty(), "Missing stderr output"); } - -#[rstest] -fn should_support_json_watching_single_file(mut action_std_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let file = temp.child("file"); - file.touch().unwrap(); - - // distant action --format json --interactive - let mut cmd = action_std_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .spawn() - .expect("Failed to execute"); - let mut stdin = cmd.stdin.take().unwrap(); - let mut stdout = ThreadedReader::new(cmd.stdout.take().unwrap()); - - let _ = send_watch_request(&mut stdin, &mut stdout, file.to_path_buf(), false); - - // Make a change to some file - file.write_str("some text").unwrap(); - - // Pause a bit to ensure that the process detected the change and reported it - wait_even_longer(); - - // Get the response and verify the change - // NOTE: Don't bother checking the kind as it can vary by platform - let res = stdout.read_response_default_timeout(); - match &res.payload[0] { - ResponseData::Changed(change) => { - assert_eq!(&change.paths, &[file.to_path_buf().canonicalize().unwrap()]); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -fn should_support_json_watching_directory_recursively(mut action_std_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let dir = temp.child("dir"); - dir.create_dir_all().unwrap(); - - let file = dir.child("file"); - file.touch().unwrap(); - - // distant action --format json --interactive - let mut cmd = action_std_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .spawn() - .expect("Failed to execute"); - let mut stdin = cmd.stdin.take().unwrap(); - let mut stdout = ThreadedReader::new(cmd.stdout.take().unwrap()); - - let _ = send_watch_request(&mut stdin, &mut stdout, temp.to_path_buf(), true); - - // Make a change to some file - file.write_str("some text").unwrap(); - - // Pause a bit to ensure that the process detected the change and reported it - wait_even_longer(); - - // Get the response and verify the change - // NOTE: Don't bother checking the kind as it can vary by platform - let res = stdout.read_response_default_timeout(); - match &res.payload[0] { - ResponseData::Changed(change) => { - assert_eq!(&change.paths, &[file.to_path_buf().canonicalize().unwrap()]); - } - x => panic!("Unexpected response: {:?}", x), - } -} - -#[rstest] -fn should_support_json_reporting_changes_using_correct_request_id(mut action_std_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - - let file1 = temp.child("file1"); - file1.touch().unwrap(); - - let file2 = temp.child("file2"); - file2.touch().unwrap(); - - // distant action --format json --interactive - let mut cmd = action_std_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .spawn() - .expect("Failed to execute"); - let mut stdin = cmd.stdin.take().unwrap(); - let mut stdout = ThreadedReader::new(cmd.stdout.take().unwrap()); - - // Create a request to watch file1 - let file1_res = send_watch_request(&mut stdin, &mut stdout, file1.to_path_buf(), true); - - // Create a request to watch file2 - let file2_res = send_watch_request(&mut stdin, &mut stdout, file2.to_path_buf(), true); - - assert_ne!( - file1_res.origin_id, file2_res.origin_id, - "Two separate watch responses have same origin id" - ); - - // Make a change to file1 - file1.write_str("some text").unwrap(); - - // Pause a bit to ensure that the process detected the change and reported it - wait_even_longer(); - - // Get the response and verify the change - // NOTE: Don't bother checking the kind as it can vary by platform - let file1_change_res = stdout.read_response_default_timeout(); - match &file1_change_res.payload[0] { - ResponseData::Changed(change) => { - assert_eq!( - &change.paths, - &[file1.to_path_buf().canonicalize().unwrap()] - ); - } - x => panic!("Unexpected response: {:?}", x), - } - - // Process any extra messages (we might get create, content, and more) - loop { - // Sleep a bit to give time to get all changes happening - wait_a_bit(); - - if stdout.try_read_line().is_none() { - break; - } - } - - // Make a change to file2 - file2.write_str("some text").unwrap(); - - // Pause a bit to ensure that the process detected the change and reported it - wait_even_longer(); - - // Get the response and verify the change - // NOTE: Don't bother checking the kind as it can vary by platform - let file2_change_res = stdout.read_response_default_timeout(); - match &file2_change_res.payload[0] { - ResponseData::Changed(change) => { - assert_eq!( - &change.paths, - &[file2.to_path_buf().canonicalize().unwrap()] - ); - } - x => panic!("Unexpected response: {:?}", x), - } - - // Verify that the response origin ids match and are separate - assert_eq!( - file1_res.origin_id, file1_change_res.origin_id, - "File 1 watch origin and change origin are different" - ); - assert_eq!( - file2_res.origin_id, file2_change_res.origin_id, - "File 1 watch origin and change origin are different" - ); - assert_ne!( - file1_change_res.origin_id, file2_change_res.origin_id, - "Two separate watch change responses have same origin id" - ); -} - -#[rstest] -fn should_support_json_output_for_error(mut action_std_cmd: Command) { - let temp = assert_fs::TempDir::new().unwrap(); - let path = temp.to_path_buf().join("missing"); - - // distant action --format json --interactive - let mut cmd = action_std_cmd - .args(&["--format", "json"]) - .arg("--interactive") - .spawn() - .expect("Failed to execute"); - let mut stdin = cmd.stdin.take().unwrap(); - let mut stdout = ThreadedReader::new(cmd.stdout.take().unwrap()); - - let req = Request { - id: rand::random(), - tenant: random_tenant(), - payload: vec![RequestData::Watch { - path, - recursive: false, - only: Vec::new(), - except: Vec::new(), - }], - }; - - // Send our request to the process - let msg = format!("{}\n", serde_json::to_string(&req).unwrap()); - stdin - .write_all(msg.as_bytes()) - .expect("Failed to write to process"); - - // Pause a bit to ensure that the process started and processed our request - wait_even_longer(); - - // Ensure we got an acknowledgement of watching - let res = stdout.read_response_default_timeout(); - match &res.payload[0] { - ResponseData::Error(x) => { - assert_eq!(x.kind, ErrorKind::NotFound); - } - x => panic!("Unexpected response: {:?}", x), - } -} diff --git a/tests/cli/fixtures.rs b/tests/cli/fixtures.rs index c4e3e37..97f447c 100644 --- a/tests/cli/fixtures.rs +++ b/tests/cli/fixtures.rs @@ -1,99 +1,194 @@ -use crate::cli::utils; use assert_cmd::Command; -use distant_core::*; -use once_cell::sync::OnceCell; +use derive_more::{Deref, DerefMut}; +use once_cell::sync::Lazy; use rstest::*; use std::{ - ffi::OsStr, - net::SocketAddr, - process::{Command as StdCommand, Stdio}, - thread, + path::PathBuf, + process::{Child, Command as StdCommand, Stdio}, + time::Duration, }; -use tokio::{runtime::Runtime, sync::mpsc}; -const LOG_PATH: &str = "/tmp/test.distant.server.log"; +mod repl; +pub use repl::Repl; + +static ROOT_LOG_DIR: Lazy = Lazy::new(|| std::env::temp_dir().join("distant")); +static SESSION_RANDOM: Lazy = Lazy::new(rand::random); +const TIMEOUT: Duration = Duration::from_secs(3); + +#[derive(Deref, DerefMut)] +pub struct CtxCommand { + pub ctx: DistantManagerCtx, + + #[deref] + #[deref_mut] + pub cmd: T, +} /// Context for some listening distant server -pub struct DistantServerCtx { - pub addr: SocketAddr, - pub key: String, - done_tx: mpsc::Sender<()>, +pub struct DistantManagerCtx { + manager: Child, + socket_or_pipe: String, } -impl DistantServerCtx { - pub fn initialize() -> Self { - let ip_addr = "127.0.0.1".parse().unwrap(); - let (done_tx, mut done_rx) = mpsc::channel(1); - let (started_tx, mut started_rx) = mpsc::channel(1); - - // NOTE: We spawn a dedicated thread that runs our tokio runtime separately - // from our test itself because using assert_cmd blocks the thread - // and prevents our runtime from working unless we make the tokio - // test multi-threaded using `tokio::test(flavor = "multi_thread", worker_threads = 1)` - // which isn't great because we're only using async tests for our - // server itself; so, we hide that away since our test logic doesn't need to be async - thread::spawn(move || match Runtime::new() { - Ok(rt) => { - rt.block_on(async move { - let logger = utils::init_logging(LOG_PATH); - let opts = DistantServerOptions { - shutdown_after: None, - max_msg_capacity: 100, - }; - let key = SecretKey::default(); - let key_hex_string = key.unprotected_to_hex_key(); - let codec = XChaCha20Poly1305Codec::from(key); - let (_server, port) = - DistantServer::bind(ip_addr, "0".parse().unwrap(), codec, opts) - .await - .unwrap(); - - started_tx.send(Ok((port, key_hex_string))).await.unwrap(); - - let _ = done_rx.recv().await; - logger.flush(); - logger.shutdown(); - }); - } - Err(x) => { - started_tx.blocking_send(Err(x)).unwrap(); - } - }); - - // Extract our server startup data if we succeeded - let (port, key) = started_rx.blocking_recv().unwrap().unwrap(); +impl DistantManagerCtx { + /// Starts a manager and server so that clients can connect + pub fn start() -> Self { + eprintln!("Logging to {:?}", ROOT_LOG_DIR.as_path()); + std::fs::create_dir_all(ROOT_LOG_DIR.as_path()).expect("Failed to create root log dir"); + + // Start the manager + let mut manager_cmd = StdCommand::new(bin_path()); + manager_cmd + .arg("manager") + .arg("listen") + .arg("--log-file") + .arg(random_log_file("manager")) + .arg("--log-level") + .arg("trace"); + + let socket_or_pipe = if cfg!(windows) { + format!("distant_test_{}", rand::random::()) + } else { + std::env::temp_dir() + .join(format!("distant_test_{}.sock", rand::random::())) + .to_string_lossy() + .to_string() + }; + + if cfg!(windows) { + manager_cmd + .arg("--windows-pipe") + .arg(socket_or_pipe.as_str()); + } else { + manager_cmd + .arg("--unix-socket") + .arg(socket_or_pipe.as_str()); + } + + eprintln!("Spawning manager cmd: {manager_cmd:?}"); + let mut manager = manager_cmd.spawn().expect("Failed to spawn manager"); + std::thread::sleep(Duration::from_millis(50)); + if let Ok(Some(status)) = manager.try_wait() { + panic!("Manager exited ({}): {:?}", status.success(), status.code()); + } + + // Spawn a server locally by launching it through the manager + let mut launch_cmd = StdCommand::new(bin_path()); + launch_cmd + .arg("client") + .arg("launch") + .arg("--log-file") + .arg(random_log_file("launch")) + .arg("--log-level") + .arg("trace") + .arg("--distant") + .arg(bin_path()) + .arg("--distant-args") + .arg(format!( + "--log-file {} --log-level trace", + random_log_file("server").to_string_lossy() + )); + + if cfg!(windows) { + launch_cmd + .arg("--windows-pipe") + .arg(socket_or_pipe.as_str()); + } else { + launch_cmd.arg("--unix-socket").arg(socket_or_pipe.as_str()); + } + + launch_cmd.arg("manager://localhost"); + + eprintln!("Spawning launch cmd: {launch_cmd:?}"); + let output = launch_cmd.output().expect("Failed to launch server"); + if !output.status.success() { + let _ = manager.kill(); + panic!( + "Failed to launch: {}", + String::from_utf8_lossy(&output.stderr) + ); + } Self { - addr: SocketAddr::new(ip_addr, port), - key, - done_tx, + manager, + socket_or_pipe, + } + } + + pub fn shutdown(&self) { + // Send a shutdown request to the manager + let mut shutdown_cmd = StdCommand::new(bin_path()); + shutdown_cmd + .arg("manager") + .arg("shutdown") + .arg("--log-file") + .arg(random_log_file("shutdown")) + .arg("--log-level") + .arg("trace"); + + if cfg!(windows) { + shutdown_cmd + .arg("--windows-pipe") + .arg(self.socket_or_pipe.as_str()); + } else { + shutdown_cmd + .arg("--unix-socket") + .arg(self.socket_or_pipe.as_str()); + } + + eprintln!("Spawning shutdown cmd: {shutdown_cmd:?}"); + let output = shutdown_cmd.output().expect("Failed to shutdown server"); + if !output.status.success() { + panic!( + "Failed to shutdown: {}", + String::from_utf8_lossy(&output.stderr) + ); } } /// Produces a new test command that configures some distant command /// configured with an environment that can talk to a remote distant server - pub fn new_assert_cmd(&self, subcommand: impl AsRef) -> Command { - let mut cmd = Command::cargo_bin(env!("CARGO_PKG_NAME")).unwrap(); - cmd.arg(subcommand) - .args(&["--session", "environment"]) - .env("DISTANT_HOST", self.addr.ip().to_string()) - .env("DISTANT_PORT", self.addr.port().to_string()) - .env("DISTANT_KEY", self.key.as_str()); + pub fn new_assert_cmd(&self, subcommands: impl IntoIterator) -> Command { + let mut cmd = Command::cargo_bin(env!("CARGO_PKG_NAME")).expect("Failed to create cmd"); + for subcommand in subcommands { + cmd.arg(subcommand); + } + + cmd.arg("--log-file") + .arg(random_log_file("client")) + .arg("--log-level") + .arg("trace"); + + if cfg!(windows) { + cmd.arg("--windows-pipe").arg(self.socket_or_pipe.as_str()); + } else { + cmd.arg("--unix-socket").arg(self.socket_or_pipe.as_str()); + } + cmd } /// Configures some distant command with an environment that can talk to a /// remote distant server, spawning it as a child process - pub fn new_std_cmd(&self, subcommand: impl AsRef) -> StdCommand { - let cmd_path = assert_cmd::cargo::cargo_bin(env!("CARGO_PKG_NAME")); - let mut cmd = StdCommand::new(cmd_path); - - cmd.arg(subcommand) - .args(&["--session", "environment"]) - .env("DISTANT_HOST", self.addr.ip().to_string()) - .env("DISTANT_PORT", self.addr.port().to_string()) - .env("DISTANT_KEY", self.key.as_str()) - .stdin(Stdio::piped()) + pub fn new_std_cmd(&self, subcommands: impl IntoIterator) -> StdCommand { + let mut cmd = StdCommand::new(bin_path()); + + for subcommand in subcommands { + cmd.arg(subcommand); + } + + cmd.arg("--log-file") + .arg(random_log_file("client")) + .arg("--log-level") + .arg("trace"); + + if cfg!(windows) { + cmd.arg("--windows-pipe").arg(self.socket_or_pipe.as_str()); + } else { + cmd.arg("--unix-socket").arg(self.socket_or_pipe.as_str()); + } + + cmd.stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()); @@ -101,31 +196,62 @@ impl DistantServerCtx { } } -impl Drop for DistantServerCtx { - /// Kills server upon drop +/// Path to distant binary +fn bin_path() -> PathBuf { + assert_cmd::cargo::cargo_bin(env!("CARGO_PKG_NAME")) +} + +fn random_log_file(prefix: &str) -> PathBuf { + ROOT_LOG_DIR.join(format!( + "{}.{}.{}.log", + prefix, + *SESSION_RANDOM, + rand::random::() + )) +} + +impl Drop for DistantManagerCtx { + /// Kills manager upon drop fn drop(&mut self) { - let _ = self.done_tx.send(()); + // Attempt to shutdown gracefully, forcing a kill otherwise + if std::panic::catch_unwind(|| self.shutdown()).is_err() { + let _ = self.manager.kill(); + let _ = self.manager.wait(); + } } } #[fixture] -pub fn ctx() -> &'static DistantServerCtx { - static CTX: OnceCell = OnceCell::new(); +pub fn ctx() -> DistantManagerCtx { + DistantManagerCtx::start() +} - CTX.get_or_init(DistantServerCtx::initialize) +#[fixture] +pub fn lsp_cmd(ctx: DistantManagerCtx) -> CtxCommand { + let cmd = ctx.new_assert_cmd(vec!["client", "lsp"]); + CtxCommand { ctx, cmd } } #[fixture] -pub fn action_cmd(ctx: &'_ DistantServerCtx) -> Command { - ctx.new_assert_cmd("action") +pub fn action_cmd(ctx: DistantManagerCtx) -> CtxCommand { + let cmd = ctx.new_assert_cmd(vec!["client", "action"]); + CtxCommand { ctx, cmd } } #[fixture] -pub fn lsp_cmd(ctx: &'_ DistantServerCtx) -> Command { - ctx.new_assert_cmd("lsp") +pub fn action_std_cmd(ctx: DistantManagerCtx) -> CtxCommand { + let cmd = ctx.new_std_cmd(vec!["client", "action"]); + CtxCommand { ctx, cmd } } #[fixture] -pub fn action_std_cmd(ctx: &'_ DistantServerCtx) -> StdCommand { - ctx.new_std_cmd("action") +pub fn json_repl(ctx: DistantManagerCtx) -> CtxCommand { + let child = ctx + .new_std_cmd(vec!["client", "repl"]) + .arg("--format") + .arg("json") + .spawn() + .expect("Failed to start distant repl with json format"); + let cmd = Repl::new(child, TIMEOUT); + CtxCommand { ctx, cmd } } diff --git a/tests/cli/fixtures/repl.rs b/tests/cli/fixtures/repl.rs new file mode 100644 index 0000000..ad4b18f --- /dev/null +++ b/tests/cli/fixtures/repl.rs @@ -0,0 +1,213 @@ +use serde_json::Value; +use std::{ + io::{self, BufRead, BufReader, BufWriter, Write}, + process::Child, + thread, + time::Duration, +}; +use tokio::sync::mpsc; + +const CHANNEL_BUFFER: usize = 100; + +pub struct Repl { + child: Child, + stdin: mpsc::Sender, + stdout: mpsc::Receiver, + stderr: mpsc::Receiver, + timeout: Option, +} + +impl Repl { + /// Create a new [`Repl`] wrapping around a [`Child`] + pub fn new(mut child: Child, timeout: impl Into>) -> Self { + let mut stdin = BufWriter::new(child.stdin.take().expect("Child missing stdin")); + let mut stdout = BufReader::new(child.stdout.take().expect("Child missing stdout")); + let mut stderr = BufReader::new(child.stderr.take().expect("Child missing stderr")); + + let (stdin_tx, mut rx) = mpsc::channel::(CHANNEL_BUFFER); + thread::spawn(move || { + while let Some(data) = rx.blocking_recv() { + if stdin.write_all(data.as_bytes()).is_err() { + break; + } + + // NOTE: If we don't do this, the data doesn't appear to get sent even + // with a newline at the end. At least in testing thus far! + if stdin.flush().is_err() { + break; + } + } + }); + + let (tx, stdout_rx) = mpsc::channel::(CHANNEL_BUFFER); + thread::spawn(move || { + let mut line = String::new(); + while let Ok(n) = stdout.read_line(&mut line) { + if n == 0 { + break; + } + + if tx.blocking_send(line).is_err() { + break; + } + + line = String::new(); + } + }); + + let (tx, stderr_rx) = mpsc::channel::(CHANNEL_BUFFER); + thread::spawn(move || { + let mut line = String::new(); + while let Ok(n) = stderr.read_line(&mut line) { + if n == 0 { + break; + } + + if tx.blocking_send(line).is_err() { + break; + } + + line = String::new(); + } + }); + + Self { + child, + stdin: stdin_tx, + stdout: stdout_rx, + stderr: stderr_rx, + timeout: timeout.into(), + } + } + + /// Writes json to the repl over stdin and then waits for json to be received over stdout, + /// failing if either operation exceeds timeout if set or if the output to stdout is not json, + /// and returns none if stdout channel has closed + pub async fn write_and_read_json( + &mut self, + value: impl Into, + ) -> io::Result> { + self.write_json_to_stdin(value).await?; + self.read_json_from_stdout().await + } + + /// Writes a line of input to stdin, failing if exceeds timeout if set or if the stdin channel + /// has been closed. Will append a newline character (`\n`) if line does not end with one. + pub async fn write_line_to_stdin(&mut self, line: impl Into) -> io::Result<()> { + let mut line = line.into(); + if !line.ends_with('\n') { + line.push('\n'); + } + + match self.timeout { + Some(timeout) => match tokio::time::timeout(timeout, self.stdin.send(line)).await { + Ok(Ok(_)) => Ok(()), + Ok(Err(x)) => Err(io::Error::new(io::ErrorKind::BrokenPipe, x)), + Err(_) => Err(io::Error::new( + io::ErrorKind::TimedOut, + self.collect_stderr(), + )), + }, + None => self + .stdin + .send(line) + .await + .map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x)), + } + } + + /// Writes json value as a line of input to stdin, failing if exceeds timeout if set or if the + /// stdin channel has been closed. Will append a newline character (`\n`) to JSON string. + pub async fn write_json_to_stdin(&mut self, value: impl Into) -> io::Result<()> { + self.write_line_to_stdin(value.into().to_string()).await + } + + /// Tries to read a line from stdout, returning none if no stdout is available right now + /// + /// Will fail if no more stdout is available + pub fn try_read_line_from_stdout(&mut self) -> io::Result> { + match self.stdout.try_recv() { + Ok(line) => Ok(Some(line)), + Err(mpsc::error::TryRecvError::Empty) => Ok(None), + Err(mpsc::error::TryRecvError::Disconnected) => { + Err(io::Error::from(io::ErrorKind::UnexpectedEof)) + } + } + } + + /// Reads a line from stdout, failing if exceeds timeout if set, returning none if the stdout + /// channel has been closed + pub async fn read_line_from_stdout(&mut self) -> io::Result> { + match self.timeout { + Some(timeout) => match tokio::time::timeout(timeout, self.stdout.recv()).await { + Ok(x) => Ok(x), + Err(_) => Err(io::Error::new( + io::ErrorKind::TimedOut, + self.collect_stderr(), + )), + }, + None => Ok(self.stdout.recv().await), + } + } + + /// Reads a line from stdout and parses it as json, failing if unable to parse as json or the + /// timeout is reached if set, returning none if the stdout channel has been closed + pub async fn read_json_from_stdout(&mut self) -> io::Result> { + match self.read_line_from_stdout().await? { + Some(line) => { + let value = serde_json::from_str(&line) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?; + Ok(Some(value)) + } + None => Ok(None), + } + } + + /// Reads a line from stderr, failing if exceeds timeout if set, returning none if the stderr + /// channel has been closed + #[allow(dead_code)] + pub async fn read_line_from_stderr(&mut self) -> io::Result> { + match self.timeout { + Some(timeout) => match tokio::time::timeout(timeout, self.stderr.recv()).await { + Ok(x) => Ok(x), + Err(x) => Err(io::Error::new(io::ErrorKind::TimedOut, x)), + }, + None => Ok(self.stderr.recv().await), + } + } + + /// Tries to read a line from stderr, returning none if no stderr is available right now + /// + /// Will fail if no more stderr is available + pub fn try_read_line_from_stderr(&mut self) -> io::Result> { + match self.stderr.try_recv() { + Ok(line) => Ok(Some(line)), + Err(mpsc::error::TryRecvError::Empty) => Ok(None), + Err(mpsc::error::TryRecvError::Disconnected) => { + Err(io::Error::from(io::ErrorKind::UnexpectedEof)) + } + } + } + + /// Collects stderr into a singular string (failures will stop the collection) + pub fn collect_stderr(&mut self) -> String { + let mut stderr = String::new(); + + while let Ok(Some(line)) = self.try_read_line_from_stderr() { + stderr.push_str(&line); + } + + stderr + } + + /// Kills the repl by sending a signal to the process + pub fn kill(&mut self) -> io::Result<()> { + self.child.kill() + } +} + +impl Drop for Repl { + fn drop(&mut self) { + let _ = self.kill(); + } +} diff --git a/tests/cli/mod.rs b/tests/cli/mod.rs index 8361c26..5ec345a 100644 --- a/tests/cli/mod.rs +++ b/tests/cli/mod.rs @@ -1,3 +1,5 @@ mod action; mod fixtures; +mod repl; +mod scripts; mod utils; diff --git a/tests/cli/repl/copy.rs b/tests/cli/repl/copy.rs new file mode 100644 index 0000000..a07de97 --- /dev/null +++ b/tests/cli/repl/copy.rs @@ -0,0 +1,111 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use predicates::prelude::*; +use rstest::*; +use serde_json::json; + +const FILE_CONTENTS: &str = r#" +some text +on multiple lines +that is a file's contents +"#; + +#[rstest] +#[tokio::test] +async fn should_support_json_copying_file(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("file"); + src.write_str(FILE_CONTENTS).unwrap(); + + let dst = temp.child("file2"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "copy", + "src": src.to_path_buf(), + "dst": dst.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + src.assert(predicate::path::exists()); + dst.assert(predicate::path::eq_file(src.path())); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_copying_nonempty_directory(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + // Make a non-empty directory + let src = temp.child("dir"); + src.create_dir_all().unwrap(); + let src_file = src.child("file"); + src_file.write_str(FILE_CONTENTS).unwrap(); + + let dst = temp.child("dir2"); + let dst_file = dst.child("file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "copy", + "src": src.to_path_buf(), + "dst": dst.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + src_file.assert(predicate::path::exists()); + dst_file.assert(predicate::path::eq_file(src_file.path())); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("dir"); + let dst = temp.child("dir2"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "copy", + "src": src.to_path_buf(), + "dst": dst.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); + + src.assert(predicate::path::missing()); + dst.assert(predicate::path::missing()); +} diff --git a/tests/cli/repl/dir_create.rs b/tests/cli/repl/dir_create.rs new file mode 100644 index 0000000..4432281 --- /dev/null +++ b/tests/cli/repl/dir_create.rs @@ -0,0 +1,92 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use predicates::prelude::*; +use rstest::*; +use serde_json::json; + +#[rstest] +#[tokio::test] +async fn should_support_json_output(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "dir_create", + "path": dir.to_path_buf(), + "all": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + dir.assert(predicate::path::exists()); + dir.assert(predicate::path::is_dir()); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_creating_missing_parent_directories_if_specified( + mut json_repl: CtxCommand, +) { + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("dir1").child("dir2"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "dir_create", + "path": dir.to_path_buf(), + "all": true, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + dir.assert(predicate::path::exists()); + dir.assert(predicate::path::is_dir()); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let dir = temp.child("missing-dir").child("dir"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "dir_create", + "path": dir.to_path_buf(), + "all": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); + + dir.assert(predicate::path::missing()); +} diff --git a/tests/cli/repl/dir_read.rs b/tests/cli/repl/dir_read.rs new file mode 100644 index 0000000..bfb80ec --- /dev/null +++ b/tests/cli/repl/dir_read.rs @@ -0,0 +1,264 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use rstest::*; +use serde_json::json; +use std::path::PathBuf; + +/// Creates a directory in the form +/// +/// $TEMP/ +/// $TEMP/dir1/ +/// $TEMP/dir1/dira/ +/// $TEMP/dir1/dirb/ +/// $TEMP/dir1/dirb/file1 +/// $TEMP/dir1/file1 +/// $TEMP/dir1/file2 +/// $TEMP/dir2/ +/// $TEMP/dir2/dira/ +/// $TEMP/dir2/dirb/ +/// $TEMP/dir2/dirb/file1 +/// $TEMP/dir2/file1 +/// $TEMP/dir2/file2 +/// $TEMP/file1 +/// $TEMP/file2 +fn make_directory() -> assert_fs::TempDir { + let temp = assert_fs::TempDir::new().unwrap(); + + // $TEMP/file1 + // $TEMP/file2 + temp.child("file1").touch().unwrap(); + temp.child("file2").touch().unwrap(); + + // $TEMP/dir1/ + // $TEMP/dir1/file1 + // $TEMP/dir1/file2 + let dir1 = temp.child("dir1"); + dir1.create_dir_all().unwrap(); + dir1.child("file1").touch().unwrap(); + dir1.child("file2").touch().unwrap(); + + // $TEMP/dir1/dira/ + let dir1_dira = dir1.child("dira"); + dir1_dira.create_dir_all().unwrap(); + + // $TEMP/dir1/dirb/ + // $TEMP/dir1/dirb/file1 + let dir1_dirb = dir1.child("dirb"); + dir1_dirb.create_dir_all().unwrap(); + dir1_dirb.child("file1").touch().unwrap(); + + // $TEMP/dir2/ + // $TEMP/dir2/file1 + // $TEMP/dir2/file2 + let dir2 = temp.child("dir2"); + dir2.create_dir_all().unwrap(); + dir2.child("file1").touch().unwrap(); + dir2.child("file2").touch().unwrap(); + + // $TEMP/dir2/dira/ + let dir2_dira = dir2.child("dira"); + dir2_dira.create_dir_all().unwrap(); + + // $TEMP/dir2/dirb/ + // $TEMP/dir2/dirb/file1 + let dir2_dirb = dir2.child("dirb"); + dir2_dirb.create_dir_all().unwrap(); + dir2_dirb.child("file1").touch().unwrap(); + + temp +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output(mut json_repl: CtxCommand) { + let temp = make_directory(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "dir_read", + "path": temp.to_path_buf(), + "depth": 1, + "absolute": false, + "canonicalize": false, + "include_root": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "dir_entries", + "entries": [ + {"path": PathBuf::from("dir1"), "file_type": "dir", "depth": 1}, + {"path": PathBuf::from("dir2"), "file_type": "dir", "depth": 1}, + {"path": PathBuf::from("file1"), "file_type": "file", "depth": 1}, + {"path": PathBuf::from("file2"), "file_type": "file", "depth": 1}, + ], + "errors": [], + }) + ); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_returning_absolute_paths_if_specified( + mut json_repl: CtxCommand, +) { + let temp = make_directory(); + + // NOTE: Our root path is always canonicalized, so the absolute path + // provided is our canonicalized root path prepended + let root_path = temp.to_path_buf().canonicalize().unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "dir_read", + "path": temp.to_path_buf(), + "depth": 1, + "absolute": true, + "canonicalize": false, + "include_root": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "dir_entries", + "entries": [ + {"path": root_path.join("dir1"), "file_type": "dir", "depth": 1}, + {"path": root_path.join("dir2"), "file_type": "dir", "depth": 1}, + {"path": root_path.join("file1"), "file_type": "file", "depth": 1}, + {"path": root_path.join("file2"), "file_type": "file", "depth": 1}, + ], + "errors": [], + }) + ); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_returning_all_files_and_directories_if_depth_is_0( + mut json_repl: CtxCommand, +) { + let temp = make_directory(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "dir_read", + "path": temp.to_path_buf(), + "depth": 0, + "absolute": false, + "canonicalize": false, + "include_root": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "dir_entries", + "entries": [ + {"path": PathBuf::from("dir1"), "file_type": "dir", "depth": 1}, + {"path": PathBuf::from("dir1").join("dira"), "file_type": "dir", "depth": 2}, + {"path": PathBuf::from("dir1").join("dirb"), "file_type": "dir", "depth": 2}, + {"path": PathBuf::from("dir1").join("dirb").join("file1"), "file_type": "file", "depth": 3}, + {"path": PathBuf::from("dir1").join("file1"), "file_type": "file", "depth": 2}, + {"path": PathBuf::from("dir1").join("file2"), "file_type": "file", "depth": 2}, + {"path": PathBuf::from("dir2"), "file_type": "dir", "depth": 1}, + {"path": PathBuf::from("dir2").join("dira"), "file_type": "dir", "depth": 2}, + {"path": PathBuf::from("dir2").join("dirb"), "file_type": "dir", "depth": 2}, + {"path": PathBuf::from("dir2").join("dirb").join("file1"), "file_type": "file", "depth": 3}, + {"path": PathBuf::from("dir2").join("file1"), "file_type": "file", "depth": 2}, + {"path": PathBuf::from("dir2").join("file2"), "file_type": "file", "depth": 2}, + {"path": PathBuf::from("file1"), "file_type": "file", "depth": 1}, + {"path": PathBuf::from("file2"), "file_type": "file", "depth": 1}, + ], + "errors": [], + }) + ); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_including_root_directory_if_specified( + mut json_repl: CtxCommand, +) { + let temp = make_directory(); + + // NOTE: Our root path is always canonicalized, so yielded entry + // is the canonicalized version + let root_path = temp.to_path_buf().canonicalize().unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "dir_read", + "path": temp.to_path_buf(), + "depth": 1, + "absolute": false, + "canonicalize": false, + "include_root": true, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "dir_entries", + "entries": [ + {"path": root_path, "file_type": "dir", "depth": 0}, + {"path": PathBuf::from("dir1"), "file_type": "dir", "depth": 1}, + {"path": PathBuf::from("dir2"), "file_type": "dir", "depth": 1}, + {"path": PathBuf::from("file1"), "file_type": "file", "depth": 1}, + {"path": PathBuf::from("file2"), "file_type": "file", "depth": 1}, + ], + "errors": [], + }) + ); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = make_directory(); + let dir = temp.child("missing-dir"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "dir_read", + "path": dir.to_path_buf(), + "depth": 1, + "absolute": false, + "canonicalize": false, + "include_root": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); +} diff --git a/tests/cli/repl/exists.rs b/tests/cli/repl/exists.rs new file mode 100644 index 0000000..7c4a434 --- /dev/null +++ b/tests/cli/repl/exists.rs @@ -0,0 +1,63 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use rstest::*; +use serde_json::json; + +#[rstest] +#[tokio::test] +async fn should_support_json_true_if_exists(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + // Create file + let file = temp.child("file"); + file.touch().unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "exists", + "path": file.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "exists", + "value": true, + }) + ); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_false_if_not_exists(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + // Don't create file + let file = temp.child("file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "exists", + "path": file.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "exists", + "value": false, + }) + ); +} diff --git a/tests/cli/repl/file_append.rs b/tests/cli/repl/file_append.rs new file mode 100644 index 0000000..e73fbd5 --- /dev/null +++ b/tests/cli/repl/file_append.rs @@ -0,0 +1,75 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use rstest::*; +use serde_json::json; + +const FILE_CONTENTS: &str = r#" +some text +on multiple lines +that is a file's contents +"#; + +const APPENDED_FILE_CONTENTS: &str = r#" +even more +file contents +"#; + +#[rstest] +#[tokio::test] +async fn should_support_json_output(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str(FILE_CONTENTS).unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_append", + "path": file.to_path_buf(), + "data": APPENDED_FILE_CONTENTS.as_bytes().to_vec(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + // NOTE: We wait a little bit to give the OS time to fully write to file + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Because we're talking to a local server, we can verify locally + file.assert(format!("{}{}", FILE_CONTENTS, APPENDED_FILE_CONTENTS)); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("missing-dir").child("missing-file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_append", + "path": file.to_path_buf(), + "data": APPENDED_FILE_CONTENTS.as_bytes().to_vec(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); + + // Because we're talking to a local server, we can verify locally + file.assert(predicates::path::missing()); +} diff --git a/tests/cli/repl/file_append_text.rs b/tests/cli/repl/file_append_text.rs new file mode 100644 index 0000000..061c989 --- /dev/null +++ b/tests/cli/repl/file_append_text.rs @@ -0,0 +1,75 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use rstest::*; +use serde_json::json; + +const FILE_CONTENTS: &str = r#" +some text +on multiple lines +that is a file's contents +"#; + +const APPENDED_FILE_CONTENTS: &str = r#" +even more +file contents +"#; + +#[rstest] +#[tokio::test] +async fn should_support_json_output(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str(FILE_CONTENTS).unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_append_text", + "path": file.to_path_buf(), + "text": APPENDED_FILE_CONTENTS.to_string(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + // NOTE: We wait a little bit to give the OS time to fully write to file + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Because we're talking to a local server, we can verify locally + file.assert(format!("{}{}", FILE_CONTENTS, APPENDED_FILE_CONTENTS)); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("missing-dir").child("missing-file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_append_text", + "path": file.to_path_buf(), + "text": APPENDED_FILE_CONTENTS.to_string(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); + + // Because we're talking to a local server, we can verify locally + file.assert(predicates::path::missing()); +} diff --git a/tests/cli/repl/file_read.rs b/tests/cli/repl/file_read.rs new file mode 100644 index 0000000..5938998 --- /dev/null +++ b/tests/cli/repl/file_read.rs @@ -0,0 +1,60 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use rstest::*; +use serde_json::json; + +const FILE_CONTENTS: &str = r#" +some text +on multiple lines +that is a file's contents +"#; + +#[rstest] +#[tokio::test] +async fn should_support_json_output(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str(FILE_CONTENTS).unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_read", + "path": file.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "blob", + "data": FILE_CONTENTS.as_bytes().to_vec() + }) + ); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("missing-file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_read", + "path": file.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); +} diff --git a/tests/cli/repl/file_read_text.rs b/tests/cli/repl/file_read_text.rs new file mode 100644 index 0000000..d1b8379 --- /dev/null +++ b/tests/cli/repl/file_read_text.rs @@ -0,0 +1,60 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use rstest::*; +use serde_json::json; + +const FILE_CONTENTS: &str = r#" +some text +on multiple lines +that is a file's contents +"#; + +#[rstest] +#[tokio::test] +async fn should_support_json_output(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + file.write_str(FILE_CONTENTS).unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_read_text", + "path": file.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "text", + "data": FILE_CONTENTS.to_string() + }) + ); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("missing-file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_read_text", + "path": file.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); +} diff --git a/tests/cli/repl/file_write.rs b/tests/cli/repl/file_write.rs new file mode 100644 index 0000000..21ef333 --- /dev/null +++ b/tests/cli/repl/file_write.rs @@ -0,0 +1,69 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use rstest::*; +use serde_json::json; + +const FILE_CONTENTS: &str = r#" +some text +on multiple lines +that is a file's contents +"#; + +#[rstest] +#[tokio::test] +async fn should_support_json_output(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_write", + "path": file.to_path_buf(), + "data": FILE_CONTENTS.as_bytes().to_vec(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + // NOTE: We wait a little bit to give the OS time to fully write to file + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Because we're talking to a local server, we can verify locally + file.assert(FILE_CONTENTS); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("missing-dir").child("missing-file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_write", + "path": file.to_path_buf(), + "data": FILE_CONTENTS.as_bytes().to_vec(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); + + // Because we're talking to a local server, we can verify locally + file.assert(predicates::path::missing()); +} diff --git a/tests/cli/repl/file_write_text.rs b/tests/cli/repl/file_write_text.rs new file mode 100644 index 0000000..00c50cc --- /dev/null +++ b/tests/cli/repl/file_write_text.rs @@ -0,0 +1,69 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use rstest::*; +use serde_json::json; + +const FILE_CONTENTS: &str = r#" +some text +on multiple lines +that is a file's contents +"#; + +#[rstest] +#[tokio::test] +async fn should_support_json_output(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("test-file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_write_text", + "path": file.to_path_buf(), + "text": FILE_CONTENTS.to_string(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + // NOTE: We wait a little bit to give the OS time to fully write to file + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Because we're talking to a local server, we can verify locally + file.assert(FILE_CONTENTS); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let file = temp.child("missing-dir").child("missing-file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "file_write_text", + "path": file.to_path_buf(), + "text": FILE_CONTENTS.to_string(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); + + // Because we're talking to a local server, we can verify locally + file.assert(predicates::path::missing()); +} diff --git a/tests/cli/repl/metadata.rs b/tests/cli/repl/metadata.rs new file mode 100644 index 0000000..81680bc --- /dev/null +++ b/tests/cli/repl/metadata.rs @@ -0,0 +1,159 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use rstest::*; +use serde_json::{json, Value}; + +const FILE_CONTENTS: &str = r#" +some text +on multiple lines +that is a file's contents +"#; + +#[rstest] +#[tokio::test] +async fn should_support_json_metadata_for_file(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + let file = temp.child("file"); + file.write_str(FILE_CONTENTS).unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "metadata", + "path": file.to_path_buf(), + "canonicalize": false, + "resolve_file_type": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "metadata"); + assert_eq!(res["payload"]["canonicalized_path"], Value::Null); + assert_eq!(res["payload"]["file_type"], "file"); + assert_eq!(res["payload"]["readonly"], false); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_metadata_for_directory(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "metadata", + "path": dir.to_path_buf(), + "canonicalize": false, + "resolve_file_type": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "metadata"); + assert_eq!(res["payload"]["canonicalized_path"], Value::Null); + assert_eq!(res["payload"]["file_type"], "dir"); + assert_eq!(res["payload"]["readonly"], false); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_metadata_for_including_a_canonicalized_path( + mut json_repl: CtxCommand, +) { + let temp = assert_fs::TempDir::new().unwrap(); + + let file = temp.child("file"); + file.touch().unwrap(); + + let link = temp.child("link"); + link.symlink_to_file(file.path()).unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "metadata", + "path": link.to_path_buf(), + "canonicalize": true, + "resolve_file_type": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "metadata"); + assert_eq!( + res["payload"]["canonicalized_path"], + json!(file.path().canonicalize().unwrap()) + ); + assert_eq!(res["payload"]["file_type"], "symlink"); + assert_eq!(res["payload"]["readonly"], false); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_metadata_for_resolving_file_type_of_symlink( + mut json_repl: CtxCommand, +) { + let temp = assert_fs::TempDir::new().unwrap(); + + let file = temp.child("file"); + file.touch().unwrap(); + + let link = temp.child("link"); + link.symlink_to_file(file.path()).unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "metadata", + "path": link.to_path_buf(), + "canonicalize": true, + "resolve_file_type": true, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "metadata"); + assert_eq!(res["payload"]["file_type"], "file"); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + // Don't create file + let file = temp.child("file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "metadata", + "path": file.to_path_buf(), + "canonicalize": false, + "resolve_file_type": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); +} diff --git a/tests/cli/repl/mod.rs b/tests/cli/repl/mod.rs new file mode 100644 index 0000000..da3800f --- /dev/null +++ b/tests/cli/repl/mod.rs @@ -0,0 +1,16 @@ +mod copy; +mod dir_create; +mod dir_read; +mod exists; +mod file_append; +mod file_append_text; +mod file_read; +mod file_read_text; +mod file_write; +mod file_write_text; +mod metadata; +mod proc_spawn; +mod remove; +mod rename; +mod system_info; +mod watch; diff --git a/tests/cli/repl/proc_spawn.rs b/tests/cli/repl/proc_spawn.rs new file mode 100644 index 0000000..7d7a240 --- /dev/null +++ b/tests/cli/repl/proc_spawn.rs @@ -0,0 +1,273 @@ +use crate::cli::{fixtures::*, scripts::*}; +use rstest::*; +use serde_json::json; + +fn make_cmd(args: Vec<&str>) -> String { + format!( + r#"{} {} {}"#, + *SCRIPT_RUNNER, + *SCRIPT_RUNNER_ARG, + args.join(" ") + ) +} + +fn trim(arr: &Vec) -> &[serde_json::Value] { + let arr = arr.as_slice(); + + if arr.is_empty() { + return arr; + } + + let mut start = 0; + let mut end = arr.len() - 1; + let mut i = start; + + fn is_whitespace(value: &serde_json::Value) -> bool { + value == b' ' || value == b'\t' || value == b'\r' || value == b'\n' + } + + // Trim from front + while start < end { + if is_whitespace(&arr[i]) { + start = i + 1; + i += 1; + } else { + break; + } + } + + i = end; + + // Trim from back + while end > start { + if is_whitespace(&arr[i]) { + end = i - 1; + i -= 1; + } else { + break; + } + } + + &arr[start..=end] +} + +// Trim and compare value to string +fn check_value_as_str(value: &serde_json::Value, other: &str) { + let arr = trim(value.as_array().expect("value should be a byte array")); + + if arr != other.as_bytes() { + let s = arr + .iter() + .map(|value| { + (value + .as_u64() + .expect("Invalid array value, expected number") as u8) as char + }) + .collect::(); + panic!("Expected '{other}', but got '{s}'"); + } +} + +#[rstest] +#[tokio::test] +async fn should_support_json_to_execute_program_and_return_exit_status( + mut json_repl: CtxCommand, +) { + let cmd = make_cmd(vec![ECHO_ARGS_TO_STDOUT.to_str().unwrap()]); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "proc_spawn", + "cmd": cmd, + "persist": false, + "pty": null, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "proc_spawned"); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_to_capture_and_print_stdout(mut json_repl: CtxCommand) { + let cmd = make_cmd(vec![ECHO_ARGS_TO_STDOUT.to_str().unwrap(), "some output"]); + + // Spawn the process + let origin_id = rand::random::().to_string(); + let req = json!({ + "id": origin_id, + "payload": { + "type": "proc_spawn", + "cmd": cmd, + "persist": false, + "pty": null, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], origin_id); + assert_eq!(res["payload"]["type"], "proc_spawned"); + + // Wait for output to show up (for stderr) + let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], origin_id); + assert_eq!(res["payload"]["type"], "proc_stdout"); + check_value_as_str(&res["payload"]["data"], "some output"); + + // Now we wait for the process to complete + let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], origin_id); + assert_eq!(res["payload"]["type"], "proc_done"); + assert_eq!(res["payload"]["success"], true); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_to_capture_and_print_stderr(mut json_repl: CtxCommand) { + let cmd = make_cmd(vec![ECHO_ARGS_TO_STDERR.to_str().unwrap(), "some output"]); + + // Spawn the process + let origin_id = rand::random::().to_string(); + let req = json!({ + "id": origin_id, + "payload": { + "type": "proc_spawn", + "cmd": cmd, + "persist": false, + "pty": null, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], origin_id); + assert_eq!(res["payload"]["type"], "proc_spawned"); + + // Wait for output to show up (for stderr) + let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], origin_id); + assert_eq!(res["payload"]["type"], "proc_stderr"); + check_value_as_str(&res["payload"]["data"], "some output"); + + // Now we wait for the process to complete + let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], origin_id); + assert_eq!(res["payload"]["type"], "proc_done"); + assert_eq!(res["payload"]["success"], true); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_to_forward_stdin_to_remote_process(mut json_repl: CtxCommand) { + let cmd = make_cmd(vec![ECHO_STDIN_TO_STDOUT.to_str().unwrap()]); + + // Spawn the process + let origin_id = rand::random::().to_string(); + let req = json!({ + "id": origin_id, + "payload": { + "type": "proc_spawn", + "cmd": cmd, + "persist": false, + "pty": null, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], origin_id); + assert_eq!(res["payload"]["type"], "proc_spawned"); + + // Write output to stdin of process to trigger getting it back as stdout + let proc_id = res["payload"]["id"] + .as_u64() + .expect("Invalid proc id value"); + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "proc_stdin", + "id": proc_id, + "data": b"some output\n", + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "ok"); + + let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], origin_id); + assert_eq!(res["payload"]["type"], "proc_stdout"); + check_value_as_str(&res["payload"]["data"], "some output"); + + // Now kill the process and wait for it to complete + let id = rand::random::().to_string(); + let res_1 = json_repl + .write_and_read_json(json!({ + "id": id, + "payload": { + "type": "proc_kill", + "id": proc_id, + }, + + })) + .await + .unwrap() + .unwrap(); + let res_2 = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + // The order of responses may be different (kill could come before ok), so we need + // to check that we get one of each type + let got_ok = res_1["payload"]["type"] == "ok" || res_2["payload"]["type"] == "ok"; + let got_done = + res_1["payload"]["type"] == "proc_done" || res_2["payload"]["type"] == "proc_done"; + + if res_1["payload"]["type"] == "ok" { + assert_eq!(res_1["origin_id"], id); + } else if res_1["payload"]["type"] == "proc_done" { + assert_eq!(res_1["origin_id"], origin_id); + } + + if res_2["payload"]["type"] == "ok" { + assert_eq!(res_2["origin_id"], id); + } else if res_2["payload"]["type"] == "proc_done" { + assert_eq!(res_2["origin_id"], origin_id); + } + + assert!(got_ok, "Did not receive ok from proc_kill"); + assert!(got_done, "Did not receive proc_done from killed process"); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "proc_spawn", + "cmd": DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), + "persist": false, + "pty": null, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); +} diff --git a/tests/cli/repl/remove.rs b/tests/cli/repl/remove.rs new file mode 100644 index 0000000..9989622 --- /dev/null +++ b/tests/cli/repl/remove.rs @@ -0,0 +1,135 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use predicates::prelude::*; +use rstest::*; +use serde_json::json; + +#[rstest] +#[tokio::test] +async fn should_support_json_removing_file(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + let file = temp.child("file"); + file.touch().unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "remove", + "path": file.to_path_buf(), + "force": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + file.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_removing_empty_directory(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + // Make an empty directory + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "remove", + "path": dir.to_path_buf(), + "force": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + dir.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_removing_nonempty_directory_if_force_specified( + mut json_repl: CtxCommand, +) { + let temp = assert_fs::TempDir::new().unwrap(); + + // Make an empty directory + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "remove", + "path": dir.to_path_buf(), + "force": true, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + dir.assert(predicate::path::missing()); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + // Make a non-empty directory so we fail to remove it + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + dir.child("file").touch().unwrap(); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "remove", + "path": dir.to_path_buf(), + "force": false, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert!( + res["payload"]["kind"] == "other" || res["payload"]["kind"] == "unknown", + "error kind was neither other or unknown" + ); + + dir.assert(predicate::path::exists()); + dir.assert(predicate::path::is_dir()); +} diff --git a/tests/cli/repl/rename.rs b/tests/cli/repl/rename.rs new file mode 100644 index 0000000..e1c50f9 --- /dev/null +++ b/tests/cli/repl/rename.rs @@ -0,0 +1,114 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use predicates::prelude::*; +use rstest::*; +use serde_json::json; + +const FILE_CONTENTS: &str = r#" +some text +on multiple lines +that is a file's contents +"#; + +#[rstest] +#[tokio::test] +async fn should_support_json_renaming_file(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("file"); + src.write_str(FILE_CONTENTS).unwrap(); + + let dst = temp.child("file2"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "rename", + "src": src.to_path_buf(), + "dst": dst.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + src.assert(predicate::path::missing()); + dst.assert(FILE_CONTENTS); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_renaming_nonempty_directory(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + // Make a non-empty directory + let src = temp.child("dir"); + src.create_dir_all().unwrap(); + let src_file = src.child("file"); + src_file.write_str(FILE_CONTENTS).unwrap(); + + let dst = temp.child("dir2"); + let dst_file = dst.child("file"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "rename", + "src": src.to_path_buf(), + "dst": dst.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + src.assert(predicate::path::missing()); + src_file.assert(predicate::path::missing()); + + dst.assert(predicate::path::is_dir()); + dst_file.assert(FILE_CONTENTS); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + let src = temp.child("dir"); + let dst = temp.child("dir2"); + + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "rename", + "src": src.to_path_buf(), + "dst": dst.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); + + src.assert(predicate::path::missing()); + dst.assert(predicate::path::missing()); +} diff --git a/tests/cli/repl/system_info.rs b/tests/cli/repl/system_info.rs new file mode 100644 index 0000000..9a4a1e2 --- /dev/null +++ b/tests/cli/repl/system_info.rs @@ -0,0 +1,29 @@ +use crate::cli::fixtures::*; +use rstest::*; +use serde_json::json; +use std::env; + +#[rstest] +#[tokio::test] +async fn should_support_json_system_info(mut json_repl: CtxCommand) { + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { "type": "system_info" }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "system_info", + "family": env::consts::FAMILY.to_string(), + "os": env::consts::OS.to_string(), + "arch": env::consts::ARCH.to_string(), + "current_dir": env::current_dir().unwrap_or_default(), + "main_separator": std::path::MAIN_SEPARATOR, + }) + ); +} diff --git a/tests/cli/repl/watch.rs b/tests/cli/repl/watch.rs new file mode 100644 index 0000000..ce4e7df --- /dev/null +++ b/tests/cli/repl/watch.rs @@ -0,0 +1,256 @@ +use crate::cli::fixtures::*; +use assert_fs::prelude::*; +use rstest::*; +use serde_json::json; +use std::time::Duration; + +async fn wait_a_bit() { + wait_millis(250).await; +} + +async fn wait_even_longer() { + wait_millis(500).await; +} + +async fn wait_millis(millis: u64) { + tokio::time::sleep(Duration::from_millis(millis)).await; +} + +#[rstest] +#[tokio::test] +async fn should_support_json_watching_single_file(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + let file = temp.child("file"); + file.touch().unwrap(); + + // Watch single file for changes + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "watch", + "path": file.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + // Make a change to some file + file.write_str("some text").unwrap(); + + // Pause a bit to ensure that the process detected the change and reported it + wait_even_longer().await; + + // Get the response and verify the change + // NOTE: Don't bother checking the kind as it can vary by platform + let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "changed"); + assert_eq!( + res["payload"]["paths"], + json!([file.to_path_buf().canonicalize().unwrap()]) + ); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_watching_directory_recursively(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + + let dir = temp.child("dir"); + dir.create_dir_all().unwrap(); + + let file = dir.child("file"); + file.touch().unwrap(); + + // Watch a directory recursively for changes + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "watch", + "path": temp.to_path_buf(), + "recursive": true, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + // Make a change to some file + file.write_str("some text").unwrap(); + + // Windows reports a directory change first + if cfg!(windows) { + // Pause a bit to ensure that the process detected the change and reported it + wait_even_longer().await; + + // Get the response and verify the change + // NOTE: Don't bother checking the kind as it can vary by platform + let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "changed"); + assert_eq!( + res["payload"]["paths"], + json!([dir.to_path_buf().canonicalize().unwrap()]) + ); + } + + // Pause a bit to ensure that the process detected the change and reported it + wait_even_longer().await; + + // Get the response and verify the change + // NOTE: Don't bother checking the kind as it can vary by platform + let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "changed"); + assert_eq!( + res["payload"]["paths"], + json!([file.to_path_buf().canonicalize().unwrap()]) + ); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_reporting_changes_using_correct_request_id( + mut json_repl: CtxCommand, +) { + let temp = assert_fs::TempDir::new().unwrap(); + + let file1 = temp.child("file1"); + file1.touch().unwrap(); + + let file2 = temp.child("file2"); + file2.touch().unwrap(); + + // Watch file1 for changes + let id_1 = rand::random::().to_string(); + let req = json!({ + "id": id_1, + "payload": { + "type": "watch", + "path": file1.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id_1); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + // Watch file2 for changes + let id_2 = rand::random::().to_string(); + let req = json!({ + "id": id_2, + "payload": { + "type": "watch", + "path": file2.to_path_buf(), + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id_2); + assert_eq!( + res["payload"], + json!({ + "type": "ok" + }) + ); + + // Make a change to file1 + file1.write_str("some text").unwrap(); + + // Pause a bit to ensure that the process detected the change and reported it + wait_even_longer().await; + + // Get the response and verify the change + // NOTE: Don't bother checking the kind as it can vary by platform + let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id_1); + assert_eq!(res["payload"]["type"], "changed"); + assert_eq!( + res["payload"]["paths"], + json!([file1.to_path_buf().canonicalize().unwrap()]) + ); + + // Process any extra messages (we might get create, content, and more) + loop { + // Sleep a bit to give time to get all changes happening + wait_a_bit().await; + + if json_repl + .try_read_line_from_stdout() + .expect("stdout closed unexpectedly") + .is_none() + { + break; + } + } + + // Make a change to file2 + file2.write_str("some text").unwrap(); + + // Pause a bit to ensure that the process detected the change and reported it + wait_even_longer().await; + + // Get the response and verify the change + // NOTE: Don't bother checking the kind as it can vary by platform + let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); + + assert_eq!(res["origin_id"], id_2); + assert_eq!(res["payload"]["type"], "changed"); + assert_eq!( + res["payload"]["paths"], + json!([file2.to_path_buf().canonicalize().unwrap()]) + ); +} + +#[rstest] +#[tokio::test] +async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + let temp = assert_fs::TempDir::new().unwrap(); + let path = temp.to_path_buf().join("missing"); + + // Watch a missing path for changes + let id = rand::random::().to_string(); + let req = json!({ + "id": id, + "payload": { + "type": "watch", + "path": path, + }, + }); + + let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); + + // Ensure we got an acknowledgement of watching that failed + assert_eq!(res["origin_id"], id); + assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["payload"]["kind"], "not_found"); +} diff --git a/tests/cli/scripts.rs b/tests/cli/scripts.rs new file mode 100644 index 0000000..3108f79 --- /dev/null +++ b/tests/cli/scripts.rs @@ -0,0 +1,121 @@ +use assert_fs::prelude::*; +use once_cell::sync::Lazy; + +static TEMP_SCRIPT_DIR: Lazy = Lazy::new(|| assert_fs::TempDir::new().unwrap()); + +pub static SCRIPT_RUNNER: Lazy = + Lazy::new(|| String::from(if cfg!(windows) { "cmd.exe" } else { "bash" })); + +pub static SCRIPT_RUNNER_ARG: Lazy = + Lazy::new(|| String::from(if cfg!(windows) { "/c" } else { "" })); + +#[cfg(unix)] +pub static ECHO_ARGS_TO_STDOUT: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_args_to_stdout.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + printf "%s" "$*" + "# + )) + .unwrap(); + script +}); + +#[cfg(windows)] +pub static ECHO_ARGS_TO_STDOUT: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_args_to_stdout.cmd"); + script + .write_str(indoc::indoc!( + r#" + @echo off + echo %* + "# + )) + .unwrap(); + script +}); + +#[cfg(unix)] +pub static ECHO_ARGS_TO_STDERR: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_args_to_stderr.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + printf "%s" "$*" 1>&2 + "# + )) + .unwrap(); + script +}); + +#[cfg(windows)] +pub static ECHO_ARGS_TO_STDERR: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_args_to_stderr.cmd"); + script + .write_str(indoc::indoc!( + r#" + @echo off + echo %* 1>&2 + "# + )) + .unwrap(); + script +}); + +#[cfg(unix)] +pub static ECHO_STDIN_TO_STDOUT: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_stdin_to_stdout.sh"); + script + .write_str(indoc::indoc!( + r#" + #/usr/bin/env bash + while IFS= read; do echo "$REPLY"; done + "# + )) + .unwrap(); + script +}); + +#[cfg(windows)] +pub static ECHO_STDIN_TO_STDOUT: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("echo_stdin_to_stdout.cmd"); + script + .write_str(indoc::indoc!( + r#" + @echo off + setlocal DisableDelayedExpansion + + set /p input= + echo %input% + "# + )) + .unwrap(); + script +}); + +#[cfg(unix)] +pub static EXIT_CODE: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("exit_code.sh"); + script + .write_str(indoc::indoc!( + r#" + #!/usr/bin/env bash + exit "$1" + "# + )) + .unwrap(); + script +}); + +#[cfg(windows)] +pub static EXIT_CODE: Lazy = Lazy::new(|| { + let script = TEMP_SCRIPT_DIR.child("exit_code.cmd"); + script.write_str(r"EXIT /B %1").unwrap(); + script +}); + +pub static DOES_NOT_EXIST_BIN: Lazy = + Lazy::new(|| TEMP_SCRIPT_DIR.child("does_not_exist_bin")); diff --git a/tests/cli/utils.rs b/tests/cli/utils.rs index 99f9c92..dbe038c 100644 --- a/tests/cli/utils.rs +++ b/tests/cli/utils.rs @@ -1,145 +1,14 @@ -use crate::cli::fixtures::DistantServerCtx; use once_cell::sync::Lazy; use predicates::prelude::*; -use std::{ - env, io, - path::PathBuf, - process::{Command, Stdio}, - sync::mpsc, - time::{Duration, Instant}, -}; + +mod reader; +pub use reader::ThreadedReader; /// Predicate that checks for a single line that is a failure pub static FAILURE_LINE: Lazy = - Lazy::new(|| regex_pred(r"^Failed \(.*\): '.*'\.\n$")); + Lazy::new(|| regex_pred(r"^.*\n$")); /// Produces a regex predicate using the given string pub fn regex_pred(s: &str) -> predicates::str::RegexPredicate { predicate::str::is_match(s).unwrap() } - -/// Creates a random tenant name -pub fn random_tenant() -> String { - format!("test-tenant-{}", rand::random::()) -} - -/// Initializes logging (should only call once) -pub fn init_logging(path: impl Into) -> flexi_logger::LoggerHandle { - use flexi_logger::{FileSpec, LevelFilter, LogSpecification, Logger}; - let modules = &["distant", "distant_core", "distant_ssh2"]; - - // Disable logging for everything but our binary, which is based on verbosity - let mut builder = LogSpecification::builder(); - builder.default(LevelFilter::Off); - - // For each module, configure logging - for module in modules { - builder.module(module, LevelFilter::Trace); - } - - // Create our logger, but don't initialize yet - let logger = Logger::with(builder.build()) - .format_for_files(flexi_logger::opt_format) - .log_to_file(FileSpec::try_from(path).expect("Failed to create log file spec")); - - logger.start().expect("Failed to initialize logger") -} - -pub fn friendly_recv_line( - receiver: &mpsc::Receiver, - duration: Duration, -) -> io::Result { - let start = Instant::now(); - loop { - if let Ok(line) = receiver.try_recv() { - break Ok(line); - } - - if start.elapsed() > duration { - return Err(io::Error::new( - io::ErrorKind::TimedOut, - format!("Failed to receive line after {}s", duration.as_secs_f32()), - )); - } - - std::thread::yield_now(); - } -} - -pub fn spawn_line_reader(mut reader: T) -> mpsc::Receiver -where - T: std::io::Read + Send + 'static, -{ - let (tx, rx) = mpsc::channel(); - std::thread::spawn(move || { - let mut buf = String::new(); - let mut tmp = [0; 1024]; - while let Ok(n) = reader.read(&mut tmp) { - if n == 0 { - break; - } - - let data = String::from_utf8_lossy(&tmp[..n]); - buf.push_str(data.as_ref()); - - // Send all complete lines - if let Some(idx) = buf.rfind('\n') { - let remaining = buf.split_off(idx + 1); - for line in buf.lines() { - tx.send(line.to_string()).unwrap(); - } - buf = remaining; - } - } - - // If something is remaining at end, also send it - if !buf.is_empty() { - tx.send(buf).unwrap(); - } - }); - - rx -} - -/// Produces a new command for distant using the given subcommand -pub fn distant_subcommand(ctx: &DistantServerCtx, subcommand: &str) -> Command { - let mut cmd = Command::new(cargo_bin(env!("CARGO_PKG_NAME"))); - cmd.arg(subcommand) - .args(&["--session", "environment"]) - .env("DISTANT_HOST", ctx.addr.ip().to_string()) - .env("DISTANT_PORT", ctx.addr.port().to_string()) - .env("DISTANT_KEY", ctx.key.as_str()) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()); - cmd -} - -/// Look up the path to a cargo-built binary within an integration test -/// -/// Taken from https://github.com/assert-rs/assert_cmd/blob/036ef47b8ad170dcaf4eaf4412c0b48fd5b6ef6e/src/cargo.rs#L199 -fn cargo_bin>(name: S) -> PathBuf { - cargo_bin_str(name.as_ref()) -} - -fn cargo_bin_str(name: &str) -> PathBuf { - let env_var = format!("CARGO_BIN_EXE_{}", name); - std::env::var_os(&env_var) - .map(|p| p.into()) - .unwrap_or_else(|| target_dir().join(format!("{}{}", name, env::consts::EXE_SUFFIX))) -} - -// Adapted from -// https://github.com/rust-lang/cargo/blob/485670b3983b52289a2f353d589c57fae2f60f82/tests/testsuite/support/mod.rs#L507 -fn target_dir() -> PathBuf { - env::current_exe() - .ok() - .map(|mut path| { - path.pop(); - if path.ends_with("deps") { - path.pop(); - } - path - }) - .unwrap() -} diff --git a/tests/cli/utils/reader.rs b/tests/cli/utils/reader.rs new file mode 100644 index 0000000..b7ab94c --- /dev/null +++ b/tests/cli/utils/reader.rs @@ -0,0 +1,106 @@ +use std::{ + io, + io::{BufRead, BufReader, Read}, + sync::mpsc, + thread, + time::{Duration, Instant}, +}; + +pub struct ThreadedReader { + #[allow(dead_code)] + handle: thread::JoinHandle>, + rx: mpsc::Receiver, +} + +impl ThreadedReader { + pub fn new(reader: R) -> Self + where + R: Read + Send + 'static, + { + let (tx, rx) = mpsc::channel(); + let handle = thread::spawn(move || { + let mut reader = BufReader::new(reader); + let mut line = String::new(); + loop { + match reader.read_line(&mut line) { + Ok(0) => break Ok(()), + Ok(_) => { + // Consume the line and create an empty line to + // be filled in next time + let line2 = line; + line = String::new(); + + if let Err(line) = tx.send(line2) { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "Failed to pass along line because channel closed! Line: '{}'", + line.0 + ), + )); + } + } + Err(x) => return Err(x), + } + } + }); + Self { handle, rx } + } + + /// Tries to read the next line if available + pub fn try_read_line(&mut self) -> Option { + self.rx.try_recv().ok() + } + + /// Reads the next line, waiting for at minimum "timeout" + pub fn try_read_line_timeout(&mut self, timeout: Duration) -> Option { + let start_time = Instant::now(); + let mut checked_at_least_once = false; + + while !checked_at_least_once || start_time.elapsed() < timeout { + if let Some(line) = self.try_read_line() { + return Some(line); + } + + checked_at_least_once = true; + } + + None + } + + /// Reads the next line, waiting for at minimum "timeout" before panicking + pub fn read_line_timeout(&mut self, timeout: Duration) -> String { + let start_time = Instant::now(); + let mut checked_at_least_once = false; + + while !checked_at_least_once || start_time.elapsed() < timeout { + if let Some(line) = self.try_read_line() { + return line; + } + + checked_at_least_once = true; + } + + panic!("Reached timeout of {:?}", timeout); + } + + /// Reads the next line, waiting for at minimum default timeout before panicking + #[allow(dead_code)] + pub fn read_line_default_timeout(&mut self) -> String { + self.read_line_timeout(Self::default_timeout()) + } + + /// Creates a new duration representing a default timeout for the threaded reader + pub fn default_timeout() -> Duration { + Duration::from_millis(250) + } + + /// Waits for reader to shut down, returning the result + #[allow(dead_code)] + pub fn wait(self) -> io::Result<()> { + match self.handle.join() { + Ok(x) => x, + Err(x) => std::panic::resume_unwind(x), + } + } +}