Refactor to use distant manager (#112)

pull/118/head
Chip Senkbeil 2 years ago committed by GitHub
parent a2e17ba35b
commit ea2e128bc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,4 @@
[profile.ci]
fail-fast = false
retries = 2
slow-timeout = { period = "60s", terminate-after = 3 }

@ -1,49 +0,0 @@
name: CI (All)
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
clippy:
name: Lint with clippy
runs-on: ubuntu-latest
env:
RUSTFLAGS: -Dwarnings
steps:
- uses: actions/checkout@v2
- name: Install Rust (clippy)
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
components: clippy
- uses: Swatinem/rust-cache@v1
- name: Check Cargo availability
run: cargo --version
- name: distant-core (all features)
run: cargo clippy -p distant-core --all-targets --verbose --all-features
- name: distant-ssh2 (all features)
run: cargo clippy -p distant-ssh2 --all-targets --verbose --all-features
- name: distant (all features)
run: cargo clippy --all-targets --verbose --all-features
rustfmt:
name: Verify code formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Install Rust (rustfmt)
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
components: rustfmt
- uses: Swatinem/rust-cache@v1
- name: Check Cargo availability
run: cargo --version
- run: cargo fmt --all -- --check

@ -1,46 +0,0 @@
name: CI (Linux)
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
tests:
name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}"
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- { rust: stable, os: ubuntu-latest }
- { rust: 1.51.0, os: ubuntu-latest }
steps:
- uses: actions/checkout@v2
- name: Install Rust ${{ matrix.rust }}
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ matrix.rust }}
- uses: Swatinem/rust-cache@v1
- name: Check Cargo availability
run: cargo --version
- name: Run core tests (default features)
run: cargo test --release --verbose -p distant-core
- name: Run core tests (all features)
run: cargo test --release --verbose --all-features -p distant-core
- name: Ensure /run/sshd exists on Unix
run: mkdir -p /run/sshd
- name: Run ssh2 tests (default features)
run: cargo test --release --verbose -p distant-ssh2
- name: Run ssh2 tests (all features)
run: cargo test --release --verbose --all-features -p distant-ssh2
- name: Run CLI tests
run: cargo test --release --verbose
shell: bash
- name: Run CLI tests (no default features)
run: cargo test --release --verbose --no-default-features
shell: bash

@ -1,43 +0,0 @@
name: CI (MacOS)
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
tests:
name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}"
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- { rust: stable, os: macos-latest }
steps:
- uses: actions/checkout@v2
- name: Install Rust ${{ matrix.rust }}
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ matrix.rust }}
- uses: Swatinem/rust-cache@v1
- name: Check Cargo availability
run: cargo --version
- name: Run core tests (default features)
run: cargo test --release --verbose -p distant-core
- name: Run core tests (all features)
run: cargo test --release --verbose --all-features -p distant-core
- name: Run ssh2 tests (default features)
run: cargo test --release --verbose -p distant-ssh2
- name: Run ssh2 tests (all features)
run: cargo test --release --verbose --all-features -p distant-ssh2
- name: Run CLI tests
run: cargo test --release --verbose
shell: bash
- name: Run CLI tests (no default features)
run: cargo test --release --verbose --no-default-features
shell: bash

@ -1,45 +0,0 @@
name: CI (Windows)
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
tests:
name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}"
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- { rust: stable, os: windows-latest, target: x86_64-pc-windows-msvc }
steps:
- uses: actions/checkout@v2
- name: Install Rust ${{ matrix.rust }}
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ matrix.rust }}
target: ${{ matrix.target }}
- uses: Swatinem/rust-cache@v1
- name: Check Cargo availability
run: cargo --version
- uses: Vampire/setup-wsl@v1
- name: Run distant-core tests (default features)
run: cargo test --release --verbose -p distant-core
- name: Run distant-core tests (all features)
run: cargo test --release --verbose --all-features -p distant-core
- name: Build distant-ssh2 (default features)
run: cargo build --release --verbose -p distant-ssh2
- name: Build distant-ssh2 (all features)
run: cargo build --release --verbose --all-features -p distant-ssh2
- name: Build CLI
run: cargo build --release --verbose
shell: bash
- name: Build CLI (no default features)
run: cargo build --release --verbose --no-default-features
shell: bash

@ -0,0 +1,191 @@
name: CI
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
clippy:
name: "Lint with clippy (${{ matrix.os }})"
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- { os: windows-latest }
- { os: ubuntu-latest }
env:
RUSTFLAGS: -Dwarnings
steps:
- name: Ensure windows git checkout keeps \n line ending
run: |
git config --system core.autocrlf false
git config --system core.eol lf
if: matrix.os == 'windows-latest'
- uses: actions/checkout@v2
- name: Install Rust (clippy)
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
components: clippy
- uses: Swatinem/rust-cache@v1
- name: Check Cargo availability
run: cargo --version
- name: distant-core (all features)
run: cargo clippy -p distant-core --all-targets --verbose --all-features
- name: distant-ssh2 (all features)
run: cargo clippy -p distant-ssh2 --all-targets --verbose --all-features
- name: distant (all features)
run: cargo clippy --all-targets --verbose --all-features
rustfmt:
name: "Verify code formatting (${{ matrix.os }})"
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- { os: windows-latest }
- { os: ubuntu-latest }
steps:
- name: Ensure windows git checkout keeps \n line ending
run: |
git config --system core.autocrlf false
git config --system core.eol lf
if: matrix.os == 'windows-latest'
- uses: actions/checkout@v2
- name: Install Rust (rustfmt)
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
components: rustfmt
- uses: Swatinem/rust-cache@v1
- name: Check Cargo availability
run: cargo --version
- run: cargo fmt --all -- --check
tests:
name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}"
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- { rust: stable, os: windows-latest, target: x86_64-pc-windows-msvc }
- { rust: stable, os: macos-latest }
- { rust: stable, os: ubuntu-latest }
- { rust: 1.61.0, os: ubuntu-latest }
steps:
- uses: actions/checkout@v2
- name: Install Rust ${{ matrix.rust }}
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ matrix.rust }}
target: ${{ matrix.target }}
- uses: taiki-e/install-action@v1
with:
tool: cargo-nextest
- uses: Swatinem/rust-cache@v1
- name: Check Cargo availability
run: cargo --version
- name: Install OpenSSH on Windows
run: |
# From https://gist.github.com/inevity/a0d7b9f1c5ba5a813917b92736122797
Add-Type -AssemblyName System.IO.Compression.FileSystem
function Unzip
{
param([string]$zipfile, [string]$outpath)
[System.IO.Compression.ZipFile]::ExtractToDirectory($zipfile, $outpath)
}
$url = 'https://github.com/PowerShell/Win32-OpenSSH/releases/latest/'
$request = [System.Net.WebRequest]::Create($url)
$request.AllowAutoRedirect=$false
$response=$request.GetResponse()
$file = $([String]$response.GetResponseHeader("Location")).Replace('tag','download') + '/OpenSSH-Win64.zip'
$client = new-object system.Net.Webclient;
$client.DownloadFile($file ,"c:\\OpenSSH-Win64.zip")
Unzip "c:\\OpenSSH-Win64.zip" "C:\Program Files\"
mv "c:\\Program Files\OpenSSH-Win64" "C:\Program Files\OpenSSH\"
powershell.exe -ExecutionPolicy Bypass -File "C:\Program Files\OpenSSH\install-sshd.ps1"
New-NetFirewallRule -Name sshd -DisplayName 'OpenSSH Server (sshd)' -Enabled True -Direction Inbound -Protocol TCP -Action Allow -LocalPort 22,49152-65535
net start sshd
Set-Service sshd -StartupType Automatic
Set-Service ssh-agent -StartupType Automatic
cd "C:\Program Files\OpenSSH\"
Powershell.exe -ExecutionPolicy Bypass -Command '. .\FixHostFilePermissions.ps1 -Confirm:$false'
$registryPath = "HKLM:\SOFTWARE\OpenSSH\"
$Name = "DefaultShell"
$value = "C:\windows\System32\WindowsPowerShell\v1.0\powershell.exe"
IF(!(Test-Path $registryPath))
{
New-Item -Path $registryPath -Force
New-ItemProperty -Path $registryPath -Name $name -Value $value -PropertyType String -Force
} ELSE {
New-ItemProperty -Path $registryPath -Name $name -Value $value -PropertyType String -Force
}
shell: pwsh
if: matrix.os == 'windows-latest'
- name: Run net tests (default features)
run: cargo nextest run --profile ci --release --verbose -p distant-net
- name: Run core tests (default features)
run: cargo nextest run --profile ci --release --verbose -p distant-core
- name: Run core tests (all features)
run: cargo nextest run --profile ci --release --verbose --all-features -p distant-core
- name: Ensure /run/sshd exists on Unix
run: mkdir -p /run/sshd
if: matrix.os == 'ubuntu-latest'
- name: Run ssh2 client tests (default features)
run: cargo nextest run --profile ci --release --verbose -p distant-ssh2 ssh2::client
- name: Run ssh2 client tests (all features)
run: cargo nextest run --profile ci --release --verbose --all-features -p distant-ssh2 ssh2::client
- name: Run CLI tests
run: cargo nextest run --profile ci --release --verbose
- name: Run CLI tests (no default features)
run: cargo nextest run --profile ci --release --verbose --no-default-features
ssh-launch-tests:
name: "Test ssh launch using Rust ${{ matrix.rust }} on ${{ matrix.os }}"
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- { rust: stable, os: macos-latest }
- { rust: stable, os: ubuntu-latest }
- { rust: 1.61.0, os: ubuntu-latest }
steps:
- uses: actions/checkout@v2
- name: Install Rust ${{ matrix.rust }}
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ matrix.rust }}
- uses: taiki-e/install-action@v1
with:
tool: cargo-nextest
- uses: Swatinem/rust-cache@v1
- name: Check Cargo availability
run: cargo --version
- name: Install distant cli for use in launch tests
run: |
cargo install --path .
echo "DISTANT_PATH=$HOME/.cargo/bin/distant" >> $GITHUB_ENV
- name: Run ssh2 launch tests (default features)
run: cargo nextest run --profile ci --release --verbose -p distant-ssh2 ssh2::launched
- name: Run ssh2 launch tests (all features)
run: cargo nextest run --profile ci --release --verbose --all-features -p distant-ssh2 ssh2::launched

@ -5,6 +5,11 @@
* `make` - needed to build openssl via vendor feature
* `perl` - needed to build openssl via vendor feature
### FreeBSD
* `gmake` - needed to build openssl via vender feature (`pkg install gmake`)
## Using Cargo
A debug build is straightforward:

@ -7,10 +7,39 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- `distant manager` subcommand
- `distant manager service` subcommand contains functionality to install,
start, stop, and uninstall the manager as a service on various operating
systems
- `distant manager info` will print information about an active connection
- `distant manager list` will print information about all connections
- `distant generate` subcommand
- `distant generate schema` will produce JSON schema for server
request/response
- `distant generate completion` will produce completion file for a specific
shell
### Changed
- `distant launch` is now `distant client launch`
- `distant action` is now `distant client action`
- `distant shell` is now `distant client shell`
- `distant listen` is now `distant server listen`
- Minimum supported rust version (MSRV) has been bumped to `1.61.0`
### Fixed
- Shell no longer has issues with fancier command prompts and other
terminal-oriented printing as `TERM=x256-color` is now set by default
### Removed
- Networking directly from distant client to distant server. All connections
are now facilitated by the manager interface with client -> manager -> server
- Environment variable output as part of launch is now gone as the connection
is now being managed, so there is no need to export session information
## [0.16.4] - 2022-06-01
### Added
- Dockerfile using Alpine linux with a basic install of distant, tagged as
@ -134,6 +163,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
pending upon full channel and no longer locks up
- stdout, stderr, and stdin of `RemoteProcess` no longer cause deadlock
[Unreleased]: https://github.com/chipsenkbeil/distant/compare/v0.15.1...HEAD
[Unreleased]: https://github.com/chipsenkbeil/distant/compare/v0.17.0...HEAD
[0.17.0]: https://github.com/chipsenkbeil/distant/compare/v0.16.4...v0.17.0
[0.16.4]: https://github.com/chipsenkbeil/distant/compare/v0.16.3...v0.16.4
[0.16.3]: https://github.com/chipsenkbeil/distant/compare/v0.16.2...v0.16.3
[0.16.2]: https://github.com/chipsenkbeil/distant/compare/v0.16.1...v0.16.2
[0.16.1]: https://github.com/chipsenkbeil/distant/compare/v0.16.0...v0.16.1
[0.16.0]: https://github.com/chipsenkbeil/distant/compare/v0.15.1...v0.16.0
[0.15.1]: https://github.com/chipsenkbeil/distant/compare/v0.15.0...v0.15.1
[0.15.0]: https://github.com/chipsenkbeil/distant/compare/v0.14.0...v0.15.0

816
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -3,16 +3,16 @@ name = "distant"
description = "Operate on a remote computer through file and process manipulation"
categories = ["command-line-utilities"]
keywords = ["cli"]
version = "0.16.4"
version = "0.17.0"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2018"
edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant"
repository = "https://github.com/chipsenkbeil/distant"
readme = "README.md"
license = "MIT OR Apache-2.0"
[workspace]
members = ["distant-core", "distant-ssh2"]
members = ["distant-core", "distant-net", "distant-ssh2"]
[profile.release]
opt-level = 'z'
@ -25,31 +25,44 @@ libssh = ["distant-ssh2/libssh"]
ssh2 = ["distant-ssh2/ssh2"]
[dependencies]
anyhow = "1.0.58"
async-trait = "0.1.56"
clap = { version = "3.2.5", features = ["derive"] }
clap_complete = "3.2.3"
config = { version = "0.13.1", default-features = false, features = ["toml"] }
derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error", "is_variant"] }
distant-core = { version = "=0.16.4", path = "distant-core", features = ["structopt"] }
dialoguer = { version = "0.10.1", default-features = false }
distant-core = { version = "=0.17.0", path = "distant-core", features = ["clap", "schemars"] }
directories = "4.0.1"
flexi_logger = "0.18.1"
indoc = "1.0.6"
log = "0.4.17"
once_cell = "1.12.0"
rand = { version = "0.8.5", features = ["getrandom"] }
rpassword = "6.0.1"
serde = { version = "1.0.137", features = ["derive"] }
serde_json = "1.0.81"
structopt = "0.3.26"
strum = { version = "0.21.0", features = ["derive"] }
sysinfo = "0.23.13"
shell-words = "1.0"
service-manager = { version = "0.1.2", features = ["clap", "serde"] }
tabled = "0.7.0"
tokio = { version = "1.19.0", features = ["full"] }
toml_edit = { version = "0.14.4", features = ["serde"] }
terminal_size = "0.1.17"
termwiz = "0.15.0"
uriparse = "0.6.4"
which = "4.2.5"
winsplit = "0.1"
whoami = "1.2.1"
# Optional native SSH functionality
distant-ssh2 = { version = "=0.16.4", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true }
distant-ssh2 = { version = "=0.17.0", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true }
[target.'cfg(unix)'.dependencies]
fork = "0.1.19"
# [target.'cfg(windows)'.dependencies]
# sysinfo = "0.23.2"
[target.'cfg(windows)'.dependencies]
sysinfo = "0.24.7"
windows-service = "0.5"
[dev-dependencies]
assert_cmd = "2.0.4"

@ -1,26 +1,15 @@
# distant - remotely edit files and run programs
[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![RustC 1.51+][distant_rustc_img]][distant_rustc_lnk]
| Operating System | Status |
| ---------------- | ------------------------------------------------------------------ |
| MacOS (x86, ARM) | [![MacOS CI][distant_ci_macos_img]][distant_ci_macos_lnk] |
| Linux (x86) | [![Linux CI][distant_ci_linux_img]][distant_ci_linux_lnk] |
| Windows (x86) | [![Windows CI][distant_ci_windows_img]][distant_ci_windows_lnk] |
[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![CI][distant_ci_img]][distant_ci_lnk] [![RustC 1.61+][distant_rustc_img]][distant_rustc_lnk]
[distant_crates_img]: https://img.shields.io/crates/v/distant.svg
[distant_crates_lnk]: https://crates.io/crates/distant
[distant_doc_img]: https://docs.rs/distant/badge.svg
[distant_doc_lnk]: https://docs.rs/distant
[distant_rustc_img]: https://img.shields.io/badge/distant-rustc_1.51+-lightgray.svg
[distant_rustc_lnk]: https://blog.rust-lang.org/2021/03/25/Rust-1.51.0.html
[distant_ci_macos_img]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-macos.yml/badge.svg
[distant_ci_macos_lnk]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-macos.yml
[distant_ci_linux_img]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-linux.yml/badge.svg
[distant_ci_linux_lnk]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-linux.yml
[distant_ci_windows_img]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-windows.yml/badge.svg
[distant_ci_windows_lnk]: https://github.com/chipsenkbeil/distant/actions/workflows/ci-windows.yml
[distant_ci_img]: https://github.com/chipsenkbeil/distant/actions/workflows/ci.yml/badge.svg
[distant_ci_lnk]: https://github.com/chipsenkbeil/distant/actions/workflows/ci.yml
[distant_rustc_img]: https://img.shields.io/badge/distant-rustc_1.61+-lightgray.svg
[distant_rustc_lnk]: https://blog.rust-lang.org/2022/05/19/Rust-1.61.0.html
🚧 **(Alpha stage software) This program is in rapid development and may break or change frequently!** 🚧
@ -31,7 +20,7 @@ a command to start a server and configure the local client to be able to
talk to the server.
- Asynchronous in nature, powered by [`tokio`](https://tokio.rs/)
- Data is serialized to send across the wire via [`CBOR`](https://cbor.io/)
- Data is serialized to send across the wire via [`msgpack`](https://msgpack.org/)
- Encryption & authentication are handled via
[XChaCha20Poly1305](https://tools.ietf.org/html/rfc8439) for an authenticated
encryption scheme via
@ -59,29 +48,83 @@ cargo install distant
Alternatively, you can clone this repository and build from source following
the [build guide](./BUILDING.md).
## Examples
## Example
### Starting the manager
In order to facilitate communication between a client and server, you first
need to start the manager. This can be done in one of two ways:
1. Leverage the `service` functionality to spawn the manager using one of the
following supported service management platforms:
- [`sc.exe`](https://docs.microsoft.com/en-us/previous-versions/windows/it-pro/windows-server-2012-r2-and-2012/cc754599(v=ws.11)) for use with [Window Service](https://en.wikipedia.org/wiki/Windows_service) (Windows)
- [Launchd](https://en.wikipedia.org/wiki/Launchd) (MacOS)
- [systemd](https://en.wikipedia.org/wiki/Systemd) (Linux)
- [OpenRC](https://en.wikipedia.org/wiki/OpenRC) (Linux)
- [rc.d](https://en.wikipedia.org/wiki/Init#Research_Unix-style/BSD-style) (FreeBSD)
2. Run the manager manually by using the `listen` subcommand
Launch a remote instance of `distant`. Calling `launch` will do the following:
#### Service management
1. Ssh into the specified host (in the below example, `my.example.com`)
2. Execute `distant listen --host ssh` on the remote machine
3. Receive on the local machine the credentials needed to connect to the server
4. Depending on the options specified, print/store/use the session settings so
future calls to `distant action` can connect
```bash
# If you want to install the manager as a service, you can use the service
# interface available directly from the CLI
#
# By default, this will install a system-level service, which means that you
# will need elevated permissions to both install AND communicate with the
# manager
distant manager service install
# If you want to maintain a user-level manager service, you can include the
# --user flag. Note that this is only supported on MacOS (via launchd) and
# Linux (via systemd)
distant manager service install --user
# ........
# Once you have installed the service, you will normally need to start it
# manually or restart your machine to trigger startup on boot
distant manager service start # --user if you are working with user-level
```
#### Manual start
```bash
# Connects to my.example.com on port 22 via SSH to start a new session
# and print out information to configure your system to talk to it
distant launch my.example.com
# NOTE: If you are using sh, bash, or zsh, you can automatically set the
# appropriate environment variables using the following
eval "$(distant launch my.example.com)"
# After the session is established, you can perform different operations
# on the remote machine via `distant action {command} [args]`
distant action copy path/to/file new/path/to/file
distant action spawn -- echo 'Hello, this is from the other side'
# If you choose to run the manager without a service management platform, you
# can either run the manager in the foreground or provide --daemon to spawn and
# detach the manager
# Run in the foreground
distant manager listen
# Detach the manager where it will not terminate even if the parent exits
distant manager listen --daemon
```
### Interacting with a remote machine
Once you have a manager listening for client requests, you can begin
interacting with the manager, spawn and/or connect to servers, and interact
with remote machines.
```bash
# Connect to my.example.com on port 22 via SSH and start a distant server
distant client launch ssh://my.example.com
# After the connection is established, you can perform different operations
# on the remote machine via `distant client action {command} [args]`
distant client action copy path/to/file new/path/to/file
distant client action spawn -- echo 'Hello, this is from the other side'
# Opening a shell to the remote machine is trivial
distant client shell
# If you have more than one connection open, you can switch between active
# connections by using the `select` subcommand
distant client select '<ID>'
# For programmatic use, a REPL following the JSON API is available
distant client repl --format json
```
## License

@ -3,20 +3,24 @@ name = "distant-core"
description = "Core library for distant, enabling operation on a remote computer through file and process manipulation"
categories = ["network-programming"]
keywords = ["api", "async"]
version = "0.16.4"
version = "0.17.0"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2018"
edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant"
repository = "https://github.com/chipsenkbeil/distant"
readme = "README.md"
license = "MIT OR Apache-2.0"
[features]
schemars = ["dep:schemars", "distant-net/schemars"]
[dependencies]
async-trait = "0.1.56"
bitflags = "1.3.2"
bytes = "1.1.0"
chacha20poly1305 = "0.9.0"
ciborium = "0.2.0"
derive_more = { version = "0.99.16", default-features = false, features = ["deref", "deref_mut", "display", "from", "error", "into_iterator", "is_variant"] }
derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] }
distant-net = { version = "=0.17.0", path = "../distant-net" }
futures = "0.3.16"
hex = "0.4.3"
log = "0.4.14"
@ -26,14 +30,19 @@ once_cell = "1.8.0"
portable-pty = "0.7.0"
rand = { version = "0.8.4", features = ["getrandom"] }
serde = { version = "1.0.126", features = ["derive"] }
serde_bytes = "0.11.6"
serde_json = "1.0.64"
shell-words = "1.0"
strum = { version = "0.21.0", features = ["derive"] }
tokio = { version = "1.12.0", features = ["full"] }
tokio-util = { version = "0.6.7", features = ["codec"] }
uriparse = "0.6.4"
walkdir = "2.3.2"
winsplit = "0.1"
# Optional dependencies based on features
structopt = { version = "0.3.22", optional = true }
clap = { version = "3.2.5", features = ["derive"], optional = true }
schemars = { version = "0.8.10", optional = true }
[dev-dependencies]
assert_fs = "1.0.4"

@ -1,13 +1,13 @@
# distant core
[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.51.0][distant_rustc_img]][distant_rustc_lnk]
[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.61.0][distant_rustc_img]][distant_rustc_lnk]
[distant_crates_img]: https://img.shields.io/crates/v/distant-core.svg
[distant_crates_lnk]: https://crates.io/crates/distant-core
[distant_doc_img]: https://docs.rs/distant-core/badge.svg
[distant_doc_lnk]: https://docs.rs/distant-core
[distant_rustc_img]: https://img.shields.io/badge/distant_core-rustc_1.51+-lightgray.svg
[distant_rustc_lnk]: https://blog.rust-lang.org/2021/03/25/Rust-1.51.0.html
[distant_rustc_img]: https://img.shields.io/badge/distant_core-rustc_1.61+-lightgray.svg
[distant_rustc_lnk]: https://blog.rust-lang.org/2022/05/19/Rust-1.61.0.html
Library that powers the [`distant`](https://github.com/chipsenkbeil/distant)
binary.
@ -16,15 +16,11 @@ binary.
## Details
The `distant` library supplies a mixture of functionality and data to run
servers that operate on remote machines and clients that talk to them.
- Asynchronous in nature, powered by [`tokio`](https://tokio.rs/)
- Data is serialized to send across the wire via [`CBOR`](https://cbor.io/)
- Encryption & authentication are handled via
[XChaCha20Poly1305](https://tools.ietf.org/html/rfc8439) for an authenticated
encryption scheme via
[RustCrypto/ChaCha20Poly1305](https://github.com/RustCrypto/AEADs/tree/master/chacha20poly1305)
The `distant-core` library supplies the client, manager, and server
implementations for use with the distant API in order to communicate with
remote machines and perform actions. This library acts as the primary
implementation that powers the CLI, but is also available for other extensions
like `distant-ssh2`.
## Installation
@ -32,40 +28,43 @@ You can import the dependency by adding the following to your `Cargo.toml`:
```toml
[dependencies]
distant-core = "0.16"
distant-core = "0.17"
```
## Features
Currently, the library supports the following features:
- `structopt`: generates [`StructOpt`](https://github.com/TeXitoi/structopt)
bindings for `RequestData` (used by cli to expose request actions)
- `clap`: generates [`Clap`](https://github.com/clap-rs) bindings for
`DistantRequestData` (used by cli to expose request actions)
- `schemars`: derives the `schemars::JsonSchema` interface on
`DistantMsg`, `DistantRequestData`, and `DistantResponseData` data types
By default, no features are enabled on the library.
## Examples
Below is an example of connecting to a distant server over TCP:
Below is an example of connecting to a distant server over TCP without any
encryption or authentication:
```rust
use distant_core::{Session, SessionChannelExt, SecretKey32, XChaCha20Poly1305Codec};
use std::net::SocketAddr;
// 32-byte secret key paresd from hex, used for a specific codec
let key: SecretKey32 = "DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF".parse().unwrap();
let codec = XChaCha20Poly1305Codec::from(key);
use distant_core::{
DistantClient,
DistantChannelExt,
net::{PlainCodec, TcpClientExt},
};
use std::{net::SocketAddr, path::Path};
// Connect to a server located at example.com on port 8080 that is using
// no encryption or authentication (PlainCodec)
let addr: SocketAddr = "example.com:8080".parse().unwrap();
let mut session = Session::tcp_connect(addr, codec).await.unwrap();
// Append text to a file, representing request as <tenant>
// NOTE: This method comes from SessionChannelExt
session.append_file_text(
"<tenant>",
"path/to/file.txt".to_string(),
"new contents"
).await.expect("Failed to append to file");
let mut client = DistantClient::connect(addr, PlainCodec).await
.expect("Failed to connect");
// Append text to a file
// NOTE: This method comes from DistantChannelExt
client.append_file_text(Path::new("path/to/file.txt"), "new contents").await
.expect("Failed to append to file");
```
## License

@ -0,0 +1,626 @@
use crate::{
data::{ChangeKind, DirEntry, Environment, Error, Metadata, ProcessId, PtySize, SystemInfo},
ConnectionId, DistantMsg, DistantRequestData, DistantResponseData,
};
use async_trait::async_trait;
use distant_net::{Reply, Server, ServerCtx};
use log::*;
use std::{io, path::PathBuf, sync::Arc};
mod local;
pub use local::LocalDistantApi;
mod reply;
use reply::DistantSingleReply;
/// Represents the context provided to the [`DistantApi`] for incoming requests
pub struct DistantCtx<T> {
pub connection_id: ConnectionId,
pub reply: Box<dyn Reply<Data = DistantResponseData>>,
pub local_data: Arc<T>,
}
/// Represents a server that leverages an API compliant with `distant`
pub struct DistantApiServer<T, D>
where
T: DistantApi<LocalData = D>,
{
api: T,
}
impl<T, D> DistantApiServer<T, D>
where
T: DistantApi<LocalData = D>,
{
pub fn new(api: T) -> Self {
Self { api }
}
}
impl DistantApiServer<LocalDistantApi, <LocalDistantApi as DistantApi>::LocalData> {
/// Creates a new server using the [`LocalDistantApi`] implementation
pub fn local() -> io::Result<Self> {
Ok(Self {
api: LocalDistantApi::initialize()?,
})
}
}
#[inline]
fn unsupported<T>(label: &str) -> io::Result<T> {
Err(io::Error::new(
io::ErrorKind::Unsupported,
format!("{} is unsupported", label),
))
}
/// Interface to support the suite of functionality available with distant,
/// which can be used to build other servers that are compatible with distant
#[async_trait]
pub trait DistantApi {
type LocalData: Send + Sync;
/// Invoked whenever a new connection is established, providing a mutable reference to the
/// newly-created local data. This is a way to support modifying local data before it is used.
#[allow(unused_variables)]
async fn on_accept(&self, local_data: &mut Self::LocalData) {}
/// Reads bytes from a file.
///
/// * `path` - the path to the file
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn read_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
unsupported("read_file")
}
/// Reads bytes from a file as text.
///
/// * `path` - the path to the file
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn read_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<String> {
unsupported("read_file_text")
}
/// Writes bytes to a file, overwriting the file if it exists.
///
/// * `path` - the path to the file
/// * `data` - the data to write
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn write_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
unsupported("write_file")
}
/// Writes text to a file, overwriting the file if it exists.
///
/// * `path` - the path to the file
/// * `data` - the data to write
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn write_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: String,
) -> io::Result<()> {
unsupported("write_file_text")
}
/// Writes bytes to the end of a file, creating it if it is missing.
///
/// * `path` - the path to the file
/// * `data` - the data to append
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn append_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
unsupported("append_file")
}
/// Writes bytes to the end of a file, creating it if it is missing.
///
/// * `path` - the path to the file
/// * `data` - the data to append
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn append_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: String,
) -> io::Result<()> {
unsupported("append_file_text")
}
/// Reads entries from a directory.
///
/// * `path` - the path to the directory
/// * `depth` - how far to traverse the directory, 0 being unlimited
/// * `absolute` - if true, will return absolute paths instead of relative paths
/// * `canonicalize` - if true, will canonicalize entry paths before returned
/// * `include_root` - if true, will include the directory specified in the entries
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn read_dir(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
depth: usize,
absolute: bool,
canonicalize: bool,
include_root: bool,
) -> io::Result<(Vec<DirEntry>, Vec<io::Error>)> {
unsupported("read_dir")
}
/// Creates a directory.
///
/// * `path` - the path to the directory
/// * `all` - if true, will create all missing parent components
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn create_dir(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
all: bool,
) -> io::Result<()> {
unsupported("create_dir")
}
/// Copies some file or directory.
///
/// * `src` - the path to the file or directory to copy
/// * `dst` - the path where the copy will be placed
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn copy(
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
unsupported("copy")
}
/// Removes some file or directory.
///
/// * `path` - the path to a file or directory
/// * `force` - if true, will remove non-empty directories
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn remove(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
force: bool,
) -> io::Result<()> {
unsupported("remove")
}
/// Renames some file or directory.
///
/// * `src` - the path to the file or directory to rename
/// * `dst` - the new name for the file or directory
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn rename(
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
unsupported("rename")
}
/// Watches a file or directory for changes.
///
/// * `path` - the path to the file or directory
/// * `recursive` - if true, will watch for changes within subdirectories and beyond
/// * `only` - if non-empty, will limit reported changes to those included in this list
/// * `except` - if non-empty, will limit reported changes to those not included in this list
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn watch(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
recursive: bool,
only: Vec<ChangeKind>,
except: Vec<ChangeKind>,
) -> io::Result<()> {
unsupported("watch")
}
/// Removes a file or directory from being watched.
///
/// * `path` - the path to the file or directory
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn unwatch(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<()> {
unsupported("unwatch")
}
/// Checks if the specified path exists.
///
/// * `path` - the path to the file or directory
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn exists(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<bool> {
unsupported("exists")
}
/// Reads metadata for a file or directory.
///
/// * `path` - the path to the file or directory
/// * `canonicalize` - if true, will include a canonicalized path in the metadata
/// * `resolve_file_type` - if true, will resolve symlinks to underlying type (file or dir)
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn metadata(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
canonicalize: bool,
resolve_file_type: bool,
) -> io::Result<Metadata> {
unsupported("metadata")
}
/// Spawns a new process, returning its id.
///
/// * `cmd` - the full command to run as a new process (including arguments)
/// * `environment` - the environment variables to associate with the process
/// * `current_dir` - the alternative current directory to use with the process
/// * `persist` - if true, the process will continue running even after the connection that
/// spawned the process has terminated
/// * `pty` - if provided, will run the process within a PTY of the given size
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn proc_spawn(
&self,
ctx: DistantCtx<Self::LocalData>,
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> io::Result<ProcessId> {
unsupported("proc_spawn")
}
/// Kills a running process by its id.
///
/// * `id` - the unique id of the process
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn proc_kill(&self, ctx: DistantCtx<Self::LocalData>, id: ProcessId) -> io::Result<()> {
unsupported("proc_kill")
}
/// Sends data to the stdin of the process with the specified id.
///
/// * `id` - the unique id of the process
/// * `data` - the bytes to send to stdin
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn proc_stdin(
&self,
ctx: DistantCtx<Self::LocalData>,
id: ProcessId,
data: Vec<u8>,
) -> io::Result<()> {
unsupported("proc_stdin")
}
/// Resizes the PTY of the process with the specified id.
///
/// * `id` - the unique id of the process
/// * `size` - the new size of the pty
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn proc_resize_pty(
&self,
ctx: DistantCtx<Self::LocalData>,
id: ProcessId,
size: PtySize,
) -> io::Result<()> {
unsupported("proc_resize_pty")
}
/// Retrieves information about the system.
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn system_info(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<SystemInfo> {
unsupported("system_info")
}
}
#[async_trait]
impl<T, D> Server for DistantApiServer<T, D>
where
T: DistantApi<LocalData = D> + Send + Sync,
D: Send + Sync,
{
type Request = DistantMsg<DistantRequestData>;
type Response = DistantMsg<DistantResponseData>;
type LocalData = D;
/// Overridden to leverage [`DistantApi`] implementation of `on_accept`
async fn on_accept(&self, local_data: &mut Self::LocalData) {
T::on_accept(&self.api, local_data).await
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
let ServerCtx {
connection_id,
request,
reply,
local_data,
} = ctx;
// Convert our reply to a queued reply so we can ensure that the result
// of an API function is sent back before anything else
let reply = reply.queue();
// Process single vs batch requests
let response = match request.payload {
DistantMsg::Single(data) => {
let ctx = DistantCtx {
connection_id,
reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
local_data,
};
let data = handle_request(self, ctx, data).await;
// Report outgoing errors in our debug logs
if let DistantResponseData::Error(x) = &data {
debug!("[Conn {}] {}", connection_id, x);
}
DistantMsg::Single(data)
}
DistantMsg::Batch(list) => {
let mut out = Vec::new();
for data in list {
let ctx = DistantCtx {
connection_id,
reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
local_data: Arc::clone(&local_data),
};
// TODO: This does not run in parallel, meaning that the next item in the
// batch will not be queued until the previous item completes! This
// would be useful if we wanted to chain requests where the previous
// request feeds into the current request, but not if we just want
// to run everything together. So we should instead rewrite this
// to spawn a task per request and then await completion of all tasks
let data = handle_request(self, ctx, data).await;
// Report outgoing errors in our debug logs
if let DistantResponseData::Error(x) = &data {
debug!("[Conn {}] {}", connection_id, x);
}
out.push(data);
}
DistantMsg::Batch(out)
}
};
// Queue up our result to go before ANY of the other messages that might be sent.
// This is important to avoid situations such as when a process is started, but before
// the confirmation can be sent some stdout or stderr is captured and sent first.
if let Err(x) = reply.send_before(response).await {
error!("[Conn {}] Failed to send response: {}", connection_id, x);
}
// Flush out all of our replies thus far and toggle to no longer hold submissions
if let Err(x) = reply.flush(false).await {
error!(
"[Conn {}] Failed to flush response queue: {}",
connection_id, x
);
}
}
}
/// Processes an incoming request
async fn handle_request<T, D>(
server: &DistantApiServer<T, D>,
ctx: DistantCtx<D>,
request: DistantRequestData,
) -> DistantResponseData
where
T: DistantApi<LocalData = D> + Send + Sync,
D: Send + Sync,
{
match request {
DistantRequestData::FileRead { path } => server
.api
.read_file(ctx, path)
.await
.map(|data| DistantResponseData::Blob { data })
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::FileReadText { path } => server
.api
.read_file_text(ctx, path)
.await
.map(|data| DistantResponseData::Text { data })
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::FileWrite { path, data } => server
.api
.write_file(ctx, path, data)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::FileWriteText { path, text } => server
.api
.write_file_text(ctx, path, text)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::FileAppend { path, data } => server
.api
.append_file(ctx, path, data)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::FileAppendText { path, text } => server
.api
.append_file_text(ctx, path, text)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::DirRead {
path,
depth,
absolute,
canonicalize,
include_root,
} => server
.api
.read_dir(ctx, path, depth, absolute, canonicalize, include_root)
.await
.map(|(entries, errors)| DistantResponseData::DirEntries {
entries,
errors: errors.into_iter().map(Error::from).collect(),
})
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::DirCreate { path, all } => server
.api
.create_dir(ctx, path, all)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::Remove { path, force } => server
.api
.remove(ctx, path, force)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::Copy { src, dst } => server
.api
.copy(ctx, src, dst)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::Rename { src, dst } => server
.api
.rename(ctx, src, dst)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::Watch {
path,
recursive,
only,
except,
} => server
.api
.watch(ctx, path, recursive, only, except)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::Unwatch { path } => server
.api
.unwatch(ctx, path)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::Exists { path } => server
.api
.exists(ctx, path)
.await
.map(|value| DistantResponseData::Exists { value })
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::Metadata {
path,
canonicalize,
resolve_file_type,
} => server
.api
.metadata(ctx, path, canonicalize, resolve_file_type)
.await
.map(DistantResponseData::Metadata)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::ProcSpawn {
cmd,
environment,
current_dir,
persist,
pty,
} => server
.api
.proc_spawn(ctx, cmd.into(), environment, current_dir, persist, pty)
.await
.map(|id| DistantResponseData::ProcSpawned { id })
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::ProcKill { id } => server
.api
.proc_kill(ctx, id)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::ProcStdin { id, data } => server
.api
.proc_stdin(ctx, id, data)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::ProcResizePty { id, size } => server
.api
.proc_resize_pty(ctx, id, size)
.await
.map(|_| DistantResponseData::Ok)
.unwrap_or_else(DistantResponseData::from),
DistantRequestData::SystemInfo {} => server
.api
.system_info(ctx)
.await
.map(DistantResponseData::SystemInfo)
.unwrap_or_else(DistantResponseData::from),
}
}

File diff suppressed because it is too large Load Diff

@ -1,4 +1,4 @@
use crate::data::PtySize;
use crate::data::{ProcessId, PtySize};
use std::{future::Future, pin::Pin};
use tokio::{io, sync::mpsc};
@ -17,7 +17,7 @@ pub type FutureReturn<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
/// Represents a process on the remote server
pub trait Process: ProcessKiller + ProcessPty {
/// Represents the id of the process
fn id(&self) -> usize;
fn id(&self) -> ProcessId;
/// Waits for the process to exit, returning the exit status
///

@ -1,19 +1,23 @@
use super::{
wait, ExitStatus, FutureReturn, InputChannel, OutputChannel, Process, ProcessKiller,
wait, ExitStatus, FutureReturn, InputChannel, OutputChannel, Process, ProcessId, ProcessKiller,
ProcessPty, PtySize, WaitRx,
};
use crate::constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_MILLIS};
use crate::{
constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_MILLIS},
data::Environment,
};
use portable_pty::{CommandBuilder, MasterPty, PtySize as PortablePtySize};
use std::{
ffi::OsStr,
io::{self, Read, Write},
path::PathBuf,
sync::{Arc, Mutex},
};
use tokio::{sync::mpsc, task::JoinHandle};
/// Represents a process that is associated with a pty
pub struct PtyProcess {
id: usize,
id: ProcessId,
pty_master: PtyProcessMaster,
stdin: Option<Box<dyn InputChannel>>,
stdout: Option<Box<dyn OutputChannel>>,
@ -25,7 +29,13 @@ pub struct PtyProcess {
impl PtyProcess {
/// Spawns a new simple process
pub fn spawn<S, I, S2>(program: S, args: I, size: PtySize) -> io::Result<Self>
pub fn spawn<S, I, S2>(
program: S,
args: I,
environment: Environment,
current_dir: Option<PathBuf>,
size: PtySize,
) -> io::Result<Self>
where
S: AsRef<OsStr>,
I: IntoIterator<Item = S2>,
@ -47,6 +57,12 @@ impl PtyProcess {
// Spawn our process within the pty
let mut cmd = CommandBuilder::new(program);
cmd.args(args);
if let Some(path) = current_dir {
cmd.cwd(path);
}
for (key, value) in environment {
cmd.env(key, value);
}
let mut child = pty_slave
.spawn_command(cmd)
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
@ -78,7 +94,7 @@ impl PtyProcess {
loop {
match stdout_reader.read(&mut buf) {
Ok(n) if n > 0 => {
let _ = stdout_tx.blocking_send(buf[..n].to_vec()).map_err(|_| {
stdout_tx.blocking_send(buf[..n].to_vec()).map_err(|_| {
io::Error::new(io::ErrorKind::BrokenPipe, "Output channel closed")
})?;
}
@ -137,7 +153,7 @@ impl PtyProcess {
}
impl Process for PtyProcess {
fn id(&self) -> usize {
fn id(&self) -> ProcessId {
self.id
}

@ -1,15 +1,16 @@
use super::{
wait, ExitStatus, FutureReturn, InputChannel, NoProcessPty, OutputChannel, Process,
wait, ExitStatus, FutureReturn, InputChannel, NoProcessPty, OutputChannel, Process, ProcessId,
ProcessKiller, WaitRx,
};
use std::{ffi::OsStr, process::Stdio};
use crate::data::Environment;
use std::{ffi::OsStr, path::PathBuf, process::Stdio};
use tokio::{io, process::Command, sync::mpsc, task::JoinHandle};
mod tasks;
/// Represents a simple process that does not have a pty
pub struct SimpleProcess {
id: usize,
id: ProcessId,
stdin: Option<Box<dyn InputChannel>>,
stdout: Option<Box<dyn OutputChannel>>,
stderr: Option<Box<dyn OutputChannel>>,
@ -22,18 +23,32 @@ pub struct SimpleProcess {
impl SimpleProcess {
/// Spawns a new simple process
pub fn spawn<S, I, S2>(program: S, args: I) -> io::Result<Self>
pub fn spawn<S, I, S2>(
program: S,
args: I,
environment: Environment,
current_dir: Option<PathBuf>,
) -> io::Result<Self>
where
S: AsRef<OsStr>,
I: IntoIterator<Item = S2>,
S2: AsRef<OsStr>,
{
let mut child = Command::new(program)
let mut child = {
let mut command = Command::new(program);
if let Some(path) = current_dir {
command.current_dir(path);
}
command
.envs(environment)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
.spawn()?
};
let stdout = child.stdout.take().unwrap();
let (stdout_task, stdout_ch) = tasks::spawn_read_task(stdout, 1);
@ -80,7 +95,7 @@ impl SimpleProcess {
}
impl Process for SimpleProcess {
fn id(&self) -> usize {
fn id(&self) -> ProcessId {
self.id
}

@ -1,6 +1,7 @@
use crate::constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_MILLIS};
use std::io;
use tokio::{
io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
sync::mpsc,
task::JoinHandle,
};
@ -27,7 +28,7 @@ where
loop {
match reader.read(&mut buf).await {
Ok(n) if n > 0 => {
let _ = channel.send(buf[..n].to_vec()).await.map_err(|_| {
channel.send(buf[..n].to_vec()).await.map_err(|_| {
io::Error::new(io::ErrorKind::BrokenPipe, "Output channel closed")
})?;
@ -65,7 +66,7 @@ where
W: AsyncWrite + Unpin,
{
while let Some(data) = channel.recv().await {
let _ = writer.write_all(&data).await?;
writer.write_all(&data).await?;
}
Ok(())
}

@ -0,0 +1,68 @@
use crate::{data::ProcessId, ConnectionId};
use std::{io, path::PathBuf};
mod process;
pub use process::*;
mod watcher;
pub use watcher::*;
/// Holds global state state managed by the server
pub struct GlobalState {
/// State that holds information about processes running on the server
pub process: ProcessState,
/// Watcher used for filesystem events
pub watcher: WatcherState,
}
impl GlobalState {
pub fn initialize() -> io::Result<Self> {
Ok(Self {
process: ProcessState::new(),
watcher: WatcherState::initialize()?,
})
}
}
/// Holds connection-specific state managed by the server
#[derive(Default)]
pub struct ConnectionState {
/// Unique id associated with connection
id: ConnectionId,
/// Channel connected to global process state
pub(crate) process_channel: ProcessChannel,
/// Channel connected to global watcher state
pub(crate) watcher_channel: WatcherChannel,
/// Contains ids of processes that will be terminated when the connection is closed
processes: Vec<ProcessId>,
/// Contains paths being watched that will be unwatched when the connection is closed
paths: Vec<PathBuf>,
}
impl Drop for ConnectionState {
fn drop(&mut self) {
let id = self.id;
let processes: Vec<ProcessId> = self.processes.drain(..).collect();
let paths: Vec<PathBuf> = self.paths.drain(..).collect();
let process_channel = self.process_channel.clone();
let watcher_channel = self.watcher_channel.clone();
// NOTE: We cannot (and should not) block during drop to perform cleanup,
// instead spawning a task that will do the cleanup async
tokio::spawn(async move {
for id in processes {
let _ = process_channel.kill(id).await;
}
for path in paths {
let _ = watcher_channel.unwatch(id, path).await;
}
});
}
}

@ -0,0 +1,231 @@
use crate::data::{DistantResponseData, Environment, ProcessId, PtySize};
use distant_net::Reply;
use std::{collections::HashMap, io, ops::Deref, path::PathBuf};
use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
};
mod instance;
pub use instance::*;
/// Holds information related to spawned processes on the server
pub struct ProcessState {
channel: ProcessChannel,
task: JoinHandle<()>,
}
impl Drop for ProcessState {
/// Aborts the task that handles process operations and management
fn drop(&mut self) {
self.abort();
}
}
impl ProcessState {
pub fn new() -> Self {
let (tx, rx) = mpsc::channel(1);
let task = tokio::spawn(process_task(tx.clone(), rx));
Self {
channel: ProcessChannel { tx },
task,
}
}
pub fn clone_channel(&self) -> ProcessChannel {
self.channel.clone()
}
/// Aborts the process task
pub fn abort(&self) {
self.task.abort();
}
}
impl Deref for ProcessState {
type Target = ProcessChannel;
fn deref(&self) -> &Self::Target {
&self.channel
}
}
#[derive(Clone)]
pub struct ProcessChannel {
tx: mpsc::Sender<InnerProcessMsg>,
}
impl Default for ProcessChannel {
/// Creates a new channel that is closed by default
fn default() -> Self {
let (tx, _) = mpsc::channel(1);
Self { tx }
}
}
impl ProcessChannel {
/// Spawns a new process, returning the id associated with it
pub async fn spawn(
&self,
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
reply: Box<dyn Reply<Data = DistantResponseData>>,
) -> io::Result<ProcessId> {
let (cb, rx) = oneshot::channel();
self.tx
.send(InnerProcessMsg::Spawn {
cmd,
environment,
current_dir,
persist,
pty,
reply,
cb,
})
.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?;
rx.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to spawn dropped"))?
}
/// Resizes the pty of a running process
pub async fn resize_pty(&self, id: ProcessId, size: PtySize) -> io::Result<()> {
let (cb, rx) = oneshot::channel();
self.tx
.send(InnerProcessMsg::Resize { id, size, cb })
.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?;
rx.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to resize dropped"))?
}
/// Send stdin to a running process
pub async fn send_stdin(&self, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
let (cb, rx) = oneshot::channel();
self.tx
.send(InnerProcessMsg::Stdin { id, data, cb })
.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?;
rx.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to stdin dropped"))?
}
/// Kills a running process
pub async fn kill(&self, id: ProcessId) -> io::Result<()> {
let (cb, rx) = oneshot::channel();
self.tx
.send(InnerProcessMsg::Kill { id, cb })
.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?;
rx.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to kill dropped"))?
}
}
/// Internal message to pass to our task below to perform some action
enum InnerProcessMsg {
Spawn {
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
reply: Box<dyn Reply<Data = DistantResponseData>>,
cb: oneshot::Sender<io::Result<ProcessId>>,
},
Resize {
id: ProcessId,
size: PtySize,
cb: oneshot::Sender<io::Result<()>>,
},
Stdin {
id: ProcessId,
data: Vec<u8>,
cb: oneshot::Sender<io::Result<()>>,
},
Kill {
id: ProcessId,
cb: oneshot::Sender<io::Result<()>>,
},
InternalRemove {
id: ProcessId,
},
}
async fn process_task(tx: mpsc::Sender<InnerProcessMsg>, mut rx: mpsc::Receiver<InnerProcessMsg>) {
let mut processes: HashMap<ProcessId, ProcessInstance> = HashMap::new();
while let Some(msg) = rx.recv().await {
match msg {
InnerProcessMsg::Spawn {
cmd,
environment,
current_dir,
persist,
pty,
reply,
cb,
} => {
let _ = cb.send(
match ProcessInstance::spawn(cmd, environment, current_dir, persist, pty, reply)
{
Ok(mut process) => {
let id = process.id;
// Attach a callback for when the process is finished where
// we will remove it from our above list
let tx = tx.clone();
process.on_done(move |_| async move {
let _ = tx.send(InnerProcessMsg::InternalRemove { id }).await;
});
processes.insert(id, process);
Ok(id)
}
Err(x) => Err(x),
},
);
}
InnerProcessMsg::Resize { id, size, cb } => {
let _ = cb.send(match processes.get(&id) {
Some(process) => process.pty.resize_pty(size),
None => Err(io::Error::new(
io::ErrorKind::Other,
format!("No process found with id {}", id),
)),
});
}
InnerProcessMsg::Stdin { id, data, cb } => {
let _ = cb.send(match processes.get_mut(&id) {
Some(process) => match process.stdin.as_mut() {
Some(stdin) => stdin.send(&data).await,
None => Err(io::Error::new(
io::ErrorKind::Other,
format!("Process {} stdin is closed", id),
)),
},
None => Err(io::Error::new(
io::ErrorKind::Other,
format!("No process found with id {}", id),
)),
});
}
InnerProcessMsg::Kill { id, cb } => {
let _ = cb.send(match processes.get_mut(&id) {
Some(process) => process.killer.kill().await,
None => Err(io::Error::new(
io::ErrorKind::Other,
format!("No process found with id {}", id),
)),
});
}
InnerProcessMsg::InternalRemove { id } => {
processes.remove(&id);
}
}
}
}

@ -0,0 +1,230 @@
use crate::{
api::local::process::{
InputChannel, OutputChannel, Process, ProcessKiller, ProcessPty, PtyProcess, SimpleProcess,
},
data::{DistantResponseData, Environment, ProcessId, PtySize},
};
use distant_net::Reply;
use log::*;
use std::{future::Future, io, path::PathBuf};
use tokio::task::JoinHandle;
/// Holds information related to a spawned process on the server
pub struct ProcessInstance {
pub cmd: String,
pub args: Vec<String>,
pub persist: bool,
pub id: ProcessId,
pub stdin: Option<Box<dyn InputChannel>>,
pub killer: Box<dyn ProcessKiller>,
pub pty: Box<dyn ProcessPty>,
stdout_task: Option<JoinHandle<io::Result<()>>>,
stderr_task: Option<JoinHandle<io::Result<()>>>,
wait_task: Option<JoinHandle<io::Result<()>>>,
}
impl Drop for ProcessInstance {
/// Closes stdin and attempts to kill the process when dropped
fn drop(&mut self) {
// Drop stdin first to close it
self.stdin = None;
// Clear out our tasks if we still have them
let stdout_task = self.stdout_task.take();
let stderr_task = self.stderr_task.take();
let wait_task = self.wait_task.take();
// Attempt to kill the process, which is an async operation that we
// will spawn a task to handle
let id = self.id;
let mut killer = self.killer.clone_killer();
tokio::spawn(async move {
if let Err(x) = killer.kill().await {
error!("Failed to kill process {} when dropped: {}", id, x);
if let Some(task) = stdout_task.as_ref() {
task.abort();
}
if let Some(task) = stderr_task.as_ref() {
task.abort();
}
if let Some(task) = wait_task.as_ref() {
task.abort();
}
}
});
}
}
impl ProcessInstance {
pub fn spawn(
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
reply: Box<dyn Reply<Data = DistantResponseData>>,
) -> io::Result<Self> {
// Build out the command and args from our string
let mut cmd_and_args = if cfg!(windows) {
winsplit::split(&cmd)
} else {
shell_words::split(&cmd).map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?
};
if cmd_and_args.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Command was empty",
));
}
// Split command from arguments, where arguments could be empty
let args = cmd_and_args.split_off(1);
let cmd = cmd_and_args.into_iter().next().unwrap();
let mut child: Box<dyn Process> = match pty {
Some(size) => Box::new(PtyProcess::spawn(
cmd.clone(),
args.clone(),
environment,
current_dir,
size,
)?),
None => Box::new(SimpleProcess::spawn(
cmd.clone(),
args.clone(),
environment,
current_dir,
)?),
};
let id = child.id();
let stdin = child.take_stdin();
let stdout = child.take_stdout();
let stderr = child.take_stderr();
let killer = child.clone_killer();
let pty = child.clone_pty();
// Spawn a task that sends stdout as a response
let stdout_task = match stdout {
Some(stdout) => {
let reply = reply.clone_reply();
let task = tokio::spawn(stdout_task(id, stdout, reply));
Some(task)
}
None => None,
};
// Spawn a task that sends stderr as a response
let stderr_task = match stderr {
Some(stderr) => {
let reply = reply.clone_reply();
let task = tokio::spawn(stderr_task(id, stderr, reply));
Some(task)
}
None => None,
};
// Spawn a task that waits on the process to exit but can also
// kill the process when triggered
let wait_task = Some(tokio::spawn(wait_task(id, child, reply)));
Ok(ProcessInstance {
cmd,
args,
persist,
id,
stdin,
killer,
pty,
stdout_task,
stderr_task,
wait_task,
})
}
/// Invokes the function once the process has completed
///
/// NOTE: Can only be used with one function. All future calls
/// will do nothing
pub fn on_done<F, R>(&mut self, f: F)
where
F: FnOnce(io::Result<()>) -> R + Send + 'static,
R: Future<Output = ()> + Send,
{
if let Some(task) = self.wait_task.take() {
tokio::spawn(async move {
f(task
.await
.unwrap_or_else(|x| Err(io::Error::new(io::ErrorKind::Other, x))))
.await
});
}
}
}
async fn stdout_task(
id: ProcessId,
mut stdout: Box<dyn OutputChannel>,
reply: Box<dyn Reply<Data = DistantResponseData>>,
) -> io::Result<()> {
loop {
match stdout.recv().await {
Ok(Some(data)) => {
if let Err(x) = reply
.send(DistantResponseData::ProcStdout { id, data })
.await
{
return Err(x);
}
}
Ok(None) => return Ok(()),
Err(x) => return Err(x),
}
}
}
async fn stderr_task(
id: ProcessId,
mut stderr: Box<dyn OutputChannel>,
reply: Box<dyn Reply<Data = DistantResponseData>>,
) -> io::Result<()> {
loop {
match stderr.recv().await {
Ok(Some(data)) => {
if let Err(x) = reply
.send(DistantResponseData::ProcStderr { id, data })
.await
{
return Err(x);
}
}
Ok(None) => return Ok(()),
Err(x) => return Err(x),
}
}
}
async fn wait_task(
id: ProcessId,
mut child: Box<dyn Process>,
reply: Box<dyn Reply<Data = DistantResponseData>>,
) -> io::Result<()> {
let status = child.wait().await;
match status {
Ok(status) => {
reply
.send(DistantResponseData::ProcDone {
id,
success: status.success,
code: status.code,
})
.await
}
Err(x) => reply.send(DistantResponseData::from(x)).await,
}
}

@ -0,0 +1,286 @@
use crate::{constants::SERVER_WATCHER_CAPACITY, data::ChangeKind, ConnectionId};
use log::*;
use notify::{
Config as WatcherConfig, Error as WatcherError, Event as WatcherEvent, RecommendedWatcher,
RecursiveMode, Watcher,
};
use std::{
collections::HashMap,
io,
ops::Deref,
path::{Path, PathBuf},
};
use tokio::{
sync::{
mpsc::{self, error::TrySendError},
oneshot,
},
task::JoinHandle,
};
mod path;
pub use path::*;
/// Holds information related to watched paths on the server
pub struct WatcherState {
channel: WatcherChannel,
task: JoinHandle<()>,
}
impl Drop for WatcherState {
/// Aborts the task that handles watcher path operations and management
fn drop(&mut self) {
self.abort();
}
}
impl WatcherState {
/// Will create a watcher and initialize watched paths to be empty
pub fn initialize() -> io::Result<Self> {
// NOTE: Cannot be something small like 1 as this seems to cause a deadlock sometimes
// with a large volume of watch requests
let (tx, rx) = mpsc::channel(SERVER_WATCHER_CAPACITY);
let mut watcher = {
let tx = tx.clone();
notify::recommended_watcher(move |res| {
match tx.try_send(match res {
Ok(x) => InnerWatcherMsg::Event { ev: x },
Err(x) => InnerWatcherMsg::Error { err: x },
}) {
Ok(_) => (),
Err(TrySendError::Full(_)) => {
warn!(
"Reached watcher capacity of {}! Dropping watcher event!",
SERVER_WATCHER_CAPACITY,
);
}
Err(TrySendError::Closed(_)) => {
warn!("Skipping watch event because watcher channel closed");
}
}
})
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?
};
// Attempt to configure watcher, but don't fail if these configurations fail
match watcher.configure(WatcherConfig::PreciseEvents(true)) {
Ok(true) => debug!("Watcher configured for precise events"),
Ok(false) => debug!("Watcher not configured for precise events",),
Err(x) => error!("Watcher configuration for precise events failed: {}", x),
}
// Attempt to configure watcher, but don't fail if these configurations fail
match watcher.configure(WatcherConfig::NoticeEvents(true)) {
Ok(true) => debug!("Watcher configured for notice events"),
Ok(false) => debug!("Watcher not configured for notice events",),
Err(x) => error!("Watcher configuration for notice events failed: {}", x),
}
Ok(Self {
channel: WatcherChannel { tx },
task: tokio::spawn(watcher_task(watcher, rx)),
})
}
pub fn clone_channel(&self) -> WatcherChannel {
self.channel.clone()
}
/// Aborts the watcher task
pub fn abort(&self) {
self.task.abort();
}
}
impl Deref for WatcherState {
type Target = WatcherChannel;
fn deref(&self) -> &Self::Target {
&self.channel
}
}
#[derive(Clone)]
pub struct WatcherChannel {
tx: mpsc::Sender<InnerWatcherMsg>,
}
impl Default for WatcherChannel {
/// Creates a new channel that is closed by default
fn default() -> Self {
let (tx, _) = mpsc::channel(1);
Self { tx }
}
}
impl WatcherChannel {
/// Watch a path for a specific connection denoted by the id within the registered path
pub async fn watch(&self, registered_path: RegisteredPath) -> io::Result<()> {
let (cb, rx) = oneshot::channel();
self.tx
.send(InnerWatcherMsg::Watch {
registered_path,
cb,
})
.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal watcher task closed"))?;
rx.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to watch dropped"))?
}
/// Unwatch a path for a specific connection denoted by the id
pub async fn unwatch(&self, id: ConnectionId, path: impl AsRef<Path>) -> io::Result<()> {
let (cb, rx) = oneshot::channel();
let path = tokio::fs::canonicalize(path.as_ref())
.await
.unwrap_or_else(|_| path.as_ref().to_path_buf());
self.tx
.send(InnerWatcherMsg::Unwatch { id, path, cb })
.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal watcher task closed"))?;
rx.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to unwatch dropped"))?
}
}
/// Internal message to pass to our task below to perform some action
enum InnerWatcherMsg {
Watch {
registered_path: RegisteredPath,
cb: oneshot::Sender<io::Result<()>>,
},
Unwatch {
id: ConnectionId,
path: PathBuf,
cb: oneshot::Sender<io::Result<()>>,
},
Event {
ev: WatcherEvent,
},
Error {
err: WatcherError,
},
}
async fn watcher_task(mut watcher: RecommendedWatcher, mut rx: mpsc::Receiver<InnerWatcherMsg>) {
// TODO: Optimize this in some way to be more performant than
// checking every path whenever an event comes in
let mut registered_paths: Vec<RegisteredPath> = Vec::new();
let mut path_cnt: HashMap<PathBuf, usize> = HashMap::new();
while let Some(msg) = rx.recv().await {
match msg {
InnerWatcherMsg::Watch {
registered_path,
cb,
} => {
// Check if we are tracking the path across any connection
if let Some(cnt) = path_cnt.get_mut(registered_path.path()) {
// Increment the count of times we are watching that path
*cnt += 1;
// Store the registered path in our collection without worry
// since we are already watching a path that impacts this one
registered_paths.push(registered_path);
// Send an okay because we always succeed in this case
let _ = cb.send(Ok(()));
} else {
let res = watcher
.watch(
registered_path.path(),
if registered_path.is_recursive() {
RecursiveMode::Recursive
} else {
RecursiveMode::NonRecursive
},
)
.map_err(|x| io::Error::new(io::ErrorKind::Other, x));
// If we succeeded, store our registered path and set the tracking cnt to 1
if res.is_ok() {
path_cnt.insert(registered_path.path().to_path_buf(), 1);
registered_paths.push(registered_path);
}
// Send the result of the watch, but don't worry if the channel was closed
let _ = cb.send(res);
}
}
InnerWatcherMsg::Unwatch { id, path, cb } => {
// Check if we are tracking the path across any connection
if let Some(cnt) = path_cnt.get(path.as_path()) {
// Cycle through and remove all paths that match the given id and path,
// capturing how many paths we removed
let removed_cnt = {
let old_len = registered_paths.len();
registered_paths
.retain(|p| p.id() != id || (p.path() != path && p.raw_path() != path));
let new_len = registered_paths.len();
old_len - new_len
};
// 1. If we are now at zero cnt for our path, we want to actually unwatch the
// path with our watcher
// 2. If we removed nothing from our path list, we want to return an error
// 3. Otherwise, we return okay because we succeeded
if *cnt <= removed_cnt {
let _ = cb.send(
watcher
.unwatch(&path)
.map_err(|x| io::Error::new(io::ErrorKind::Other, x)),
);
} else if removed_cnt == 0 {
// Send a failure as there was nothing to unwatch for this connection
let _ = cb.send(Err(io::Error::new(
io::ErrorKind::Other,
format!("{:?} is not being watched", path),
)));
} else {
// Send a success as we removed some paths
let _ = cb.send(Ok(()));
}
} else {
// Send a failure as there was nothing to unwatch
let _ = cb.send(Err(io::Error::new(
io::ErrorKind::Other,
format!("{:?} is not being watched", path),
)));
}
}
InnerWatcherMsg::Event { ev } => {
let kind = ChangeKind::from(ev.kind);
for registered_path in registered_paths.iter() {
match registered_path.filter_and_send(kind, &ev.paths).await {
Ok(_) => (),
Err(x) => error!(
"[Conn {}] Failed to forward changes to paths: {}",
registered_path.id(),
x
),
}
}
}
InnerWatcherMsg::Error { err } => {
let msg = err.to_string();
error!("Watcher encountered an error {} for {:?}", msg, err.paths);
for registered_path in registered_paths.iter() {
match registered_path
.filter_and_send_error(&msg, &err.paths, !err.paths.is_empty())
.await
{
Ok(_) => (),
Err(x) => error!(
"[Conn {}] Failed to forward changes to paths: {}",
registered_path.id(),
x
),
}
}
}
}
}
}

@ -0,0 +1,212 @@
use crate::{
data::{Change, ChangeKind, ChangeKindSet, DistantResponseData, Error},
ConnectionId,
};
use distant_net::Reply;
use std::{
fmt,
hash::{Hash, Hasher},
io,
path::{Path, PathBuf},
};
/// Represents a path registered with a watcher that includes relevant state including
/// the ability to reply with
pub struct RegisteredPath {
/// Unique id tied to the path to distinguish it
id: ConnectionId,
/// The raw path provided to the watcher, which is not canonicalized
raw_path: PathBuf,
/// The canonicalized path at the time of providing to the watcher,
/// as all paths must exist for a watcher, we use this to get the
/// source of truth when watching
path: PathBuf,
/// Whether or not the path was set to be recursive
recursive: bool,
/// Specific filter for path (only the allowed change kinds are tracked)
/// NOTE: This is a combination of only and except filters
allowed: ChangeKindSet,
/// Used to send a reply through the connection watching this path
reply: Box<dyn Reply<Data = DistantResponseData>>,
}
impl fmt::Debug for RegisteredPath {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RegisteredPath")
.field("raw_path", &self.raw_path)
.field("path", &self.path)
.field("recursive", &self.recursive)
.field("allowed", &self.allowed)
.finish()
}
}
impl PartialEq for RegisteredPath {
/// Checks for equality using the id, canonicalized path, and allowed change kinds
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.path == other.path && self.allowed == other.allowed
}
}
impl Eq for RegisteredPath {}
impl Hash for RegisteredPath {
/// Hashes using the id, canonicalized path, and allowed change kinds
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.path.hash(state);
self.allowed.hash(state);
}
}
impl RegisteredPath {
/// Registers a new path to be watched (does not actually do any watching)
pub async fn register(
id: ConnectionId,
path: impl Into<PathBuf>,
recursive: bool,
only: impl Into<ChangeKindSet>,
except: impl Into<ChangeKindSet>,
reply: Box<dyn Reply<Data = DistantResponseData>>,
) -> io::Result<Self> {
let raw_path = path.into();
let path = tokio::fs::canonicalize(raw_path.as_path()).await?;
let only = only.into();
let except = except.into();
// Calculate the true list of kinds based on only and except filters
let allowed = if only.is_empty() {
ChangeKindSet::all() - except
} else {
only - except
};
Ok(Self {
id,
raw_path,
path,
recursive,
allowed,
reply,
})
}
/// Represents a unique id to distinguish this path from other registrations
/// of the same path
pub fn id(&self) -> ConnectionId {
self.id
}
/// Represents the path provided during registration before canonicalization
pub fn raw_path(&self) -> &Path {
self.raw_path.as_path()
}
/// Represents the canonicalized path used by watchers
pub fn path(&self) -> &Path {
self.path.as_path()
}
/// Returns true if this path represents a recursive watcher path
pub fn is_recursive(&self) -> bool {
self.recursive
}
/// Returns reference to set of [`ChangeKind`] that this path watches
pub fn allowed(&self) -> &ChangeKindSet {
&self.allowed
}
/// Sends a reply for a change tied to this registered path, filtering
/// out any paths that are not applicable
///
/// Returns true if message was sent, and false if not
pub async fn filter_and_send<T>(&self, kind: ChangeKind, paths: T) -> io::Result<bool>
where
T: IntoIterator,
T::Item: AsRef<Path>,
{
if !self.allowed().contains(&kind) {
return Ok(false);
}
let paths: Vec<PathBuf> = paths
.into_iter()
.filter(|p| self.applies_to_path(p.as_ref()))
.map(|p| p.as_ref().to_path_buf())
.collect();
if !paths.is_empty() {
self.reply
.send(DistantResponseData::Changed(Change { kind, paths }))
.await
.map(|_| true)
} else {
Ok(false)
}
}
/// Sends an error message and includes paths if provided, skipping sending the message if
/// no paths match and `skip_if_no_paths` is true
///
/// Returns true if message was sent, and false if not
pub async fn filter_and_send_error<T>(
&self,
msg: &str,
paths: T,
skip_if_no_paths: bool,
) -> io::Result<bool>
where
T: IntoIterator,
T::Item: AsRef<Path>,
{
let paths: Vec<PathBuf> = paths
.into_iter()
.filter(|p| self.applies_to_path(p.as_ref()))
.map(|p| p.as_ref().to_path_buf())
.collect();
if !paths.is_empty() || !skip_if_no_paths {
self.reply
.send(if paths.is_empty() {
DistantResponseData::Error(Error::from(msg))
} else {
DistantResponseData::Error(Error::from(format!("{} about {:?}", msg, paths)))
})
.await
.map(|_| true)
} else {
Ok(false)
}
}
/// Returns true if this path applies to the given path.
/// This is accomplished by checking if the path is contained
/// within either the raw or canonicalized path of the watcher
/// and ensures that recursion rules are respected
pub fn applies_to_path(&self, path: &Path) -> bool {
let check_path = |path: &Path| -> bool {
let cnt = path.components().count();
// 0 means exact match from strip_prefix
// 1 means that it was within immediate directory (fine for non-recursive)
// 2+ means it needs to be recursive
cnt < 2 || self.recursive
};
match (
path.strip_prefix(self.path()),
path.strip_prefix(self.raw_path()),
) {
(Ok(p1), Ok(p2)) => check_path(p1) || check_path(p2),
(Ok(p), Err(_)) => check_path(p),
(Err(_), Ok(p)) => check_path(p),
(Err(_), Err(_)) => false,
}
}
}

@ -0,0 +1,29 @@
use crate::{api::DistantMsg, data::DistantResponseData};
use distant_net::Reply;
use std::{future::Future, io, pin::Pin};
/// Wrapper around a reply that can be batch or single, converting
/// a single data into the wrapped type
pub struct DistantSingleReply(Box<dyn Reply<Data = DistantMsg<DistantResponseData>>>);
impl From<Box<dyn Reply<Data = DistantMsg<DistantResponseData>>>> for DistantSingleReply {
fn from(reply: Box<dyn Reply<Data = DistantMsg<DistantResponseData>>>) -> Self {
Self(reply)
}
}
impl Reply for DistantSingleReply {
type Data = DistantResponseData;
fn send(&self, data: Self::Data) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + '_>> {
self.0.send(DistantMsg::Single(data))
}
fn blocking_send(&self, data: Self::Data) -> io::Result<()> {
self.0.blocking_send(DistantMsg::Single(data))
}
fn clone_reply(&self) -> Box<dyn Reply<Data = Self::Data>> {
Box::new(Self(self.0.clone_reply()))
}
}

@ -0,0 +1,18 @@
use crate::{DistantMsg, DistantRequestData, DistantResponseData};
use distant_net::{Channel, Client};
mod ext;
mod lsp;
mod process;
mod watcher;
/// Represents a [`Client`] that communicates using the distant protocol
pub type DistantClient = Client<DistantMsg<DistantRequestData>, DistantMsg<DistantResponseData>>;
/// Represents a [`Channel`] that communicates using the distant protocol
pub type DistantChannel = Channel<DistantMsg<DistantRequestData>, DistantMsg<DistantResponseData>>;
pub use ext::*;
pub use lsp::*;
pub use process::*;
pub use watcher::*;

@ -0,0 +1,424 @@
use crate::{
client::{
RemoteCommand, RemoteLspCommand, RemoteLspProcess, RemoteOutput, RemoteProcess, Watcher,
},
data::{
ChangeKindSet, DirEntry, DistantRequestData, DistantResponseData, Environment,
Error as Failure, Metadata, PtySize, SystemInfo,
},
DistantMsg,
};
use distant_net::{Channel, Request};
use std::{future::Future, io, path::PathBuf, pin::Pin};
pub type AsyncReturn<'a, T, E = io::Error> =
Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>;
fn mismatched_response() -> io::Error {
io::Error::new(io::ErrorKind::Other, "Mismatched response")
}
/// Provides convenience functions on top of a [`SessionChannel`]
pub trait DistantChannelExt {
/// Appends to a remote file using the data from a collection of bytes
fn append_file(
&mut self,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()>;
/// Appends to a remote file using the data from a string
fn append_file_text(
&mut self,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()>;
/// Copies a remote file or directory from src to dst
fn copy(&mut self, src: impl Into<PathBuf>, dst: impl Into<PathBuf>) -> AsyncReturn<'_, ()>;
/// Creates a remote directory, optionally creating all parent components if specified
fn create_dir(&mut self, path: impl Into<PathBuf>, all: bool) -> AsyncReturn<'_, ()>;
fn exists(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, bool>;
/// Retrieves metadata about a path on a remote machine
fn metadata(
&mut self,
path: impl Into<PathBuf>,
canonicalize: bool,
resolve_file_type: bool,
) -> AsyncReturn<'_, Metadata>;
/// Reads entries from a directory, returning a tuple of directory entries and failures
fn read_dir(
&mut self,
path: impl Into<PathBuf>,
depth: usize,
absolute: bool,
canonicalize: bool,
include_root: bool,
) -> AsyncReturn<'_, (Vec<DirEntry>, Vec<Failure>)>;
/// Reads a remote file as a collection of bytes
fn read_file(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, Vec<u8>>;
/// Returns a remote file as a string
fn read_file_text(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, String>;
/// Removes a remote file or directory, supporting removal of non-empty directories if
/// force is true
fn remove(&mut self, path: impl Into<PathBuf>, force: bool) -> AsyncReturn<'_, ()>;
/// Renames a remote file or directory from src to dst
fn rename(&mut self, src: impl Into<PathBuf>, dst: impl Into<PathBuf>) -> AsyncReturn<'_, ()>;
/// Watches a remote file or directory
fn watch(
&mut self,
path: impl Into<PathBuf>,
recursive: bool,
only: impl Into<ChangeKindSet>,
except: impl Into<ChangeKindSet>,
) -> AsyncReturn<'_, Watcher>;
/// Unwatches a remote file or directory
fn unwatch(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, ()>;
/// Spawns a process on the remote machine
fn spawn(
&mut self,
cmd: impl Into<String>,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteProcess>;
/// Spawns an LSP process on the remote machine
fn spawn_lsp(
&mut self,
cmd: impl Into<String>,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteLspProcess>;
/// Spawns a process on the remote machine and wait for it to complete
fn output(
&mut self,
cmd: impl Into<String>,
environment: Environment,
current_dir: Option<PathBuf>,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteOutput>;
/// Retrieves information about the remote system
fn system_info(&mut self) -> AsyncReturn<'_, SystemInfo>;
/// Writes a remote file with the data from a collection of bytes
fn write_file(
&mut self,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()>;
/// Writes a remote file with the data from a string
fn write_file_text(
&mut self,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()>;
}
macro_rules! make_body {
($self:expr, $data:expr, @ok) => {
make_body!($self, $data, |data| {
match data {
DistantResponseData::Ok => Ok(()),
DistantResponseData::Error(x) => Err(io::Error::from(x)),
_ => Err(mismatched_response()),
}
})
};
($self:expr, $data:expr, $and_then:expr) => {{
let req = Request::new(DistantMsg::Single($data));
Box::pin(async move {
$self
.send(req)
.await
.and_then(|res| match res.payload {
DistantMsg::Single(x) => Ok(x),
_ => Err(mismatched_response()),
})
.and_then($and_then)
})
}};
}
impl DistantChannelExt
for Channel<DistantMsg<DistantRequestData>, DistantMsg<DistantResponseData>>
{
fn append_file(
&mut self,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
DistantRequestData::FileAppend { path: path.into(), data: data.into() },
@ok
)
}
fn append_file_text(
&mut self,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
DistantRequestData::FileAppendText { path: path.into(), text: data.into() },
@ok
)
}
fn copy(&mut self, src: impl Into<PathBuf>, dst: impl Into<PathBuf>) -> AsyncReturn<'_, ()> {
make_body!(
self,
DistantRequestData::Copy { src: src.into(), dst: dst.into() },
@ok
)
}
fn create_dir(&mut self, path: impl Into<PathBuf>, all: bool) -> AsyncReturn<'_, ()> {
make_body!(
self,
DistantRequestData::DirCreate { path: path.into(), all },
@ok
)
}
fn exists(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, bool> {
make_body!(
self,
DistantRequestData::Exists { path: path.into() },
|data| match data {
DistantResponseData::Exists { value } => Ok(value),
DistantResponseData::Error(x) => Err(io::Error::from(x)),
_ => Err(mismatched_response()),
}
)
}
fn metadata(
&mut self,
path: impl Into<PathBuf>,
canonicalize: bool,
resolve_file_type: bool,
) -> AsyncReturn<'_, Metadata> {
make_body!(
self,
DistantRequestData::Metadata {
path: path.into(),
canonicalize,
resolve_file_type
},
|data| match data {
DistantResponseData::Metadata(x) => Ok(x),
DistantResponseData::Error(x) => Err(io::Error::from(x)),
_ => Err(mismatched_response()),
}
)
}
fn read_dir(
&mut self,
path: impl Into<PathBuf>,
depth: usize,
absolute: bool,
canonicalize: bool,
include_root: bool,
) -> AsyncReturn<'_, (Vec<DirEntry>, Vec<Failure>)> {
make_body!(
self,
DistantRequestData::DirRead {
path: path.into(),
depth,
absolute,
canonicalize,
include_root
},
|data| match data {
DistantResponseData::DirEntries { entries, errors } => Ok((entries, errors)),
DistantResponseData::Error(x) => Err(io::Error::from(x)),
_ => Err(mismatched_response()),
}
)
}
fn read_file(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, Vec<u8>> {
make_body!(
self,
DistantRequestData::FileRead { path: path.into() },
|data| match data {
DistantResponseData::Blob { data } => Ok(data),
DistantResponseData::Error(x) => Err(io::Error::from(x)),
_ => Err(mismatched_response()),
}
)
}
fn read_file_text(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, String> {
make_body!(
self,
DistantRequestData::FileReadText { path: path.into() },
|data| match data {
DistantResponseData::Text { data } => Ok(data),
DistantResponseData::Error(x) => Err(io::Error::from(x)),
_ => Err(mismatched_response()),
}
)
}
fn remove(&mut self, path: impl Into<PathBuf>, force: bool) -> AsyncReturn<'_, ()> {
make_body!(
self,
DistantRequestData::Remove { path: path.into(), force },
@ok
)
}
fn rename(&mut self, src: impl Into<PathBuf>, dst: impl Into<PathBuf>) -> AsyncReturn<'_, ()> {
make_body!(
self,
DistantRequestData::Rename { src: src.into(), dst: dst.into() },
@ok
)
}
fn watch(
&mut self,
path: impl Into<PathBuf>,
recursive: bool,
only: impl Into<ChangeKindSet>,
except: impl Into<ChangeKindSet>,
) -> AsyncReturn<'_, Watcher> {
let path = path.into();
let only = only.into();
let except = except.into();
Box::pin(async move { Watcher::watch(self.clone(), path, recursive, only, except).await })
}
fn unwatch(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, ()> {
fn inner_unwatch(
channel: &mut Channel<DistantMsg<DistantRequestData>, DistantMsg<DistantResponseData>>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, ()> {
make_body!(
channel,
DistantRequestData::Unwatch { path: path.into() },
@ok
)
}
let path = path.into();
Box::pin(async move { inner_unwatch(self, path).await })
}
fn spawn(
&mut self,
cmd: impl Into<String>,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteProcess> {
let cmd = cmd.into();
Box::pin(async move {
RemoteCommand::new()
.environment(environment)
.current_dir(current_dir)
.persist(persist)
.pty(pty)
.spawn(self.clone(), cmd)
.await
})
}
fn spawn_lsp(
&mut self,
cmd: impl Into<String>,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteLspProcess> {
let cmd = cmd.into();
Box::pin(async move {
RemoteLspCommand::new()
.environment(environment)
.current_dir(current_dir)
.persist(persist)
.pty(pty)
.spawn(self.clone(), cmd)
.await
})
}
fn output(
&mut self,
cmd: impl Into<String>,
environment: Environment,
current_dir: Option<PathBuf>,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteOutput> {
let cmd = cmd.into();
Box::pin(async move {
RemoteCommand::new()
.environment(environment)
.current_dir(current_dir)
.persist(false)
.pty(pty)
.spawn(self.clone(), cmd)
.await?
.output()
.await
})
}
fn system_info(&mut self) -> AsyncReturn<'_, SystemInfo> {
make_body!(self, DistantRequestData::SystemInfo {}, |data| match data {
DistantResponseData::SystemInfo(x) => Ok(x),
DistantResponseData::Error(x) => Err(io::Error::from(x)),
_ => Err(mismatched_response()),
})
}
fn write_file(
&mut self,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
DistantRequestData::FileWrite { path: path.into(), data: data.into() },
@ok
)
}
fn write_file_text(
&mut self,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
DistantRequestData::FileWriteText { path: path.into(), text: data.into() },
@ok
)
}
}

@ -1,39 +1,90 @@
use super::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout};
use crate::{client::SessionChannel, data::PtySize};
use crate::{
client::{
DistantChannel, RemoteCommand, RemoteProcess, RemoteStatus, RemoteStderr, RemoteStdin,
RemoteStdout,
},
data::{Environment, PtySize},
};
use futures::stream::{Stream, StreamExt};
use std::{
io::{self, Cursor, Read},
ops::{Deref, DerefMut},
path::PathBuf,
};
use tokio::{
sync::mpsc::{self, error::TryRecvError},
task::JoinHandle,
};
mod data;
pub use data::*;
mod msg;
pub use msg::*;
/// Represents an LSP server process on a remote machine
#[derive(Debug)]
pub struct RemoteLspProcess {
inner: RemoteProcess,
pub stdin: Option<RemoteLspStdin>,
pub stdout: Option<RemoteLspStdout>,
pub stderr: Option<RemoteLspStderr>,
/// A [`RemoteLspProcess`] builder providing support to configure
/// before spawning the process on a remote machine
pub struct RemoteLspCommand {
persist: bool,
pty: Option<PtySize>,
environment: Environment,
current_dir: Option<PathBuf>,
}
impl Default for RemoteLspCommand {
fn default() -> Self {
Self::new()
}
}
impl RemoteLspCommand {
/// Creates a new set of options for a remote LSP process
pub fn new() -> Self {
Self {
persist: false,
pty: None,
environment: Environment::new(),
current_dir: None,
}
}
/// Sets whether or not the process will be persistent,
/// meaning that it will not be terminated when the
/// connection to the remote machine is terminated
pub fn persist(&mut self, persist: bool) -> &mut Self {
self.persist = persist;
self
}
/// Configures the process to leverage a PTY with the specified size
pub fn pty(&mut self, pty: Option<PtySize>) -> &mut Self {
self.pty = pty;
self
}
/// Replaces the existing environment variables with the given collection
pub fn environment(&mut self, environment: Environment) -> &mut Self {
self.environment = environment;
self
}
/// Configures the process with an alternative current directory
pub fn current_dir(&mut self, current_dir: Option<PathBuf>) -> &mut Self {
self.current_dir = current_dir;
self
}
impl RemoteLspProcess {
/// Spawns the specified process on the remote machine using the given session, treating
/// the process like an LSP server
pub async fn spawn(
tenant: impl Into<String>,
channel: SessionChannel,
&mut self,
channel: DistantChannel,
cmd: impl Into<String>,
args: Vec<String>,
persist: bool,
pty: Option<PtySize>,
) -> Result<Self, RemoteProcessError> {
let mut inner = RemoteProcess::spawn(tenant, channel, cmd, args, persist, pty).await?;
) -> io::Result<RemoteLspProcess> {
let mut command = RemoteCommand::new();
command.environment(self.environment.clone());
command.current_dir(self.current_dir.clone());
command.persist(self.persist);
command.pty(self.pty);
let mut inner = command.spawn(channel, cmd).await?;
let stdin = inner.stdin.take().map(RemoteLspStdin::new);
let stdout = inner.stdout.take().map(RemoteLspStdout::new);
let stderr = inner.stderr.take().map(RemoteLspStderr::new);
@ -45,9 +96,20 @@ impl RemoteLspProcess {
stderr,
})
}
}
/// Represents an LSP server process on a remote machine
#[derive(Debug)]
pub struct RemoteLspProcess {
inner: RemoteProcess,
pub stdin: Option<RemoteLspStdin>,
pub stdout: Option<RemoteLspStdout>,
pub stderr: Option<RemoteLspStderr>,
}
impl RemoteLspProcess {
/// Waits for the process to terminate, returning the success status and an optional exit code
pub async fn wait(self) -> Result<(bool, Option<i32>), RemoteProcessError> {
pub async fn wait(self) -> io::Result<RemoteStatus> {
self.inner.wait().await
}
}
@ -116,7 +178,7 @@ impl RemoteLspStdin {
self.write(data.as_bytes()).await
}
fn update_and_read_messages(&mut self, data: &[u8]) -> io::Result<Vec<LspData>> {
fn update_and_read_messages(&mut self, data: &[u8]) -> io::Result<Vec<LspMsg>> {
// Create or insert into our buffer
match &mut self.buf {
Some(buf) => buf.extend(data),
@ -318,7 +380,7 @@ where
(read_task, rx)
}
fn read_lsp_messages(input: &[u8]) -> io::Result<(Option<Vec<u8>>, Vec<LspData>)> {
fn read_lsp_messages(input: &[u8]) -> io::Result<(Option<Vec<u8>>, Vec<LspMsg>)> {
let mut queue = Vec::new();
// Continue to read complete messages from the input until we either fail to parse or we reach
@ -326,7 +388,7 @@ fn read_lsp_messages(input: &[u8]) -> io::Result<(Option<Vec<u8>>, Vec<LspData>)
// cursor may have moved partially from lsp successfully reading the start of a message
let mut cursor = Cursor::new(input);
let mut pos = 0;
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
while let Ok(data) = LspMsg::from_buf_reader(&mut cursor) {
queue.push(data);
pos = cursor.position();
}
@ -347,10 +409,10 @@ fn read_lsp_messages(input: &[u8]) -> io::Result<(Option<Vec<u8>>, Vec<LspData>)
#[cfg(test)]
mod tests {
use super::*;
use crate::{
client::Session,
data::{Request, RequestData, Response, ResponseData},
net::{InmemoryStream, PlainCodec, Transport},
use crate::data::{DistantRequestData, DistantResponseData};
use distant_net::{
Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Request, Response,
TypedAsyncRead, TypedAsyncWrite,
};
use std::{future::Future, time::Duration};
@ -358,29 +420,26 @@ mod tests {
const TIMEOUT: Duration = Duration::from_millis(50);
// Configures an lsp process with a means to send & receive data from outside
async fn spawn_lsp_process() -> (Transport<InmemoryStream, PlainCodec>, RemoteLspProcess) {
let (mut t1, t2) = Transport::make_pair();
let session = Session::initialize(t2).unwrap();
async fn spawn_lsp_process() -> (
FramedTransport<InmemoryTransport, PlainCodec>,
RemoteLspProcess,
) {
let (mut t1, t2) = FramedTransport::pair(100);
let (writer, reader) = t2.into_split();
let session = Client::new(writer, reader).unwrap();
let spawn_task = tokio::spawn(async move {
RemoteLspProcess::spawn(
String::from("test-tenant"),
session.clone_channel(),
String::from("cmd"),
vec![String::from("arg")],
false,
None,
)
RemoteLspCommand::new()
.spawn(session.clone_channel(), String::from("cmd arg"))
.await
});
// Wait until we get the request from the session
let req = t1.receive::<Request>().await.unwrap().unwrap();
let req: Request<DistantRequestData> = t1.read().await.unwrap().unwrap();
// Send back a response through the session
t1.send(Response::new(
"test-tenant",
t1.write(Response::new(
req.id,
vec![ResponseData::ProcSpawned { id: rand::random() }],
DistantResponseData::ProcSpawned { id: rand::random() },
))
.await
.unwrap();
@ -427,13 +486,12 @@ mod tests {
.unwrap();
// Validate that the outgoing req is a complete LSP message
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
@ -460,20 +518,23 @@ mod tests {
// Verify that nothing has been sent out yet
// NOTE: Yield to ensure that data would be waiting at the transport if it was sent
tokio::task::yield_now().await;
let result = timeout(TIMEOUT, transport.receive::<Request>()).await;
let result = timeout(
TIMEOUT,
TypedAsyncRead::<Request<DistantRequestData>>::read(&mut transport),
)
.await;
assert!(result.is_err(), "Unexpectedly got data: {:?}", result);
// Write remainder of message
proc.stdin.as_mut().unwrap().write(msg_b).await.unwrap();
// Validate that the outgoing req is a complete LSP message
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
@ -503,13 +564,12 @@ mod tests {
.unwrap();
// Validate that the outgoing req is a complete LSP message
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
@ -553,13 +613,12 @@ mod tests {
.unwrap();
// Validate that the first outgoing req is a complete LSP message matching first
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
@ -569,13 +628,12 @@ mod tests {
}
// Validate that the second outgoing req is a complete LSP message matching second
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
make_lsp_msg(serde_json::json!({
"field1": "c",
"field2": "d",
}))
@ -600,16 +658,15 @@ mod tests {
.unwrap();
// Validate that the outgoing req is a complete LSP message
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
// Verify the contents AND headers are as expected; in this case,
// this will also ensure that the Content-Length is adjusted
// when the distant scheme was changed to file
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
make_lsp_msg(serde_json::json!({
"field1": "file://some/path",
"field2": "file://other/path",
}))
@ -625,16 +682,15 @@ mod tests {
// Send complete LSP message as stdout to process
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStdout {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
})),
}],
},
))
.await
.unwrap();
@ -662,13 +718,12 @@ mod tests {
// Send half of LSP message over stdout
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStdout {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
data: msg_a.to_vec(),
}],
},
))
.await
.unwrap();
@ -681,13 +736,12 @@ mod tests {
// Send other half of LSP message over stdout
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStdout {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
data: msg_b.to_vec(),
}],
},
))
.await
.unwrap();
@ -716,13 +770,12 @@ mod tests {
// Send complete LSP message as stdout to process
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStdout {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
data: format!("{}{}", String::from_utf8(msg).unwrap(), extra).into_bytes(),
}],
},
))
.await
.unwrap();
@ -760,10 +813,9 @@ mod tests {
// Send complete LSP message as stdout to process
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStdout {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
data: format!(
"{}{}",
@ -771,7 +823,7 @@ mod tests {
String::from_utf8(msg_2).unwrap()
)
.into_bytes(),
}],
},
))
.await
.unwrap();
@ -803,16 +855,15 @@ mod tests {
// Send complete LSP message as stdout to process
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStdout {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
"field1": "distant://some/path",
"field2": "file://other/path",
})),
}],
},
))
.await
.unwrap();
@ -834,16 +885,15 @@ mod tests {
// Send complete LSP message as stderr to process
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStderr {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
})),
}],
},
))
.await
.unwrap();
@ -871,13 +921,12 @@ mod tests {
// Send half of LSP message over stderr
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStderr {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
data: msg_a.to_vec(),
}],
},
))
.await
.unwrap();
@ -890,13 +939,12 @@ mod tests {
// Send other half of LSP message over stderr
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStderr {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
data: msg_b.to_vec(),
}],
},
))
.await
.unwrap();
@ -925,13 +973,12 @@ mod tests {
// Send complete LSP message as stderr to process
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStderr {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
data: format!("{}{}", String::from_utf8(msg).unwrap(), extra).into_bytes(),
}],
},
))
.await
.unwrap();
@ -969,10 +1016,9 @@ mod tests {
// Send complete LSP message as stderr to process
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStderr {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
data: format!(
"{}{}",
@ -980,7 +1026,7 @@ mod tests {
String::from_utf8(msg_2).unwrap()
)
.into_bytes(),
}],
},
))
.await
.unwrap();
@ -1012,16 +1058,15 @@ mod tests {
// Send complete LSP message as stderr to process
transport
.send(Response::new(
"test-tenant",
proc.origin_id(),
vec![ResponseData::ProcStderr {
.write(Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
"field1": "distant://some/path",
"field2": "file://other/path",
})),
}],
},
))
.await
.unwrap();

@ -1,4 +1,3 @@
use crate::client::{SessionInfo, SessionInfoParseError};
use derive_more::{Display, Error, From};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
@ -10,33 +9,9 @@ use std::{
string::FromUtf8Error,
};
#[derive(Copy, Clone, Debug, PartialEq, Eq, Display, Error, From)]
pub enum LspSessionInfoError {
/// Encountered when attempting to create a session from a request that is not initialize
NotInitializeRequest,
/// Encountered if missing session parameters within an initialize request
MissingSessionInfoParams,
/// Encountered if session parameters are not expected types
InvalidSessionInfoParams,
/// Encountered when failing to parse session
SessionInfoParseError(SessionInfoParseError),
}
impl From<LspSessionInfoError> for io::Error {
fn from(x: LspSessionInfoError) -> Self {
match x {
LspSessionInfoError::SessionInfoParseError(x) => x.into(),
x => io::Error::new(io::ErrorKind::InvalidData, x),
}
}
}
/// Represents some data being communicated to/from an LSP consisting of a header and content part
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LspData {
pub struct LspMsg {
/// Header-portion of some data related to LSP
header: LspHeader,
@ -45,7 +20,7 @@ pub struct LspData {
}
#[derive(Debug, Display, Error, From)]
pub enum LspDataParseError {
pub enum LspMsgParseError {
/// When the received content is malformed
BadContent(LspContentParseError),
@ -65,23 +40,23 @@ pub enum LspDataParseError {
UnexpectedEof,
}
impl From<LspDataParseError> for io::Error {
fn from(x: LspDataParseError) -> Self {
impl From<LspMsgParseError> for io::Error {
fn from(x: LspMsgParseError) -> Self {
match x {
LspDataParseError::BadContent(x) => x.into(),
LspDataParseError::BadHeader(x) => x.into(),
LspDataParseError::BadHeaderTermination => io::Error::new(
LspMsgParseError::BadContent(x) => x.into(),
LspMsgParseError::BadHeader(x) => x.into(),
LspMsgParseError::BadHeaderTermination => io::Error::new(
io::ErrorKind::InvalidData,
r"Received header line not terminated in \r\n",
),
LspDataParseError::BadInput(x) => io::Error::new(io::ErrorKind::InvalidData, x),
LspDataParseError::IoError(x) => x,
LspDataParseError::UnexpectedEof => io::Error::from(io::ErrorKind::UnexpectedEof),
LspMsgParseError::BadInput(x) => io::Error::new(io::ErrorKind::InvalidData, x),
LspMsgParseError::IoError(x) => x,
LspMsgParseError::UnexpectedEof => io::Error::from(io::ErrorKind::UnexpectedEof),
}
}
}
impl LspData {
impl LspMsg {
/// Returns a reference to the header part
pub fn header(&self) -> &LspHeader {
&self.header
@ -106,19 +81,6 @@ impl LspData {
self.header.content_length = self.content.to_string().len();
}
/// Creates a session's info by inspecting the content for session parameters, removing the
/// session parameters from the content. Will also adjust the content length header to match
/// the new size of the content.
pub fn take_session_info(&mut self) -> Result<SessionInfo, LspSessionInfoError> {
match self.content.take_session_info() {
Ok(session) => {
self.refresh_content_length();
Ok(session)
}
Err(x) => Err(x),
}
}
/// Attempts to read incoming lsp data from a buffered reader.
///
/// Note that this is **blocking** while it waits on the header information (or EOF)!
@ -132,7 +94,7 @@ impl LspData {
/// ...
/// }
/// ```
pub fn from_buf_reader<R: BufRead>(r: &mut R) -> Result<Self, LspDataParseError> {
pub fn from_buf_reader<R: BufRead>(r: &mut R) -> Result<Self, LspMsgParseError> {
// Read in our headers first so we can figure out how much more to read
let mut buf = String::new();
loop {
@ -145,14 +107,14 @@ impl LspData {
// We shouldn't be getting end of the reader yet
if len == 0 {
return Err(LspDataParseError::UnexpectedEof);
return Err(LspMsgParseError::UnexpectedEof);
}
let line = &buf[start..end];
// Check if we've gotten bad data
if !line.ends_with("\r\n") {
return Err(LspDataParseError::BadHeaderTermination);
return Err(LspMsgParseError::BadHeaderTermination);
// Check if we've received the header termination
} else if line == "\r\n" {
@ -168,9 +130,9 @@ impl LspData {
let mut buf = vec![0u8; header.content_length];
r.read_exact(&mut buf).map_err(|x| {
if x.kind() == io::ErrorKind::UnexpectedEof {
LspDataParseError::UnexpectedEof
LspMsgParseError::UnexpectedEof
} else {
LspDataParseError::IoError(x)
LspMsgParseError::IoError(x)
}
})?;
String::from_utf8(buf)?.parse::<LspContent>()?
@ -185,7 +147,7 @@ impl LspData {
}
}
impl fmt::Display for LspData {
impl fmt::Display for LspMsg {
/// Outputs header & content in form
///
/// ```text
@ -203,8 +165,8 @@ impl fmt::Display for LspData {
}
}
impl FromStr for LspData {
type Err = LspDataParseError;
impl FromStr for LspMsg {
type Err = LspMsgParseError;
/// Parses headers and content in the form of
///
@ -380,63 +342,6 @@ impl LspContent {
pub fn convert_distant_scheme_to_local(&mut self) {
swap_prefix(&mut self.0, "distant:", "file:");
}
/// Creates a session's info by inspecting the content for session parameters, removing the
/// session parameters from the content
pub fn take_session_info(&mut self) -> Result<SessionInfo, LspSessionInfoError> {
// Verify that we're dealing with an initialize request
match self.0.get("method") {
Some(value) if value == "initialize" => {}
_ => return Err(LspSessionInfoError::NotInitializeRequest),
}
// Attempt to grab the distant initialization options
match self.strip_session_params() {
Some((Some(host), Some(port), Some(key))) => {
let host = host
.as_str()
.ok_or(LspSessionInfoError::InvalidSessionInfoParams)?;
let port = port
.as_u64()
.ok_or(LspSessionInfoError::InvalidSessionInfoParams)?;
let key = key
.as_str()
.ok_or(LspSessionInfoError::InvalidSessionInfoParams)?;
Ok(format!("DISTANT CONNECT {} {} {}", host, port, key).parse()?)
}
_ => Err(LspSessionInfoError::MissingSessionInfoParams),
}
}
/// Strips the session params from the content, returning them if they exist
///
/// ```json
/// {
/// "params": {
/// "initializationOptions": {
/// "distant": {
/// "host": "...",
/// "port": ...,
/// "key": "..."
/// }
/// }
/// }
/// }
/// ```
fn strip_session_params(&mut self) -> Option<(Option<Value>, Option<Value>, Option<Value>)> {
self.0
.get_mut("params")
.and_then(|v| v.get_mut("initializationOptions"))
.and_then(|v| v.as_object_mut())
.and_then(|o| o.remove("distant"))
.map(|mut v| {
(
v.get_mut("host").map(Value::take),
v.get_mut("port").map(Value::take),
v.get_mut("key").map(Value::take),
)
})
}
}
impl AsRef<Map<String, Value>> for LspContent {
@ -491,10 +396,6 @@ impl FromStr for LspContent {
#[cfg(test)]
mod tests {
use super::*;
use crate::net::SecretKey;
// 32-byte test hex key (64 hex characters)
const TEST_HEX_KEY: &str = "ABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCD";
macro_rules! make_obj {
($($tail:tt)*) => {
@ -506,8 +407,8 @@ mod tests {
}
#[test]
fn data_display_should_output_header_and_content() {
let data = LspData {
fn msg_display_should_output_header_and_content() {
let msg = LspMsg {
header: LspHeader {
content_length: 123,
content_type: Some(String::from("some content type")),
@ -515,7 +416,7 @@ mod tests {
content: LspContent(make_obj!({"hello": "world"})),
};
let output = data.to_string();
let output = msg.to_string();
assert_eq!(
output,
concat!(
@ -530,7 +431,7 @@ mod tests {
}
#[test]
fn data_from_buf_reader_should_be_successful_if_valid_data_received() {
fn msg_from_buf_reader_should_be_successful_if_valid_msg_received() {
let mut input = io::Cursor::new(concat!(
"Content-Length: 22\r\n",
"Content-Type: some content type\r\n",
@ -539,45 +440,45 @@ mod tests {
" \"hello\": \"world\"\n",
"}",
));
let data = LspData::from_buf_reader(&mut input).unwrap();
assert_eq!(data.header.content_length, 22);
let msg = LspMsg::from_buf_reader(&mut input).unwrap();
assert_eq!(msg.header.content_length, 22);
assert_eq!(
data.header.content_type.as_deref(),
msg.header.content_type.as_deref(),
Some("some content type")
);
assert_eq!(data.content.as_ref(), &make_obj!({ "hello": "world" }));
assert_eq!(msg.content.as_ref(), &make_obj!({ "hello": "world" }));
}
#[test]
fn data_from_buf_reader_should_fail_if_reach_eof_before_received_full_data() {
fn msg_from_buf_reader_should_fail_if_reach_eof_before_received_full_msg() {
// No line termination
let err = LspData::from_buf_reader(&mut io::Cursor::new("Content-Length: 22")).unwrap_err();
let err = LspMsg::from_buf_reader(&mut io::Cursor::new("Content-Length: 22")).unwrap_err();
assert!(
matches!(err, LspDataParseError::BadHeaderTermination),
matches!(err, LspMsgParseError::BadHeaderTermination),
"{:?}",
err
);
// Header doesn't finish
let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!(
let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!(
"Content-Length: 22\r\n",
"Content-Type: some content type\r\n",
)))
.unwrap_err();
assert!(matches!(err, LspDataParseError::UnexpectedEof), "{:?}", err);
assert!(matches!(err, LspMsgParseError::UnexpectedEof), "{:?}", err);
// No content after header
let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!(
let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!(
"Content-Length: 22\r\n",
"\r\n",
)))
.unwrap_err();
assert!(matches!(err, LspDataParseError::UnexpectedEof), "{:?}", err);
assert!(matches!(err, LspMsgParseError::UnexpectedEof), "{:?}", err);
}
#[test]
fn data_from_buf_reader_should_fail_if_missing_proper_line_termination_for_header_field() {
let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!(
fn msg_from_buf_reader_should_fail_if_missing_proper_line_termination_for_header_field() {
let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!(
"Content-Length: 22\n",
"\r\n",
"{\n",
@ -586,16 +487,16 @@ mod tests {
)))
.unwrap_err();
assert!(
matches!(err, LspDataParseError::BadHeaderTermination),
matches!(err, LspMsgParseError::BadHeaderTermination),
"{:?}",
err
);
}
#[test]
fn data_from_buf_reader_should_fail_if_bad_header_provided() {
fn msg_from_buf_reader_should_fail_if_bad_header_provided() {
// Invalid content length
let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!(
let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!(
"Content-Length: -1\r\n",
"\r\n",
"{\n",
@ -603,10 +504,10 @@ mod tests {
"}",
)))
.unwrap_err();
assert!(matches!(err, LspDataParseError::BadHeader(_)), "{:?}", err);
assert!(matches!(err, LspMsgParseError::BadHeader(_)), "{:?}", err);
// Missing content length
let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!(
let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!(
"Content-Type: some content type\r\n",
"\r\n",
"{\n",
@ -614,371 +515,30 @@ mod tests {
"}",
)))
.unwrap_err();
assert!(matches!(err, LspDataParseError::BadHeader(_)), "{:?}", err);
assert!(matches!(err, LspMsgParseError::BadHeader(_)), "{:?}", err);
}
#[test]
fn data_from_buf_reader_should_fail_if_bad_content_provided() {
fn msg_from_buf_reader_should_fail_if_bad_content_provided() {
// Not full content
let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!(
let err = LspMsg::from_buf_reader(&mut io::Cursor::new(concat!(
"Content-Length: 21\r\n",
"\r\n",
"{\n",
" \"hello\": \"world\"\n",
)))
.unwrap_err();
assert!(matches!(err, LspDataParseError::BadContent(_)), "{:?}", err);
assert!(matches!(err, LspMsgParseError::BadContent(_)), "{:?}", err);
}
#[test]
fn data_from_buf_reader_should_fail_if_non_utf8_data_encountered_for_content() {
fn msg_from_buf_reader_should_fail_if_non_utf8_msg_encountered_for_content() {
// Not utf-8 content
let mut raw = b"Content-Length: 2\r\n\r\n".to_vec();
raw.extend(vec![0, 159]);
let err = LspData::from_buf_reader(&mut io::Cursor::new(raw)).unwrap_err();
assert!(matches!(err, LspDataParseError::BadInput(_)), "{:?}", err);
}
#[test]
fn data_take_session_info_should_succeed_if_valid_session_found_in_params() {
let mut data = LspData {
header: LspHeader {
content_length: 123,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
"distant": {
"host": "some.host",
"port": 22,
"key": TEST_HEX_KEY
}
}
}
})),
};
let info = data.take_session_info().unwrap();
assert_eq!(
info,
SessionInfo {
host: String::from("some.host"),
port: 22,
key: SecretKey::from_slice(&hex::decode(TEST_HEX_KEY).unwrap()).unwrap(),
}
);
}
#[test]
fn data_take_session_info_should_remove_session_parameters_if_successful() {
let mut data = LspData {
header: LspHeader {
content_length: 123,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
"distant": {
"host": "some.host",
"port": 22,
"key": TEST_HEX_KEY
}
}
}
})),
};
let _ = data.take_session_info().unwrap();
assert_eq!(
data.content.as_ref(),
&make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {}
}
})
);
}
#[test]
fn data_take_session_info_should_adjust_content_length_based_on_new_content_byte_length() {
let mut data = LspData {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
"distant": {
"host": "some.host",
"port": 22,
"key": TEST_HEX_KEY
}
}
}
})),
};
let _ = data.take_session_info().unwrap();
assert_eq!(data.header.content_length, data.content.to_string().len());
}
#[test]
fn data_take_session_info_should_fail_if_path_incomplete_to_session_params() {
let mut data = LspData {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {}
}
})),
};
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionInfoError::MissingSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_info_should_fail_if_missing_host_param() {
let mut data = LspData {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
"distant": {
"port": 22,
"key": TEST_HEX_KEY
}
}
}
})),
};
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionInfoError::MissingSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_info_should_fail_if_host_param_is_invalid() {
let mut data = LspData {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
"distant": {
"host": 1234,
"port": 22,
"key": TEST_HEX_KEY
}
}
}
})),
};
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionInfoError::InvalidSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_info_should_fail_if_missing_port_param() {
let mut data = LspData {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
"distant": {
"host": "some.host",
"key": TEST_HEX_KEY
}
}
}
})),
};
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionInfoError::MissingSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_info_should_fail_if_port_param_is_invalid() {
let mut data = LspData {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
"distant": {
"host": "some.host",
"port": "abcd",
"key": TEST_HEX_KEY
}
}
}
})),
};
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionInfoError::InvalidSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_info_should_fail_if_missing_key_param() {
let mut data = LspData {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
"distant": {
"host": "some.host",
"port": 22,
}
}
}
})),
};
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionInfoError::MissingSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_info_should_fail_if_key_param_is_invalid() {
let mut data = LspData {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
"distant": {
"host": "some.host",
"port": 22,
"key": 1234,
}
}
}
})),
};
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionInfoError::InvalidSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_info_should_fail_if_missing_method_field() {
let mut data = LspData {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"params": {
"initializationOptions": {
"distant": {
"host": "some.host",
"port": 22,
"key": TEST_HEX_KEY,
}
}
}
})),
};
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionInfoError::NotInitializeRequest),
"{:?}",
err
);
}
#[test]
fn data_take_session_info_should_fail_if_method_field_is_not_initialize() {
let mut data = LspData {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspContent(make_obj!({
"method": "not initialize",
"params": {
"initializationOptions": {
"distant": {
"host": "some.host",
"port": 22,
"key": TEST_HEX_KEY,
}
}
}
})),
};
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionInfoError::NotInitializeRequest),
"{:?}",
err
);
let err = LspMsg::from_buf_reader(&mut io::Cursor::new(raw)).unwrap_err();
assert!(matches!(err, LspMsgParseError::BadInput(_)), "{:?}", err);
}
#[test]

@ -1,10 +0,0 @@
mod lsp;
mod process;
mod session;
mod utils;
mod watcher;
pub use lsp::*;
pub use process::*;
pub use session::*;
pub use watcher::*;

File diff suppressed because it is too large Load Diff

@ -1,514 +0,0 @@
use crate::{
client::{
RemoteLspProcess, RemoteProcess, RemoteProcessError, SessionChannel, UnwatchError,
WatchError, Watcher,
},
data::{
ChangeKindSet, DirEntry, Error as Failure, Metadata, PtySize, Request, RequestData,
ResponseData, SystemInfo,
},
net::TransportError,
};
use derive_more::{Display, Error, From};
use std::{future::Future, path::PathBuf, pin::Pin};
/// Represents an error that can occur related to convenience functions tied to a
/// [`SessionChannel`] through [`SessionChannelExt`]
#[derive(Debug, Display, Error, From)]
pub enum SessionChannelExtError {
/// Occurs when the remote action fails
Failure(#[error(not(source))] Failure),
/// Occurs when a transport error is encountered
TransportError(TransportError),
/// Occurs when receiving a response that was not expected
MismatchedResponse,
}
pub type AsyncReturn<'a, T, E = SessionChannelExtError> =
Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>;
/// Provides convenience functions on top of a [`SessionChannel`]
pub trait SessionChannelExt {
/// Appends to a remote file using the data from a collection of bytes
fn append_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()>;
/// Appends to a remote file using the data from a string
fn append_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()>;
/// Copies a remote file or directory from src to dst
fn copy(
&mut self,
tenant: impl Into<String>,
src: impl Into<PathBuf>,
dst: impl Into<PathBuf>,
) -> AsyncReturn<'_, ()>;
/// Creates a remote directory, optionally creating all parent components if specified
fn create_dir(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
all: bool,
) -> AsyncReturn<'_, ()>;
/// Checks if a path exists on a remote machine
fn exists(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, bool>;
/// Retrieves metadata about a path on a remote machine
fn metadata(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
canonicalize: bool,
resolve_file_type: bool,
) -> AsyncReturn<'_, Metadata>;
/// Reads entries from a directory, returning a tuple of directory entries and failures
fn read_dir(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
depth: usize,
absolute: bool,
canonicalize: bool,
include_root: bool,
) -> AsyncReturn<'_, (Vec<DirEntry>, Vec<Failure>)>;
/// Reads a remote file as a collection of bytes
fn read_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, Vec<u8>>;
/// Returns a remote file as a string
fn read_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, String>;
/// Removes a remote file or directory, supporting removal of non-empty directories if
/// force is true
fn remove(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
force: bool,
) -> AsyncReturn<'_, ()>;
/// Renames a remote file or directory from src to dst
fn rename(
&mut self,
tenant: impl Into<String>,
src: impl Into<PathBuf>,
dst: impl Into<PathBuf>,
) -> AsyncReturn<'_, ()>;
/// Watches a remote file or directory
fn watch(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
recursive: bool,
only: impl Into<ChangeKindSet>,
except: impl Into<ChangeKindSet>,
) -> AsyncReturn<'_, Watcher, WatchError>;
/// Unwatches a remote file or directory
fn unwatch(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, (), UnwatchError>;
/// Spawns a process on the remote machine
fn spawn(
&mut self,
tenant: impl Into<String>,
cmd: impl Into<String>,
args: Vec<impl Into<String>>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteProcess, RemoteProcessError>;
/// Spawns an LSP process on the remote machine
fn spawn_lsp(
&mut self,
tenant: impl Into<String>,
cmd: impl Into<String>,
args: Vec<impl Into<String>>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteLspProcess, RemoteProcessError>;
/// Retrieves information about the remote system
fn system_info(&mut self, tenant: impl Into<String>) -> AsyncReturn<'_, SystemInfo>;
/// Writes a remote file with the data from a collection of bytes
fn write_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()>;
/// Writes a remote file with the data from a string
fn write_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()>;
}
macro_rules! make_body {
($self:expr, $tenant:expr, $data:expr, @ok) => {
make_body!($self, $tenant, $data, |data| {
match data {
ResponseData::Ok => Ok(()),
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
_ => Err(SessionChannelExtError::MismatchedResponse),
}
})
};
($self:expr, $tenant:expr, $data:expr, $and_then:expr) => {{
let req = Request::new($tenant, vec![$data]);
Box::pin(async move {
$self
.send(req)
.await
.map_err(SessionChannelExtError::from)
.and_then(|res| {
if res.payload.len() == 1 {
Ok(res.payload.into_iter().next().unwrap())
} else {
Err(SessionChannelExtError::MismatchedResponse)
}
})
.and_then($and_then)
})
}};
}
impl SessionChannelExt for SessionChannel {
fn append_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::FileAppend { path: path.into(), data: data.into() },
@ok
)
}
fn append_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::FileAppendText { path: path.into(), text: data.into() },
@ok
)
}
fn copy(
&mut self,
tenant: impl Into<String>,
src: impl Into<PathBuf>,
dst: impl Into<PathBuf>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::Copy { src: src.into(), dst: dst.into() },
@ok
)
}
fn create_dir(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
all: bool,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::DirCreate { path: path.into(), all },
@ok
)
}
fn exists(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, bool> {
make_body!(
self,
tenant,
RequestData::Exists { path: path.into() },
|data| match data {
ResponseData::Exists { value } => Ok(value),
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
_ => Err(SessionChannelExtError::MismatchedResponse),
}
)
}
fn metadata(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
canonicalize: bool,
resolve_file_type: bool,
) -> AsyncReturn<'_, Metadata> {
make_body!(
self,
tenant,
RequestData::Metadata {
path: path.into(),
canonicalize,
resolve_file_type
},
|data| match data {
ResponseData::Metadata(x) => Ok(x),
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
_ => Err(SessionChannelExtError::MismatchedResponse),
}
)
}
fn read_dir(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
depth: usize,
absolute: bool,
canonicalize: bool,
include_root: bool,
) -> AsyncReturn<'_, (Vec<DirEntry>, Vec<Failure>)> {
make_body!(
self,
tenant,
RequestData::DirRead {
path: path.into(),
depth,
absolute,
canonicalize,
include_root
},
|data| match data {
ResponseData::DirEntries { entries, errors } => Ok((entries, errors)),
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
_ => Err(SessionChannelExtError::MismatchedResponse),
}
)
}
fn read_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, Vec<u8>> {
make_body!(
self,
tenant,
RequestData::FileRead { path: path.into() },
|data| match data {
ResponseData::Blob { data } => Ok(data),
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
_ => Err(SessionChannelExtError::MismatchedResponse),
}
)
}
fn read_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, String> {
make_body!(
self,
tenant,
RequestData::FileReadText { path: path.into() },
|data| match data {
ResponseData::Text { data } => Ok(data),
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
_ => Err(SessionChannelExtError::MismatchedResponse),
}
)
}
fn remove(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
force: bool,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::Remove { path: path.into(), force },
@ok
)
}
fn rename(
&mut self,
tenant: impl Into<String>,
src: impl Into<PathBuf>,
dst: impl Into<PathBuf>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::Rename { src: src.into(), dst: dst.into() },
@ok
)
}
fn watch(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
recursive: bool,
only: impl Into<ChangeKindSet>,
except: impl Into<ChangeKindSet>,
) -> AsyncReturn<'_, Watcher, WatchError> {
let tenant = tenant.into();
let path = path.into();
let only = only.into();
let except = except.into();
Box::pin(async move {
Watcher::watch(tenant, self.clone(), path, recursive, only, except).await
})
}
fn unwatch(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, (), UnwatchError> {
fn inner_unwatch(
channel: &mut SessionChannel,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, ()> {
make_body!(
channel,
tenant,
RequestData::Unwatch { path: path.into() },
@ok
)
}
let tenant = tenant.into();
let path = path.into();
Box::pin(async move {
inner_unwatch(self, tenant, path)
.await
.map_err(UnwatchError::from)
})
}
fn spawn(
&mut self,
tenant: impl Into<String>,
cmd: impl Into<String>,
args: Vec<impl Into<String>>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteProcess, RemoteProcessError> {
let tenant = tenant.into();
let cmd = cmd.into();
let args = args.into_iter().map(Into::into).collect();
Box::pin(async move {
RemoteProcess::spawn(tenant, self.clone(), cmd, args, persist, pty).await
})
}
fn spawn_lsp(
&mut self,
tenant: impl Into<String>,
cmd: impl Into<String>,
args: Vec<impl Into<String>>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteLspProcess, RemoteProcessError> {
let tenant = tenant.into();
let cmd = cmd.into();
let args = args.into_iter().map(Into::into).collect();
Box::pin(async move {
RemoteLspProcess::spawn(tenant, self.clone(), cmd, args, persist, pty).await
})
}
fn system_info(&mut self, tenant: impl Into<String>) -> AsyncReturn<'_, SystemInfo> {
make_body!(
self,
tenant,
RequestData::SystemInfo {},
|data| match data {
ResponseData::SystemInfo(x) => Ok(x),
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
_ => Err(SessionChannelExtError::MismatchedResponse),
}
)
}
fn write_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::FileWrite { path: path.into(), data: data.into() },
@ok
)
}
fn write_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::FileWriteText { path: path.into(), text: data.into() },
@ok
)
}
}

@ -1,235 +0,0 @@
use crate::net::{SecretKey32, UnprotectedToHexKey};
use derive_more::{Display, Error};
use std::{
env,
net::{IpAddr, SocketAddr},
ops::Deref,
path::{Path, PathBuf},
str::FromStr,
};
use tokio::{io, net::lookup_host};
#[derive(Debug, PartialEq, Eq)]
pub struct SessionInfo {
pub host: String,
pub port: u16,
pub key: SecretKey32,
}
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq)]
pub enum SessionInfoParseError {
#[display(fmt = "Prefix of string is invalid")]
BadPrefix,
#[display(fmt = "Bad hex key for session")]
BadHexKey,
#[display(fmt = "Invalid key for session")]
InvalidKey,
#[display(fmt = "Invalid port for session")]
InvalidPort,
#[display(fmt = "Missing address for session")]
MissingAddr,
#[display(fmt = "Missing key for session")]
MissingKey,
#[display(fmt = "Missing port for session")]
MissingPort,
}
impl From<SessionInfoParseError> for io::Error {
fn from(x: SessionInfoParseError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, x)
}
}
impl FromStr for SessionInfo {
type Err = SessionInfoParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut tokens = s.trim().split(' ').take(5);
// First, validate that we have the appropriate prefix
if tokens.next().ok_or(SessionInfoParseError::BadPrefix)? != "DISTANT" {
return Err(SessionInfoParseError::BadPrefix);
}
if tokens.next().ok_or(SessionInfoParseError::BadPrefix)? != "CONNECT" {
return Err(SessionInfoParseError::BadPrefix);
}
// Second, load up the address without parsing it
let host = tokens
.next()
.ok_or(SessionInfoParseError::MissingAddr)?
.trim()
.to_string();
// Third, load up the port and parse it into a number
let port = tokens
.next()
.ok_or(SessionInfoParseError::MissingPort)?
.trim()
.parse::<u16>()
.map_err(|_| SessionInfoParseError::InvalidPort)?;
// Fourth, load up the key and convert it back into a secret key from a hex slice
let key = SecretKey32::from_slice(
&hex::decode(
tokens
.next()
.ok_or(SessionInfoParseError::MissingKey)?
.trim(),
)
.map_err(|_| SessionInfoParseError::BadHexKey)?,
)
.map_err(|_| SessionInfoParseError::InvalidKey)?;
Ok(SessionInfo { host, port, key })
}
}
impl SessionInfo {
/// Loads session from environment variables
pub fn from_environment() -> io::Result<Self> {
fn to_err(x: env::VarError) -> io::Error {
io::Error::new(io::ErrorKind::InvalidInput, x)
}
let host = env::var("DISTANT_HOST").map_err(to_err)?;
let port = env::var("DISTANT_PORT").map_err(to_err)?;
let key = env::var("DISTANT_KEY").map_err(to_err)?;
Ok(format!("DISTANT CONNECT {} {} {}", host, port, key).parse()?)
}
/// Loads session from the next line available in this program's stdin
pub fn from_stdin() -> io::Result<Self> {
let mut line = String::new();
std::io::stdin().read_line(&mut line)?;
line.parse()
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
}
/// Consumes the session and returns the key
pub fn into_key(self) -> SecretKey32 {
self.key
}
/// Returns the ip address associated with the session based on the host
pub async fn to_ip_addr(&self) -> io::Result<IpAddr> {
let addr = match self.host.parse::<IpAddr>() {
Ok(addr) => addr,
Err(_) => lookup_host((self.host.as_str(), self.port))
.await?
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Failed to lookup_host"))?
.ip(),
};
Ok(addr)
}
/// Returns socket address associated with the session
pub async fn to_socket_addr(&self) -> io::Result<SocketAddr> {
let addr = self.to_ip_addr().await?;
Ok(SocketAddr::from((addr, self.port)))
}
/// Converts the session's key to a hex string
pub fn key_to_unprotected_string(&self) -> String {
self.key.unprotected_to_hex_key()
}
/// Converts to unprotected string that exposes the key in the form of
/// `DISTANT CONNECT <host> <port> <key>`
pub fn to_unprotected_string(&self) -> String {
format!(
"DISTANT CONNECT {} {} {}",
self.host,
self.port,
self.key.unprotected_to_hex_key()
)
}
}
/// Provides operations related to working with a session that is disk-based
pub struct SessionInfoFile {
path: PathBuf,
session: SessionInfo,
}
impl AsRef<Path> for SessionInfoFile {
fn as_ref(&self) -> &Path {
self.as_path()
}
}
impl AsRef<SessionInfo> for SessionInfoFile {
fn as_ref(&self) -> &SessionInfo {
self.as_session()
}
}
impl Deref for SessionInfoFile {
type Target = SessionInfo;
fn deref(&self) -> &Self::Target {
&self.session
}
}
impl From<SessionInfoFile> for SessionInfo {
fn from(sf: SessionInfoFile) -> Self {
sf.session
}
}
impl SessionInfoFile {
/// Creates a new inmemory pointer to a session and its file
pub fn new(path: impl Into<PathBuf>, session: SessionInfo) -> Self {
Self {
path: path.into(),
session,
}
}
/// Returns a reference to the path to the session file
pub fn as_path(&self) -> &Path {
self.path.as_path()
}
/// Returns a reference to the session
pub fn as_session(&self) -> &SessionInfo {
&self.session
}
/// Saves a session by overwriting its current
pub async fn save(&self) -> io::Result<()> {
self.save_to(self.as_path(), true).await
}
/// Saves a session to to a file at the specified path
///
/// If all is true, will create all directories leading up to file's location
pub async fn save_to(&self, path: impl AsRef<Path>, all: bool) -> io::Result<()> {
if all {
if let Some(dir) = path.as_ref().parent() {
tokio::fs::create_dir_all(dir).await?;
}
}
tokio::fs::write(path.as_ref(), self.session.to_unprotected_string()).await
}
/// Loads a session from a file at the specified path
pub async fn load_from(path: impl AsRef<Path>) -> io::Result<Self> {
let text = tokio::fs::read_to_string(path.as_ref()).await?;
Ok(Self {
path: path.as_ref().to_path_buf(),
session: text.parse()?,
})
}
}

@ -1,84 +0,0 @@
use crate::{client::utils, data::Response};
use std::{collections::HashMap, time::Duration};
use tokio::{io, sync::mpsc};
pub struct PostOffice {
mailboxes: HashMap<usize, mpsc::Sender<Response>>,
}
impl PostOffice {
pub fn new() -> Self {
Self {
mailboxes: HashMap::new(),
}
}
/// Creates a new mailbox using the given id and buffer size for maximum messages
pub fn make_mailbox(&mut self, id: usize, buffer: usize) -> Mailbox {
let (tx, rx) = mpsc::channel(buffer);
self.mailboxes.insert(id, tx);
Mailbox { id, rx }
}
/// Delivers a response to appropriate mailbox, returning false if no mailbox is found
/// for the response or if the mailbox is no longer receiving responses
pub async fn deliver(&mut self, res: Response) -> bool {
let id = res.origin_id;
let success = if let Some(tx) = self.mailboxes.get_mut(&id) {
tx.send(res).await.is_ok()
} else {
false
};
// If failed, we want to remove the mailbox sender as it is no longer valid
if !success {
self.mailboxes.remove(&id);
}
success
}
/// Removes all mailboxes from post office that are closed
pub fn prune_mailboxes(&mut self) {
self.mailboxes.retain(|_, tx| !tx.is_closed())
}
/// Closes out all mailboxes by removing the mailboxes delivery trackers internally
pub fn clear_mailboxes(&mut self) {
self.mailboxes.clear();
}
}
pub struct Mailbox {
/// Represents id associated with the mailbox
id: usize,
/// Underlying mailbox storage
rx: mpsc::Receiver<Response>,
}
impl Mailbox {
/// Represents id associated with the mailbox
pub fn id(&self) -> usize {
self.id
}
/// Receives next response in mailbox
pub async fn next(&mut self) -> Option<Response> {
self.rx.recv().await
}
/// Receives next response in mailbox, waiting up to duration before timing out
pub async fn next_timeout(&mut self, duration: Duration) -> io::Result<Option<Response>> {
utils::timeout(duration, self.next()).await
}
/// Closes the mailbox such that it will not receive any more responses
///
/// Any responses already in the mailbox will still be returned via `next`
pub async fn close(&mut self) {
self.rx.close()
}
}

@ -1,504 +0,0 @@
use crate::{
client::utils,
constants::CLIENT_MAILBOX_CAPACITY,
data::{Request, Response},
net::{Codec, DataStream, Transport, TransportError},
};
use log::*;
use serde::{Deserialize, Serialize};
use std::{
convert,
net::SocketAddr,
ops::{Deref, DerefMut},
path::{Path, PathBuf},
sync::{Arc, Weak},
};
use tokio::{
io,
net::TcpStream,
sync::{mpsc, Mutex},
task::{JoinError, JoinHandle},
time::Duration,
};
mod ext;
pub use ext::{SessionChannelExt, SessionChannelExtError};
mod info;
pub use info::{SessionInfo, SessionInfoFile, SessionInfoParseError};
mod mailbox;
pub use mailbox::Mailbox;
use mailbox::PostOffice;
/// Details about the session
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SessionDetails {
/// Indicates session is a TCP type
Tcp { addr: SocketAddr, tag: String },
/// Indicates session is a Unix socket type
Socket { path: PathBuf, tag: String },
/// Indicates session type is inmemory
Inmemory { tag: String },
/// Indicates session type is a custom type (such as ssh)
Custom { tag: String },
}
impl SessionDetails {
/// Represents the tag associated with the session
pub fn tag(&self) -> &str {
match self {
Self::Tcp { tag, .. } => tag.as_str(),
Self::Socket { tag, .. } => tag.as_str(),
Self::Inmemory { tag } => tag.as_str(),
Self::Custom { tag } => tag.as_str(),
}
}
/// Represents the socket address associated with the session, if it has one
pub fn addr(&self) -> Option<SocketAddr> {
match self {
Self::Tcp { addr, .. } => Some(*addr),
_ => None,
}
}
/// Represents the path associated with the session, if it has one
pub fn path(&self) -> Option<&Path> {
match self {
Self::Socket { path, .. } => Some(path.as_path()),
_ => None,
}
}
}
/// Represents a session with a remote server that can be used to send requests & receive responses
pub struct Session {
/// Used to send requests to a server
channel: SessionChannel,
/// Details about the session
details: Option<SessionDetails>,
/// Contains the task that is running to send requests to a server
request_task: JoinHandle<()>,
/// Contains the task that is running to receive responses from a server
response_task: JoinHandle<()>,
/// Contains the task that runs on a timer to prune closed mailboxes
prune_task: JoinHandle<()>,
}
impl Session {
/// Connect to a remote TCP server using the provided information
pub async fn tcp_connect<U>(addr: SocketAddr, codec: U) -> io::Result<Self>
where
U: Codec + Send + 'static,
{
let transport = Transport::<TcpStream, U>::connect(addr, codec).await?;
let details = SessionDetails::Tcp {
addr,
tag: transport.to_connection_tag(),
};
debug!(
"Session has been established with {}",
transport
.peer_addr()
.map(|x| x.to_string())
.unwrap_or_else(|_| String::from("???"))
);
Self::initialize_with_details(transport, Some(details))
}
/// Connect to a remote TCP server, timing out after duration has passed
pub async fn tcp_connect_timeout<U>(
addr: SocketAddr,
codec: U,
duration: Duration,
) -> io::Result<Self>
where
U: Codec + Send + 'static,
{
utils::timeout(duration, Self::tcp_connect(addr, codec))
.await
.and_then(convert::identity)
}
/// Convert into underlying channel
pub fn into_channel(self) -> SessionChannel {
self.channel
}
}
#[cfg(unix)]
impl Session {
/// Connect to a proxy unix socket
pub async fn unix_connect<U>(path: impl AsRef<std::path::Path>, codec: U) -> io::Result<Self>
where
U: Codec + Send + 'static,
{
let p = path.as_ref();
let transport = Transport::<tokio::net::UnixStream, U>::connect(p, codec).await?;
let details = SessionDetails::Socket {
path: p.to_path_buf(),
tag: transport.to_connection_tag(),
};
debug!(
"Session has been established with {}",
transport
.peer_addr()
.map(|x| format!("{:?}", x))
.unwrap_or_else(|_| String::from("???"))
);
Self::initialize_with_details(transport, Some(details))
}
/// Connect to a proxy unix socket, timing out after duration has passed
pub async fn unix_connect_timeout<U>(
path: impl AsRef<std::path::Path>,
codec: U,
duration: Duration,
) -> io::Result<Self>
where
U: Codec + Send + 'static,
{
utils::timeout(duration, Self::unix_connect(path, codec))
.await
.and_then(convert::identity)
}
}
impl Session {
/// Initializes a session using the provided transport and no extra details
pub fn initialize<T, U>(transport: Transport<T, U>) -> io::Result<Self>
where
T: DataStream,
U: Codec + Send + 'static,
{
Self::initialize_with_details(transport, None)
}
/// Initializes a session using the provided transport and extra details
pub fn initialize_with_details<T, U>(
transport: Transport<T, U>,
details: Option<SessionDetails>,
) -> io::Result<Self>
where
T: DataStream,
U: Codec + Send + 'static,
{
let (mut t_read, mut t_write) = transport.into_split();
let post_office = Arc::new(Mutex::new(PostOffice::new()));
let weak_post_office = Arc::downgrade(&post_office);
// Start a task that continually checks for responses and delivers them using the
// post office
let response_task = tokio::spawn(async move {
loop {
match t_read.receive::<Response>().await {
Ok(Some(res)) => {
trace!("Incoming response: {:?}", res);
let res_id = res.id;
let res_origin_id = res.origin_id;
// Try to send response to appropriate mailbox
// NOTE: We don't log failures as errors as using fire(...) for a
// session is valid and would not have a mailbox
if !post_office.lock().await.deliver(res).await {
trace!(
"Response {} has no mailbox for origin {}",
res_id,
res_origin_id
);
}
}
Ok(None) => {
debug!("Session closing response task as transport read-half closed!");
break;
}
Err(x) => {
error!("Failed to receive response from server: {}", x);
break;
}
}
}
// Clean up remaining mailbox senders
post_office.lock().await.clear_mailboxes();
});
let (tx, mut rx) = mpsc::channel::<Request>(1);
let request_task = tokio::spawn(async move {
while let Some(req) = rx.recv().await {
if let Err(x) = t_write.send(req).await {
error!("Failed to send request to server: {}", x);
break;
}
}
});
// Create a task that runs once a minute and prunes mailboxes
let post_office = Weak::clone(&weak_post_office);
let prune_task = tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(60)).await;
if let Some(post_office) = Weak::upgrade(&post_office) {
post_office.lock().await.prune_mailboxes();
} else {
break;
}
}
});
let channel = SessionChannel {
tx,
post_office: weak_post_office,
};
Ok(Self {
channel,
details,
request_task,
response_task,
prune_task,
})
}
}
impl Session {
/// Returns details about the session, if it has any
pub fn details(&self) -> Option<&SessionDetails> {
self.details.as_ref()
}
/// Waits for the session to terminate, which results when the receiving end of the network
/// connection is closed (or the session is shutdown)
pub async fn wait(self) -> Result<(), JoinError> {
self.prune_task.abort();
tokio::try_join!(self.request_task, self.response_task).map(|_| ())
}
/// Abort the session's current connection by forcing its tasks to abort
pub fn abort(&self) {
self.request_task.abort();
self.response_task.abort();
self.prune_task.abort();
}
/// Clones the underlying channel for requests and returns the cloned instance
pub fn clone_channel(&self) -> SessionChannel {
self.channel.clone()
}
}
impl Deref for Session {
type Target = SessionChannel;
fn deref(&self) -> &Self::Target {
&self.channel
}
}
impl DerefMut for Session {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.channel
}
}
impl From<Session> for SessionChannel {
fn from(session: Session) -> Self {
session.channel
}
}
/// Represents a sender of requests tied to a session, holding onto a weak reference of
/// mailboxes to relay responses, meaning that once the [`Session`] is closed or dropped,
/// any sent request will no longer be able to receive responses
#[derive(Clone)]
pub struct SessionChannel {
/// Used to send requests to a server
tx: mpsc::Sender<Request>,
/// Collection of mailboxes for receiving responses to requests
post_office: Weak<Mutex<PostOffice>>,
}
impl SessionChannel {
/// Returns true if no more requests can be transferred
pub fn is_closed(&self) -> bool {
self.tx.is_closed()
}
/// Sends a request and returns a mailbox that can receive one or more responses, failing if
/// unable to send a request or if the session's receiving line to the remote server has
/// already been severed
pub async fn mail(&mut self, req: Request) -> Result<Mailbox, TransportError> {
trace!("Mailing request: {:?}", req);
// First, create a mailbox using the request's id
let mailbox = Weak::upgrade(&self.post_office)
.ok_or_else(|| {
TransportError::IoError(io::Error::new(
io::ErrorKind::NotConnected,
"Session's post office is no longer available",
))
})?
.lock()
.await
.make_mailbox(req.id, CLIENT_MAILBOX_CAPACITY);
// Second, send the request
self.fire(req).await?;
// Third, return mailbox
Ok(mailbox)
}
/// Sends a request and waits for a response, failing if unable to send a request or if
/// the session's receiving line to the remote server has already been severed
pub async fn send(&mut self, req: Request) -> Result<Response, TransportError> {
trace!("Sending request: {:?}", req);
// Send mail and get back a mailbox
let mut mailbox = self.mail(req).await?;
// Wait for first response, and then drop the mailbox
mailbox.next().await.ok_or_else(|| {
TransportError::IoError(io::Error::from(io::ErrorKind::ConnectionAborted))
})
}
/// Sends a request and waits for a response, timing out after duration has passed
pub async fn send_timeout(
&mut self,
req: Request,
duration: Duration,
) -> Result<Response, TransportError> {
utils::timeout(duration, self.send(req))
.await
.map_err(TransportError::from)
.and_then(convert::identity)
}
/// Sends a request without waiting for a response; this method is able to be used even
/// if the session's receiving line to the remote server has been severed
pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> {
trace!("Firing off request: {:?}", req);
self.tx
.send(req)
.await
.map_err(|x| TransportError::IoError(io::Error::new(io::ErrorKind::BrokenPipe, x)))
}
/// Sends a request without waiting for a response, timing out after duration has passed
pub async fn fire_timeout(
&mut self,
req: Request,
duration: Duration,
) -> Result<(), TransportError> {
utils::timeout(duration, self.fire(req))
.await
.map_err(TransportError::from)
.and_then(convert::identity)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
constants::test::TENANT,
data::{RequestData, ResponseData},
};
use std::time::Duration;
#[tokio::test]
async fn mail_should_return_mailbox_that_receives_responses_until_transport_closes() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
let res = Response::new(TENANT, req.id, vec![ResponseData::Ok]);
let mut mailbox = session.mail(req).await.unwrap();
// Get first response
match tokio::join!(mailbox.next(), t2.send(res.clone())) {
(Some(actual), _) => assert_eq!(actual, res),
x => panic!("Unexpected response: {:?}", x),
}
// Get second response
match tokio::join!(mailbox.next(), t2.send(res.clone())) {
(Some(actual), _) => assert_eq!(actual, res),
x => panic!("Unexpected response: {:?}", x),
}
// Trigger the mailbox to wait BEFORE closing our transport to ensure that
// we don't get stuck if the mailbox was already waiting
let next_task = tokio::spawn(async move { mailbox.next().await });
tokio::task::yield_now().await;
drop(t2);
match next_task.await {
Ok(None) => {}
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn send_should_wait_until_response_received() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
let res = Response::new(
TENANT,
req.id,
vec![ResponseData::ProcEntries {
entries: Vec::new(),
}],
);
let (actual, _) = tokio::join!(session.send(req), t2.send(res.clone()));
match actual {
Ok(actual) => assert_eq!(actual, res),
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn send_timeout_should_fail_if_response_not_received_in_time() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
match session.send_timeout(req, Duration::from_millis(30)).await {
Err(TransportError::IoError(x)) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
x => panic!("Unexpected response: {:?}", x),
}
let req = t2.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.tenant, TENANT);
}
#[tokio::test]
async fn fire_should_send_request_and_not_wait_for_response() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
match session.fire(req).await {
Ok(_) => {}
x => panic!("Unexpected response: {:?}", x),
}
let req = t2.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.tenant, TENANT);
}
}

@ -1,13 +0,0 @@
use std::{future::Future, time::Duration};
use tokio::{io, time};
// Wraps a future in a tokio timeout call, transforming the error into
// an io error
pub async fn timeout<T, F>(d: Duration, f: F) -> io::Result<T>
where
F: Future<Output = T>,
{
time::timeout(d, f)
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
}

@ -1,43 +1,20 @@
use crate::{
client::{SessionChannel, SessionChannelExt, SessionChannelExtError},
client::{DistantChannel, DistantChannelExt},
constants::CLIENT_WATCHER_CAPACITY,
data::{Change, ChangeKindSet, Error as DistantError, Request, RequestData, ResponseData},
net::TransportError,
data::{Change, ChangeKindSet, DistantRequestData, DistantResponseData},
DistantMsg,
};
use derive_more::{Display, Error, From};
use distant_net::Request;
use log::*;
use std::{
fmt,
fmt, io,
path::{Path, PathBuf},
};
use tokio::{sync::mpsc, task::JoinHandle};
#[derive(Debug, Display, Error)]
pub enum WatchError {
/// When no confirmation of watch is received
MissingConfirmation,
/// A server-side error occurred when attempting to watch
ServerError(DistantError),
/// When the communication over the wire has issues
TransportError(TransportError),
/// When a queued change is dropped because the response channel closed early
QueuedChangeDropped,
/// Some unexpected response was received when attempting to watch
#[display(fmt = "Unexpected response: {:?}", _0)]
UnexpectedResponse(#[error(not(source))] ResponseData),
}
#[derive(Debug, Display, From, Error)]
pub struct UnwatchError(SessionChannelExtError);
/// Represents a watcher of some path on a remote machine
pub struct Watcher {
tenant: String,
channel: SessionChannel,
channel: DistantChannel,
path: PathBuf,
task: JoinHandle<()>,
rx: mpsc::Receiver<Change>,
@ -46,24 +23,19 @@ pub struct Watcher {
impl fmt::Debug for Watcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Watcher")
.field("tenant", &self.tenant)
.field("path", &self.path)
.finish()
f.debug_struct("Watcher").field("path", &self.path).finish()
}
}
impl Watcher {
/// Creates a watcher for some remote path
pub async fn watch(
tenant: impl Into<String>,
mut channel: SessionChannel,
mut channel: DistantChannel,
path: impl Into<PathBuf>,
recursive: bool,
only: impl Into<ChangeKindSet>,
except: impl Into<ChangeKindSet>,
) -> Result<Self, WatchError> {
let tenant = tenant.into();
) -> io::Result<Self> {
let path = path.into();
let only = only.into();
let except = except.into();
@ -85,17 +57,15 @@ impl Watcher {
// Submit our run request and get back a mailbox for responses
let mut mailbox = channel
.mail(Request::new(
tenant.as_str(),
vec![RequestData::Watch {
.mail(Request::new(DistantMsg::Single(
DistantRequestData::Watch {
path: path.to_path_buf(),
recursive,
only: only.into_vec(),
except: except.into_vec(),
}],
))
.await
.map_err(WatchError::TransportError)?;
},
)))
.await?;
let (tx, rx) = mpsc::channel(CLIENT_WATCHER_CAPACITY);
@ -103,14 +73,19 @@ impl Watcher {
let mut queue: Vec<Change> = Vec::new();
let mut confirmed = false;
while let Some(res) = mailbox.next().await {
for data in res.payload {
for data in res.payload.into_vec() {
match data {
ResponseData::Changed(change) => queue.push(change),
ResponseData::Ok => {
DistantResponseData::Changed(change) => queue.push(change),
DistantResponseData::Ok => {
confirmed = true;
}
ResponseData::Error(x) => return Err(WatchError::ServerError(x)),
x => return Err(WatchError::UnexpectedResponse(x)),
DistantResponseData::Error(x) => return Err(io::Error::from(x)),
x => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("Unexpected response: {:?}", x),
))
}
}
}
@ -126,14 +101,14 @@ impl Watcher {
trace!("Forwarding {} queued changes for {:?}", queue.len(), path);
for change in queue {
if tx.send(change).await.is_err() {
return Err(WatchError::QueuedChangeDropped);
return Err(io::Error::new(io::ErrorKind::Other, "Queue change dropped"));
}
}
// If we never received an acknowledgement of watch before the mailbox closed,
// fail with a missing confirmation error
if !confirmed {
return Err(WatchError::MissingConfirmation);
return Err(io::Error::new(io::ErrorKind::Other, "Missing confirmation"));
}
// Spawn a task that continues to look for change events, discarding anything
@ -142,9 +117,9 @@ impl Watcher {
let path = path.clone();
async move {
while let Some(res) = mailbox.next().await {
for data in res.payload {
for data in res.payload.into_vec() {
match data {
ResponseData::Changed(change) => {
DistantResponseData::Changed(change) => {
// If we can't queue up a change anymore, we've
// been closed and therefore want to quit
if tx.is_closed() {
@ -168,7 +143,6 @@ impl Watcher {
});
Ok(Self {
tenant,
path,
channel,
task,
@ -193,42 +167,37 @@ impl Watcher {
}
/// Unwatches the path being watched, closing out the watcher
pub async fn unwatch(&mut self) -> Result<(), UnwatchError> {
pub async fn unwatch(&mut self) -> io::Result<()> {
trace!("Unwatching {:?}", self.path);
let result = self
.channel
.unwatch(self.tenant.to_string(), self.path.to_path_buf())
.await
.map_err(UnwatchError::from);
self.channel.unwatch(self.path.to_path_buf()).await?;
match result {
Ok(_) => {
// Kill our task that processes inbound changes if we
// have successfully unwatched the path
// Kill our task that processes inbound changes if we have successfully unwatched the path
self.task.abort();
self.active = false;
Ok(())
}
Err(x) => Err(x),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
client::Session,
data::{ChangeKind, Response},
net::{InmemoryStream, PlainCodec, Transport},
use crate::data::ChangeKind;
use crate::DistantClient;
use distant_net::{
Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Response,
TypedAsyncRead, TypedAsyncWrite,
};
use std::sync::Arc;
use tokio::sync::Mutex;
fn make_session() -> (Transport<InmemoryStream, PlainCodec>, Session) {
let (t1, t2) = Transport::make_pair();
(t1, Session::initialize(t2).unwrap())
fn make_session() -> (
FramedTransport<InmemoryTransport, PlainCodec>,
DistantClient,
) {
let (t1, t2) = FramedTransport::pair(100);
let (writer, reader) = t2.into_split();
(t1, Client::new(writer, reader).unwrap())
}
#[tokio::test]
@ -240,7 +209,6 @@ mod tests {
// in a separate async block
let watch_task = tokio::spawn(async move {
Watcher::watch(
String::from("test-tenant"),
session.clone_channel(),
test_path,
true,
@ -251,11 +219,11 @@ mod tests {
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
// Send back an acknowledgement that a watcher was created
transport
.send(Response::new("test-tenant", req.id, vec![ResponseData::Ok]))
.write(Response::new(req.id, DistantResponseData::Ok))
.await
.unwrap();
@ -273,7 +241,6 @@ mod tests {
// in a separate async block
let watch_task = tokio::spawn(async move {
Watcher::watch(
String::from("test-tenant"),
session.clone_channel(),
test_path,
true,
@ -284,11 +251,11 @@ mod tests {
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
// Send back an acknowledgement that a watcher was created
transport
.send(Response::new("test-tenant", req.id, vec![ResponseData::Ok]))
.write(Response::new(req.id.clone(), DistantResponseData::Ok))
.await
.unwrap();
@ -297,15 +264,14 @@ mod tests {
// Send some changes related to the file
transport
.send(Response::new(
"test-tenant",
.write(Response::new(
req.id,
vec![
ResponseData::Changed(Change {
DistantResponseData::Changed(Change {
kind: ChangeKind::Access,
paths: vec![test_path.to_path_buf()],
}),
ResponseData::Changed(Change {
DistantResponseData::Changed(Change {
kind: ChangeKind::Content,
paths: vec![test_path.to_path_buf()],
}),
@ -343,7 +309,6 @@ mod tests {
// in a separate async block
let watch_task = tokio::spawn(async move {
Watcher::watch(
String::from("test-tenant"),
session.clone_channel(),
test_path,
true,
@ -354,11 +319,11 @@ mod tests {
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
// Send back an acknowledgement that a watcher was created
transport
.send(Response::new("test-tenant", req.id, vec![ResponseData::Ok]))
.write(Response::new(req.id.clone(), DistantResponseData::Ok))
.await
.unwrap();
@ -367,39 +332,36 @@ mod tests {
// Send a change from the appropriate origin
transport
.send(Response::new(
"test-tenant",
req.id,
vec![ResponseData::Changed(Change {
.write(Response::new(
req.id.clone(),
DistantResponseData::Changed(Change {
kind: ChangeKind::Access,
paths: vec![test_path.to_path_buf()],
})],
}),
))
.await
.unwrap();
// Send a change from a different origin
transport
.send(Response::new(
"test-tenant",
req.id + 1,
vec![ResponseData::Changed(Change {
.write(Response::new(
req.id.clone() + "1",
DistantResponseData::Changed(Change {
kind: ChangeKind::Content,
paths: vec![test_path.to_path_buf()],
})],
}),
))
.await
.unwrap();
// Send a change from the appropriate origin
transport
.send(Response::new(
"test-tenant",
.write(Response::new(
req.id,
vec![ResponseData::Changed(Change {
DistantResponseData::Changed(Change {
kind: ChangeKind::Remove,
paths: vec![test_path.to_path_buf()],
})],
}),
))
.await
.unwrap();
@ -433,7 +395,6 @@ mod tests {
// in a separate async block
let watch_task = tokio::spawn(async move {
Watcher::watch(
String::from("test-tenant"),
session.clone_channel(),
test_path,
true,
@ -444,29 +405,28 @@ mod tests {
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
// Send back an acknowledgement that a watcher was created
transport
.send(Response::new("test-tenant", req.id, vec![ResponseData::Ok]))
.write(Response::new(req.id.clone(), DistantResponseData::Ok))
.await
.unwrap();
// Send some changes from the appropriate origin
transport
.send(Response::new(
"test-tenant",
.write(Response::new(
req.id,
vec![
ResponseData::Changed(Change {
DistantResponseData::Changed(Change {
kind: ChangeKind::Access,
paths: vec![test_path.to_path_buf()],
}),
ResponseData::Changed(Change {
DistantResponseData::Changed(Change {
kind: ChangeKind::Content,
paths: vec![test_path.to_path_buf()],
}),
ResponseData::Changed(Change {
DistantResponseData::Changed(Change {
kind: ChangeKind::Remove,
paths: vec![test_path.to_path_buf()],
}),
@ -501,24 +461,23 @@ mod tests {
let watcher_2 = Arc::clone(&watcher);
let unwatch_task = tokio::spawn(async move { watcher_2.lock().await.unwatch().await });
let req = transport.receive::<Request>().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
transport
.send(Response::new("test-tenant", req.id, vec![ResponseData::Ok]))
.write(Response::new(req.id.clone(), DistantResponseData::Ok))
.await
.unwrap();
// Wait for the unwatch to complete
let _ = unwatch_task.await.unwrap().unwrap();
unwatch_task.await.unwrap().unwrap();
transport
.send(Response::new(
"test-tenant",
.write(Response::new(
req.id,
vec![ResponseData::Changed(Change {
DistantResponseData::Changed(Change {
kind: ChangeKind::Unknown,
paths: vec![test_path.to_path_buf()],
})],
}),
))
.await
.unwrap();

@ -1,6 +1,3 @@
/// Capacity associated with a client mailboxes for receiving multiple responses to a request
pub const CLIENT_MAILBOX_CAPACITY: usize = 10000;
/// Capacity associated stdin, stdout, and stderr pipes receiving data from remote server
pub const CLIENT_PIPE_CAPACITY: usize = 10000;
@ -19,13 +16,3 @@ pub const MAX_PIPE_CHUNK_SIZE: usize = 16384;
/// Duration in milliseconds to sleep between reading stdout/stderr chunks
/// to avoid sending many small messages to clients
pub const READ_PAUSE_MILLIS: u64 = 50;
/// Maximum message capacity per connection for the distant server
pub const MAX_MSG_CAPACITY: usize = 10000;
/// Test-only constants
#[cfg(test)]
pub mod test {
pub const BUFFER_SIZE: usize = 100;
pub const TENANT: &str = "test-tenant";
}

@ -0,0 +1,133 @@
use crate::{
serde_str::{deserialize_from_str, serialize_to_str},
Destination,
};
use distant_net::SecretKey32;
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
use std::{
convert::{TryFrom, TryInto},
fmt, io,
str::FromStr,
};
use uriparse::{URIReference, URI};
/// Represents credentials used for a distant server that is maintaining a single key
/// across all connections
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct DistantSingleKeyCredentials {
pub host: String,
pub port: u16,
pub key: SecretKey32,
pub username: Option<String>,
}
impl fmt::Display for DistantSingleKeyCredentials {
/// Converts credentials into string in the form of `distant://[username]:{key}@{host}:{port}`
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "distant://")?;
if let Some(username) = self.username.as_ref() {
write!(f, "{}", username)?;
}
write!(f, ":{}@{}:{}", self.key, self.host, self.port)
}
}
impl FromStr for DistantSingleKeyCredentials {
type Err = io::Error;
/// Parse `distant://[username]:{key}@{host}` as credentials. Note that this requires the
/// `distant` scheme to be included. If parsing without scheme is desired, call the
/// [`DistantSingleKeyCredentials::try_from_uri_ref`] method instead with `require_scheme`
/// set to false
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from_uri_ref(s, true)
}
}
impl Serialize for DistantSingleKeyCredentials {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_to_str(self, serializer)
}
}
impl<'de> Deserialize<'de> for DistantSingleKeyCredentials {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserialize_from_str(deserializer)
}
}
impl DistantSingleKeyCredentials {
/// Converts credentials into a [`Destination`] of the form `distant://[username]:{key}@{host}`,
/// failing if the credentials would not produce a valid [`Destination`]
pub fn try_to_destination(&self) -> io::Result<Destination> {
let uri = self.try_to_uri()?;
Destination::try_from(uri.as_uri_reference().to_borrowed())
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
}
/// Converts credentials into a [`URI`] of the form `distant://[username]:{key}@{host}`,
/// failing if the credentials would not produce a valid [`URI`]
pub fn try_to_uri(&self) -> io::Result<URI<'static>> {
let uri_string = self.to_string();
URI::try_from(uri_string.as_str())
.map(URI::into_owned)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
}
/// Parses credentials from a [`URIReference`], failing if the input was not a valid
/// [`URIReference`] or if required parameters like `host` or `password` are missing or bad
/// format
///
/// If `require_scheme` is true, will enforce that a scheme is provided. Regardless, if a
/// scheme is provided that is not `distant`, this will also fail
pub fn try_from_uri_ref<'a, E>(
uri_ref: impl TryInto<URIReference<'a>, Error = E>,
require_scheme: bool,
) -> io::Result<Self>
where
E: std::error::Error + Send + Sync + 'static,
{
let uri_ref = uri_ref
.try_into()
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
// Check if the scheme is correct, and if missing if we require it
if let Some(scheme) = uri_ref.scheme() {
if !scheme.as_str().eq_ignore_ascii_case("distant") {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Scheme is not distant",
));
}
} else if require_scheme {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Missing scheme",
));
}
Ok(Self {
host: uri_ref
.host()
.map(ToString::to_string)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Missing host"))?,
port: uri_ref
.port()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Missing port"))?,
key: uri_ref
.password()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Missing password"))
.and_then(|x| {
x.parse()
.map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))
})?,
username: uri_ref.username().map(ToString::to_string),
})
}
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,506 @@
use derive_more::{Deref, DerefMut, IntoIterator};
use notify::{event::Event as NotifyEvent, EventKind as NotifyEventKind};
use serde::{Deserialize, Serialize};
use std::{
collections::HashSet,
fmt,
hash::{Hash, Hasher},
iter::FromIterator,
ops::{BitOr, Sub},
path::PathBuf,
str::FromStr,
};
use strum::{EnumString, EnumVariantNames};
/// Change to one or more paths on the filesystem
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct Change {
/// Label describing the kind of change
pub kind: ChangeKind,
/// Paths that were changed
pub paths: Vec<PathBuf>,
}
#[cfg(feature = "schemars")]
impl Change {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(Change)
}
}
impl From<NotifyEvent> for Change {
fn from(x: NotifyEvent) -> Self {
Self {
kind: x.kind.into(),
paths: x.paths,
}
}
}
#[derive(
Copy,
Clone,
Debug,
strum::Display,
EnumString,
EnumVariantNames,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Serialize,
Deserialize,
)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
#[strum(serialize_all = "snake_case")]
#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
#[cfg_attr(feature = "clap", clap(rename_all = "snake_case"))]
pub enum ChangeKind {
/// Something about a file or directory was accessed, but
/// no specific details were known
Access,
/// A file was closed for executing
AccessCloseExecute,
/// A file was closed for reading
AccessCloseRead,
/// A file was closed for writing
AccessCloseWrite,
/// A file was opened for executing
AccessOpenExecute,
/// A file was opened for reading
AccessOpenRead,
/// A file was opened for writing
AccessOpenWrite,
/// A file or directory was read
AccessRead,
/// The access time of a file or directory was changed
AccessTime,
/// A file, directory, or something else was created
Create,
/// The content of a file or directory changed
Content,
/// The data of a file or directory was modified, but
/// no specific details were known
Data,
/// The metadata of a file or directory was modified, but
/// no specific details were known
Metadata,
/// Something about a file or directory was modified, but
/// no specific details were known
Modify,
/// A file, directory, or something else was removed
Remove,
/// A file or directory was renamed, but no specific details were known
Rename,
/// A file or directory was renamed, and the provided paths
/// are the source and target in that order (from, to)
RenameBoth,
/// A file or directory was renamed, and the provided path
/// is the origin of the rename (before being renamed)
RenameFrom,
/// A file or directory was renamed, and the provided path
/// is the result of the rename
RenameTo,
/// A file's size changed
Size,
/// The ownership of a file or directory was changed
Ownership,
/// The permissions of a file or directory was changed
Permissions,
/// The write or modify time of a file or directory was changed
WriteTime,
// Catchall in case we have no insight as to the type of change
Unknown,
}
impl ChangeKind {
/// Returns true if the change is a kind of access
pub fn is_access_kind(&self) -> bool {
self.is_open_access_kind()
|| self.is_close_access_kind()
|| matches!(self, Self::Access | Self::AccessRead)
}
/// Returns true if the change is a kind of open access
pub fn is_open_access_kind(&self) -> bool {
matches!(
self,
Self::AccessOpenExecute | Self::AccessOpenRead | Self::AccessOpenWrite
)
}
/// Returns true if the change is a kind of close access
pub fn is_close_access_kind(&self) -> bool {
matches!(
self,
Self::AccessCloseExecute | Self::AccessCloseRead | Self::AccessCloseWrite
)
}
/// Returns true if the change is a kind of creation
pub fn is_create_kind(&self) -> bool {
matches!(self, Self::Create)
}
/// Returns true if the change is a kind of modification
pub fn is_modify_kind(&self) -> bool {
self.is_data_modify_kind() || self.is_metadata_modify_kind() || matches!(self, Self::Modify)
}
/// Returns true if the change is a kind of data modification
pub fn is_data_modify_kind(&self) -> bool {
matches!(self, Self::Content | Self::Data | Self::Size)
}
/// Returns true if the change is a kind of metadata modification
pub fn is_metadata_modify_kind(&self) -> bool {
matches!(
self,
Self::AccessTime
| Self::Metadata
| Self::Ownership
| Self::Permissions
| Self::WriteTime
)
}
/// Returns true if the change is a kind of rename
pub fn is_rename_kind(&self) -> bool {
matches!(
self,
Self::Rename | Self::RenameBoth | Self::RenameFrom | Self::RenameTo
)
}
/// Returns true if the change is a kind of removal
pub fn is_remove_kind(&self) -> bool {
matches!(self, Self::Remove)
}
/// Returns true if the change kind is unknown
pub fn is_unknown_kind(&self) -> bool {
matches!(self, Self::Unknown)
}
}
#[cfg(feature = "schemars")]
impl ChangeKind {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(ChangeKind)
}
}
impl BitOr for ChangeKind {
type Output = ChangeKindSet;
fn bitor(self, rhs: Self) -> Self::Output {
let mut set = ChangeKindSet::empty();
set.insert(self);
set.insert(rhs);
set
}
}
impl From<NotifyEventKind> for ChangeKind {
fn from(x: NotifyEventKind) -> Self {
use notify::event::{
AccessKind, AccessMode, DataChange, MetadataKind, ModifyKind, RenameMode,
};
match x {
// File/directory access events
NotifyEventKind::Access(AccessKind::Read) => Self::AccessRead,
NotifyEventKind::Access(AccessKind::Open(AccessMode::Execute)) => {
Self::AccessOpenExecute
}
NotifyEventKind::Access(AccessKind::Open(AccessMode::Read)) => Self::AccessOpenRead,
NotifyEventKind::Access(AccessKind::Open(AccessMode::Write)) => Self::AccessOpenWrite,
NotifyEventKind::Access(AccessKind::Close(AccessMode::Execute)) => {
Self::AccessCloseExecute
}
NotifyEventKind::Access(AccessKind::Close(AccessMode::Read)) => Self::AccessCloseRead,
NotifyEventKind::Access(AccessKind::Close(AccessMode::Write)) => Self::AccessCloseWrite,
NotifyEventKind::Access(_) => Self::Access,
// File/directory creation events
NotifyEventKind::Create(_) => Self::Create,
// Rename-oriented events
NotifyEventKind::Modify(ModifyKind::Name(RenameMode::Both)) => Self::RenameBoth,
NotifyEventKind::Modify(ModifyKind::Name(RenameMode::From)) => Self::RenameFrom,
NotifyEventKind::Modify(ModifyKind::Name(RenameMode::To)) => Self::RenameTo,
NotifyEventKind::Modify(ModifyKind::Name(_)) => Self::Rename,
// Data-modification events
NotifyEventKind::Modify(ModifyKind::Data(DataChange::Content)) => Self::Content,
NotifyEventKind::Modify(ModifyKind::Data(DataChange::Size)) => Self::Size,
NotifyEventKind::Modify(ModifyKind::Data(_)) => Self::Data,
// Metadata-modification events
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::AccessTime)) => {
Self::AccessTime
}
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::WriteTime)) => {
Self::WriteTime
}
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::Permissions)) => {
Self::Permissions
}
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::Ownership)) => {
Self::Ownership
}
NotifyEventKind::Modify(ModifyKind::Metadata(_)) => Self::Metadata,
// General modification events
NotifyEventKind::Modify(_) => Self::Modify,
// File/directory removal events
NotifyEventKind::Remove(_) => Self::Remove,
// Catch-all for other events
NotifyEventKind::Any | NotifyEventKind::Other => Self::Unknown,
}
}
}
/// Represents a distinct set of different change kinds
#[derive(Clone, Debug, Deref, DerefMut, IntoIterator, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct ChangeKindSet(HashSet<ChangeKind>);
impl ChangeKindSet {
/// Produces an empty set of [`ChangeKind`]
pub fn empty() -> Self {
Self(HashSet::new())
}
/// Produces a set of all [`ChangeKind`]
pub fn all() -> Self {
vec![
ChangeKind::Access,
ChangeKind::AccessCloseExecute,
ChangeKind::AccessCloseRead,
ChangeKind::AccessCloseWrite,
ChangeKind::AccessOpenExecute,
ChangeKind::AccessOpenRead,
ChangeKind::AccessOpenWrite,
ChangeKind::AccessRead,
ChangeKind::AccessTime,
ChangeKind::Create,
ChangeKind::Content,
ChangeKind::Data,
ChangeKind::Metadata,
ChangeKind::Modify,
ChangeKind::Remove,
ChangeKind::Rename,
ChangeKind::RenameBoth,
ChangeKind::RenameFrom,
ChangeKind::RenameTo,
ChangeKind::Size,
ChangeKind::Ownership,
ChangeKind::Permissions,
ChangeKind::WriteTime,
ChangeKind::Unknown,
]
.into_iter()
.collect()
}
/// Produces a changeset containing all of the access kinds
pub fn access_set() -> Self {
Self::access_open_set()
| Self::access_close_set()
| ChangeKind::AccessRead
| ChangeKind::Access
}
/// Produces a changeset containing all of the open access kinds
pub fn access_open_set() -> Self {
ChangeKind::AccessOpenExecute | ChangeKind::AccessOpenRead | ChangeKind::AccessOpenWrite
}
/// Produces a changeset containing all of the close access kinds
pub fn access_close_set() -> Self {
ChangeKind::AccessCloseExecute | ChangeKind::AccessCloseRead | ChangeKind::AccessCloseWrite
}
// Produces a changeset containing all of the modification kinds
pub fn modify_set() -> Self {
Self::modify_data_set() | Self::modify_metadata_set() | ChangeKind::Modify
}
/// Produces a changeset containing all of the data modification kinds
pub fn modify_data_set() -> Self {
ChangeKind::Content | ChangeKind::Data | ChangeKind::Size
}
/// Produces a changeset containing all of the metadata modification kinds
pub fn modify_metadata_set() -> Self {
ChangeKind::AccessTime
| ChangeKind::Metadata
| ChangeKind::Ownership
| ChangeKind::Permissions
| ChangeKind::WriteTime
}
/// Produces a changeset containing all of the rename kinds
pub fn rename_set() -> Self {
ChangeKind::Rename | ChangeKind::RenameBoth | ChangeKind::RenameFrom | ChangeKind::RenameTo
}
/// Consumes set and returns a vec of the kinds of changes
pub fn into_vec(self) -> Vec<ChangeKind> {
self.0.into_iter().collect()
}
}
#[cfg(feature = "schemars")]
impl ChangeKindSet {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(ChangeKindSet)
}
}
impl fmt::Display for ChangeKindSet {
/// Outputs a comma-separated series of [`ChangeKind`] as string that are sorted
/// such that this will always be consistent output
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut kinds = self
.0
.iter()
.map(ToString::to_string)
.collect::<Vec<String>>();
kinds.sort_unstable();
write!(f, "{}", kinds.join(","))
}
}
impl PartialEq for ChangeKindSet {
fn eq(&self, other: &Self) -> bool {
self.to_string() == other.to_string()
}
}
impl Eq for ChangeKindSet {}
impl Hash for ChangeKindSet {
/// Hashes based on the output of [`fmt::Display`]
fn hash<H: Hasher>(&self, state: &mut H) {
self.to_string().hash(state);
}
}
impl BitOr<ChangeKindSet> for ChangeKindSet {
type Output = Self;
fn bitor(mut self, rhs: ChangeKindSet) -> Self::Output {
self.extend(rhs.0);
self
}
}
impl BitOr<ChangeKind> for ChangeKindSet {
type Output = Self;
fn bitor(mut self, rhs: ChangeKind) -> Self::Output {
self.0.insert(rhs);
self
}
}
impl BitOr<ChangeKindSet> for ChangeKind {
type Output = ChangeKindSet;
fn bitor(self, rhs: ChangeKindSet) -> Self::Output {
rhs | self
}
}
impl Sub<ChangeKindSet> for ChangeKindSet {
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
ChangeKindSet(&self.0 - &other.0)
}
}
impl Sub<&'_ ChangeKindSet> for &ChangeKindSet {
type Output = ChangeKindSet;
fn sub(self, other: &ChangeKindSet) -> Self::Output {
ChangeKindSet(&self.0 - &other.0)
}
}
impl FromStr for ChangeKindSet {
type Err = strum::ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut change_set = HashSet::new();
for word in s.split(',') {
change_set.insert(ChangeKind::from_str(word.trim())?);
}
Ok(ChangeKindSet(change_set))
}
}
impl FromIterator<ChangeKind> for ChangeKindSet {
fn from_iter<I: IntoIterator<Item = ChangeKind>>(iter: I) -> Self {
let mut change_set = HashSet::new();
for i in iter {
change_set.insert(i);
}
ChangeKindSet(change_set)
}
}
impl From<ChangeKind> for ChangeKindSet {
fn from(change_kind: ChangeKind) -> Self {
let mut set = Self::empty();
set.insert(change_kind);
set
}
}
impl From<Vec<ChangeKind>> for ChangeKindSet {
fn from(changes: Vec<ChangeKind>) -> Self {
changes.into_iter().collect()
}
}
impl Default for ChangeKindSet {
fn default() -> Self {
Self::empty()
}
}

@ -0,0 +1,106 @@
use crate::{data::Cmd, DistantMsg, DistantRequestData};
use clap::{
error::{Error, ErrorKind},
Arg, ArgAction, ArgMatches, Args, Command, FromArgMatches, Subcommand,
};
impl FromArgMatches for Cmd {
fn from_arg_matches(matches: &ArgMatches) -> Result<Self, Error> {
let mut matches = matches.clone();
Self::from_arg_matches_mut(&mut matches)
}
fn from_arg_matches_mut(matches: &mut ArgMatches) -> Result<Self, Error> {
let cmd = matches.get_one::<String>("cmd").ok_or_else(|| {
Error::raw(
ErrorKind::MissingRequiredArgument,
"program must be specified",
)
})?;
let args: Vec<String> = matches
.get_many::<String>("arg")
.unwrap_or_default()
.map(ToString::to_string)
.collect();
Ok(Self::new(format!("{cmd} {}", args.join(" "))))
}
fn update_from_arg_matches(&mut self, matches: &ArgMatches) -> Result<(), Error> {
let mut matches = matches.clone();
self.update_from_arg_matches_mut(&mut matches)
}
fn update_from_arg_matches_mut(&mut self, _matches: &mut ArgMatches) -> Result<(), Error> {
Ok(())
}
}
impl Args for Cmd {
fn augment_args(cmd: Command<'_>) -> Command<'_> {
cmd.arg(
Arg::new("cmd")
.required(true)
.value_name("CMD")
.action(ArgAction::Set),
)
.trailing_var_arg(true)
.arg(
Arg::new("arg")
.value_name("ARGS")
.multiple_values(true)
.action(ArgAction::Append),
)
}
fn augment_args_for_update(cmd: Command<'_>) -> Command<'_> {
cmd
}
}
impl FromArgMatches for DistantMsg<DistantRequestData> {
fn from_arg_matches(matches: &ArgMatches) -> Result<Self, Error> {
match matches.subcommand() {
Some(("single", args)) => Ok(Self::Single(DistantRequestData::from_arg_matches(args)?)),
Some((_, _)) => Err(Error::raw(
ErrorKind::UnrecognizedSubcommand,
"Valid subcommand is `single`",
)),
None => Err(Error::raw(
ErrorKind::MissingSubcommand,
"Valid subcommand is `single`",
)),
}
}
fn update_from_arg_matches(&mut self, matches: &ArgMatches) -> Result<(), Error> {
match matches.subcommand() {
Some(("single", args)) => {
*self = Self::Single(DistantRequestData::from_arg_matches(args)?)
}
Some((_, _)) => {
return Err(Error::raw(
ErrorKind::UnrecognizedSubcommand,
"Valid subcommand is `single`",
))
}
None => (),
};
Ok(())
}
}
impl Subcommand for DistantMsg<DistantRequestData> {
fn augment_subcommands(cmd: Command<'_>) -> Command<'_> {
cmd.subcommand(DistantRequestData::augment_subcommands(Command::new(
"single",
)))
.subcommand_required(true)
}
fn augment_subcommands_for_update(cmd: Command<'_>) -> Command<'_> {
cmd.subcommand(DistantRequestData::augment_subcommands(Command::new(
"single",
)))
.subcommand_required(true)
}
fn has_subcommand(name: &str) -> bool {
matches!(name, "single")
}
}

@ -0,0 +1,52 @@
use derive_more::{Display, From, Into};
use serde::{Deserialize, Serialize};
use std::ops::{Deref, DerefMut};
/// Represents some command with arguments to execute
#[derive(Clone, Debug, Display, From, Into, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct Cmd(String);
impl Cmd {
/// Creates a new command from the given `cmd`
pub fn new(cmd: impl Into<String>) -> Self {
Self(cmd.into())
}
/// Returns reference to the program portion of the command
pub fn program(&self) -> &str {
match self.0.split_once(' ') {
Some((program, _)) => program.trim(),
None => self.0.trim(),
}
}
/// Returns reference to the arguments portion of the command
pub fn arguments(&self) -> &str {
match self.0.split_once(' ') {
Some((_, arguments)) => arguments.trim(),
None => "",
}
}
}
#[cfg(feature = "schemars")]
impl Cmd {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(Cmd)
}
}
impl Deref for Cmd {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Cmd {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

@ -0,0 +1,269 @@
use derive_more::Display;
use notify::ErrorKind as NotifyErrorKind;
use serde::{Deserialize, Serialize};
use std::io;
/// General purpose error type that can be sent across the wire
#[derive(Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[display(fmt = "{}: {}", kind, description)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct Error {
/// Label describing the kind of error
pub kind: ErrorKind,
/// Description of the error itself
pub description: String,
}
impl std::error::Error for Error {}
#[cfg(feature = "schemars")]
impl Error {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(Error)
}
}
impl<'a> From<&'a str> for Error {
fn from(x: &'a str) -> Self {
Self::from(x.to_string())
}
}
impl From<String> for Error {
fn from(x: String) -> Self {
Self {
kind: ErrorKind::Other,
description: x,
}
}
}
impl From<io::Error> for Error {
fn from(x: io::Error) -> Self {
Self {
kind: ErrorKind::from(x.kind()),
description: x.to_string(),
}
}
}
impl From<Error> for io::Error {
fn from(x: Error) -> Self {
Self::new(x.kind.into(), x.description)
}
}
impl From<notify::Error> for Error {
fn from(x: notify::Error) -> Self {
let err = match x.kind {
NotifyErrorKind::Generic(x) => Self {
kind: ErrorKind::Other,
description: x,
},
NotifyErrorKind::Io(x) => Self::from(x),
NotifyErrorKind::PathNotFound => Self {
kind: ErrorKind::Other,
description: String::from("Path not found"),
},
NotifyErrorKind::WatchNotFound => Self {
kind: ErrorKind::Other,
description: String::from("Watch not found"),
},
NotifyErrorKind::InvalidConfig(_) => Self {
kind: ErrorKind::Other,
description: String::from("Invalid config"),
},
NotifyErrorKind::MaxFilesWatch => Self {
kind: ErrorKind::Other,
description: String::from("Max files watched"),
},
};
Self {
kind: err.kind,
description: format!(
"{}\n\nPaths: {}",
err.description,
x.paths
.into_iter()
.map(|p| p.to_string_lossy().to_string())
.collect::<Vec<String>>()
.join(", ")
),
}
}
}
impl From<walkdir::Error> for Error {
fn from(x: walkdir::Error) -> Self {
if x.io_error().is_some() {
x.into_io_error().map(Self::from).unwrap()
} else {
Self {
kind: ErrorKind::Loop,
description: format!("{}", x),
}
}
}
}
impl From<tokio::task::JoinError> for Error {
fn from(x: tokio::task::JoinError) -> Self {
Self {
kind: if x.is_cancelled() {
ErrorKind::TaskCancelled
} else {
ErrorKind::TaskPanicked
},
description: format!("{}", x),
}
}
}
/// All possible kinds of errors that can be returned
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub enum ErrorKind {
/// An entity was not found, often a file
NotFound,
/// The operation lacked the necessary privileges to complete
PermissionDenied,
/// The connection was refused by the remote server
ConnectionRefused,
/// The connection was reset by the remote server
ConnectionReset,
/// The connection was aborted (terminated) by the remote server
ConnectionAborted,
/// The network operation failed because it was not connected yet
NotConnected,
/// A socket address could not be bound because the address is already in use elsewhere
AddrInUse,
/// A nonexistent interface was requested or the requested address was not local
AddrNotAvailable,
/// The operation failed because a pipe was closed
BrokenPipe,
/// An entity already exists, often a file
AlreadyExists,
/// The operation needs to block to complete, but the blocking operation was requested to not
/// occur
WouldBlock,
/// A parameter was incorrect
InvalidInput,
/// Data not valid for the operation were encountered
InvalidData,
/// The I/O operation's timeout expired, causing it to be cancelled
TimedOut,
/// An error returned when an operation could not be completed because a
/// call to `write` returned `Ok(0)`
WriteZero,
/// This operation was interrupted
Interrupted,
/// Any I/O error not part of this list
Other,
/// An error returned when an operation could not be completed because an "end of file" was
/// reached prematurely
UnexpectedEof,
/// This operation is unsupported on this platform
Unsupported,
/// An operation could not be completed, because it failed to allocate enough memory
OutOfMemory,
/// When a loop is encountered when walking a directory
Loop,
/// When a task is cancelled
TaskCancelled,
/// When a task panics
TaskPanicked,
/// Catchall for an error that has no specific type
Unknown,
}
#[cfg(feature = "schemars")]
impl ErrorKind {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(ErrorKind)
}
}
impl From<io::ErrorKind> for ErrorKind {
fn from(kind: io::ErrorKind) -> Self {
match kind {
io::ErrorKind::NotFound => Self::NotFound,
io::ErrorKind::PermissionDenied => Self::PermissionDenied,
io::ErrorKind::ConnectionRefused => Self::ConnectionRefused,
io::ErrorKind::ConnectionReset => Self::ConnectionReset,
io::ErrorKind::ConnectionAborted => Self::ConnectionAborted,
io::ErrorKind::NotConnected => Self::NotConnected,
io::ErrorKind::AddrInUse => Self::AddrInUse,
io::ErrorKind::AddrNotAvailable => Self::AddrNotAvailable,
io::ErrorKind::BrokenPipe => Self::BrokenPipe,
io::ErrorKind::AlreadyExists => Self::AlreadyExists,
io::ErrorKind::WouldBlock => Self::WouldBlock,
io::ErrorKind::InvalidInput => Self::InvalidInput,
io::ErrorKind::InvalidData => Self::InvalidData,
io::ErrorKind::TimedOut => Self::TimedOut,
io::ErrorKind::WriteZero => Self::WriteZero,
io::ErrorKind::Interrupted => Self::Interrupted,
io::ErrorKind::Other => Self::Other,
io::ErrorKind::OutOfMemory => Self::OutOfMemory,
io::ErrorKind::UnexpectedEof => Self::UnexpectedEof,
io::ErrorKind::Unsupported => Self::Unsupported,
// This exists because io::ErrorKind is non_exhaustive
_ => Self::Unknown,
}
}
}
impl From<ErrorKind> for io::ErrorKind {
fn from(kind: ErrorKind) -> Self {
match kind {
ErrorKind::NotFound => Self::NotFound,
ErrorKind::PermissionDenied => Self::PermissionDenied,
ErrorKind::ConnectionRefused => Self::ConnectionRefused,
ErrorKind::ConnectionReset => Self::ConnectionReset,
ErrorKind::ConnectionAborted => Self::ConnectionAborted,
ErrorKind::NotConnected => Self::NotConnected,
ErrorKind::AddrInUse => Self::AddrInUse,
ErrorKind::AddrNotAvailable => Self::AddrNotAvailable,
ErrorKind::BrokenPipe => Self::BrokenPipe,
ErrorKind::AlreadyExists => Self::AlreadyExists,
ErrorKind::WouldBlock => Self::WouldBlock,
ErrorKind::InvalidInput => Self::InvalidInput,
ErrorKind::InvalidData => Self::InvalidData,
ErrorKind::TimedOut => Self::TimedOut,
ErrorKind::WriteZero => Self::WriteZero,
ErrorKind::Interrupted => Self::Interrupted,
ErrorKind::Other => Self::Other,
ErrorKind::OutOfMemory => Self::OutOfMemory,
ErrorKind::UnexpectedEof => Self::UnexpectedEof,
ErrorKind::Unsupported => Self::Unsupported,
_ => Self::Other,
}
}
}

@ -0,0 +1,45 @@
use derive_more::IsVariant;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use strum::AsRefStr;
/// Represents information about a single entry within a directory
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct DirEntry {
/// Represents the full path to the entry
pub path: PathBuf,
/// Represents the type of the entry as a file/dir/symlink
pub file_type: FileType,
/// Depth at which this entry was created relative to the root (0 being immediately within
/// root)
pub depth: usize,
}
#[cfg(feature = "schemars")]
impl DirEntry {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(DirEntry)
}
}
/// Represents the type associated with a dir entry
#[derive(Copy, Clone, Debug, PartialEq, Eq, AsRefStr, IsVariant, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
#[strum(serialize_all = "snake_case")]
pub enum FileType {
Dir,
File,
Symlink,
}
#[cfg(feature = "schemars")]
impl FileType {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(FileType)
}
}

@ -0,0 +1,244 @@
use crate::serde_str::{deserialize_from_str, serialize_to_str};
use derive_more::{From, IntoIterator};
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
use std::{
collections::HashMap,
fmt,
ops::{Deref, DerefMut},
str::FromStr,
};
/// Contains map information for connections and other use cases
#[derive(Clone, Debug, From, IntoIterator, PartialEq, Eq)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct Map(HashMap<String, String>);
impl Map {
pub fn new() -> Self {
Self(HashMap::new())
}
pub fn into_map(self) -> HashMap<String, String> {
self.0
}
}
#[cfg(feature = "schemars")]
impl Map {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(Map)
}
}
impl Default for Map {
fn default() -> Self {
Self::new()
}
}
impl Deref for Map {
type Target = HashMap<String, String>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Map {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl fmt::Display for Map {
/// Outputs a `key=value` mapping in the form `key="value",key2="value2"`
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let len = self.0.len();
for (i, (key, value)) in self.0.iter().enumerate() {
write!(f, "{}=\"{}\"", key, value)?;
// Include a comma after each but the last pair
if i + 1 < len {
write!(f, ",")?;
}
}
Ok(())
}
}
impl FromStr for Map {
type Err = &'static str;
/// Parses a series of `key=value` pairs in the form `key="value",key2=value2` where
/// the quotes around the value are optional
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut map = HashMap::new();
let mut s = s.trim();
while !s.is_empty() {
// Find {key}={tail...} where tail is everything after =
let (key, tail) = s.split_once('=').ok_or("Missing = after key")?;
// Remove whitespace around the key and ensure it starts with a proper character
let key = key.trim();
if !key.starts_with(char::is_alphabetic) {
return Err("Key must start with alphabetic character");
}
// Remove any map whitespace at the front of the tail
let tail = tail.trim_start();
// Determine if we start with a quote " otherwise we will look for the next ,
let (value, tail) = match tail.strip_prefix('"') {
// If quoted, we maintain the whitespace within the quotes
Some(tail) => {
// Skip the quote so we can look for the trailing quote
let (value, tail) =
tail.split_once('"').ok_or("Missing closing \" for value")?;
// Skip comma if we have one
let tail = tail.strip_prefix(',').unwrap_or(tail);
(value, tail)
}
// If not quoted, we remove all whitespace around the value
None => match tail.split_once(',') {
Some((value, tail)) => (value.trim(), tail),
None => (tail.trim(), ""),
},
};
// Insert our new pair and update the slice to be the tail (removing whitespace)
map.insert(key.to_string(), value.to_string());
s = tail.trim();
}
Ok(Self(map))
}
}
impl Serialize for Map {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_to_str(self, serializer)
}
}
impl<'de> Deserialize<'de> for Map {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserialize_from_str(deserializer)
}
}
#[macro_export]
macro_rules! map {
($($key:literal -> $value:literal),*) => {{
let mut _map = ::std::collections::HashMap::new();
$(
_map.insert($key.to_string(), $value.to_string());
)*
$crate::Map::from(_map)
}};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_support_being_parsed_from_str() {
// Empty string (whitespace only) yields an empty map
let map = " ".parse::<Map>().unwrap();
assert_eq!(map, map!());
// Simple key=value should succeed
let map = "key=value".parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> "value"));
// Key can be anything but =
let map = "key.with-characters@=value".parse::<Map>().unwrap();
assert_eq!(map, map!("key.with-characters@" -> "value"));
// Value can be anything but ,
let map = "key=value.has -@#$".parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> "value.has -@#$"));
// Value can include comma if quoted
let map = r#"key=",,,,""#.parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> ",,,,"));
// Supports whitespace around key and value
let map = " key = value ".parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> "value"));
// Supports value capturing whitespace if quoted
let map = r#" key = " value " "#.parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> " value "));
// Multiple key=value should succeed
let map = "key=value,key2=value2".parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> "value", "key2" -> "value2"));
// Quoted key=value should succeed
let map = r#"key="value one",key2=value2"#.parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> "value one", "key2" -> "value2"));
let map = r#"key=value,key2="value two""#.parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> "value", "key2" -> "value two"));
let map = r#"key="value one",key2="value two""#.parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> "value one", "key2" -> "value two"));
let map = r#"key="1,2,3",key2="4,5,6""#.parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> "1,2,3", "key2" -> "4,5,6"));
// Dangling comma is okay
let map = "key=value,".parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> "value"));
let map = r#"key=",value,","#.parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> ",value,"));
// Demonstrating greedy
let map = "key=value key2=value2".parse::<Map>().unwrap();
assert_eq!(map, map!("key" -> "value key2=value2"));
// Variety of edge cases that should fail
let _ = ",".parse::<Map>().unwrap_err();
let _ = ",key=value".parse::<Map>().unwrap_err();
let _ = "key=value,key2".parse::<Map>().unwrap_err();
}
#[test]
fn should_support_being_displayed_as_a_string() {
let map = map!().to_string();
assert_eq!(map, "");
let map = map!("key" -> "value").to_string();
assert_eq!(map, r#"key="value""#);
// Order of key=value output is not guaranteed
let map = map!("key" -> "value", "key2" -> "value2").to_string();
assert!(
map == r#"key="value",key2="value2""# || map == r#"key2="value2",key="value""#,
"{:?}",
map
);
// Order of key=value output is not guaranteed
let map = map!("key" -> ",", "key2" -> ",,").to_string();
assert!(
map == r#"key=",",key2=",,""# || map == r#"key2=",,",key=",""#,
"{:?}",
map
);
}
}

@ -0,0 +1,404 @@
use super::{deserialize_u128_option, serialize_u128_option, FileType};
use bitflags::bitflags;
use serde::{Deserialize, Serialize};
use std::{
io,
path::{Path, PathBuf},
time::SystemTime,
};
/// Represents metadata about some path on a remote machine
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct Metadata {
/// Canonicalized path to the file or directory, resolving symlinks, only included
/// if flagged during the request
pub canonicalized_path: Option<PathBuf>,
/// Represents the type of the entry as a file/dir/symlink
pub file_type: FileType,
/// Size of the file/directory/symlink in bytes
pub len: u64,
/// Whether or not the file/directory/symlink is marked as unwriteable
pub readonly: bool,
/// Represents the last time (in milliseconds) when the file/directory/symlink was accessed;
/// can be optional as certain systems don't support this
#[serde(serialize_with = "serialize_u128_option")]
#[serde(deserialize_with = "deserialize_u128_option")]
pub accessed: Option<u128>,
/// Represents when (in milliseconds) the file/directory/symlink was created;
/// can be optional as certain systems don't support this
#[serde(serialize_with = "serialize_u128_option")]
#[serde(deserialize_with = "deserialize_u128_option")]
pub created: Option<u128>,
/// Represents the last time (in milliseconds) when the file/directory/symlink was modified;
/// can be optional as certain systems don't support this
#[serde(serialize_with = "serialize_u128_option")]
#[serde(deserialize_with = "deserialize_u128_option")]
pub modified: Option<u128>,
/// Represents metadata that is specific to a unix remote machine
pub unix: Option<UnixMetadata>,
/// Represents metadata that is specific to a windows remote machine
pub windows: Option<WindowsMetadata>,
}
impl Metadata {
pub async fn read(
path: impl AsRef<Path>,
canonicalize: bool,
resolve_file_type: bool,
) -> io::Result<Self> {
let metadata = tokio::fs::symlink_metadata(path.as_ref()).await?;
let canonicalized_path = if canonicalize {
Some(tokio::fs::canonicalize(path.as_ref()).await?)
} else {
None
};
// If asking for resolved file type and current type is symlink, then we want to refresh
// our metadata to get the filetype for the resolved link
let file_type = if resolve_file_type && metadata.file_type().is_symlink() {
tokio::fs::metadata(path).await?.file_type()
} else {
metadata.file_type()
};
Ok(Self {
canonicalized_path,
accessed: metadata
.accessed()
.ok()
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
.map(|d| d.as_millis()),
created: metadata
.created()
.ok()
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
.map(|d| d.as_millis()),
modified: metadata
.modified()
.ok()
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
.map(|d| d.as_millis()),
len: metadata.len(),
readonly: metadata.permissions().readonly(),
file_type: if file_type.is_dir() {
FileType::Dir
} else if file_type.is_file() {
FileType::File
} else {
FileType::Symlink
},
#[cfg(unix)]
unix: Some({
use std::os::unix::prelude::*;
let mode = metadata.mode();
crate::data::UnixMetadata::from(mode)
}),
#[cfg(not(unix))]
unix: None,
#[cfg(windows)]
windows: Some({
use std::os::windows::prelude::*;
let attributes = metadata.file_attributes();
crate::data::WindowsMetadata::from(attributes)
}),
#[cfg(not(windows))]
windows: None,
})
}
}
#[cfg(feature = "schemars")]
impl Metadata {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(Metadata)
}
}
/// Represents unix-specific metadata about some path on a remote machine
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct UnixMetadata {
/// Represents whether or not owner can read from the file
pub owner_read: bool,
/// Represents whether or not owner can write to the file
pub owner_write: bool,
/// Represents whether or not owner can execute the file
pub owner_exec: bool,
/// Represents whether or not associated group can read from the file
pub group_read: bool,
/// Represents whether or not associated group can write to the file
pub group_write: bool,
/// Represents whether or not associated group can execute the file
pub group_exec: bool,
/// Represents whether or not other can read from the file
pub other_read: bool,
/// Represents whether or not other can write to the file
pub other_write: bool,
/// Represents whether or not other can execute the file
pub other_exec: bool,
}
#[cfg(feature = "schemars")]
impl UnixMetadata {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(UnixMetadata)
}
}
impl From<u32> for UnixMetadata {
/// Create from a unix mode bitset
fn from(mode: u32) -> Self {
let flags = UnixFilePermissionFlags::from_bits_truncate(mode);
Self {
owner_read: flags.contains(UnixFilePermissionFlags::OWNER_READ),
owner_write: flags.contains(UnixFilePermissionFlags::OWNER_WRITE),
owner_exec: flags.contains(UnixFilePermissionFlags::OWNER_EXEC),
group_read: flags.contains(UnixFilePermissionFlags::GROUP_READ),
group_write: flags.contains(UnixFilePermissionFlags::GROUP_WRITE),
group_exec: flags.contains(UnixFilePermissionFlags::GROUP_EXEC),
other_read: flags.contains(UnixFilePermissionFlags::OTHER_READ),
other_write: flags.contains(UnixFilePermissionFlags::OTHER_WRITE),
other_exec: flags.contains(UnixFilePermissionFlags::OTHER_EXEC),
}
}
}
impl From<UnixMetadata> for u32 {
/// Convert to a unix mode bitset
fn from(metadata: UnixMetadata) -> Self {
let mut flags = UnixFilePermissionFlags::empty();
if metadata.owner_read {
flags.insert(UnixFilePermissionFlags::OWNER_READ);
}
if metadata.owner_write {
flags.insert(UnixFilePermissionFlags::OWNER_WRITE);
}
if metadata.owner_exec {
flags.insert(UnixFilePermissionFlags::OWNER_EXEC);
}
if metadata.group_read {
flags.insert(UnixFilePermissionFlags::GROUP_READ);
}
if metadata.group_write {
flags.insert(UnixFilePermissionFlags::GROUP_WRITE);
}
if metadata.group_exec {
flags.insert(UnixFilePermissionFlags::GROUP_EXEC);
}
if metadata.other_read {
flags.insert(UnixFilePermissionFlags::OTHER_READ);
}
if metadata.other_write {
flags.insert(UnixFilePermissionFlags::OTHER_WRITE);
}
if metadata.other_exec {
flags.insert(UnixFilePermissionFlags::OTHER_EXEC);
}
flags.bits
}
}
impl UnixMetadata {
pub fn is_readonly(self) -> bool {
!(self.owner_read || self.group_read || self.other_read)
}
}
bitflags! {
struct UnixFilePermissionFlags: u32 {
const OWNER_READ = 0o400;
const OWNER_WRITE = 0o200;
const OWNER_EXEC = 0o100;
const GROUP_READ = 0o40;
const GROUP_WRITE = 0o20;
const GROUP_EXEC = 0o10;
const OTHER_READ = 0o4;
const OTHER_WRITE = 0o2;
const OTHER_EXEC = 0o1;
}
}
/// Represents windows-specific metadata about some path on a remote machine
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct WindowsMetadata {
/// Represents whether or not a file or directory is an archive
pub archive: bool,
/// Represents whether or not a file or directory is compressed
pub compressed: bool,
/// Represents whether or not the file or directory is encrypted
pub encrypted: bool,
/// Represents whether or not a file or directory is hidden
pub hidden: bool,
/// Represents whether or not a directory or user data stream is configured with integrity
pub integrity_stream: bool,
/// Represents whether or not a file does not have other attributes set
pub normal: bool,
/// Represents whether or not a file or directory is not to be indexed by content indexing
/// service
pub not_content_indexed: bool,
/// Represents whether or not a user data stream is not to be read by the background data
/// integrity scanner
pub no_scrub_data: bool,
/// Represents whether or not the data of a file is not available immediately
pub offline: bool,
/// Represents whether or not a file or directory is not fully present locally
pub recall_on_data_access: bool,
/// Represents whether or not a file or directory has no physical representation on the local
/// system (is virtual)
pub recall_on_open: bool,
/// Represents whether or not a file or directory has an associated reparse point, or a file is
/// a symbolic link
pub reparse_point: bool,
/// Represents whether or not a file is a sparse file
pub sparse_file: bool,
/// Represents whether or not a file or directory is used partially or exclusively by the
/// operating system
pub system: bool,
/// Represents whether or not a file is being used for temporary storage
pub temporary: bool,
}
#[cfg(feature = "schemars")]
impl WindowsMetadata {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(WindowsMetadata)
}
}
impl From<u32> for WindowsMetadata {
/// Create from a windows file attribute bitset
fn from(file_attributes: u32) -> Self {
let flags = WindowsFileAttributeFlags::from_bits_truncate(file_attributes);
Self {
archive: flags.contains(WindowsFileAttributeFlags::ARCHIVE),
compressed: flags.contains(WindowsFileAttributeFlags::COMPRESSED),
encrypted: flags.contains(WindowsFileAttributeFlags::ENCRYPTED),
hidden: flags.contains(WindowsFileAttributeFlags::HIDDEN),
integrity_stream: flags.contains(WindowsFileAttributeFlags::INTEGRITY_SYSTEM),
normal: flags.contains(WindowsFileAttributeFlags::NORMAL),
not_content_indexed: flags.contains(WindowsFileAttributeFlags::NOT_CONTENT_INDEXED),
no_scrub_data: flags.contains(WindowsFileAttributeFlags::NO_SCRUB_DATA),
offline: flags.contains(WindowsFileAttributeFlags::OFFLINE),
recall_on_data_access: flags.contains(WindowsFileAttributeFlags::RECALL_ON_DATA_ACCESS),
recall_on_open: flags.contains(WindowsFileAttributeFlags::RECALL_ON_OPEN),
reparse_point: flags.contains(WindowsFileAttributeFlags::REPARSE_POINT),
sparse_file: flags.contains(WindowsFileAttributeFlags::SPARSE_FILE),
system: flags.contains(WindowsFileAttributeFlags::SYSTEM),
temporary: flags.contains(WindowsFileAttributeFlags::TEMPORARY),
}
}
}
impl From<WindowsMetadata> for u32 {
/// Convert to a windows file attribute bitset
fn from(metadata: WindowsMetadata) -> Self {
let mut flags = WindowsFileAttributeFlags::empty();
if metadata.archive {
flags.insert(WindowsFileAttributeFlags::ARCHIVE);
}
if metadata.compressed {
flags.insert(WindowsFileAttributeFlags::COMPRESSED);
}
if metadata.encrypted {
flags.insert(WindowsFileAttributeFlags::ENCRYPTED);
}
if metadata.hidden {
flags.insert(WindowsFileAttributeFlags::HIDDEN);
}
if metadata.integrity_stream {
flags.insert(WindowsFileAttributeFlags::INTEGRITY_SYSTEM);
}
if metadata.normal {
flags.insert(WindowsFileAttributeFlags::NORMAL);
}
if metadata.not_content_indexed {
flags.insert(WindowsFileAttributeFlags::NOT_CONTENT_INDEXED);
}
if metadata.no_scrub_data {
flags.insert(WindowsFileAttributeFlags::NO_SCRUB_DATA);
}
if metadata.offline {
flags.insert(WindowsFileAttributeFlags::OFFLINE);
}
if metadata.recall_on_data_access {
flags.insert(WindowsFileAttributeFlags::RECALL_ON_DATA_ACCESS);
}
if metadata.recall_on_open {
flags.insert(WindowsFileAttributeFlags::RECALL_ON_OPEN);
}
if metadata.reparse_point {
flags.insert(WindowsFileAttributeFlags::REPARSE_POINT);
}
if metadata.sparse_file {
flags.insert(WindowsFileAttributeFlags::SPARSE_FILE);
}
if metadata.system {
flags.insert(WindowsFileAttributeFlags::SYSTEM);
}
if metadata.temporary {
flags.insert(WindowsFileAttributeFlags::TEMPORARY);
}
flags.bits
}
}
bitflags! {
struct WindowsFileAttributeFlags: u32 {
const ARCHIVE = 0x20;
const COMPRESSED = 0x800;
const ENCRYPTED = 0x4000;
const HIDDEN = 0x2;
const INTEGRITY_SYSTEM = 0x8000;
const NORMAL = 0x80;
const NOT_CONTENT_INDEXED = 0x2000;
const NO_SCRUB_DATA = 0x20000;
const OFFLINE = 0x1000;
const RECALL_ON_DATA_ACCESS = 0x400000;
const RECALL_ON_OPEN = 0x40000;
const REPARSE_POINT = 0x400;
const SPARSE_FILE = 0x200;
const SYSTEM = 0x4;
const TEMPORARY = 0x100;
const VIRTUAL = 0x10000;
}
}

@ -0,0 +1,137 @@
use derive_more::{Display, Error};
use portable_pty::PtySize as PortablePtySize;
use serde::{Deserialize, Serialize};
use std::{fmt, num::ParseIntError, str::FromStr};
/// Represents the size associated with a remote PTY
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct PtySize {
/// Number of lines of text
pub rows: u16,
/// Number of columns of text
pub cols: u16,
/// Width of a cell in pixels. Note that some systems never fill this value and ignore it.
#[serde(default)]
pub pixel_width: u16,
/// Height of a cell in pixels. Note that some systems never fill this value and ignore it.
#[serde(default)]
pub pixel_height: u16,
}
impl PtySize {
/// Creates new size using just rows and columns
pub fn from_rows_and_cols(rows: u16, cols: u16) -> Self {
Self {
rows,
cols,
..Default::default()
}
}
}
#[cfg(feature = "schemars")]
impl PtySize {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(PtySize)
}
}
impl From<PortablePtySize> for PtySize {
fn from(size: PortablePtySize) -> Self {
Self {
rows: size.rows,
cols: size.cols,
pixel_width: size.pixel_width,
pixel_height: size.pixel_height,
}
}
}
impl From<PtySize> for PortablePtySize {
fn from(size: PtySize) -> Self {
Self {
rows: size.rows,
cols: size.cols,
pixel_width: size.pixel_width,
pixel_height: size.pixel_height,
}
}
}
impl fmt::Display for PtySize {
/// Prints out `rows,cols[,pixel_width,pixel_height]` where the
/// pixel width and pixel height are only included if either
/// one of them is not zero
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{},{}", self.rows, self.cols)?;
if self.pixel_width > 0 || self.pixel_height > 0 {
write!(f, ",{},{}", self.pixel_width, self.pixel_height)?;
}
Ok(())
}
}
impl Default for PtySize {
fn default() -> Self {
PtySize {
rows: 24,
cols: 80,
pixel_width: 0,
pixel_height: 0,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Display, Error)]
pub enum PtySizeParseError {
MissingRows,
MissingColumns,
InvalidRows(ParseIntError),
InvalidColumns(ParseIntError),
InvalidPixelWidth(ParseIntError),
InvalidPixelHeight(ParseIntError),
}
impl FromStr for PtySize {
type Err = PtySizeParseError;
/// Attempts to parse a str into PtySize using one of the following formats:
///
/// * rows,cols (defaults to 0 for pixel_width & pixel_height)
/// * rows,cols,pixel_width,pixel_height
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut tokens = s.split(',');
Ok(Self {
rows: tokens
.next()
.ok_or(PtySizeParseError::MissingRows)?
.trim()
.parse()
.map_err(PtySizeParseError::InvalidRows)?,
cols: tokens
.next()
.ok_or(PtySizeParseError::MissingColumns)?
.trim()
.parse()
.map_err(PtySizeParseError::InvalidColumns)?,
pixel_width: tokens
.next()
.map(|s| s.trim().parse())
.transpose()
.map_err(PtySizeParseError::InvalidPixelWidth)?
.unwrap_or(0),
pixel_height: tokens
.next()
.map(|s| s.trim().parse())
.transpose()
.map_err(PtySizeParseError::InvalidPixelHeight)?
.unwrap_or(0),
})
}
}

@ -0,0 +1,45 @@
use serde::{Deserialize, Serialize};
use std::{env, path::PathBuf};
/// Represents information about a system
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct SystemInfo {
/// Family of the operating system as described in
/// https://doc.rust-lang.org/std/env/consts/constant.FAMILY.html
pub family: String,
/// Name of the specific operating system as described in
/// https://doc.rust-lang.org/std/env/consts/constant.OS.html
pub os: String,
/// Architecture of the CPI as described in
/// https://doc.rust-lang.org/std/env/consts/constant.ARCH.html
pub arch: String,
/// Current working directory of the running server process
pub current_dir: PathBuf,
/// Primary separator for path components for the current platform
/// as defined in https://doc.rust-lang.org/std/path/constant.MAIN_SEPARATOR.html
pub main_separator: char,
}
#[cfg(feature = "schemars")]
impl SystemInfo {
pub fn root_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(SystemInfo)
}
}
impl Default for SystemInfo {
fn default() -> Self {
Self {
family: env::consts::FAMILY.to_string(),
os: env::consts::OS.to_string(),
arch: env::consts::ARCH.to_string(),
current_dir: env::current_dir().unwrap_or_default(),
main_separator: std::path::MAIN_SEPARATOR,
}
}
}

@ -0,0 +1,27 @@
use serde::{Deserialize, Serialize};
pub(crate) fn deserialize_u128_option<'de, D>(deserializer: D) -> Result<Option<u128>, D::Error>
where
D: serde::Deserializer<'de>,
{
match Option::<String>::deserialize(deserializer)? {
Some(s) => match s.parse::<u128>() {
Ok(value) => Ok(Some(value)),
Err(error) => Err(serde::de::Error::custom(format!(
"Cannot convert to u128 with error: {:?}",
error
))),
},
None => Ok(None),
}
}
pub(crate) fn serialize_u128_option<S: serde::Serializer>(
val: &Option<u128>,
s: S,
) -> Result<S::Ok, S::Error> {
match val {
Some(v) => format!("{}", *v).serialize(s),
None => s.serialize_unit(),
}
}

@ -1,13 +1,20 @@
mod api;
pub use api::*;
mod client;
pub use client::*;
mod constants;
mod net;
pub use net::*;
mod credentials;
pub use credentials::*;
pub mod data;
pub use data::*;
pub use data::{DistantMsg, DistantRequestData, DistantResponseData, Map};
mod manager;
pub use manager::*;
mod constants;
mod serde_str;
mod server;
pub use server::*;
/// Re-export of `distant-net` as `net`
pub use distant_net as net;

@ -0,0 +1,7 @@
mod client;
mod data;
mod server;
pub use client::*;
pub use data::*;
pub use server::*;

@ -0,0 +1,761 @@
use super::data::{
ConnectionId, ConnectionInfo, ConnectionList, Destination, Extra, ManagerRequest,
ManagerResponse,
};
use crate::{DistantChannel, DistantClient, DistantMsg, DistantRequestData, DistantResponseData};
use distant_net::{
router, Auth, AuthServer, Client, IntoSplit, MpscTransport, OneshotListener, Request, Response,
ServerExt, ServerRef, UntypedTransportRead, UntypedTransportWrite,
};
use log::*;
use std::{
collections::HashMap,
io,
ops::{Deref, DerefMut},
};
use tokio::task::JoinHandle;
mod config;
pub use config::*;
mod ext;
pub use ext::*;
router!(DistantManagerClientRouter {
auth_transport: Request<Auth> => Response<Auth>,
manager_transport: Response<ManagerResponse> => Request<ManagerRequest>,
});
/// Represents a client that can connect to a remote distant manager
pub struct DistantManagerClient {
auth: Box<dyn ServerRef>,
client: Client<ManagerRequest, ManagerResponse>,
distant_clients: HashMap<ConnectionId, ClientHandle>,
}
impl Drop for DistantManagerClient {
fn drop(&mut self) {
self.auth.abort();
self.client.abort();
}
}
/// Represents a raw channel between a manager client and some remote server
pub struct RawDistantChannel {
pub transport: MpscTransport<
Request<DistantMsg<DistantRequestData>>,
Response<DistantMsg<DistantResponseData>>,
>,
forward_task: JoinHandle<()>,
mailbox_task: JoinHandle<()>,
}
impl RawDistantChannel {
pub fn abort(&self) {
self.forward_task.abort();
self.mailbox_task.abort();
}
}
impl Deref for RawDistantChannel {
type Target = MpscTransport<
Request<DistantMsg<DistantRequestData>>,
Response<DistantMsg<DistantResponseData>>,
>;
fn deref(&self) -> &Self::Target {
&self.transport
}
}
impl DerefMut for RawDistantChannel {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.transport
}
}
struct ClientHandle {
client: DistantClient,
forward_task: JoinHandle<()>,
mailbox_task: JoinHandle<()>,
}
impl Drop for ClientHandle {
fn drop(&mut self) {
self.forward_task.abort();
self.mailbox_task.abort();
}
}
impl DistantManagerClient {
/// Initializes a client using the provided [`UntypedTransport`]
pub fn new<T>(config: DistantManagerClientConfig, transport: T) -> io::Result<Self>
where
T: IntoSplit + 'static,
T::Read: UntypedTransportRead + 'static,
T::Write: UntypedTransportWrite + 'static,
{
let DistantManagerClientRouter {
auth_transport,
manager_transport,
..
} = DistantManagerClientRouter::new(transport);
// Initialize our client with manager request/response transport
let (writer, reader) = manager_transport.into_split();
let client = Client::new(writer, reader)?;
// Initialize our auth handler with auth/auth transport
let auth = AuthServer {
on_challenge: config.on_challenge,
on_verify: config.on_verify,
on_info: config.on_info,
on_error: config.on_error,
}
.start(OneshotListener::from_value(auth_transport.into_split()))?;
Ok(Self {
auth,
client,
distant_clients: HashMap::new(),
})
}
/// Request that the manager launches a new server at the given `destination`
/// with `extra` being passed for destination-specific details, returning the new
/// `destination` of the spawned server to connect to
pub async fn launch(
&mut self,
destination: impl Into<Destination>,
extra: impl Into<Extra>,
) -> io::Result<Destination> {
let destination = Box::new(destination.into());
let extra = extra.into();
trace!("launch({}, {})", destination, extra);
let res = self
.client
.send(ManagerRequest::Launch { destination, extra })
.await?;
match res.payload {
ManagerResponse::Launched { destination } => Ok(destination),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Request that the manager establishes a new connection at the given `destination`
/// with `extra` being passed for destination-specific details
pub async fn connect(
&mut self,
destination: impl Into<Destination>,
extra: impl Into<Extra>,
) -> io::Result<ConnectionId> {
let destination = Box::new(destination.into());
let extra = extra.into();
trace!("connect({}, {})", destination, extra);
let res = self
.client
.send(ManagerRequest::Connect { destination, extra })
.await?;
match res.payload {
ManagerResponse::Connected { id } => Ok(id),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Establishes a channel with the server represented by the `connection_id`,
/// returning a [`DistantChannel`] acting as the connection
///
/// ### Note
///
/// Multiple calls to open a channel against the same connection will result in
/// clones of the same [`DistantChannel`] rather than establishing a duplicate
/// remote connection to the same server
pub async fn open_channel(
&mut self,
connection_id: ConnectionId,
) -> io::Result<DistantChannel> {
trace!("open_channel({})", connection_id);
if let Some(handle) = self.distant_clients.get(&connection_id) {
Ok(handle.client.clone_channel())
} else {
let RawDistantChannel {
transport,
forward_task,
mailbox_task,
} = self.open_raw_channel(connection_id).await?;
let (writer, reader) = transport.into_split();
let client = DistantClient::new(writer, reader)?;
let channel = client.clone_channel();
self.distant_clients.insert(
connection_id,
ClientHandle {
client,
forward_task,
mailbox_task,
},
);
Ok(channel)
}
}
/// Establishes a channel with the server represented by the `connection_id`,
/// returning a [`Transport`] acting as the connection
///
/// ### Note
///
/// Multiple calls to open a channel against the same connection will result in establishing a
/// duplicate remote connections to the same server, so take care when using this method
pub async fn open_raw_channel(
&mut self,
connection_id: ConnectionId,
) -> io::Result<RawDistantChannel> {
trace!("open_raw_channel({})", connection_id);
let mut mailbox = self
.client
.mail(ManagerRequest::OpenChannel { id: connection_id })
.await?;
// Wait for the first response, which should be channel confirmation
let channel_id = match mailbox.next().await {
Some(response) => match response.payload {
ManagerResponse::ChannelOpened { id } => Ok(id),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
},
None => Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"open_channel mailbox aborted",
)),
}?;
// Spawn reader and writer tasks to forward requests and replies
// using our opened channel
let (t1, t2) = MpscTransport::pair(1);
let (mut writer, mut reader) = t1.into_split();
let mailbox_task = tokio::spawn(async move {
use distant_net::TypedAsyncWrite;
while let Some(response) = mailbox.next().await {
match response.payload {
ManagerResponse::Channel { response, .. } => {
if let Err(x) = writer.write(response).await {
error!("[Conn {}] {}", connection_id, x);
}
}
ManagerResponse::ChannelClosed { .. } => break,
_ => continue,
}
}
});
let mut manager_channel = self.client.clone_channel();
let forward_task = tokio::spawn(async move {
use distant_net::TypedAsyncRead;
loop {
match reader.read().await {
Ok(Some(request)) => {
// NOTE: In this situation, we do not expect a response to this
// request (even if the server sends something back)
if let Err(x) = manager_channel
.fire(ManagerRequest::Channel {
id: channel_id,
request,
})
.await
{
error!("[Conn {}] {}", connection_id, x);
}
}
Ok(None) => break,
Err(x) => {
error!("[Conn {}] {}", connection_id, x);
continue;
}
}
}
});
Ok(RawDistantChannel {
transport: t2,
forward_task,
mailbox_task,
})
}
/// Retrieves information about a specific connection
pub async fn info(&mut self, id: ConnectionId) -> io::Result<ConnectionInfo> {
trace!("info({})", id);
let res = self.client.send(ManagerRequest::Info { id }).await?;
match res.payload {
ManagerResponse::Info(info) => Ok(info),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Kills the specified connection
pub async fn kill(&mut self, id: ConnectionId) -> io::Result<()> {
trace!("kill({})", id);
let res = self.client.send(ManagerRequest::Kill { id }).await?;
match res.payload {
ManagerResponse::Killed => Ok(()),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Retrieves a list of active connections
pub async fn list(&mut self) -> io::Result<ConnectionList> {
trace!("list()");
let res = self.client.send(ManagerRequest::List).await?;
match res.payload {
ManagerResponse::List(list) => Ok(list),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Requests that the manager shuts down
pub async fn shutdown(&mut self) -> io::Result<()> {
trace!("shutdown()");
let res = self.client.send(ManagerRequest::Shutdown).await?;
match res.payload {
ManagerResponse::Shutdown => Ok(()),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::{Error, ErrorKind};
use distant_net::{
FramedTransport, InmemoryTransport, PlainCodec, UntypedTransportRead, UntypedTransportWrite,
};
fn setup() -> (
DistantManagerClient,
FramedTransport<InmemoryTransport, PlainCodec>,
) {
let (t1, t2) = FramedTransport::pair(100);
let client =
DistantManagerClient::new(DistantManagerClientConfig::with_empty_prompts(), t1)
.unwrap();
(client, t2)
}
#[inline]
fn test_error() -> Error {
Error {
kind: ErrorKind::Interrupted,
description: "test error".to_string(),
}
}
#[inline]
fn test_io_error() -> io::Error {
test_error().into()
}
#[tokio::test]
async fn connect_should_report_error_if_receives_error_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Error(test_error()),
))
.await
.unwrap();
});
let err = client
.connect(
"scheme://host".parse::<Destination>().unwrap(),
"key=value".parse::<Extra>().unwrap(),
)
.await
.unwrap_err();
assert_eq!(err.kind(), test_io_error().kind());
assert_eq!(err.to_string(), test_io_error().to_string());
}
#[tokio::test]
async fn connect_should_report_error_if_receives_unexpected_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Shutdown))
.await
.unwrap();
});
let err = client
.connect(
"scheme://host".parse::<Destination>().unwrap(),
"key=value".parse::<Extra>().unwrap(),
)
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn connect_should_return_id_from_successful_response() {
let (mut client, mut transport) = setup();
let expected_id = 999;
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Connected { id: expected_id },
))
.await
.unwrap();
});
let id = client
.connect(
"scheme://host".parse::<Destination>().unwrap(),
"key=value".parse::<Extra>().unwrap(),
)
.await
.unwrap();
assert_eq!(id, expected_id);
}
#[tokio::test]
async fn info_should_report_error_if_receives_error_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Error(test_error()),
))
.await
.unwrap();
});
let err = client.info(123).await.unwrap_err();
assert_eq!(err.kind(), test_io_error().kind());
assert_eq!(err.to_string(), test_io_error().to_string());
}
#[tokio::test]
async fn info_should_report_error_if_receives_unexpected_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Shutdown))
.await
.unwrap();
});
let err = client.info(123).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn info_should_return_connection_info_from_successful_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
let info = ConnectionInfo {
id: 123,
destination: "scheme://host".parse::<Destination>().unwrap(),
extra: "key=value".parse::<Extra>().unwrap(),
};
transport
.write(Response::new(request.id, ManagerResponse::Info(info)))
.await
.unwrap();
});
let info = client.info(123).await.unwrap();
assert_eq!(info.id, 123);
assert_eq!(
info.destination,
"scheme://host".parse::<Destination>().unwrap()
);
assert_eq!(info.extra, "key=value".parse::<Extra>().unwrap());
}
#[tokio::test]
async fn list_should_report_error_if_receives_error_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Error(test_error()),
))
.await
.unwrap();
});
let err = client.list().await.unwrap_err();
assert_eq!(err.kind(), test_io_error().kind());
assert_eq!(err.to_string(), test_io_error().to_string());
}
#[tokio::test]
async fn list_should_report_error_if_receives_unexpected_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Shutdown))
.await
.unwrap();
});
let err = client.list().await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn list_should_return_connection_list_from_successful_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
let mut list = ConnectionList::new();
list.insert(123, "scheme://host".parse::<Destination>().unwrap());
transport
.write(Response::new(request.id, ManagerResponse::List(list)))
.await
.unwrap();
});
let list = client.list().await.unwrap();
assert_eq!(list.len(), 1);
assert_eq!(
list.get(&123).expect("Connection list missing item"),
&"scheme://host".parse::<Destination>().unwrap()
);
}
#[tokio::test]
async fn kill_should_report_error_if_receives_error_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Error(test_error()),
))
.await
.unwrap();
});
let err = client.kill(123).await.unwrap_err();
assert_eq!(err.kind(), test_io_error().kind());
assert_eq!(err.to_string(), test_io_error().to_string());
}
#[tokio::test]
async fn kill_should_report_error_if_receives_unexpected_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Shutdown))
.await
.unwrap();
});
let err = client.kill(123).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn kill_should_return_success_from_successful_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Killed))
.await
.unwrap();
});
client.kill(123).await.unwrap();
}
#[tokio::test]
async fn shutdown_should_report_error_if_receives_error_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Connected { id: 0 },
))
.await
.unwrap();
});
let err = client.shutdown().await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn shutdown_should_report_error_if_receives_unexpected_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Error(test_error()),
))
.await
.unwrap();
});
let err = client.shutdown().await.unwrap_err();
assert_eq!(err.kind(), test_io_error().kind());
assert_eq!(err.to_string(), test_io_error().to_string());
}
#[tokio::test]
async fn shutdown_should_return_success_from_successful_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Shutdown))
.await
.unwrap();
});
client.shutdown().await.unwrap();
}
}

@ -0,0 +1,85 @@
use distant_net::{AuthChallengeFn, AuthErrorFn, AuthInfoFn, AuthVerifyFn, AuthVerifyKind};
use log::*;
use std::io;
/// Configuration to use when creating a new [`DistantManagerClient`](super::DistantManagerClient)
pub struct DistantManagerClientConfig {
pub on_challenge: Box<AuthChallengeFn>,
pub on_verify: Box<AuthVerifyFn>,
pub on_info: Box<AuthInfoFn>,
pub on_error: Box<AuthErrorFn>,
}
impl DistantManagerClientConfig {
/// Creates a new config with prompts that return empty strings
pub fn with_empty_prompts() -> Self {
Self::with_prompts(|_| Ok("".to_string()), |_| Ok("".to_string()))
}
/// Creates a new config with two prompts
///
/// * `password_prompt` - used for prompting for a secret, and should not display what is typed
/// * `text_prompt` - used for general text, and is okay to display what is typed
pub fn with_prompts<PP, PT>(password_prompt: PP, text_prompt: PT) -> Self
where
PP: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
PT: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
{
Self {
on_challenge: Box::new(move |questions, _extra| {
trace!("[manager client] on_challenge({questions:?}, {_extra:?})");
let mut answers = Vec::new();
for question in questions.iter() {
// Contains all prompt lines including same line
let mut lines = question.text.split('\n').collect::<Vec<_>>();
// Line that is prompt on same line as answer
let line = lines.pop().unwrap();
// Go ahead and display all other lines
for line in lines.into_iter() {
eprintln!("{}", line);
}
// Get an answer from user input, or use a blank string as an answer
// if we fail to get input from the user
let answer = password_prompt(line).unwrap_or_default();
answers.push(answer);
}
answers
}),
on_verify: Box::new(move |kind, text| {
trace!("[manager client] on_verify({kind}, {text})");
match kind {
AuthVerifyKind::Host => {
eprintln!("{}", text);
match text_prompt("Enter [y/N]> ") {
Ok(answer) => {
trace!("Verify? Answer = '{answer}'");
matches!(answer.trim(), "y" | "Y" | "yes" | "YES")
}
Err(x) => {
error!("Failed verification: {x}");
false
}
}
}
x => {
error!("Unsupported verify kind: {x}");
false
}
}
}),
on_info: Box::new(|text| {
trace!("[manager client] on_info({text})");
println!("{}", text);
}),
on_error: Box::new(|kind, text| {
trace!("[manager client] on_error({kind}, {text})");
eprintln!("{}: {}", kind, text);
}),
}
}
}

@ -0,0 +1,14 @@
mod tcp;
pub use tcp::*;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
pub use unix::*;
#[cfg(windows)]
mod windows;
#[cfg(windows)]
pub use windows::*;

@ -0,0 +1,50 @@
use crate::{DistantManagerClient, DistantManagerClientConfig};
use async_trait::async_trait;
use distant_net::{Codec, FramedTransport, TcpTransport};
use std::{convert, net::SocketAddr};
use tokio::{io, time::Duration};
#[async_trait]
pub trait TcpDistantManagerClientExt {
/// Connect to a remote TCP server using the provided information
async fn connect<C>(
config: DistantManagerClientConfig,
addr: SocketAddr,
codec: C,
) -> io::Result<DistantManagerClient>
where
C: Codec + Send + 'static;
/// Connect to a remote TCP server, timing out after duration has passed
async fn connect_timeout<C>(
config: DistantManagerClientConfig,
addr: SocketAddr,
codec: C,
duration: Duration,
) -> io::Result<DistantManagerClient>
where
C: Codec + Send + 'static,
{
tokio::time::timeout(duration, Self::connect(config, addr, codec))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
}
#[async_trait]
impl TcpDistantManagerClientExt for DistantManagerClient {
/// Connect to a remote TCP server using the provided information
async fn connect<C>(
config: DistantManagerClientConfig,
addr: SocketAddr,
codec: C,
) -> io::Result<DistantManagerClient>
where
C: Codec + Send + 'static,
{
let transport = TcpTransport::connect(addr).await?;
let transport = FramedTransport::new(transport, codec);
Self::new(config, transport)
}
}

@ -0,0 +1,54 @@
use crate::{DistantManagerClient, DistantManagerClientConfig};
use async_trait::async_trait;
use distant_net::{Codec, FramedTransport, UnixSocketTransport};
use std::{convert, path::Path};
use tokio::{io, time::Duration};
#[async_trait]
pub trait UnixSocketDistantManagerClientExt {
/// Connect to a proxy unix socket
async fn connect<P, C>(
config: DistantManagerClientConfig,
path: P,
codec: C,
) -> io::Result<DistantManagerClient>
where
P: AsRef<Path> + Send,
C: Codec + Send + 'static;
/// Connect to a proxy unix socket, timing out after duration has passed
async fn connect_timeout<P, C>(
config: DistantManagerClientConfig,
path: P,
codec: C,
duration: Duration,
) -> io::Result<DistantManagerClient>
where
P: AsRef<Path> + Send,
C: Codec + Send + 'static,
{
tokio::time::timeout(duration, Self::connect(config, path, codec))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
}
#[async_trait]
impl UnixSocketDistantManagerClientExt for DistantManagerClient {
/// Connect to a proxy unix socket
async fn connect<P, C>(
config: DistantManagerClientConfig,
path: P,
codec: C,
) -> io::Result<DistantManagerClient>
where
P: AsRef<Path> + Send,
C: Codec + Send + 'static,
{
let p = path.as_ref();
let transport = UnixSocketTransport::connect(p).await?;
let transport = FramedTransport::new(transport, codec);
Ok(DistantManagerClient::new(config, transport)?)
}
}

@ -0,0 +1,91 @@
use crate::{DistantManagerClient, DistantManagerClientConfig};
use async_trait::async_trait;
use distant_net::{Codec, FramedTransport, WindowsPipeTransport};
use std::{
convert,
ffi::{OsStr, OsString},
};
use tokio::{io, time::Duration};
#[async_trait]
pub trait WindowsPipeDistantManagerClientExt {
/// Connect to a server listening on a Windows pipe at the specified address
/// using the given codec
async fn connect<A, C>(
config: DistantManagerClientConfig,
addr: A,
codec: C,
) -> io::Result<DistantManagerClient>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + 'static;
/// Connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}` using the given codec
async fn connect_local<N, C>(
config: DistantManagerClientConfig,
name: N,
codec: C,
) -> io::Result<DistantManagerClient>
where
N: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
let mut addr = OsString::from(r"\\.\pipe\");
addr.push(name.as_ref());
Self::connect(config, addr, codec).await
}
/// Connect to a server listening on a Windows pipe at the specified address
/// using the given codec, timing out after duration has passed
async fn connect_timeout<A, C>(
config: DistantManagerClientConfig,
addr: A,
codec: C,
duration: Duration,
) -> io::Result<DistantManagerClient>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
tokio::time::timeout(duration, Self::connect(config, addr, codec))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
/// Connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed
async fn connect_local_timeout<N, C>(
config: DistantManagerClientConfig,
name: N,
codec: C,
duration: Duration,
) -> io::Result<DistantManagerClient>
where
N: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
let mut addr = OsString::from(r"\\.\pipe\");
addr.push(name.as_ref());
Self::connect_timeout(config, addr, codec, duration).await
}
}
#[async_trait]
impl WindowsPipeDistantManagerClientExt for DistantManagerClient {
async fn connect<A, C>(
config: DistantManagerClientConfig,
addr: A,
codec: C,
) -> io::Result<DistantManagerClient>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
let a = addr.as_ref();
let transport = WindowsPipeTransport::connect(a).await?;
let transport = FramedTransport::new(transport, codec);
Ok(DistantManagerClient::new(config, transport)?)
}
}

@ -0,0 +1,20 @@
mod destination;
pub use destination::*;
mod extra;
pub use extra::*;
mod id;
pub use id::*;
mod info;
pub use info::*;
mod list;
pub use list::*;
mod request;
pub use request::*;
mod response;
pub use response::*;

@ -0,0 +1,266 @@
use crate::serde_str::{deserialize_from_str, serialize_to_str};
use derive_more::{Display, Error, From};
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
use std::{convert::TryFrom, fmt, hash::Hash, str::FromStr};
use uriparse::{
Authority, AuthorityError, Host, Password, Scheme, URIReference, URIReferenceError, Username,
URI,
};
/// Represents an error that occurs when trying to parse a destination from a str
#[derive(Copy, Clone, Debug, Display, Error, From, PartialEq, Eq)]
pub enum DestinationError {
MissingHost,
URIReferenceError(URIReferenceError),
}
/// `distant` connects and logs into the specified destination, which may be specified as either
/// `hostname:port` where an attempt to connect to a **distant** server will be made, or a URI of
/// one of the following forms:
///
/// * `distant://hostname:port` - connect to a distant server
/// * `ssh://[user@]hostname[:port]` - connect to an SSH server
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct Destination(URIReference<'static>);
impl Destination {
/// Returns a reference to the scheme associated with the destination, if it has one
pub fn scheme(&self) -> Option<&str> {
self.0.scheme().map(Scheme::as_str)
}
/// Returns the host of the destination as a string
pub fn to_host_string(&self) -> String {
// NOTE: We guarantee that there is a host for a destination during construction
self.0.host().unwrap().to_string()
}
/// Returns the port tied to the destination, if it has one
pub fn port(&self) -> Option<u16> {
self.0.port()
}
/// Returns the username tied with the destination if it has one
pub fn username(&self) -> Option<&str> {
self.0.username().map(Username::as_str)
}
/// Returns the password tied with the destination if it has one
pub fn password(&self) -> Option<&str> {
self.0.password().map(Password::as_str)
}
/// Replaces the host of the destination
pub fn replace_host(&mut self, host: &str) -> Result<(), URIReferenceError> {
let username = self
.0
.username()
.map(Username::as_borrowed)
.map(Username::into_owned);
let password = self
.0
.password()
.map(Password::as_borrowed)
.map(Password::into_owned);
let port = self.0.port();
let _ = self.0.set_authority(Some(
Authority::from_parts(
username,
password,
Host::try_from(host)
.map(Host::into_owned)
.map_err(AuthorityError::from)
.map_err(URIReferenceError::from)?,
port,
)
.map(Authority::into_owned)
.map_err(URIReferenceError::from)?,
))?;
Ok(())
}
/// Indicates whether the host destination is globally routable
pub fn is_host_global(&self) -> bool {
match self.0.host() {
Some(Host::IPv4Address(x)) => {
!(x.is_broadcast()
|| x.is_documentation()
|| x.is_link_local()
|| x.is_loopback()
|| x.is_private()
|| x.is_unspecified())
}
Some(Host::IPv6Address(x)) => {
// NOTE: 14 is the global flag
x.is_multicast() && (x.segments()[0] & 0x000f == 14)
}
Some(Host::RegisteredName(name)) => !name.trim().is_empty(),
None => false,
}
}
/// Returns true if destination represents a distant server
pub fn is_distant(&self) -> bool {
self.scheme_eq("distant")
}
/// Returns true if destination represents an ssh server
pub fn is_ssh(&self) -> bool {
self.scheme_eq("ssh")
}
fn scheme_eq(&self, s: &str) -> bool {
match self.0.scheme() {
Some(scheme) => scheme.as_str().eq_ignore_ascii_case(s),
None => false,
}
}
/// Returns reference to inner [`URIReference`]
pub fn as_uri_ref(&self) -> &URIReference<'static> {
&self.0
}
}
impl AsRef<Destination> for &Destination {
fn as_ref(&self) -> &Destination {
*self
}
}
impl AsRef<URIReference<'static>> for Destination {
fn as_ref(&self) -> &URIReference<'static> {
self.as_uri_ref()
}
}
impl fmt::Display for Destination {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl FromStr for Destination {
type Err = DestinationError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Disallow empty (whitespace-only) input as that passes our
// parsing for a URI reference (relative with no scheme or anything)
if s.trim().is_empty() {
return Err(DestinationError::MissingHost);
}
let mut destination = URIReference::try_from(s)
.map(URIReference::into_owned)
.map(Destination)
.map_err(DestinationError::URIReferenceError)?;
// Only support relative reference if it is a path reference as
// we convert that to a relative reference with a host
if destination.0.is_relative_reference() {
let path = destination.0.path().to_string();
destination.replace_host(path.as_str())?;
let _ = destination.0.set_path("/")?;
}
Ok(destination)
}
}
impl<'a> TryFrom<URIReference<'a>> for Destination {
type Error = DestinationError;
fn try_from(uri_ref: URIReference<'a>) -> Result<Self, Self::Error> {
if uri_ref.host().is_none() {
return Err(DestinationError::MissingHost);
}
Ok(Self(uri_ref.into_owned()))
}
}
impl<'a> TryFrom<URI<'a>> for Destination {
type Error = DestinationError;
fn try_from(uri: URI<'a>) -> Result<Self, Self::Error> {
let uri_ref: URIReference<'a> = uri.into();
Self::try_from(uri_ref)
}
}
impl FromStr for Box<Destination> {
type Err = DestinationError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let destination = s.parse::<Destination>()?;
Ok(Box::new(destination))
}
}
impl Serialize for Destination {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_to_str(self, serializer)
}
}
impl<'de> Deserialize<'de> for Destination {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserialize_from_str(deserializer)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_should_fail_if_string_is_only_whitespace() {
let err = "".parse::<Destination>().unwrap_err();
assert_eq!(err, DestinationError::MissingHost);
let err = " ".parse::<Destination>().unwrap_err();
assert_eq!(err, DestinationError::MissingHost);
let err = "\t".parse::<Destination>().unwrap_err();
assert_eq!(err, DestinationError::MissingHost);
let err = "\n".parse::<Destination>().unwrap_err();
assert_eq!(err, DestinationError::MissingHost);
let err = "\r".parse::<Destination>().unwrap_err();
assert_eq!(err, DestinationError::MissingHost);
let err = "\r\n".parse::<Destination>().unwrap_err();
assert_eq!(err, DestinationError::MissingHost);
}
#[test]
fn parse_should_succeed_with_valid_uri() {
let destination = "distant://localhost".parse::<Destination>().unwrap();
assert_eq!(destination.scheme().unwrap(), "distant");
assert_eq!(destination.to_host_string(), "localhost");
assert_eq!(destination.as_uri_ref().path().to_string(), "/");
}
#[test]
fn parse_should_fail_if_relative_reference_that_is_not_valid_host() {
let _ = "/".parse::<Destination>().unwrap_err();
let _ = "/localhost".parse::<Destination>().unwrap_err();
let _ = "my/path".parse::<Destination>().unwrap_err();
let _ = "/my/path".parse::<Destination>().unwrap_err();
let _ = "//localhost".parse::<Destination>().unwrap_err();
}
#[test]
fn parse_should_succeed_with_nonempty_relative_reference_by_setting_host_to_path() {
let destination = "localhost".parse::<Destination>().unwrap();
assert_eq!(destination.to_host_string(), "localhost");
assert_eq!(destination.as_uri_ref().path().to_string(), "/");
}
}

@ -0,0 +1,2 @@
/// Represents extra data included for connections
pub type Extra = crate::data::Map;

@ -0,0 +1,5 @@
/// Id associated with an active connection
pub type ConnectionId = u64;
/// Id associated with an open channel
pub type ChannelId = u64;

@ -0,0 +1,15 @@
use super::{ConnectionId, Destination, Extra};
use serde::{Deserialize, Serialize};
/// Information about a specific connection
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ConnectionInfo {
/// Connection's id
pub id: ConnectionId,
/// Destination with which this connection is associated
pub destination: Destination,
/// Extra information associated with this connection
pub extra: Extra,
}

@ -0,0 +1,58 @@
use super::{ConnectionId, Destination};
use derive_more::IntoIterator;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
ops::{Deref, DerefMut, Index, IndexMut},
};
/// Represents a list of information about active connections
#[derive(Clone, Debug, PartialEq, Eq, IntoIterator, Serialize, Deserialize)]
pub struct ConnectionList(pub(crate) HashMap<ConnectionId, Destination>);
impl ConnectionList {
pub fn new() -> Self {
Self(HashMap::new())
}
/// Returns a reference to the destination associated with an active connection
pub fn connection_destination(&self, id: ConnectionId) -> Option<&Destination> {
self.0.get(&id)
}
}
impl Default for ConnectionList {
fn default() -> Self {
Self::new()
}
}
impl Deref for ConnectionList {
type Target = HashMap<ConnectionId, Destination>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for ConnectionList {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Index<u64> for ConnectionList {
type Output = Destination;
fn index(&self, connection_id: u64) -> &Self::Output {
&self.0[&connection_id]
}
}
impl IndexMut<u64> for ConnectionList {
fn index_mut(&mut self, connection_id: u64) -> &mut Self::Output {
self.0
.get_mut(&connection_id)
.expect("No connection with id")
}
}

@ -0,0 +1,72 @@
use super::{ChannelId, ConnectionId, Destination, Extra};
use crate::{DistantMsg, DistantRequestData};
use distant_net::Request;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "clap", derive(clap::Subcommand))]
#[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")]
pub enum ManagerRequest {
/// Launch a server using the manager
Launch {
// NOTE: Boxed per clippy's large_enum_variant warning
destination: Box<Destination>,
/// Extra details specific to the connection
#[cfg_attr(feature = "clap", clap(short, long, action = clap::ArgAction::Append))]
extra: Extra,
},
/// Initiate a connection through the manager
Connect {
// NOTE: Boxed per clippy's large_enum_variant warning
destination: Box<Destination>,
/// Extra details specific to the connection
#[cfg_attr(feature = "clap", clap(short, long, action = clap::ArgAction::Append))]
extra: Extra,
},
/// Opens a channel for communication with a server
#[cfg_attr(feature = "clap", clap(skip))]
OpenChannel {
/// Id of the connection
id: ConnectionId,
},
/// Sends data through channel
#[cfg_attr(feature = "clap", clap(skip))]
Channel {
/// Id of the channel
id: ChannelId,
/// Request to send to through the channel
#[cfg_attr(feature = "clap", clap(skip = skipped_request()))]
request: Request<DistantMsg<DistantRequestData>>,
},
/// Closes an open channel
#[cfg_attr(feature = "clap", clap(skip))]
CloseChannel {
/// Id of the channel to close
id: ChannelId,
},
/// Retrieve information about a specific connection
Info { id: ConnectionId },
/// Kill a specific connection
Kill { id: ConnectionId },
/// Retrieve list of connections being managed
List,
/// Signals the manager to shutdown
Shutdown,
}
/// Produces some default request, purely to satisfy clap
#[cfg(feature = "clap")]
fn skipped_request() -> Request<DistantMsg<DistantRequestData>> {
Request::new(DistantMsg::Single(DistantRequestData::SystemInfo {}))
}

@ -0,0 +1,53 @@
use crate::{data::Error, ConnectionInfo, ConnectionList, Destination};
use crate::{ChannelId, ConnectionId, DistantMsg, DistantResponseData};
use distant_net::Response;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")]
pub enum ManagerResponse {
/// Acknowledgement that a connection was killed
Killed,
/// Broadcast that the manager is shutting down (not guaranteed to be sent)
Shutdown,
/// Indicates that some error occurred during a request
Error(Error),
/// Confirmation of a distant server being launched
Launched {
/// Updated location of the spawned server
destination: Destination,
},
/// Confirmation of a connection being established
Connected { id: ConnectionId },
/// Information about a specific connection
Info(ConnectionInfo),
/// List of connections in the form of id -> destination
List(ConnectionList),
/// Forward a response back to a specific channel that made a request
Channel {
/// Id of the channel
id: ChannelId,
/// Response to an earlier channel request
response: Response<DistantMsg<DistantResponseData>>,
},
/// Indicates that a channel has been opened
ChannelOpened {
/// Id of the channel
id: ChannelId,
},
/// Indicates that a channel has been closed
ChannelClosed {
/// Id of the channel
id: ChannelId,
},
}

@ -0,0 +1,698 @@
use crate::{
ChannelId, ConnectionId, ConnectionInfo, ConnectionList, Destination, Extra, ManagerRequest,
ManagerResponse,
};
use async_trait::async_trait;
use distant_net::{
router, Auth, AuthClient, Client, IntoSplit, Listener, MpscListener, Request, Response, Server,
ServerCtx, ServerExt, UntypedTransportRead, UntypedTransportWrite,
};
use log::*;
use std::{collections::HashMap, io, sync::Arc};
use tokio::{
sync::{mpsc, Mutex, RwLock},
task::JoinHandle,
};
mod config;
pub use config::*;
mod connection;
pub use connection::*;
mod ext;
pub use ext::*;
mod handler;
pub use handler::*;
mod r#ref;
pub use r#ref::*;
router!(DistantManagerRouter {
auth_transport: Response<Auth> => Request<Auth>,
manager_transport: Request<ManagerRequest> => Response<ManagerResponse>,
});
/// Represents a manager of multiple distant server connections
pub struct DistantManager {
/// Receives authentication clients to feed into local data of server
auth_client_rx: Mutex<mpsc::Receiver<AuthClient>>,
/// Configuration settings for the server
config: DistantManagerConfig,
/// Mapping of connection id -> connection
connections: RwLock<HashMap<ConnectionId, DistantManagerConnection>>,
/// Handlers for launch requests
launch_handlers: Arc<RwLock<HashMap<String, BoxedLaunchHandler>>>,
/// Handlers for connect requests
connect_handlers: Arc<RwLock<HashMap<String, BoxedConnectHandler>>>,
/// Primary task of server
task: JoinHandle<()>,
}
impl DistantManager {
/// Initializes a new instance of [`DistantManagerServer`] using the provided [`UntypedTransport`]
pub fn start<L, T>(
mut config: DistantManagerConfig,
mut listener: L,
) -> io::Result<DistantManagerRef>
where
L: Listener<Output = T> + 'static,
T: IntoSplit + Send + 'static,
T::Read: UntypedTransportRead + 'static,
T::Write: UntypedTransportWrite + 'static,
{
let (conn_tx, mpsc_listener) = MpscListener::channel(config.connection_buffer_size);
let (auth_client_tx, auth_client_rx) = mpsc::channel(1);
// Spawn task that uses our input listener to get both auth and manager events,
// forwarding manager events to the internal mpsc listener
let task = tokio::spawn(async move {
while let Ok(transport) = listener.accept().await {
let DistantManagerRouter {
auth_transport,
manager_transport,
..
} = DistantManagerRouter::new(transport);
let (writer, reader) = auth_transport.into_split();
let client = match Client::new(writer, reader) {
Ok(client) => client,
Err(x) => {
error!("Creating auth client failed: {}", x);
continue;
}
};
let auth_client = AuthClient::from(client);
// Forward auth client for new connection in server
if auth_client_tx.send(auth_client).await.is_err() {
break;
}
// Forward connected and routed transport to server
if conn_tx.send(manager_transport.into_split()).await.is_err() {
break;
}
}
});
let launch_handlers = Arc::new(RwLock::new(config.launch_handlers.drain().collect()));
let weak_launch_handlers = Arc::downgrade(&launch_handlers);
let connect_handlers = Arc::new(RwLock::new(config.connect_handlers.drain().collect()));
let weak_connect_handlers = Arc::downgrade(&connect_handlers);
let server_ref = Self {
auth_client_rx: Mutex::new(auth_client_rx),
config,
launch_handlers,
connect_handlers,
connections: RwLock::new(HashMap::new()),
task,
}
.start(mpsc_listener)?;
Ok(DistantManagerRef {
launch_handlers: weak_launch_handlers,
connect_handlers: weak_connect_handlers,
inner: server_ref,
})
}
/// Launches a new server at the specified `destination` using the given `extra` information
/// and authentication client (if needed) to retrieve additional information needed to
/// enter the destination prior to starting the server, returning the destination of the
/// launched server
async fn launch(
&self,
destination: Destination,
extra: Extra,
auth: Option<&mut AuthClient>,
) -> io::Result<Destination> {
let auth = auth.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Authentication client not initialized",
)
})?;
let scheme = match destination.scheme() {
Some(scheme) => {
trace!("Using scheme {}", scheme);
scheme
}
None => {
trace!(
"Using fallback scheme of {}",
self.config.fallback_scheme.as_str()
);
self.config.fallback_scheme.as_str()
}
}
.to_lowercase();
let credentials = {
let lock = self.launch_handlers.read().await;
let handler = lock.get(&scheme).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("No launch handler registered for {}", scheme),
)
})?;
handler.launch(&destination, &extra, auth).await?
};
Ok(credentials)
}
/// Connects to a new server at the specified `destination` using the given `extra` information
/// and authentication client (if needed) to retrieve additional information needed to
/// establish the connection to the server
async fn connect(
&self,
destination: Destination,
extra: Extra,
auth: Option<&mut AuthClient>,
) -> io::Result<ConnectionId> {
let auth = auth.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Authentication client not initialized",
)
})?;
let scheme = match destination.scheme() {
Some(scheme) => {
trace!("Using scheme {}", scheme);
scheme
}
None => {
trace!(
"Using fallback scheme of {}",
self.config.fallback_scheme.as_str()
);
self.config.fallback_scheme.as_str()
}
}
.to_lowercase();
let (writer, reader) = {
let lock = self.connect_handlers.read().await;
let handler = lock.get(&scheme).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("No connect handler registered for {}", scheme),
)
})?;
handler.connect(&destination, &extra, auth).await?
};
let connection = DistantManagerConnection::new(destination, extra, writer, reader);
let id = connection.id;
self.connections.write().await.insert(id, connection);
Ok(id)
}
/// Retrieves information about the connection to the server with the specified `id`
async fn info(&self, id: ConnectionId) -> io::Result<ConnectionInfo> {
match self.connections.read().await.get(&id) {
Some(connection) => Ok(ConnectionInfo {
id: connection.id,
destination: connection.destination.clone(),
extra: connection.extra.clone(),
}),
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"No connection found",
)),
}
}
/// Retrieves a list of connections to servers
async fn list(&self) -> io::Result<ConnectionList> {
Ok(ConnectionList(
self.connections
.read()
.await
.values()
.map(|conn| (conn.id, conn.destination.clone()))
.collect(),
))
}
/// Kills the connection to the server with the specified `id`
async fn kill(&self, id: ConnectionId) -> io::Result<()> {
match self.connections.write().await.remove(&id) {
Some(_) => Ok(()),
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"No connection found",
)),
}
}
}
#[derive(Default)]
pub struct DistantManagerServerConnection {
/// Authentication client that manager can use when establishing a new connection
/// and needing to get authentication details from the client to move forward
auth_client: Option<Mutex<AuthClient>>,
/// Holds on to open channels feeding data back from a server to some connected client,
/// enabling us to cancel the tasks on demand
channels: RwLock<HashMap<ChannelId, DistantManagerChannel>>,
}
#[async_trait]
impl Server for DistantManager {
type Request = ManagerRequest;
type Response = ManagerResponse;
type LocalData = DistantManagerServerConnection;
async fn on_accept(&self, local_data: &mut Self::LocalData) {
local_data.auth_client = self
.auth_client_rx
.lock()
.await
.recv()
.await
.map(Mutex::new);
// Enable jit handshake
if let Some(auth_client) = local_data.auth_client.as_ref() {
auth_client.lock().await.set_jit_handshake(true);
}
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
let ServerCtx {
connection_id,
request,
reply,
local_data,
} = ctx;
let response = match request.payload {
ManagerRequest::Launch { destination, extra } => {
let mut auth = match local_data.auth_client.as_ref() {
Some(client) => Some(client.lock().await),
None => None,
};
match self.launch(*destination, extra, auth.as_deref_mut()).await {
Ok(destination) => ManagerResponse::Launched { destination },
Err(x) => ManagerResponse::Error(x.into()),
}
}
ManagerRequest::Connect { destination, extra } => {
let mut auth = match local_data.auth_client.as_ref() {
Some(client) => Some(client.lock().await),
None => None,
};
match self.connect(*destination, extra, auth.as_deref_mut()).await {
Ok(id) => ManagerResponse::Connected { id },
Err(x) => ManagerResponse::Error(x.into()),
}
}
ManagerRequest::OpenChannel { id } => match self.connections.read().await.get(&id) {
Some(connection) => match connection.open_channel(reply.clone()).await {
Ok(channel) => {
let id = channel.id();
local_data.channels.write().await.insert(id, channel);
ManagerResponse::ChannelOpened { id }
}
Err(x) => ManagerResponse::Error(x.into()),
},
None => ManagerResponse::Error(
io::Error::new(io::ErrorKind::NotConnected, "Connection does not exist").into(),
),
},
ManagerRequest::Channel { id, request } => {
match local_data.channels.read().await.get(&id) {
// TODO: For now, we are NOT sending back a response to acknowledge
// a successful channel send. We could do this in order for
// the client to listen for a complete send, but is it worth it?
Some(channel) => match channel.send(request).await {
Ok(_) => return,
Err(x) => ManagerResponse::Error(x.into()),
},
None => ManagerResponse::Error(
io::Error::new(
io::ErrorKind::NotConnected,
"Channel is not open or does not exist",
)
.into(),
),
}
}
ManagerRequest::CloseChannel { id } => {
match local_data.channels.write().await.remove(&id) {
Some(channel) => match channel.close().await {
Ok(_) => ManagerResponse::ChannelClosed { id },
Err(x) => ManagerResponse::Error(x.into()),
},
None => ManagerResponse::Error(
io::Error::new(
io::ErrorKind::NotConnected,
"Channel is not open or does not exist",
)
.into(),
),
}
}
ManagerRequest::Info { id } => match self.info(id).await {
Ok(info) => ManagerResponse::Info(info),
Err(x) => ManagerResponse::Error(x.into()),
},
ManagerRequest::List => match self.list().await {
Ok(list) => ManagerResponse::List(list),
Err(x) => ManagerResponse::Error(x.into()),
},
ManagerRequest::Kill { id } => match self.kill(id).await {
Ok(()) => ManagerResponse::Killed,
Err(x) => ManagerResponse::Error(x.into()),
},
ManagerRequest::Shutdown => {
if let Err(x) = reply.send(ManagerResponse::Shutdown).await {
error!("[Conn {}] {}", connection_id, x);
}
// Clear out handler state in order to trigger drops
self.launch_handlers.write().await.clear();
self.connect_handlers.write().await.clear();
// Shutdown the primary server task
self.task.abort();
// TODO: Perform a graceful shutdown instead of this?
// Review https://tokio.rs/tokio/topics/shutdown
std::process::exit(0);
}
};
if let Err(x) = reply.send(response).await {
error!("[Conn {}] {}", connection_id, x);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use distant_net::{
AuthClient, FramedTransport, HeapAuthServer, InmemoryTransport, IntoSplit, MappedListener,
OneshotListener, PlainCodec, ServerExt, ServerRef,
};
/// Create a new server, bypassing the start loop
fn setup() -> DistantManager {
let (_, rx) = mpsc::channel(1);
DistantManager {
auth_client_rx: Mutex::new(rx),
config: Default::default(),
connections: RwLock::new(HashMap::new()),
launch_handlers: Arc::new(RwLock::new(HashMap::new())),
connect_handlers: Arc::new(RwLock::new(HashMap::new())),
task: tokio::spawn(async move {}),
}
}
/// Creates a connected [`AuthClient`] with a launched auth server that blindly responds
fn auth_client_server() -> (AuthClient, Box<dyn ServerRef>) {
let (t1, t2) = FramedTransport::pair(1);
let client = AuthClient::from(Client::from_framed_transport(t1).unwrap());
// Create a server that does nothing, but will support
let server = HeapAuthServer {
on_challenge: Box::new(|_, _| Vec::new()),
on_verify: Box::new(|_, _| false),
on_info: Box::new(|_| ()),
on_error: Box::new(|_, _| ()),
}
.start(MappedListener::new(OneshotListener::from_value(t2), |t| {
t.into_split()
}))
.unwrap();
(client, server)
}
fn dummy_distant_writer_reader() -> (BoxedDistantWriter, BoxedDistantReader) {
setup_distant_writer_reader().0
}
/// Creates a writer & reader with a connected transport
fn setup_distant_writer_reader() -> (
(BoxedDistantWriter, BoxedDistantReader),
FramedTransport<InmemoryTransport, PlainCodec>,
) {
let (t1, t2) = FramedTransport::pair(1);
let (writer, reader) = t1.into_split();
((Box::new(writer), Box::new(reader)), t2)
}
#[tokio::test]
async fn launch_should_fail_if_destination_scheme_is_unsupported() {
let server = setup();
let destination = "scheme://host".parse::<Destination>().unwrap();
let extra = "".parse::<Extra>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let err = server
.launch(destination, extra, Some(&mut auth))
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
}
#[tokio::test]
async fn launch_should_fail_if_handler_tied_to_scheme_fails() {
let server = setup();
let handler: Box<dyn LaunchHandler> = Box::new(|_: &_, _: &_, _: &mut _| async {
Err(io::Error::new(io::ErrorKind::Other, "test failure"))
});
server
.launch_handlers
.write()
.await
.insert("scheme".to_string(), handler);
let destination = "scheme://host".parse::<Destination>().unwrap();
let extra = "".parse::<Extra>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let err = server
.launch(destination, extra, Some(&mut auth))
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::Other);
assert_eq!(err.to_string(), "test failure");
}
#[tokio::test]
async fn launch_should_return_new_destination_on_success() {
let server = setup();
let handler: Box<dyn LaunchHandler> = {
Box::new(|_: &_, _: &_, _: &mut _| async {
Ok("scheme2://host2".parse::<Destination>().unwrap())
})
};
server
.launch_handlers
.write()
.await
.insert("scheme".to_string(), handler);
let destination = "scheme://host".parse::<Destination>().unwrap();
let extra = "key=value".parse::<Extra>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let destination = server
.launch(destination, extra, Some(&mut auth))
.await
.unwrap();
assert_eq!(
destination,
"scheme2://host2".parse::<Destination>().unwrap()
);
}
#[tokio::test]
async fn connect_should_fail_if_destination_scheme_is_unsupported() {
let server = setup();
let destination = "scheme://host".parse::<Destination>().unwrap();
let extra = "".parse::<Extra>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let err = server
.connect(destination, extra, Some(&mut auth))
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
}
#[tokio::test]
async fn connect_should_fail_if_handler_tied_to_scheme_fails() {
let server = setup();
let handler: Box<dyn ConnectHandler> = Box::new(|_: &_, _: &_, _: &mut _| async {
Err(io::Error::new(io::ErrorKind::Other, "test failure"))
});
server
.connect_handlers
.write()
.await
.insert("scheme".to_string(), handler);
let destination = "scheme://host".parse::<Destination>().unwrap();
let extra = "".parse::<Extra>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let err = server
.connect(destination, extra, Some(&mut auth))
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::Other);
assert_eq!(err.to_string(), "test failure");
}
#[tokio::test]
async fn connect_should_return_id_of_new_connection_on_success() {
let server = setup();
let handler: Box<dyn ConnectHandler> =
Box::new(|_: &_, _: &_, _: &mut _| async { Ok(dummy_distant_writer_reader()) });
server
.connect_handlers
.write()
.await
.insert("scheme".to_string(), handler);
let destination = "scheme://host".parse::<Destination>().unwrap();
let extra = "key=value".parse::<Extra>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let id = server
.connect(destination, extra, Some(&mut auth))
.await
.unwrap();
let lock = server.connections.read().await;
let connection = lock.get(&id).unwrap();
assert_eq!(connection.id, id);
assert_eq!(connection.destination, "scheme://host".parse().unwrap());
assert_eq!(connection.extra, "key=value".parse().unwrap());
}
#[tokio::test]
async fn info_should_fail_if_no_connection_found_for_specified_id() {
let server = setup();
let err = server.info(999).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
}
#[tokio::test]
async fn info_should_return_information_about_established_connection() {
let server = setup();
let (writer, reader) = dummy_distant_writer_reader();
let connection = DistantManagerConnection::new(
"scheme://host".parse().unwrap(),
"key=value".parse().unwrap(),
writer,
reader,
);
let id = connection.id;
server.connections.write().await.insert(id, connection);
let info = server.info(id).await.unwrap();
assert_eq!(
info,
ConnectionInfo {
id,
destination: "scheme://host".parse().unwrap(),
extra: "key=value".parse().unwrap(),
}
);
}
#[tokio::test]
async fn list_should_return_empty_connection_list_if_no_established_connections() {
let server = setup();
let list = server.list().await.unwrap();
assert_eq!(list, ConnectionList(HashMap::new()));
}
#[tokio::test]
async fn list_should_return_a_list_of_established_connections() {
let server = setup();
let (writer, reader) = dummy_distant_writer_reader();
let connection = DistantManagerConnection::new(
"scheme://host".parse().unwrap(),
"key=value".parse().unwrap(),
writer,
reader,
);
let id_1 = connection.id;
server.connections.write().await.insert(id_1, connection);
let (writer, reader) = dummy_distant_writer_reader();
let connection = DistantManagerConnection::new(
"other://host2".parse().unwrap(),
"key=value".parse().unwrap(),
writer,
reader,
);
let id_2 = connection.id;
server.connections.write().await.insert(id_2, connection);
let list = server.list().await.unwrap();
assert_eq!(
list.get(&id_1).unwrap(),
&"scheme://host".parse::<Destination>().unwrap()
);
assert_eq!(
list.get(&id_2).unwrap(),
&"other://host2".parse::<Destination>().unwrap()
);
}
#[tokio::test]
async fn kill_should_fail_if_no_connection_found_for_specified_id() {
let server = setup();
let err = server.kill(999).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
}
#[tokio::test]
async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() {
let server = setup();
let (writer, reader) = dummy_distant_writer_reader();
let connection = DistantManagerConnection::new(
"scheme://host".parse().unwrap(),
"key=value".parse().unwrap(),
writer,
reader,
);
let id = connection.id;
server.connections.write().await.insert(id, connection);
server.kill(id).await.unwrap();
let lock = server.connections.read().await;
assert!(!lock.contains_key(&id), "Connection still exists");
}
}

@ -0,0 +1,31 @@
use crate::{BoxedConnectHandler, BoxedLaunchHandler};
use std::collections::HashMap;
pub struct DistantManagerConfig {
/// Scheme to use when none is provided in a destination
pub fallback_scheme: String,
/// Buffer size for queue of incoming connections before blocking
pub connection_buffer_size: usize,
/// If listening as local user
pub user: bool,
/// Handlers to use for launch requests
pub launch_handlers: HashMap<String, BoxedLaunchHandler>,
/// Handlers to use for connect requests
pub connect_handlers: HashMap<String, BoxedConnectHandler>,
}
impl Default for DistantManagerConfig {
fn default() -> Self {
Self {
fallback_scheme: "distant".to_string(),
connection_buffer_size: 100,
user: false,
launch_handlers: HashMap::new(),
connect_handlers: HashMap::new(),
}
}
}

@ -0,0 +1,201 @@
use crate::{
manager::{
data::{ChannelId, ConnectionId, Destination, Extra},
BoxedDistantReader, BoxedDistantWriter,
},
DistantMsg, DistantRequestData, DistantResponseData, ManagerResponse,
};
use distant_net::{Request, Response, ServerReply};
use log::*;
use std::{collections::HashMap, io};
use tokio::{sync::mpsc, task::JoinHandle};
/// Represents a connection a distant manager has with some distant-compatible server
pub struct DistantManagerConnection {
pub id: ConnectionId,
pub destination: Destination,
pub extra: Extra,
tx: mpsc::Sender<StateMachine>,
reader_task: JoinHandle<()>,
writer_task: JoinHandle<()>,
}
#[derive(Clone)]
pub struct DistantManagerChannel {
channel_id: ChannelId,
tx: mpsc::Sender<StateMachine>,
}
impl DistantManagerChannel {
pub fn id(&self) -> ChannelId {
self.channel_id
}
pub async fn send(&self, request: Request<DistantMsg<DistantRequestData>>) -> io::Result<()> {
let channel_id = self.channel_id;
self.tx
.send(StateMachine::Write {
id: channel_id,
request,
})
.await
.map_err(|x| {
io::Error::new(
io::ErrorKind::BrokenPipe,
format!("channel {} send failed: {}", channel_id, x),
)
})
}
pub async fn close(&self) -> io::Result<()> {
let channel_id = self.channel_id;
self.tx
.send(StateMachine::Unregister { id: channel_id })
.await
.map_err(|x| {
io::Error::new(
io::ErrorKind::BrokenPipe,
format!("channel {} close failed: {}", channel_id, x),
)
})
}
}
enum StateMachine {
Register {
id: ChannelId,
reply: ServerReply<ManagerResponse>,
},
Unregister {
id: ChannelId,
},
Read {
response: Response<DistantMsg<DistantResponseData>>,
},
Write {
id: ChannelId,
request: Request<DistantMsg<DistantRequestData>>,
},
}
impl DistantManagerConnection {
pub fn new(
destination: Destination,
extra: Extra,
mut writer: BoxedDistantWriter,
mut reader: BoxedDistantReader,
) -> Self {
let connection_id = rand::random();
let (tx, mut rx) = mpsc::channel(1);
let reader_task = {
let tx = tx.clone();
tokio::spawn(async move {
loop {
match reader.read().await {
Ok(Some(response)) => {
if tx.send(StateMachine::Read { response }).await.is_err() {
break;
}
}
Ok(None) => break,
Err(x) => {
error!("[Conn {}] {}", connection_id, x);
continue;
}
}
}
})
};
let writer_task = tokio::spawn(async move {
let mut registered = HashMap::new();
while let Some(state_machine) = rx.recv().await {
match state_machine {
StateMachine::Register { id, reply } => {
registered.insert(id, reply);
}
StateMachine::Unregister { id } => {
registered.remove(&id);
}
StateMachine::Read { mut response } => {
// Split {channel id}_{request id} back into pieces and
// update the origin id to match the request id only
let channel_id = match response.origin_id.split_once('_') {
Some((cid_str, oid_str)) => {
if let Ok(cid) = cid_str.parse::<ChannelId>() {
response.origin_id = oid_str.to_string();
cid
} else {
continue;
}
}
None => continue,
};
if let Some(reply) = registered.get(&channel_id) {
let response = ManagerResponse::Channel {
id: channel_id,
response,
};
if let Err(x) = reply.send(response).await {
error!("[Conn {}] {}", connection_id, x);
}
}
}
StateMachine::Write { id, request } => {
// Combine channel id with request id so we can properly forward
// the response containing this in the origin id
let request = Request {
id: format!("{}_{}", id, request.id),
payload: request.payload,
};
if let Err(x) = writer.write(request).await {
error!("[Conn {}] {}", connection_id, x);
}
}
}
}
});
Self {
id: connection_id,
destination,
extra,
tx,
reader_task,
writer_task,
}
}
pub async fn open_channel(
&self,
reply: ServerReply<ManagerResponse>,
) -> io::Result<DistantManagerChannel> {
let channel_id = rand::random();
self.tx
.send(StateMachine::Register {
id: channel_id,
reply,
})
.await
.map_err(|x| {
io::Error::new(
io::ErrorKind::BrokenPipe,
format!("open_channel failed: {}", x),
)
})?;
Ok(DistantManagerChannel {
channel_id,
tx: self.tx.clone(),
})
}
}
impl Drop for DistantManagerConnection {
fn drop(&mut self) {
self.reader_task.abort();
self.writer_task.abort();
}
}

@ -0,0 +1,14 @@
mod tcp;
pub use tcp::*;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
pub use unix::*;
#[cfg(windows)]
mod windows;
#[cfg(windows)]
pub use windows::*;

@ -0,0 +1,30 @@
use crate::{DistantManager, DistantManagerConfig};
use distant_net::{
Codec, FramedTransport, IntoSplit, MappedListener, PortRange, TcpListener, TcpServerRef,
};
use std::{io, net::IpAddr};
impl DistantManager {
/// Start a new server by binding to the given IP address and one of the ports in the
/// specified range, mapping all connections to use the given codec
pub async fn start_tcp<P, C>(
config: DistantManagerConfig,
addr: IpAddr,
port: P,
codec: C,
) -> io::Result<TcpServerRef>
where
P: Into<PortRange> + Send,
C: Codec + Send + Sync + 'static,
{
let listener = TcpListener::bind(addr, port).await?;
let port = listener.port();
let listener = MappedListener::new(listener, move |transport| {
let transport = FramedTransport::new(transport, codec.clone());
transport.into_split()
});
let inner = DistantManager::start(config, listener)?;
Ok(TcpServerRef::new(addr, port, Box::new(inner)))
}
}

@ -0,0 +1,50 @@
use crate::{DistantManager, DistantManagerConfig};
use distant_net::{
Codec, FramedTransport, IntoSplit, MappedListener, UnixSocketListener, UnixSocketServerRef,
};
use std::{io, path::Path};
impl DistantManager {
/// Start a new server using the specified path as a unix socket using default unix socket file
/// permissions
pub async fn start_unix_socket<P, C>(
config: DistantManagerConfig,
path: P,
codec: C,
) -> io::Result<UnixSocketServerRef>
where
P: AsRef<Path> + Send,
C: Codec + Send + Sync + 'static,
{
Self::start_unix_socket_with_permissions(
config,
path,
codec,
UnixSocketListener::default_unix_socket_file_permissions(),
)
.await
}
/// Start a new server using the specified path as a unix socket and `mode` as the unix socket
/// file permissions
pub async fn start_unix_socket_with_permissions<P, C>(
config: DistantManagerConfig,
path: P,
codec: C,
mode: u32,
) -> io::Result<UnixSocketServerRef>
where
P: AsRef<Path> + Send,
C: Codec + Send + Sync + 'static,
{
let listener = UnixSocketListener::bind_with_permissions(path, mode).await?;
let path = listener.path().to_path_buf();
let listener = MappedListener::new(listener, move |transport| {
let transport = FramedTransport::new(transport, codec.clone());
transport.into_split()
});
let inner = DistantManager::start(config, listener)?;
Ok(UnixSocketServerRef::new(path, Box::new(inner)))
}
}

@ -0,0 +1,48 @@
use crate::{DistantManager, DistantManagerConfig};
use distant_net::{
Codec, FramedTransport, IntoSplit, MappedListener, WindowsPipeListener, WindowsPipeServerRef,
};
use std::{
ffi::{OsStr, OsString},
io,
};
impl DistantManager {
/// Start a new server at the specified address via `\\.\pipe\{name}` using the given codec
pub async fn start_local_named_pipe<N, C>(
config: DistantManagerConfig,
name: N,
codec: C,
) -> io::Result<WindowsPipeServerRef>
where
Self: Sized,
N: AsRef<OsStr> + Send,
C: Codec + Send + Sync + 'static,
{
let mut addr = OsString::from(r"\\.\pipe\");
addr.push(name.as_ref());
Self::start_named_pipe(config, addr, codec).await
}
/// Start a new server at the specified pipe address using the given codec
pub async fn start_named_pipe<A, C>(
config: DistantManagerConfig,
addr: A,
codec: C,
) -> io::Result<WindowsPipeServerRef>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + Sync + 'static,
{
let a = addr.as_ref();
let listener = WindowsPipeListener::bind(a)?;
let addr = listener.addr().to_os_string();
let listener = MappedListener::new(listener, move |transport| {
let transport = FramedTransport::new(transport, codec.clone());
transport.into_split()
});
let inner = DistantManager::start(config, listener)?;
Ok(WindowsPipeServerRef::new(addr, Box::new(inner)))
}
}

@ -0,0 +1,69 @@
use crate::{
manager::data::{Destination, Extra},
DistantMsg, DistantRequestData, DistantResponseData,
};
use async_trait::async_trait;
use distant_net::{AuthClient, Request, Response, TypedAsyncRead, TypedAsyncWrite};
use std::{future::Future, io};
pub type BoxedDistantWriter =
Box<dyn TypedAsyncWrite<Request<DistantMsg<DistantRequestData>>> + Send>;
pub type BoxedDistantReader =
Box<dyn TypedAsyncRead<Response<DistantMsg<DistantResponseData>>> + Send>;
pub type BoxedDistantWriterReader = (BoxedDistantWriter, BoxedDistantReader);
pub type BoxedLaunchHandler = Box<dyn LaunchHandler>;
pub type BoxedConnectHandler = Box<dyn ConnectHandler>;
/// Used to launch a server at the specified destination, returning some result as a vec of bytes
#[async_trait]
pub trait LaunchHandler: Send + Sync {
async fn launch(
&self,
destination: &Destination,
extra: &Extra,
auth_client: &mut AuthClient,
) -> io::Result<Destination>;
}
#[async_trait]
impl<F, R> LaunchHandler for F
where
F: for<'a> Fn(&'a Destination, &'a Extra, &'a mut AuthClient) -> R + Send + Sync + 'static,
R: Future<Output = io::Result<Destination>> + Send + 'static,
{
async fn launch(
&self,
destination: &Destination,
extra: &Extra,
auth_client: &mut AuthClient,
) -> io::Result<Destination> {
self(destination, extra, auth_client).await
}
}
/// Used to connect to a destination, returning a connected reader and writer pair
#[async_trait]
pub trait ConnectHandler: Send + Sync {
async fn connect(
&self,
destination: &Destination,
extra: &Extra,
auth_client: &mut AuthClient,
) -> io::Result<BoxedDistantWriterReader>;
}
#[async_trait]
impl<F, R> ConnectHandler for F
where
F: for<'a> Fn(&'a Destination, &'a Extra, &'a mut AuthClient) -> R + Send + Sync + 'static,
R: Future<Output = io::Result<BoxedDistantWriterReader>> + Send + 'static,
{
async fn connect(
&self,
destination: &Destination,
extra: &Extra,
auth_client: &mut AuthClient,
) -> io::Result<BoxedDistantWriterReader> {
self(destination, extra, auth_client).await
}
}

@ -0,0 +1,73 @@
use super::{BoxedConnectHandler, BoxedLaunchHandler, ConnectHandler, LaunchHandler};
use distant_net::{ServerRef, ServerState};
use std::{collections::HashMap, io, sync::Weak};
use tokio::sync::RwLock;
/// Reference to a distant manager's server instance
pub struct DistantManagerRef {
/// Mapping of "scheme" -> handler
pub(crate) launch_handlers: Weak<RwLock<HashMap<String, BoxedLaunchHandler>>>,
/// Mapping of "scheme" -> handler
pub(crate) connect_handlers: Weak<RwLock<HashMap<String, BoxedConnectHandler>>>,
pub(crate) inner: Box<dyn ServerRef>,
}
impl DistantManagerRef {
/// Registers a new [`LaunchHandler`] for the specified scheme (e.g. "distant" or "ssh")
pub async fn register_launch_handler(
&self,
scheme: impl Into<String>,
handler: impl LaunchHandler + 'static,
) -> io::Result<()> {
let handlers = Weak::upgrade(&self.launch_handlers).ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Handler reference is no longer available",
)
})?;
handlers
.write()
.await
.insert(scheme.into(), Box::new(handler));
Ok(())
}
/// Registers a new [`ConnectHandler`] for the specified scheme (e.g. "distant" or "ssh")
pub async fn register_connect_handler(
&self,
scheme: impl Into<String>,
handler: impl ConnectHandler + 'static,
) -> io::Result<()> {
let handlers = Weak::upgrade(&self.connect_handlers).ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Handler reference is no longer available",
)
})?;
handlers
.write()
.await
.insert(scheme.into(), Box::new(handler));
Ok(())
}
}
impl ServerRef for DistantManagerRef {
fn state(&self) -> &ServerState {
self.inner.state()
}
fn is_finished(&self) -> bool {
self.inner.is_finished()
}
fn abort(&self) {
self.inner.abort();
}
}

@ -1,162 +0,0 @@
use super::{Codec, DataStream, Transport};
use futures::stream::Stream;
use log::*;
use std::{future::Future, pin::Pin};
use tokio::{
io,
net::{TcpListener, TcpStream},
sync::mpsc,
task::JoinHandle,
};
/// Represents a [`Stream`] consisting of newly-connected [`DataStream`] instances that
/// have been wrapped in [`Transport`]
pub struct TransportListener<T, U>
where
T: DataStream,
U: Codec,
{
listen_task: JoinHandle<()>,
accept_task: JoinHandle<()>,
rx: mpsc::Receiver<Transport<T, U>>,
}
impl<T, U> TransportListener<T, U>
where
T: DataStream + Send + 'static,
U: Codec + Send + 'static,
{
pub fn initialize<L, F>(listener: L, mut make_transport: F) -> Self
where
L: Listener<Output = T> + 'static,
F: FnMut(T) -> Transport<T, U> + Send + 'static,
{
let (stream_tx, mut stream_rx) = mpsc::channel::<T>(1);
let listen_task = tokio::spawn(async move {
loop {
match listener.accept().await {
Ok(stream) => {
if stream_tx.send(stream).await.is_err() {
error!("Listener failed to pass along stream");
break;
}
}
Err(x) => {
error!("Listener failed to accept stream: {}", x);
break;
}
}
}
});
let (tx, rx) = mpsc::channel::<Transport<T, U>>(1);
let accept_task = tokio::spawn(async move {
// Check if we have a new connection. If so, wrap it in a transport and forward
// it along to
while let Some(stream) = stream_rx.recv().await {
let transport = make_transport(stream);
if let Err(x) = tx.send(transport).await {
error!("Failed to forward transport: {}", x);
}
}
});
Self {
listen_task,
accept_task,
rx,
}
}
pub fn abort(&self) {
self.listen_task.abort();
self.accept_task.abort();
}
/// Waits for the next fully-initialized transport for an incoming stream to be available,
/// returning none if no longer accepting new connections
pub async fn accept(&mut self) -> Option<Transport<T, U>> {
self.rx.recv().await
}
/// Converts into a stream of transport-wrapped connections
pub fn into_stream(self) -> impl Stream<Item = Transport<T, U>> {
futures::stream::unfold(self, |mut _self| async move {
_self
.accept()
.await
.map(move |transport| (transport, _self))
})
}
}
pub type AcceptFuture<'a, T> = Pin<Box<dyn Future<Output = io::Result<T>> + Send + 'a>>;
/// Represents a type that has a listen interface for receiving raw streams
pub trait Listener: Send + Sync {
type Output;
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
where
Self: Sync + 'a;
}
impl Listener for TcpListener {
type Output = TcpStream;
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
where
Self: Sync + 'a,
{
async fn accept(_self: &TcpListener) -> io::Result<TcpStream> {
_self.accept().await.map(|(stream, _)| stream)
}
Box::pin(accept(self))
}
}
#[cfg(unix)]
impl Listener for tokio::net::UnixListener {
type Output = tokio::net::UnixStream;
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
where
Self: Sync + 'a,
{
async fn accept(_self: &tokio::net::UnixListener) -> io::Result<tokio::net::UnixStream> {
_self.accept().await.map(|(stream, _)| stream)
}
Box::pin(accept(self))
}
}
#[cfg(test)]
impl<T> Listener for tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>
where
T: DataStream + Send + Sync + 'static,
{
type Output = T;
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
where
Self: Sync + 'a,
{
async fn accept<T>(
_self: &tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>,
) -> io::Result<T>
where
T: DataStream + Send + Sync + 'static,
{
_self
.lock()
.await
.recv()
.await
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
}
Box::pin(accept(self))
}
}

@ -1,583 +0,0 @@
use super::{DataStream, PlainCodec, Transport};
use futures::ready;
use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::{
io::{self, AsyncRead, AsyncWrite, ReadBuf},
sync::mpsc,
};
/// Represents a data stream comprised of two inmemory channels
#[derive(Debug)]
pub struct InmemoryStream {
incoming: InmemoryStreamReadHalf,
outgoing: InmemoryStreamWriteHalf,
}
impl InmemoryStream {
pub fn new(incoming: mpsc::Receiver<Vec<u8>>, outgoing: mpsc::Sender<Vec<u8>>) -> Self {
Self {
incoming: InmemoryStreamReadHalf::new(incoming),
outgoing: InmemoryStreamWriteHalf::new(outgoing),
}
}
/// Returns (incoming_tx, outgoing_rx, stream)
pub fn make(buffer: usize) -> (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>, Self) {
let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer);
(
incoming_tx,
outgoing_rx,
Self::new(incoming_rx, outgoing_tx),
)
}
/// Returns pair of streams that are connected such that one sends to the other and
/// vice versa
pub fn pair(buffer: usize) -> (Self, Self) {
let (tx, rx, stream) = Self::make(buffer);
(stream, Self::new(rx, tx))
}
}
impl AsyncRead for InmemoryStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.incoming).poll_read(cx, buf)
}
}
impl AsyncWrite for InmemoryStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.outgoing).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.outgoing).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.outgoing).poll_shutdown(cx)
}
}
/// Read portion of an inmemory channel
#[derive(Debug)]
pub struct InmemoryStreamReadHalf {
rx: mpsc::Receiver<Vec<u8>>,
overflow: Vec<u8>,
}
impl InmemoryStreamReadHalf {
pub fn new(rx: mpsc::Receiver<Vec<u8>>) -> Self {
Self {
rx,
overflow: Vec::new(),
}
}
}
impl AsyncRead for InmemoryStreamReadHalf {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
// If we cannot fit any more into the buffer at the moment, we wait
if buf.remaining() == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"Cannot poll as buf.remaining() == 0",
)));
}
// If we have overflow from the last poll, put that in the buffer
if !self.overflow.is_empty() {
if self.overflow.len() > buf.remaining() {
let extra = self.overflow.split_off(buf.remaining());
buf.put_slice(&self.overflow);
self.overflow = extra;
} else {
buf.put_slice(&self.overflow);
self.overflow.clear();
}
return Poll::Ready(Ok(()));
}
// Otherwise, we poll for the next batch to read in
match ready!(self.rx.poll_recv(cx)) {
Some(mut x) => {
if x.len() > buf.remaining() {
self.overflow = x.split_off(buf.remaining());
}
buf.put_slice(&x);
Poll::Ready(Ok(()))
}
None => Poll::Ready(Ok(())),
}
}
}
/// Write portion of an inmemory channel
pub struct InmemoryStreamWriteHalf {
tx: Option<mpsc::Sender<Vec<u8>>>,
task: Option<Pin<Box<dyn Future<Output = io::Result<usize>> + Send + Sync + 'static>>>,
}
impl InmemoryStreamWriteHalf {
pub fn new(tx: mpsc::Sender<Vec<u8>>) -> Self {
Self {
tx: Some(tx),
task: None,
}
}
}
impl fmt::Debug for InmemoryStreamWriteHalf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InmemoryStreamWriteHalf")
.field("tx", &self.tx)
.field(
"task",
&if self.tx.is_some() {
"Some(...)"
} else {
"None"
},
)
.finish()
}
}
impl AsyncWrite for InmemoryStreamWriteHalf {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
match self.task.as_mut() {
Some(task) => {
let res = ready!(task.as_mut().poll(cx));
self.task.take();
return Poll::Ready(res);
}
None => match self.tx.as_mut() {
Some(tx) => {
let n = buf.len();
let tx_2 = tx.clone();
let data = buf.to_vec();
let task =
Box::pin(async move { tx_2.send(data).await.map(|_| n).or(Ok(0)) });
self.task.replace(task);
}
None => return Poll::Ready(Ok(0)),
},
}
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.tx.take();
self.task.take();
Poll::Ready(Ok(()))
}
}
impl DataStream for InmemoryStream {
type Read = InmemoryStreamReadHalf;
type Write = InmemoryStreamWriteHalf;
fn to_connection_tag(&self) -> String {
String::from("inmemory-stream")
}
fn into_split(self) -> (Self::Read, Self::Write) {
(self.incoming, self.outgoing)
}
}
impl Transport<InmemoryStream, PlainCodec> {
/// Produces a pair of inmemory transports that are connected to each other using
/// a standard codec
///
/// Sets the buffer for message passing for each underlying stream to the given buffer size
pub fn pair(
buffer: usize,
) -> (
Transport<InmemoryStream, PlainCodec>,
Transport<InmemoryStream, PlainCodec>,
) {
let (a, b) = InmemoryStream::pair(buffer);
let a = Transport::new(a, PlainCodec::new());
let b = Transport::new(b, PlainCodec::new());
(a, b)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[test]
fn to_connection_tag_should_be_hardcoded_string() {
let (_, _, stream) = InmemoryStream::make(1);
assert_eq!(stream.to_connection_tag(), "inmemory-stream");
}
#[tokio::test]
async fn make_should_return_sender_that_sends_data_to_stream() {
let (tx, _, mut stream) = InmemoryStream::make(3);
tx.send(b"test msg 1".to_vec()).await.unwrap();
tx.send(b"test msg 2".to_vec()).await.unwrap();
tx.send(b"test msg 3".to_vec()).await.unwrap();
// Should get data matching a singular message
let mut buf = [0; 256];
let len = stream.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 1");
// Next call would get the second message
let len = stream.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 2");
// When the last of the senders is dropped, we should still get
// the rest of the data that was sent first before getting
// an indicator that there is no more data
drop(tx);
let len = stream.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 3");
let len = stream.read(&mut buf).await.unwrap();
assert_eq!(len, 0, "Unexpectedly got more data");
}
#[tokio::test]
async fn make_should_return_receiver_that_receives_data_from_stream() {
let (_, mut rx, mut stream) = InmemoryStream::make(3);
stream.write_all(b"test msg 1").await.unwrap();
stream.write_all(b"test msg 2").await.unwrap();
stream.write_all(b"test msg 3").await.unwrap();
// Should get data matching a singular message
assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec()));
// Next call would get the second message
assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec()));
// When the stream is dropped, we should still get
// the rest of the data that was sent first before getting
// an indicator that there is no more data
drop(stream);
assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec()));
assert_eq!(rx.recv().await, None, "Unexpectedly got more data");
}
#[tokio::test]
async fn into_split_should_provide_a_read_half_that_receives_from_sender() {
let (tx, _, stream) = InmemoryStream::make(3);
let (mut read_half, _) = stream.into_split();
tx.send(b"test msg 1".to_vec()).await.unwrap();
tx.send(b"test msg 2".to_vec()).await.unwrap();
tx.send(b"test msg 3".to_vec()).await.unwrap();
// Should get data matching a singular message
let mut buf = [0; 256];
let len = read_half.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 1");
// Next call would get the second message
let len = read_half.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 2");
// When the last of the senders is dropped, we should still get
// the rest of the data that was sent first before getting
// an indicator that there is no more data
drop(tx);
let len = read_half.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 3");
let len = read_half.read(&mut buf).await.unwrap();
assert_eq!(len, 0, "Unexpectedly got more data");
}
#[tokio::test]
async fn into_split_should_provide_a_write_half_that_sends_to_receiver() {
let (_, mut rx, stream) = InmemoryStream::make(3);
let (_, mut write_half) = stream.into_split();
write_half.write_all(b"test msg 1").await.unwrap();
write_half.write_all(b"test msg 2").await.unwrap();
write_half.write_all(b"test msg 3").await.unwrap();
// Should get data matching a singular message
assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec()));
// Next call would get the second message
assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec()));
// When the stream is dropped, we should still get
// the rest of the data that was sent first before getting
// an indicator that there is no more data
drop(write_half);
assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec()));
assert_eq!(rx.recv().await, None, "Unexpectedly got more data");
}
#[tokio::test]
async fn read_half_should_fail_if_buf_has_no_space_remaining() {
let (_tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
let mut buf = [0u8; 0];
match t_read.read(&mut buf).await {
Err(x) if x.kind() == io::ErrorKind::Other => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn read_half_should_update_buf_with_all_overflow_from_last_read_if_it_all_fits() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
tx.send(vec![1, 2, 3]).await.expect("Failed to send");
let mut buf = [0u8; 2];
// First, read part of the data (first two bytes)
match t_read.read(&mut buf).await {
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]),
x => panic!("Unexpected result: {:?}", x),
}
// Second, we send more data because the last message was placed in overflow
tx.send(vec![4, 5, 6]).await.expect("Failed to send");
// Third, read remainder of the overflow from first message (third byte)
match t_read.read(&mut buf).await {
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[3]),
x => panic!("Unexpected result: {:?}", x),
}
// Fourth, verify that we start to receive the next overflow
match t_read.read(&mut buf).await {
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[4, 5]),
x => panic!("Unexpected result: {:?}", x),
}
// Fifth, verify that we get the last bit of overflow
match t_read.read(&mut buf).await {
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[6]),
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn read_half_should_update_buf_with_some_of_overflow_that_can_fit() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send");
let mut buf = [0u8; 2];
// First, read part of the data (first two bytes)
match t_read.read(&mut buf).await {
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]),
x => panic!("Unexpected result: {:?}", x),
}
// Second, we send more data because the last message was placed in overflow
tx.send(vec![6]).await.expect("Failed to send");
// Third, read next chunk of the overflow from first message (next two byte)
match t_read.read(&mut buf).await {
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[3, 4]),
x => panic!("Unexpected result: {:?}", x),
}
// Fourth, read last chunk of the overflow from first message (fifth byte)
match t_read.read(&mut buf).await {
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[5]),
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn read_half_should_update_buf_with_all_of_inner_channel_when_it_fits() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
let mut buf = [0u8; 5];
tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send");
// First, read all of data that fits exactly
match t_read.read(&mut buf).await {
Ok(n) if n == 5 => assert_eq!(&buf[..n], &[1, 2, 3, 4, 5]),
x => panic!("Unexpected result: {:?}", x),
}
tx.send(vec![6, 7, 8]).await.expect("Failed to send");
// Second, read data that fits within buf
match t_read.read(&mut buf).await {
Ok(n) if n == 3 => assert_eq!(&buf[..n], &[6, 7, 8]),
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn read_half_should_update_buf_with_some_of_inner_channel_that_can_fit_and_add_rest_to_overflow(
) {
let (tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
let mut buf = [0u8; 1];
tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send");
// Attempt a read that places more in overflow
match t_read.read(&mut buf).await {
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[1]),
x => panic!("Unexpected result: {:?}", x),
}
// Verify overflow contains the rest
assert_eq!(&t_read.overflow, &[2, 3, 4, 5]);
// Queue up extra data that will not be read until overflow is finished
tx.send(vec![6, 7, 8]).await.expect("Failed to send");
// Read next data point
match t_read.read(&mut buf).await {
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[2]),
x => panic!("Unexpected result: {:?}", x),
}
// Verify overflow contains the rest without having added extra data
assert_eq!(&t_read.overflow, &[3, 4, 5]);
}
#[tokio::test]
async fn read_half_should_yield_pending_if_no_data_available_on_inner_channel() {
let (_tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
let mut buf = [0u8; 1];
// Attempt a read that should yield ok with no change, which is what should
// happen when nothing is read into buf
let f = t_read.read(&mut buf);
tokio::pin!(f);
match futures::poll!(f) {
Poll::Pending => {}
x => panic!("Unexpected poll result: {:?}", x),
}
}
#[tokio::test]
async fn read_half_should_not_update_buf_if_inner_channel_closed() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
let mut buf = [0u8; 1];
// Drop the channel that would be sending data to the transport
drop(tx);
// Attempt a read that should yield ok with no change, which is what should
// happen when nothing is read into buf
match t_read.read(&mut buf).await {
Ok(n) if n == 0 => assert_eq!(&buf, &[0]),
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn write_half_should_return_buf_len_if_can_send_immediately() {
let (_tx, mut rx, stream) = InmemoryStream::make(1);
let (_t_read, mut t_write) = stream.into_split();
// Write that is not waiting should always succeed with full contents
let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write");
assert_eq!(n, 3, "Unexpected byte count returned");
// Verify we actually had the data sent
let data = rx.try_recv().expect("Failed to recv data");
assert_eq!(data, &[1, 2, 3]);
}
#[tokio::test]
async fn write_half_should_return_support_eventually_sending_by_retrying_when_not_ready() {
let (_tx, mut rx, stream) = InmemoryStream::make(1);
let (_t_read, mut t_write) = stream.into_split();
// Queue a write already so that we block on the next one
let _ = t_write.write(&[1, 2, 3]).await.expect("Failed to write");
// Verify that the next write is pending
let f = t_write.write(&[4, 5]);
tokio::pin!(f);
match futures::poll!(&mut f) {
Poll::Pending => {}
x => panic!("Unexpected poll result: {:?}", x),
}
// Consume first batch of data so future of second can continue
let data = rx.try_recv().expect("Failed to recv data");
assert_eq!(data, &[1, 2, 3]);
// Verify that poll now returns success
match futures::poll!(f) {
Poll::Ready(Ok(n)) if n == 2 => {}
x => panic!("Unexpected poll result: {:?}", x),
}
// Consume second batch of data
let data = rx.try_recv().expect("Failed to recv data");
assert_eq!(data, &[4, 5]);
}
#[tokio::test]
async fn write_half_should_zero_if_inner_channel_closed() {
let (_tx, rx, stream) = InmemoryStream::make(1);
let (_t_read, mut t_write) = stream.into_split();
// Drop receiving end that transport would talk to
drop(rx);
// Channel is dropped, so return 0 to indicate no bytes sent
let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write");
assert_eq!(n, 0, "Unexpected byte count returned");
}
}

@ -1,399 +0,0 @@
use crate::net::SecretKeyError;
use derive_more::{Display, Error, From};
use futures::{SinkExt, StreamExt};
use serde::{de::DeserializeOwned, Serialize};
use std::marker::Unpin;
use tokio::io::{self, AsyncRead, AsyncWrite};
use tokio_util::codec::{Framed, FramedRead, FramedWrite};
mod codec;
pub use codec::*;
mod inmemory;
pub use inmemory::*;
mod tcp;
pub use tcp::*;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
pub use unix::*;
#[derive(Debug, Display, Error, From)]
pub struct SerializeError(#[error(not(source))] String);
#[derive(Debug, Display, Error, From)]
pub struct DeserializeError(#[error(not(source))] String);
fn serialize_to_vec<T: Serialize>(value: &T) -> Result<Vec<u8>, SerializeError> {
let mut v = Vec::new();
let _ = ciborium::ser::into_writer(value, &mut v).map_err(|x| SerializeError(x.to_string()))?;
Ok(v)
}
fn deserialize_from_slice<T: DeserializeOwned>(slice: &[u8]) -> Result<T, DeserializeError> {
ciborium::de::from_reader(slice).map_err(|x| DeserializeError(x.to_string()))
}
#[derive(Debug, Display, Error, From)]
pub enum TransportError {
CryptoError(SecretKeyError),
IoError(io::Error),
SerializeError(SerializeError),
DeserializeError(DeserializeError),
}
/// Interface representing a two-way data stream
///
/// Enables splitting into separate, functioning halves that can read and write respectively
pub trait DataStream: AsyncRead + AsyncWrite + Unpin {
type Read: AsyncRead + Send + Unpin + 'static;
type Write: AsyncWrite + Send + Unpin + 'static;
/// Returns a textual description of the connection associated with this stream
fn to_connection_tag(&self) -> String;
/// Splits this stream into read and write halves
fn into_split(self) -> (Self::Read, Self::Write);
}
/// Represents a transport of data across the network
#[derive(Debug)]
pub struct Transport<T, U>(Framed<T, U>)
where
T: DataStream,
U: Codec;
impl<T, U> Transport<T, U>
where
T: DataStream,
U: Codec,
{
/// Creates a new instance of the transport, wrapping the stream in a `Framed<T, XChaCha20Poly1305Codec>`
pub fn new(stream: T, codec: U) -> Self {
Self(Framed::new(stream, codec))
}
/// Sends some data across the wire, waiting for it to completely send
pub async fn send<D: Serialize>(&mut self, data: D) -> Result<(), TransportError> {
// Serialize data into a byte stream
// NOTE: Cannot used packed implementation for now due to issues with deserialization
let data = serialize_to_vec(&data)?;
// Use underlying codec to send data (may encrypt, sign, etc.)
self.0.send(&data).await.map_err(TransportError::from)
}
/// Receives some data from out on the wire, waiting until it's available,
/// returning none if the transport is now closed
pub async fn receive<R: DeserializeOwned>(&mut self) -> Result<Option<R>, TransportError> {
// Use underlying codec to receive data (may decrypt, validate, etc.)
if let Some(data) = self.0.next().await {
let data = data?;
// Deserialize byte stream into our expected type
let data = deserialize_from_slice(&data)?;
Ok(Some(data))
} else {
Ok(None)
}
}
/// Returns a textual description of the transport's underlying connection
pub fn to_connection_tag(&self) -> String {
self.0.get_ref().to_connection_tag()
}
/// Returns a reference to the underlying I/O stream
///
/// Note that care should be taken to not tamper with the underlying stream of data coming in
/// as it may corrupt the stream of frames otherwise being worked with
pub fn get_ref(&self) -> &T {
self.0.get_ref()
}
/// Returns a reference to the underlying I/O stream
///
/// Note that care should be taken to not tamper with the underlying stream of data coming in
/// as it may corrupt the stream of frames otherwise being worked with
pub fn get_mut(&mut self) -> &mut T {
self.0.get_mut()
}
/// Consumes the transport, returning its underlying I/O stream
///
/// Note that care should be taken to not tamper with the underlying stream of data coming in
/// as it may corrupt the stream of frames otherwise being worked with.
pub fn into_inner(self) -> T {
self.0.into_inner()
}
/// Splits transport into read and write halves
pub fn into_split(
self,
) -> (
TransportReadHalf<T::Read, U>,
TransportWriteHalf<T::Write, U>,
) {
let parts = self.0.into_parts();
let (read_half, write_half) = parts.io.into_split();
// Create our split read half and populate its buffer with existing contents
let mut f_read = FramedRead::new(read_half, parts.codec.clone());
*f_read.read_buffer_mut() = parts.read_buf;
// Create our split write half and populate its buffer with existing contents
let mut f_write = FramedWrite::new(write_half, parts.codec);
*f_write.write_buffer_mut() = parts.write_buf;
let t_read = TransportReadHalf(f_read);
let t_write = TransportWriteHalf(f_write);
(t_read, t_write)
}
}
/// Represents a transport of data out to the network
pub struct TransportWriteHalf<T, U>(FramedWrite<T, U>)
where
T: AsyncWrite + Unpin,
U: Codec;
impl<T, U> TransportWriteHalf<T, U>
where
T: AsyncWrite + Unpin,
U: Codec,
{
/// Sends some data across the wire, waiting for it to completely send
pub async fn send<D: Serialize>(&mut self, data: D) -> Result<(), TransportError> {
// Serialize data into a byte stream
// NOTE: Cannot used packed implementation for now due to issues with deserialization
let data = serialize_to_vec(&data)?;
// Use underlying codec to send data (may encrypt, sign, etc.)
self.0.send(&data).await.map_err(TransportError::from)
}
}
/// Represents a transport of data in from the network
pub struct TransportReadHalf<T, U>(FramedRead<T, U>)
where
T: AsyncRead + Unpin,
U: Codec;
impl<T, U> TransportReadHalf<T, U>
where
T: AsyncRead + Unpin,
U: Codec,
{
/// Receives some data from out on the wire, waiting until it's available,
/// returning none if the transport is now closed
pub async fn receive<R: DeserializeOwned>(&mut self) -> Result<Option<R>, TransportError> {
// Use underlying codec to receive data (may decrypt, validate, etc.)
if let Some(data) = self.0.next().await {
let data = data?;
// Deserialize byte stream into our expected type
let data = deserialize_from_slice(&data)?;
Ok(Some(data))
} else {
Ok(None)
}
}
}
/// Test utilities
#[cfg(test)]
impl Transport<crate::net::InmemoryStream, crate::net::PlainCodec> {
/// Makes a connected pair of inmemory transports
pub fn make_pair() -> (
Transport<crate::net::InmemoryStream, crate::net::PlainCodec>,
Transport<crate::net::InmemoryStream, crate::net::PlainCodec>,
) {
Self::pair(crate::constants::test::BUFFER_SIZE)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct TestData {
name: String,
value: usize,
}
#[tokio::test]
async fn send_should_convert_data_into_byte_stream_and_send_through_stream() {
let (_tx, mut rx, stream) = InmemoryStream::make(1);
let mut transport = Transport::new(stream, PlainCodec::new());
let data = TestData {
name: String::from("test"),
value: 123,
};
let bytes = serialize_to_vec(&data).unwrap();
let len = (bytes.len() as u64).to_be_bytes();
let mut frame = Vec::new();
frame.extend(len.iter().copied());
frame.extend(bytes);
transport.send(data).await.unwrap();
let outgoing = rx.recv().await.unwrap();
assert_eq!(outgoing, frame);
}
#[tokio::test]
async fn receive_should_return_none_if_stream_is_closed() {
let (_, _, stream) = InmemoryStream::make(1);
let mut transport = Transport::new(stream, PlainCodec::new());
let result = transport.receive::<TestData>().await;
match result {
Ok(None) => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn receive_should_fail_if_unable_to_convert_to_type() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let mut transport = Transport::new(stream, PlainCodec::new());
#[derive(Serialize, Deserialize)]
struct OtherTestData(usize);
let data = OtherTestData(123);
let bytes = serialize_to_vec(&data).unwrap();
let len = (bytes.len() as u64).to_be_bytes();
let mut frame = Vec::new();
frame.extend(len.iter().copied());
frame.extend(bytes);
tx.send(frame).await.unwrap();
let result = transport.receive::<TestData>().await;
match result {
Err(TransportError::DeserializeError(_)) => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn receive_should_return_some_instance_of_type_when_coming_into_stream() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let mut transport = Transport::new(stream, PlainCodec::new());
let data = TestData {
name: String::from("test"),
value: 123,
};
let bytes = serialize_to_vec(&data).unwrap();
let len = (bytes.len() as u64).to_be_bytes();
let mut frame = Vec::new();
frame.extend(len.iter().copied());
frame.extend(bytes);
tx.send(frame).await.unwrap();
let received_data = transport.receive::<TestData>().await.unwrap().unwrap();
assert_eq!(received_data, data);
}
mod read_half {
use super::*;
#[tokio::test]
async fn receive_should_return_none_if_stream_is_closed() {
let (_, _, stream) = InmemoryStream::make(1);
let transport = Transport::new(stream, PlainCodec::new());
let (mut rh, _) = transport.into_split();
let result = rh.receive::<TestData>().await;
match result {
Ok(None) => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn receive_should_fail_if_unable_to_convert_to_type() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let transport = Transport::new(stream, PlainCodec::new());
let (mut rh, _) = transport.into_split();
#[derive(Serialize, Deserialize)]
struct OtherTestData(usize);
let data = OtherTestData(123);
let bytes = serialize_to_vec(&data).unwrap();
let len = (bytes.len() as u64).to_be_bytes();
let mut frame = Vec::new();
frame.extend(len.iter().copied());
frame.extend(bytes);
tx.send(frame).await.unwrap();
let result = rh.receive::<TestData>().await;
match result {
Err(TransportError::DeserializeError(_)) => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn receive_should_return_some_instance_of_type_when_coming_into_stream() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let transport = Transport::new(stream, PlainCodec::new());
let (mut rh, _) = transport.into_split();
let data = TestData {
name: String::from("test"),
value: 123,
};
let bytes = serialize_to_vec(&data).unwrap();
let len = (bytes.len() as u64).to_be_bytes();
let mut frame = Vec::new();
frame.extend(len.iter().copied());
frame.extend(bytes);
tx.send(frame).await.unwrap();
let received_data = rh.receive::<TestData>().await.unwrap().unwrap();
assert_eq!(received_data, data);
}
}
mod write_half {
use super::*;
#[tokio::test]
async fn send_should_convert_data_into_byte_stream_and_send_through_stream() {
let (_tx, mut rx, stream) = InmemoryStream::make(1);
let transport = Transport::new(stream, PlainCodec::new());
let (_, mut wh) = transport.into_split();
let data = TestData {
name: String::from("test"),
value: 123,
};
let bytes = serialize_to_vec(&data).unwrap();
let len = (bytes.len() as u64).to_be_bytes();
let mut frame = Vec::new();
frame.extend(len.iter().copied());
frame.extend(bytes);
wh.send(data).await.unwrap();
let outgoing = rx.recv().await.unwrap();
assert_eq!(outgoing, frame);
}
}
}

@ -1,38 +0,0 @@
use super::{Codec, DataStream, Transport};
use std::net::SocketAddr;
use tokio::{
io,
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream, ToSocketAddrs,
},
};
impl DataStream for TcpStream {
type Read = OwnedReadHalf;
type Write = OwnedWriteHalf;
fn to_connection_tag(&self) -> String {
self.peer_addr()
.map(|addr| format!("{}", addr))
.unwrap_or_else(|_| String::from("--"))
}
fn into_split(self) -> (Self::Read, Self::Write) {
TcpStream::into_split(self)
}
}
impl<U: Codec> Transport<TcpStream, U> {
/// Establishes a connection to one of the specified addresses and uses the provided codec
/// for transportation
pub async fn connect(addrs: impl ToSocketAddrs, codec: U) -> io::Result<Self> {
let stream = TcpStream::connect(addrs).await?;
Ok(Transport::new(stream, codec))
}
/// Returns the address of the peer the transport is connected to
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.0.get_ref().peer_addr()
}
}

@ -1,37 +0,0 @@
use super::{Codec, DataStream, Transport};
use tokio::{
io,
net::{
unix::{OwnedReadHalf, OwnedWriteHalf, SocketAddr},
UnixStream,
},
};
impl DataStream for UnixStream {
type Read = OwnedReadHalf;
type Write = OwnedWriteHalf;
fn to_connection_tag(&self) -> String {
self.peer_addr()
.map(|addr| format!("{:?}", addr))
.unwrap_or_else(|_| String::from("--"))
}
fn into_split(self) -> (Self::Read, Self::Write) {
UnixStream::into_split(self)
}
}
impl<U: Codec> Transport<UnixStream, U> {
/// Establishes a connection to the socket at the specified path and uses the provided codec
/// for transportation
pub async fn connect(path: impl AsRef<std::path::Path>, codec: U) -> io::Result<Self> {
let stream = UnixStream::connect(path.as_ref()).await?;
Ok(Transport::new(stream, codec))
}
/// Returns the address of the peer the transport is connected to
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.0.get_ref().peer_addr()
}
}

@ -0,0 +1,45 @@
use serde::{
de::{Deserializer, Error as SerdeError, Visitor},
ser::Serializer,
};
use std::{fmt, marker::PhantomData, str::FromStr};
/// From https://docs.rs/serde_with/1.14.0/src/serde_with/rust.rs.html#90-118
pub fn deserialize_from_str<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
T: FromStr,
T::Err: fmt::Display,
{
struct Helper<S>(PhantomData<S>);
impl<'de, S> Visitor<'de> for Helper<S>
where
S: FromStr,
<S as FromStr>::Err: fmt::Display,
{
type Value = S;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a string")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: SerdeError,
{
value.parse::<Self::Value>().map_err(SerdeError::custom)
}
}
deserializer.deserialize_str(Helper(PhantomData))
}
/// From https://docs.rs/serde_with/1.14.0/src/serde_with/rust.rs.html#121-127
pub fn serialize_to_str<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
T: fmt::Display,
S: Serializer,
{
serializer.collect_str(&value)
}

File diff suppressed because it is too large Load Diff

@ -1,362 +0,0 @@
mod handler;
mod process;
mod state;
pub(crate) use process::{InputChannel, ProcessKiller, ProcessPty};
use state::State;
use crate::{
constants::MAX_MSG_CAPACITY,
data::{Request, Response},
net::{Codec, DataStream, Transport, TransportListener, TransportReadHalf, TransportWriteHalf},
server::{
utils::{ConnTracker, ShutdownTask},
PortRange,
},
};
use futures::stream::{Stream, StreamExt};
use log::*;
use std::{net::IpAddr, sync::Arc};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::TcpListener,
sync::{mpsc, Mutex},
task::{JoinError, JoinHandle},
time::Duration,
};
/// Represents a server that listens for requests, processes them, and sends responses
pub struct DistantServer {
conn_task: JoinHandle<()>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct DistantServerOptions {
pub shutdown_after: Option<Duration>,
pub max_msg_capacity: usize,
}
impl Default for DistantServerOptions {
fn default() -> Self {
Self {
shutdown_after: None,
max_msg_capacity: MAX_MSG_CAPACITY,
}
}
}
impl DistantServer {
/// Bind to an IP address and port from the given range, taking an optional shutdown duration
/// that will shutdown the server if there is no active connection after duration
pub async fn bind<U>(
addr: IpAddr,
port: PortRange,
codec: U,
opts: DistantServerOptions,
) -> io::Result<(Self, u16)>
where
U: Codec + Send + 'static,
{
debug!("Binding to {} in range {}", addr, port);
let listener = TcpListener::bind(port.make_socket_addrs(addr).as_slice()).await?;
let port = listener.local_addr()?.port();
debug!("Bound to port: {}", port);
let stream = TransportListener::initialize(listener, move |stream| {
Transport::new(stream, codec.clone())
})
.into_stream();
Ok((Self::initialize(Box::pin(stream), opts), port))
}
/// Initialize a distant server using the provided listener
pub fn initialize<T, U, S>(stream: S, opts: DistantServerOptions) -> Self
where
T: DataStream + Send + 'static,
U: Codec + Send + 'static,
S: Stream<Item = Transport<T, U>> + Send + Unpin + 'static,
{
// Build our state for the server
let state: Arc<Mutex<State>> = Arc::new(Mutex::new(State::default()));
let (shutdown, tracker) = ShutdownTask::maybe_initialize(opts.shutdown_after);
// Spawn our connection task
let conn_task = tokio::spawn(async move {
connection_loop(stream, state, tracker, shutdown, opts.max_msg_capacity).await
});
Self { conn_task }
}
/// Waits for the server to terminate
pub async fn wait(self) -> Result<(), JoinError> {
self.conn_task.await
}
/// Aborts the server by aborting the internal task handling new connections
pub fn abort(&self) {
self.conn_task.abort();
}
}
async fn connection_loop<T, U, S>(
mut stream: S,
state: Arc<Mutex<State>>,
tracker: Option<Arc<Mutex<ConnTracker>>>,
shutdown: Option<ShutdownTask>,
max_msg_capacity: usize,
) where
T: DataStream + Send + 'static,
U: Codec + Send + 'static,
S: Stream<Item = Transport<T, U>> + Send + Unpin + 'static,
{
let inner = async move {
loop {
match stream.next().await {
Some(transport) => {
let conn_id = rand::random();
debug!(
"<Conn @ {}> Established against {}",
conn_id,
transport.to_connection_tag()
);
if let Err(x) = on_new_conn(
transport,
conn_id,
Arc::clone(&state),
tracker.as_ref().map(Arc::clone),
max_msg_capacity,
)
.await
{
error!("<Conn @ {}> Failed handshake: {}", conn_id, x);
}
}
None => {
info!("Listener shutting down");
break;
}
};
}
};
match shutdown {
Some(shutdown) => tokio::select! {
_ = inner => {}
_ = shutdown => {
warn!("Reached shutdown timeout, so terminating");
}
},
None => inner.await,
}
}
/// Processes a new connection, performing a handshake, and then spawning two tasks to handle
/// input and output, returning join handles for the input and output tasks respectively
async fn on_new_conn<T, U>(
transport: Transport<T, U>,
conn_id: usize,
state: Arc<Mutex<State>>,
tracker: Option<Arc<Mutex<ConnTracker>>>,
max_msg_capacity: usize,
) -> io::Result<JoinHandle<()>>
where
T: DataStream,
U: Codec + Send + 'static,
{
// Update our tracker to reflect the new connection
if let Some(ct) = tracker.as_ref() {
ct.lock().await.increment();
}
// Split the transport into read and write halves so we can handle input
// and output concurrently
let (t_read, t_write) = transport.into_split();
let (tx, rx) = mpsc::channel(max_msg_capacity);
// Spawn a new task that loops to handle requests from the client
let state_2 = Arc::clone(&state);
let req_task = tokio::spawn(async move {
request_loop(conn_id, state_2, t_read, tx).await;
});
// Spawn a new task that loops to handle responses to the client
let res_task = tokio::spawn(async move { response_loop(conn_id, t_write, rx).await });
// Spawn cleanup task that waits on our req & res tasks to complete
let cleanup_task = tokio::spawn(async move {
// Wait for both receiving and sending tasks to complete before marking
// the connection as complete
let _ = tokio::join!(req_task, res_task);
state.lock().await.cleanup_connection(conn_id).await;
if let Some(ct) = tracker.as_ref() {
ct.lock().await.decrement();
}
});
Ok(cleanup_task)
}
/// Repeatedly reads in new requests, processes them, and sends their responses to the
/// response loop
async fn request_loop<T, U>(
conn_id: usize,
state: Arc<Mutex<State>>,
mut transport: TransportReadHalf<T, U>,
tx: mpsc::Sender<Response>,
) where
T: AsyncRead + Send + Unpin + 'static,
U: Codec,
{
loop {
match transport.receive::<Request>().await {
Ok(Some(req)) => {
debug!(
"<Conn @ {}> Received request of type{} {}",
conn_id,
if req.payload.len() > 1 { "s" } else { "" },
req.to_payload_type_string()
);
if let Err(x) = handler::process(conn_id, Arc::clone(&state), req, tx.clone()).await
{
error!("<Conn @ {}> {}", conn_id, x);
break;
}
}
Ok(None) => {
trace!("<Conn @ {}> Input from connection closed", conn_id);
break;
}
Err(x) => {
error!("<Conn @ {}> {}", conn_id, x);
break;
}
}
}
// Properly close off any associated process' stdin given that we can't get new
// requests to send more stdin to them
state.lock().await.close_stdin_for_connection(conn_id);
}
/// Repeatedly sends responses out over the wire
async fn response_loop<T, U>(
conn_id: usize,
mut transport: TransportWriteHalf<T, U>,
mut rx: mpsc::Receiver<Response>,
) where
T: AsyncWrite + Send + Unpin + 'static,
U: Codec,
{
while let Some(res) = rx.recv().await {
debug!(
"<Conn @ {}> Sending response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string()
);
if let Err(x) = transport.send(res).await {
error!("<Conn @ {}> {}", conn_id, x);
break;
}
}
trace!("<Conn @ {}> Output to connection closed", conn_id);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
data::{RequestData, ResponseData},
net::{InmemoryStream, PlainCodec},
};
use std::pin::Pin;
#[allow(clippy::type_complexity)]
fn make_transport_stream() -> (
mpsc::Sender<Transport<InmemoryStream, PlainCodec>>,
Pin<Box<dyn Stream<Item = Transport<InmemoryStream, PlainCodec>> + Send>>,
) {
let (tx, rx) = mpsc::channel::<Transport<InmemoryStream, PlainCodec>>(1);
let stream = futures::stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(move |transport| (transport, rx))
});
(tx, Box::pin(stream))
}
#[tokio::test]
async fn wait_should_return_ok_when_all_inner_tasks_complete() {
let (tx, stream) = make_transport_stream();
let server = DistantServer::initialize(stream, Default::default());
// Conclude all server tasks by closing out the listener
drop(tx);
let result = server.wait().await;
assert!(result.is_ok(), "Unexpected result: {:?}", result);
}
#[tokio::test]
async fn wait_should_return_error_when_server_aborted() {
let (_tx, stream) = make_transport_stream();
let server = DistantServer::initialize(stream, Default::default());
server.abort();
match server.wait().await {
Err(x) if x.is_cancelled() => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn server_should_receive_requests_and_send_responses_to_appropriate_connections() {
let (tx, stream) = make_transport_stream();
let _server = DistantServer::initialize(stream, Default::default());
// Send over a "connection"
let (mut t1, t2) = Transport::make_pair();
tx.send(t2).await.unwrap();
// Send a request
t1.send(Request::new(
"test-tenant",
vec![RequestData::SystemInfo {}],
))
.await
.unwrap();
// Get a response
let res = t1.receive::<Response>().await.unwrap().unwrap();
assert!(res.payload.len() == 1, "Unexpected payload size");
assert!(
matches!(res.payload[0], ResponseData::SystemInfo { .. }),
"Unexpected response: {:?}",
res.payload[0]
);
}
#[tokio::test]
async fn server_should_shutdown_if_no_connections_after_shutdown_duration() {
let (_tx, stream) = make_transport_stream();
let server = DistantServer::initialize(
stream,
DistantServerOptions {
shutdown_after: Some(Duration::from_millis(50)),
max_msg_capacity: 1,
},
);
let result = server.wait().await;
assert!(result.is_ok(), "Unexpected result: {:?}", result);
}
}

@ -1,232 +0,0 @@
use super::{InputChannel, ProcessKiller, ProcessPty};
use crate::data::{ChangeKindSet, ResponseData};
use log::*;
use notify::RecommendedWatcher;
use std::{
collections::HashMap,
future::Future,
hash::{Hash, Hasher},
io,
ops::{Deref, DerefMut},
path::{Path, PathBuf},
pin::Pin,
};
pub type ReplyFn = Box<dyn FnMut(Vec<ResponseData>) -> ReplyRet + Send + 'static>;
pub type ReplyRet = Pin<Box<dyn Future<Output = bool> + Send + 'static>>;
/// Holds state related to multiple connections managed by a server
#[derive(Default)]
pub struct State {
/// Map of all processes running on the server
pub processes: HashMap<usize, ProcessState>,
/// List of processes that will be killed when a connection drops
client_processes: HashMap<usize, Vec<usize>>,
/// Watcher used for filesystem events
pub watcher: Option<RecommendedWatcher>,
/// Mapping of Path -> (Reply Fn, recursive) for watcher notifications
pub watcher_paths: HashMap<WatcherPath, ReplyFn>,
}
#[derive(Clone, Debug)]
pub struct WatcherPath {
/// The raw path provided to the watcher, which is not canonicalized
raw_path: PathBuf,
/// The canonicalized path at the time of providing to the watcher,
/// as all paths must exist for a watcher, we use this to get the
/// source of truth when watching
path: PathBuf,
/// Whether or not the path was set to be recursive
recursive: bool,
/// Specific filter for path
only: ChangeKindSet,
}
impl PartialEq for WatcherPath {
fn eq(&self, other: &Self) -> bool {
self.path == other.path
}
}
impl Eq for WatcherPath {}
impl Hash for WatcherPath {
fn hash<H: Hasher>(&self, state: &mut H) {
self.path.hash(state);
}
}
impl Deref for WatcherPath {
type Target = PathBuf;
fn deref(&self) -> &Self::Target {
&self.path
}
}
impl DerefMut for WatcherPath {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.path
}
}
impl WatcherPath {
/// Create a new watcher path using the given path and canonicalizing it
pub fn new(
path: impl Into<PathBuf>,
recursive: bool,
only: impl Into<ChangeKindSet>,
) -> io::Result<Self> {
let raw_path = path.into();
let path = raw_path.canonicalize()?;
let only = only.into();
Ok(Self {
raw_path,
path,
recursive,
only,
})
}
pub fn raw_path(&self) -> &Path {
self.raw_path.as_path()
}
pub fn path(&self) -> &Path {
self.path.as_path()
}
/// Returns true if this watcher path applies to the given path.
/// This is accomplished by checking if the path is contained
/// within either the raw or canonicalized path of the watcher
/// and ensures that recursion rules are respected
pub fn applies_to_path(&self, path: &Path) -> bool {
let check_path = |path: &Path| -> bool {
let cnt = path.components().count();
// 0 means exact match from strip_prefix
// 1 means that it was within immediate directory (fine for non-recursive)
// 2+ means it needs to be recursive
cnt < 2 || self.recursive
};
match (
path.strip_prefix(self.path()),
path.strip_prefix(self.raw_path()),
) {
(Ok(p1), Ok(p2)) => check_path(p1) || check_path(p2),
(Ok(p), Err(_)) => check_path(p),
(Err(_), Ok(p)) => check_path(p),
(Err(_), Err(_)) => false,
}
}
}
/// Holds information related to a spawned process on the server
pub struct ProcessState {
pub cmd: String,
pub args: Vec<String>,
pub persist: bool,
pub id: usize,
pub stdin: Option<Box<dyn InputChannel>>,
pub killer: Box<dyn ProcessKiller>,
pub pty: Box<dyn ProcessPty>,
}
impl State {
pub fn map_paths_to_watcher_paths_and_replies<'a>(
&mut self,
paths: &'a [PathBuf],
) -> Vec<(Vec<&'a PathBuf>, &ChangeKindSet, &mut ReplyFn)> {
let mut results = Vec::new();
for (wp, reply) in self.watcher_paths.iter_mut() {
let mut wp_paths = Vec::new();
for path in paths {
if wp.applies_to_path(path) {
wp_paths.push(path);
}
}
if !wp_paths.is_empty() {
results.push((wp_paths, &wp.only, reply));
}
}
results
}
/// Pushes a new process associated with a connection
pub fn push_process_state(&mut self, conn_id: usize, process_state: ProcessState) {
self.client_processes
.entry(conn_id)
.or_insert_with(Vec::new)
.push(process_state.id);
self.processes.insert(process_state.id, process_state);
}
/// Removes a process associated with a connection
pub fn remove_process(&mut self, conn_id: usize, proc_id: usize) {
self.client_processes.entry(conn_id).and_modify(|v| {
if let Some(pos) = v.iter().position(|x| *x == proc_id) {
v.remove(pos);
}
});
self.processes.remove(&proc_id);
}
/// Closes stdin for all processes associated with the connection
pub fn close_stdin_for_connection(&mut self, conn_id: usize) {
debug!("<Conn @ {:?}> Closing stdin to all processes", conn_id);
if let Some(ids) = self.client_processes.get(&conn_id) {
for id in ids {
if let Some(process) = self.processes.get_mut(id) {
trace!(
"<Conn @ {:?}> Closing stdin for proc {}",
conn_id,
process.id
);
let _ = process.stdin.take();
}
}
}
}
/// Cleans up state associated with a particular connection
pub async fn cleanup_connection(&mut self, conn_id: usize) {
debug!("<Conn @ {:?}> Cleaning up state", conn_id);
if let Some(ids) = self.client_processes.remove(&conn_id) {
for id in ids {
if let Some(mut process) = self.processes.remove(&id) {
if !process.persist {
trace!(
"<Conn @ {:?}> Requesting proc {} be killed",
conn_id,
process.id
);
let pid = process.id;
if let Err(x) = process.killer.kill().await {
error!(
"Conn {} failed to send process {} kill signal: {}",
id, pid, x
);
}
} else {
trace!(
"<Conn @ {:?}> Proc {} is persistent and will not be killed",
conn_id,
process.id
);
}
}
}
}
}
}

@ -1,8 +0,0 @@
mod distant;
mod port;
mod relay;
mod utils;
pub use self::distant::{DistantServer, DistantServerOptions};
pub use port::PortRange;
pub use relay::RelayServer;

@ -1,382 +0,0 @@
use crate::{
client::{Session, SessionChannel},
data::{Request, RequestData, ResponseData},
net::{Codec, DataStream, Transport},
server::utils::{ConnTracker, ShutdownTask},
};
use futures::stream::{Stream, StreamExt};
use log::*;
use std::{collections::HashMap, marker::Unpin, sync::Arc};
use tokio::{
io,
sync::{oneshot, Mutex},
task::{JoinError, JoinHandle},
time::Duration,
};
/// Represents a server that relays requests & responses between connections and the
/// actual server
pub struct RelayServer {
accept_task: JoinHandle<()>,
conns: Arc<Mutex<HashMap<usize, Conn>>>,
}
impl RelayServer {
pub fn initialize<T, U, S>(
session: Session,
mut stream: S,
shutdown_after: Option<Duration>,
) -> io::Result<Self>
where
T: DataStream + Send + 'static,
U: Codec + Send + 'static,
S: Stream<Item = Transport<T, U>> + Send + Unpin + 'static,
{
let conns: Arc<Mutex<HashMap<usize, Conn>>> = Arc::new(Mutex::new(HashMap::new()));
let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after);
let conns_2 = Arc::clone(&conns);
let accept_task = tokio::spawn(async move {
let inner = async move {
loop {
let channel = session.clone_channel();
match stream.next().await {
Some(transport) => {
let result = Conn::initialize(
transport,
channel,
tracker.as_ref().map(Arc::clone),
)
.await;
match result {
Ok(conn) => {
conns_2.lock().await.insert(conn.id(), conn);
}
Err(x) => {
error!("Failed to initialize connection: {}", x);
}
};
}
None => {
info!("Listener shutting down");
break;
}
};
}
};
match shutdown {
Some(shutdown) => tokio::select! {
_ = inner => {}
_ = shutdown => {
warn!("Reached shutdown timeout, so terminating");
}
},
None => inner.await,
}
});
Ok(Self { accept_task, conns })
}
/// Waits for the server to terminate
pub async fn wait(self) -> Result<(), JoinError> {
self.accept_task.await
}
/// Aborts the server by aborting the internal tasks and current connections
pub async fn abort(&self) {
self.accept_task.abort();
self.conns
.lock()
.await
.values()
.for_each(|conn| conn.abort());
}
}
struct Conn {
id: usize,
conn_task: JoinHandle<()>,
}
impl Conn {
pub async fn initialize<T, U>(
transport: Transport<T, U>,
channel: SessionChannel,
ct: Option<Arc<Mutex<ConnTracker>>>,
) -> io::Result<Self>
where
T: DataStream + 'static,
U: Codec + Send + 'static,
{
// Create a unique id to associate with the connection since its address
// is not guaranteed to have an identifiable string
let id: usize = rand::random();
// Mark that we have a new connection
if let Some(ct) = ct.as_ref() {
ct.lock().await.increment();
}
let conn_task = spawn_conn_handler(id, transport, channel, ct).await;
Ok(Self { id, conn_task })
}
/// Id associated with the connection
pub fn id(&self) -> usize {
self.id
}
/// Aborts the connection from the server side
pub fn abort(&self) {
self.conn_task.abort();
}
}
async fn spawn_conn_handler<T, U>(
conn_id: usize,
transport: Transport<T, U>,
mut channel: SessionChannel,
ct: Option<Arc<Mutex<ConnTracker>>>,
) -> JoinHandle<()>
where
T: DataStream,
U: Codec + Send + 'static,
{
let (mut t_reader, t_writer) = transport.into_split();
let processes = Arc::new(Mutex::new(Vec::new()));
let t_writer = Arc::new(Mutex::new(t_writer));
let (done_tx, done_rx) = oneshot::channel();
let mut channel_2 = channel.clone();
let processes_2 = Arc::clone(&processes);
let task = tokio::spawn(async move {
loop {
if channel_2.is_closed() {
break;
}
// For each request, forward it through the session and monitor all responses
match t_reader.receive::<Request>().await {
Ok(Some(req)) => match channel_2.mail(req).await {
Ok(mut mailbox) => {
let processes = Arc::clone(&processes_2);
let t_writer = Arc::clone(&t_writer);
tokio::spawn(async move {
while let Some(res) = mailbox.next().await {
// Keep track of processes that are started so we can kill them
// when we're done
{
let mut p_lock = processes.lock().await;
for data in res.payload.iter() {
if let ResponseData::ProcSpawned { id } = *data {
p_lock.push(id);
}
}
}
if let Err(x) = t_writer.lock().await.send(res).await {
error!(
"<Conn @ {}> Failed to send response back: {}",
conn_id, x
);
}
}
});
}
Err(x) => error!(
"<Conn @ {}> Failed to pass along request received on unix socket: {:?}",
conn_id, x
),
},
Ok(None) => break,
Err(x) => {
error!(
"<Conn @ {}> Failed to receive request from unix stream: {:?}",
conn_id, x
);
break;
}
}
}
let _ = done_tx.send(());
});
// Perform cleanup if done by sending a request to kill each running process
tokio::spawn(async move {
let _ = done_rx.await;
let p_lock = processes.lock().await;
if !p_lock.is_empty() {
trace!(
"Cleaning conn {} :: killing {} process",
conn_id,
p_lock.len()
);
if let Err(x) = channel
.fire(Request::new(
"relay",
p_lock
.iter()
.map(|id| RequestData::ProcKill { id: *id })
.collect(),
))
.await
{
error!("<Conn @ {}> Failed to send kill signals: {}", conn_id, x);
}
}
if let Some(ct) = ct.as_ref() {
ct.lock().await.decrement();
}
debug!("<Conn @ {}> Disconnected", conn_id);
});
task
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
data::Response,
net::{InmemoryStream, PlainCodec},
};
use std::{pin::Pin, time::Duration};
use tokio::sync::mpsc;
fn make_session() -> (Transport<InmemoryStream, PlainCodec>, Session) {
let (t1, t2) = Transport::make_pair();
(t1, Session::initialize(t2).unwrap())
}
#[allow(clippy::type_complexity)]
fn make_transport_stream() -> (
mpsc::Sender<Transport<InmemoryStream, PlainCodec>>,
Pin<Box<dyn Stream<Item = Transport<InmemoryStream, PlainCodec>> + Send>>,
) {
let (tx, rx) = mpsc::channel::<Transport<InmemoryStream, PlainCodec>>(1);
let stream = futures::stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(move |transport| (transport, rx))
});
(tx, Box::pin(stream))
}
#[tokio::test]
async fn wait_should_return_ok_when_all_inner_tasks_complete() {
let (transport, session) = make_session();
let (tx, stream) = make_transport_stream();
let server = RelayServer::initialize(session, stream, None).unwrap();
// Conclude all server tasks by closing out the listener & session
drop(transport);
drop(tx);
let result = server.wait().await;
assert!(result.is_ok(), "Unexpected result: {:?}", result);
}
#[tokio::test]
async fn wait_should_return_error_when_server_aborted() {
let (_transport, session) = make_session();
let (_tx, stream) = make_transport_stream();
let server = RelayServer::initialize(session, stream, None).unwrap();
server.abort().await;
match server.wait().await {
Err(x) if x.is_cancelled() => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn server_should_forward_requests_using_session() {
let (mut transport, session) = make_session();
let (tx, stream) = make_transport_stream();
let _server = RelayServer::initialize(session, stream, None).unwrap();
// Send over a "connection"
let (mut t1, t2) = Transport::make_pair();
tx.send(t2).await.unwrap();
// Send a request
let req = Request::new("test-tenant", vec![RequestData::SystemInfo {}]);
t1.send(req.clone()).await.unwrap();
// Verify the request is forwarded out via session
let outbound_req = transport.receive().await.unwrap().unwrap();
assert_eq!(req, outbound_req);
}
#[tokio::test]
async fn server_should_send_back_response_with_tenant_matching_connection() {
let (mut transport, session) = make_session();
let (tx, stream) = make_transport_stream();
let _server = RelayServer::initialize(session, stream, None).unwrap();
// Send over a "connection"
let (mut t1, t2) = Transport::make_pair();
tx.send(t2).await.unwrap();
// Send over a second "connection"
let (mut t2, t3) = Transport::make_pair();
tx.send(t3).await.unwrap();
// Send a request to mark the tenant of the first connection
t1.send(Request::new(
"test-tenant-1",
vec![RequestData::SystemInfo {}],
))
.await
.unwrap();
// Send a request to mark the tenant of the second connection
t2.send(Request::new(
"test-tenant-2",
vec![RequestData::SystemInfo {}],
))
.await
.unwrap();
// Clear out the transport channel (outbound of session)
// NOTE: Because our test stream uses a buffer size of 1, we have to clear out the
// outbound data from the earlier requests before we can send back a response
let req_1 = transport.receive::<Request>().await.unwrap().unwrap();
let req_2 = transport.receive::<Request>().await.unwrap().unwrap();
let origin_id = if req_1.tenant == "test-tenant-2" {
req_1.id
} else {
req_2.id
};
// Send a response back to a singular connection based on the tenant
let res = Response::new("test-tenant-2", origin_id, vec![ResponseData::Ok]);
transport.send(res.clone()).await.unwrap();
// Verify that response is only received by a singular connection
let inbound_res = t2.receive().await.unwrap().unwrap();
assert_eq!(res, inbound_res);
let no_inbound = tokio::select! {
_ = t1.receive::<Response>() => {false}
_ = tokio::time::sleep(Duration::from_millis(50)) => {true}
};
assert!(no_inbound, "Unexpectedly got response for wrong connection");
}
#[tokio::test]
async fn server_should_shutdown_if_no_connections_after_shutdown_duration() {
let (_transport, session) = make_session();
let (_tx, stream) = make_transport_stream();
let server =
RelayServer::initialize(session, stream, Some(Duration::from_millis(50))).unwrap();
let result = server.wait().await;
assert!(result.is_ok(), "Unexpected result: {:?}", result);
}
}

@ -1,289 +0,0 @@
use log::*;
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use tokio::{
sync::Mutex,
task::{JoinError, JoinHandle},
time::{self, Instant},
};
/// Task to keep track of a possible server shutdown based on connections
pub struct ShutdownTask {
task: JoinHandle<()>,
tracker: Arc<Mutex<ConnTracker>>,
}
impl Future for ShutdownTask {
type Output = Result<(), JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.task).poll(cx)
}
}
impl ShutdownTask {
/// Given an optional timeout, will either create the shutdown task or not,
/// returning an optional shutdown task alongside an optional connection tracker
pub fn maybe_initialize(
duration: Option<Duration>,
) -> (Option<ShutdownTask>, Option<Arc<Mutex<ConnTracker>>>) {
match duration {
Some(duration) => {
let task = Self::initialize(duration);
let tracker = task.tracker();
(Some(task), Some(tracker))
}
None => (None, None),
}
}
/// Spawns a new task that continues to monitor the time since a
/// connection on the server existed, reporting a shutdown to all listeners
/// once the timeout is exceeded
pub fn initialize(duration: Duration) -> Self {
let tracker = Arc::new(Mutex::new(ConnTracker::new()));
let tracker_2 = Arc::clone(&tracker);
let task = tokio::spawn(async move {
loop {
// Get the time since the last connection joined/left
let (base_time, cnt) = tracker_2.lock().await.time_and_cnt();
// If we have no connections left, we want to wait
// until the remaining period has passed and then
// verify that we still have no connections
if cnt == 0 {
// Get the time we should wait based on when the last connection
// was dropped; this closes the gap in the case where we start
// sometime later than exactly duration since the last check
let next_time = base_time + duration;
let wait_duration = next_time
.checked_duration_since(Instant::now())
.unwrap_or_default()
+ Duration::from_millis(1);
// Wait until we've reached our desired duration since the
// last connection was dropped
time::sleep(wait_duration).await;
// If we do have a connection at this point, don't exit
if !tracker_2.lock().await.has_reached_timeout(duration) {
continue;
}
// Otherwise, we now should exit, which we do by reporting
debug!(
"Shutdown time of {}s has been reached!",
duration.as_secs_f32()
);
break;
}
// Otherwise, we just wait the full duration as worst case
// we'll have waited just about the time desired if right
// after waiting starts the last connection is closed
time::sleep(duration).await;
}
});
Self { task, tracker }
}
/// Produces a new copy of the connection tracker associated with the shutdown manager
pub fn tracker(&self) -> Arc<Mutex<ConnTracker>> {
Arc::clone(&self.tracker)
}
}
pub struct ConnTracker {
time: Instant,
cnt: usize,
}
impl ConnTracker {
pub fn new() -> Self {
Self {
time: Instant::now(),
cnt: 0,
}
}
pub fn increment(&mut self) {
self.time = Instant::now();
self.cnt += 1;
}
pub fn decrement(&mut self) {
if self.cnt > 0 {
self.time = Instant::now();
self.cnt -= 1;
}
}
fn time_and_cnt(&self) -> (Instant, usize) {
(self.time, self.cnt)
}
fn has_reached_timeout(&self, duration: Duration) -> bool {
self.cnt == 0 && self.time.elapsed() >= duration
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[tokio::test]
async fn shutdown_task_should_not_resolve_if_has_connection_regardless_of_time() {
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
task.tracker().lock().await.increment();
assert!(
futures::poll!(&mut task).is_pending(),
"Shutdown task unexpectedly completed"
);
time::sleep(Duration::from_millis(50)).await;
assert!(
futures::poll!(task).is_pending(),
"Shutdown task unexpectedly completed"
);
}
#[tokio::test]
async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration() {
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
assert!(
futures::poll!(&mut task).is_pending(),
"Shutdown task unexpectedly completed"
);
tokio::select! {
_ = task => {}
_ = time::sleep(Duration::from_secs(1)) => {
panic!("Shutdown task unexpectedly pending");
}
}
}
#[tokio::test]
async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration_after_connection_removed(
) {
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
task.tracker().lock().await.increment();
assert!(
futures::poll!(&mut task).is_pending(),
"Shutdown task unexpectedly completed"
);
time::sleep(Duration::from_millis(50)).await;
assert!(
futures::poll!(&mut task).is_pending(),
"Shutdown task unexpectedly completed"
);
task.tracker().lock().await.decrement();
tokio::select! {
_ = task => {}
_ = time::sleep(Duration::from_secs(1)) => {
panic!("Shutdown task unexpectedly pending");
}
}
}
#[tokio::test]
async fn shutdown_task_should_not_resolve_before_minimum_duration() {
let mut task = ShutdownTask::initialize(Duration::from_millis(50));
assert!(
futures::poll!(&mut task).is_pending(),
"Shutdown task unexpectedly completed"
);
time::sleep(Duration::from_millis(5)).await;
assert!(
futures::poll!(task).is_pending(),
"Shutdown task unexpectedly completed"
);
}
#[test]
fn conn_tracker_should_update_time_when_incremented() {
let mut tracker = ConnTracker::new();
let (old_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 0);
// Wait to ensure that the new time will be different
thread::sleep(Duration::from_millis(1));
tracker.increment();
let (new_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 1);
assert!(new_time > old_time);
}
#[test]
fn conn_tracker_should_update_time_when_decremented() {
let mut tracker = ConnTracker::new();
tracker.increment();
let (old_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 1);
// Wait to ensure that the new time will be different
thread::sleep(Duration::from_millis(1));
tracker.decrement();
let (new_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 0);
assert!(new_time > old_time);
}
#[test]
fn conn_tracker_should_not_update_time_when_decremented_if_at_zero_already() {
let mut tracker = ConnTracker::new();
let (old_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 0);
// Wait to ensure that the new time would be different if updated
thread::sleep(Duration::from_millis(1));
tracker.decrement();
let (new_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 0);
assert!(new_time == old_time);
}
#[test]
fn conn_tracker_should_report_timeout_reached_when_time_has_elapsed_and_no_connections() {
let tracker = ConnTracker::new();
let (_, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 0);
// Wait to ensure that the new time would be different if updated
thread::sleep(Duration::from_millis(1));
assert!(tracker.has_reached_timeout(Duration::from_millis(1)));
}
#[test]
fn conn_tracker_should_not_report_timeout_reached_when_time_has_elapsed_but_has_connections() {
let mut tracker = ConnTracker::new();
tracker.increment();
let (_, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 1);
// Wait to ensure that the new time would be different if updated
thread::sleep(Duration::from_millis(1));
assert!(!tracker.has_reached_timeout(Duration::from_millis(1)));
}
}

@ -0,0 +1,96 @@
use distant_core::{
net::{FramedTransport, InmemoryTransport, IntoSplit, OneshotListener, PlainCodec},
BoxedDistantReader, BoxedDistantWriter, Destination, DistantApiServer, DistantChannelExt,
DistantManager, DistantManagerClient, DistantManagerClientConfig, DistantManagerConfig, Extra,
};
use std::io;
/// Creates a client transport and server listener for our tests
/// that are connected together
async fn setup() -> (
FramedTransport<InmemoryTransport, PlainCodec>,
OneshotListener<FramedTransport<InmemoryTransport, PlainCodec>>,
) {
let (t1, t2) = InmemoryTransport::pair(100);
let listener = OneshotListener::from_value(FramedTransport::new(t2, PlainCodec));
let transport = FramedTransport::new(t1, PlainCodec);
(transport, listener)
}
#[tokio::test]
async fn should_be_able_to_establish_a_single_connection_and_communicate() {
let (transport, listener) = setup().await;
let config = DistantManagerConfig::default();
let manager_ref = DistantManager::start(config, listener).expect("Failed to start manager");
// NOTE: To pass in a raw function, we HAVE to specify the types of the parameters manually,
// otherwise we get a compilation error about lifetime mismatches
manager_ref
.register_connect_handler("scheme", |_: &_, _: &_, _: &mut _| async {
use distant_core::net::ServerExt;
let (t1, t2) = FramedTransport::pair(100);
// Spawn a server on one end
let _ = DistantApiServer::local()
.unwrap()
.start(OneshotListener::from_value(t2.into_split()))?;
// Create a reader/writer pair on the other end
let (writer, reader) = t1.into_split();
let writer: BoxedDistantWriter = Box::new(writer);
let reader: BoxedDistantReader = Box::new(reader);
Ok((writer, reader))
})
.await
.expect("Failed to register handler");
let config = DistantManagerClientConfig::with_empty_prompts();
let mut client =
DistantManagerClient::new(config, transport).expect("Failed to connect to manager");
// Test establishing a connection to some remote server
let id = client
.connect(
"scheme://host".parse::<Destination>().unwrap(),
"key=value".parse::<Extra>().unwrap(),
)
.await
.expect("Failed to connect to a remote server");
// Test retrieving list of connections
let list = client
.list()
.await
.expect("Failed to get list of connections");
assert_eq!(list.len(), 1);
assert_eq!(list.get(&id).unwrap().to_string(), "scheme://host/");
// Test retrieving information
let info = client
.info(id)
.await
.expect("Failed to get info about connection");
assert_eq!(info.id, id);
assert_eq!(info.destination.to_string(), "scheme://host/");
assert_eq!(info.extra, "key=value".parse::<Extra>().unwrap());
// Create a new channel and request some data
let mut channel = client
.open_channel(id)
.await
.expect("Failed to open channel");
let _ = channel
.system_info()
.await
.expect("Failed to get system information");
// Test killing a connection
client.kill(id).await.expect("Failed to kill connection");
// Test getting an error to ensure that serialization of that data works,
// which we do by trying to access a connection that no longer exists
let err = client.info(id).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotConnected);
}

@ -1,6 +1,6 @@
use crate::stress::fixtures::*;
use assert_fs::prelude::*;
use distant_core::{ChangeKindSet, SessionChannelExt};
use distant_core::{data::ChangeKindSet, DistantChannelExt};
use rstest::*;
const MAX_FILES: usize = 500;
@ -8,11 +8,10 @@ const MAX_FILES: usize = 500;
#[rstest]
#[tokio::test]
#[ignore]
async fn should_handle_large_volume_of_file_watching(#[future] ctx: DistantSessionCtx) {
async fn should_handle_large_volume_of_file_watching(#[future] ctx: DistantClientCtx) {
let ctx = ctx.await;
let mut channel = ctx.session.clone_channel();
let mut channel = ctx.client.clone_channel();
let tenant = "watch-stress-test";
let root = assert_fs::TempDir::new().unwrap();
let mut files_and_watchers = Vec::new();
@ -25,7 +24,6 @@ async fn should_handle_large_volume_of_file_watching(#[future] ctx: DistantSessi
eprintln!("Watching {:?}", file.path());
let watcher = channel
.watch(
tenant,
file.path(),
false,
ChangeKindSet::modify_set(),

@ -1,17 +1,20 @@
use crate::stress::utils;
use distant_core::{DistantServer, SecretKey, SecretKey32, Session, XChaCha20Poly1305Codec};
use distant_core::{DistantApiServer, DistantClient, LocalDistantApi};
use distant_net::{
PortRange, SecretKey, SecretKey32, TcpClientExt, TcpServerExt, XChaCha20Poly1305Codec,
};
use rstest::*;
use std::time::Duration;
use tokio::sync::mpsc;
const LOG_PATH: &str = "/tmp/test.distant.server.log";
pub struct DistantSessionCtx {
pub session: Session,
pub struct DistantClientCtx {
pub client: DistantClient,
_done_tx: mpsc::Sender<()>,
}
impl DistantSessionCtx {
impl DistantClientCtx {
pub async fn initialize() -> Self {
let ip_addr = "127.0.0.1".parse().unwrap();
let (done_tx, mut done_rx) = mpsc::channel::<()>(1);
@ -21,14 +24,21 @@ impl DistantSessionCtx {
let logger = utils::init_logging(LOG_PATH);
let key = SecretKey::default();
let codec = XChaCha20Poly1305Codec::from(key.clone());
let (_server, port) =
DistantServer::bind(ip_addr, "0".parse().unwrap(), codec, Default::default())
if let Ok(api) = LocalDistantApi::initialize() {
let port: PortRange = "0".parse().unwrap();
let port = {
let server_ref = DistantApiServer::new(api)
.start(ip_addr, port, codec)
.await
.unwrap();
server_ref.port()
};
started_tx.send((port, key)).await.unwrap();
let _ = done_rx.recv().await;
}
logger.flush();
logger.shutdown();
});
@ -36,8 +46,8 @@ impl DistantSessionCtx {
// Extract our server startup data if we succeeded
let (port, key) = started_rx.recv().await.unwrap();
// Now initialize our session
let session = Session::tcp_connect_timeout(
// Now initialize our client
let client = DistantClient::connect_timeout(
format!("{}:{}", ip_addr, port).parse().unwrap(),
XChaCha20Poly1305Codec::from(key),
Duration::from_secs(1),
@ -45,14 +55,14 @@ impl DistantSessionCtx {
.await
.unwrap();
DistantSessionCtx {
session,
DistantClientCtx {
client,
_done_tx: done_tx,
}
}
}
#[fixture]
pub async fn ctx() -> DistantSessionCtx {
DistantSessionCtx::initialize().await
pub async fn ctx() -> DistantClientCtx {
DistantClientCtx::initialize().await
}

@ -0,0 +1,37 @@
[package]
name = "distant-net"
description = "Network library for distant, providing implementations to support client/server architecture"
categories = ["network-programming"]
keywords = ["api", "async"]
version = "0.17.0"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant"
repository = "https://github.com/chipsenkbeil/distant"
readme = "README.md"
license = "MIT OR Apache-2.0"
[dependencies]
async-trait = "0.1.56"
bytes = "1.1.0"
chacha20poly1305 = "=0.10.0-pre"
derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] }
futures = "0.3.21"
hex = "0.4.3"
hkdf = "0.12.3"
log = "0.4.17"
paste = "1.0.7"
p256 = { version = "0.11.1", features = ["ecdh", "pem"] }
rand = { version = "0.8.4", features = ["getrandom"] }
rmp-serde = "1.1.0"
sha2 = "0.10.2"
serde = { version = "1.0.126", features = ["derive"] }
serde_bytes = "0.11.6"
tokio = { version = "1.12.0", features = ["full"] }
tokio-util = { version = "0.6.7", features = ["codec"] }
# Optional dependencies based on features
schemars = { version = "0.8.10", optional = true }
[dev-dependencies]
tempfile = "3"

@ -0,0 +1,49 @@
# distant net
[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.61.0][distant_rustc_img]][distant_rustc_lnk]
[distant_crates_img]: https://img.shields.io/crates/v/distant-net.svg
[distant_crates_lnk]: https://crates.io/crates/distant-net
[distant_doc_img]: https://docs.rs/distant-net/badge.svg
[distant_doc_lnk]: https://docs.rs/distant-net
[distant_rustc_img]: https://img.shields.io/badge/distant_net-rustc_1.61+-lightgray.svg
[distant_rustc_lnk]: https://blog.rust-lang.org/2022/05/19/Rust-1.61.0.html
Library that powers the [`distant`](https://github.com/chipsenkbeil/distant)
binary.
🚧 **(Alpha stage software) This library is in rapid development and may break or change frequently!** 🚧
## Details
The `distant-net` library supplies the foundational networking functionality
for the distant interfaces and distant cli.
## Installation
You can import the dependency by adding the following to your `Cargo.toml`:
```toml
[dependencies]
distant-net = "0.17"
```
## Features
Currently, the library supports the following features:
- `schemars`: derives the `schemars::JsonSchema` interface on `Request`
and `Response` data types
By default, no features are enabled on the library.
## License
This project is licensed under either of
Apache License, Version 2.0, (LICENSE-APACHE or
[apache-license][apache-license]) MIT license (LICENSE-MIT or
[mit-license][mit-license]) at your option.
[apache-license]: http://www.apache.org/licenses/LICENSE-2.0
[mit-license]: http://opensource.org/licenses/MIT

@ -0,0 +1,29 @@
use std::any::Any;
/// Trait used for casting support into the [`Any`] trait object
pub trait AsAny: Any {
/// Converts reference to [`Any`]
fn as_any(&self) -> &dyn Any;
/// Converts mutable reference to [`Any`]
fn as_mut_any(&mut self) -> &mut dyn Any;
/// Consumes and produces `Box<dyn Any>`
fn into_any(self: Box<Self>) -> Box<dyn Any>;
}
/// Blanket implementation that enables any `'static` reference to convert
/// to the [`Any`] type
impl<T: 'static> AsAny for T {
fn as_any(&self) -> &dyn Any {
self
}
fn as_mut_any(&mut self) -> &mut dyn Any {
self
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}

@ -0,0 +1,122 @@
use derive_more::Display;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
mod client;
pub use client::*;
mod handshake;
pub use handshake::*;
mod server;
pub use server::*;
/// Represents authentication messages that can be sent over the wire
///
/// NOTE: Must use serde's content attribute with the tag attribute. Just the tag attribute will
/// cause deserialization to fail
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type", content = "data")]
pub enum Auth {
/// Represents a request to perform an authentication handshake,
/// providing the public key and salt from one side in order to
/// derive the shared key
#[serde(rename = "auth_handshake")]
Handshake {
/// Bytes of the public key
#[serde(with = "serde_bytes")]
public_key: PublicKeyBytes,
/// Randomly generated salt
#[serde(with = "serde_bytes")]
salt: Salt,
},
/// Represents the bytes of an encrypted message
///
/// Underneath, will be one of either [`AuthRequest`] or [`AuthResponse`]
#[serde(rename = "auth_msg")]
Msg {
#[serde(with = "serde_bytes")]
encrypted_payload: Vec<u8>,
},
}
/// Represents authentication messages that act as initiators such as providing
/// a challenge, verifying information, presenting information, or highlighting an error
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthRequest {
/// Represents a challenge comprising a series of questions to be presented
Challenge {
questions: Vec<AuthQuestion>,
extra: HashMap<String, String>,
},
/// Represents an ask to verify some information
Verify { kind: AuthVerifyKind, text: String },
/// Represents some information to be presented
Info { text: String },
/// Represents some error that occurred
Error { kind: AuthErrorKind, text: String },
}
/// Represents authentication messages that are responses to auth requests such
/// as answers to challenges or verifying information
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthResponse {
/// Represents the answers to a previously-asked challenge
Challenge { answers: Vec<String> },
/// Represents the answer to a previously-asked verify
Verify { valid: bool },
}
/// Represents the type of verification being requested
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum AuthVerifyKind {
/// An ask to verify the host such as with SSH
#[display(fmt = "host")]
Host,
}
/// Represents a single question in a challenge
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthQuestion {
/// The text of the question
pub text: String,
/// Any extra information specific to a particular auth domain
/// such as including a username and instructions for SSH authentication
pub extra: HashMap<String, String>,
}
impl AuthQuestion {
/// Creates a new question without any extra data
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
extra: HashMap::new(),
}
}
}
/// Represents the type of error encountered during authentication
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthErrorKind {
/// When the answer(s) to a challenge do not pass authentication
FailedChallenge,
/// When verification during authentication fails
/// (e.g. a host is not allowed or blocked)
FailedVerification,
/// When the error is unknown
Unknown,
}

@ -0,0 +1,817 @@
use crate::{
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Client,
Codec, Handshake, XChaCha20Poly1305Codec,
};
use bytes::BytesMut;
use log::*;
use std::{collections::HashMap, io};
pub struct AuthClient {
inner: Client<Auth, Auth>,
codec: Option<XChaCha20Poly1305Codec>,
jit_handshake: bool,
}
impl From<Client<Auth, Auth>> for AuthClient {
fn from(client: Client<Auth, Auth>) -> Self {
Self {
inner: client,
codec: None,
jit_handshake: false,
}
}
}
impl AuthClient {
/// Sends a request to the server to establish an encrypted connection
pub async fn handshake(&mut self) -> io::Result<()> {
let handshake = Handshake::default();
let response = self
.inner
.send(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
})
.await?;
match response.payload {
Auth::Handshake { public_key, salt } => {
let key = handshake.handshake(public_key, salt)?;
self.codec.replace(XChaCha20Poly1305Codec::new(&key));
Ok(())
}
Auth::Msg { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected encrypted message during handshake",
)),
}
}
/// Perform a handshake only if jit is enabled and no handshake has succeeded yet
async fn jit_handshake(&mut self) -> io::Result<()> {
if self.will_jit_handshake() && !self.is_ready() {
self.handshake().await
} else {
Ok(())
}
}
/// Returns true if client has successfully performed a handshake
/// and is ready to communicate with the server
pub fn is_ready(&self) -> bool {
self.codec.is_some()
}
/// Returns true if this client will perform a handshake just-in-time (JIT) prior to making a
/// request in the scenario where the client has not already performed a handshake
#[inline]
pub fn will_jit_handshake(&self) -> bool {
self.jit_handshake
}
/// Sets the jit flag on this client with `true` indicating that this client will perform a
/// handshake just-in-time (JIT) prior to making a request in the scenario where the client has
/// not already performed a handshake
#[inline]
pub fn set_jit_handshake(&mut self, flag: bool) {
self.jit_handshake = flag;
}
/// Provides a challenge to the server and returns the answers to the questions
/// asked by the client
pub async fn challenge(
&mut self,
questions: Vec<AuthQuestion>,
extra: HashMap<String, String>,
) -> io::Result<Vec<String>> {
trace!(
"AuthClient::challenge(questions = {:?}, extra = {:?})",
questions,
extra
);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Challenge { questions, extra };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
match response.payload {
Auth::Msg { encrypted_payload } => {
match self.decrypt_and_deserialize(&encrypted_payload)? {
AuthResponse::Challenge { answers } => Ok(answers),
AuthResponse::Verify { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected verify response during challenge",
)),
}
}
Auth::Handshake { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected handshake during challenge",
)),
}
}
/// Provides a verification request to the server and returns whether or not
/// the server approved
pub async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool> {
trace!("AuthClient::verify(kind = {:?}, text = {:?})", kind, text);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Verify { kind, text };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
match response.payload {
Auth::Msg { encrypted_payload } => {
match self.decrypt_and_deserialize(&encrypted_payload)? {
AuthResponse::Verify { valid } => Ok(valid),
AuthResponse::Challenge { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected challenge response during verify",
)),
}
}
Auth::Handshake { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected handshake during verify",
)),
}
}
/// Provides information to the server to use as it pleases with no response expected
pub async fn info(&mut self, text: String) -> io::Result<()> {
trace!("AuthClient::info(text = {:?})", text);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Info { text };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
self.inner.fire(Auth::Msg { encrypted_payload }).await
}
/// Provides an error to the server to use as it pleases with no response expected
pub async fn error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> {
trace!("AuthClient::error(kind = {:?}, text = {:?})", kind, text);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Error { kind, text };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
self.inner.fire(Auth::Msg { encrypted_payload }).await
}
fn serialize_and_encrypt(&mut self, payload: &AuthRequest) -> io::Result<Vec<u8>> {
let codec = self.codec.as_mut().ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (client encrypt message)",
)
})?;
let mut encryped_payload = BytesMut::new();
let payload = utils::serialize_to_vec(payload)?;
codec.encode(&payload, &mut encryped_payload)?;
Ok(encryped_payload.freeze().to_vec())
}
fn decrypt_and_deserialize(&mut self, payload: &[u8]) -> io::Result<AuthResponse> {
let codec = self.codec.as_mut().ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (client decrypt message)",
)
})?;
let mut payload = BytesMut::from(payload);
match codec.decode(&mut payload)? {
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
None => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Client, FramedTransport, Request, Response, TypedAsyncRead, TypedAsyncWrite};
use serde::{de::DeserializeOwned, Serialize};
const TIMEOUT_MILLIS: u64 = 100;
#[tokio::test]
async fn handshake_should_fail_if_get_unexpected_response_from_server() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move { client.handshake().await });
// Get the request, but send a bad response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Handshake { .. } => server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: Vec::new(),
},
))
.await
.unwrap(),
_ => panic!("Server received unexpected payload"),
}
let result = task.await.unwrap();
assert!(result.is_err(), "Handshake succeeded unexpectedly")
}
#[tokio::test]
async fn challenge_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move { client.challenge(Vec::new(), HashMap::new()).await });
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Challenge succeeded unexpectedly")
}
#[tokio::test]
async fn challenge_should_fail_if_receive_wrong_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.challenge(
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
extra: vec![("key2".to_string(), "value2".to_string())]
.into_iter()
.collect(),
},
],
vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
)
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a challenge request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Challenge { .. } => {
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Verify { valid: true },
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Challenge succeeded unexpectedly")
}
#[tokio::test]
async fn challenge_should_return_answers_received_from_server() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.challenge(
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
extra: vec![("key2".to_string(), "value2".to_string())]
.into_iter()
.collect(),
},
],
vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
)
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a challenge request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Challenge { questions, extra } => {
assert_eq!(
questions,
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
extra: vec![("key2".to_string(), "value2".to_string())]
.into_iter()
.collect(),
},
],
);
assert_eq!(
extra,
vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
);
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Challenge {
answers: vec![
"answer1".to_string(),
"answer2".to_string(),
],
},
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
let answers = task.await.unwrap().unwrap();
assert_eq!(answers, vec!["answer1".to_string(), "answer2".to_string()]);
}
#[tokio::test]
async fn verify_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client
.verify(AuthVerifyKind::Host, "some text".to_string())
.await
});
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Verify succeeded unexpectedly")
}
#[tokio::test]
async fn verify_should_fail_if_receive_wrong_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.verify(AuthVerifyKind::Host, "some text".to_string())
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a verify request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Verify { .. } => {
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Challenge {
answers: Vec::new(),
},
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Verify succeeded unexpectedly")
}
#[tokio::test]
async fn verify_should_return_valid_bool_received_from_server() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.verify(AuthVerifyKind::Host, "some text".to_string())
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a challenge request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Verify { kind, text } => {
assert_eq!(kind, AuthVerifyKind::Host);
assert_eq!(text, "some text");
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Verify { valid: true },
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
let valid = task.await.unwrap().unwrap();
assert!(valid, "Got verify response, but valid was set incorrectly");
}
#[tokio::test]
async fn info_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move { client.info("some text".to_string()).await });
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Info succeeded unexpectedly")
}
#[tokio::test]
async fn info_should_send_the_server_a_request_but_not_wait_for_a_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client.info("some text".to_string()).await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a request
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Info { text } => {
assert_eq!(text, "some text");
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn error_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
.await
});
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Error succeeded unexpectedly")
}
#[tokio::test]
async fn error_should_send_the_server_a_request_but_not_wait_for_a_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a request
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Error { kind, text } => {
assert_eq!(kind, AuthErrorKind::FailedChallenge);
assert_eq!(text, "some text");
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
task.await.unwrap().unwrap();
}
async fn wait_ms(ms: u64) {
use std::time::Duration;
tokio::time::sleep(Duration::from_millis(ms)).await;
}
fn serialize_and_encrypt<T: Serialize>(
codec: &mut XChaCha20Poly1305Codec,
payload: &T,
) -> io::Result<Vec<u8>> {
let mut encryped_payload = BytesMut::new();
let payload = utils::serialize_to_vec(payload)?;
codec.encode(&payload, &mut encryped_payload)?;
Ok(encryped_payload.freeze().to_vec())
}
fn decrypt_and_deserialize<T: DeserializeOwned>(
codec: &mut XChaCha20Poly1305Codec,
payload: &[u8],
) -> io::Result<T> {
let mut payload = BytesMut::from(payload);
match codec.decode(&mut payload)? {
Some(payload) => utils::deserialize_from_slice::<T>(&payload),
None => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
}
}
}

@ -0,0 +1,62 @@
use p256::{ecdh::EphemeralSecret, PublicKey};
use rand::rngs::OsRng;
use sha2::Sha256;
use std::{convert::TryFrom, io};
mod pkb;
pub use pkb::PublicKeyBytes;
mod salt;
pub use salt::Salt;
/// 32-byte key shared by handshake
pub type SharedKey = [u8; 32];
/// Utility to perform a handshake
pub struct Handshake {
secret: EphemeralSecret,
salt: Salt,
}
impl Default for Handshake {
// Create a new handshake instance with a secret and salt
fn default() -> Self {
let secret = EphemeralSecret::random(&mut OsRng);
let salt = Salt::random();
Self { secret, salt }
}
}
impl Handshake {
// Return encoded bytes of public key
pub fn pk_bytes(&self) -> PublicKeyBytes {
PublicKeyBytes::from(self.secret.public_key())
}
// Return the salt contained by this handshake
pub fn salt(&self) -> &Salt {
&self.salt
}
pub fn handshake(&self, public_key: PublicKeyBytes, salt: Salt) -> io::Result<SharedKey> {
// Decode the public key of the client
let decoded_public_key = PublicKey::try_from(public_key)?;
// Produce a salt that is consistent with what the other side will do
let shared_salt = self.salt ^ salt;
// Acquire the shared secret
let shared_secret = self.secret.diffie_hellman(&decoded_public_key);
// Extract entropy from the shared secret for use in producing a key
let hkdf = shared_secret.extract::<Sha256>(Some(shared_salt.as_ref()));
// Derive a shared key (32 bytes)
let mut shared_key = [0u8; 32];
match hkdf.expand(&[], &mut shared_key) {
Ok(_) => Ok(shared_key),
Err(x) => Err(io::Error::new(io::ErrorKind::InvalidData, x.to_string())),
}
}
}

@ -0,0 +1,60 @@
use p256::{EncodedPoint, PublicKey};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{convert::TryFrom, io};
/// Represents a wrapper around [`EncodedPoint`], and exists to
/// fix an issue with [`serde`] deserialization failing when
/// directly serializing the [`EncodedPoint`] type
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(into = "Vec<u8>", try_from = "Vec<u8>")]
pub struct PublicKeyBytes(EncodedPoint);
impl From<PublicKey> for PublicKeyBytes {
fn from(pk: PublicKey) -> Self {
Self(EncodedPoint::from(pk))
}
}
impl TryFrom<PublicKeyBytes> for PublicKey {
type Error = io::Error;
fn try_from(pkb: PublicKeyBytes) -> Result<Self, Self::Error> {
PublicKey::from_sec1_bytes(pkb.0.as_ref())
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
}
}
impl From<PublicKeyBytes> for Vec<u8> {
fn from(pkb: PublicKeyBytes) -> Self {
pkb.0.as_bytes().to_vec()
}
}
impl TryFrom<Vec<u8>> for PublicKeyBytes {
type Error = io::Error;
fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
Ok(Self(EncodedPoint::from_bytes(bytes).map_err(|x| {
io::Error::new(io::ErrorKind::InvalidData, x.to_string())
})?))
}
}
impl serde_bytes::Serialize for PublicKeyBytes {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(self.0.as_ref())
}
}
impl<'de> serde_bytes::Deserialize<'de> for PublicKeyBytes {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let bytes = Deserialize::deserialize(deserializer).map(serde_bytes::ByteBuf::into_vec)?;
PublicKeyBytes::try_from(bytes).map_err(serde::de::Error::custom)
}
}

@ -0,0 +1,111 @@
use rand::{rngs::OsRng, RngCore};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{
convert::{TryFrom, TryInto},
fmt, io,
ops::BitXor,
str::FromStr,
};
/// Friendly wrapper around a 32-byte array representing a salt
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(into = "Vec<u8>", try_from = "Vec<u8>")]
pub struct Salt([u8; 32]);
impl Salt {
/// Generates a salt via a uniform random
pub fn random() -> Self {
let mut salt = [0u8; 32];
OsRng.fill_bytes(&mut salt);
Self(salt)
}
}
impl fmt::Display for Salt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", hex::encode(self.0))
}
}
impl serde_bytes::Serialize for Salt {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(self.as_ref())
}
}
impl<'de> serde_bytes::Deserialize<'de> for Salt {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let bytes = Deserialize::deserialize(deserializer).map(serde_bytes::ByteBuf::into_vec)?;
let bytes_len = bytes.len();
Salt::try_from(bytes)
.map_err(|_| serde::de::Error::invalid_length(bytes_len, &"expected 32-byte length"))
}
}
impl From<Salt> for String {
fn from(salt: Salt) -> Self {
salt.to_string()
}
}
impl FromStr for Salt {
type Err = io::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let bytes = hex::decode(s).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?;
Self::try_from(bytes)
}
}
impl TryFrom<String> for Salt {
type Error = io::Error;
fn try_from(s: String) -> Result<Self, Self::Error> {
s.parse()
}
}
impl TryFrom<Vec<u8>> for Salt {
type Error = io::Error;
fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
Ok(Self(bytes.try_into().map_err(|x: Vec<u8>| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Vec<u8> len of {} != 32", x.len()),
)
})?))
}
}
impl From<Salt> for Vec<u8> {
fn from(salt: Salt) -> Self {
salt.0.to_vec()
}
}
impl AsRef<[u8]> for Salt {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl BitXor for Salt {
type Output = Self;
fn bitxor(self, rhs: Self) -> Self::Output {
let shared_salt = self
.0
.iter()
.zip(rhs.0.iter())
.map(|(x, y)| x ^ y)
.collect::<Vec<u8>>();
Self::try_from(shared_salt).unwrap()
}
}

@ -0,0 +1,653 @@
use crate::{
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Codec,
Handshake, Server, ServerCtx, XChaCha20Poly1305Codec,
};
use async_trait::async_trait;
use bytes::BytesMut;
use log::*;
use std::{collections::HashMap, io};
use tokio::sync::RwLock;
/// Type signature for a dynamic on_challenge function
pub type AuthChallengeFn =
dyn Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync;
/// Type signature for a dynamic on_verify function
pub type AuthVerifyFn = dyn Fn(AuthVerifyKind, String) -> bool + Send + Sync;
/// Type signature for a dynamic on_info function
pub type AuthInfoFn = dyn Fn(String) + Send + Sync;
/// Type signature for a dynamic on_error function
pub type AuthErrorFn = dyn Fn(AuthErrorKind, String) + Send + Sync;
/// Represents an [`AuthServer`] where all handlers are stored on the heap
pub type HeapAuthServer =
AuthServer<Box<AuthChallengeFn>, Box<AuthVerifyFn>, Box<AuthInfoFn>, Box<AuthErrorFn>>;
/// Server that handles authentication
pub struct AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
where
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
InfoFn: Fn(String) + Send + Sync,
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
{
pub on_challenge: ChallengeFn,
pub on_verify: VerifyFn,
pub on_info: InfoFn,
pub on_error: ErrorFn,
}
#[async_trait]
impl<ChallengeFn, VerifyFn, InfoFn, ErrorFn> Server
for AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
where
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
InfoFn: Fn(String) + Send + Sync,
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
{
type Request = Auth;
type Response = Auth;
type LocalData = RwLock<Option<XChaCha20Poly1305Codec>>;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
let reply = ctx.reply.clone();
match ctx.request.payload {
Auth::Handshake { public_key, salt } => {
trace!(
"Received handshake request from client, request id = {}",
ctx.request.id
);
let handshake = Handshake::default();
match handshake.handshake(public_key, salt) {
Ok(key) => {
ctx.local_data
.write()
.await
.replace(XChaCha20Poly1305Codec::new(&key));
trace!(
"Sending reciprocal handshake to client, response origin id = {}",
ctx.request.id
);
if let Err(x) = reply
.send(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
})
.await
{
error!("[Conn {}] {}", ctx.connection_id, x);
}
}
Err(x) => {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
}
}
Auth::Msg {
ref encrypted_payload,
} => {
trace!(
"Received auth msg, encrypted payload size = {}",
encrypted_payload.len()
);
// Attempt to decrypt the message so we can understand what to do
let request = match ctx.local_data.write().await.as_mut() {
Some(codec) => {
let mut payload = BytesMut::from(encrypted_payload.as_slice());
match codec.decode(&mut payload) {
Ok(Some(payload)) => {
utils::deserialize_from_slice::<AuthRequest>(&payload)
}
Ok(None) => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
Err(x) => Err(x),
}
}
None => Err(io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (server decrypt message)",
)),
};
let response = match request {
Ok(request) => match request {
AuthRequest::Challenge { questions, extra } => {
trace!("Received challenge request");
trace!("questions = {:?}", questions);
trace!("extra = {:?}", extra);
let answers = (self.on_challenge)(questions, extra);
AuthResponse::Challenge { answers }
}
AuthRequest::Verify { kind, text } => {
trace!("Received verify request");
trace!("kind = {:?}", kind);
trace!("text = {:?}", text);
let valid = (self.on_verify)(kind, text);
AuthResponse::Verify { valid }
}
AuthRequest::Info { text } => {
trace!("Received info request");
trace!("text = {:?}", text);
(self.on_info)(text);
return;
}
AuthRequest::Error { kind, text } => {
trace!("Received error request");
trace!("kind = {:?}", kind);
trace!("text = {:?}", text);
(self.on_error)(kind, text);
return;
}
},
Err(x) => {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
};
// Serialize and encrypt the message before sending it back
let encrypted_payload = match ctx.local_data.write().await.as_mut() {
Some(codec) => {
let mut encrypted_payload = BytesMut::new();
// Convert the response into bytes for us to send back
match utils::serialize_to_vec(&response) {
Ok(bytes) => match codec.encode(&bytes, &mut encrypted_payload) {
Ok(_) => Ok(encrypted_payload.freeze().to_vec()),
Err(x) => Err(x),
},
Err(x) => Err(x),
}
}
None => Err(io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (server encrypt messaage)",
)),
};
match encrypted_payload {
Ok(encrypted_payload) => {
if let Err(x) = reply.send(Auth::Msg { encrypted_payload }).await {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
}
Err(x) => {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
IntoSplit, MpscListener, MpscTransport, Request, Response, ServerExt, ServerRef,
TypedAsyncRead, TypedAsyncWrite,
};
use tokio::sync::mpsc;
const TIMEOUT_MILLIS: u64 = 100;
#[tokio::test]
async fn should_not_reply_if_receive_encrypted_msg_without_handshake_first() {
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send an encrypted message before establishing a handshake
t.write(Request::new(Auth::Msg {
encrypted_payload: Vec::new(),
}))
.await
.expect("Failed to send request to server");
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_reply_to_handshake_request_with_new_public_key_and_salt() {
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Wait for a handshake response
tokio::select! {
x = t.read() => {
let response = x.expect("Request failed").expect("Response missing");
match response.payload {
Auth::Handshake { .. } => {},
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
}
}
_ = wait_ms(TIMEOUT_MILLIS) => panic!("Ran out of time waiting on response"),
}
}
#[tokio::test]
async fn should_not_reply_if_receive_invalid_encrypted_msg() {
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send a bad chunk of data
let _codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: vec![1, 2, 3, 4],
}))
.await
.unwrap();
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_challenge_request_and_reply() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */
move |questions, extra| {
tx.try_send((questions, extra)).unwrap();
vec!["answer1".to_string(), "answer2".to_string()]
},
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Challenge {
questions: vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
extra: vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
},
],
extra: vec![("hello".to_string(), "world".to_string())]
.into_iter()
.collect(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let (questions, extra) = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(
questions,
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
extra: vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
}
]
);
assert_eq!(
extra,
vec![("hello".to_string(), "world".to_string())]
.into_iter()
.collect()
);
// Wait for a response and verify that it matches what we expect
tokio::select! {
x = t.read() => {
let response = x.expect("Request failed").expect("Response missing");
match response.payload {
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthResponse::Challenge { answers } =>
assert_eq!(
answers,
vec!["answer1".to_string(), "answer2".to_string()]
),
_ => panic!("Got wrong response for verify"),
}
},
}
}
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_verify_request_and_reply() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */
move |kind, text| {
tx.try_send((kind, text)).unwrap();
true
},
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Verify {
kind: AuthVerifyKind::Host,
text: "some text".to_string(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(kind, AuthVerifyKind::Host);
assert_eq!(text, "some text");
// Wait for a response and verify that it matches what we expect
tokio::select! {
x = t.read() => {
let response = x.expect("Request failed").expect("Response missing");
match response.payload {
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthResponse::Verify { valid } =>
assert!(valid, "Got verify, but valid was wrong"),
_ => panic!("Got wrong response for verify"),
}
},
}
}
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_info_request() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */
move |text| {
tx.try_send(text).unwrap();
},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Info {
text: "some text".to_string(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let text = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(text, "some text");
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_error_request() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */
move |kind, text| {
tx.try_send((kind, text)).unwrap();
},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Error {
kind: AuthErrorKind::FailedChallenge,
text: "some text".to_string(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(kind, AuthErrorKind::FailedChallenge);
assert_eq!(text, "some text");
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
async fn wait_ms(ms: u64) {
use std::time::Duration;
tokio::time::sleep(Duration::from_millis(ms)).await;
}
fn serialize_and_encrypt(
codec: &mut XChaCha20Poly1305Codec,
payload: &AuthRequest,
) -> io::Result<Vec<u8>> {
let mut encryped_payload = BytesMut::new();
let payload = utils::serialize_to_vec(payload)?;
codec.encode(&payload, &mut encryped_payload)?;
Ok(encryped_payload.freeze().to_vec())
}
fn decrypt_and_deserialize(
codec: &mut XChaCha20Poly1305Codec,
payload: &[u8],
) -> io::Result<AuthResponse> {
let mut payload = BytesMut::from(payload);
match codec.decode(&mut payload)? {
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
None => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
}
}
async fn spawn_auth_server<ChallengeFn, VerifyFn, InfoFn, ErrorFn>(
on_challenge: ChallengeFn,
on_verify: VerifyFn,
on_info: InfoFn,
on_error: ErrorFn,
) -> io::Result<(
MpscTransport<Request<Auth>, Response<Auth>>,
Box<dyn ServerRef>,
)>
where
ChallengeFn:
Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync + 'static,
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync + 'static,
InfoFn: Fn(String) + Send + Sync + 'static,
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync + 'static,
{
let server = AuthServer {
on_challenge,
on_verify,
on_info,
on_error,
};
// Create a test listener where we will forward a connection
let (tx, listener) = MpscListener::channel(100);
// Make bounded transport pair and send off one of them to act as our connection
let (transport, connection) = MpscTransport::<Request<Auth>, Response<Auth>>::pair(100);
tx.send(connection.into_split())
.await
.expect("Failed to feed listener a connection");
let server = server.start(listener)?;
Ok((transport, server))
}
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save