mirror of https://github.com/chipsenkbeil/distant
Refactor to use distant manager (#112)
parent
a2e17ba35b
commit
ea2e128bc4
@ -0,0 +1,4 @@
|
||||
[profile.ci]
|
||||
fail-fast = false
|
||||
retries = 2
|
||||
slow-timeout = { period = "60s", terminate-after = 3 }
|
@ -1,49 +0,0 @@
|
||||
name: CI (All)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
clippy:
|
||||
name: Lint with clippy
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RUSTFLAGS: -Dwarnings
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install Rust (clippy)
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@v1
|
||||
- name: Check Cargo availability
|
||||
run: cargo --version
|
||||
- name: distant-core (all features)
|
||||
run: cargo clippy -p distant-core --all-targets --verbose --all-features
|
||||
- name: distant-ssh2 (all features)
|
||||
run: cargo clippy -p distant-ssh2 --all-targets --verbose --all-features
|
||||
- name: distant (all features)
|
||||
run: cargo clippy --all-targets --verbose --all-features
|
||||
|
||||
rustfmt:
|
||||
name: Verify code formatting
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install Rust (rustfmt)
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
components: rustfmt
|
||||
- uses: Swatinem/rust-cache@v1
|
||||
- name: Check Cargo availability
|
||||
run: cargo --version
|
||||
- run: cargo fmt --all -- --check
|
@ -1,46 +0,0 @@
|
||||
name: CI (Linux)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}"
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- { rust: stable, os: ubuntu-latest }
|
||||
- { rust: 1.51.0, os: ubuntu-latest }
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install Rust ${{ matrix.rust }}
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{ matrix.rust }}
|
||||
- uses: Swatinem/rust-cache@v1
|
||||
- name: Check Cargo availability
|
||||
run: cargo --version
|
||||
- name: Run core tests (default features)
|
||||
run: cargo test --release --verbose -p distant-core
|
||||
- name: Run core tests (all features)
|
||||
run: cargo test --release --verbose --all-features -p distant-core
|
||||
- name: Ensure /run/sshd exists on Unix
|
||||
run: mkdir -p /run/sshd
|
||||
- name: Run ssh2 tests (default features)
|
||||
run: cargo test --release --verbose -p distant-ssh2
|
||||
- name: Run ssh2 tests (all features)
|
||||
run: cargo test --release --verbose --all-features -p distant-ssh2
|
||||
- name: Run CLI tests
|
||||
run: cargo test --release --verbose
|
||||
shell: bash
|
||||
- name: Run CLI tests (no default features)
|
||||
run: cargo test --release --verbose --no-default-features
|
||||
shell: bash
|
@ -1,43 +0,0 @@
|
||||
name: CI (MacOS)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}"
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- { rust: stable, os: macos-latest }
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install Rust ${{ matrix.rust }}
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{ matrix.rust }}
|
||||
- uses: Swatinem/rust-cache@v1
|
||||
- name: Check Cargo availability
|
||||
run: cargo --version
|
||||
- name: Run core tests (default features)
|
||||
run: cargo test --release --verbose -p distant-core
|
||||
- name: Run core tests (all features)
|
||||
run: cargo test --release --verbose --all-features -p distant-core
|
||||
- name: Run ssh2 tests (default features)
|
||||
run: cargo test --release --verbose -p distant-ssh2
|
||||
- name: Run ssh2 tests (all features)
|
||||
run: cargo test --release --verbose --all-features -p distant-ssh2
|
||||
- name: Run CLI tests
|
||||
run: cargo test --release --verbose
|
||||
shell: bash
|
||||
- name: Run CLI tests (no default features)
|
||||
run: cargo test --release --verbose --no-default-features
|
||||
shell: bash
|
@ -1,45 +0,0 @@
|
||||
name: CI (Windows)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}"
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- { rust: stable, os: windows-latest, target: x86_64-pc-windows-msvc }
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install Rust ${{ matrix.rust }}
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{ matrix.rust }}
|
||||
target: ${{ matrix.target }}
|
||||
- uses: Swatinem/rust-cache@v1
|
||||
- name: Check Cargo availability
|
||||
run: cargo --version
|
||||
- uses: Vampire/setup-wsl@v1
|
||||
- name: Run distant-core tests (default features)
|
||||
run: cargo test --release --verbose -p distant-core
|
||||
- name: Run distant-core tests (all features)
|
||||
run: cargo test --release --verbose --all-features -p distant-core
|
||||
- name: Build distant-ssh2 (default features)
|
||||
run: cargo build --release --verbose -p distant-ssh2
|
||||
- name: Build distant-ssh2 (all features)
|
||||
run: cargo build --release --verbose --all-features -p distant-ssh2
|
||||
- name: Build CLI
|
||||
run: cargo build --release --verbose
|
||||
shell: bash
|
||||
- name: Build CLI (no default features)
|
||||
run: cargo build --release --verbose --no-default-features
|
||||
shell: bash
|
@ -0,0 +1,191 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
clippy:
|
||||
name: "Lint with clippy (${{ matrix.os }})"
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- { os: windows-latest }
|
||||
- { os: ubuntu-latest }
|
||||
env:
|
||||
RUSTFLAGS: -Dwarnings
|
||||
steps:
|
||||
- name: Ensure windows git checkout keeps \n line ending
|
||||
run: |
|
||||
git config --system core.autocrlf false
|
||||
git config --system core.eol lf
|
||||
if: matrix.os == 'windows-latest'
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install Rust (clippy)
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@v1
|
||||
- name: Check Cargo availability
|
||||
run: cargo --version
|
||||
- name: distant-core (all features)
|
||||
run: cargo clippy -p distant-core --all-targets --verbose --all-features
|
||||
- name: distant-ssh2 (all features)
|
||||
run: cargo clippy -p distant-ssh2 --all-targets --verbose --all-features
|
||||
- name: distant (all features)
|
||||
run: cargo clippy --all-targets --verbose --all-features
|
||||
rustfmt:
|
||||
name: "Verify code formatting (${{ matrix.os }})"
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- { os: windows-latest }
|
||||
- { os: ubuntu-latest }
|
||||
steps:
|
||||
- name: Ensure windows git checkout keeps \n line ending
|
||||
run: |
|
||||
git config --system core.autocrlf false
|
||||
git config --system core.eol lf
|
||||
if: matrix.os == 'windows-latest'
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install Rust (rustfmt)
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
components: rustfmt
|
||||
- uses: Swatinem/rust-cache@v1
|
||||
- name: Check Cargo availability
|
||||
run: cargo --version
|
||||
- run: cargo fmt --all -- --check
|
||||
tests:
|
||||
name: "Test Rust ${{ matrix.rust }} on ${{ matrix.os }}"
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- { rust: stable, os: windows-latest, target: x86_64-pc-windows-msvc }
|
||||
- { rust: stable, os: macos-latest }
|
||||
- { rust: stable, os: ubuntu-latest }
|
||||
- { rust: 1.61.0, os: ubuntu-latest }
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install Rust ${{ matrix.rust }}
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{ matrix.rust }}
|
||||
target: ${{ matrix.target }}
|
||||
- uses: taiki-e/install-action@v1
|
||||
with:
|
||||
tool: cargo-nextest
|
||||
- uses: Swatinem/rust-cache@v1
|
||||
- name: Check Cargo availability
|
||||
run: cargo --version
|
||||
- name: Install OpenSSH on Windows
|
||||
run: |
|
||||
# From https://gist.github.com/inevity/a0d7b9f1c5ba5a813917b92736122797
|
||||
Add-Type -AssemblyName System.IO.Compression.FileSystem
|
||||
function Unzip
|
||||
{
|
||||
param([string]$zipfile, [string]$outpath)
|
||||
|
||||
[System.IO.Compression.ZipFile]::ExtractToDirectory($zipfile, $outpath)
|
||||
}
|
||||
|
||||
$url = 'https://github.com/PowerShell/Win32-OpenSSH/releases/latest/'
|
||||
$request = [System.Net.WebRequest]::Create($url)
|
||||
$request.AllowAutoRedirect=$false
|
||||
$response=$request.GetResponse()
|
||||
$file = $([String]$response.GetResponseHeader("Location")).Replace('tag','download') + '/OpenSSH-Win64.zip'
|
||||
|
||||
$client = new-object system.Net.Webclient;
|
||||
$client.DownloadFile($file ,"c:\\OpenSSH-Win64.zip")
|
||||
|
||||
Unzip "c:\\OpenSSH-Win64.zip" "C:\Program Files\"
|
||||
mv "c:\\Program Files\OpenSSH-Win64" "C:\Program Files\OpenSSH\"
|
||||
|
||||
powershell.exe -ExecutionPolicy Bypass -File "C:\Program Files\OpenSSH\install-sshd.ps1"
|
||||
|
||||
New-NetFirewallRule -Name sshd -DisplayName 'OpenSSH Server (sshd)' -Enabled True -Direction Inbound -Protocol TCP -Action Allow -LocalPort 22,49152-65535
|
||||
|
||||
net start sshd
|
||||
|
||||
Set-Service sshd -StartupType Automatic
|
||||
Set-Service ssh-agent -StartupType Automatic
|
||||
|
||||
cd "C:\Program Files\OpenSSH\"
|
||||
Powershell.exe -ExecutionPolicy Bypass -Command '. .\FixHostFilePermissions.ps1 -Confirm:$false'
|
||||
|
||||
$registryPath = "HKLM:\SOFTWARE\OpenSSH\"
|
||||
$Name = "DefaultShell"
|
||||
$value = "C:\windows\System32\WindowsPowerShell\v1.0\powershell.exe"
|
||||
|
||||
IF(!(Test-Path $registryPath))
|
||||
{
|
||||
New-Item -Path $registryPath -Force
|
||||
New-ItemProperty -Path $registryPath -Name $name -Value $value -PropertyType String -Force
|
||||
} ELSE {
|
||||
New-ItemProperty -Path $registryPath -Name $name -Value $value -PropertyType String -Force
|
||||
}
|
||||
shell: pwsh
|
||||
if: matrix.os == 'windows-latest'
|
||||
- name: Run net tests (default features)
|
||||
run: cargo nextest run --profile ci --release --verbose -p distant-net
|
||||
- name: Run core tests (default features)
|
||||
run: cargo nextest run --profile ci --release --verbose -p distant-core
|
||||
- name: Run core tests (all features)
|
||||
run: cargo nextest run --profile ci --release --verbose --all-features -p distant-core
|
||||
- name: Ensure /run/sshd exists on Unix
|
||||
run: mkdir -p /run/sshd
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
- name: Run ssh2 client tests (default features)
|
||||
run: cargo nextest run --profile ci --release --verbose -p distant-ssh2 ssh2::client
|
||||
- name: Run ssh2 client tests (all features)
|
||||
run: cargo nextest run --profile ci --release --verbose --all-features -p distant-ssh2 ssh2::client
|
||||
- name: Run CLI tests
|
||||
run: cargo nextest run --profile ci --release --verbose
|
||||
- name: Run CLI tests (no default features)
|
||||
run: cargo nextest run --profile ci --release --verbose --no-default-features
|
||||
ssh-launch-tests:
|
||||
name: "Test ssh launch using Rust ${{ matrix.rust }} on ${{ matrix.os }}"
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- { rust: stable, os: macos-latest }
|
||||
- { rust: stable, os: ubuntu-latest }
|
||||
- { rust: 1.61.0, os: ubuntu-latest }
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install Rust ${{ matrix.rust }}
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{ matrix.rust }}
|
||||
- uses: taiki-e/install-action@v1
|
||||
with:
|
||||
tool: cargo-nextest
|
||||
- uses: Swatinem/rust-cache@v1
|
||||
- name: Check Cargo availability
|
||||
run: cargo --version
|
||||
- name: Install distant cli for use in launch tests
|
||||
run: |
|
||||
cargo install --path .
|
||||
echo "DISTANT_PATH=$HOME/.cargo/bin/distant" >> $GITHUB_ENV
|
||||
- name: Run ssh2 launch tests (default features)
|
||||
run: cargo nextest run --profile ci --release --verbose -p distant-ssh2 ssh2::launched
|
||||
- name: Run ssh2 launch tests (all features)
|
||||
run: cargo nextest run --profile ci --release --verbose --all-features -p distant-ssh2 ssh2::launched
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,626 @@
|
||||
use crate::{
|
||||
data::{ChangeKind, DirEntry, Environment, Error, Metadata, ProcessId, PtySize, SystemInfo},
|
||||
ConnectionId, DistantMsg, DistantRequestData, DistantResponseData,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{Reply, Server, ServerCtx};
|
||||
use log::*;
|
||||
use std::{io, path::PathBuf, sync::Arc};
|
||||
|
||||
mod local;
|
||||
pub use local::LocalDistantApi;
|
||||
|
||||
mod reply;
|
||||
use reply::DistantSingleReply;
|
||||
|
||||
/// Represents the context provided to the [`DistantApi`] for incoming requests
|
||||
pub struct DistantCtx<T> {
|
||||
pub connection_id: ConnectionId,
|
||||
pub reply: Box<dyn Reply<Data = DistantResponseData>>,
|
||||
pub local_data: Arc<T>,
|
||||
}
|
||||
|
||||
/// Represents a server that leverages an API compliant with `distant`
|
||||
pub struct DistantApiServer<T, D>
|
||||
where
|
||||
T: DistantApi<LocalData = D>,
|
||||
{
|
||||
api: T,
|
||||
}
|
||||
|
||||
impl<T, D> DistantApiServer<T, D>
|
||||
where
|
||||
T: DistantApi<LocalData = D>,
|
||||
{
|
||||
pub fn new(api: T) -> Self {
|
||||
Self { api }
|
||||
}
|
||||
}
|
||||
|
||||
impl DistantApiServer<LocalDistantApi, <LocalDistantApi as DistantApi>::LocalData> {
|
||||
/// Creates a new server using the [`LocalDistantApi`] implementation
|
||||
pub fn local() -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
api: LocalDistantApi::initialize()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn unsupported<T>(label: &str) -> io::Result<T> {
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::Unsupported,
|
||||
format!("{} is unsupported", label),
|
||||
))
|
||||
}
|
||||
|
||||
/// Interface to support the suite of functionality available with distant,
|
||||
/// which can be used to build other servers that are compatible with distant
|
||||
#[async_trait]
|
||||
pub trait DistantApi {
|
||||
type LocalData: Send + Sync;
|
||||
|
||||
/// Invoked whenever a new connection is established, providing a mutable reference to the
|
||||
/// newly-created local data. This is a way to support modifying local data before it is used.
|
||||
#[allow(unused_variables)]
|
||||
async fn on_accept(&self, local_data: &mut Self::LocalData) {}
|
||||
|
||||
/// Reads bytes from a file.
|
||||
///
|
||||
/// * `path` - the path to the file
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn read_file(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
) -> io::Result<Vec<u8>> {
|
||||
unsupported("read_file")
|
||||
}
|
||||
|
||||
/// Reads bytes from a file as text.
|
||||
///
|
||||
/// * `path` - the path to the file
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn read_file_text(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
) -> io::Result<String> {
|
||||
unsupported("read_file_text")
|
||||
}
|
||||
|
||||
/// Writes bytes to a file, overwriting the file if it exists.
|
||||
///
|
||||
/// * `path` - the path to the file
|
||||
/// * `data` - the data to write
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn write_file(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
data: Vec<u8>,
|
||||
) -> io::Result<()> {
|
||||
unsupported("write_file")
|
||||
}
|
||||
|
||||
/// Writes text to a file, overwriting the file if it exists.
|
||||
///
|
||||
/// * `path` - the path to the file
|
||||
/// * `data` - the data to write
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn write_file_text(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
data: String,
|
||||
) -> io::Result<()> {
|
||||
unsupported("write_file_text")
|
||||
}
|
||||
|
||||
/// Writes bytes to the end of a file, creating it if it is missing.
|
||||
///
|
||||
/// * `path` - the path to the file
|
||||
/// * `data` - the data to append
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn append_file(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
data: Vec<u8>,
|
||||
) -> io::Result<()> {
|
||||
unsupported("append_file")
|
||||
}
|
||||
|
||||
/// Writes bytes to the end of a file, creating it if it is missing.
|
||||
///
|
||||
/// * `path` - the path to the file
|
||||
/// * `data` - the data to append
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn append_file_text(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
data: String,
|
||||
) -> io::Result<()> {
|
||||
unsupported("append_file_text")
|
||||
}
|
||||
|
||||
/// Reads entries from a directory.
|
||||
///
|
||||
/// * `path` - the path to the directory
|
||||
/// * `depth` - how far to traverse the directory, 0 being unlimited
|
||||
/// * `absolute` - if true, will return absolute paths instead of relative paths
|
||||
/// * `canonicalize` - if true, will canonicalize entry paths before returned
|
||||
/// * `include_root` - if true, will include the directory specified in the entries
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn read_dir(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
depth: usize,
|
||||
absolute: bool,
|
||||
canonicalize: bool,
|
||||
include_root: bool,
|
||||
) -> io::Result<(Vec<DirEntry>, Vec<io::Error>)> {
|
||||
unsupported("read_dir")
|
||||
}
|
||||
|
||||
/// Creates a directory.
|
||||
///
|
||||
/// * `path` - the path to the directory
|
||||
/// * `all` - if true, will create all missing parent components
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn create_dir(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
all: bool,
|
||||
) -> io::Result<()> {
|
||||
unsupported("create_dir")
|
||||
}
|
||||
|
||||
/// Copies some file or directory.
|
||||
///
|
||||
/// * `src` - the path to the file or directory to copy
|
||||
/// * `dst` - the path where the copy will be placed
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn copy(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
src: PathBuf,
|
||||
dst: PathBuf,
|
||||
) -> io::Result<()> {
|
||||
unsupported("copy")
|
||||
}
|
||||
|
||||
/// Removes some file or directory.
|
||||
///
|
||||
/// * `path` - the path to a file or directory
|
||||
/// * `force` - if true, will remove non-empty directories
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn remove(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
force: bool,
|
||||
) -> io::Result<()> {
|
||||
unsupported("remove")
|
||||
}
|
||||
|
||||
/// Renames some file or directory.
|
||||
///
|
||||
/// * `src` - the path to the file or directory to rename
|
||||
/// * `dst` - the new name for the file or directory
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn rename(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
src: PathBuf,
|
||||
dst: PathBuf,
|
||||
) -> io::Result<()> {
|
||||
unsupported("rename")
|
||||
}
|
||||
|
||||
/// Watches a file or directory for changes.
|
||||
///
|
||||
/// * `path` - the path to the file or directory
|
||||
/// * `recursive` - if true, will watch for changes within subdirectories and beyond
|
||||
/// * `only` - if non-empty, will limit reported changes to those included in this list
|
||||
/// * `except` - if non-empty, will limit reported changes to those not included in this list
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn watch(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
recursive: bool,
|
||||
only: Vec<ChangeKind>,
|
||||
except: Vec<ChangeKind>,
|
||||
) -> io::Result<()> {
|
||||
unsupported("watch")
|
||||
}
|
||||
|
||||
/// Removes a file or directory from being watched.
|
||||
///
|
||||
/// * `path` - the path to the file or directory
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn unwatch(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<()> {
|
||||
unsupported("unwatch")
|
||||
}
|
||||
|
||||
/// Checks if the specified path exists.
|
||||
///
|
||||
/// * `path` - the path to the file or directory
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn exists(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<bool> {
|
||||
unsupported("exists")
|
||||
}
|
||||
|
||||
/// Reads metadata for a file or directory.
|
||||
///
|
||||
/// * `path` - the path to the file or directory
|
||||
/// * `canonicalize` - if true, will include a canonicalized path in the metadata
|
||||
/// * `resolve_file_type` - if true, will resolve symlinks to underlying type (file or dir)
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn metadata(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
path: PathBuf,
|
||||
canonicalize: bool,
|
||||
resolve_file_type: bool,
|
||||
) -> io::Result<Metadata> {
|
||||
unsupported("metadata")
|
||||
}
|
||||
|
||||
/// Spawns a new process, returning its id.
|
||||
///
|
||||
/// * `cmd` - the full command to run as a new process (including arguments)
|
||||
/// * `environment` - the environment variables to associate with the process
|
||||
/// * `current_dir` - the alternative current directory to use with the process
|
||||
/// * `persist` - if true, the process will continue running even after the connection that
|
||||
/// spawned the process has terminated
|
||||
/// * `pty` - if provided, will run the process within a PTY of the given size
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn proc_spawn(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
cmd: String,
|
||||
environment: Environment,
|
||||
current_dir: Option<PathBuf>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
) -> io::Result<ProcessId> {
|
||||
unsupported("proc_spawn")
|
||||
}
|
||||
|
||||
/// Kills a running process by its id.
|
||||
///
|
||||
/// * `id` - the unique id of the process
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn proc_kill(&self, ctx: DistantCtx<Self::LocalData>, id: ProcessId) -> io::Result<()> {
|
||||
unsupported("proc_kill")
|
||||
}
|
||||
|
||||
/// Sends data to the stdin of the process with the specified id.
|
||||
///
|
||||
/// * `id` - the unique id of the process
|
||||
/// * `data` - the bytes to send to stdin
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn proc_stdin(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
id: ProcessId,
|
||||
data: Vec<u8>,
|
||||
) -> io::Result<()> {
|
||||
unsupported("proc_stdin")
|
||||
}
|
||||
|
||||
/// Resizes the PTY of the process with the specified id.
|
||||
///
|
||||
/// * `id` - the unique id of the process
|
||||
/// * `size` - the new size of the pty
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn proc_resize_pty(
|
||||
&self,
|
||||
ctx: DistantCtx<Self::LocalData>,
|
||||
id: ProcessId,
|
||||
size: PtySize,
|
||||
) -> io::Result<()> {
|
||||
unsupported("proc_resize_pty")
|
||||
}
|
||||
|
||||
/// Retrieves information about the system.
|
||||
///
|
||||
/// *Override this, otherwise it will return "unsupported" as an error.*
|
||||
#[allow(unused_variables)]
|
||||
async fn system_info(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<SystemInfo> {
|
||||
unsupported("system_info")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, D> Server for DistantApiServer<T, D>
|
||||
where
|
||||
T: DistantApi<LocalData = D> + Send + Sync,
|
||||
D: Send + Sync,
|
||||
{
|
||||
type Request = DistantMsg<DistantRequestData>;
|
||||
type Response = DistantMsg<DistantResponseData>;
|
||||
type LocalData = D;
|
||||
|
||||
/// Overridden to leverage [`DistantApi`] implementation of `on_accept`
|
||||
async fn on_accept(&self, local_data: &mut Self::LocalData) {
|
||||
T::on_accept(&self.api, local_data).await
|
||||
}
|
||||
|
||||
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
|
||||
let ServerCtx {
|
||||
connection_id,
|
||||
request,
|
||||
reply,
|
||||
local_data,
|
||||
} = ctx;
|
||||
|
||||
// Convert our reply to a queued reply so we can ensure that the result
|
||||
// of an API function is sent back before anything else
|
||||
let reply = reply.queue();
|
||||
|
||||
// Process single vs batch requests
|
||||
let response = match request.payload {
|
||||
DistantMsg::Single(data) => {
|
||||
let ctx = DistantCtx {
|
||||
connection_id,
|
||||
reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
|
||||
local_data,
|
||||
};
|
||||
|
||||
let data = handle_request(self, ctx, data).await;
|
||||
|
||||
// Report outgoing errors in our debug logs
|
||||
if let DistantResponseData::Error(x) = &data {
|
||||
debug!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
|
||||
DistantMsg::Single(data)
|
||||
}
|
||||
DistantMsg::Batch(list) => {
|
||||
let mut out = Vec::new();
|
||||
|
||||
for data in list {
|
||||
let ctx = DistantCtx {
|
||||
connection_id,
|
||||
reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
|
||||
local_data: Arc::clone(&local_data),
|
||||
};
|
||||
|
||||
// TODO: This does not run in parallel, meaning that the next item in the
|
||||
// batch will not be queued until the previous item completes! This
|
||||
// would be useful if we wanted to chain requests where the previous
|
||||
// request feeds into the current request, but not if we just want
|
||||
// to run everything together. So we should instead rewrite this
|
||||
// to spawn a task per request and then await completion of all tasks
|
||||
let data = handle_request(self, ctx, data).await;
|
||||
|
||||
// Report outgoing errors in our debug logs
|
||||
if let DistantResponseData::Error(x) = &data {
|
||||
debug!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
|
||||
out.push(data);
|
||||
}
|
||||
|
||||
DistantMsg::Batch(out)
|
||||
}
|
||||
};
|
||||
|
||||
// Queue up our result to go before ANY of the other messages that might be sent.
|
||||
// This is important to avoid situations such as when a process is started, but before
|
||||
// the confirmation can be sent some stdout or stderr is captured and sent first.
|
||||
if let Err(x) = reply.send_before(response).await {
|
||||
error!("[Conn {}] Failed to send response: {}", connection_id, x);
|
||||
}
|
||||
|
||||
// Flush out all of our replies thus far and toggle to no longer hold submissions
|
||||
if let Err(x) = reply.flush(false).await {
|
||||
error!(
|
||||
"[Conn {}] Failed to flush response queue: {}",
|
||||
connection_id, x
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Processes an incoming request
|
||||
async fn handle_request<T, D>(
|
||||
server: &DistantApiServer<T, D>,
|
||||
ctx: DistantCtx<D>,
|
||||
request: DistantRequestData,
|
||||
) -> DistantResponseData
|
||||
where
|
||||
T: DistantApi<LocalData = D> + Send + Sync,
|
||||
D: Send + Sync,
|
||||
{
|
||||
match request {
|
||||
DistantRequestData::FileRead { path } => server
|
||||
.api
|
||||
.read_file(ctx, path)
|
||||
.await
|
||||
.map(|data| DistantResponseData::Blob { data })
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::FileReadText { path } => server
|
||||
.api
|
||||
.read_file_text(ctx, path)
|
||||
.await
|
||||
.map(|data| DistantResponseData::Text { data })
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::FileWrite { path, data } => server
|
||||
.api
|
||||
.write_file(ctx, path, data)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::FileWriteText { path, text } => server
|
||||
.api
|
||||
.write_file_text(ctx, path, text)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::FileAppend { path, data } => server
|
||||
.api
|
||||
.append_file(ctx, path, data)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::FileAppendText { path, text } => server
|
||||
.api
|
||||
.append_file_text(ctx, path, text)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::DirRead {
|
||||
path,
|
||||
depth,
|
||||
absolute,
|
||||
canonicalize,
|
||||
include_root,
|
||||
} => server
|
||||
.api
|
||||
.read_dir(ctx, path, depth, absolute, canonicalize, include_root)
|
||||
.await
|
||||
.map(|(entries, errors)| DistantResponseData::DirEntries {
|
||||
entries,
|
||||
errors: errors.into_iter().map(Error::from).collect(),
|
||||
})
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::DirCreate { path, all } => server
|
||||
.api
|
||||
.create_dir(ctx, path, all)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::Remove { path, force } => server
|
||||
.api
|
||||
.remove(ctx, path, force)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::Copy { src, dst } => server
|
||||
.api
|
||||
.copy(ctx, src, dst)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::Rename { src, dst } => server
|
||||
.api
|
||||
.rename(ctx, src, dst)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::Watch {
|
||||
path,
|
||||
recursive,
|
||||
only,
|
||||
except,
|
||||
} => server
|
||||
.api
|
||||
.watch(ctx, path, recursive, only, except)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::Unwatch { path } => server
|
||||
.api
|
||||
.unwatch(ctx, path)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::Exists { path } => server
|
||||
.api
|
||||
.exists(ctx, path)
|
||||
.await
|
||||
.map(|value| DistantResponseData::Exists { value })
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::Metadata {
|
||||
path,
|
||||
canonicalize,
|
||||
resolve_file_type,
|
||||
} => server
|
||||
.api
|
||||
.metadata(ctx, path, canonicalize, resolve_file_type)
|
||||
.await
|
||||
.map(DistantResponseData::Metadata)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::ProcSpawn {
|
||||
cmd,
|
||||
environment,
|
||||
current_dir,
|
||||
persist,
|
||||
pty,
|
||||
} => server
|
||||
.api
|
||||
.proc_spawn(ctx, cmd.into(), environment, current_dir, persist, pty)
|
||||
.await
|
||||
.map(|id| DistantResponseData::ProcSpawned { id })
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::ProcKill { id } => server
|
||||
.api
|
||||
.proc_kill(ctx, id)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::ProcStdin { id, data } => server
|
||||
.api
|
||||
.proc_stdin(ctx, id, data)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::ProcResizePty { id, size } => server
|
||||
.api
|
||||
.proc_resize_pty(ctx, id, size)
|
||||
.await
|
||||
.map(|_| DistantResponseData::Ok)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
DistantRequestData::SystemInfo {} => server
|
||||
.api
|
||||
.system_info(ctx)
|
||||
.await
|
||||
.map(DistantResponseData::SystemInfo)
|
||||
.unwrap_or_else(DistantResponseData::from),
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,68 @@
|
||||
use crate::{data::ProcessId, ConnectionId};
|
||||
use std::{io, path::PathBuf};
|
||||
|
||||
mod process;
|
||||
pub use process::*;
|
||||
|
||||
mod watcher;
|
||||
pub use watcher::*;
|
||||
|
||||
/// Holds global state state managed by the server
|
||||
pub struct GlobalState {
|
||||
/// State that holds information about processes running on the server
|
||||
pub process: ProcessState,
|
||||
|
||||
/// Watcher used for filesystem events
|
||||
pub watcher: WatcherState,
|
||||
}
|
||||
|
||||
impl GlobalState {
|
||||
pub fn initialize() -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
process: ProcessState::new(),
|
||||
watcher: WatcherState::initialize()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds connection-specific state managed by the server
|
||||
#[derive(Default)]
|
||||
pub struct ConnectionState {
|
||||
/// Unique id associated with connection
|
||||
id: ConnectionId,
|
||||
|
||||
/// Channel connected to global process state
|
||||
pub(crate) process_channel: ProcessChannel,
|
||||
|
||||
/// Channel connected to global watcher state
|
||||
pub(crate) watcher_channel: WatcherChannel,
|
||||
|
||||
/// Contains ids of processes that will be terminated when the connection is closed
|
||||
processes: Vec<ProcessId>,
|
||||
|
||||
/// Contains paths being watched that will be unwatched when the connection is closed
|
||||
paths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl Drop for ConnectionState {
|
||||
fn drop(&mut self) {
|
||||
let id = self.id;
|
||||
let processes: Vec<ProcessId> = self.processes.drain(..).collect();
|
||||
let paths: Vec<PathBuf> = self.paths.drain(..).collect();
|
||||
|
||||
let process_channel = self.process_channel.clone();
|
||||
let watcher_channel = self.watcher_channel.clone();
|
||||
|
||||
// NOTE: We cannot (and should not) block during drop to perform cleanup,
|
||||
// instead spawning a task that will do the cleanup async
|
||||
tokio::spawn(async move {
|
||||
for id in processes {
|
||||
let _ = process_channel.kill(id).await;
|
||||
}
|
||||
|
||||
for path in paths {
|
||||
let _ = watcher_channel.unwatch(id, path).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
@ -0,0 +1,231 @@
|
||||
use crate::data::{DistantResponseData, Environment, ProcessId, PtySize};
|
||||
use distant_net::Reply;
|
||||
use std::{collections::HashMap, io, ops::Deref, path::PathBuf};
|
||||
use tokio::{
|
||||
sync::{mpsc, oneshot},
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
mod instance;
|
||||
pub use instance::*;
|
||||
|
||||
/// Holds information related to spawned processes on the server
|
||||
pub struct ProcessState {
|
||||
channel: ProcessChannel,
|
||||
task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Drop for ProcessState {
|
||||
/// Aborts the task that handles process operations and management
|
||||
fn drop(&mut self) {
|
||||
self.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl ProcessState {
|
||||
pub fn new() -> Self {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
let task = tokio::spawn(process_task(tx.clone(), rx));
|
||||
|
||||
Self {
|
||||
channel: ProcessChannel { tx },
|
||||
task,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clone_channel(&self) -> ProcessChannel {
|
||||
self.channel.clone()
|
||||
}
|
||||
|
||||
/// Aborts the process task
|
||||
pub fn abort(&self) {
|
||||
self.task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ProcessState {
|
||||
type Target = ProcessChannel;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.channel
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProcessChannel {
|
||||
tx: mpsc::Sender<InnerProcessMsg>,
|
||||
}
|
||||
|
||||
impl Default for ProcessChannel {
|
||||
/// Creates a new channel that is closed by default
|
||||
fn default() -> Self {
|
||||
let (tx, _) = mpsc::channel(1);
|
||||
Self { tx }
|
||||
}
|
||||
}
|
||||
|
||||
impl ProcessChannel {
|
||||
/// Spawns a new process, returning the id associated with it
|
||||
pub async fn spawn(
|
||||
&self,
|
||||
cmd: String,
|
||||
environment: Environment,
|
||||
current_dir: Option<PathBuf>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
reply: Box<dyn Reply<Data = DistantResponseData>>,
|
||||
) -> io::Result<ProcessId> {
|
||||
let (cb, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(InnerProcessMsg::Spawn {
|
||||
cmd,
|
||||
environment,
|
||||
current_dir,
|
||||
persist,
|
||||
pty,
|
||||
reply,
|
||||
cb,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?;
|
||||
rx.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to spawn dropped"))?
|
||||
}
|
||||
|
||||
/// Resizes the pty of a running process
|
||||
pub async fn resize_pty(&self, id: ProcessId, size: PtySize) -> io::Result<()> {
|
||||
let (cb, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(InnerProcessMsg::Resize { id, size, cb })
|
||||
.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?;
|
||||
rx.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to resize dropped"))?
|
||||
}
|
||||
|
||||
/// Send stdin to a running process
|
||||
pub async fn send_stdin(&self, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
|
||||
let (cb, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(InnerProcessMsg::Stdin { id, data, cb })
|
||||
.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?;
|
||||
rx.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to stdin dropped"))?
|
||||
}
|
||||
|
||||
/// Kills a running process
|
||||
pub async fn kill(&self, id: ProcessId) -> io::Result<()> {
|
||||
let (cb, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(InnerProcessMsg::Kill { id, cb })
|
||||
.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal process task closed"))?;
|
||||
rx.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to kill dropped"))?
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal message to pass to our task below to perform some action
|
||||
enum InnerProcessMsg {
|
||||
Spawn {
|
||||
cmd: String,
|
||||
environment: Environment,
|
||||
current_dir: Option<PathBuf>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
reply: Box<dyn Reply<Data = DistantResponseData>>,
|
||||
cb: oneshot::Sender<io::Result<ProcessId>>,
|
||||
},
|
||||
Resize {
|
||||
id: ProcessId,
|
||||
size: PtySize,
|
||||
cb: oneshot::Sender<io::Result<()>>,
|
||||
},
|
||||
Stdin {
|
||||
id: ProcessId,
|
||||
data: Vec<u8>,
|
||||
cb: oneshot::Sender<io::Result<()>>,
|
||||
},
|
||||
Kill {
|
||||
id: ProcessId,
|
||||
cb: oneshot::Sender<io::Result<()>>,
|
||||
},
|
||||
InternalRemove {
|
||||
id: ProcessId,
|
||||
},
|
||||
}
|
||||
|
||||
async fn process_task(tx: mpsc::Sender<InnerProcessMsg>, mut rx: mpsc::Receiver<InnerProcessMsg>) {
|
||||
let mut processes: HashMap<ProcessId, ProcessInstance> = HashMap::new();
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
match msg {
|
||||
InnerProcessMsg::Spawn {
|
||||
cmd,
|
||||
environment,
|
||||
current_dir,
|
||||
persist,
|
||||
pty,
|
||||
reply,
|
||||
cb,
|
||||
} => {
|
||||
let _ = cb.send(
|
||||
match ProcessInstance::spawn(cmd, environment, current_dir, persist, pty, reply)
|
||||
{
|
||||
Ok(mut process) => {
|
||||
let id = process.id;
|
||||
|
||||
// Attach a callback for when the process is finished where
|
||||
// we will remove it from our above list
|
||||
let tx = tx.clone();
|
||||
process.on_done(move |_| async move {
|
||||
let _ = tx.send(InnerProcessMsg::InternalRemove { id }).await;
|
||||
});
|
||||
|
||||
processes.insert(id, process);
|
||||
Ok(id)
|
||||
}
|
||||
Err(x) => Err(x),
|
||||
},
|
||||
);
|
||||
}
|
||||
InnerProcessMsg::Resize { id, size, cb } => {
|
||||
let _ = cb.send(match processes.get(&id) {
|
||||
Some(process) => process.pty.resize_pty(size),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("No process found with id {}", id),
|
||||
)),
|
||||
});
|
||||
}
|
||||
InnerProcessMsg::Stdin { id, data, cb } => {
|
||||
let _ = cb.send(match processes.get_mut(&id) {
|
||||
Some(process) => match process.stdin.as_mut() {
|
||||
Some(stdin) => stdin.send(&data).await,
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("Process {} stdin is closed", id),
|
||||
)),
|
||||
},
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("No process found with id {}", id),
|
||||
)),
|
||||
});
|
||||
}
|
||||
InnerProcessMsg::Kill { id, cb } => {
|
||||
let _ = cb.send(match processes.get_mut(&id) {
|
||||
Some(process) => process.killer.kill().await,
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("No process found with id {}", id),
|
||||
)),
|
||||
});
|
||||
}
|
||||
InnerProcessMsg::InternalRemove { id } => {
|
||||
processes.remove(&id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,230 @@
|
||||
use crate::{
|
||||
api::local::process::{
|
||||
InputChannel, OutputChannel, Process, ProcessKiller, ProcessPty, PtyProcess, SimpleProcess,
|
||||
},
|
||||
data::{DistantResponseData, Environment, ProcessId, PtySize},
|
||||
};
|
||||
use distant_net::Reply;
|
||||
use log::*;
|
||||
use std::{future::Future, io, path::PathBuf};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Holds information related to a spawned process on the server
|
||||
pub struct ProcessInstance {
|
||||
pub cmd: String,
|
||||
pub args: Vec<String>,
|
||||
pub persist: bool,
|
||||
|
||||
pub id: ProcessId,
|
||||
pub stdin: Option<Box<dyn InputChannel>>,
|
||||
pub killer: Box<dyn ProcessKiller>,
|
||||
pub pty: Box<dyn ProcessPty>,
|
||||
|
||||
stdout_task: Option<JoinHandle<io::Result<()>>>,
|
||||
stderr_task: Option<JoinHandle<io::Result<()>>>,
|
||||
wait_task: Option<JoinHandle<io::Result<()>>>,
|
||||
}
|
||||
|
||||
impl Drop for ProcessInstance {
|
||||
/// Closes stdin and attempts to kill the process when dropped
|
||||
fn drop(&mut self) {
|
||||
// Drop stdin first to close it
|
||||
self.stdin = None;
|
||||
|
||||
// Clear out our tasks if we still have them
|
||||
let stdout_task = self.stdout_task.take();
|
||||
let stderr_task = self.stderr_task.take();
|
||||
let wait_task = self.wait_task.take();
|
||||
|
||||
// Attempt to kill the process, which is an async operation that we
|
||||
// will spawn a task to handle
|
||||
let id = self.id;
|
||||
let mut killer = self.killer.clone_killer();
|
||||
tokio::spawn(async move {
|
||||
if let Err(x) = killer.kill().await {
|
||||
error!("Failed to kill process {} when dropped: {}", id, x);
|
||||
|
||||
if let Some(task) = stdout_task.as_ref() {
|
||||
task.abort();
|
||||
}
|
||||
if let Some(task) = stderr_task.as_ref() {
|
||||
task.abort();
|
||||
}
|
||||
if let Some(task) = wait_task.as_ref() {
|
||||
task.abort();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl ProcessInstance {
|
||||
pub fn spawn(
|
||||
cmd: String,
|
||||
environment: Environment,
|
||||
current_dir: Option<PathBuf>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
reply: Box<dyn Reply<Data = DistantResponseData>>,
|
||||
) -> io::Result<Self> {
|
||||
// Build out the command and args from our string
|
||||
let mut cmd_and_args = if cfg!(windows) {
|
||||
winsplit::split(&cmd)
|
||||
} else {
|
||||
shell_words::split(&cmd).map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?
|
||||
};
|
||||
|
||||
if cmd_and_args.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Command was empty",
|
||||
));
|
||||
}
|
||||
|
||||
// Split command from arguments, where arguments could be empty
|
||||
let args = cmd_and_args.split_off(1);
|
||||
let cmd = cmd_and_args.into_iter().next().unwrap();
|
||||
|
||||
let mut child: Box<dyn Process> = match pty {
|
||||
Some(size) => Box::new(PtyProcess::spawn(
|
||||
cmd.clone(),
|
||||
args.clone(),
|
||||
environment,
|
||||
current_dir,
|
||||
size,
|
||||
)?),
|
||||
None => Box::new(SimpleProcess::spawn(
|
||||
cmd.clone(),
|
||||
args.clone(),
|
||||
environment,
|
||||
current_dir,
|
||||
)?),
|
||||
};
|
||||
|
||||
let id = child.id();
|
||||
let stdin = child.take_stdin();
|
||||
let stdout = child.take_stdout();
|
||||
let stderr = child.take_stderr();
|
||||
let killer = child.clone_killer();
|
||||
let pty = child.clone_pty();
|
||||
|
||||
// Spawn a task that sends stdout as a response
|
||||
let stdout_task = match stdout {
|
||||
Some(stdout) => {
|
||||
let reply = reply.clone_reply();
|
||||
let task = tokio::spawn(stdout_task(id, stdout, reply));
|
||||
Some(task)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
// Spawn a task that sends stderr as a response
|
||||
let stderr_task = match stderr {
|
||||
Some(stderr) => {
|
||||
let reply = reply.clone_reply();
|
||||
let task = tokio::spawn(stderr_task(id, stderr, reply));
|
||||
Some(task)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
// Spawn a task that waits on the process to exit but can also
|
||||
// kill the process when triggered
|
||||
let wait_task = Some(tokio::spawn(wait_task(id, child, reply)));
|
||||
|
||||
Ok(ProcessInstance {
|
||||
cmd,
|
||||
args,
|
||||
persist,
|
||||
id,
|
||||
stdin,
|
||||
killer,
|
||||
pty,
|
||||
stdout_task,
|
||||
stderr_task,
|
||||
wait_task,
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the function once the process has completed
|
||||
///
|
||||
/// NOTE: Can only be used with one function. All future calls
|
||||
/// will do nothing
|
||||
pub fn on_done<F, R>(&mut self, f: F)
|
||||
where
|
||||
F: FnOnce(io::Result<()>) -> R + Send + 'static,
|
||||
R: Future<Output = ()> + Send,
|
||||
{
|
||||
if let Some(task) = self.wait_task.take() {
|
||||
tokio::spawn(async move {
|
||||
f(task
|
||||
.await
|
||||
.unwrap_or_else(|x| Err(io::Error::new(io::ErrorKind::Other, x))))
|
||||
.await
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn stdout_task(
|
||||
id: ProcessId,
|
||||
mut stdout: Box<dyn OutputChannel>,
|
||||
reply: Box<dyn Reply<Data = DistantResponseData>>,
|
||||
) -> io::Result<()> {
|
||||
loop {
|
||||
match stdout.recv().await {
|
||||
Ok(Some(data)) => {
|
||||
if let Err(x) = reply
|
||||
.send(DistantResponseData::ProcStdout { id, data })
|
||||
.await
|
||||
{
|
||||
return Err(x);
|
||||
}
|
||||
}
|
||||
Ok(None) => return Ok(()),
|
||||
Err(x) => return Err(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn stderr_task(
|
||||
id: ProcessId,
|
||||
mut stderr: Box<dyn OutputChannel>,
|
||||
reply: Box<dyn Reply<Data = DistantResponseData>>,
|
||||
) -> io::Result<()> {
|
||||
loop {
|
||||
match stderr.recv().await {
|
||||
Ok(Some(data)) => {
|
||||
if let Err(x) = reply
|
||||
.send(DistantResponseData::ProcStderr { id, data })
|
||||
.await
|
||||
{
|
||||
return Err(x);
|
||||
}
|
||||
}
|
||||
Ok(None) => return Ok(()),
|
||||
Err(x) => return Err(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_task(
|
||||
id: ProcessId,
|
||||
mut child: Box<dyn Process>,
|
||||
reply: Box<dyn Reply<Data = DistantResponseData>>,
|
||||
) -> io::Result<()> {
|
||||
let status = child.wait().await;
|
||||
|
||||
match status {
|
||||
Ok(status) => {
|
||||
reply
|
||||
.send(DistantResponseData::ProcDone {
|
||||
id,
|
||||
success: status.success,
|
||||
code: status.code,
|
||||
})
|
||||
.await
|
||||
}
|
||||
Err(x) => reply.send(DistantResponseData::from(x)).await,
|
||||
}
|
||||
}
|
@ -0,0 +1,286 @@
|
||||
use crate::{constants::SERVER_WATCHER_CAPACITY, data::ChangeKind, ConnectionId};
|
||||
use log::*;
|
||||
use notify::{
|
||||
Config as WatcherConfig, Error as WatcherError, Event as WatcherEvent, RecommendedWatcher,
|
||||
RecursiveMode, Watcher,
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
ops::Deref,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc::{self, error::TrySendError},
|
||||
oneshot,
|
||||
},
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
mod path;
|
||||
pub use path::*;
|
||||
|
||||
/// Holds information related to watched paths on the server
|
||||
pub struct WatcherState {
|
||||
channel: WatcherChannel,
|
||||
task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Drop for WatcherState {
|
||||
/// Aborts the task that handles watcher path operations and management
|
||||
fn drop(&mut self) {
|
||||
self.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl WatcherState {
|
||||
/// Will create a watcher and initialize watched paths to be empty
|
||||
pub fn initialize() -> io::Result<Self> {
|
||||
// NOTE: Cannot be something small like 1 as this seems to cause a deadlock sometimes
|
||||
// with a large volume of watch requests
|
||||
let (tx, rx) = mpsc::channel(SERVER_WATCHER_CAPACITY);
|
||||
|
||||
let mut watcher = {
|
||||
let tx = tx.clone();
|
||||
notify::recommended_watcher(move |res| {
|
||||
match tx.try_send(match res {
|
||||
Ok(x) => InnerWatcherMsg::Event { ev: x },
|
||||
Err(x) => InnerWatcherMsg::Error { err: x },
|
||||
}) {
|
||||
Ok(_) => (),
|
||||
Err(TrySendError::Full(_)) => {
|
||||
warn!(
|
||||
"Reached watcher capacity of {}! Dropping watcher event!",
|
||||
SERVER_WATCHER_CAPACITY,
|
||||
);
|
||||
}
|
||||
Err(TrySendError::Closed(_)) => {
|
||||
warn!("Skipping watch event because watcher channel closed");
|
||||
}
|
||||
}
|
||||
})
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?
|
||||
};
|
||||
|
||||
// Attempt to configure watcher, but don't fail if these configurations fail
|
||||
match watcher.configure(WatcherConfig::PreciseEvents(true)) {
|
||||
Ok(true) => debug!("Watcher configured for precise events"),
|
||||
Ok(false) => debug!("Watcher not configured for precise events",),
|
||||
Err(x) => error!("Watcher configuration for precise events failed: {}", x),
|
||||
}
|
||||
|
||||
// Attempt to configure watcher, but don't fail if these configurations fail
|
||||
match watcher.configure(WatcherConfig::NoticeEvents(true)) {
|
||||
Ok(true) => debug!("Watcher configured for notice events"),
|
||||
Ok(false) => debug!("Watcher not configured for notice events",),
|
||||
Err(x) => error!("Watcher configuration for notice events failed: {}", x),
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
channel: WatcherChannel { tx },
|
||||
task: tokio::spawn(watcher_task(watcher, rx)),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clone_channel(&self) -> WatcherChannel {
|
||||
self.channel.clone()
|
||||
}
|
||||
|
||||
/// Aborts the watcher task
|
||||
pub fn abort(&self) {
|
||||
self.task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for WatcherState {
|
||||
type Target = WatcherChannel;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.channel
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WatcherChannel {
|
||||
tx: mpsc::Sender<InnerWatcherMsg>,
|
||||
}
|
||||
|
||||
impl Default for WatcherChannel {
|
||||
/// Creates a new channel that is closed by default
|
||||
fn default() -> Self {
|
||||
let (tx, _) = mpsc::channel(1);
|
||||
Self { tx }
|
||||
}
|
||||
}
|
||||
|
||||
impl WatcherChannel {
|
||||
/// Watch a path for a specific connection denoted by the id within the registered path
|
||||
pub async fn watch(&self, registered_path: RegisteredPath) -> io::Result<()> {
|
||||
let (cb, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(InnerWatcherMsg::Watch {
|
||||
registered_path,
|
||||
cb,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal watcher task closed"))?;
|
||||
rx.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to watch dropped"))?
|
||||
}
|
||||
|
||||
/// Unwatch a path for a specific connection denoted by the id
|
||||
pub async fn unwatch(&self, id: ConnectionId, path: impl AsRef<Path>) -> io::Result<()> {
|
||||
let (cb, rx) = oneshot::channel();
|
||||
let path = tokio::fs::canonicalize(path.as_ref())
|
||||
.await
|
||||
.unwrap_or_else(|_| path.as_ref().to_path_buf());
|
||||
self.tx
|
||||
.send(InnerWatcherMsg::Unwatch { id, path, cb })
|
||||
.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal watcher task closed"))?;
|
||||
rx.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to unwatch dropped"))?
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal message to pass to our task below to perform some action
|
||||
enum InnerWatcherMsg {
|
||||
Watch {
|
||||
registered_path: RegisteredPath,
|
||||
cb: oneshot::Sender<io::Result<()>>,
|
||||
},
|
||||
Unwatch {
|
||||
id: ConnectionId,
|
||||
path: PathBuf,
|
||||
cb: oneshot::Sender<io::Result<()>>,
|
||||
},
|
||||
Event {
|
||||
ev: WatcherEvent,
|
||||
},
|
||||
Error {
|
||||
err: WatcherError,
|
||||
},
|
||||
}
|
||||
|
||||
async fn watcher_task(mut watcher: RecommendedWatcher, mut rx: mpsc::Receiver<InnerWatcherMsg>) {
|
||||
// TODO: Optimize this in some way to be more performant than
|
||||
// checking every path whenever an event comes in
|
||||
let mut registered_paths: Vec<RegisteredPath> = Vec::new();
|
||||
let mut path_cnt: HashMap<PathBuf, usize> = HashMap::new();
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
match msg {
|
||||
InnerWatcherMsg::Watch {
|
||||
registered_path,
|
||||
cb,
|
||||
} => {
|
||||
// Check if we are tracking the path across any connection
|
||||
if let Some(cnt) = path_cnt.get_mut(registered_path.path()) {
|
||||
// Increment the count of times we are watching that path
|
||||
*cnt += 1;
|
||||
|
||||
// Store the registered path in our collection without worry
|
||||
// since we are already watching a path that impacts this one
|
||||
registered_paths.push(registered_path);
|
||||
|
||||
// Send an okay because we always succeed in this case
|
||||
let _ = cb.send(Ok(()));
|
||||
} else {
|
||||
let res = watcher
|
||||
.watch(
|
||||
registered_path.path(),
|
||||
if registered_path.is_recursive() {
|
||||
RecursiveMode::Recursive
|
||||
} else {
|
||||
RecursiveMode::NonRecursive
|
||||
},
|
||||
)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x));
|
||||
|
||||
// If we succeeded, store our registered path and set the tracking cnt to 1
|
||||
if res.is_ok() {
|
||||
path_cnt.insert(registered_path.path().to_path_buf(), 1);
|
||||
registered_paths.push(registered_path);
|
||||
}
|
||||
|
||||
// Send the result of the watch, but don't worry if the channel was closed
|
||||
let _ = cb.send(res);
|
||||
}
|
||||
}
|
||||
InnerWatcherMsg::Unwatch { id, path, cb } => {
|
||||
// Check if we are tracking the path across any connection
|
||||
if let Some(cnt) = path_cnt.get(path.as_path()) {
|
||||
// Cycle through and remove all paths that match the given id and path,
|
||||
// capturing how many paths we removed
|
||||
let removed_cnt = {
|
||||
let old_len = registered_paths.len();
|
||||
registered_paths
|
||||
.retain(|p| p.id() != id || (p.path() != path && p.raw_path() != path));
|
||||
let new_len = registered_paths.len();
|
||||
old_len - new_len
|
||||
};
|
||||
|
||||
// 1. If we are now at zero cnt for our path, we want to actually unwatch the
|
||||
// path with our watcher
|
||||
// 2. If we removed nothing from our path list, we want to return an error
|
||||
// 3. Otherwise, we return okay because we succeeded
|
||||
if *cnt <= removed_cnt {
|
||||
let _ = cb.send(
|
||||
watcher
|
||||
.unwatch(&path)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x)),
|
||||
);
|
||||
} else if removed_cnt == 0 {
|
||||
// Send a failure as there was nothing to unwatch for this connection
|
||||
let _ = cb.send(Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("{:?} is not being watched", path),
|
||||
)));
|
||||
} else {
|
||||
// Send a success as we removed some paths
|
||||
let _ = cb.send(Ok(()));
|
||||
}
|
||||
} else {
|
||||
// Send a failure as there was nothing to unwatch
|
||||
let _ = cb.send(Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("{:?} is not being watched", path),
|
||||
)));
|
||||
}
|
||||
}
|
||||
InnerWatcherMsg::Event { ev } => {
|
||||
let kind = ChangeKind::from(ev.kind);
|
||||
|
||||
for registered_path in registered_paths.iter() {
|
||||
match registered_path.filter_and_send(kind, &ev.paths).await {
|
||||
Ok(_) => (),
|
||||
Err(x) => error!(
|
||||
"[Conn {}] Failed to forward changes to paths: {}",
|
||||
registered_path.id(),
|
||||
x
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
InnerWatcherMsg::Error { err } => {
|
||||
let msg = err.to_string();
|
||||
error!("Watcher encountered an error {} for {:?}", msg, err.paths);
|
||||
|
||||
for registered_path in registered_paths.iter() {
|
||||
match registered_path
|
||||
.filter_and_send_error(&msg, &err.paths, !err.paths.is_empty())
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(x) => error!(
|
||||
"[Conn {}] Failed to forward changes to paths: {}",
|
||||
registered_path.id(),
|
||||
x
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,212 @@
|
||||
use crate::{
|
||||
data::{Change, ChangeKind, ChangeKindSet, DistantResponseData, Error},
|
||||
ConnectionId,
|
||||
};
|
||||
use distant_net::Reply;
|
||||
use std::{
|
||||
fmt,
|
||||
hash::{Hash, Hasher},
|
||||
io,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
/// Represents a path registered with a watcher that includes relevant state including
|
||||
/// the ability to reply with
|
||||
pub struct RegisteredPath {
|
||||
/// Unique id tied to the path to distinguish it
|
||||
id: ConnectionId,
|
||||
|
||||
/// The raw path provided to the watcher, which is not canonicalized
|
||||
raw_path: PathBuf,
|
||||
|
||||
/// The canonicalized path at the time of providing to the watcher,
|
||||
/// as all paths must exist for a watcher, we use this to get the
|
||||
/// source of truth when watching
|
||||
path: PathBuf,
|
||||
|
||||
/// Whether or not the path was set to be recursive
|
||||
recursive: bool,
|
||||
|
||||
/// Specific filter for path (only the allowed change kinds are tracked)
|
||||
/// NOTE: This is a combination of only and except filters
|
||||
allowed: ChangeKindSet,
|
||||
|
||||
/// Used to send a reply through the connection watching this path
|
||||
reply: Box<dyn Reply<Data = DistantResponseData>>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for RegisteredPath {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("RegisteredPath")
|
||||
.field("raw_path", &self.raw_path)
|
||||
.field("path", &self.path)
|
||||
.field("recursive", &self.recursive)
|
||||
.field("allowed", &self.allowed)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for RegisteredPath {
|
||||
/// Checks for equality using the id, canonicalized path, and allowed change kinds
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id == other.id && self.path == other.path && self.allowed == other.allowed
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for RegisteredPath {}
|
||||
|
||||
impl Hash for RegisteredPath {
|
||||
/// Hashes using the id, canonicalized path, and allowed change kinds
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.id.hash(state);
|
||||
self.path.hash(state);
|
||||
self.allowed.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl RegisteredPath {
|
||||
/// Registers a new path to be watched (does not actually do any watching)
|
||||
pub async fn register(
|
||||
id: ConnectionId,
|
||||
path: impl Into<PathBuf>,
|
||||
recursive: bool,
|
||||
only: impl Into<ChangeKindSet>,
|
||||
except: impl Into<ChangeKindSet>,
|
||||
reply: Box<dyn Reply<Data = DistantResponseData>>,
|
||||
) -> io::Result<Self> {
|
||||
let raw_path = path.into();
|
||||
let path = tokio::fs::canonicalize(raw_path.as_path()).await?;
|
||||
let only = only.into();
|
||||
let except = except.into();
|
||||
|
||||
// Calculate the true list of kinds based on only and except filters
|
||||
let allowed = if only.is_empty() {
|
||||
ChangeKindSet::all() - except
|
||||
} else {
|
||||
only - except
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
id,
|
||||
raw_path,
|
||||
path,
|
||||
recursive,
|
||||
allowed,
|
||||
reply,
|
||||
})
|
||||
}
|
||||
|
||||
/// Represents a unique id to distinguish this path from other registrations
|
||||
/// of the same path
|
||||
pub fn id(&self) -> ConnectionId {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Represents the path provided during registration before canonicalization
|
||||
pub fn raw_path(&self) -> &Path {
|
||||
self.raw_path.as_path()
|
||||
}
|
||||
|
||||
/// Represents the canonicalized path used by watchers
|
||||
pub fn path(&self) -> &Path {
|
||||
self.path.as_path()
|
||||
}
|
||||
|
||||
/// Returns true if this path represents a recursive watcher path
|
||||
pub fn is_recursive(&self) -> bool {
|
||||
self.recursive
|
||||
}
|
||||
|
||||
/// Returns reference to set of [`ChangeKind`] that this path watches
|
||||
pub fn allowed(&self) -> &ChangeKindSet {
|
||||
&self.allowed
|
||||
}
|
||||
|
||||
/// Sends a reply for a change tied to this registered path, filtering
|
||||
/// out any paths that are not applicable
|
||||
///
|
||||
/// Returns true if message was sent, and false if not
|
||||
pub async fn filter_and_send<T>(&self, kind: ChangeKind, paths: T) -> io::Result<bool>
|
||||
where
|
||||
T: IntoIterator,
|
||||
T::Item: AsRef<Path>,
|
||||
{
|
||||
if !self.allowed().contains(&kind) {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let paths: Vec<PathBuf> = paths
|
||||
.into_iter()
|
||||
.filter(|p| self.applies_to_path(p.as_ref()))
|
||||
.map(|p| p.as_ref().to_path_buf())
|
||||
.collect();
|
||||
|
||||
if !paths.is_empty() {
|
||||
self.reply
|
||||
.send(DistantResponseData::Changed(Change { kind, paths }))
|
||||
.await
|
||||
.map(|_| true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends an error message and includes paths if provided, skipping sending the message if
|
||||
/// no paths match and `skip_if_no_paths` is true
|
||||
///
|
||||
/// Returns true if message was sent, and false if not
|
||||
pub async fn filter_and_send_error<T>(
|
||||
&self,
|
||||
msg: &str,
|
||||
paths: T,
|
||||
skip_if_no_paths: bool,
|
||||
) -> io::Result<bool>
|
||||
where
|
||||
T: IntoIterator,
|
||||
T::Item: AsRef<Path>,
|
||||
{
|
||||
let paths: Vec<PathBuf> = paths
|
||||
.into_iter()
|
||||
.filter(|p| self.applies_to_path(p.as_ref()))
|
||||
.map(|p| p.as_ref().to_path_buf())
|
||||
.collect();
|
||||
|
||||
if !paths.is_empty() || !skip_if_no_paths {
|
||||
self.reply
|
||||
.send(if paths.is_empty() {
|
||||
DistantResponseData::Error(Error::from(msg))
|
||||
} else {
|
||||
DistantResponseData::Error(Error::from(format!("{} about {:?}", msg, paths)))
|
||||
})
|
||||
.await
|
||||
.map(|_| true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this path applies to the given path.
|
||||
/// This is accomplished by checking if the path is contained
|
||||
/// within either the raw or canonicalized path of the watcher
|
||||
/// and ensures that recursion rules are respected
|
||||
pub fn applies_to_path(&self, path: &Path) -> bool {
|
||||
let check_path = |path: &Path| -> bool {
|
||||
let cnt = path.components().count();
|
||||
|
||||
// 0 means exact match from strip_prefix
|
||||
// 1 means that it was within immediate directory (fine for non-recursive)
|
||||
// 2+ means it needs to be recursive
|
||||
cnt < 2 || self.recursive
|
||||
};
|
||||
|
||||
match (
|
||||
path.strip_prefix(self.path()),
|
||||
path.strip_prefix(self.raw_path()),
|
||||
) {
|
||||
(Ok(p1), Ok(p2)) => check_path(p1) || check_path(p2),
|
||||
(Ok(p), Err(_)) => check_path(p),
|
||||
(Err(_), Ok(p)) => check_path(p),
|
||||
(Err(_), Err(_)) => false,
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
use crate::{api::DistantMsg, data::DistantResponseData};
|
||||
use distant_net::Reply;
|
||||
use std::{future::Future, io, pin::Pin};
|
||||
|
||||
/// Wrapper around a reply that can be batch or single, converting
|
||||
/// a single data into the wrapped type
|
||||
pub struct DistantSingleReply(Box<dyn Reply<Data = DistantMsg<DistantResponseData>>>);
|
||||
|
||||
impl From<Box<dyn Reply<Data = DistantMsg<DistantResponseData>>>> for DistantSingleReply {
|
||||
fn from(reply: Box<dyn Reply<Data = DistantMsg<DistantResponseData>>>) -> Self {
|
||||
Self(reply)
|
||||
}
|
||||
}
|
||||
|
||||
impl Reply for DistantSingleReply {
|
||||
type Data = DistantResponseData;
|
||||
|
||||
fn send(&self, data: Self::Data) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + '_>> {
|
||||
self.0.send(DistantMsg::Single(data))
|
||||
}
|
||||
|
||||
fn blocking_send(&self, data: Self::Data) -> io::Result<()> {
|
||||
self.0.blocking_send(DistantMsg::Single(data))
|
||||
}
|
||||
|
||||
fn clone_reply(&self) -> Box<dyn Reply<Data = Self::Data>> {
|
||||
Box::new(Self(self.0.clone_reply()))
|
||||
}
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
use crate::{DistantMsg, DistantRequestData, DistantResponseData};
|
||||
use distant_net::{Channel, Client};
|
||||
|
||||
mod ext;
|
||||
mod lsp;
|
||||
mod process;
|
||||
mod watcher;
|
||||
|
||||
/// Represents a [`Client`] that communicates using the distant protocol
|
||||
pub type DistantClient = Client<DistantMsg<DistantRequestData>, DistantMsg<DistantResponseData>>;
|
||||
|
||||
/// Represents a [`Channel`] that communicates using the distant protocol
|
||||
pub type DistantChannel = Channel<DistantMsg<DistantRequestData>, DistantMsg<DistantResponseData>>;
|
||||
|
||||
pub use ext::*;
|
||||
pub use lsp::*;
|
||||
pub use process::*;
|
||||
pub use watcher::*;
|
@ -0,0 +1,424 @@
|
||||
use crate::{
|
||||
client::{
|
||||
RemoteCommand, RemoteLspCommand, RemoteLspProcess, RemoteOutput, RemoteProcess, Watcher,
|
||||
},
|
||||
data::{
|
||||
ChangeKindSet, DirEntry, DistantRequestData, DistantResponseData, Environment,
|
||||
Error as Failure, Metadata, PtySize, SystemInfo,
|
||||
},
|
||||
DistantMsg,
|
||||
};
|
||||
use distant_net::{Channel, Request};
|
||||
use std::{future::Future, io, path::PathBuf, pin::Pin};
|
||||
|
||||
pub type AsyncReturn<'a, T, E = io::Error> =
|
||||
Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>;
|
||||
|
||||
fn mismatched_response() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, "Mismatched response")
|
||||
}
|
||||
|
||||
/// Provides convenience functions on top of a [`SessionChannel`]
|
||||
pub trait DistantChannelExt {
|
||||
/// Appends to a remote file using the data from a collection of bytes
|
||||
fn append_file(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<Vec<u8>>,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Appends to a remote file using the data from a string
|
||||
fn append_file_text(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<String>,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Copies a remote file or directory from src to dst
|
||||
fn copy(&mut self, src: impl Into<PathBuf>, dst: impl Into<PathBuf>) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Creates a remote directory, optionally creating all parent components if specified
|
||||
fn create_dir(&mut self, path: impl Into<PathBuf>, all: bool) -> AsyncReturn<'_, ()>;
|
||||
|
||||
fn exists(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, bool>;
|
||||
|
||||
/// Retrieves metadata about a path on a remote machine
|
||||
fn metadata(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
canonicalize: bool,
|
||||
resolve_file_type: bool,
|
||||
) -> AsyncReturn<'_, Metadata>;
|
||||
|
||||
/// Reads entries from a directory, returning a tuple of directory entries and failures
|
||||
fn read_dir(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
depth: usize,
|
||||
absolute: bool,
|
||||
canonicalize: bool,
|
||||
include_root: bool,
|
||||
) -> AsyncReturn<'_, (Vec<DirEntry>, Vec<Failure>)>;
|
||||
|
||||
/// Reads a remote file as a collection of bytes
|
||||
fn read_file(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, Vec<u8>>;
|
||||
|
||||
/// Returns a remote file as a string
|
||||
fn read_file_text(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, String>;
|
||||
|
||||
/// Removes a remote file or directory, supporting removal of non-empty directories if
|
||||
/// force is true
|
||||
fn remove(&mut self, path: impl Into<PathBuf>, force: bool) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Renames a remote file or directory from src to dst
|
||||
fn rename(&mut self, src: impl Into<PathBuf>, dst: impl Into<PathBuf>) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Watches a remote file or directory
|
||||
fn watch(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
recursive: bool,
|
||||
only: impl Into<ChangeKindSet>,
|
||||
except: impl Into<ChangeKindSet>,
|
||||
) -> AsyncReturn<'_, Watcher>;
|
||||
|
||||
/// Unwatches a remote file or directory
|
||||
fn unwatch(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Spawns a process on the remote machine
|
||||
fn spawn(
|
||||
&mut self,
|
||||
cmd: impl Into<String>,
|
||||
environment: Environment,
|
||||
current_dir: Option<PathBuf>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
) -> AsyncReturn<'_, RemoteProcess>;
|
||||
|
||||
/// Spawns an LSP process on the remote machine
|
||||
fn spawn_lsp(
|
||||
&mut self,
|
||||
cmd: impl Into<String>,
|
||||
environment: Environment,
|
||||
current_dir: Option<PathBuf>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
) -> AsyncReturn<'_, RemoteLspProcess>;
|
||||
|
||||
/// Spawns a process on the remote machine and wait for it to complete
|
||||
fn output(
|
||||
&mut self,
|
||||
cmd: impl Into<String>,
|
||||
environment: Environment,
|
||||
current_dir: Option<PathBuf>,
|
||||
pty: Option<PtySize>,
|
||||
) -> AsyncReturn<'_, RemoteOutput>;
|
||||
|
||||
/// Retrieves information about the remote system
|
||||
fn system_info(&mut self) -> AsyncReturn<'_, SystemInfo>;
|
||||
|
||||
/// Writes a remote file with the data from a collection of bytes
|
||||
fn write_file(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<Vec<u8>>,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Writes a remote file with the data from a string
|
||||
fn write_file_text(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<String>,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
}
|
||||
|
||||
macro_rules! make_body {
|
||||
($self:expr, $data:expr, @ok) => {
|
||||
make_body!($self, $data, |data| {
|
||||
match data {
|
||||
DistantResponseData::Ok => Ok(()),
|
||||
DistantResponseData::Error(x) => Err(io::Error::from(x)),
|
||||
_ => Err(mismatched_response()),
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
($self:expr, $data:expr, $and_then:expr) => {{
|
||||
let req = Request::new(DistantMsg::Single($data));
|
||||
Box::pin(async move {
|
||||
$self
|
||||
.send(req)
|
||||
.await
|
||||
.and_then(|res| match res.payload {
|
||||
DistantMsg::Single(x) => Ok(x),
|
||||
_ => Err(mismatched_response()),
|
||||
})
|
||||
.and_then($and_then)
|
||||
})
|
||||
}};
|
||||
}
|
||||
|
||||
impl DistantChannelExt
|
||||
for Channel<DistantMsg<DistantRequestData>, DistantMsg<DistantResponseData>>
|
||||
{
|
||||
fn append_file(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<Vec<u8>>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::FileAppend { path: path.into(), data: data.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn append_file_text(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<String>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::FileAppendText { path: path.into(), text: data.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn copy(&mut self, src: impl Into<PathBuf>, dst: impl Into<PathBuf>) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::Copy { src: src.into(), dst: dst.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn create_dir(&mut self, path: impl Into<PathBuf>, all: bool) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::DirCreate { path: path.into(), all },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn exists(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, bool> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::Exists { path: path.into() },
|
||||
|data| match data {
|
||||
DistantResponseData::Exists { value } => Ok(value),
|
||||
DistantResponseData::Error(x) => Err(io::Error::from(x)),
|
||||
_ => Err(mismatched_response()),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn metadata(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
canonicalize: bool,
|
||||
resolve_file_type: bool,
|
||||
) -> AsyncReturn<'_, Metadata> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::Metadata {
|
||||
path: path.into(),
|
||||
canonicalize,
|
||||
resolve_file_type
|
||||
},
|
||||
|data| match data {
|
||||
DistantResponseData::Metadata(x) => Ok(x),
|
||||
DistantResponseData::Error(x) => Err(io::Error::from(x)),
|
||||
_ => Err(mismatched_response()),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn read_dir(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
depth: usize,
|
||||
absolute: bool,
|
||||
canonicalize: bool,
|
||||
include_root: bool,
|
||||
) -> AsyncReturn<'_, (Vec<DirEntry>, Vec<Failure>)> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::DirRead {
|
||||
path: path.into(),
|
||||
depth,
|
||||
absolute,
|
||||
canonicalize,
|
||||
include_root
|
||||
},
|
||||
|data| match data {
|
||||
DistantResponseData::DirEntries { entries, errors } => Ok((entries, errors)),
|
||||
DistantResponseData::Error(x) => Err(io::Error::from(x)),
|
||||
_ => Err(mismatched_response()),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn read_file(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, Vec<u8>> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::FileRead { path: path.into() },
|
||||
|data| match data {
|
||||
DistantResponseData::Blob { data } => Ok(data),
|
||||
DistantResponseData::Error(x) => Err(io::Error::from(x)),
|
||||
_ => Err(mismatched_response()),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn read_file_text(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, String> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::FileReadText { path: path.into() },
|
||||
|data| match data {
|
||||
DistantResponseData::Text { data } => Ok(data),
|
||||
DistantResponseData::Error(x) => Err(io::Error::from(x)),
|
||||
_ => Err(mismatched_response()),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn remove(&mut self, path: impl Into<PathBuf>, force: bool) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::Remove { path: path.into(), force },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn rename(&mut self, src: impl Into<PathBuf>, dst: impl Into<PathBuf>) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::Rename { src: src.into(), dst: dst.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn watch(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
recursive: bool,
|
||||
only: impl Into<ChangeKindSet>,
|
||||
except: impl Into<ChangeKindSet>,
|
||||
) -> AsyncReturn<'_, Watcher> {
|
||||
let path = path.into();
|
||||
let only = only.into();
|
||||
let except = except.into();
|
||||
Box::pin(async move { Watcher::watch(self.clone(), path, recursive, only, except).await })
|
||||
}
|
||||
|
||||
fn unwatch(&mut self, path: impl Into<PathBuf>) -> AsyncReturn<'_, ()> {
|
||||
fn inner_unwatch(
|
||||
channel: &mut Channel<DistantMsg<DistantRequestData>, DistantMsg<DistantResponseData>>,
|
||||
path: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
channel,
|
||||
DistantRequestData::Unwatch { path: path.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
let path = path.into();
|
||||
|
||||
Box::pin(async move { inner_unwatch(self, path).await })
|
||||
}
|
||||
|
||||
fn spawn(
|
||||
&mut self,
|
||||
cmd: impl Into<String>,
|
||||
environment: Environment,
|
||||
current_dir: Option<PathBuf>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
) -> AsyncReturn<'_, RemoteProcess> {
|
||||
let cmd = cmd.into();
|
||||
Box::pin(async move {
|
||||
RemoteCommand::new()
|
||||
.environment(environment)
|
||||
.current_dir(current_dir)
|
||||
.persist(persist)
|
||||
.pty(pty)
|
||||
.spawn(self.clone(), cmd)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn spawn_lsp(
|
||||
&mut self,
|
||||
cmd: impl Into<String>,
|
||||
environment: Environment,
|
||||
current_dir: Option<PathBuf>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
) -> AsyncReturn<'_, RemoteLspProcess> {
|
||||
let cmd = cmd.into();
|
||||
Box::pin(async move {
|
||||
RemoteLspCommand::new()
|
||||
.environment(environment)
|
||||
.current_dir(current_dir)
|
||||
.persist(persist)
|
||||
.pty(pty)
|
||||
.spawn(self.clone(), cmd)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn output(
|
||||
&mut self,
|
||||
cmd: impl Into<String>,
|
||||
environment: Environment,
|
||||
current_dir: Option<PathBuf>,
|
||||
pty: Option<PtySize>,
|
||||
) -> AsyncReturn<'_, RemoteOutput> {
|
||||
let cmd = cmd.into();
|
||||
Box::pin(async move {
|
||||
RemoteCommand::new()
|
||||
.environment(environment)
|
||||
.current_dir(current_dir)
|
||||
.persist(false)
|
||||
.pty(pty)
|
||||
.spawn(self.clone(), cmd)
|
||||
.await?
|
||||
.output()
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn system_info(&mut self) -> AsyncReturn<'_, SystemInfo> {
|
||||
make_body!(self, DistantRequestData::SystemInfo {}, |data| match data {
|
||||
DistantResponseData::SystemInfo(x) => Ok(x),
|
||||
DistantResponseData::Error(x) => Err(io::Error::from(x)),
|
||||
_ => Err(mismatched_response()),
|
||||
})
|
||||
}
|
||||
|
||||
fn write_file(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<Vec<u8>>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::FileWrite { path: path.into(), data: data.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn write_file_text(
|
||||
&mut self,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<String>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
DistantRequestData::FileWriteText { path: path.into(), text: data.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
}
|
@ -1,10 +0,0 @@
|
||||
mod lsp;
|
||||
mod process;
|
||||
mod session;
|
||||
mod utils;
|
||||
mod watcher;
|
||||
|
||||
pub use lsp::*;
|
||||
pub use process::*;
|
||||
pub use session::*;
|
||||
pub use watcher::*;
|
File diff suppressed because it is too large
Load Diff
@ -1,514 +0,0 @@
|
||||
use crate::{
|
||||
client::{
|
||||
RemoteLspProcess, RemoteProcess, RemoteProcessError, SessionChannel, UnwatchError,
|
||||
WatchError, Watcher,
|
||||
},
|
||||
data::{
|
||||
ChangeKindSet, DirEntry, Error as Failure, Metadata, PtySize, Request, RequestData,
|
||||
ResponseData, SystemInfo,
|
||||
},
|
||||
net::TransportError,
|
||||
};
|
||||
use derive_more::{Display, Error, From};
|
||||
use std::{future::Future, path::PathBuf, pin::Pin};
|
||||
|
||||
/// Represents an error that can occur related to convenience functions tied to a
|
||||
/// [`SessionChannel`] through [`SessionChannelExt`]
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum SessionChannelExtError {
|
||||
/// Occurs when the remote action fails
|
||||
Failure(#[error(not(source))] Failure),
|
||||
|
||||
/// Occurs when a transport error is encountered
|
||||
TransportError(TransportError),
|
||||
|
||||
/// Occurs when receiving a response that was not expected
|
||||
MismatchedResponse,
|
||||
}
|
||||
|
||||
pub type AsyncReturn<'a, T, E = SessionChannelExtError> =
|
||||
Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>;
|
||||
|
||||
/// Provides convenience functions on top of a [`SessionChannel`]
|
||||
pub trait SessionChannelExt {
|
||||
/// Appends to a remote file using the data from a collection of bytes
|
||||
fn append_file(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<Vec<u8>>,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Appends to a remote file using the data from a string
|
||||
fn append_file_text(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<String>,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Copies a remote file or directory from src to dst
|
||||
fn copy(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
src: impl Into<PathBuf>,
|
||||
dst: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Creates a remote directory, optionally creating all parent components if specified
|
||||
fn create_dir(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
all: bool,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Checks if a path exists on a remote machine
|
||||
fn exists(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, bool>;
|
||||
|
||||
/// Retrieves metadata about a path on a remote machine
|
||||
fn metadata(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
canonicalize: bool,
|
||||
resolve_file_type: bool,
|
||||
) -> AsyncReturn<'_, Metadata>;
|
||||
|
||||
/// Reads entries from a directory, returning a tuple of directory entries and failures
|
||||
fn read_dir(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
depth: usize,
|
||||
absolute: bool,
|
||||
canonicalize: bool,
|
||||
include_root: bool,
|
||||
) -> AsyncReturn<'_, (Vec<DirEntry>, Vec<Failure>)>;
|
||||
|
||||
/// Reads a remote file as a collection of bytes
|
||||
fn read_file(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, Vec<u8>>;
|
||||
|
||||
/// Returns a remote file as a string
|
||||
fn read_file_text(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, String>;
|
||||
|
||||
/// Removes a remote file or directory, supporting removal of non-empty directories if
|
||||
/// force is true
|
||||
fn remove(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
force: bool,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Renames a remote file or directory from src to dst
|
||||
fn rename(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
src: impl Into<PathBuf>,
|
||||
dst: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Watches a remote file or directory
|
||||
fn watch(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
recursive: bool,
|
||||
only: impl Into<ChangeKindSet>,
|
||||
except: impl Into<ChangeKindSet>,
|
||||
) -> AsyncReturn<'_, Watcher, WatchError>;
|
||||
|
||||
/// Unwatches a remote file or directory
|
||||
fn unwatch(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, (), UnwatchError>;
|
||||
|
||||
/// Spawns a process on the remote machine
|
||||
fn spawn(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
cmd: impl Into<String>,
|
||||
args: Vec<impl Into<String>>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
) -> AsyncReturn<'_, RemoteProcess, RemoteProcessError>;
|
||||
|
||||
/// Spawns an LSP process on the remote machine
|
||||
fn spawn_lsp(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
cmd: impl Into<String>,
|
||||
args: Vec<impl Into<String>>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
) -> AsyncReturn<'_, RemoteLspProcess, RemoteProcessError>;
|
||||
|
||||
/// Retrieves information about the remote system
|
||||
fn system_info(&mut self, tenant: impl Into<String>) -> AsyncReturn<'_, SystemInfo>;
|
||||
|
||||
/// Writes a remote file with the data from a collection of bytes
|
||||
fn write_file(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<Vec<u8>>,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
|
||||
/// Writes a remote file with the data from a string
|
||||
fn write_file_text(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<String>,
|
||||
) -> AsyncReturn<'_, ()>;
|
||||
}
|
||||
|
||||
macro_rules! make_body {
|
||||
($self:expr, $tenant:expr, $data:expr, @ok) => {
|
||||
make_body!($self, $tenant, $data, |data| {
|
||||
match data {
|
||||
ResponseData::Ok => Ok(()),
|
||||
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
|
||||
_ => Err(SessionChannelExtError::MismatchedResponse),
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
($self:expr, $tenant:expr, $data:expr, $and_then:expr) => {{
|
||||
let req = Request::new($tenant, vec![$data]);
|
||||
Box::pin(async move {
|
||||
$self
|
||||
.send(req)
|
||||
.await
|
||||
.map_err(SessionChannelExtError::from)
|
||||
.and_then(|res| {
|
||||
if res.payload.len() == 1 {
|
||||
Ok(res.payload.into_iter().next().unwrap())
|
||||
} else {
|
||||
Err(SessionChannelExtError::MismatchedResponse)
|
||||
}
|
||||
})
|
||||
.and_then($and_then)
|
||||
})
|
||||
}};
|
||||
}
|
||||
|
||||
impl SessionChannelExt for SessionChannel {
|
||||
fn append_file(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<Vec<u8>>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::FileAppend { path: path.into(), data: data.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn append_file_text(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<String>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::FileAppendText { path: path.into(), text: data.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn copy(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
src: impl Into<PathBuf>,
|
||||
dst: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::Copy { src: src.into(), dst: dst.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn create_dir(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
all: bool,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::DirCreate { path: path.into(), all },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn exists(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, bool> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::Exists { path: path.into() },
|
||||
|data| match data {
|
||||
ResponseData::Exists { value } => Ok(value),
|
||||
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
|
||||
_ => Err(SessionChannelExtError::MismatchedResponse),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn metadata(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
canonicalize: bool,
|
||||
resolve_file_type: bool,
|
||||
) -> AsyncReturn<'_, Metadata> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::Metadata {
|
||||
path: path.into(),
|
||||
canonicalize,
|
||||
resolve_file_type
|
||||
},
|
||||
|data| match data {
|
||||
ResponseData::Metadata(x) => Ok(x),
|
||||
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
|
||||
_ => Err(SessionChannelExtError::MismatchedResponse),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn read_dir(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
depth: usize,
|
||||
absolute: bool,
|
||||
canonicalize: bool,
|
||||
include_root: bool,
|
||||
) -> AsyncReturn<'_, (Vec<DirEntry>, Vec<Failure>)> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::DirRead {
|
||||
path: path.into(),
|
||||
depth,
|
||||
absolute,
|
||||
canonicalize,
|
||||
include_root
|
||||
},
|
||||
|data| match data {
|
||||
ResponseData::DirEntries { entries, errors } => Ok((entries, errors)),
|
||||
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
|
||||
_ => Err(SessionChannelExtError::MismatchedResponse),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn read_file(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, Vec<u8>> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::FileRead { path: path.into() },
|
||||
|data| match data {
|
||||
ResponseData::Blob { data } => Ok(data),
|
||||
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
|
||||
_ => Err(SessionChannelExtError::MismatchedResponse),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn read_file_text(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, String> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::FileReadText { path: path.into() },
|
||||
|data| match data {
|
||||
ResponseData::Text { data } => Ok(data),
|
||||
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
|
||||
_ => Err(SessionChannelExtError::MismatchedResponse),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn remove(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
force: bool,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::Remove { path: path.into(), force },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn rename(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
src: impl Into<PathBuf>,
|
||||
dst: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::Rename { src: src.into(), dst: dst.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn watch(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
recursive: bool,
|
||||
only: impl Into<ChangeKindSet>,
|
||||
except: impl Into<ChangeKindSet>,
|
||||
) -> AsyncReturn<'_, Watcher, WatchError> {
|
||||
let tenant = tenant.into();
|
||||
let path = path.into();
|
||||
let only = only.into();
|
||||
let except = except.into();
|
||||
Box::pin(async move {
|
||||
Watcher::watch(tenant, self.clone(), path, recursive, only, except).await
|
||||
})
|
||||
}
|
||||
|
||||
fn unwatch(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, (), UnwatchError> {
|
||||
fn inner_unwatch(
|
||||
channel: &mut SessionChannel,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
channel,
|
||||
tenant,
|
||||
RequestData::Unwatch { path: path.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
let tenant = tenant.into();
|
||||
let path = path.into();
|
||||
|
||||
Box::pin(async move {
|
||||
inner_unwatch(self, tenant, path)
|
||||
.await
|
||||
.map_err(UnwatchError::from)
|
||||
})
|
||||
}
|
||||
|
||||
fn spawn(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
cmd: impl Into<String>,
|
||||
args: Vec<impl Into<String>>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
) -> AsyncReturn<'_, RemoteProcess, RemoteProcessError> {
|
||||
let tenant = tenant.into();
|
||||
let cmd = cmd.into();
|
||||
let args = args.into_iter().map(Into::into).collect();
|
||||
Box::pin(async move {
|
||||
RemoteProcess::spawn(tenant, self.clone(), cmd, args, persist, pty).await
|
||||
})
|
||||
}
|
||||
|
||||
fn spawn_lsp(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
cmd: impl Into<String>,
|
||||
args: Vec<impl Into<String>>,
|
||||
persist: bool,
|
||||
pty: Option<PtySize>,
|
||||
) -> AsyncReturn<'_, RemoteLspProcess, RemoteProcessError> {
|
||||
let tenant = tenant.into();
|
||||
let cmd = cmd.into();
|
||||
let args = args.into_iter().map(Into::into).collect();
|
||||
Box::pin(async move {
|
||||
RemoteLspProcess::spawn(tenant, self.clone(), cmd, args, persist, pty).await
|
||||
})
|
||||
}
|
||||
|
||||
fn system_info(&mut self, tenant: impl Into<String>) -> AsyncReturn<'_, SystemInfo> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::SystemInfo {},
|
||||
|data| match data {
|
||||
ResponseData::SystemInfo(x) => Ok(x),
|
||||
ResponseData::Error(x) => Err(SessionChannelExtError::Failure(x)),
|
||||
_ => Err(SessionChannelExtError::MismatchedResponse),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn write_file(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<Vec<u8>>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::FileWrite { path: path.into(), data: data.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
|
||||
fn write_file_text(
|
||||
&mut self,
|
||||
tenant: impl Into<String>,
|
||||
path: impl Into<PathBuf>,
|
||||
data: impl Into<String>,
|
||||
) -> AsyncReturn<'_, ()> {
|
||||
make_body!(
|
||||
self,
|
||||
tenant,
|
||||
RequestData::FileWriteText { path: path.into(), text: data.into() },
|
||||
@ok
|
||||
)
|
||||
}
|
||||
}
|
@ -1,235 +0,0 @@
|
||||
use crate::net::{SecretKey32, UnprotectedToHexKey};
|
||||
use derive_more::{Display, Error};
|
||||
use std::{
|
||||
env,
|
||||
net::{IpAddr, SocketAddr},
|
||||
ops::Deref,
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr,
|
||||
};
|
||||
use tokio::{io, net::lookup_host};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct SessionInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub key: SecretKey32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq)]
|
||||
pub enum SessionInfoParseError {
|
||||
#[display(fmt = "Prefix of string is invalid")]
|
||||
BadPrefix,
|
||||
|
||||
#[display(fmt = "Bad hex key for session")]
|
||||
BadHexKey,
|
||||
|
||||
#[display(fmt = "Invalid key for session")]
|
||||
InvalidKey,
|
||||
|
||||
#[display(fmt = "Invalid port for session")]
|
||||
InvalidPort,
|
||||
|
||||
#[display(fmt = "Missing address for session")]
|
||||
MissingAddr,
|
||||
|
||||
#[display(fmt = "Missing key for session")]
|
||||
MissingKey,
|
||||
|
||||
#[display(fmt = "Missing port for session")]
|
||||
MissingPort,
|
||||
}
|
||||
|
||||
impl From<SessionInfoParseError> for io::Error {
|
||||
fn from(x: SessionInfoParseError) -> Self {
|
||||
io::Error::new(io::ErrorKind::InvalidData, x)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SessionInfo {
|
||||
type Err = SessionInfoParseError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut tokens = s.trim().split(' ').take(5);
|
||||
|
||||
// First, validate that we have the appropriate prefix
|
||||
if tokens.next().ok_or(SessionInfoParseError::BadPrefix)? != "DISTANT" {
|
||||
return Err(SessionInfoParseError::BadPrefix);
|
||||
}
|
||||
if tokens.next().ok_or(SessionInfoParseError::BadPrefix)? != "CONNECT" {
|
||||
return Err(SessionInfoParseError::BadPrefix);
|
||||
}
|
||||
|
||||
// Second, load up the address without parsing it
|
||||
let host = tokens
|
||||
.next()
|
||||
.ok_or(SessionInfoParseError::MissingAddr)?
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Third, load up the port and parse it into a number
|
||||
let port = tokens
|
||||
.next()
|
||||
.ok_or(SessionInfoParseError::MissingPort)?
|
||||
.trim()
|
||||
.parse::<u16>()
|
||||
.map_err(|_| SessionInfoParseError::InvalidPort)?;
|
||||
|
||||
// Fourth, load up the key and convert it back into a secret key from a hex slice
|
||||
let key = SecretKey32::from_slice(
|
||||
&hex::decode(
|
||||
tokens
|
||||
.next()
|
||||
.ok_or(SessionInfoParseError::MissingKey)?
|
||||
.trim(),
|
||||
)
|
||||
.map_err(|_| SessionInfoParseError::BadHexKey)?,
|
||||
)
|
||||
.map_err(|_| SessionInfoParseError::InvalidKey)?;
|
||||
|
||||
Ok(SessionInfo { host, port, key })
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionInfo {
|
||||
/// Loads session from environment variables
|
||||
pub fn from_environment() -> io::Result<Self> {
|
||||
fn to_err(x: env::VarError) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, x)
|
||||
}
|
||||
|
||||
let host = env::var("DISTANT_HOST").map_err(to_err)?;
|
||||
let port = env::var("DISTANT_PORT").map_err(to_err)?;
|
||||
let key = env::var("DISTANT_KEY").map_err(to_err)?;
|
||||
Ok(format!("DISTANT CONNECT {} {} {}", host, port, key).parse()?)
|
||||
}
|
||||
|
||||
/// Loads session from the next line available in this program's stdin
|
||||
pub fn from_stdin() -> io::Result<Self> {
|
||||
let mut line = String::new();
|
||||
std::io::stdin().read_line(&mut line)?;
|
||||
line.parse()
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
|
||||
}
|
||||
|
||||
/// Consumes the session and returns the key
|
||||
pub fn into_key(self) -> SecretKey32 {
|
||||
self.key
|
||||
}
|
||||
|
||||
/// Returns the ip address associated with the session based on the host
|
||||
pub async fn to_ip_addr(&self) -> io::Result<IpAddr> {
|
||||
let addr = match self.host.parse::<IpAddr>() {
|
||||
Ok(addr) => addr,
|
||||
Err(_) => lookup_host((self.host.as_str(), self.port))
|
||||
.await?
|
||||
.next()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Failed to lookup_host"))?
|
||||
.ip(),
|
||||
};
|
||||
|
||||
Ok(addr)
|
||||
}
|
||||
|
||||
/// Returns socket address associated with the session
|
||||
pub async fn to_socket_addr(&self) -> io::Result<SocketAddr> {
|
||||
let addr = self.to_ip_addr().await?;
|
||||
Ok(SocketAddr::from((addr, self.port)))
|
||||
}
|
||||
|
||||
/// Converts the session's key to a hex string
|
||||
pub fn key_to_unprotected_string(&self) -> String {
|
||||
self.key.unprotected_to_hex_key()
|
||||
}
|
||||
|
||||
/// Converts to unprotected string that exposes the key in the form of
|
||||
/// `DISTANT CONNECT <host> <port> <key>`
|
||||
pub fn to_unprotected_string(&self) -> String {
|
||||
format!(
|
||||
"DISTANT CONNECT {} {} {}",
|
||||
self.host,
|
||||
self.port,
|
||||
self.key.unprotected_to_hex_key()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides operations related to working with a session that is disk-based
|
||||
pub struct SessionInfoFile {
|
||||
path: PathBuf,
|
||||
session: SessionInfo,
|
||||
}
|
||||
|
||||
impl AsRef<Path> for SessionInfoFile {
|
||||
fn as_ref(&self) -> &Path {
|
||||
self.as_path()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<SessionInfo> for SessionInfoFile {
|
||||
fn as_ref(&self) -> &SessionInfo {
|
||||
self.as_session()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SessionInfoFile {
|
||||
type Target = SessionInfo;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.session
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SessionInfoFile> for SessionInfo {
|
||||
fn from(sf: SessionInfoFile) -> Self {
|
||||
sf.session
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionInfoFile {
|
||||
/// Creates a new inmemory pointer to a session and its file
|
||||
pub fn new(path: impl Into<PathBuf>, session: SessionInfo) -> Self {
|
||||
Self {
|
||||
path: path.into(),
|
||||
session,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a reference to the path to the session file
|
||||
pub fn as_path(&self) -> &Path {
|
||||
self.path.as_path()
|
||||
}
|
||||
|
||||
/// Returns a reference to the session
|
||||
pub fn as_session(&self) -> &SessionInfo {
|
||||
&self.session
|
||||
}
|
||||
|
||||
/// Saves a session by overwriting its current
|
||||
pub async fn save(&self) -> io::Result<()> {
|
||||
self.save_to(self.as_path(), true).await
|
||||
}
|
||||
|
||||
/// Saves a session to to a file at the specified path
|
||||
///
|
||||
/// If all is true, will create all directories leading up to file's location
|
||||
pub async fn save_to(&self, path: impl AsRef<Path>, all: bool) -> io::Result<()> {
|
||||
if all {
|
||||
if let Some(dir) = path.as_ref().parent() {
|
||||
tokio::fs::create_dir_all(dir).await?;
|
||||
}
|
||||
}
|
||||
|
||||
tokio::fs::write(path.as_ref(), self.session.to_unprotected_string()).await
|
||||
}
|
||||
|
||||
/// Loads a session from a file at the specified path
|
||||
pub async fn load_from(path: impl AsRef<Path>) -> io::Result<Self> {
|
||||
let text = tokio::fs::read_to_string(path.as_ref()).await?;
|
||||
|
||||
Ok(Self {
|
||||
path: path.as_ref().to_path_buf(),
|
||||
session: text.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
@ -1,84 +0,0 @@
|
||||
use crate::{client::utils, data::Response};
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
use tokio::{io, sync::mpsc};
|
||||
|
||||
pub struct PostOffice {
|
||||
mailboxes: HashMap<usize, mpsc::Sender<Response>>,
|
||||
}
|
||||
|
||||
impl PostOffice {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
mailboxes: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new mailbox using the given id and buffer size for maximum messages
|
||||
pub fn make_mailbox(&mut self, id: usize, buffer: usize) -> Mailbox {
|
||||
let (tx, rx) = mpsc::channel(buffer);
|
||||
self.mailboxes.insert(id, tx);
|
||||
|
||||
Mailbox { id, rx }
|
||||
}
|
||||
|
||||
/// Delivers a response to appropriate mailbox, returning false if no mailbox is found
|
||||
/// for the response or if the mailbox is no longer receiving responses
|
||||
pub async fn deliver(&mut self, res: Response) -> bool {
|
||||
let id = res.origin_id;
|
||||
|
||||
let success = if let Some(tx) = self.mailboxes.get_mut(&id) {
|
||||
tx.send(res).await.is_ok()
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// If failed, we want to remove the mailbox sender as it is no longer valid
|
||||
if !success {
|
||||
self.mailboxes.remove(&id);
|
||||
}
|
||||
|
||||
success
|
||||
}
|
||||
|
||||
/// Removes all mailboxes from post office that are closed
|
||||
pub fn prune_mailboxes(&mut self) {
|
||||
self.mailboxes.retain(|_, tx| !tx.is_closed())
|
||||
}
|
||||
|
||||
/// Closes out all mailboxes by removing the mailboxes delivery trackers internally
|
||||
pub fn clear_mailboxes(&mut self) {
|
||||
self.mailboxes.clear();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mailbox {
|
||||
/// Represents id associated with the mailbox
|
||||
id: usize,
|
||||
|
||||
/// Underlying mailbox storage
|
||||
rx: mpsc::Receiver<Response>,
|
||||
}
|
||||
|
||||
impl Mailbox {
|
||||
/// Represents id associated with the mailbox
|
||||
pub fn id(&self) -> usize {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Receives next response in mailbox
|
||||
pub async fn next(&mut self) -> Option<Response> {
|
||||
self.rx.recv().await
|
||||
}
|
||||
|
||||
/// Receives next response in mailbox, waiting up to duration before timing out
|
||||
pub async fn next_timeout(&mut self, duration: Duration) -> io::Result<Option<Response>> {
|
||||
utils::timeout(duration, self.next()).await
|
||||
}
|
||||
|
||||
/// Closes the mailbox such that it will not receive any more responses
|
||||
///
|
||||
/// Any responses already in the mailbox will still be returned via `next`
|
||||
pub async fn close(&mut self) {
|
||||
self.rx.close()
|
||||
}
|
||||
}
|
@ -1,504 +0,0 @@
|
||||
use crate::{
|
||||
client::utils,
|
||||
constants::CLIENT_MAILBOX_CAPACITY,
|
||||
data::{Request, Response},
|
||||
net::{Codec, DataStream, Transport, TransportError},
|
||||
};
|
||||
use log::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
convert,
|
||||
net::SocketAddr,
|
||||
ops::{Deref, DerefMut},
|
||||
path::{Path, PathBuf},
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
use tokio::{
|
||||
io,
|
||||
net::TcpStream,
|
||||
sync::{mpsc, Mutex},
|
||||
task::{JoinError, JoinHandle},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
mod ext;
|
||||
pub use ext::{SessionChannelExt, SessionChannelExtError};
|
||||
|
||||
mod info;
|
||||
pub use info::{SessionInfo, SessionInfoFile, SessionInfoParseError};
|
||||
|
||||
mod mailbox;
|
||||
pub use mailbox::Mailbox;
|
||||
use mailbox::PostOffice;
|
||||
|
||||
/// Details about the session
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SessionDetails {
|
||||
/// Indicates session is a TCP type
|
||||
Tcp { addr: SocketAddr, tag: String },
|
||||
|
||||
/// Indicates session is a Unix socket type
|
||||
Socket { path: PathBuf, tag: String },
|
||||
|
||||
/// Indicates session type is inmemory
|
||||
Inmemory { tag: String },
|
||||
|
||||
/// Indicates session type is a custom type (such as ssh)
|
||||
Custom { tag: String },
|
||||
}
|
||||
|
||||
impl SessionDetails {
|
||||
/// Represents the tag associated with the session
|
||||
pub fn tag(&self) -> &str {
|
||||
match self {
|
||||
Self::Tcp { tag, .. } => tag.as_str(),
|
||||
Self::Socket { tag, .. } => tag.as_str(),
|
||||
Self::Inmemory { tag } => tag.as_str(),
|
||||
Self::Custom { tag } => tag.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the socket address associated with the session, if it has one
|
||||
pub fn addr(&self) -> Option<SocketAddr> {
|
||||
match self {
|
||||
Self::Tcp { addr, .. } => Some(*addr),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the path associated with the session, if it has one
|
||||
pub fn path(&self) -> Option<&Path> {
|
||||
match self {
|
||||
Self::Socket { path, .. } => Some(path.as_path()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a session with a remote server that can be used to send requests & receive responses
|
||||
pub struct Session {
|
||||
/// Used to send requests to a server
|
||||
channel: SessionChannel,
|
||||
|
||||
/// Details about the session
|
||||
details: Option<SessionDetails>,
|
||||
|
||||
/// Contains the task that is running to send requests to a server
|
||||
request_task: JoinHandle<()>,
|
||||
|
||||
/// Contains the task that is running to receive responses from a server
|
||||
response_task: JoinHandle<()>,
|
||||
|
||||
/// Contains the task that runs on a timer to prune closed mailboxes
|
||||
prune_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Connect to a remote TCP server using the provided information
|
||||
pub async fn tcp_connect<U>(addr: SocketAddr, codec: U) -> io::Result<Self>
|
||||
where
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
let transport = Transport::<TcpStream, U>::connect(addr, codec).await?;
|
||||
let details = SessionDetails::Tcp {
|
||||
addr,
|
||||
tag: transport.to_connection_tag(),
|
||||
};
|
||||
debug!(
|
||||
"Session has been established with {}",
|
||||
transport
|
||||
.peer_addr()
|
||||
.map(|x| x.to_string())
|
||||
.unwrap_or_else(|_| String::from("???"))
|
||||
);
|
||||
Self::initialize_with_details(transport, Some(details))
|
||||
}
|
||||
|
||||
/// Connect to a remote TCP server, timing out after duration has passed
|
||||
pub async fn tcp_connect_timeout<U>(
|
||||
addr: SocketAddr,
|
||||
codec: U,
|
||||
duration: Duration,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
utils::timeout(duration, Self::tcp_connect(addr, codec))
|
||||
.await
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
|
||||
/// Convert into underlying channel
|
||||
pub fn into_channel(self) -> SessionChannel {
|
||||
self.channel
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl Session {
|
||||
/// Connect to a proxy unix socket
|
||||
pub async fn unix_connect<U>(path: impl AsRef<std::path::Path>, codec: U) -> io::Result<Self>
|
||||
where
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
let p = path.as_ref();
|
||||
let transport = Transport::<tokio::net::UnixStream, U>::connect(p, codec).await?;
|
||||
let details = SessionDetails::Socket {
|
||||
path: p.to_path_buf(),
|
||||
tag: transport.to_connection_tag(),
|
||||
};
|
||||
debug!(
|
||||
"Session has been established with {}",
|
||||
transport
|
||||
.peer_addr()
|
||||
.map(|x| format!("{:?}", x))
|
||||
.unwrap_or_else(|_| String::from("???"))
|
||||
);
|
||||
Self::initialize_with_details(transport, Some(details))
|
||||
}
|
||||
|
||||
/// Connect to a proxy unix socket, timing out after duration has passed
|
||||
pub async fn unix_connect_timeout<U>(
|
||||
path: impl AsRef<std::path::Path>,
|
||||
codec: U,
|
||||
duration: Duration,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
utils::timeout(duration, Self::unix_connect(path, codec))
|
||||
.await
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Initializes a session using the provided transport and no extra details
|
||||
pub fn initialize<T, U>(transport: Transport<T, U>) -> io::Result<Self>
|
||||
where
|
||||
T: DataStream,
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
Self::initialize_with_details(transport, None)
|
||||
}
|
||||
|
||||
/// Initializes a session using the provided transport and extra details
|
||||
pub fn initialize_with_details<T, U>(
|
||||
transport: Transport<T, U>,
|
||||
details: Option<SessionDetails>,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
T: DataStream,
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
let (mut t_read, mut t_write) = transport.into_split();
|
||||
let post_office = Arc::new(Mutex::new(PostOffice::new()));
|
||||
let weak_post_office = Arc::downgrade(&post_office);
|
||||
|
||||
// Start a task that continually checks for responses and delivers them using the
|
||||
// post office
|
||||
let response_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match t_read.receive::<Response>().await {
|
||||
Ok(Some(res)) => {
|
||||
trace!("Incoming response: {:?}", res);
|
||||
let res_id = res.id;
|
||||
let res_origin_id = res.origin_id;
|
||||
|
||||
// Try to send response to appropriate mailbox
|
||||
// NOTE: We don't log failures as errors as using fire(...) for a
|
||||
// session is valid and would not have a mailbox
|
||||
if !post_office.lock().await.deliver(res).await {
|
||||
trace!(
|
||||
"Response {} has no mailbox for origin {}",
|
||||
res_id,
|
||||
res_origin_id
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
debug!("Session closing response task as transport read-half closed!");
|
||||
break;
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Failed to receive response from server: {}", x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up remaining mailbox senders
|
||||
post_office.lock().await.clear_mailboxes();
|
||||
});
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Request>(1);
|
||||
let request_task = tokio::spawn(async move {
|
||||
while let Some(req) = rx.recv().await {
|
||||
if let Err(x) = t_write.send(req).await {
|
||||
error!("Failed to send request to server: {}", x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Create a task that runs once a minute and prunes mailboxes
|
||||
let post_office = Weak::clone(&weak_post_office);
|
||||
let prune_task = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
if let Some(post_office) = Weak::upgrade(&post_office) {
|
||||
post_office.lock().await.prune_mailboxes();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let channel = SessionChannel {
|
||||
tx,
|
||||
post_office: weak_post_office,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
channel,
|
||||
details,
|
||||
request_task,
|
||||
response_task,
|
||||
prune_task,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Returns details about the session, if it has any
|
||||
pub fn details(&self) -> Option<&SessionDetails> {
|
||||
self.details.as_ref()
|
||||
}
|
||||
|
||||
/// Waits for the session to terminate, which results when the receiving end of the network
|
||||
/// connection is closed (or the session is shutdown)
|
||||
pub async fn wait(self) -> Result<(), JoinError> {
|
||||
self.prune_task.abort();
|
||||
tokio::try_join!(self.request_task, self.response_task).map(|_| ())
|
||||
}
|
||||
|
||||
/// Abort the session's current connection by forcing its tasks to abort
|
||||
pub fn abort(&self) {
|
||||
self.request_task.abort();
|
||||
self.response_task.abort();
|
||||
self.prune_task.abort();
|
||||
}
|
||||
|
||||
/// Clones the underlying channel for requests and returns the cloned instance
|
||||
pub fn clone_channel(&self) -> SessionChannel {
|
||||
self.channel.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Session {
|
||||
type Target = SessionChannel;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.channel
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Session {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.channel
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Session> for SessionChannel {
|
||||
fn from(session: Session) -> Self {
|
||||
session.channel
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a sender of requests tied to a session, holding onto a weak reference of
|
||||
/// mailboxes to relay responses, meaning that once the [`Session`] is closed or dropped,
|
||||
/// any sent request will no longer be able to receive responses
|
||||
#[derive(Clone)]
|
||||
pub struct SessionChannel {
|
||||
/// Used to send requests to a server
|
||||
tx: mpsc::Sender<Request>,
|
||||
|
||||
/// Collection of mailboxes for receiving responses to requests
|
||||
post_office: Weak<Mutex<PostOffice>>,
|
||||
}
|
||||
|
||||
impl SessionChannel {
|
||||
/// Returns true if no more requests can be transferred
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.tx.is_closed()
|
||||
}
|
||||
|
||||
/// Sends a request and returns a mailbox that can receive one or more responses, failing if
|
||||
/// unable to send a request or if the session's receiving line to the remote server has
|
||||
/// already been severed
|
||||
pub async fn mail(&mut self, req: Request) -> Result<Mailbox, TransportError> {
|
||||
trace!("Mailing request: {:?}", req);
|
||||
|
||||
// First, create a mailbox using the request's id
|
||||
let mailbox = Weak::upgrade(&self.post_office)
|
||||
.ok_or_else(|| {
|
||||
TransportError::IoError(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"Session's post office is no longer available",
|
||||
))
|
||||
})?
|
||||
.lock()
|
||||
.await
|
||||
.make_mailbox(req.id, CLIENT_MAILBOX_CAPACITY);
|
||||
|
||||
// Second, send the request
|
||||
self.fire(req).await?;
|
||||
|
||||
// Third, return mailbox
|
||||
Ok(mailbox)
|
||||
}
|
||||
|
||||
/// Sends a request and waits for a response, failing if unable to send a request or if
|
||||
/// the session's receiving line to the remote server has already been severed
|
||||
pub async fn send(&mut self, req: Request) -> Result<Response, TransportError> {
|
||||
trace!("Sending request: {:?}", req);
|
||||
|
||||
// Send mail and get back a mailbox
|
||||
let mut mailbox = self.mail(req).await?;
|
||||
|
||||
// Wait for first response, and then drop the mailbox
|
||||
mailbox.next().await.ok_or_else(|| {
|
||||
TransportError::IoError(io::Error::from(io::ErrorKind::ConnectionAborted))
|
||||
})
|
||||
}
|
||||
|
||||
/// Sends a request and waits for a response, timing out after duration has passed
|
||||
pub async fn send_timeout(
|
||||
&mut self,
|
||||
req: Request,
|
||||
duration: Duration,
|
||||
) -> Result<Response, TransportError> {
|
||||
utils::timeout(duration, self.send(req))
|
||||
.await
|
||||
.map_err(TransportError::from)
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
|
||||
/// Sends a request without waiting for a response; this method is able to be used even
|
||||
/// if the session's receiving line to the remote server has been severed
|
||||
pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> {
|
||||
trace!("Firing off request: {:?}", req);
|
||||
self.tx
|
||||
.send(req)
|
||||
.await
|
||||
.map_err(|x| TransportError::IoError(io::Error::new(io::ErrorKind::BrokenPipe, x)))
|
||||
}
|
||||
|
||||
/// Sends a request without waiting for a response, timing out after duration has passed
|
||||
pub async fn fire_timeout(
|
||||
&mut self,
|
||||
req: Request,
|
||||
duration: Duration,
|
||||
) -> Result<(), TransportError> {
|
||||
utils::timeout(duration, self.fire(req))
|
||||
.await
|
||||
.map_err(TransportError::from)
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
constants::test::TENANT,
|
||||
data::{RequestData, ResponseData},
|
||||
};
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn mail_should_return_mailbox_that_receives_responses_until_transport_closes() {
|
||||
let (t1, mut t2) = Transport::make_pair();
|
||||
let mut session = Session::initialize(t1).unwrap();
|
||||
|
||||
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
|
||||
let res = Response::new(TENANT, req.id, vec![ResponseData::Ok]);
|
||||
|
||||
let mut mailbox = session.mail(req).await.unwrap();
|
||||
|
||||
// Get first response
|
||||
match tokio::join!(mailbox.next(), t2.send(res.clone())) {
|
||||
(Some(actual), _) => assert_eq!(actual, res),
|
||||
x => panic!("Unexpected response: {:?}", x),
|
||||
}
|
||||
|
||||
// Get second response
|
||||
match tokio::join!(mailbox.next(), t2.send(res.clone())) {
|
||||
(Some(actual), _) => assert_eq!(actual, res),
|
||||
x => panic!("Unexpected response: {:?}", x),
|
||||
}
|
||||
|
||||
// Trigger the mailbox to wait BEFORE closing our transport to ensure that
|
||||
// we don't get stuck if the mailbox was already waiting
|
||||
let next_task = tokio::spawn(async move { mailbox.next().await });
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
drop(t2);
|
||||
match next_task.await {
|
||||
Ok(None) => {}
|
||||
x => panic!("Unexpected response: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_should_wait_until_response_received() {
|
||||
let (t1, mut t2) = Transport::make_pair();
|
||||
let mut session = Session::initialize(t1).unwrap();
|
||||
|
||||
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
|
||||
let res = Response::new(
|
||||
TENANT,
|
||||
req.id,
|
||||
vec![ResponseData::ProcEntries {
|
||||
entries: Vec::new(),
|
||||
}],
|
||||
);
|
||||
|
||||
let (actual, _) = tokio::join!(session.send(req), t2.send(res.clone()));
|
||||
match actual {
|
||||
Ok(actual) => assert_eq!(actual, res),
|
||||
x => panic!("Unexpected response: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_timeout_should_fail_if_response_not_received_in_time() {
|
||||
let (t1, mut t2) = Transport::make_pair();
|
||||
let mut session = Session::initialize(t1).unwrap();
|
||||
|
||||
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
|
||||
match session.send_timeout(req, Duration::from_millis(30)).await {
|
||||
Err(TransportError::IoError(x)) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
|
||||
x => panic!("Unexpected response: {:?}", x),
|
||||
}
|
||||
|
||||
let req = t2.receive::<Request>().await.unwrap().unwrap();
|
||||
assert_eq!(req.tenant, TENANT);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fire_should_send_request_and_not_wait_for_response() {
|
||||
let (t1, mut t2) = Transport::make_pair();
|
||||
let mut session = Session::initialize(t1).unwrap();
|
||||
|
||||
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
|
||||
match session.fire(req).await {
|
||||
Ok(_) => {}
|
||||
x => panic!("Unexpected response: {:?}", x),
|
||||
}
|
||||
|
||||
let req = t2.receive::<Request>().await.unwrap().unwrap();
|
||||
assert_eq!(req.tenant, TENANT);
|
||||
}
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
use std::{future::Future, time::Duration};
|
||||
use tokio::{io, time};
|
||||
|
||||
// Wraps a future in a tokio timeout call, transforming the error into
|
||||
// an io error
|
||||
pub async fn timeout<T, F>(d: Duration, f: F) -> io::Result<T>
|
||||
where
|
||||
F: Future<Output = T>,
|
||||
{
|
||||
time::timeout(d, f)
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
}
|
@ -0,0 +1,133 @@
|
||||
use crate::{
|
||||
serde_str::{deserialize_from_str, serialize_to_str},
|
||||
Destination,
|
||||
};
|
||||
use distant_net::SecretKey32;
|
||||
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
|
||||
use std::{
|
||||
convert::{TryFrom, TryInto},
|
||||
fmt, io,
|
||||
str::FromStr,
|
||||
};
|
||||
use uriparse::{URIReference, URI};
|
||||
|
||||
/// Represents credentials used for a distant server that is maintaining a single key
|
||||
/// across all connections
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct DistantSingleKeyCredentials {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub key: SecretKey32,
|
||||
pub username: Option<String>,
|
||||
}
|
||||
|
||||
impl fmt::Display for DistantSingleKeyCredentials {
|
||||
/// Converts credentials into string in the form of `distant://[username]:{key}@{host}:{port}`
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "distant://")?;
|
||||
if let Some(username) = self.username.as_ref() {
|
||||
write!(f, "{}", username)?;
|
||||
}
|
||||
write!(f, ":{}@{}:{}", self.key, self.host, self.port)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for DistantSingleKeyCredentials {
|
||||
type Err = io::Error;
|
||||
|
||||
/// Parse `distant://[username]:{key}@{host}` as credentials. Note that this requires the
|
||||
/// `distant` scheme to be included. If parsing without scheme is desired, call the
|
||||
/// [`DistantSingleKeyCredentials::try_from_uri_ref`] method instead with `require_scheme`
|
||||
/// set to false
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Self::try_from_uri_ref(s, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for DistantSingleKeyCredentials {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serialize_to_str(self, serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for DistantSingleKeyCredentials {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
deserialize_from_str(deserializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl DistantSingleKeyCredentials {
|
||||
/// Converts credentials into a [`Destination`] of the form `distant://[username]:{key}@{host}`,
|
||||
/// failing if the credentials would not produce a valid [`Destination`]
|
||||
pub fn try_to_destination(&self) -> io::Result<Destination> {
|
||||
let uri = self.try_to_uri()?;
|
||||
Destination::try_from(uri.as_uri_reference().to_borrowed())
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
|
||||
}
|
||||
|
||||
/// Converts credentials into a [`URI`] of the form `distant://[username]:{key}@{host}`,
|
||||
/// failing if the credentials would not produce a valid [`URI`]
|
||||
pub fn try_to_uri(&self) -> io::Result<URI<'static>> {
|
||||
let uri_string = self.to_string();
|
||||
URI::try_from(uri_string.as_str())
|
||||
.map(URI::into_owned)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
|
||||
}
|
||||
|
||||
/// Parses credentials from a [`URIReference`], failing if the input was not a valid
|
||||
/// [`URIReference`] or if required parameters like `host` or `password` are missing or bad
|
||||
/// format
|
||||
///
|
||||
/// If `require_scheme` is true, will enforce that a scheme is provided. Regardless, if a
|
||||
/// scheme is provided that is not `distant`, this will also fail
|
||||
pub fn try_from_uri_ref<'a, E>(
|
||||
uri_ref: impl TryInto<URIReference<'a>, Error = E>,
|
||||
require_scheme: bool,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
E: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let uri_ref = uri_ref
|
||||
.try_into()
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
|
||||
|
||||
// Check if the scheme is correct, and if missing if we require it
|
||||
if let Some(scheme) = uri_ref.scheme() {
|
||||
if !scheme.as_str().eq_ignore_ascii_case("distant") {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Scheme is not distant",
|
||||
));
|
||||
}
|
||||
} else if require_scheme {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Missing scheme",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
host: uri_ref
|
||||
.host()
|
||||
.map(ToString::to_string)
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Missing host"))?,
|
||||
port: uri_ref
|
||||
.port()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Missing port"))?,
|
||||
key: uri_ref
|
||||
.password()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Missing password"))
|
||||
.and_then(|x| {
|
||||
x.parse()
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))
|
||||
})?,
|
||||
username: uri_ref.username().map(ToString::to_string),
|
||||
})
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,506 @@
|
||||
use derive_more::{Deref, DerefMut, IntoIterator};
|
||||
use notify::{event::Event as NotifyEvent, EventKind as NotifyEventKind};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
fmt,
|
||||
hash::{Hash, Hasher},
|
||||
iter::FromIterator,
|
||||
ops::{BitOr, Sub},
|
||||
path::PathBuf,
|
||||
str::FromStr,
|
||||
};
|
||||
use strum::{EnumString, EnumVariantNames};
|
||||
|
||||
/// Change to one or more paths on the filesystem
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields)]
|
||||
pub struct Change {
|
||||
/// Label describing the kind of change
|
||||
pub kind: ChangeKind,
|
||||
|
||||
/// Paths that were changed
|
||||
pub paths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl Change {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(Change)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NotifyEvent> for Change {
|
||||
fn from(x: NotifyEvent) -> Self {
|
||||
Self {
|
||||
kind: x.kind.into(),
|
||||
paths: x.paths,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
Debug,
|
||||
strum::Display,
|
||||
EnumString,
|
||||
EnumVariantNames,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
|
||||
#[cfg_attr(feature = "clap", clap(rename_all = "snake_case"))]
|
||||
pub enum ChangeKind {
|
||||
/// Something about a file or directory was accessed, but
|
||||
/// no specific details were known
|
||||
Access,
|
||||
|
||||
/// A file was closed for executing
|
||||
AccessCloseExecute,
|
||||
|
||||
/// A file was closed for reading
|
||||
AccessCloseRead,
|
||||
|
||||
/// A file was closed for writing
|
||||
AccessCloseWrite,
|
||||
|
||||
/// A file was opened for executing
|
||||
AccessOpenExecute,
|
||||
|
||||
/// A file was opened for reading
|
||||
AccessOpenRead,
|
||||
|
||||
/// A file was opened for writing
|
||||
AccessOpenWrite,
|
||||
|
||||
/// A file or directory was read
|
||||
AccessRead,
|
||||
|
||||
/// The access time of a file or directory was changed
|
||||
AccessTime,
|
||||
|
||||
/// A file, directory, or something else was created
|
||||
Create,
|
||||
|
||||
/// The content of a file or directory changed
|
||||
Content,
|
||||
|
||||
/// The data of a file or directory was modified, but
|
||||
/// no specific details were known
|
||||
Data,
|
||||
|
||||
/// The metadata of a file or directory was modified, but
|
||||
/// no specific details were known
|
||||
Metadata,
|
||||
|
||||
/// Something about a file or directory was modified, but
|
||||
/// no specific details were known
|
||||
Modify,
|
||||
|
||||
/// A file, directory, or something else was removed
|
||||
Remove,
|
||||
|
||||
/// A file or directory was renamed, but no specific details were known
|
||||
Rename,
|
||||
|
||||
/// A file or directory was renamed, and the provided paths
|
||||
/// are the source and target in that order (from, to)
|
||||
RenameBoth,
|
||||
|
||||
/// A file or directory was renamed, and the provided path
|
||||
/// is the origin of the rename (before being renamed)
|
||||
RenameFrom,
|
||||
|
||||
/// A file or directory was renamed, and the provided path
|
||||
/// is the result of the rename
|
||||
RenameTo,
|
||||
|
||||
/// A file's size changed
|
||||
Size,
|
||||
|
||||
/// The ownership of a file or directory was changed
|
||||
Ownership,
|
||||
|
||||
/// The permissions of a file or directory was changed
|
||||
Permissions,
|
||||
|
||||
/// The write or modify time of a file or directory was changed
|
||||
WriteTime,
|
||||
|
||||
// Catchall in case we have no insight as to the type of change
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl ChangeKind {
|
||||
/// Returns true if the change is a kind of access
|
||||
pub fn is_access_kind(&self) -> bool {
|
||||
self.is_open_access_kind()
|
||||
|| self.is_close_access_kind()
|
||||
|| matches!(self, Self::Access | Self::AccessRead)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of open access
|
||||
pub fn is_open_access_kind(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::AccessOpenExecute | Self::AccessOpenRead | Self::AccessOpenWrite
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of close access
|
||||
pub fn is_close_access_kind(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::AccessCloseExecute | Self::AccessCloseRead | Self::AccessCloseWrite
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of creation
|
||||
pub fn is_create_kind(&self) -> bool {
|
||||
matches!(self, Self::Create)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of modification
|
||||
pub fn is_modify_kind(&self) -> bool {
|
||||
self.is_data_modify_kind() || self.is_metadata_modify_kind() || matches!(self, Self::Modify)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of data modification
|
||||
pub fn is_data_modify_kind(&self) -> bool {
|
||||
matches!(self, Self::Content | Self::Data | Self::Size)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of metadata modification
|
||||
pub fn is_metadata_modify_kind(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::AccessTime
|
||||
| Self::Metadata
|
||||
| Self::Ownership
|
||||
| Self::Permissions
|
||||
| Self::WriteTime
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of rename
|
||||
pub fn is_rename_kind(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::Rename | Self::RenameBoth | Self::RenameFrom | Self::RenameTo
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of removal
|
||||
pub fn is_remove_kind(&self) -> bool {
|
||||
matches!(self, Self::Remove)
|
||||
}
|
||||
|
||||
/// Returns true if the change kind is unknown
|
||||
pub fn is_unknown_kind(&self) -> bool {
|
||||
matches!(self, Self::Unknown)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl ChangeKind {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(ChangeKind)
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr for ChangeKind {
|
||||
type Output = ChangeKindSet;
|
||||
|
||||
fn bitor(self, rhs: Self) -> Self::Output {
|
||||
let mut set = ChangeKindSet::empty();
|
||||
set.insert(self);
|
||||
set.insert(rhs);
|
||||
set
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NotifyEventKind> for ChangeKind {
|
||||
fn from(x: NotifyEventKind) -> Self {
|
||||
use notify::event::{
|
||||
AccessKind, AccessMode, DataChange, MetadataKind, ModifyKind, RenameMode,
|
||||
};
|
||||
match x {
|
||||
// File/directory access events
|
||||
NotifyEventKind::Access(AccessKind::Read) => Self::AccessRead,
|
||||
NotifyEventKind::Access(AccessKind::Open(AccessMode::Execute)) => {
|
||||
Self::AccessOpenExecute
|
||||
}
|
||||
NotifyEventKind::Access(AccessKind::Open(AccessMode::Read)) => Self::AccessOpenRead,
|
||||
NotifyEventKind::Access(AccessKind::Open(AccessMode::Write)) => Self::AccessOpenWrite,
|
||||
NotifyEventKind::Access(AccessKind::Close(AccessMode::Execute)) => {
|
||||
Self::AccessCloseExecute
|
||||
}
|
||||
NotifyEventKind::Access(AccessKind::Close(AccessMode::Read)) => Self::AccessCloseRead,
|
||||
NotifyEventKind::Access(AccessKind::Close(AccessMode::Write)) => Self::AccessCloseWrite,
|
||||
NotifyEventKind::Access(_) => Self::Access,
|
||||
|
||||
// File/directory creation events
|
||||
NotifyEventKind::Create(_) => Self::Create,
|
||||
|
||||
// Rename-oriented events
|
||||
NotifyEventKind::Modify(ModifyKind::Name(RenameMode::Both)) => Self::RenameBoth,
|
||||
NotifyEventKind::Modify(ModifyKind::Name(RenameMode::From)) => Self::RenameFrom,
|
||||
NotifyEventKind::Modify(ModifyKind::Name(RenameMode::To)) => Self::RenameTo,
|
||||
NotifyEventKind::Modify(ModifyKind::Name(_)) => Self::Rename,
|
||||
|
||||
// Data-modification events
|
||||
NotifyEventKind::Modify(ModifyKind::Data(DataChange::Content)) => Self::Content,
|
||||
NotifyEventKind::Modify(ModifyKind::Data(DataChange::Size)) => Self::Size,
|
||||
NotifyEventKind::Modify(ModifyKind::Data(_)) => Self::Data,
|
||||
|
||||
// Metadata-modification events
|
||||
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::AccessTime)) => {
|
||||
Self::AccessTime
|
||||
}
|
||||
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::WriteTime)) => {
|
||||
Self::WriteTime
|
||||
}
|
||||
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::Permissions)) => {
|
||||
Self::Permissions
|
||||
}
|
||||
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::Ownership)) => {
|
||||
Self::Ownership
|
||||
}
|
||||
NotifyEventKind::Modify(ModifyKind::Metadata(_)) => Self::Metadata,
|
||||
|
||||
// General modification events
|
||||
NotifyEventKind::Modify(_) => Self::Modify,
|
||||
|
||||
// File/directory removal events
|
||||
NotifyEventKind::Remove(_) => Self::Remove,
|
||||
|
||||
// Catch-all for other events
|
||||
NotifyEventKind::Any | NotifyEventKind::Other => Self::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a distinct set of different change kinds
|
||||
#[derive(Clone, Debug, Deref, DerefMut, IntoIterator, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct ChangeKindSet(HashSet<ChangeKind>);
|
||||
|
||||
impl ChangeKindSet {
|
||||
/// Produces an empty set of [`ChangeKind`]
|
||||
pub fn empty() -> Self {
|
||||
Self(HashSet::new())
|
||||
}
|
||||
|
||||
/// Produces a set of all [`ChangeKind`]
|
||||
pub fn all() -> Self {
|
||||
vec![
|
||||
ChangeKind::Access,
|
||||
ChangeKind::AccessCloseExecute,
|
||||
ChangeKind::AccessCloseRead,
|
||||
ChangeKind::AccessCloseWrite,
|
||||
ChangeKind::AccessOpenExecute,
|
||||
ChangeKind::AccessOpenRead,
|
||||
ChangeKind::AccessOpenWrite,
|
||||
ChangeKind::AccessRead,
|
||||
ChangeKind::AccessTime,
|
||||
ChangeKind::Create,
|
||||
ChangeKind::Content,
|
||||
ChangeKind::Data,
|
||||
ChangeKind::Metadata,
|
||||
ChangeKind::Modify,
|
||||
ChangeKind::Remove,
|
||||
ChangeKind::Rename,
|
||||
ChangeKind::RenameBoth,
|
||||
ChangeKind::RenameFrom,
|
||||
ChangeKind::RenameTo,
|
||||
ChangeKind::Size,
|
||||
ChangeKind::Ownership,
|
||||
ChangeKind::Permissions,
|
||||
ChangeKind::WriteTime,
|
||||
ChangeKind::Unknown,
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the access kinds
|
||||
pub fn access_set() -> Self {
|
||||
Self::access_open_set()
|
||||
| Self::access_close_set()
|
||||
| ChangeKind::AccessRead
|
||||
| ChangeKind::Access
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the open access kinds
|
||||
pub fn access_open_set() -> Self {
|
||||
ChangeKind::AccessOpenExecute | ChangeKind::AccessOpenRead | ChangeKind::AccessOpenWrite
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the close access kinds
|
||||
pub fn access_close_set() -> Self {
|
||||
ChangeKind::AccessCloseExecute | ChangeKind::AccessCloseRead | ChangeKind::AccessCloseWrite
|
||||
}
|
||||
|
||||
// Produces a changeset containing all of the modification kinds
|
||||
pub fn modify_set() -> Self {
|
||||
Self::modify_data_set() | Self::modify_metadata_set() | ChangeKind::Modify
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the data modification kinds
|
||||
pub fn modify_data_set() -> Self {
|
||||
ChangeKind::Content | ChangeKind::Data | ChangeKind::Size
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the metadata modification kinds
|
||||
pub fn modify_metadata_set() -> Self {
|
||||
ChangeKind::AccessTime
|
||||
| ChangeKind::Metadata
|
||||
| ChangeKind::Ownership
|
||||
| ChangeKind::Permissions
|
||||
| ChangeKind::WriteTime
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the rename kinds
|
||||
pub fn rename_set() -> Self {
|
||||
ChangeKind::Rename | ChangeKind::RenameBoth | ChangeKind::RenameFrom | ChangeKind::RenameTo
|
||||
}
|
||||
|
||||
/// Consumes set and returns a vec of the kinds of changes
|
||||
pub fn into_vec(self) -> Vec<ChangeKind> {
|
||||
self.0.into_iter().collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl ChangeKindSet {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(ChangeKindSet)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ChangeKindSet {
|
||||
/// Outputs a comma-separated series of [`ChangeKind`] as string that are sorted
|
||||
/// such that this will always be consistent output
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let mut kinds = self
|
||||
.0
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<String>>();
|
||||
kinds.sort_unstable();
|
||||
write!(f, "{}", kinds.join(","))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for ChangeKindSet {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.to_string() == other.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for ChangeKindSet {}
|
||||
|
||||
impl Hash for ChangeKindSet {
|
||||
/// Hashes based on the output of [`fmt::Display`]
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.to_string().hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr<ChangeKindSet> for ChangeKindSet {
|
||||
type Output = Self;
|
||||
|
||||
fn bitor(mut self, rhs: ChangeKindSet) -> Self::Output {
|
||||
self.extend(rhs.0);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr<ChangeKind> for ChangeKindSet {
|
||||
type Output = Self;
|
||||
|
||||
fn bitor(mut self, rhs: ChangeKind) -> Self::Output {
|
||||
self.0.insert(rhs);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr<ChangeKindSet> for ChangeKind {
|
||||
type Output = ChangeKindSet;
|
||||
|
||||
fn bitor(self, rhs: ChangeKindSet) -> Self::Output {
|
||||
rhs | self
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<ChangeKindSet> for ChangeKindSet {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: Self) -> Self::Output {
|
||||
ChangeKindSet(&self.0 - &other.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<&'_ ChangeKindSet> for &ChangeKindSet {
|
||||
type Output = ChangeKindSet;
|
||||
|
||||
fn sub(self, other: &ChangeKindSet) -> Self::Output {
|
||||
ChangeKindSet(&self.0 - &other.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for ChangeKindSet {
|
||||
type Err = strum::ParseError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut change_set = HashSet::new();
|
||||
|
||||
for word in s.split(',') {
|
||||
change_set.insert(ChangeKind::from_str(word.trim())?);
|
||||
}
|
||||
|
||||
Ok(ChangeKindSet(change_set))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<ChangeKind> for ChangeKindSet {
|
||||
fn from_iter<I: IntoIterator<Item = ChangeKind>>(iter: I) -> Self {
|
||||
let mut change_set = HashSet::new();
|
||||
|
||||
for i in iter {
|
||||
change_set.insert(i);
|
||||
}
|
||||
|
||||
ChangeKindSet(change_set)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ChangeKind> for ChangeKindSet {
|
||||
fn from(change_kind: ChangeKind) -> Self {
|
||||
let mut set = Self::empty();
|
||||
set.insert(change_kind);
|
||||
set
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<ChangeKind>> for ChangeKindSet {
|
||||
fn from(changes: Vec<ChangeKind>) -> Self {
|
||||
changes.into_iter().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChangeKindSet {
|
||||
fn default() -> Self {
|
||||
Self::empty()
|
||||
}
|
||||
}
|
@ -0,0 +1,106 @@
|
||||
use crate::{data::Cmd, DistantMsg, DistantRequestData};
|
||||
use clap::{
|
||||
error::{Error, ErrorKind},
|
||||
Arg, ArgAction, ArgMatches, Args, Command, FromArgMatches, Subcommand,
|
||||
};
|
||||
|
||||
impl FromArgMatches for Cmd {
|
||||
fn from_arg_matches(matches: &ArgMatches) -> Result<Self, Error> {
|
||||
let mut matches = matches.clone();
|
||||
Self::from_arg_matches_mut(&mut matches)
|
||||
}
|
||||
fn from_arg_matches_mut(matches: &mut ArgMatches) -> Result<Self, Error> {
|
||||
let cmd = matches.get_one::<String>("cmd").ok_or_else(|| {
|
||||
Error::raw(
|
||||
ErrorKind::MissingRequiredArgument,
|
||||
"program must be specified",
|
||||
)
|
||||
})?;
|
||||
let args: Vec<String> = matches
|
||||
.get_many::<String>("arg")
|
||||
.unwrap_or_default()
|
||||
.map(ToString::to_string)
|
||||
.collect();
|
||||
Ok(Self::new(format!("{cmd} {}", args.join(" "))))
|
||||
}
|
||||
fn update_from_arg_matches(&mut self, matches: &ArgMatches) -> Result<(), Error> {
|
||||
let mut matches = matches.clone();
|
||||
self.update_from_arg_matches_mut(&mut matches)
|
||||
}
|
||||
fn update_from_arg_matches_mut(&mut self, _matches: &mut ArgMatches) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Args for Cmd {
|
||||
fn augment_args(cmd: Command<'_>) -> Command<'_> {
|
||||
cmd.arg(
|
||||
Arg::new("cmd")
|
||||
.required(true)
|
||||
.value_name("CMD")
|
||||
.action(ArgAction::Set),
|
||||
)
|
||||
.trailing_var_arg(true)
|
||||
.arg(
|
||||
Arg::new("arg")
|
||||
.value_name("ARGS")
|
||||
.multiple_values(true)
|
||||
.action(ArgAction::Append),
|
||||
)
|
||||
}
|
||||
fn augment_args_for_update(cmd: Command<'_>) -> Command<'_> {
|
||||
cmd
|
||||
}
|
||||
}
|
||||
|
||||
impl FromArgMatches for DistantMsg<DistantRequestData> {
|
||||
fn from_arg_matches(matches: &ArgMatches) -> Result<Self, Error> {
|
||||
match matches.subcommand() {
|
||||
Some(("single", args)) => Ok(Self::Single(DistantRequestData::from_arg_matches(args)?)),
|
||||
Some((_, _)) => Err(Error::raw(
|
||||
ErrorKind::UnrecognizedSubcommand,
|
||||
"Valid subcommand is `single`",
|
||||
)),
|
||||
None => Err(Error::raw(
|
||||
ErrorKind::MissingSubcommand,
|
||||
"Valid subcommand is `single`",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_from_arg_matches(&mut self, matches: &ArgMatches) -> Result<(), Error> {
|
||||
match matches.subcommand() {
|
||||
Some(("single", args)) => {
|
||||
*self = Self::Single(DistantRequestData::from_arg_matches(args)?)
|
||||
}
|
||||
Some((_, _)) => {
|
||||
return Err(Error::raw(
|
||||
ErrorKind::UnrecognizedSubcommand,
|
||||
"Valid subcommand is `single`",
|
||||
))
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Subcommand for DistantMsg<DistantRequestData> {
|
||||
fn augment_subcommands(cmd: Command<'_>) -> Command<'_> {
|
||||
cmd.subcommand(DistantRequestData::augment_subcommands(Command::new(
|
||||
"single",
|
||||
)))
|
||||
.subcommand_required(true)
|
||||
}
|
||||
|
||||
fn augment_subcommands_for_update(cmd: Command<'_>) -> Command<'_> {
|
||||
cmd.subcommand(DistantRequestData::augment_subcommands(Command::new(
|
||||
"single",
|
||||
)))
|
||||
.subcommand_required(true)
|
||||
}
|
||||
|
||||
fn has_subcommand(name: &str) -> bool {
|
||||
matches!(name, "single")
|
||||
}
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
use derive_more::{Display, From, Into};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
/// Represents some command with arguments to execute
|
||||
#[derive(Clone, Debug, Display, From, Into, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct Cmd(String);
|
||||
|
||||
impl Cmd {
|
||||
/// Creates a new command from the given `cmd`
|
||||
pub fn new(cmd: impl Into<String>) -> Self {
|
||||
Self(cmd.into())
|
||||
}
|
||||
|
||||
/// Returns reference to the program portion of the command
|
||||
pub fn program(&self) -> &str {
|
||||
match self.0.split_once(' ') {
|
||||
Some((program, _)) => program.trim(),
|
||||
None => self.0.trim(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to the arguments portion of the command
|
||||
pub fn arguments(&self) -> &str {
|
||||
match self.0.split_once(' ') {
|
||||
Some((_, arguments)) => arguments.trim(),
|
||||
None => "",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl Cmd {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(Cmd)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Cmd {
|
||||
type Target = String;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Cmd {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
@ -0,0 +1,269 @@
|
||||
use derive_more::Display;
|
||||
use notify::ErrorKind as NotifyErrorKind;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io;
|
||||
|
||||
/// General purpose error type that can be sent across the wire
|
||||
#[derive(Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[display(fmt = "{}: {}", kind, description)]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields)]
|
||||
pub struct Error {
|
||||
/// Label describing the kind of error
|
||||
pub kind: ErrorKind,
|
||||
|
||||
/// Description of the error itself
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl Error {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(Error)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Error {
|
||||
fn from(x: &'a str) -> Self {
|
||||
Self::from(x.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Error {
|
||||
fn from(x: String) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Other,
|
||||
description: x,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(x: io::Error) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::from(x.kind()),
|
||||
description: x.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error> for io::Error {
|
||||
fn from(x: Error) -> Self {
|
||||
Self::new(x.kind.into(), x.description)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<notify::Error> for Error {
|
||||
fn from(x: notify::Error) -> Self {
|
||||
let err = match x.kind {
|
||||
NotifyErrorKind::Generic(x) => Self {
|
||||
kind: ErrorKind::Other,
|
||||
description: x,
|
||||
},
|
||||
NotifyErrorKind::Io(x) => Self::from(x),
|
||||
NotifyErrorKind::PathNotFound => Self {
|
||||
kind: ErrorKind::Other,
|
||||
description: String::from("Path not found"),
|
||||
},
|
||||
NotifyErrorKind::WatchNotFound => Self {
|
||||
kind: ErrorKind::Other,
|
||||
description: String::from("Watch not found"),
|
||||
},
|
||||
NotifyErrorKind::InvalidConfig(_) => Self {
|
||||
kind: ErrorKind::Other,
|
||||
description: String::from("Invalid config"),
|
||||
},
|
||||
NotifyErrorKind::MaxFilesWatch => Self {
|
||||
kind: ErrorKind::Other,
|
||||
description: String::from("Max files watched"),
|
||||
},
|
||||
};
|
||||
|
||||
Self {
|
||||
kind: err.kind,
|
||||
description: format!(
|
||||
"{}\n\nPaths: {}",
|
||||
err.description,
|
||||
x.paths
|
||||
.into_iter()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(", ")
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<walkdir::Error> for Error {
|
||||
fn from(x: walkdir::Error) -> Self {
|
||||
if x.io_error().is_some() {
|
||||
x.into_io_error().map(Self::from).unwrap()
|
||||
} else {
|
||||
Self {
|
||||
kind: ErrorKind::Loop,
|
||||
description: format!("{}", x),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio::task::JoinError> for Error {
|
||||
fn from(x: tokio::task::JoinError) -> Self {
|
||||
Self {
|
||||
kind: if x.is_cancelled() {
|
||||
ErrorKind::TaskCancelled
|
||||
} else {
|
||||
ErrorKind::TaskPanicked
|
||||
},
|
||||
description: format!("{}", x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// All possible kinds of errors that can be returned
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields)]
|
||||
pub enum ErrorKind {
|
||||
/// An entity was not found, often a file
|
||||
NotFound,
|
||||
|
||||
/// The operation lacked the necessary privileges to complete
|
||||
PermissionDenied,
|
||||
|
||||
/// The connection was refused by the remote server
|
||||
ConnectionRefused,
|
||||
|
||||
/// The connection was reset by the remote server
|
||||
ConnectionReset,
|
||||
|
||||
/// The connection was aborted (terminated) by the remote server
|
||||
ConnectionAborted,
|
||||
|
||||
/// The network operation failed because it was not connected yet
|
||||
NotConnected,
|
||||
|
||||
/// A socket address could not be bound because the address is already in use elsewhere
|
||||
AddrInUse,
|
||||
|
||||
/// A nonexistent interface was requested or the requested address was not local
|
||||
AddrNotAvailable,
|
||||
|
||||
/// The operation failed because a pipe was closed
|
||||
BrokenPipe,
|
||||
|
||||
/// An entity already exists, often a file
|
||||
AlreadyExists,
|
||||
|
||||
/// The operation needs to block to complete, but the blocking operation was requested to not
|
||||
/// occur
|
||||
WouldBlock,
|
||||
|
||||
/// A parameter was incorrect
|
||||
InvalidInput,
|
||||
|
||||
/// Data not valid for the operation were encountered
|
||||
InvalidData,
|
||||
|
||||
/// The I/O operation's timeout expired, causing it to be cancelled
|
||||
TimedOut,
|
||||
|
||||
/// An error returned when an operation could not be completed because a
|
||||
/// call to `write` returned `Ok(0)`
|
||||
WriteZero,
|
||||
|
||||
/// This operation was interrupted
|
||||
Interrupted,
|
||||
|
||||
/// Any I/O error not part of this list
|
||||
Other,
|
||||
|
||||
/// An error returned when an operation could not be completed because an "end of file" was
|
||||
/// reached prematurely
|
||||
UnexpectedEof,
|
||||
|
||||
/// This operation is unsupported on this platform
|
||||
Unsupported,
|
||||
|
||||
/// An operation could not be completed, because it failed to allocate enough memory
|
||||
OutOfMemory,
|
||||
|
||||
/// When a loop is encountered when walking a directory
|
||||
Loop,
|
||||
|
||||
/// When a task is cancelled
|
||||
TaskCancelled,
|
||||
|
||||
/// When a task panics
|
||||
TaskPanicked,
|
||||
|
||||
/// Catchall for an error that has no specific type
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl ErrorKind {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(ErrorKind)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::ErrorKind> for ErrorKind {
|
||||
fn from(kind: io::ErrorKind) -> Self {
|
||||
match kind {
|
||||
io::ErrorKind::NotFound => Self::NotFound,
|
||||
io::ErrorKind::PermissionDenied => Self::PermissionDenied,
|
||||
io::ErrorKind::ConnectionRefused => Self::ConnectionRefused,
|
||||
io::ErrorKind::ConnectionReset => Self::ConnectionReset,
|
||||
io::ErrorKind::ConnectionAborted => Self::ConnectionAborted,
|
||||
io::ErrorKind::NotConnected => Self::NotConnected,
|
||||
io::ErrorKind::AddrInUse => Self::AddrInUse,
|
||||
io::ErrorKind::AddrNotAvailable => Self::AddrNotAvailable,
|
||||
io::ErrorKind::BrokenPipe => Self::BrokenPipe,
|
||||
io::ErrorKind::AlreadyExists => Self::AlreadyExists,
|
||||
io::ErrorKind::WouldBlock => Self::WouldBlock,
|
||||
io::ErrorKind::InvalidInput => Self::InvalidInput,
|
||||
io::ErrorKind::InvalidData => Self::InvalidData,
|
||||
io::ErrorKind::TimedOut => Self::TimedOut,
|
||||
io::ErrorKind::WriteZero => Self::WriteZero,
|
||||
io::ErrorKind::Interrupted => Self::Interrupted,
|
||||
io::ErrorKind::Other => Self::Other,
|
||||
io::ErrorKind::OutOfMemory => Self::OutOfMemory,
|
||||
io::ErrorKind::UnexpectedEof => Self::UnexpectedEof,
|
||||
io::ErrorKind::Unsupported => Self::Unsupported,
|
||||
|
||||
// This exists because io::ErrorKind is non_exhaustive
|
||||
_ => Self::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ErrorKind> for io::ErrorKind {
|
||||
fn from(kind: ErrorKind) -> Self {
|
||||
match kind {
|
||||
ErrorKind::NotFound => Self::NotFound,
|
||||
ErrorKind::PermissionDenied => Self::PermissionDenied,
|
||||
ErrorKind::ConnectionRefused => Self::ConnectionRefused,
|
||||
ErrorKind::ConnectionReset => Self::ConnectionReset,
|
||||
ErrorKind::ConnectionAborted => Self::ConnectionAborted,
|
||||
ErrorKind::NotConnected => Self::NotConnected,
|
||||
ErrorKind::AddrInUse => Self::AddrInUse,
|
||||
ErrorKind::AddrNotAvailable => Self::AddrNotAvailable,
|
||||
ErrorKind::BrokenPipe => Self::BrokenPipe,
|
||||
ErrorKind::AlreadyExists => Self::AlreadyExists,
|
||||
ErrorKind::WouldBlock => Self::WouldBlock,
|
||||
ErrorKind::InvalidInput => Self::InvalidInput,
|
||||
ErrorKind::InvalidData => Self::InvalidData,
|
||||
ErrorKind::TimedOut => Self::TimedOut,
|
||||
ErrorKind::WriteZero => Self::WriteZero,
|
||||
ErrorKind::Interrupted => Self::Interrupted,
|
||||
ErrorKind::Other => Self::Other,
|
||||
ErrorKind::OutOfMemory => Self::OutOfMemory,
|
||||
ErrorKind::UnexpectedEof => Self::UnexpectedEof,
|
||||
ErrorKind::Unsupported => Self::Unsupported,
|
||||
_ => Self::Other,
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
use derive_more::IsVariant;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use strum::AsRefStr;
|
||||
|
||||
/// Represents information about a single entry within a directory
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields)]
|
||||
pub struct DirEntry {
|
||||
/// Represents the full path to the entry
|
||||
pub path: PathBuf,
|
||||
|
||||
/// Represents the type of the entry as a file/dir/symlink
|
||||
pub file_type: FileType,
|
||||
|
||||
/// Depth at which this entry was created relative to the root (0 being immediately within
|
||||
/// root)
|
||||
pub depth: usize,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl DirEntry {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(DirEntry)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the type associated with a dir entry
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, AsRefStr, IsVariant, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum FileType {
|
||||
Dir,
|
||||
File,
|
||||
Symlink,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl FileType {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(FileType)
|
||||
}
|
||||
}
|
@ -0,0 +1,244 @@
|
||||
use crate::serde_str::{deserialize_from_str, serialize_to_str};
|
||||
use derive_more::{From, IntoIterator};
|
||||
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt,
|
||||
ops::{Deref, DerefMut},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
/// Contains map information for connections and other use cases
|
||||
#[derive(Clone, Debug, From, IntoIterator, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct Map(HashMap<String, String>);
|
||||
|
||||
impl Map {
|
||||
pub fn new() -> Self {
|
||||
Self(HashMap::new())
|
||||
}
|
||||
|
||||
pub fn into_map(self) -> HashMap<String, String> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl Map {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(Map)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Map {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Map {
|
||||
type Target = HashMap<String, String>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Map {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Map {
|
||||
/// Outputs a `key=value` mapping in the form `key="value",key2="value2"`
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let len = self.0.len();
|
||||
for (i, (key, value)) in self.0.iter().enumerate() {
|
||||
write!(f, "{}=\"{}\"", key, value)?;
|
||||
|
||||
// Include a comma after each but the last pair
|
||||
if i + 1 < len {
|
||||
write!(f, ",")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Map {
|
||||
type Err = &'static str;
|
||||
|
||||
/// Parses a series of `key=value` pairs in the form `key="value",key2=value2` where
|
||||
/// the quotes around the value are optional
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut map = HashMap::new();
|
||||
|
||||
let mut s = s.trim();
|
||||
while !s.is_empty() {
|
||||
// Find {key}={tail...} where tail is everything after =
|
||||
let (key, tail) = s.split_once('=').ok_or("Missing = after key")?;
|
||||
|
||||
// Remove whitespace around the key and ensure it starts with a proper character
|
||||
let key = key.trim();
|
||||
|
||||
if !key.starts_with(char::is_alphabetic) {
|
||||
return Err("Key must start with alphabetic character");
|
||||
}
|
||||
|
||||
// Remove any map whitespace at the front of the tail
|
||||
let tail = tail.trim_start();
|
||||
|
||||
// Determine if we start with a quote " otherwise we will look for the next ,
|
||||
let (value, tail) = match tail.strip_prefix('"') {
|
||||
// If quoted, we maintain the whitespace within the quotes
|
||||
Some(tail) => {
|
||||
// Skip the quote so we can look for the trailing quote
|
||||
let (value, tail) =
|
||||
tail.split_once('"').ok_or("Missing closing \" for value")?;
|
||||
|
||||
// Skip comma if we have one
|
||||
let tail = tail.strip_prefix(',').unwrap_or(tail);
|
||||
|
||||
(value, tail)
|
||||
}
|
||||
|
||||
// If not quoted, we remove all whitespace around the value
|
||||
None => match tail.split_once(',') {
|
||||
Some((value, tail)) => (value.trim(), tail),
|
||||
None => (tail.trim(), ""),
|
||||
},
|
||||
};
|
||||
|
||||
// Insert our new pair and update the slice to be the tail (removing whitespace)
|
||||
map.insert(key.to_string(), value.to_string());
|
||||
s = tail.trim();
|
||||
}
|
||||
|
||||
Ok(Self(map))
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Map {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serialize_to_str(self, serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Map {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
deserialize_from_str(deserializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! map {
|
||||
($($key:literal -> $value:literal),*) => {{
|
||||
let mut _map = ::std::collections::HashMap::new();
|
||||
|
||||
$(
|
||||
_map.insert($key.to_string(), $value.to_string());
|
||||
)*
|
||||
|
||||
$crate::Map::from(_map)
|
||||
}};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_support_being_parsed_from_str() {
|
||||
// Empty string (whitespace only) yields an empty map
|
||||
let map = " ".parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!());
|
||||
|
||||
// Simple key=value should succeed
|
||||
let map = "key=value".parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> "value"));
|
||||
|
||||
// Key can be anything but =
|
||||
let map = "key.with-characters@=value".parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key.with-characters@" -> "value"));
|
||||
|
||||
// Value can be anything but ,
|
||||
let map = "key=value.has -@#$".parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> "value.has -@#$"));
|
||||
|
||||
// Value can include comma if quoted
|
||||
let map = r#"key=",,,,""#.parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> ",,,,"));
|
||||
|
||||
// Supports whitespace around key and value
|
||||
let map = " key = value ".parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> "value"));
|
||||
|
||||
// Supports value capturing whitespace if quoted
|
||||
let map = r#" key = " value " "#.parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> " value "));
|
||||
|
||||
// Multiple key=value should succeed
|
||||
let map = "key=value,key2=value2".parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> "value", "key2" -> "value2"));
|
||||
|
||||
// Quoted key=value should succeed
|
||||
let map = r#"key="value one",key2=value2"#.parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> "value one", "key2" -> "value2"));
|
||||
|
||||
let map = r#"key=value,key2="value two""#.parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> "value", "key2" -> "value two"));
|
||||
|
||||
let map = r#"key="value one",key2="value two""#.parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> "value one", "key2" -> "value two"));
|
||||
|
||||
let map = r#"key="1,2,3",key2="4,5,6""#.parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> "1,2,3", "key2" -> "4,5,6"));
|
||||
|
||||
// Dangling comma is okay
|
||||
let map = "key=value,".parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> "value"));
|
||||
let map = r#"key=",value,","#.parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> ",value,"));
|
||||
|
||||
// Demonstrating greedy
|
||||
let map = "key=value key2=value2".parse::<Map>().unwrap();
|
||||
assert_eq!(map, map!("key" -> "value key2=value2"));
|
||||
|
||||
// Variety of edge cases that should fail
|
||||
let _ = ",".parse::<Map>().unwrap_err();
|
||||
let _ = ",key=value".parse::<Map>().unwrap_err();
|
||||
let _ = "key=value,key2".parse::<Map>().unwrap_err();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_being_displayed_as_a_string() {
|
||||
let map = map!().to_string();
|
||||
assert_eq!(map, "");
|
||||
|
||||
let map = map!("key" -> "value").to_string();
|
||||
assert_eq!(map, r#"key="value""#);
|
||||
|
||||
// Order of key=value output is not guaranteed
|
||||
let map = map!("key" -> "value", "key2" -> "value2").to_string();
|
||||
assert!(
|
||||
map == r#"key="value",key2="value2""# || map == r#"key2="value2",key="value""#,
|
||||
"{:?}",
|
||||
map
|
||||
);
|
||||
|
||||
// Order of key=value output is not guaranteed
|
||||
let map = map!("key" -> ",", "key2" -> ",,").to_string();
|
||||
assert!(
|
||||
map == r#"key=",",key2=",,""# || map == r#"key2=",,",key=",""#,
|
||||
"{:?}",
|
||||
map
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,404 @@
|
||||
use super::{deserialize_u128_option, serialize_u128_option, FileType};
|
||||
use bitflags::bitflags;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
io,
|
||||
path::{Path, PathBuf},
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
/// Represents metadata about some path on a remote machine
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct Metadata {
|
||||
/// Canonicalized path to the file or directory, resolving symlinks, only included
|
||||
/// if flagged during the request
|
||||
pub canonicalized_path: Option<PathBuf>,
|
||||
|
||||
/// Represents the type of the entry as a file/dir/symlink
|
||||
pub file_type: FileType,
|
||||
|
||||
/// Size of the file/directory/symlink in bytes
|
||||
pub len: u64,
|
||||
|
||||
/// Whether or not the file/directory/symlink is marked as unwriteable
|
||||
pub readonly: bool,
|
||||
|
||||
/// Represents the last time (in milliseconds) when the file/directory/symlink was accessed;
|
||||
/// can be optional as certain systems don't support this
|
||||
#[serde(serialize_with = "serialize_u128_option")]
|
||||
#[serde(deserialize_with = "deserialize_u128_option")]
|
||||
pub accessed: Option<u128>,
|
||||
|
||||
/// Represents when (in milliseconds) the file/directory/symlink was created;
|
||||
/// can be optional as certain systems don't support this
|
||||
#[serde(serialize_with = "serialize_u128_option")]
|
||||
#[serde(deserialize_with = "deserialize_u128_option")]
|
||||
pub created: Option<u128>,
|
||||
|
||||
/// Represents the last time (in milliseconds) when the file/directory/symlink was modified;
|
||||
/// can be optional as certain systems don't support this
|
||||
#[serde(serialize_with = "serialize_u128_option")]
|
||||
#[serde(deserialize_with = "deserialize_u128_option")]
|
||||
pub modified: Option<u128>,
|
||||
|
||||
/// Represents metadata that is specific to a unix remote machine
|
||||
pub unix: Option<UnixMetadata>,
|
||||
|
||||
/// Represents metadata that is specific to a windows remote machine
|
||||
pub windows: Option<WindowsMetadata>,
|
||||
}
|
||||
|
||||
impl Metadata {
|
||||
pub async fn read(
|
||||
path: impl AsRef<Path>,
|
||||
canonicalize: bool,
|
||||
resolve_file_type: bool,
|
||||
) -> io::Result<Self> {
|
||||
let metadata = tokio::fs::symlink_metadata(path.as_ref()).await?;
|
||||
let canonicalized_path = if canonicalize {
|
||||
Some(tokio::fs::canonicalize(path.as_ref()).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// If asking for resolved file type and current type is symlink, then we want to refresh
|
||||
// our metadata to get the filetype for the resolved link
|
||||
let file_type = if resolve_file_type && metadata.file_type().is_symlink() {
|
||||
tokio::fs::metadata(path).await?.file_type()
|
||||
} else {
|
||||
metadata.file_type()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
canonicalized_path,
|
||||
accessed: metadata
|
||||
.accessed()
|
||||
.ok()
|
||||
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
|
||||
.map(|d| d.as_millis()),
|
||||
created: metadata
|
||||
.created()
|
||||
.ok()
|
||||
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
|
||||
.map(|d| d.as_millis()),
|
||||
modified: metadata
|
||||
.modified()
|
||||
.ok()
|
||||
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
|
||||
.map(|d| d.as_millis()),
|
||||
len: metadata.len(),
|
||||
readonly: metadata.permissions().readonly(),
|
||||
file_type: if file_type.is_dir() {
|
||||
FileType::Dir
|
||||
} else if file_type.is_file() {
|
||||
FileType::File
|
||||
} else {
|
||||
FileType::Symlink
|
||||
},
|
||||
|
||||
#[cfg(unix)]
|
||||
unix: Some({
|
||||
use std::os::unix::prelude::*;
|
||||
let mode = metadata.mode();
|
||||
crate::data::UnixMetadata::from(mode)
|
||||
}),
|
||||
#[cfg(not(unix))]
|
||||
unix: None,
|
||||
|
||||
#[cfg(windows)]
|
||||
windows: Some({
|
||||
use std::os::windows::prelude::*;
|
||||
let attributes = metadata.file_attributes();
|
||||
crate::data::WindowsMetadata::from(attributes)
|
||||
}),
|
||||
#[cfg(not(windows))]
|
||||
windows: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl Metadata {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents unix-specific metadata about some path on a remote machine
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct UnixMetadata {
|
||||
/// Represents whether or not owner can read from the file
|
||||
pub owner_read: bool,
|
||||
|
||||
/// Represents whether or not owner can write to the file
|
||||
pub owner_write: bool,
|
||||
|
||||
/// Represents whether or not owner can execute the file
|
||||
pub owner_exec: bool,
|
||||
|
||||
/// Represents whether or not associated group can read from the file
|
||||
pub group_read: bool,
|
||||
|
||||
/// Represents whether or not associated group can write to the file
|
||||
pub group_write: bool,
|
||||
|
||||
/// Represents whether or not associated group can execute the file
|
||||
pub group_exec: bool,
|
||||
|
||||
/// Represents whether or not other can read from the file
|
||||
pub other_read: bool,
|
||||
|
||||
/// Represents whether or not other can write to the file
|
||||
pub other_write: bool,
|
||||
|
||||
/// Represents whether or not other can execute the file
|
||||
pub other_exec: bool,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl UnixMetadata {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(UnixMetadata)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for UnixMetadata {
|
||||
/// Create from a unix mode bitset
|
||||
fn from(mode: u32) -> Self {
|
||||
let flags = UnixFilePermissionFlags::from_bits_truncate(mode);
|
||||
Self {
|
||||
owner_read: flags.contains(UnixFilePermissionFlags::OWNER_READ),
|
||||
owner_write: flags.contains(UnixFilePermissionFlags::OWNER_WRITE),
|
||||
owner_exec: flags.contains(UnixFilePermissionFlags::OWNER_EXEC),
|
||||
group_read: flags.contains(UnixFilePermissionFlags::GROUP_READ),
|
||||
group_write: flags.contains(UnixFilePermissionFlags::GROUP_WRITE),
|
||||
group_exec: flags.contains(UnixFilePermissionFlags::GROUP_EXEC),
|
||||
other_read: flags.contains(UnixFilePermissionFlags::OTHER_READ),
|
||||
other_write: flags.contains(UnixFilePermissionFlags::OTHER_WRITE),
|
||||
other_exec: flags.contains(UnixFilePermissionFlags::OTHER_EXEC),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UnixMetadata> for u32 {
|
||||
/// Convert to a unix mode bitset
|
||||
fn from(metadata: UnixMetadata) -> Self {
|
||||
let mut flags = UnixFilePermissionFlags::empty();
|
||||
|
||||
if metadata.owner_read {
|
||||
flags.insert(UnixFilePermissionFlags::OWNER_READ);
|
||||
}
|
||||
if metadata.owner_write {
|
||||
flags.insert(UnixFilePermissionFlags::OWNER_WRITE);
|
||||
}
|
||||
if metadata.owner_exec {
|
||||
flags.insert(UnixFilePermissionFlags::OWNER_EXEC);
|
||||
}
|
||||
|
||||
if metadata.group_read {
|
||||
flags.insert(UnixFilePermissionFlags::GROUP_READ);
|
||||
}
|
||||
if metadata.group_write {
|
||||
flags.insert(UnixFilePermissionFlags::GROUP_WRITE);
|
||||
}
|
||||
if metadata.group_exec {
|
||||
flags.insert(UnixFilePermissionFlags::GROUP_EXEC);
|
||||
}
|
||||
|
||||
if metadata.other_read {
|
||||
flags.insert(UnixFilePermissionFlags::OTHER_READ);
|
||||
}
|
||||
if metadata.other_write {
|
||||
flags.insert(UnixFilePermissionFlags::OTHER_WRITE);
|
||||
}
|
||||
if metadata.other_exec {
|
||||
flags.insert(UnixFilePermissionFlags::OTHER_EXEC);
|
||||
}
|
||||
|
||||
flags.bits
|
||||
}
|
||||
}
|
||||
|
||||
impl UnixMetadata {
|
||||
pub fn is_readonly(self) -> bool {
|
||||
!(self.owner_read || self.group_read || self.other_read)
|
||||
}
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
struct UnixFilePermissionFlags: u32 {
|
||||
const OWNER_READ = 0o400;
|
||||
const OWNER_WRITE = 0o200;
|
||||
const OWNER_EXEC = 0o100;
|
||||
const GROUP_READ = 0o40;
|
||||
const GROUP_WRITE = 0o20;
|
||||
const GROUP_EXEC = 0o10;
|
||||
const OTHER_READ = 0o4;
|
||||
const OTHER_WRITE = 0o2;
|
||||
const OTHER_EXEC = 0o1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents windows-specific metadata about some path on a remote machine
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct WindowsMetadata {
|
||||
/// Represents whether or not a file or directory is an archive
|
||||
pub archive: bool,
|
||||
|
||||
/// Represents whether or not a file or directory is compressed
|
||||
pub compressed: bool,
|
||||
|
||||
/// Represents whether or not the file or directory is encrypted
|
||||
pub encrypted: bool,
|
||||
|
||||
/// Represents whether or not a file or directory is hidden
|
||||
pub hidden: bool,
|
||||
|
||||
/// Represents whether or not a directory or user data stream is configured with integrity
|
||||
pub integrity_stream: bool,
|
||||
|
||||
/// Represents whether or not a file does not have other attributes set
|
||||
pub normal: bool,
|
||||
|
||||
/// Represents whether or not a file or directory is not to be indexed by content indexing
|
||||
/// service
|
||||
pub not_content_indexed: bool,
|
||||
|
||||
/// Represents whether or not a user data stream is not to be read by the background data
|
||||
/// integrity scanner
|
||||
pub no_scrub_data: bool,
|
||||
|
||||
/// Represents whether or not the data of a file is not available immediately
|
||||
pub offline: bool,
|
||||
|
||||
/// Represents whether or not a file or directory is not fully present locally
|
||||
pub recall_on_data_access: bool,
|
||||
|
||||
/// Represents whether or not a file or directory has no physical representation on the local
|
||||
/// system (is virtual)
|
||||
pub recall_on_open: bool,
|
||||
|
||||
/// Represents whether or not a file or directory has an associated reparse point, or a file is
|
||||
/// a symbolic link
|
||||
pub reparse_point: bool,
|
||||
|
||||
/// Represents whether or not a file is a sparse file
|
||||
pub sparse_file: bool,
|
||||
|
||||
/// Represents whether or not a file or directory is used partially or exclusively by the
|
||||
/// operating system
|
||||
pub system: bool,
|
||||
|
||||
/// Represents whether or not a file is being used for temporary storage
|
||||
pub temporary: bool,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl WindowsMetadata {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(WindowsMetadata)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for WindowsMetadata {
|
||||
/// Create from a windows file attribute bitset
|
||||
fn from(file_attributes: u32) -> Self {
|
||||
let flags = WindowsFileAttributeFlags::from_bits_truncate(file_attributes);
|
||||
Self {
|
||||
archive: flags.contains(WindowsFileAttributeFlags::ARCHIVE),
|
||||
compressed: flags.contains(WindowsFileAttributeFlags::COMPRESSED),
|
||||
encrypted: flags.contains(WindowsFileAttributeFlags::ENCRYPTED),
|
||||
hidden: flags.contains(WindowsFileAttributeFlags::HIDDEN),
|
||||
integrity_stream: flags.contains(WindowsFileAttributeFlags::INTEGRITY_SYSTEM),
|
||||
normal: flags.contains(WindowsFileAttributeFlags::NORMAL),
|
||||
not_content_indexed: flags.contains(WindowsFileAttributeFlags::NOT_CONTENT_INDEXED),
|
||||
no_scrub_data: flags.contains(WindowsFileAttributeFlags::NO_SCRUB_DATA),
|
||||
offline: flags.contains(WindowsFileAttributeFlags::OFFLINE),
|
||||
recall_on_data_access: flags.contains(WindowsFileAttributeFlags::RECALL_ON_DATA_ACCESS),
|
||||
recall_on_open: flags.contains(WindowsFileAttributeFlags::RECALL_ON_OPEN),
|
||||
reparse_point: flags.contains(WindowsFileAttributeFlags::REPARSE_POINT),
|
||||
sparse_file: flags.contains(WindowsFileAttributeFlags::SPARSE_FILE),
|
||||
system: flags.contains(WindowsFileAttributeFlags::SYSTEM),
|
||||
temporary: flags.contains(WindowsFileAttributeFlags::TEMPORARY),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WindowsMetadata> for u32 {
|
||||
/// Convert to a windows file attribute bitset
|
||||
fn from(metadata: WindowsMetadata) -> Self {
|
||||
let mut flags = WindowsFileAttributeFlags::empty();
|
||||
|
||||
if metadata.archive {
|
||||
flags.insert(WindowsFileAttributeFlags::ARCHIVE);
|
||||
}
|
||||
if metadata.compressed {
|
||||
flags.insert(WindowsFileAttributeFlags::COMPRESSED);
|
||||
}
|
||||
if metadata.encrypted {
|
||||
flags.insert(WindowsFileAttributeFlags::ENCRYPTED);
|
||||
}
|
||||
if metadata.hidden {
|
||||
flags.insert(WindowsFileAttributeFlags::HIDDEN);
|
||||
}
|
||||
if metadata.integrity_stream {
|
||||
flags.insert(WindowsFileAttributeFlags::INTEGRITY_SYSTEM);
|
||||
}
|
||||
if metadata.normal {
|
||||
flags.insert(WindowsFileAttributeFlags::NORMAL);
|
||||
}
|
||||
if metadata.not_content_indexed {
|
||||
flags.insert(WindowsFileAttributeFlags::NOT_CONTENT_INDEXED);
|
||||
}
|
||||
if metadata.no_scrub_data {
|
||||
flags.insert(WindowsFileAttributeFlags::NO_SCRUB_DATA);
|
||||
}
|
||||
if metadata.offline {
|
||||
flags.insert(WindowsFileAttributeFlags::OFFLINE);
|
||||
}
|
||||
if metadata.recall_on_data_access {
|
||||
flags.insert(WindowsFileAttributeFlags::RECALL_ON_DATA_ACCESS);
|
||||
}
|
||||
if metadata.recall_on_open {
|
||||
flags.insert(WindowsFileAttributeFlags::RECALL_ON_OPEN);
|
||||
}
|
||||
if metadata.reparse_point {
|
||||
flags.insert(WindowsFileAttributeFlags::REPARSE_POINT);
|
||||
}
|
||||
if metadata.sparse_file {
|
||||
flags.insert(WindowsFileAttributeFlags::SPARSE_FILE);
|
||||
}
|
||||
if metadata.system {
|
||||
flags.insert(WindowsFileAttributeFlags::SYSTEM);
|
||||
}
|
||||
if metadata.temporary {
|
||||
flags.insert(WindowsFileAttributeFlags::TEMPORARY);
|
||||
}
|
||||
|
||||
flags.bits
|
||||
}
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
struct WindowsFileAttributeFlags: u32 {
|
||||
const ARCHIVE = 0x20;
|
||||
const COMPRESSED = 0x800;
|
||||
const ENCRYPTED = 0x4000;
|
||||
const HIDDEN = 0x2;
|
||||
const INTEGRITY_SYSTEM = 0x8000;
|
||||
const NORMAL = 0x80;
|
||||
const NOT_CONTENT_INDEXED = 0x2000;
|
||||
const NO_SCRUB_DATA = 0x20000;
|
||||
const OFFLINE = 0x1000;
|
||||
const RECALL_ON_DATA_ACCESS = 0x400000;
|
||||
const RECALL_ON_OPEN = 0x40000;
|
||||
const REPARSE_POINT = 0x400;
|
||||
const SPARSE_FILE = 0x200;
|
||||
const SYSTEM = 0x4;
|
||||
const TEMPORARY = 0x100;
|
||||
const VIRTUAL = 0x10000;
|
||||
}
|
||||
}
|
@ -0,0 +1,137 @@
|
||||
use derive_more::{Display, Error};
|
||||
use portable_pty::PtySize as PortablePtySize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt, num::ParseIntError, str::FromStr};
|
||||
|
||||
/// Represents the size associated with a remote PTY
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct PtySize {
|
||||
/// Number of lines of text
|
||||
pub rows: u16,
|
||||
|
||||
/// Number of columns of text
|
||||
pub cols: u16,
|
||||
|
||||
/// Width of a cell in pixels. Note that some systems never fill this value and ignore it.
|
||||
#[serde(default)]
|
||||
pub pixel_width: u16,
|
||||
|
||||
/// Height of a cell in pixels. Note that some systems never fill this value and ignore it.
|
||||
#[serde(default)]
|
||||
pub pixel_height: u16,
|
||||
}
|
||||
|
||||
impl PtySize {
|
||||
/// Creates new size using just rows and columns
|
||||
pub fn from_rows_and_cols(rows: u16, cols: u16) -> Self {
|
||||
Self {
|
||||
rows,
|
||||
cols,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl PtySize {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(PtySize)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PortablePtySize> for PtySize {
|
||||
fn from(size: PortablePtySize) -> Self {
|
||||
Self {
|
||||
rows: size.rows,
|
||||
cols: size.cols,
|
||||
pixel_width: size.pixel_width,
|
||||
pixel_height: size.pixel_height,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PtySize> for PortablePtySize {
|
||||
fn from(size: PtySize) -> Self {
|
||||
Self {
|
||||
rows: size.rows,
|
||||
cols: size.cols,
|
||||
pixel_width: size.pixel_width,
|
||||
pixel_height: size.pixel_height,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PtySize {
|
||||
/// Prints out `rows,cols[,pixel_width,pixel_height]` where the
|
||||
/// pixel width and pixel height are only included if either
|
||||
/// one of them is not zero
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{},{}", self.rows, self.cols)?;
|
||||
if self.pixel_width > 0 || self.pixel_height > 0 {
|
||||
write!(f, ",{},{}", self.pixel_width, self.pixel_height)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PtySize {
|
||||
fn default() -> Self {
|
||||
PtySize {
|
||||
rows: 24,
|
||||
cols: 80,
|
||||
pixel_width: 0,
|
||||
pixel_height: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Display, Error)]
|
||||
pub enum PtySizeParseError {
|
||||
MissingRows,
|
||||
MissingColumns,
|
||||
InvalidRows(ParseIntError),
|
||||
InvalidColumns(ParseIntError),
|
||||
InvalidPixelWidth(ParseIntError),
|
||||
InvalidPixelHeight(ParseIntError),
|
||||
}
|
||||
|
||||
impl FromStr for PtySize {
|
||||
type Err = PtySizeParseError;
|
||||
|
||||
/// Attempts to parse a str into PtySize using one of the following formats:
|
||||
///
|
||||
/// * rows,cols (defaults to 0 for pixel_width & pixel_height)
|
||||
/// * rows,cols,pixel_width,pixel_height
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut tokens = s.split(',');
|
||||
|
||||
Ok(Self {
|
||||
rows: tokens
|
||||
.next()
|
||||
.ok_or(PtySizeParseError::MissingRows)?
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(PtySizeParseError::InvalidRows)?,
|
||||
cols: tokens
|
||||
.next()
|
||||
.ok_or(PtySizeParseError::MissingColumns)?
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(PtySizeParseError::InvalidColumns)?,
|
||||
pixel_width: tokens
|
||||
.next()
|
||||
.map(|s| s.trim().parse())
|
||||
.transpose()
|
||||
.map_err(PtySizeParseError::InvalidPixelWidth)?
|
||||
.unwrap_or(0),
|
||||
pixel_height: tokens
|
||||
.next()
|
||||
.map(|s| s.trim().parse())
|
||||
.transpose()
|
||||
.map_err(PtySizeParseError::InvalidPixelHeight)?
|
||||
.unwrap_or(0),
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, path::PathBuf};
|
||||
|
||||
/// Represents information about a system
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct SystemInfo {
|
||||
/// Family of the operating system as described in
|
||||
/// https://doc.rust-lang.org/std/env/consts/constant.FAMILY.html
|
||||
pub family: String,
|
||||
|
||||
/// Name of the specific operating system as described in
|
||||
/// https://doc.rust-lang.org/std/env/consts/constant.OS.html
|
||||
pub os: String,
|
||||
|
||||
/// Architecture of the CPI as described in
|
||||
/// https://doc.rust-lang.org/std/env/consts/constant.ARCH.html
|
||||
pub arch: String,
|
||||
|
||||
/// Current working directory of the running server process
|
||||
pub current_dir: PathBuf,
|
||||
|
||||
/// Primary separator for path components for the current platform
|
||||
/// as defined in https://doc.rust-lang.org/std/path/constant.MAIN_SEPARATOR.html
|
||||
pub main_separator: char,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl SystemInfo {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(SystemInfo)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SystemInfo {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
family: env::consts::FAMILY.to_string(),
|
||||
os: env::consts::OS.to_string(),
|
||||
arch: env::consts::ARCH.to_string(),
|
||||
current_dir: env::current_dir().unwrap_or_default(),
|
||||
main_separator: std::path::MAIN_SEPARATOR,
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub(crate) fn deserialize_u128_option<'de, D>(deserializer: D) -> Result<Option<u128>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
match Option::<String>::deserialize(deserializer)? {
|
||||
Some(s) => match s.parse::<u128>() {
|
||||
Ok(value) => Ok(Some(value)),
|
||||
Err(error) => Err(serde::de::Error::custom(format!(
|
||||
"Cannot convert to u128 with error: {:?}",
|
||||
error
|
||||
))),
|
||||
},
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn serialize_u128_option<S: serde::Serializer>(
|
||||
val: &Option<u128>,
|
||||
s: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
match val {
|
||||
Some(v) => format!("{}", *v).serialize(s),
|
||||
None => s.serialize_unit(),
|
||||
}
|
||||
}
|
@ -1,13 +1,20 @@
|
||||
mod api;
|
||||
pub use api::*;
|
||||
|
||||
mod client;
|
||||
pub use client::*;
|
||||
|
||||
mod constants;
|
||||
|
||||
mod net;
|
||||
pub use net::*;
|
||||
mod credentials;
|
||||
pub use credentials::*;
|
||||
|
||||
pub mod data;
|
||||
pub use data::*;
|
||||
pub use data::{DistantMsg, DistantRequestData, DistantResponseData, Map};
|
||||
|
||||
mod manager;
|
||||
pub use manager::*;
|
||||
|
||||
mod constants;
|
||||
mod serde_str;
|
||||
|
||||
mod server;
|
||||
pub use server::*;
|
||||
/// Re-export of `distant-net` as `net`
|
||||
pub use distant_net as net;
|
||||
|
@ -0,0 +1,7 @@
|
||||
mod client;
|
||||
mod data;
|
||||
mod server;
|
||||
|
||||
pub use client::*;
|
||||
pub use data::*;
|
||||
pub use server::*;
|
@ -0,0 +1,761 @@
|
||||
use super::data::{
|
||||
ConnectionId, ConnectionInfo, ConnectionList, Destination, Extra, ManagerRequest,
|
||||
ManagerResponse,
|
||||
};
|
||||
use crate::{DistantChannel, DistantClient, DistantMsg, DistantRequestData, DistantResponseData};
|
||||
use distant_net::{
|
||||
router, Auth, AuthServer, Client, IntoSplit, MpscTransport, OneshotListener, Request, Response,
|
||||
ServerExt, ServerRef, UntypedTransportRead, UntypedTransportWrite,
|
||||
};
|
||||
use log::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
mod config;
|
||||
pub use config::*;
|
||||
|
||||
mod ext;
|
||||
pub use ext::*;
|
||||
|
||||
router!(DistantManagerClientRouter {
|
||||
auth_transport: Request<Auth> => Response<Auth>,
|
||||
manager_transport: Response<ManagerResponse> => Request<ManagerRequest>,
|
||||
});
|
||||
|
||||
/// Represents a client that can connect to a remote distant manager
|
||||
pub struct DistantManagerClient {
|
||||
auth: Box<dyn ServerRef>,
|
||||
client: Client<ManagerRequest, ManagerResponse>,
|
||||
distant_clients: HashMap<ConnectionId, ClientHandle>,
|
||||
}
|
||||
|
||||
impl Drop for DistantManagerClient {
|
||||
fn drop(&mut self) {
|
||||
self.auth.abort();
|
||||
self.client.abort();
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a raw channel between a manager client and some remote server
|
||||
pub struct RawDistantChannel {
|
||||
pub transport: MpscTransport<
|
||||
Request<DistantMsg<DistantRequestData>>,
|
||||
Response<DistantMsg<DistantResponseData>>,
|
||||
>,
|
||||
forward_task: JoinHandle<()>,
|
||||
mailbox_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl RawDistantChannel {
|
||||
pub fn abort(&self) {
|
||||
self.forward_task.abort();
|
||||
self.mailbox_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RawDistantChannel {
|
||||
type Target = MpscTransport<
|
||||
Request<DistantMsg<DistantRequestData>>,
|
||||
Response<DistantMsg<DistantResponseData>>,
|
||||
>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.transport
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for RawDistantChannel {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.transport
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientHandle {
|
||||
client: DistantClient,
|
||||
forward_task: JoinHandle<()>,
|
||||
mailbox_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Drop for ClientHandle {
|
||||
fn drop(&mut self) {
|
||||
self.forward_task.abort();
|
||||
self.mailbox_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl DistantManagerClient {
|
||||
/// Initializes a client using the provided [`UntypedTransport`]
|
||||
pub fn new<T>(config: DistantManagerClientConfig, transport: T) -> io::Result<Self>
|
||||
where
|
||||
T: IntoSplit + 'static,
|
||||
T::Read: UntypedTransportRead + 'static,
|
||||
T::Write: UntypedTransportWrite + 'static,
|
||||
{
|
||||
let DistantManagerClientRouter {
|
||||
auth_transport,
|
||||
manager_transport,
|
||||
..
|
||||
} = DistantManagerClientRouter::new(transport);
|
||||
|
||||
// Initialize our client with manager request/response transport
|
||||
let (writer, reader) = manager_transport.into_split();
|
||||
let client = Client::new(writer, reader)?;
|
||||
|
||||
// Initialize our auth handler with auth/auth transport
|
||||
let auth = AuthServer {
|
||||
on_challenge: config.on_challenge,
|
||||
on_verify: config.on_verify,
|
||||
on_info: config.on_info,
|
||||
on_error: config.on_error,
|
||||
}
|
||||
.start(OneshotListener::from_value(auth_transport.into_split()))?;
|
||||
|
||||
Ok(Self {
|
||||
auth,
|
||||
client,
|
||||
distant_clients: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Request that the manager launches a new server at the given `destination`
|
||||
/// with `extra` being passed for destination-specific details, returning the new
|
||||
/// `destination` of the spawned server to connect to
|
||||
pub async fn launch(
|
||||
&mut self,
|
||||
destination: impl Into<Destination>,
|
||||
extra: impl Into<Extra>,
|
||||
) -> io::Result<Destination> {
|
||||
let destination = Box::new(destination.into());
|
||||
let extra = extra.into();
|
||||
trace!("launch({}, {})", destination, extra);
|
||||
|
||||
let res = self
|
||||
.client
|
||||
.send(ManagerRequest::Launch { destination, extra })
|
||||
.await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Launched { destination } => Ok(destination),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Request that the manager establishes a new connection at the given `destination`
|
||||
/// with `extra` being passed for destination-specific details
|
||||
pub async fn connect(
|
||||
&mut self,
|
||||
destination: impl Into<Destination>,
|
||||
extra: impl Into<Extra>,
|
||||
) -> io::Result<ConnectionId> {
|
||||
let destination = Box::new(destination.into());
|
||||
let extra = extra.into();
|
||||
trace!("connect({}, {})", destination, extra);
|
||||
|
||||
let res = self
|
||||
.client
|
||||
.send(ManagerRequest::Connect { destination, extra })
|
||||
.await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Connected { id } => Ok(id),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Establishes a channel with the server represented by the `connection_id`,
|
||||
/// returning a [`DistantChannel`] acting as the connection
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Multiple calls to open a channel against the same connection will result in
|
||||
/// clones of the same [`DistantChannel`] rather than establishing a duplicate
|
||||
/// remote connection to the same server
|
||||
pub async fn open_channel(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
) -> io::Result<DistantChannel> {
|
||||
trace!("open_channel({})", connection_id);
|
||||
if let Some(handle) = self.distant_clients.get(&connection_id) {
|
||||
Ok(handle.client.clone_channel())
|
||||
} else {
|
||||
let RawDistantChannel {
|
||||
transport,
|
||||
forward_task,
|
||||
mailbox_task,
|
||||
} = self.open_raw_channel(connection_id).await?;
|
||||
let (writer, reader) = transport.into_split();
|
||||
let client = DistantClient::new(writer, reader)?;
|
||||
let channel = client.clone_channel();
|
||||
self.distant_clients.insert(
|
||||
connection_id,
|
||||
ClientHandle {
|
||||
client,
|
||||
forward_task,
|
||||
mailbox_task,
|
||||
},
|
||||
);
|
||||
Ok(channel)
|
||||
}
|
||||
}
|
||||
|
||||
/// Establishes a channel with the server represented by the `connection_id`,
|
||||
/// returning a [`Transport`] acting as the connection
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Multiple calls to open a channel against the same connection will result in establishing a
|
||||
/// duplicate remote connections to the same server, so take care when using this method
|
||||
pub async fn open_raw_channel(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
) -> io::Result<RawDistantChannel> {
|
||||
trace!("open_raw_channel({})", connection_id);
|
||||
let mut mailbox = self
|
||||
.client
|
||||
.mail(ManagerRequest::OpenChannel { id: connection_id })
|
||||
.await?;
|
||||
|
||||
// Wait for the first response, which should be channel confirmation
|
||||
let channel_id = match mailbox.next().await {
|
||||
Some(response) => match response.payload {
|
||||
ManagerResponse::ChannelOpened { id } => Ok(id),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
},
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"open_channel mailbox aborted",
|
||||
)),
|
||||
}?;
|
||||
|
||||
// Spawn reader and writer tasks to forward requests and replies
|
||||
// using our opened channel
|
||||
let (t1, t2) = MpscTransport::pair(1);
|
||||
let (mut writer, mut reader) = t1.into_split();
|
||||
let mailbox_task = tokio::spawn(async move {
|
||||
use distant_net::TypedAsyncWrite;
|
||||
while let Some(response) = mailbox.next().await {
|
||||
match response.payload {
|
||||
ManagerResponse::Channel { response, .. } => {
|
||||
if let Err(x) = writer.write(response).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
ManagerResponse::ChannelClosed { .. } => break,
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut manager_channel = self.client.clone_channel();
|
||||
let forward_task = tokio::spawn(async move {
|
||||
use distant_net::TypedAsyncRead;
|
||||
loop {
|
||||
match reader.read().await {
|
||||
Ok(Some(request)) => {
|
||||
// NOTE: In this situation, we do not expect a response to this
|
||||
// request (even if the server sends something back)
|
||||
if let Err(x) = manager_channel
|
||||
.fire(ManagerRequest::Channel {
|
||||
id: channel_id,
|
||||
request,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(RawDistantChannel {
|
||||
transport: t2,
|
||||
forward_task,
|
||||
mailbox_task,
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieves information about a specific connection
|
||||
pub async fn info(&mut self, id: ConnectionId) -> io::Result<ConnectionInfo> {
|
||||
trace!("info({})", id);
|
||||
let res = self.client.send(ManagerRequest::Info { id }).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Info(info) => Ok(info),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Kills the specified connection
|
||||
pub async fn kill(&mut self, id: ConnectionId) -> io::Result<()> {
|
||||
trace!("kill({})", id);
|
||||
let res = self.client.send(ManagerRequest::Kill { id }).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Killed => Ok(()),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a list of active connections
|
||||
pub async fn list(&mut self) -> io::Result<ConnectionList> {
|
||||
trace!("list()");
|
||||
let res = self.client.send(ManagerRequest::List).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::List(list) => Ok(list),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Requests that the manager shuts down
|
||||
pub async fn shutdown(&mut self) -> io::Result<()> {
|
||||
trace!("shutdown()");
|
||||
let res = self.client.send(ManagerRequest::Shutdown).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Shutdown => Ok(()),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::data::{Error, ErrorKind};
|
||||
use distant_net::{
|
||||
FramedTransport, InmemoryTransport, PlainCodec, UntypedTransportRead, UntypedTransportWrite,
|
||||
};
|
||||
|
||||
fn setup() -> (
|
||||
DistantManagerClient,
|
||||
FramedTransport<InmemoryTransport, PlainCodec>,
|
||||
) {
|
||||
let (t1, t2) = FramedTransport::pair(100);
|
||||
let client =
|
||||
DistantManagerClient::new(DistantManagerClientConfig::with_empty_prompts(), t1)
|
||||
.unwrap();
|
||||
(client, t2)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn test_error() -> Error {
|
||||
Error {
|
||||
kind: ErrorKind::Interrupted,
|
||||
description: "test error".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn test_io_error() -> io::Error {
|
||||
test_error().into()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Extra>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Extra>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_return_id_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
let expected_id = 999;
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Connected { id: expected_id },
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let id = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Extra>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(id, expected_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.info(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.info(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_return_connection_info_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let info = ConnectionInfo {
|
||||
id: 123,
|
||||
destination: "scheme://host".parse::<Destination>().unwrap(),
|
||||
extra: "key=value".parse::<Extra>().unwrap(),
|
||||
};
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Info(info)))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let info = client.info(123).await.unwrap();
|
||||
assert_eq!(info.id, 123);
|
||||
assert_eq!(
|
||||
info.destination,
|
||||
"scheme://host".parse::<Destination>().unwrap()
|
||||
);
|
||||
assert_eq!(info.extra, "key=value".parse::<Extra>().unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.list().await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.list().await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_return_connection_list_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut list = ConnectionList::new();
|
||||
list.insert(123, "scheme://host".parse::<Destination>().unwrap());
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::List(list)))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let list = client.list().await.unwrap();
|
||||
assert_eq!(list.len(), 1);
|
||||
assert_eq!(
|
||||
list.get(&123).expect("Connection list missing item"),
|
||||
&"scheme://host".parse::<Destination>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.kill(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.kill(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_return_success_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Killed))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
client.kill(123).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Connected { id: 0 },
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.shutdown().await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.shutdown().await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_should_return_success_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
client.shutdown().await.unwrap();
|
||||
}
|
||||
}
|
@ -0,0 +1,85 @@
|
||||
use distant_net::{AuthChallengeFn, AuthErrorFn, AuthInfoFn, AuthVerifyFn, AuthVerifyKind};
|
||||
use log::*;
|
||||
use std::io;
|
||||
|
||||
/// Configuration to use when creating a new [`DistantManagerClient`](super::DistantManagerClient)
|
||||
pub struct DistantManagerClientConfig {
|
||||
pub on_challenge: Box<AuthChallengeFn>,
|
||||
pub on_verify: Box<AuthVerifyFn>,
|
||||
pub on_info: Box<AuthInfoFn>,
|
||||
pub on_error: Box<AuthErrorFn>,
|
||||
}
|
||||
|
||||
impl DistantManagerClientConfig {
|
||||
/// Creates a new config with prompts that return empty strings
|
||||
pub fn with_empty_prompts() -> Self {
|
||||
Self::with_prompts(|_| Ok("".to_string()), |_| Ok("".to_string()))
|
||||
}
|
||||
|
||||
/// Creates a new config with two prompts
|
||||
///
|
||||
/// * `password_prompt` - used for prompting for a secret, and should not display what is typed
|
||||
/// * `text_prompt` - used for general text, and is okay to display what is typed
|
||||
pub fn with_prompts<PP, PT>(password_prompt: PP, text_prompt: PT) -> Self
|
||||
where
|
||||
PP: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
|
||||
PT: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
on_challenge: Box::new(move |questions, _extra| {
|
||||
trace!("[manager client] on_challenge({questions:?}, {_extra:?})");
|
||||
let mut answers = Vec::new();
|
||||
for question in questions.iter() {
|
||||
// Contains all prompt lines including same line
|
||||
let mut lines = question.text.split('\n').collect::<Vec<_>>();
|
||||
|
||||
// Line that is prompt on same line as answer
|
||||
let line = lines.pop().unwrap();
|
||||
|
||||
// Go ahead and display all other lines
|
||||
for line in lines.into_iter() {
|
||||
eprintln!("{}", line);
|
||||
}
|
||||
|
||||
// Get an answer from user input, or use a blank string as an answer
|
||||
// if we fail to get input from the user
|
||||
let answer = password_prompt(line).unwrap_or_default();
|
||||
|
||||
answers.push(answer);
|
||||
}
|
||||
answers
|
||||
}),
|
||||
on_verify: Box::new(move |kind, text| {
|
||||
trace!("[manager client] on_verify({kind}, {text})");
|
||||
match kind {
|
||||
AuthVerifyKind::Host => {
|
||||
eprintln!("{}", text);
|
||||
|
||||
match text_prompt("Enter [y/N]> ") {
|
||||
Ok(answer) => {
|
||||
trace!("Verify? Answer = '{answer}'");
|
||||
matches!(answer.trim(), "y" | "Y" | "yes" | "YES")
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Failed verification: {x}");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
x => {
|
||||
error!("Unsupported verify kind: {x}");
|
||||
false
|
||||
}
|
||||
}
|
||||
}),
|
||||
on_info: Box::new(|text| {
|
||||
trace!("[manager client] on_info({text})");
|
||||
println!("{}", text);
|
||||
}),
|
||||
on_error: Box::new(|kind, text| {
|
||||
trace!("[manager client] on_error({kind}, {text})");
|
||||
eprintln!("{}: {}", kind, text);
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
mod tcp;
|
||||
pub use tcp::*;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
|
||||
#[cfg(windows)]
|
||||
mod windows;
|
||||
|
||||
#[cfg(windows)]
|
||||
pub use windows::*;
|
@ -0,0 +1,50 @@
|
||||
use crate::{DistantManagerClient, DistantManagerClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{Codec, FramedTransport, TcpTransport};
|
||||
use std::{convert, net::SocketAddr};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait TcpDistantManagerClientExt {
|
||||
/// Connect to a remote TCP server using the provided information
|
||||
async fn connect<C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: SocketAddr,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a remote TCP server, timing out after duration has passed
|
||||
async fn connect_timeout<C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: SocketAddr,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(config, addr, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TcpDistantManagerClientExt for DistantManagerClient {
|
||||
/// Connect to a remote TCP server using the provided information
|
||||
async fn connect<C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: SocketAddr,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let transport = TcpTransport::connect(addr).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
Self::new(config, transport)
|
||||
}
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
use crate::{DistantManagerClient, DistantManagerClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{Codec, FramedTransport, UnixSocketTransport};
|
||||
use std::{convert, path::Path};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait UnixSocketDistantManagerClientExt {
|
||||
/// Connect to a proxy unix socket
|
||||
async fn connect<P, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a proxy unix socket, timing out after duration has passed
|
||||
async fn connect_timeout<P, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(config, path, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl UnixSocketDistantManagerClientExt for DistantManagerClient {
|
||||
/// Connect to a proxy unix socket
|
||||
async fn connect<P, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let p = path.as_ref();
|
||||
let transport = UnixSocketTransport::connect(p).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
Ok(DistantManagerClient::new(config, transport)?)
|
||||
}
|
||||
}
|
@ -0,0 +1,91 @@
|
||||
use crate::{DistantManagerClient, DistantManagerClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{Codec, FramedTransport, WindowsPipeTransport};
|
||||
use std::{
|
||||
convert,
|
||||
ffi::{OsStr, OsString},
|
||||
};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait WindowsPipeDistantManagerClientExt {
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// using the given codec
|
||||
async fn connect<A, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// via `\\.\pipe\{name}` using the given codec
|
||||
async fn connect_local<N, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
name: N,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::connect(config, addr, codec).await
|
||||
}
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// using the given codec, timing out after duration has passed
|
||||
async fn connect_timeout<A, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(config, addr, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed
|
||||
async fn connect_local_timeout<N, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
name: N,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::connect_timeout(config, addr, codec, duration).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WindowsPipeDistantManagerClientExt for DistantManagerClient {
|
||||
async fn connect<A, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let a = addr.as_ref();
|
||||
let transport = WindowsPipeTransport::connect(a).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
Ok(DistantManagerClient::new(config, transport)?)
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
mod destination;
|
||||
pub use destination::*;
|
||||
|
||||
mod extra;
|
||||
pub use extra::*;
|
||||
|
||||
mod id;
|
||||
pub use id::*;
|
||||
|
||||
mod info;
|
||||
pub use info::*;
|
||||
|
||||
mod list;
|
||||
pub use list::*;
|
||||
|
||||
mod request;
|
||||
pub use request::*;
|
||||
|
||||
mod response;
|
||||
pub use response::*;
|
@ -0,0 +1,266 @@
|
||||
use crate::serde_str::{deserialize_from_str, serialize_to_str};
|
||||
use derive_more::{Display, Error, From};
|
||||
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
|
||||
use std::{convert::TryFrom, fmt, hash::Hash, str::FromStr};
|
||||
use uriparse::{
|
||||
Authority, AuthorityError, Host, Password, Scheme, URIReference, URIReferenceError, Username,
|
||||
URI,
|
||||
};
|
||||
|
||||
/// Represents an error that occurs when trying to parse a destination from a str
|
||||
#[derive(Copy, Clone, Debug, Display, Error, From, PartialEq, Eq)]
|
||||
pub enum DestinationError {
|
||||
MissingHost,
|
||||
URIReferenceError(URIReferenceError),
|
||||
}
|
||||
|
||||
/// `distant` connects and logs into the specified destination, which may be specified as either
|
||||
/// `hostname:port` where an attempt to connect to a **distant** server will be made, or a URI of
|
||||
/// one of the following forms:
|
||||
///
|
||||
/// * `distant://hostname:port` - connect to a distant server
|
||||
/// * `ssh://[user@]hostname[:port]` - connect to an SSH server
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
pub struct Destination(URIReference<'static>);
|
||||
|
||||
impl Destination {
|
||||
/// Returns a reference to the scheme associated with the destination, if it has one
|
||||
pub fn scheme(&self) -> Option<&str> {
|
||||
self.0.scheme().map(Scheme::as_str)
|
||||
}
|
||||
|
||||
/// Returns the host of the destination as a string
|
||||
pub fn to_host_string(&self) -> String {
|
||||
// NOTE: We guarantee that there is a host for a destination during construction
|
||||
self.0.host().unwrap().to_string()
|
||||
}
|
||||
|
||||
/// Returns the port tied to the destination, if it has one
|
||||
pub fn port(&self) -> Option<u16> {
|
||||
self.0.port()
|
||||
}
|
||||
|
||||
/// Returns the username tied with the destination if it has one
|
||||
pub fn username(&self) -> Option<&str> {
|
||||
self.0.username().map(Username::as_str)
|
||||
}
|
||||
|
||||
/// Returns the password tied with the destination if it has one
|
||||
pub fn password(&self) -> Option<&str> {
|
||||
self.0.password().map(Password::as_str)
|
||||
}
|
||||
|
||||
/// Replaces the host of the destination
|
||||
pub fn replace_host(&mut self, host: &str) -> Result<(), URIReferenceError> {
|
||||
let username = self
|
||||
.0
|
||||
.username()
|
||||
.map(Username::as_borrowed)
|
||||
.map(Username::into_owned);
|
||||
let password = self
|
||||
.0
|
||||
.password()
|
||||
.map(Password::as_borrowed)
|
||||
.map(Password::into_owned);
|
||||
let port = self.0.port();
|
||||
let _ = self.0.set_authority(Some(
|
||||
Authority::from_parts(
|
||||
username,
|
||||
password,
|
||||
Host::try_from(host)
|
||||
.map(Host::into_owned)
|
||||
.map_err(AuthorityError::from)
|
||||
.map_err(URIReferenceError::from)?,
|
||||
port,
|
||||
)
|
||||
.map(Authority::into_owned)
|
||||
.map_err(URIReferenceError::from)?,
|
||||
))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Indicates whether the host destination is globally routable
|
||||
pub fn is_host_global(&self) -> bool {
|
||||
match self.0.host() {
|
||||
Some(Host::IPv4Address(x)) => {
|
||||
!(x.is_broadcast()
|
||||
|| x.is_documentation()
|
||||
|| x.is_link_local()
|
||||
|| x.is_loopback()
|
||||
|| x.is_private()
|
||||
|| x.is_unspecified())
|
||||
}
|
||||
Some(Host::IPv6Address(x)) => {
|
||||
// NOTE: 14 is the global flag
|
||||
x.is_multicast() && (x.segments()[0] & 0x000f == 14)
|
||||
}
|
||||
Some(Host::RegisteredName(name)) => !name.trim().is_empty(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if destination represents a distant server
|
||||
pub fn is_distant(&self) -> bool {
|
||||
self.scheme_eq("distant")
|
||||
}
|
||||
|
||||
/// Returns true if destination represents an ssh server
|
||||
pub fn is_ssh(&self) -> bool {
|
||||
self.scheme_eq("ssh")
|
||||
}
|
||||
|
||||
fn scheme_eq(&self, s: &str) -> bool {
|
||||
match self.0.scheme() {
|
||||
Some(scheme) => scheme.as_str().eq_ignore_ascii_case(s),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to inner [`URIReference`]
|
||||
pub fn as_uri_ref(&self) -> &URIReference<'static> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<Destination> for &Destination {
|
||||
fn as_ref(&self) -> &Destination {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<URIReference<'static>> for Destination {
|
||||
fn as_ref(&self) -> &URIReference<'static> {
|
||||
self.as_uri_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Destination {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Destination {
|
||||
type Err = DestinationError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// Disallow empty (whitespace-only) input as that passes our
|
||||
// parsing for a URI reference (relative with no scheme or anything)
|
||||
if s.trim().is_empty() {
|
||||
return Err(DestinationError::MissingHost);
|
||||
}
|
||||
|
||||
let mut destination = URIReference::try_from(s)
|
||||
.map(URIReference::into_owned)
|
||||
.map(Destination)
|
||||
.map_err(DestinationError::URIReferenceError)?;
|
||||
|
||||
// Only support relative reference if it is a path reference as
|
||||
// we convert that to a relative reference with a host
|
||||
if destination.0.is_relative_reference() {
|
||||
let path = destination.0.path().to_string();
|
||||
destination.replace_host(path.as_str())?;
|
||||
let _ = destination.0.set_path("/")?;
|
||||
}
|
||||
|
||||
Ok(destination)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<URIReference<'a>> for Destination {
|
||||
type Error = DestinationError;
|
||||
|
||||
fn try_from(uri_ref: URIReference<'a>) -> Result<Self, Self::Error> {
|
||||
if uri_ref.host().is_none() {
|
||||
return Err(DestinationError::MissingHost);
|
||||
}
|
||||
|
||||
Ok(Self(uri_ref.into_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<URI<'a>> for Destination {
|
||||
type Error = DestinationError;
|
||||
|
||||
fn try_from(uri: URI<'a>) -> Result<Self, Self::Error> {
|
||||
let uri_ref: URIReference<'a> = uri.into();
|
||||
Self::try_from(uri_ref)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Box<Destination> {
|
||||
type Err = DestinationError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let destination = s.parse::<Destination>()?;
|
||||
Ok(Box::new(destination))
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Destination {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serialize_to_str(self, serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Destination {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
deserialize_from_str(deserializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_should_fail_if_string_is_only_whitespace() {
|
||||
let err = "".parse::<Destination>().unwrap_err();
|
||||
assert_eq!(err, DestinationError::MissingHost);
|
||||
|
||||
let err = " ".parse::<Destination>().unwrap_err();
|
||||
assert_eq!(err, DestinationError::MissingHost);
|
||||
|
||||
let err = "\t".parse::<Destination>().unwrap_err();
|
||||
assert_eq!(err, DestinationError::MissingHost);
|
||||
|
||||
let err = "\n".parse::<Destination>().unwrap_err();
|
||||
assert_eq!(err, DestinationError::MissingHost);
|
||||
|
||||
let err = "\r".parse::<Destination>().unwrap_err();
|
||||
assert_eq!(err, DestinationError::MissingHost);
|
||||
|
||||
let err = "\r\n".parse::<Destination>().unwrap_err();
|
||||
assert_eq!(err, DestinationError::MissingHost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_should_succeed_with_valid_uri() {
|
||||
let destination = "distant://localhost".parse::<Destination>().unwrap();
|
||||
assert_eq!(destination.scheme().unwrap(), "distant");
|
||||
assert_eq!(destination.to_host_string(), "localhost");
|
||||
assert_eq!(destination.as_uri_ref().path().to_string(), "/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_should_fail_if_relative_reference_that_is_not_valid_host() {
|
||||
let _ = "/".parse::<Destination>().unwrap_err();
|
||||
let _ = "/localhost".parse::<Destination>().unwrap_err();
|
||||
let _ = "my/path".parse::<Destination>().unwrap_err();
|
||||
let _ = "/my/path".parse::<Destination>().unwrap_err();
|
||||
let _ = "//localhost".parse::<Destination>().unwrap_err();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_should_succeed_with_nonempty_relative_reference_by_setting_host_to_path() {
|
||||
let destination = "localhost".parse::<Destination>().unwrap();
|
||||
assert_eq!(destination.to_host_string(), "localhost");
|
||||
assert_eq!(destination.as_uri_ref().path().to_string(), "/");
|
||||
}
|
||||
}
|
@ -0,0 +1,2 @@
|
||||
/// Represents extra data included for connections
|
||||
pub type Extra = crate::data::Map;
|
@ -0,0 +1,5 @@
|
||||
/// Id associated with an active connection
|
||||
pub type ConnectionId = u64;
|
||||
|
||||
/// Id associated with an open channel
|
||||
pub type ChannelId = u64;
|
@ -0,0 +1,15 @@
|
||||
use super::{ConnectionId, Destination, Extra};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Information about a specific connection
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ConnectionInfo {
|
||||
/// Connection's id
|
||||
pub id: ConnectionId,
|
||||
|
||||
/// Destination with which this connection is associated
|
||||
pub destination: Destination,
|
||||
|
||||
/// Extra information associated with this connection
|
||||
pub extra: Extra,
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
use super::{ConnectionId, Destination};
|
||||
use derive_more::IntoIterator;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
ops::{Deref, DerefMut, Index, IndexMut},
|
||||
};
|
||||
|
||||
/// Represents a list of information about active connections
|
||||
#[derive(Clone, Debug, PartialEq, Eq, IntoIterator, Serialize, Deserialize)]
|
||||
pub struct ConnectionList(pub(crate) HashMap<ConnectionId, Destination>);
|
||||
|
||||
impl ConnectionList {
|
||||
pub fn new() -> Self {
|
||||
Self(HashMap::new())
|
||||
}
|
||||
|
||||
/// Returns a reference to the destination associated with an active connection
|
||||
pub fn connection_destination(&self, id: ConnectionId) -> Option<&Destination> {
|
||||
self.0.get(&id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConnectionList {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ConnectionList {
|
||||
type Target = HashMap<ConnectionId, Destination>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for ConnectionList {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Index<u64> for ConnectionList {
|
||||
type Output = Destination;
|
||||
|
||||
fn index(&self, connection_id: u64) -> &Self::Output {
|
||||
&self.0[&connection_id]
|
||||
}
|
||||
}
|
||||
|
||||
impl IndexMut<u64> for ConnectionList {
|
||||
fn index_mut(&mut self, connection_id: u64) -> &mut Self::Output {
|
||||
self.0
|
||||
.get_mut(&connection_id)
|
||||
.expect("No connection with id")
|
||||
}
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
use super::{ChannelId, ConnectionId, Destination, Extra};
|
||||
use crate::{DistantMsg, DistantRequestData};
|
||||
use distant_net::Request;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "clap", derive(clap::Subcommand))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")]
|
||||
pub enum ManagerRequest {
|
||||
/// Launch a server using the manager
|
||||
Launch {
|
||||
// NOTE: Boxed per clippy's large_enum_variant warning
|
||||
destination: Box<Destination>,
|
||||
|
||||
/// Extra details specific to the connection
|
||||
#[cfg_attr(feature = "clap", clap(short, long, action = clap::ArgAction::Append))]
|
||||
extra: Extra,
|
||||
},
|
||||
|
||||
/// Initiate a connection through the manager
|
||||
Connect {
|
||||
// NOTE: Boxed per clippy's large_enum_variant warning
|
||||
destination: Box<Destination>,
|
||||
|
||||
/// Extra details specific to the connection
|
||||
#[cfg_attr(feature = "clap", clap(short, long, action = clap::ArgAction::Append))]
|
||||
extra: Extra,
|
||||
},
|
||||
|
||||
/// Opens a channel for communication with a server
|
||||
#[cfg_attr(feature = "clap", clap(skip))]
|
||||
OpenChannel {
|
||||
/// Id of the connection
|
||||
id: ConnectionId,
|
||||
},
|
||||
|
||||
/// Sends data through channel
|
||||
#[cfg_attr(feature = "clap", clap(skip))]
|
||||
Channel {
|
||||
/// Id of the channel
|
||||
id: ChannelId,
|
||||
|
||||
/// Request to send to through the channel
|
||||
#[cfg_attr(feature = "clap", clap(skip = skipped_request()))]
|
||||
request: Request<DistantMsg<DistantRequestData>>,
|
||||
},
|
||||
|
||||
/// Closes an open channel
|
||||
#[cfg_attr(feature = "clap", clap(skip))]
|
||||
CloseChannel {
|
||||
/// Id of the channel to close
|
||||
id: ChannelId,
|
||||
},
|
||||
|
||||
/// Retrieve information about a specific connection
|
||||
Info { id: ConnectionId },
|
||||
|
||||
/// Kill a specific connection
|
||||
Kill { id: ConnectionId },
|
||||
|
||||
/// Retrieve list of connections being managed
|
||||
List,
|
||||
|
||||
/// Signals the manager to shutdown
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
/// Produces some default request, purely to satisfy clap
|
||||
#[cfg(feature = "clap")]
|
||||
fn skipped_request() -> Request<DistantMsg<DistantRequestData>> {
|
||||
Request::new(DistantMsg::Single(DistantRequestData::SystemInfo {}))
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
use crate::{data::Error, ConnectionInfo, ConnectionList, Destination};
|
||||
use crate::{ChannelId, ConnectionId, DistantMsg, DistantResponseData};
|
||||
use distant_net::Response;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")]
|
||||
pub enum ManagerResponse {
|
||||
/// Acknowledgement that a connection was killed
|
||||
Killed,
|
||||
|
||||
/// Broadcast that the manager is shutting down (not guaranteed to be sent)
|
||||
Shutdown,
|
||||
|
||||
/// Indicates that some error occurred during a request
|
||||
Error(Error),
|
||||
|
||||
/// Confirmation of a distant server being launched
|
||||
Launched {
|
||||
/// Updated location of the spawned server
|
||||
destination: Destination,
|
||||
},
|
||||
|
||||
/// Confirmation of a connection being established
|
||||
Connected { id: ConnectionId },
|
||||
|
||||
/// Information about a specific connection
|
||||
Info(ConnectionInfo),
|
||||
|
||||
/// List of connections in the form of id -> destination
|
||||
List(ConnectionList),
|
||||
|
||||
/// Forward a response back to a specific channel that made a request
|
||||
Channel {
|
||||
/// Id of the channel
|
||||
id: ChannelId,
|
||||
|
||||
/// Response to an earlier channel request
|
||||
response: Response<DistantMsg<DistantResponseData>>,
|
||||
},
|
||||
|
||||
/// Indicates that a channel has been opened
|
||||
ChannelOpened {
|
||||
/// Id of the channel
|
||||
id: ChannelId,
|
||||
},
|
||||
|
||||
/// Indicates that a channel has been closed
|
||||
ChannelClosed {
|
||||
/// Id of the channel
|
||||
id: ChannelId,
|
||||
},
|
||||
}
|
@ -0,0 +1,698 @@
|
||||
use crate::{
|
||||
ChannelId, ConnectionId, ConnectionInfo, ConnectionList, Destination, Extra, ManagerRequest,
|
||||
ManagerResponse,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{
|
||||
router, Auth, AuthClient, Client, IntoSplit, Listener, MpscListener, Request, Response, Server,
|
||||
ServerCtx, ServerExt, UntypedTransportRead, UntypedTransportWrite,
|
||||
};
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io, sync::Arc};
|
||||
use tokio::{
|
||||
sync::{mpsc, Mutex, RwLock},
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
mod config;
|
||||
pub use config::*;
|
||||
|
||||
mod connection;
|
||||
pub use connection::*;
|
||||
|
||||
mod ext;
|
||||
pub use ext::*;
|
||||
|
||||
mod handler;
|
||||
pub use handler::*;
|
||||
|
||||
mod r#ref;
|
||||
pub use r#ref::*;
|
||||
|
||||
router!(DistantManagerRouter {
|
||||
auth_transport: Response<Auth> => Request<Auth>,
|
||||
manager_transport: Request<ManagerRequest> => Response<ManagerResponse>,
|
||||
});
|
||||
|
||||
/// Represents a manager of multiple distant server connections
|
||||
pub struct DistantManager {
|
||||
/// Receives authentication clients to feed into local data of server
|
||||
auth_client_rx: Mutex<mpsc::Receiver<AuthClient>>,
|
||||
|
||||
/// Configuration settings for the server
|
||||
config: DistantManagerConfig,
|
||||
|
||||
/// Mapping of connection id -> connection
|
||||
connections: RwLock<HashMap<ConnectionId, DistantManagerConnection>>,
|
||||
|
||||
/// Handlers for launch requests
|
||||
launch_handlers: Arc<RwLock<HashMap<String, BoxedLaunchHandler>>>,
|
||||
|
||||
/// Handlers for connect requests
|
||||
connect_handlers: Arc<RwLock<HashMap<String, BoxedConnectHandler>>>,
|
||||
|
||||
/// Primary task of server
|
||||
task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl DistantManager {
|
||||
/// Initializes a new instance of [`DistantManagerServer`] using the provided [`UntypedTransport`]
|
||||
pub fn start<L, T>(
|
||||
mut config: DistantManagerConfig,
|
||||
mut listener: L,
|
||||
) -> io::Result<DistantManagerRef>
|
||||
where
|
||||
L: Listener<Output = T> + 'static,
|
||||
T: IntoSplit + Send + 'static,
|
||||
T::Read: UntypedTransportRead + 'static,
|
||||
T::Write: UntypedTransportWrite + 'static,
|
||||
{
|
||||
let (conn_tx, mpsc_listener) = MpscListener::channel(config.connection_buffer_size);
|
||||
let (auth_client_tx, auth_client_rx) = mpsc::channel(1);
|
||||
|
||||
// Spawn task that uses our input listener to get both auth and manager events,
|
||||
// forwarding manager events to the internal mpsc listener
|
||||
let task = tokio::spawn(async move {
|
||||
while let Ok(transport) = listener.accept().await {
|
||||
let DistantManagerRouter {
|
||||
auth_transport,
|
||||
manager_transport,
|
||||
..
|
||||
} = DistantManagerRouter::new(transport);
|
||||
|
||||
let (writer, reader) = auth_transport.into_split();
|
||||
let client = match Client::new(writer, reader) {
|
||||
Ok(client) => client,
|
||||
Err(x) => {
|
||||
error!("Creating auth client failed: {}", x);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let auth_client = AuthClient::from(client);
|
||||
|
||||
// Forward auth client for new connection in server
|
||||
if auth_client_tx.send(auth_client).await.is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Forward connected and routed transport to server
|
||||
if conn_tx.send(manager_transport.into_split()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let launch_handlers = Arc::new(RwLock::new(config.launch_handlers.drain().collect()));
|
||||
let weak_launch_handlers = Arc::downgrade(&launch_handlers);
|
||||
let connect_handlers = Arc::new(RwLock::new(config.connect_handlers.drain().collect()));
|
||||
let weak_connect_handlers = Arc::downgrade(&connect_handlers);
|
||||
let server_ref = Self {
|
||||
auth_client_rx: Mutex::new(auth_client_rx),
|
||||
config,
|
||||
launch_handlers,
|
||||
connect_handlers,
|
||||
connections: RwLock::new(HashMap::new()),
|
||||
task,
|
||||
}
|
||||
.start(mpsc_listener)?;
|
||||
|
||||
Ok(DistantManagerRef {
|
||||
launch_handlers: weak_launch_handlers,
|
||||
connect_handlers: weak_connect_handlers,
|
||||
inner: server_ref,
|
||||
})
|
||||
}
|
||||
|
||||
/// Launches a new server at the specified `destination` using the given `extra` information
|
||||
/// and authentication client (if needed) to retrieve additional information needed to
|
||||
/// enter the destination prior to starting the server, returning the destination of the
|
||||
/// launched server
|
||||
async fn launch(
|
||||
&self,
|
||||
destination: Destination,
|
||||
extra: Extra,
|
||||
auth: Option<&mut AuthClient>,
|
||||
) -> io::Result<Destination> {
|
||||
let auth = auth.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Authentication client not initialized",
|
||||
)
|
||||
})?;
|
||||
|
||||
let scheme = match destination.scheme() {
|
||||
Some(scheme) => {
|
||||
trace!("Using scheme {}", scheme);
|
||||
scheme
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"Using fallback scheme of {}",
|
||||
self.config.fallback_scheme.as_str()
|
||||
);
|
||||
self.config.fallback_scheme.as_str()
|
||||
}
|
||||
}
|
||||
.to_lowercase();
|
||||
|
||||
let credentials = {
|
||||
let lock = self.launch_handlers.read().await;
|
||||
let handler = lock.get(&scheme).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("No launch handler registered for {}", scheme),
|
||||
)
|
||||
})?;
|
||||
handler.launch(&destination, &extra, auth).await?
|
||||
};
|
||||
|
||||
Ok(credentials)
|
||||
}
|
||||
|
||||
/// Connects to a new server at the specified `destination` using the given `extra` information
|
||||
/// and authentication client (if needed) to retrieve additional information needed to
|
||||
/// establish the connection to the server
|
||||
async fn connect(
|
||||
&self,
|
||||
destination: Destination,
|
||||
extra: Extra,
|
||||
auth: Option<&mut AuthClient>,
|
||||
) -> io::Result<ConnectionId> {
|
||||
let auth = auth.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Authentication client not initialized",
|
||||
)
|
||||
})?;
|
||||
|
||||
let scheme = match destination.scheme() {
|
||||
Some(scheme) => {
|
||||
trace!("Using scheme {}", scheme);
|
||||
scheme
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"Using fallback scheme of {}",
|
||||
self.config.fallback_scheme.as_str()
|
||||
);
|
||||
self.config.fallback_scheme.as_str()
|
||||
}
|
||||
}
|
||||
.to_lowercase();
|
||||
|
||||
let (writer, reader) = {
|
||||
let lock = self.connect_handlers.read().await;
|
||||
let handler = lock.get(&scheme).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("No connect handler registered for {}", scheme),
|
||||
)
|
||||
})?;
|
||||
handler.connect(&destination, &extra, auth).await?
|
||||
};
|
||||
|
||||
let connection = DistantManagerConnection::new(destination, extra, writer, reader);
|
||||
let id = connection.id;
|
||||
self.connections.write().await.insert(id, connection);
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Retrieves information about the connection to the server with the specified `id`
|
||||
async fn info(&self, id: ConnectionId) -> io::Result<ConnectionInfo> {
|
||||
match self.connections.read().await.get(&id) {
|
||||
Some(connection) => Ok(ConnectionInfo {
|
||||
id: connection.id,
|
||||
destination: connection.destination.clone(),
|
||||
extra: connection.extra.clone(),
|
||||
}),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"No connection found",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a list of connections to servers
|
||||
async fn list(&self) -> io::Result<ConnectionList> {
|
||||
Ok(ConnectionList(
|
||||
self.connections
|
||||
.read()
|
||||
.await
|
||||
.values()
|
||||
.map(|conn| (conn.id, conn.destination.clone()))
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Kills the connection to the server with the specified `id`
|
||||
async fn kill(&self, id: ConnectionId) -> io::Result<()> {
|
||||
match self.connections.write().await.remove(&id) {
|
||||
Some(_) => Ok(()),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"No connection found",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DistantManagerServerConnection {
|
||||
/// Authentication client that manager can use when establishing a new connection
|
||||
/// and needing to get authentication details from the client to move forward
|
||||
auth_client: Option<Mutex<AuthClient>>,
|
||||
|
||||
/// Holds on to open channels feeding data back from a server to some connected client,
|
||||
/// enabling us to cancel the tasks on demand
|
||||
channels: RwLock<HashMap<ChannelId, DistantManagerChannel>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Server for DistantManager {
|
||||
type Request = ManagerRequest;
|
||||
type Response = ManagerResponse;
|
||||
type LocalData = DistantManagerServerConnection;
|
||||
|
||||
async fn on_accept(&self, local_data: &mut Self::LocalData) {
|
||||
local_data.auth_client = self
|
||||
.auth_client_rx
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
.await
|
||||
.map(Mutex::new);
|
||||
|
||||
// Enable jit handshake
|
||||
if let Some(auth_client) = local_data.auth_client.as_ref() {
|
||||
auth_client.lock().await.set_jit_handshake(true);
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
|
||||
let ServerCtx {
|
||||
connection_id,
|
||||
request,
|
||||
reply,
|
||||
local_data,
|
||||
} = ctx;
|
||||
|
||||
let response = match request.payload {
|
||||
ManagerRequest::Launch { destination, extra } => {
|
||||
let mut auth = match local_data.auth_client.as_ref() {
|
||||
Some(client) => Some(client.lock().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
match self.launch(*destination, extra, auth.as_deref_mut()).await {
|
||||
Ok(destination) => ManagerResponse::Launched { destination },
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
}
|
||||
}
|
||||
ManagerRequest::Connect { destination, extra } => {
|
||||
let mut auth = match local_data.auth_client.as_ref() {
|
||||
Some(client) => Some(client.lock().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
match self.connect(*destination, extra, auth.as_deref_mut()).await {
|
||||
Ok(id) => ManagerResponse::Connected { id },
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
}
|
||||
}
|
||||
ManagerRequest::OpenChannel { id } => match self.connections.read().await.get(&id) {
|
||||
Some(connection) => match connection.open_channel(reply.clone()).await {
|
||||
Ok(channel) => {
|
||||
let id = channel.id();
|
||||
local_data.channels.write().await.insert(id, channel);
|
||||
ManagerResponse::ChannelOpened { id }
|
||||
}
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
None => ManagerResponse::Error(
|
||||
io::Error::new(io::ErrorKind::NotConnected, "Connection does not exist").into(),
|
||||
),
|
||||
},
|
||||
ManagerRequest::Channel { id, request } => {
|
||||
match local_data.channels.read().await.get(&id) {
|
||||
// TODO: For now, we are NOT sending back a response to acknowledge
|
||||
// a successful channel send. We could do this in order for
|
||||
// the client to listen for a complete send, but is it worth it?
|
||||
Some(channel) => match channel.send(request).await {
|
||||
Ok(_) => return,
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
None => ManagerResponse::Error(
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"Channel is not open or does not exist",
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
ManagerRequest::CloseChannel { id } => {
|
||||
match local_data.channels.write().await.remove(&id) {
|
||||
Some(channel) => match channel.close().await {
|
||||
Ok(_) => ManagerResponse::ChannelClosed { id },
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
None => ManagerResponse::Error(
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"Channel is not open or does not exist",
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
ManagerRequest::Info { id } => match self.info(id).await {
|
||||
Ok(info) => ManagerResponse::Info(info),
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
ManagerRequest::List => match self.list().await {
|
||||
Ok(list) => ManagerResponse::List(list),
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
ManagerRequest::Kill { id } => match self.kill(id).await {
|
||||
Ok(()) => ManagerResponse::Killed,
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
ManagerRequest::Shutdown => {
|
||||
if let Err(x) = reply.send(ManagerResponse::Shutdown).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
|
||||
// Clear out handler state in order to trigger drops
|
||||
self.launch_handlers.write().await.clear();
|
||||
self.connect_handlers.write().await.clear();
|
||||
|
||||
// Shutdown the primary server task
|
||||
self.task.abort();
|
||||
|
||||
// TODO: Perform a graceful shutdown instead of this?
|
||||
// Review https://tokio.rs/tokio/topics/shutdown
|
||||
std::process::exit(0);
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(x) = reply.send(response).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use distant_net::{
|
||||
AuthClient, FramedTransport, HeapAuthServer, InmemoryTransport, IntoSplit, MappedListener,
|
||||
OneshotListener, PlainCodec, ServerExt, ServerRef,
|
||||
};
|
||||
|
||||
/// Create a new server, bypassing the start loop
|
||||
fn setup() -> DistantManager {
|
||||
let (_, rx) = mpsc::channel(1);
|
||||
DistantManager {
|
||||
auth_client_rx: Mutex::new(rx),
|
||||
config: Default::default(),
|
||||
connections: RwLock::new(HashMap::new()),
|
||||
launch_handlers: Arc::new(RwLock::new(HashMap::new())),
|
||||
connect_handlers: Arc::new(RwLock::new(HashMap::new())),
|
||||
task: tokio::spawn(async move {}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a connected [`AuthClient`] with a launched auth server that blindly responds
|
||||
fn auth_client_server() -> (AuthClient, Box<dyn ServerRef>) {
|
||||
let (t1, t2) = FramedTransport::pair(1);
|
||||
let client = AuthClient::from(Client::from_framed_transport(t1).unwrap());
|
||||
|
||||
// Create a server that does nothing, but will support
|
||||
let server = HeapAuthServer {
|
||||
on_challenge: Box::new(|_, _| Vec::new()),
|
||||
on_verify: Box::new(|_, _| false),
|
||||
on_info: Box::new(|_| ()),
|
||||
on_error: Box::new(|_, _| ()),
|
||||
}
|
||||
.start(MappedListener::new(OneshotListener::from_value(t2), |t| {
|
||||
t.into_split()
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
(client, server)
|
||||
}
|
||||
|
||||
fn dummy_distant_writer_reader() -> (BoxedDistantWriter, BoxedDistantReader) {
|
||||
setup_distant_writer_reader().0
|
||||
}
|
||||
|
||||
/// Creates a writer & reader with a connected transport
|
||||
fn setup_distant_writer_reader() -> (
|
||||
(BoxedDistantWriter, BoxedDistantReader),
|
||||
FramedTransport<InmemoryTransport, PlainCodec>,
|
||||
) {
|
||||
let (t1, t2) = FramedTransport::pair(1);
|
||||
let (writer, reader) = t1.into_split();
|
||||
((Box::new(writer), Box::new(reader)), t2)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn launch_should_fail_if_destination_scheme_is_unsupported() {
|
||||
let server = setup();
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.launch(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn launch_should_fail_if_handler_tied_to_scheme_fails() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn LaunchHandler> = Box::new(|_: &_, _: &_, _: &mut _| async {
|
||||
Err(io::Error::new(io::ErrorKind::Other, "test failure"))
|
||||
});
|
||||
|
||||
server
|
||||
.launch_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.launch(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::Other);
|
||||
assert_eq!(err.to_string(), "test failure");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn launch_should_return_new_destination_on_success() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn LaunchHandler> = {
|
||||
Box::new(|_: &_, _: &_, _: &mut _| async {
|
||||
Ok("scheme2://host2".parse::<Destination>().unwrap())
|
||||
})
|
||||
};
|
||||
|
||||
server
|
||||
.launch_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "key=value".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let destination = server
|
||||
.launch(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
destination,
|
||||
"scheme2://host2".parse::<Destination>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_fail_if_destination_scheme_is_unsupported() {
|
||||
let server = setup();
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.connect(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_fail_if_handler_tied_to_scheme_fails() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn ConnectHandler> = Box::new(|_: &_, _: &_, _: &mut _| async {
|
||||
Err(io::Error::new(io::ErrorKind::Other, "test failure"))
|
||||
});
|
||||
|
||||
server
|
||||
.connect_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.connect(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::Other);
|
||||
assert_eq!(err.to_string(), "test failure");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_return_id_of_new_connection_on_success() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn ConnectHandler> =
|
||||
Box::new(|_: &_, _: &_, _: &mut _| async { Ok(dummy_distant_writer_reader()) });
|
||||
|
||||
server
|
||||
.connect_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "key=value".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let id = server
|
||||
.connect(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let lock = server.connections.read().await;
|
||||
let connection = lock.get(&id).unwrap();
|
||||
assert_eq!(connection.id, id);
|
||||
assert_eq!(connection.destination, "scheme://host".parse().unwrap());
|
||||
assert_eq!(connection.extra, "key=value".parse().unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_fail_if_no_connection_found_for_specified_id() {
|
||||
let server = setup();
|
||||
|
||||
let err = server.info(999).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_return_information_about_established_connection() {
|
||||
let server = setup();
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"scheme://host".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id = connection.id;
|
||||
server.connections.write().await.insert(id, connection);
|
||||
|
||||
let info = server.info(id).await.unwrap();
|
||||
assert_eq!(
|
||||
info,
|
||||
ConnectionInfo {
|
||||
id,
|
||||
destination: "scheme://host".parse().unwrap(),
|
||||
extra: "key=value".parse().unwrap(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_return_empty_connection_list_if_no_established_connections() {
|
||||
let server = setup();
|
||||
|
||||
let list = server.list().await.unwrap();
|
||||
assert_eq!(list, ConnectionList(HashMap::new()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_return_a_list_of_established_connections() {
|
||||
let server = setup();
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"scheme://host".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id_1 = connection.id;
|
||||
server.connections.write().await.insert(id_1, connection);
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"other://host2".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id_2 = connection.id;
|
||||
server.connections.write().await.insert(id_2, connection);
|
||||
|
||||
let list = server.list().await.unwrap();
|
||||
assert_eq!(
|
||||
list.get(&id_1).unwrap(),
|
||||
&"scheme://host".parse::<Destination>().unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
list.get(&id_2).unwrap(),
|
||||
&"other://host2".parse::<Destination>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_fail_if_no_connection_found_for_specified_id() {
|
||||
let server = setup();
|
||||
|
||||
let err = server.kill(999).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() {
|
||||
let server = setup();
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"scheme://host".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id = connection.id;
|
||||
server.connections.write().await.insert(id, connection);
|
||||
|
||||
server.kill(id).await.unwrap();
|
||||
|
||||
let lock = server.connections.read().await;
|
||||
assert!(!lock.contains_key(&id), "Connection still exists");
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
use crate::{BoxedConnectHandler, BoxedLaunchHandler};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub struct DistantManagerConfig {
|
||||
/// Scheme to use when none is provided in a destination
|
||||
pub fallback_scheme: String,
|
||||
|
||||
/// Buffer size for queue of incoming connections before blocking
|
||||
pub connection_buffer_size: usize,
|
||||
|
||||
/// If listening as local user
|
||||
pub user: bool,
|
||||
|
||||
/// Handlers to use for launch requests
|
||||
pub launch_handlers: HashMap<String, BoxedLaunchHandler>,
|
||||
|
||||
/// Handlers to use for connect requests
|
||||
pub connect_handlers: HashMap<String, BoxedConnectHandler>,
|
||||
}
|
||||
|
||||
impl Default for DistantManagerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
fallback_scheme: "distant".to_string(),
|
||||
connection_buffer_size: 100,
|
||||
user: false,
|
||||
launch_handlers: HashMap::new(),
|
||||
connect_handlers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,201 @@
|
||||
use crate::{
|
||||
manager::{
|
||||
data::{ChannelId, ConnectionId, Destination, Extra},
|
||||
BoxedDistantReader, BoxedDistantWriter,
|
||||
},
|
||||
DistantMsg, DistantRequestData, DistantResponseData, ManagerResponse,
|
||||
};
|
||||
use distant_net::{Request, Response, ServerReply};
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io};
|
||||
use tokio::{sync::mpsc, task::JoinHandle};
|
||||
|
||||
/// Represents a connection a distant manager has with some distant-compatible server
|
||||
pub struct DistantManagerConnection {
|
||||
pub id: ConnectionId,
|
||||
pub destination: Destination,
|
||||
pub extra: Extra,
|
||||
tx: mpsc::Sender<StateMachine>,
|
||||
reader_task: JoinHandle<()>,
|
||||
writer_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DistantManagerChannel {
|
||||
channel_id: ChannelId,
|
||||
tx: mpsc::Sender<StateMachine>,
|
||||
}
|
||||
|
||||
impl DistantManagerChannel {
|
||||
pub fn id(&self) -> ChannelId {
|
||||
self.channel_id
|
||||
}
|
||||
|
||||
pub async fn send(&self, request: Request<DistantMsg<DistantRequestData>>) -> io::Result<()> {
|
||||
let channel_id = self.channel_id;
|
||||
self.tx
|
||||
.send(StateMachine::Write {
|
||||
id: channel_id,
|
||||
request,
|
||||
})
|
||||
.await
|
||||
.map_err(|x| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
format!("channel {} send failed: {}", channel_id, x),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> io::Result<()> {
|
||||
let channel_id = self.channel_id;
|
||||
self.tx
|
||||
.send(StateMachine::Unregister { id: channel_id })
|
||||
.await
|
||||
.map_err(|x| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
format!("channel {} close failed: {}", channel_id, x),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
enum StateMachine {
|
||||
Register {
|
||||
id: ChannelId,
|
||||
reply: ServerReply<ManagerResponse>,
|
||||
},
|
||||
|
||||
Unregister {
|
||||
id: ChannelId,
|
||||
},
|
||||
|
||||
Read {
|
||||
response: Response<DistantMsg<DistantResponseData>>,
|
||||
},
|
||||
|
||||
Write {
|
||||
id: ChannelId,
|
||||
request: Request<DistantMsg<DistantRequestData>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl DistantManagerConnection {
|
||||
pub fn new(
|
||||
destination: Destination,
|
||||
extra: Extra,
|
||||
mut writer: BoxedDistantWriter,
|
||||
mut reader: BoxedDistantReader,
|
||||
) -> Self {
|
||||
let connection_id = rand::random();
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let reader_task = {
|
||||
let tx = tx.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match reader.read().await {
|
||||
Ok(Some(response)) => {
|
||||
if tx.send(StateMachine::Read { response }).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let mut registered = HashMap::new();
|
||||
while let Some(state_machine) = rx.recv().await {
|
||||
match state_machine {
|
||||
StateMachine::Register { id, reply } => {
|
||||
registered.insert(id, reply);
|
||||
}
|
||||
StateMachine::Unregister { id } => {
|
||||
registered.remove(&id);
|
||||
}
|
||||
StateMachine::Read { mut response } => {
|
||||
// Split {channel id}_{request id} back into pieces and
|
||||
// update the origin id to match the request id only
|
||||
let channel_id = match response.origin_id.split_once('_') {
|
||||
Some((cid_str, oid_str)) => {
|
||||
if let Ok(cid) = cid_str.parse::<ChannelId>() {
|
||||
response.origin_id = oid_str.to_string();
|
||||
cid
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if let Some(reply) = registered.get(&channel_id) {
|
||||
let response = ManagerResponse::Channel {
|
||||
id: channel_id,
|
||||
response,
|
||||
};
|
||||
if let Err(x) = reply.send(response).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
StateMachine::Write { id, request } => {
|
||||
// Combine channel id with request id so we can properly forward
|
||||
// the response containing this in the origin id
|
||||
let request = Request {
|
||||
id: format!("{}_{}", id, request.id),
|
||||
payload: request.payload,
|
||||
};
|
||||
if let Err(x) = writer.write(request).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
id: connection_id,
|
||||
destination,
|
||||
extra,
|
||||
tx,
|
||||
reader_task,
|
||||
writer_task,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn open_channel(
|
||||
&self,
|
||||
reply: ServerReply<ManagerResponse>,
|
||||
) -> io::Result<DistantManagerChannel> {
|
||||
let channel_id = rand::random();
|
||||
self.tx
|
||||
.send(StateMachine::Register {
|
||||
id: channel_id,
|
||||
reply,
|
||||
})
|
||||
.await
|
||||
.map_err(|x| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
format!("open_channel failed: {}", x),
|
||||
)
|
||||
})?;
|
||||
Ok(DistantManagerChannel {
|
||||
channel_id,
|
||||
tx: self.tx.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DistantManagerConnection {
|
||||
fn drop(&mut self) {
|
||||
self.reader_task.abort();
|
||||
self.writer_task.abort();
|
||||
}
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
mod tcp;
|
||||
pub use tcp::*;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
|
||||
#[cfg(windows)]
|
||||
mod windows;
|
||||
|
||||
#[cfg(windows)]
|
||||
pub use windows::*;
|
@ -0,0 +1,30 @@
|
||||
use crate::{DistantManager, DistantManagerConfig};
|
||||
use distant_net::{
|
||||
Codec, FramedTransport, IntoSplit, MappedListener, PortRange, TcpListener, TcpServerRef,
|
||||
};
|
||||
use std::{io, net::IpAddr};
|
||||
|
||||
impl DistantManager {
|
||||
/// Start a new server by binding to the given IP address and one of the ports in the
|
||||
/// specified range, mapping all connections to use the given codec
|
||||
pub async fn start_tcp<P, C>(
|
||||
config: DistantManagerConfig,
|
||||
addr: IpAddr,
|
||||
port: P,
|
||||
codec: C,
|
||||
) -> io::Result<TcpServerRef>
|
||||
where
|
||||
P: Into<PortRange> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let listener = TcpListener::bind(addr, port).await?;
|
||||
let port = listener.port();
|
||||
|
||||
let listener = MappedListener::new(listener, move |transport| {
|
||||
let transport = FramedTransport::new(transport, codec.clone());
|
||||
transport.into_split()
|
||||
});
|
||||
let inner = DistantManager::start(config, listener)?;
|
||||
Ok(TcpServerRef::new(addr, port, Box::new(inner)))
|
||||
}
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
use crate::{DistantManager, DistantManagerConfig};
|
||||
use distant_net::{
|
||||
Codec, FramedTransport, IntoSplit, MappedListener, UnixSocketListener, UnixSocketServerRef,
|
||||
};
|
||||
use std::{io, path::Path};
|
||||
|
||||
impl DistantManager {
|
||||
/// Start a new server using the specified path as a unix socket using default unix socket file
|
||||
/// permissions
|
||||
pub async fn start_unix_socket<P, C>(
|
||||
config: DistantManagerConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
) -> io::Result<UnixSocketServerRef>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
Self::start_unix_socket_with_permissions(
|
||||
config,
|
||||
path,
|
||||
codec,
|
||||
UnixSocketListener::default_unix_socket_file_permissions(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Start a new server using the specified path as a unix socket and `mode` as the unix socket
|
||||
/// file permissions
|
||||
pub async fn start_unix_socket_with_permissions<P, C>(
|
||||
config: DistantManagerConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
mode: u32,
|
||||
) -> io::Result<UnixSocketServerRef>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let listener = UnixSocketListener::bind_with_permissions(path, mode).await?;
|
||||
let path = listener.path().to_path_buf();
|
||||
|
||||
let listener = MappedListener::new(listener, move |transport| {
|
||||
let transport = FramedTransport::new(transport, codec.clone());
|
||||
transport.into_split()
|
||||
});
|
||||
let inner = DistantManager::start(config, listener)?;
|
||||
Ok(UnixSocketServerRef::new(path, Box::new(inner)))
|
||||
}
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
use crate::{DistantManager, DistantManagerConfig};
|
||||
use distant_net::{
|
||||
Codec, FramedTransport, IntoSplit, MappedListener, WindowsPipeListener, WindowsPipeServerRef,
|
||||
};
|
||||
use std::{
|
||||
ffi::{OsStr, OsString},
|
||||
io,
|
||||
};
|
||||
|
||||
impl DistantManager {
|
||||
/// Start a new server at the specified address via `\\.\pipe\{name}` using the given codec
|
||||
pub async fn start_local_named_pipe<N, C>(
|
||||
config: DistantManagerConfig,
|
||||
name: N,
|
||||
codec: C,
|
||||
) -> io::Result<WindowsPipeServerRef>
|
||||
where
|
||||
Self: Sized,
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::start_named_pipe(config, addr, codec).await
|
||||
}
|
||||
|
||||
/// Start a new server at the specified pipe address using the given codec
|
||||
pub async fn start_named_pipe<A, C>(
|
||||
config: DistantManagerConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
) -> io::Result<WindowsPipeServerRef>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let a = addr.as_ref();
|
||||
let listener = WindowsPipeListener::bind(a)?;
|
||||
let addr = listener.addr().to_os_string();
|
||||
|
||||
let listener = MappedListener::new(listener, move |transport| {
|
||||
let transport = FramedTransport::new(transport, codec.clone());
|
||||
transport.into_split()
|
||||
});
|
||||
let inner = DistantManager::start(config, listener)?;
|
||||
Ok(WindowsPipeServerRef::new(addr, Box::new(inner)))
|
||||
}
|
||||
}
|
@ -0,0 +1,69 @@
|
||||
use crate::{
|
||||
manager::data::{Destination, Extra},
|
||||
DistantMsg, DistantRequestData, DistantResponseData,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{AuthClient, Request, Response, TypedAsyncRead, TypedAsyncWrite};
|
||||
use std::{future::Future, io};
|
||||
|
||||
pub type BoxedDistantWriter =
|
||||
Box<dyn TypedAsyncWrite<Request<DistantMsg<DistantRequestData>>> + Send>;
|
||||
pub type BoxedDistantReader =
|
||||
Box<dyn TypedAsyncRead<Response<DistantMsg<DistantResponseData>>> + Send>;
|
||||
pub type BoxedDistantWriterReader = (BoxedDistantWriter, BoxedDistantReader);
|
||||
pub type BoxedLaunchHandler = Box<dyn LaunchHandler>;
|
||||
pub type BoxedConnectHandler = Box<dyn ConnectHandler>;
|
||||
|
||||
/// Used to launch a server at the specified destination, returning some result as a vec of bytes
|
||||
#[async_trait]
|
||||
pub trait LaunchHandler: Send + Sync {
|
||||
async fn launch(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
extra: &Extra,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<Destination>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, R> LaunchHandler for F
|
||||
where
|
||||
F: for<'a> Fn(&'a Destination, &'a Extra, &'a mut AuthClient) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = io::Result<Destination>> + Send + 'static,
|
||||
{
|
||||
async fn launch(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
extra: &Extra,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<Destination> {
|
||||
self(destination, extra, auth_client).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Used to connect to a destination, returning a connected reader and writer pair
|
||||
#[async_trait]
|
||||
pub trait ConnectHandler: Send + Sync {
|
||||
async fn connect(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
extra: &Extra,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<BoxedDistantWriterReader>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, R> ConnectHandler for F
|
||||
where
|
||||
F: for<'a> Fn(&'a Destination, &'a Extra, &'a mut AuthClient) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = io::Result<BoxedDistantWriterReader>> + Send + 'static,
|
||||
{
|
||||
async fn connect(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
extra: &Extra,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<BoxedDistantWriterReader> {
|
||||
self(destination, extra, auth_client).await
|
||||
}
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
use super::{BoxedConnectHandler, BoxedLaunchHandler, ConnectHandler, LaunchHandler};
|
||||
use distant_net::{ServerRef, ServerState};
|
||||
use std::{collections::HashMap, io, sync::Weak};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Reference to a distant manager's server instance
|
||||
pub struct DistantManagerRef {
|
||||
/// Mapping of "scheme" -> handler
|
||||
pub(crate) launch_handlers: Weak<RwLock<HashMap<String, BoxedLaunchHandler>>>,
|
||||
|
||||
/// Mapping of "scheme" -> handler
|
||||
pub(crate) connect_handlers: Weak<RwLock<HashMap<String, BoxedConnectHandler>>>,
|
||||
|
||||
pub(crate) inner: Box<dyn ServerRef>,
|
||||
}
|
||||
|
||||
impl DistantManagerRef {
|
||||
/// Registers a new [`LaunchHandler`] for the specified scheme (e.g. "distant" or "ssh")
|
||||
pub async fn register_launch_handler(
|
||||
&self,
|
||||
scheme: impl Into<String>,
|
||||
handler: impl LaunchHandler + 'static,
|
||||
) -> io::Result<()> {
|
||||
let handlers = Weak::upgrade(&self.launch_handlers).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handler reference is no longer available",
|
||||
)
|
||||
})?;
|
||||
|
||||
handlers
|
||||
.write()
|
||||
.await
|
||||
.insert(scheme.into(), Box::new(handler));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Registers a new [`ConnectHandler`] for the specified scheme (e.g. "distant" or "ssh")
|
||||
pub async fn register_connect_handler(
|
||||
&self,
|
||||
scheme: impl Into<String>,
|
||||
handler: impl ConnectHandler + 'static,
|
||||
) -> io::Result<()> {
|
||||
let handlers = Weak::upgrade(&self.connect_handlers).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handler reference is no longer available",
|
||||
)
|
||||
})?;
|
||||
|
||||
handlers
|
||||
.write()
|
||||
.await
|
||||
.insert(scheme.into(), Box::new(handler));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerRef for DistantManagerRef {
|
||||
fn state(&self) -> &ServerState {
|
||||
self.inner.state()
|
||||
}
|
||||
|
||||
fn is_finished(&self) -> bool {
|
||||
self.inner.is_finished()
|
||||
}
|
||||
|
||||
fn abort(&self) {
|
||||
self.inner.abort();
|
||||
}
|
||||
}
|
@ -1,162 +0,0 @@
|
||||
use super::{Codec, DataStream, Transport};
|
||||
use futures::stream::Stream;
|
||||
use log::*;
|
||||
use std::{future::Future, pin::Pin};
|
||||
use tokio::{
|
||||
io,
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::mpsc,
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
/// Represents a [`Stream`] consisting of newly-connected [`DataStream`] instances that
|
||||
/// have been wrapped in [`Transport`]
|
||||
pub struct TransportListener<T, U>
|
||||
where
|
||||
T: DataStream,
|
||||
U: Codec,
|
||||
{
|
||||
listen_task: JoinHandle<()>,
|
||||
accept_task: JoinHandle<()>,
|
||||
rx: mpsc::Receiver<Transport<T, U>>,
|
||||
}
|
||||
|
||||
impl<T, U> TransportListener<T, U>
|
||||
where
|
||||
T: DataStream + Send + 'static,
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
pub fn initialize<L, F>(listener: L, mut make_transport: F) -> Self
|
||||
where
|
||||
L: Listener<Output = T> + 'static,
|
||||
F: FnMut(T) -> Transport<T, U> + Send + 'static,
|
||||
{
|
||||
let (stream_tx, mut stream_rx) = mpsc::channel::<T>(1);
|
||||
let listen_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok(stream) => {
|
||||
if stream_tx.send(stream).await.is_err() {
|
||||
error!("Listener failed to pass along stream");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Listener failed to accept stream: {}", x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let (tx, rx) = mpsc::channel::<Transport<T, U>>(1);
|
||||
let accept_task = tokio::spawn(async move {
|
||||
// Check if we have a new connection. If so, wrap it in a transport and forward
|
||||
// it along to
|
||||
while let Some(stream) = stream_rx.recv().await {
|
||||
let transport = make_transport(stream);
|
||||
if let Err(x) = tx.send(transport).await {
|
||||
error!("Failed to forward transport: {}", x);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
listen_task,
|
||||
accept_task,
|
||||
rx,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn abort(&self) {
|
||||
self.listen_task.abort();
|
||||
self.accept_task.abort();
|
||||
}
|
||||
|
||||
/// Waits for the next fully-initialized transport for an incoming stream to be available,
|
||||
/// returning none if no longer accepting new connections
|
||||
pub async fn accept(&mut self) -> Option<Transport<T, U>> {
|
||||
self.rx.recv().await
|
||||
}
|
||||
|
||||
/// Converts into a stream of transport-wrapped connections
|
||||
pub fn into_stream(self) -> impl Stream<Item = Transport<T, U>> {
|
||||
futures::stream::unfold(self, |mut _self| async move {
|
||||
_self
|
||||
.accept()
|
||||
.await
|
||||
.map(move |transport| (transport, _self))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub type AcceptFuture<'a, T> = Pin<Box<dyn Future<Output = io::Result<T>> + Send + 'a>>;
|
||||
|
||||
/// Represents a type that has a listen interface for receiving raw streams
|
||||
pub trait Listener: Send + Sync {
|
||||
type Output;
|
||||
|
||||
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
|
||||
where
|
||||
Self: Sync + 'a;
|
||||
}
|
||||
|
||||
impl Listener for TcpListener {
|
||||
type Output = TcpStream;
|
||||
|
||||
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
|
||||
where
|
||||
Self: Sync + 'a,
|
||||
{
|
||||
async fn accept(_self: &TcpListener) -> io::Result<TcpStream> {
|
||||
_self.accept().await.map(|(stream, _)| stream)
|
||||
}
|
||||
|
||||
Box::pin(accept(self))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl Listener for tokio::net::UnixListener {
|
||||
type Output = tokio::net::UnixStream;
|
||||
|
||||
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
|
||||
where
|
||||
Self: Sync + 'a,
|
||||
{
|
||||
async fn accept(_self: &tokio::net::UnixListener) -> io::Result<tokio::net::UnixStream> {
|
||||
_self.accept().await.map(|(stream, _)| stream)
|
||||
}
|
||||
|
||||
Box::pin(accept(self))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl<T> Listener for tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>
|
||||
where
|
||||
T: DataStream + Send + Sync + 'static,
|
||||
{
|
||||
type Output = T;
|
||||
|
||||
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
|
||||
where
|
||||
Self: Sync + 'a,
|
||||
{
|
||||
async fn accept<T>(
|
||||
_self: &tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>,
|
||||
) -> io::Result<T>
|
||||
where
|
||||
T: DataStream + Send + Sync + 'static,
|
||||
{
|
||||
_self
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
|
||||
}
|
||||
|
||||
Box::pin(accept(self))
|
||||
}
|
||||
}
|
@ -1,583 +0,0 @@
|
||||
use super::{DataStream, PlainCodec, Transport};
|
||||
use futures::ready;
|
||||
use std::{
|
||||
fmt,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::{
|
||||
io::{self, AsyncRead, AsyncWrite, ReadBuf},
|
||||
sync::mpsc,
|
||||
};
|
||||
|
||||
/// Represents a data stream comprised of two inmemory channels
|
||||
#[derive(Debug)]
|
||||
pub struct InmemoryStream {
|
||||
incoming: InmemoryStreamReadHalf,
|
||||
outgoing: InmemoryStreamWriteHalf,
|
||||
}
|
||||
|
||||
impl InmemoryStream {
|
||||
pub fn new(incoming: mpsc::Receiver<Vec<u8>>, outgoing: mpsc::Sender<Vec<u8>>) -> Self {
|
||||
Self {
|
||||
incoming: InmemoryStreamReadHalf::new(incoming),
|
||||
outgoing: InmemoryStreamWriteHalf::new(outgoing),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns (incoming_tx, outgoing_rx, stream)
|
||||
pub fn make(buffer: usize) -> (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>, Self) {
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer);
|
||||
|
||||
(
|
||||
incoming_tx,
|
||||
outgoing_rx,
|
||||
Self::new(incoming_rx, outgoing_tx),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns pair of streams that are connected such that one sends to the other and
|
||||
/// vice versa
|
||||
pub fn pair(buffer: usize) -> (Self, Self) {
|
||||
let (tx, rx, stream) = Self::make(buffer);
|
||||
(stream, Self::new(rx, tx))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for InmemoryStream {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.incoming).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for InmemoryStream {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.outgoing).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.outgoing).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.outgoing).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Read portion of an inmemory channel
|
||||
#[derive(Debug)]
|
||||
pub struct InmemoryStreamReadHalf {
|
||||
rx: mpsc::Receiver<Vec<u8>>,
|
||||
overflow: Vec<u8>,
|
||||
}
|
||||
|
||||
impl InmemoryStreamReadHalf {
|
||||
pub fn new(rx: mpsc::Receiver<Vec<u8>>) -> Self {
|
||||
Self {
|
||||
rx,
|
||||
overflow: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for InmemoryStreamReadHalf {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
// If we cannot fit any more into the buffer at the moment, we wait
|
||||
if buf.remaining() == 0 {
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Cannot poll as buf.remaining() == 0",
|
||||
)));
|
||||
}
|
||||
|
||||
// If we have overflow from the last poll, put that in the buffer
|
||||
if !self.overflow.is_empty() {
|
||||
if self.overflow.len() > buf.remaining() {
|
||||
let extra = self.overflow.split_off(buf.remaining());
|
||||
buf.put_slice(&self.overflow);
|
||||
self.overflow = extra;
|
||||
} else {
|
||||
buf.put_slice(&self.overflow);
|
||||
self.overflow.clear();
|
||||
}
|
||||
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// Otherwise, we poll for the next batch to read in
|
||||
match ready!(self.rx.poll_recv(cx)) {
|
||||
Some(mut x) => {
|
||||
if x.len() > buf.remaining() {
|
||||
self.overflow = x.split_off(buf.remaining());
|
||||
}
|
||||
buf.put_slice(&x);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
None => Poll::Ready(Ok(())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write portion of an inmemory channel
|
||||
pub struct InmemoryStreamWriteHalf {
|
||||
tx: Option<mpsc::Sender<Vec<u8>>>,
|
||||
task: Option<Pin<Box<dyn Future<Output = io::Result<usize>> + Send + Sync + 'static>>>,
|
||||
}
|
||||
|
||||
impl InmemoryStreamWriteHalf {
|
||||
pub fn new(tx: mpsc::Sender<Vec<u8>>) -> Self {
|
||||
Self {
|
||||
tx: Some(tx),
|
||||
task: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for InmemoryStreamWriteHalf {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("InmemoryStreamWriteHalf")
|
||||
.field("tx", &self.tx)
|
||||
.field(
|
||||
"task",
|
||||
&if self.tx.is_some() {
|
||||
"Some(...)"
|
||||
} else {
|
||||
"None"
|
||||
},
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for InmemoryStreamWriteHalf {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
loop {
|
||||
match self.task.as_mut() {
|
||||
Some(task) => {
|
||||
let res = ready!(task.as_mut().poll(cx));
|
||||
self.task.take();
|
||||
return Poll::Ready(res);
|
||||
}
|
||||
None => match self.tx.as_mut() {
|
||||
Some(tx) => {
|
||||
let n = buf.len();
|
||||
let tx_2 = tx.clone();
|
||||
let data = buf.to_vec();
|
||||
let task =
|
||||
Box::pin(async move { tx_2.send(data).await.map(|_| n).or(Ok(0)) });
|
||||
self.task.replace(task);
|
||||
}
|
||||
None => return Poll::Ready(Ok(0)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.tx.take();
|
||||
self.task.take();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl DataStream for InmemoryStream {
|
||||
type Read = InmemoryStreamReadHalf;
|
||||
type Write = InmemoryStreamWriteHalf;
|
||||
|
||||
fn to_connection_tag(&self) -> String {
|
||||
String::from("inmemory-stream")
|
||||
}
|
||||
|
||||
fn into_split(self) -> (Self::Read, Self::Write) {
|
||||
(self.incoming, self.outgoing)
|
||||
}
|
||||
}
|
||||
|
||||
impl Transport<InmemoryStream, PlainCodec> {
|
||||
/// Produces a pair of inmemory transports that are connected to each other using
|
||||
/// a standard codec
|
||||
///
|
||||
/// Sets the buffer for message passing for each underlying stream to the given buffer size
|
||||
pub fn pair(
|
||||
buffer: usize,
|
||||
) -> (
|
||||
Transport<InmemoryStream, PlainCodec>,
|
||||
Transport<InmemoryStream, PlainCodec>,
|
||||
) {
|
||||
let (a, b) = InmemoryStream::pair(buffer);
|
||||
let a = Transport::new(a, PlainCodec::new());
|
||||
let b = Transport::new(b, PlainCodec::new());
|
||||
(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
#[test]
|
||||
fn to_connection_tag_should_be_hardcoded_string() {
|
||||
let (_, _, stream) = InmemoryStream::make(1);
|
||||
assert_eq!(stream.to_connection_tag(), "inmemory-stream");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn make_should_return_sender_that_sends_data_to_stream() {
|
||||
let (tx, _, mut stream) = InmemoryStream::make(3);
|
||||
|
||||
tx.send(b"test msg 1".to_vec()).await.unwrap();
|
||||
tx.send(b"test msg 2".to_vec()).await.unwrap();
|
||||
tx.send(b"test msg 3".to_vec()).await.unwrap();
|
||||
|
||||
// Should get data matching a singular message
|
||||
let mut buf = [0; 256];
|
||||
let len = stream.read(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..len], b"test msg 1");
|
||||
|
||||
// Next call would get the second message
|
||||
let len = stream.read(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..len], b"test msg 2");
|
||||
|
||||
// When the last of the senders is dropped, we should still get
|
||||
// the rest of the data that was sent first before getting
|
||||
// an indicator that there is no more data
|
||||
drop(tx);
|
||||
|
||||
let len = stream.read(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..len], b"test msg 3");
|
||||
|
||||
let len = stream.read(&mut buf).await.unwrap();
|
||||
assert_eq!(len, 0, "Unexpectedly got more data");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn make_should_return_receiver_that_receives_data_from_stream() {
|
||||
let (_, mut rx, mut stream) = InmemoryStream::make(3);
|
||||
|
||||
stream.write_all(b"test msg 1").await.unwrap();
|
||||
stream.write_all(b"test msg 2").await.unwrap();
|
||||
stream.write_all(b"test msg 3").await.unwrap();
|
||||
|
||||
// Should get data matching a singular message
|
||||
assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec()));
|
||||
|
||||
// Next call would get the second message
|
||||
assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec()));
|
||||
|
||||
// When the stream is dropped, we should still get
|
||||
// the rest of the data that was sent first before getting
|
||||
// an indicator that there is no more data
|
||||
drop(stream);
|
||||
|
||||
assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec()));
|
||||
|
||||
assert_eq!(rx.recv().await, None, "Unexpectedly got more data");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn into_split_should_provide_a_read_half_that_receives_from_sender() {
|
||||
let (tx, _, stream) = InmemoryStream::make(3);
|
||||
let (mut read_half, _) = stream.into_split();
|
||||
|
||||
tx.send(b"test msg 1".to_vec()).await.unwrap();
|
||||
tx.send(b"test msg 2".to_vec()).await.unwrap();
|
||||
tx.send(b"test msg 3".to_vec()).await.unwrap();
|
||||
|
||||
// Should get data matching a singular message
|
||||
let mut buf = [0; 256];
|
||||
let len = read_half.read(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..len], b"test msg 1");
|
||||
|
||||
// Next call would get the second message
|
||||
let len = read_half.read(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..len], b"test msg 2");
|
||||
|
||||
// When the last of the senders is dropped, we should still get
|
||||
// the rest of the data that was sent first before getting
|
||||
// an indicator that there is no more data
|
||||
drop(tx);
|
||||
|
||||
let len = read_half.read(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..len], b"test msg 3");
|
||||
|
||||
let len = read_half.read(&mut buf).await.unwrap();
|
||||
assert_eq!(len, 0, "Unexpectedly got more data");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn into_split_should_provide_a_write_half_that_sends_to_receiver() {
|
||||
let (_, mut rx, stream) = InmemoryStream::make(3);
|
||||
let (_, mut write_half) = stream.into_split();
|
||||
|
||||
write_half.write_all(b"test msg 1").await.unwrap();
|
||||
write_half.write_all(b"test msg 2").await.unwrap();
|
||||
write_half.write_all(b"test msg 3").await.unwrap();
|
||||
|
||||
// Should get data matching a singular message
|
||||
assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec()));
|
||||
|
||||
// Next call would get the second message
|
||||
assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec()));
|
||||
|
||||
// When the stream is dropped, we should still get
|
||||
// the rest of the data that was sent first before getting
|
||||
// an indicator that there is no more data
|
||||
drop(write_half);
|
||||
|
||||
assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec()));
|
||||
|
||||
assert_eq!(rx.recv().await, None, "Unexpectedly got more data");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_half_should_fail_if_buf_has_no_space_remaining() {
|
||||
let (_tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let (mut t_read, _t_write) = stream.into_split();
|
||||
|
||||
let mut buf = [0u8; 0];
|
||||
match t_read.read(&mut buf).await {
|
||||
Err(x) if x.kind() == io::ErrorKind::Other => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_half_should_update_buf_with_all_overflow_from_last_read_if_it_all_fits() {
|
||||
let (tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let (mut t_read, _t_write) = stream.into_split();
|
||||
|
||||
tx.send(vec![1, 2, 3]).await.expect("Failed to send");
|
||||
|
||||
let mut buf = [0u8; 2];
|
||||
|
||||
// First, read part of the data (first two bytes)
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
|
||||
// Second, we send more data because the last message was placed in overflow
|
||||
tx.send(vec![4, 5, 6]).await.expect("Failed to send");
|
||||
|
||||
// Third, read remainder of the overflow from first message (third byte)
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[3]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
|
||||
// Fourth, verify that we start to receive the next overflow
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[4, 5]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
|
||||
// Fifth, verify that we get the last bit of overflow
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[6]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_half_should_update_buf_with_some_of_overflow_that_can_fit() {
|
||||
let (tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let (mut t_read, _t_write) = stream.into_split();
|
||||
|
||||
tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send");
|
||||
|
||||
let mut buf = [0u8; 2];
|
||||
|
||||
// First, read part of the data (first two bytes)
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
|
||||
// Second, we send more data because the last message was placed in overflow
|
||||
tx.send(vec![6]).await.expect("Failed to send");
|
||||
|
||||
// Third, read next chunk of the overflow from first message (next two byte)
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[3, 4]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
|
||||
// Fourth, read last chunk of the overflow from first message (fifth byte)
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[5]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_half_should_update_buf_with_all_of_inner_channel_when_it_fits() {
|
||||
let (tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let (mut t_read, _t_write) = stream.into_split();
|
||||
|
||||
let mut buf = [0u8; 5];
|
||||
|
||||
tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send");
|
||||
|
||||
// First, read all of data that fits exactly
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 5 => assert_eq!(&buf[..n], &[1, 2, 3, 4, 5]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
|
||||
tx.send(vec![6, 7, 8]).await.expect("Failed to send");
|
||||
|
||||
// Second, read data that fits within buf
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 3 => assert_eq!(&buf[..n], &[6, 7, 8]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_half_should_update_buf_with_some_of_inner_channel_that_can_fit_and_add_rest_to_overflow(
|
||||
) {
|
||||
let (tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let (mut t_read, _t_write) = stream.into_split();
|
||||
|
||||
let mut buf = [0u8; 1];
|
||||
|
||||
tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send");
|
||||
|
||||
// Attempt a read that places more in overflow
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[1]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
|
||||
// Verify overflow contains the rest
|
||||
assert_eq!(&t_read.overflow, &[2, 3, 4, 5]);
|
||||
|
||||
// Queue up extra data that will not be read until overflow is finished
|
||||
tx.send(vec![6, 7, 8]).await.expect("Failed to send");
|
||||
|
||||
// Read next data point
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[2]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
|
||||
// Verify overflow contains the rest without having added extra data
|
||||
assert_eq!(&t_read.overflow, &[3, 4, 5]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_half_should_yield_pending_if_no_data_available_on_inner_channel() {
|
||||
let (_tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let (mut t_read, _t_write) = stream.into_split();
|
||||
|
||||
let mut buf = [0u8; 1];
|
||||
|
||||
// Attempt a read that should yield ok with no change, which is what should
|
||||
// happen when nothing is read into buf
|
||||
let f = t_read.read(&mut buf);
|
||||
tokio::pin!(f);
|
||||
match futures::poll!(f) {
|
||||
Poll::Pending => {}
|
||||
x => panic!("Unexpected poll result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_half_should_not_update_buf_if_inner_channel_closed() {
|
||||
let (tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let (mut t_read, _t_write) = stream.into_split();
|
||||
|
||||
let mut buf = [0u8; 1];
|
||||
|
||||
// Drop the channel that would be sending data to the transport
|
||||
drop(tx);
|
||||
|
||||
// Attempt a read that should yield ok with no change, which is what should
|
||||
// happen when nothing is read into buf
|
||||
match t_read.read(&mut buf).await {
|
||||
Ok(n) if n == 0 => assert_eq!(&buf, &[0]),
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_half_should_return_buf_len_if_can_send_immediately() {
|
||||
let (_tx, mut rx, stream) = InmemoryStream::make(1);
|
||||
let (_t_read, mut t_write) = stream.into_split();
|
||||
|
||||
// Write that is not waiting should always succeed with full contents
|
||||
let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write");
|
||||
assert_eq!(n, 3, "Unexpected byte count returned");
|
||||
|
||||
// Verify we actually had the data sent
|
||||
let data = rx.try_recv().expect("Failed to recv data");
|
||||
assert_eq!(data, &[1, 2, 3]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_half_should_return_support_eventually_sending_by_retrying_when_not_ready() {
|
||||
let (_tx, mut rx, stream) = InmemoryStream::make(1);
|
||||
let (_t_read, mut t_write) = stream.into_split();
|
||||
|
||||
// Queue a write already so that we block on the next one
|
||||
let _ = t_write.write(&[1, 2, 3]).await.expect("Failed to write");
|
||||
|
||||
// Verify that the next write is pending
|
||||
let f = t_write.write(&[4, 5]);
|
||||
tokio::pin!(f);
|
||||
match futures::poll!(&mut f) {
|
||||
Poll::Pending => {}
|
||||
x => panic!("Unexpected poll result: {:?}", x),
|
||||
}
|
||||
|
||||
// Consume first batch of data so future of second can continue
|
||||
let data = rx.try_recv().expect("Failed to recv data");
|
||||
assert_eq!(data, &[1, 2, 3]);
|
||||
|
||||
// Verify that poll now returns success
|
||||
match futures::poll!(f) {
|
||||
Poll::Ready(Ok(n)) if n == 2 => {}
|
||||
x => panic!("Unexpected poll result: {:?}", x),
|
||||
}
|
||||
|
||||
// Consume second batch of data
|
||||
let data = rx.try_recv().expect("Failed to recv data");
|
||||
assert_eq!(data, &[4, 5]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_half_should_zero_if_inner_channel_closed() {
|
||||
let (_tx, rx, stream) = InmemoryStream::make(1);
|
||||
let (_t_read, mut t_write) = stream.into_split();
|
||||
|
||||
// Drop receiving end that transport would talk to
|
||||
drop(rx);
|
||||
|
||||
// Channel is dropped, so return 0 to indicate no bytes sent
|
||||
let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write");
|
||||
assert_eq!(n, 0, "Unexpected byte count returned");
|
||||
}
|
||||
}
|
@ -1,399 +0,0 @@
|
||||
use crate::net::SecretKeyError;
|
||||
use derive_more::{Display, Error, From};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::marker::Unpin;
|
||||
use tokio::io::{self, AsyncRead, AsyncWrite};
|
||||
use tokio_util::codec::{Framed, FramedRead, FramedWrite};
|
||||
|
||||
mod codec;
|
||||
pub use codec::*;
|
||||
|
||||
mod inmemory;
|
||||
pub use inmemory::*;
|
||||
|
||||
mod tcp;
|
||||
pub use tcp::*;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub struct SerializeError(#[error(not(source))] String);
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub struct DeserializeError(#[error(not(source))] String);
|
||||
|
||||
fn serialize_to_vec<T: Serialize>(value: &T) -> Result<Vec<u8>, SerializeError> {
|
||||
let mut v = Vec::new();
|
||||
|
||||
let _ = ciborium::ser::into_writer(value, &mut v).map_err(|x| SerializeError(x.to_string()))?;
|
||||
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn deserialize_from_slice<T: DeserializeOwned>(slice: &[u8]) -> Result<T, DeserializeError> {
|
||||
ciborium::de::from_reader(slice).map_err(|x| DeserializeError(x.to_string()))
|
||||
}
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum TransportError {
|
||||
CryptoError(SecretKeyError),
|
||||
IoError(io::Error),
|
||||
SerializeError(SerializeError),
|
||||
DeserializeError(DeserializeError),
|
||||
}
|
||||
|
||||
/// Interface representing a two-way data stream
|
||||
///
|
||||
/// Enables splitting into separate, functioning halves that can read and write respectively
|
||||
pub trait DataStream: AsyncRead + AsyncWrite + Unpin {
|
||||
type Read: AsyncRead + Send + Unpin + 'static;
|
||||
type Write: AsyncWrite + Send + Unpin + 'static;
|
||||
|
||||
/// Returns a textual description of the connection associated with this stream
|
||||
fn to_connection_tag(&self) -> String;
|
||||
|
||||
/// Splits this stream into read and write halves
|
||||
fn into_split(self) -> (Self::Read, Self::Write);
|
||||
}
|
||||
|
||||
/// Represents a transport of data across the network
|
||||
#[derive(Debug)]
|
||||
pub struct Transport<T, U>(Framed<T, U>)
|
||||
where
|
||||
T: DataStream,
|
||||
U: Codec;
|
||||
|
||||
impl<T, U> Transport<T, U>
|
||||
where
|
||||
T: DataStream,
|
||||
U: Codec,
|
||||
{
|
||||
/// Creates a new instance of the transport, wrapping the stream in a `Framed<T, XChaCha20Poly1305Codec>`
|
||||
pub fn new(stream: T, codec: U) -> Self {
|
||||
Self(Framed::new(stream, codec))
|
||||
}
|
||||
|
||||
/// Sends some data across the wire, waiting for it to completely send
|
||||
pub async fn send<D: Serialize>(&mut self, data: D) -> Result<(), TransportError> {
|
||||
// Serialize data into a byte stream
|
||||
// NOTE: Cannot used packed implementation for now due to issues with deserialization
|
||||
let data = serialize_to_vec(&data)?;
|
||||
|
||||
// Use underlying codec to send data (may encrypt, sign, etc.)
|
||||
self.0.send(&data).await.map_err(TransportError::from)
|
||||
}
|
||||
|
||||
/// Receives some data from out on the wire, waiting until it's available,
|
||||
/// returning none if the transport is now closed
|
||||
pub async fn receive<R: DeserializeOwned>(&mut self) -> Result<Option<R>, TransportError> {
|
||||
// Use underlying codec to receive data (may decrypt, validate, etc.)
|
||||
if let Some(data) = self.0.next().await {
|
||||
let data = data?;
|
||||
|
||||
// Deserialize byte stream into our expected type
|
||||
let data = deserialize_from_slice(&data)?;
|
||||
Ok(Some(data))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a textual description of the transport's underlying connection
|
||||
pub fn to_connection_tag(&self) -> String {
|
||||
self.0.get_ref().to_connection_tag()
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying I/O stream
|
||||
///
|
||||
/// Note that care should be taken to not tamper with the underlying stream of data coming in
|
||||
/// as it may corrupt the stream of frames otherwise being worked with
|
||||
pub fn get_ref(&self) -> &T {
|
||||
self.0.get_ref()
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying I/O stream
|
||||
///
|
||||
/// Note that care should be taken to not tamper with the underlying stream of data coming in
|
||||
/// as it may corrupt the stream of frames otherwise being worked with
|
||||
pub fn get_mut(&mut self) -> &mut T {
|
||||
self.0.get_mut()
|
||||
}
|
||||
|
||||
/// Consumes the transport, returning its underlying I/O stream
|
||||
///
|
||||
/// Note that care should be taken to not tamper with the underlying stream of data coming in
|
||||
/// as it may corrupt the stream of frames otherwise being worked with.
|
||||
pub fn into_inner(self) -> T {
|
||||
self.0.into_inner()
|
||||
}
|
||||
|
||||
/// Splits transport into read and write halves
|
||||
pub fn into_split(
|
||||
self,
|
||||
) -> (
|
||||
TransportReadHalf<T::Read, U>,
|
||||
TransportWriteHalf<T::Write, U>,
|
||||
) {
|
||||
let parts = self.0.into_parts();
|
||||
let (read_half, write_half) = parts.io.into_split();
|
||||
|
||||
// Create our split read half and populate its buffer with existing contents
|
||||
let mut f_read = FramedRead::new(read_half, parts.codec.clone());
|
||||
*f_read.read_buffer_mut() = parts.read_buf;
|
||||
|
||||
// Create our split write half and populate its buffer with existing contents
|
||||
let mut f_write = FramedWrite::new(write_half, parts.codec);
|
||||
*f_write.write_buffer_mut() = parts.write_buf;
|
||||
|
||||
let t_read = TransportReadHalf(f_read);
|
||||
let t_write = TransportWriteHalf(f_write);
|
||||
|
||||
(t_read, t_write)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a transport of data out to the network
|
||||
pub struct TransportWriteHalf<T, U>(FramedWrite<T, U>)
|
||||
where
|
||||
T: AsyncWrite + Unpin,
|
||||
U: Codec;
|
||||
|
||||
impl<T, U> TransportWriteHalf<T, U>
|
||||
where
|
||||
T: AsyncWrite + Unpin,
|
||||
U: Codec,
|
||||
{
|
||||
/// Sends some data across the wire, waiting for it to completely send
|
||||
pub async fn send<D: Serialize>(&mut self, data: D) -> Result<(), TransportError> {
|
||||
// Serialize data into a byte stream
|
||||
// NOTE: Cannot used packed implementation for now due to issues with deserialization
|
||||
let data = serialize_to_vec(&data)?;
|
||||
|
||||
// Use underlying codec to send data (may encrypt, sign, etc.)
|
||||
self.0.send(&data).await.map_err(TransportError::from)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a transport of data in from the network
|
||||
pub struct TransportReadHalf<T, U>(FramedRead<T, U>)
|
||||
where
|
||||
T: AsyncRead + Unpin,
|
||||
U: Codec;
|
||||
|
||||
impl<T, U> TransportReadHalf<T, U>
|
||||
where
|
||||
T: AsyncRead + Unpin,
|
||||
U: Codec,
|
||||
{
|
||||
/// Receives some data from out on the wire, waiting until it's available,
|
||||
/// returning none if the transport is now closed
|
||||
pub async fn receive<R: DeserializeOwned>(&mut self) -> Result<Option<R>, TransportError> {
|
||||
// Use underlying codec to receive data (may decrypt, validate, etc.)
|
||||
if let Some(data) = self.0.next().await {
|
||||
let data = data?;
|
||||
|
||||
// Deserialize byte stream into our expected type
|
||||
let data = deserialize_from_slice(&data)?;
|
||||
Ok(Some(data))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Test utilities
|
||||
#[cfg(test)]
|
||||
impl Transport<crate::net::InmemoryStream, crate::net::PlainCodec> {
|
||||
/// Makes a connected pair of inmemory transports
|
||||
pub fn make_pair() -> (
|
||||
Transport<crate::net::InmemoryStream, crate::net::PlainCodec>,
|
||||
Transport<crate::net::InmemoryStream, crate::net::PlainCodec>,
|
||||
) {
|
||||
Self::pair(crate::constants::test::BUFFER_SIZE)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct TestData {
|
||||
name: String,
|
||||
value: usize,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_should_convert_data_into_byte_stream_and_send_through_stream() {
|
||||
let (_tx, mut rx, stream) = InmemoryStream::make(1);
|
||||
let mut transport = Transport::new(stream, PlainCodec::new());
|
||||
|
||||
let data = TestData {
|
||||
name: String::from("test"),
|
||||
value: 123,
|
||||
};
|
||||
|
||||
let bytes = serialize_to_vec(&data).unwrap();
|
||||
let len = (bytes.len() as u64).to_be_bytes();
|
||||
let mut frame = Vec::new();
|
||||
frame.extend(len.iter().copied());
|
||||
frame.extend(bytes);
|
||||
|
||||
transport.send(data).await.unwrap();
|
||||
|
||||
let outgoing = rx.recv().await.unwrap();
|
||||
assert_eq!(outgoing, frame);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_should_return_none_if_stream_is_closed() {
|
||||
let (_, _, stream) = InmemoryStream::make(1);
|
||||
let mut transport = Transport::new(stream, PlainCodec::new());
|
||||
|
||||
let result = transport.receive::<TestData>().await;
|
||||
match result {
|
||||
Ok(None) => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_should_fail_if_unable_to_convert_to_type() {
|
||||
let (tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let mut transport = Transport::new(stream, PlainCodec::new());
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct OtherTestData(usize);
|
||||
|
||||
let data = OtherTestData(123);
|
||||
let bytes = serialize_to_vec(&data).unwrap();
|
||||
let len = (bytes.len() as u64).to_be_bytes();
|
||||
let mut frame = Vec::new();
|
||||
frame.extend(len.iter().copied());
|
||||
frame.extend(bytes);
|
||||
|
||||
tx.send(frame).await.unwrap();
|
||||
let result = transport.receive::<TestData>().await;
|
||||
match result {
|
||||
Err(TransportError::DeserializeError(_)) => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_should_return_some_instance_of_type_when_coming_into_stream() {
|
||||
let (tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let mut transport = Transport::new(stream, PlainCodec::new());
|
||||
|
||||
let data = TestData {
|
||||
name: String::from("test"),
|
||||
value: 123,
|
||||
};
|
||||
|
||||
let bytes = serialize_to_vec(&data).unwrap();
|
||||
let len = (bytes.len() as u64).to_be_bytes();
|
||||
let mut frame = Vec::new();
|
||||
frame.extend(len.iter().copied());
|
||||
frame.extend(bytes);
|
||||
|
||||
tx.send(frame).await.unwrap();
|
||||
let received_data = transport.receive::<TestData>().await.unwrap().unwrap();
|
||||
assert_eq!(received_data, data);
|
||||
}
|
||||
|
||||
mod read_half {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_should_return_none_if_stream_is_closed() {
|
||||
let (_, _, stream) = InmemoryStream::make(1);
|
||||
let transport = Transport::new(stream, PlainCodec::new());
|
||||
let (mut rh, _) = transport.into_split();
|
||||
|
||||
let result = rh.receive::<TestData>().await;
|
||||
match result {
|
||||
Ok(None) => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_should_fail_if_unable_to_convert_to_type() {
|
||||
let (tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let transport = Transport::new(stream, PlainCodec::new());
|
||||
let (mut rh, _) = transport.into_split();
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct OtherTestData(usize);
|
||||
|
||||
let data = OtherTestData(123);
|
||||
let bytes = serialize_to_vec(&data).unwrap();
|
||||
let len = (bytes.len() as u64).to_be_bytes();
|
||||
let mut frame = Vec::new();
|
||||
frame.extend(len.iter().copied());
|
||||
frame.extend(bytes);
|
||||
|
||||
tx.send(frame).await.unwrap();
|
||||
let result = rh.receive::<TestData>().await;
|
||||
match result {
|
||||
Err(TransportError::DeserializeError(_)) => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_should_return_some_instance_of_type_when_coming_into_stream() {
|
||||
let (tx, _rx, stream) = InmemoryStream::make(1);
|
||||
let transport = Transport::new(stream, PlainCodec::new());
|
||||
let (mut rh, _) = transport.into_split();
|
||||
|
||||
let data = TestData {
|
||||
name: String::from("test"),
|
||||
value: 123,
|
||||
};
|
||||
|
||||
let bytes = serialize_to_vec(&data).unwrap();
|
||||
let len = (bytes.len() as u64).to_be_bytes();
|
||||
let mut frame = Vec::new();
|
||||
frame.extend(len.iter().copied());
|
||||
frame.extend(bytes);
|
||||
|
||||
tx.send(frame).await.unwrap();
|
||||
let received_data = rh.receive::<TestData>().await.unwrap().unwrap();
|
||||
assert_eq!(received_data, data);
|
||||
}
|
||||
}
|
||||
|
||||
mod write_half {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_should_convert_data_into_byte_stream_and_send_through_stream() {
|
||||
let (_tx, mut rx, stream) = InmemoryStream::make(1);
|
||||
let transport = Transport::new(stream, PlainCodec::new());
|
||||
let (_, mut wh) = transport.into_split();
|
||||
|
||||
let data = TestData {
|
||||
name: String::from("test"),
|
||||
value: 123,
|
||||
};
|
||||
|
||||
let bytes = serialize_to_vec(&data).unwrap();
|
||||
let len = (bytes.len() as u64).to_be_bytes();
|
||||
let mut frame = Vec::new();
|
||||
frame.extend(len.iter().copied());
|
||||
frame.extend(bytes);
|
||||
|
||||
wh.send(data).await.unwrap();
|
||||
|
||||
let outgoing = rx.recv().await.unwrap();
|
||||
assert_eq!(outgoing, frame);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
use super::{Codec, DataStream, Transport};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::{
|
||||
io,
|
||||
net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream, ToSocketAddrs,
|
||||
},
|
||||
};
|
||||
|
||||
impl DataStream for TcpStream {
|
||||
type Read = OwnedReadHalf;
|
||||
type Write = OwnedWriteHalf;
|
||||
|
||||
fn to_connection_tag(&self) -> String {
|
||||
self.peer_addr()
|
||||
.map(|addr| format!("{}", addr))
|
||||
.unwrap_or_else(|_| String::from("--"))
|
||||
}
|
||||
|
||||
fn into_split(self) -> (Self::Read, Self::Write) {
|
||||
TcpStream::into_split(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<U: Codec> Transport<TcpStream, U> {
|
||||
/// Establishes a connection to one of the specified addresses and uses the provided codec
|
||||
/// for transportation
|
||||
pub async fn connect(addrs: impl ToSocketAddrs, codec: U) -> io::Result<Self> {
|
||||
let stream = TcpStream::connect(addrs).await?;
|
||||
Ok(Transport::new(stream, codec))
|
||||
}
|
||||
|
||||
/// Returns the address of the peer the transport is connected to
|
||||
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||
self.0.get_ref().peer_addr()
|
||||
}
|
||||
}
|
@ -1,37 +0,0 @@
|
||||
use super::{Codec, DataStream, Transport};
|
||||
use tokio::{
|
||||
io,
|
||||
net::{
|
||||
unix::{OwnedReadHalf, OwnedWriteHalf, SocketAddr},
|
||||
UnixStream,
|
||||
},
|
||||
};
|
||||
|
||||
impl DataStream for UnixStream {
|
||||
type Read = OwnedReadHalf;
|
||||
type Write = OwnedWriteHalf;
|
||||
|
||||
fn to_connection_tag(&self) -> String {
|
||||
self.peer_addr()
|
||||
.map(|addr| format!("{:?}", addr))
|
||||
.unwrap_or_else(|_| String::from("--"))
|
||||
}
|
||||
|
||||
fn into_split(self) -> (Self::Read, Self::Write) {
|
||||
UnixStream::into_split(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<U: Codec> Transport<UnixStream, U> {
|
||||
/// Establishes a connection to the socket at the specified path and uses the provided codec
|
||||
/// for transportation
|
||||
pub async fn connect(path: impl AsRef<std::path::Path>, codec: U) -> io::Result<Self> {
|
||||
let stream = UnixStream::connect(path.as_ref()).await?;
|
||||
Ok(Transport::new(stream, codec))
|
||||
}
|
||||
|
||||
/// Returns the address of the peer the transport is connected to
|
||||
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||
self.0.get_ref().peer_addr()
|
||||
}
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
use serde::{
|
||||
de::{Deserializer, Error as SerdeError, Visitor},
|
||||
ser::Serializer,
|
||||
};
|
||||
use std::{fmt, marker::PhantomData, str::FromStr};
|
||||
|
||||
/// From https://docs.rs/serde_with/1.14.0/src/serde_with/rust.rs.html#90-118
|
||||
pub fn deserialize_from_str<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
T: FromStr,
|
||||
T::Err: fmt::Display,
|
||||
{
|
||||
struct Helper<S>(PhantomData<S>);
|
||||
|
||||
impl<'de, S> Visitor<'de> for Helper<S>
|
||||
where
|
||||
S: FromStr,
|
||||
<S as FromStr>::Err: fmt::Display,
|
||||
{
|
||||
type Value = S;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(formatter, "a string")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: SerdeError,
|
||||
{
|
||||
value.parse::<Self::Value>().map_err(SerdeError::custom)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(Helper(PhantomData))
|
||||
}
|
||||
|
||||
/// From https://docs.rs/serde_with/1.14.0/src/serde_with/rust.rs.html#121-127
|
||||
pub fn serialize_to_str<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
T: fmt::Display,
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.collect_str(&value)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,362 +0,0 @@
|
||||
mod handler;
|
||||
mod process;
|
||||
mod state;
|
||||
|
||||
pub(crate) use process::{InputChannel, ProcessKiller, ProcessPty};
|
||||
use state::State;
|
||||
|
||||
use crate::{
|
||||
constants::MAX_MSG_CAPACITY,
|
||||
data::{Request, Response},
|
||||
net::{Codec, DataStream, Transport, TransportListener, TransportReadHalf, TransportWriteHalf},
|
||||
server::{
|
||||
utils::{ConnTracker, ShutdownTask},
|
||||
PortRange,
|
||||
},
|
||||
};
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use log::*;
|
||||
use std::{net::IpAddr, sync::Arc};
|
||||
use tokio::{
|
||||
io::{self, AsyncRead, AsyncWrite},
|
||||
net::TcpListener,
|
||||
sync::{mpsc, Mutex},
|
||||
task::{JoinError, JoinHandle},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
/// Represents a server that listens for requests, processes them, and sends responses
|
||||
pub struct DistantServer {
|
||||
conn_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct DistantServerOptions {
|
||||
pub shutdown_after: Option<Duration>,
|
||||
pub max_msg_capacity: usize,
|
||||
}
|
||||
|
||||
impl Default for DistantServerOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
shutdown_after: None,
|
||||
max_msg_capacity: MAX_MSG_CAPACITY,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DistantServer {
|
||||
/// Bind to an IP address and port from the given range, taking an optional shutdown duration
|
||||
/// that will shutdown the server if there is no active connection after duration
|
||||
pub async fn bind<U>(
|
||||
addr: IpAddr,
|
||||
port: PortRange,
|
||||
codec: U,
|
||||
opts: DistantServerOptions,
|
||||
) -> io::Result<(Self, u16)>
|
||||
where
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
debug!("Binding to {} in range {}", addr, port);
|
||||
let listener = TcpListener::bind(port.make_socket_addrs(addr).as_slice()).await?;
|
||||
|
||||
let port = listener.local_addr()?.port();
|
||||
debug!("Bound to port: {}", port);
|
||||
|
||||
let stream = TransportListener::initialize(listener, move |stream| {
|
||||
Transport::new(stream, codec.clone())
|
||||
})
|
||||
.into_stream();
|
||||
|
||||
Ok((Self::initialize(Box::pin(stream), opts), port))
|
||||
}
|
||||
|
||||
/// Initialize a distant server using the provided listener
|
||||
pub fn initialize<T, U, S>(stream: S, opts: DistantServerOptions) -> Self
|
||||
where
|
||||
T: DataStream + Send + 'static,
|
||||
U: Codec + Send + 'static,
|
||||
S: Stream<Item = Transport<T, U>> + Send + Unpin + 'static,
|
||||
{
|
||||
// Build our state for the server
|
||||
let state: Arc<Mutex<State>> = Arc::new(Mutex::new(State::default()));
|
||||
let (shutdown, tracker) = ShutdownTask::maybe_initialize(opts.shutdown_after);
|
||||
|
||||
// Spawn our connection task
|
||||
let conn_task = tokio::spawn(async move {
|
||||
connection_loop(stream, state, tracker, shutdown, opts.max_msg_capacity).await
|
||||
});
|
||||
|
||||
Self { conn_task }
|
||||
}
|
||||
|
||||
/// Waits for the server to terminate
|
||||
pub async fn wait(self) -> Result<(), JoinError> {
|
||||
self.conn_task.await
|
||||
}
|
||||
|
||||
/// Aborts the server by aborting the internal task handling new connections
|
||||
pub fn abort(&self) {
|
||||
self.conn_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
async fn connection_loop<T, U, S>(
|
||||
mut stream: S,
|
||||
state: Arc<Mutex<State>>,
|
||||
tracker: Option<Arc<Mutex<ConnTracker>>>,
|
||||
shutdown: Option<ShutdownTask>,
|
||||
max_msg_capacity: usize,
|
||||
) where
|
||||
T: DataStream + Send + 'static,
|
||||
U: Codec + Send + 'static,
|
||||
S: Stream<Item = Transport<T, U>> + Send + Unpin + 'static,
|
||||
{
|
||||
let inner = async move {
|
||||
loop {
|
||||
match stream.next().await {
|
||||
Some(transport) => {
|
||||
let conn_id = rand::random();
|
||||
debug!(
|
||||
"<Conn @ {}> Established against {}",
|
||||
conn_id,
|
||||
transport.to_connection_tag()
|
||||
);
|
||||
if let Err(x) = on_new_conn(
|
||||
transport,
|
||||
conn_id,
|
||||
Arc::clone(&state),
|
||||
tracker.as_ref().map(Arc::clone),
|
||||
max_msg_capacity,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("<Conn @ {}> Failed handshake: {}", conn_id, x);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
info!("Listener shutting down");
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
match shutdown {
|
||||
Some(shutdown) => tokio::select! {
|
||||
_ = inner => {}
|
||||
_ = shutdown => {
|
||||
warn!("Reached shutdown timeout, so terminating");
|
||||
}
|
||||
},
|
||||
None => inner.await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Processes a new connection, performing a handshake, and then spawning two tasks to handle
|
||||
/// input and output, returning join handles for the input and output tasks respectively
|
||||
async fn on_new_conn<T, U>(
|
||||
transport: Transport<T, U>,
|
||||
conn_id: usize,
|
||||
state: Arc<Mutex<State>>,
|
||||
tracker: Option<Arc<Mutex<ConnTracker>>>,
|
||||
max_msg_capacity: usize,
|
||||
) -> io::Result<JoinHandle<()>>
|
||||
where
|
||||
T: DataStream,
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
// Update our tracker to reflect the new connection
|
||||
if let Some(ct) = tracker.as_ref() {
|
||||
ct.lock().await.increment();
|
||||
}
|
||||
|
||||
// Split the transport into read and write halves so we can handle input
|
||||
// and output concurrently
|
||||
let (t_read, t_write) = transport.into_split();
|
||||
let (tx, rx) = mpsc::channel(max_msg_capacity);
|
||||
|
||||
// Spawn a new task that loops to handle requests from the client
|
||||
let state_2 = Arc::clone(&state);
|
||||
let req_task = tokio::spawn(async move {
|
||||
request_loop(conn_id, state_2, t_read, tx).await;
|
||||
});
|
||||
|
||||
// Spawn a new task that loops to handle responses to the client
|
||||
let res_task = tokio::spawn(async move { response_loop(conn_id, t_write, rx).await });
|
||||
|
||||
// Spawn cleanup task that waits on our req & res tasks to complete
|
||||
let cleanup_task = tokio::spawn(async move {
|
||||
// Wait for both receiving and sending tasks to complete before marking
|
||||
// the connection as complete
|
||||
let _ = tokio::join!(req_task, res_task);
|
||||
|
||||
state.lock().await.cleanup_connection(conn_id).await;
|
||||
if let Some(ct) = tracker.as_ref() {
|
||||
ct.lock().await.decrement();
|
||||
}
|
||||
});
|
||||
|
||||
Ok(cleanup_task)
|
||||
}
|
||||
|
||||
/// Repeatedly reads in new requests, processes them, and sends their responses to the
|
||||
/// response loop
|
||||
async fn request_loop<T, U>(
|
||||
conn_id: usize,
|
||||
state: Arc<Mutex<State>>,
|
||||
mut transport: TransportReadHalf<T, U>,
|
||||
tx: mpsc::Sender<Response>,
|
||||
) where
|
||||
T: AsyncRead + Send + Unpin + 'static,
|
||||
U: Codec,
|
||||
{
|
||||
loop {
|
||||
match transport.receive::<Request>().await {
|
||||
Ok(Some(req)) => {
|
||||
debug!(
|
||||
"<Conn @ {}> Received request of type{} {}",
|
||||
conn_id,
|
||||
if req.payload.len() > 1 { "s" } else { "" },
|
||||
req.to_payload_type_string()
|
||||
);
|
||||
|
||||
if let Err(x) = handler::process(conn_id, Arc::clone(&state), req, tx.clone()).await
|
||||
{
|
||||
error!("<Conn @ {}> {}", conn_id, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
trace!("<Conn @ {}> Input from connection closed", conn_id);
|
||||
break;
|
||||
}
|
||||
Err(x) => {
|
||||
error!("<Conn @ {}> {}", conn_id, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Properly close off any associated process' stdin given that we can't get new
|
||||
// requests to send more stdin to them
|
||||
state.lock().await.close_stdin_for_connection(conn_id);
|
||||
}
|
||||
|
||||
/// Repeatedly sends responses out over the wire
|
||||
async fn response_loop<T, U>(
|
||||
conn_id: usize,
|
||||
mut transport: TransportWriteHalf<T, U>,
|
||||
mut rx: mpsc::Receiver<Response>,
|
||||
) where
|
||||
T: AsyncWrite + Send + Unpin + 'static,
|
||||
U: Codec,
|
||||
{
|
||||
while let Some(res) = rx.recv().await {
|
||||
debug!(
|
||||
"<Conn @ {}> Sending response of type{} {}",
|
||||
conn_id,
|
||||
if res.payload.len() > 1 { "s" } else { "" },
|
||||
res.to_payload_type_string()
|
||||
);
|
||||
|
||||
if let Err(x) = transport.send(res).await {
|
||||
error!("<Conn @ {}> {}", conn_id, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
trace!("<Conn @ {}> Output to connection closed", conn_id);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
data::{RequestData, ResponseData},
|
||||
net::{InmemoryStream, PlainCodec},
|
||||
};
|
||||
use std::pin::Pin;
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn make_transport_stream() -> (
|
||||
mpsc::Sender<Transport<InmemoryStream, PlainCodec>>,
|
||||
Pin<Box<dyn Stream<Item = Transport<InmemoryStream, PlainCodec>> + Send>>,
|
||||
) {
|
||||
let (tx, rx) = mpsc::channel::<Transport<InmemoryStream, PlainCodec>>(1);
|
||||
let stream = futures::stream::unfold(rx, |mut rx| async move {
|
||||
rx.recv().await.map(move |transport| (transport, rx))
|
||||
});
|
||||
(tx, Box::pin(stream))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_should_return_ok_when_all_inner_tasks_complete() {
|
||||
let (tx, stream) = make_transport_stream();
|
||||
|
||||
let server = DistantServer::initialize(stream, Default::default());
|
||||
|
||||
// Conclude all server tasks by closing out the listener
|
||||
drop(tx);
|
||||
|
||||
let result = server.wait().await;
|
||||
assert!(result.is_ok(), "Unexpected result: {:?}", result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_should_return_error_when_server_aborted() {
|
||||
let (_tx, stream) = make_transport_stream();
|
||||
|
||||
let server = DistantServer::initialize(stream, Default::default());
|
||||
server.abort();
|
||||
|
||||
match server.wait().await {
|
||||
Err(x) if x.is_cancelled() => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn server_should_receive_requests_and_send_responses_to_appropriate_connections() {
|
||||
let (tx, stream) = make_transport_stream();
|
||||
|
||||
let _server = DistantServer::initialize(stream, Default::default());
|
||||
|
||||
// Send over a "connection"
|
||||
let (mut t1, t2) = Transport::make_pair();
|
||||
tx.send(t2).await.unwrap();
|
||||
|
||||
// Send a request
|
||||
t1.send(Request::new(
|
||||
"test-tenant",
|
||||
vec![RequestData::SystemInfo {}],
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Get a response
|
||||
let res = t1.receive::<Response>().await.unwrap().unwrap();
|
||||
assert!(res.payload.len() == 1, "Unexpected payload size");
|
||||
assert!(
|
||||
matches!(res.payload[0], ResponseData::SystemInfo { .. }),
|
||||
"Unexpected response: {:?}",
|
||||
res.payload[0]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn server_should_shutdown_if_no_connections_after_shutdown_duration() {
|
||||
let (_tx, stream) = make_transport_stream();
|
||||
|
||||
let server = DistantServer::initialize(
|
||||
stream,
|
||||
DistantServerOptions {
|
||||
shutdown_after: Some(Duration::from_millis(50)),
|
||||
max_msg_capacity: 1,
|
||||
},
|
||||
);
|
||||
|
||||
let result = server.wait().await;
|
||||
assert!(result.is_ok(), "Unexpected result: {:?}", result);
|
||||
}
|
||||
}
|
@ -1,232 +0,0 @@
|
||||
use super::{InputChannel, ProcessKiller, ProcessPty};
|
||||
use crate::data::{ChangeKindSet, ResponseData};
|
||||
use log::*;
|
||||
use notify::RecommendedWatcher;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
future::Future,
|
||||
hash::{Hash, Hasher},
|
||||
io,
|
||||
ops::{Deref, DerefMut},
|
||||
path::{Path, PathBuf},
|
||||
pin::Pin,
|
||||
};
|
||||
|
||||
pub type ReplyFn = Box<dyn FnMut(Vec<ResponseData>) -> ReplyRet + Send + 'static>;
|
||||
pub type ReplyRet = Pin<Box<dyn Future<Output = bool> + Send + 'static>>;
|
||||
|
||||
/// Holds state related to multiple connections managed by a server
|
||||
#[derive(Default)]
|
||||
pub struct State {
|
||||
/// Map of all processes running on the server
|
||||
pub processes: HashMap<usize, ProcessState>,
|
||||
|
||||
/// List of processes that will be killed when a connection drops
|
||||
client_processes: HashMap<usize, Vec<usize>>,
|
||||
|
||||
/// Watcher used for filesystem events
|
||||
pub watcher: Option<RecommendedWatcher>,
|
||||
|
||||
/// Mapping of Path -> (Reply Fn, recursive) for watcher notifications
|
||||
pub watcher_paths: HashMap<WatcherPath, ReplyFn>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WatcherPath {
|
||||
/// The raw path provided to the watcher, which is not canonicalized
|
||||
raw_path: PathBuf,
|
||||
|
||||
/// The canonicalized path at the time of providing to the watcher,
|
||||
/// as all paths must exist for a watcher, we use this to get the
|
||||
/// source of truth when watching
|
||||
path: PathBuf,
|
||||
|
||||
/// Whether or not the path was set to be recursive
|
||||
recursive: bool,
|
||||
|
||||
/// Specific filter for path
|
||||
only: ChangeKindSet,
|
||||
}
|
||||
|
||||
impl PartialEq for WatcherPath {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.path == other.path
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for WatcherPath {}
|
||||
|
||||
impl Hash for WatcherPath {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.path.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for WatcherPath {
|
||||
type Target = PathBuf;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.path
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for WatcherPath {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.path
|
||||
}
|
||||
}
|
||||
|
||||
impl WatcherPath {
|
||||
/// Create a new watcher path using the given path and canonicalizing it
|
||||
pub fn new(
|
||||
path: impl Into<PathBuf>,
|
||||
recursive: bool,
|
||||
only: impl Into<ChangeKindSet>,
|
||||
) -> io::Result<Self> {
|
||||
let raw_path = path.into();
|
||||
let path = raw_path.canonicalize()?;
|
||||
let only = only.into();
|
||||
Ok(Self {
|
||||
raw_path,
|
||||
path,
|
||||
recursive,
|
||||
only,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn raw_path(&self) -> &Path {
|
||||
self.raw_path.as_path()
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &Path {
|
||||
self.path.as_path()
|
||||
}
|
||||
|
||||
/// Returns true if this watcher path applies to the given path.
|
||||
/// This is accomplished by checking if the path is contained
|
||||
/// within either the raw or canonicalized path of the watcher
|
||||
/// and ensures that recursion rules are respected
|
||||
pub fn applies_to_path(&self, path: &Path) -> bool {
|
||||
let check_path = |path: &Path| -> bool {
|
||||
let cnt = path.components().count();
|
||||
|
||||
// 0 means exact match from strip_prefix
|
||||
// 1 means that it was within immediate directory (fine for non-recursive)
|
||||
// 2+ means it needs to be recursive
|
||||
cnt < 2 || self.recursive
|
||||
};
|
||||
|
||||
match (
|
||||
path.strip_prefix(self.path()),
|
||||
path.strip_prefix(self.raw_path()),
|
||||
) {
|
||||
(Ok(p1), Ok(p2)) => check_path(p1) || check_path(p2),
|
||||
(Ok(p), Err(_)) => check_path(p),
|
||||
(Err(_), Ok(p)) => check_path(p),
|
||||
(Err(_), Err(_)) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds information related to a spawned process on the server
|
||||
pub struct ProcessState {
|
||||
pub cmd: String,
|
||||
pub args: Vec<String>,
|
||||
pub persist: bool,
|
||||
|
||||
pub id: usize,
|
||||
pub stdin: Option<Box<dyn InputChannel>>,
|
||||
pub killer: Box<dyn ProcessKiller>,
|
||||
pub pty: Box<dyn ProcessPty>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn map_paths_to_watcher_paths_and_replies<'a>(
|
||||
&mut self,
|
||||
paths: &'a [PathBuf],
|
||||
) -> Vec<(Vec<&'a PathBuf>, &ChangeKindSet, &mut ReplyFn)> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for (wp, reply) in self.watcher_paths.iter_mut() {
|
||||
let mut wp_paths = Vec::new();
|
||||
for path in paths {
|
||||
if wp.applies_to_path(path) {
|
||||
wp_paths.push(path);
|
||||
}
|
||||
}
|
||||
if !wp_paths.is_empty() {
|
||||
results.push((wp_paths, &wp.only, reply));
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Pushes a new process associated with a connection
|
||||
pub fn push_process_state(&mut self, conn_id: usize, process_state: ProcessState) {
|
||||
self.client_processes
|
||||
.entry(conn_id)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(process_state.id);
|
||||
self.processes.insert(process_state.id, process_state);
|
||||
}
|
||||
|
||||
/// Removes a process associated with a connection
|
||||
pub fn remove_process(&mut self, conn_id: usize, proc_id: usize) {
|
||||
self.client_processes.entry(conn_id).and_modify(|v| {
|
||||
if let Some(pos) = v.iter().position(|x| *x == proc_id) {
|
||||
v.remove(pos);
|
||||
}
|
||||
});
|
||||
self.processes.remove(&proc_id);
|
||||
}
|
||||
|
||||
/// Closes stdin for all processes associated with the connection
|
||||
pub fn close_stdin_for_connection(&mut self, conn_id: usize) {
|
||||
debug!("<Conn @ {:?}> Closing stdin to all processes", conn_id);
|
||||
if let Some(ids) = self.client_processes.get(&conn_id) {
|
||||
for id in ids {
|
||||
if let Some(process) = self.processes.get_mut(id) {
|
||||
trace!(
|
||||
"<Conn @ {:?}> Closing stdin for proc {}",
|
||||
conn_id,
|
||||
process.id
|
||||
);
|
||||
|
||||
let _ = process.stdin.take();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cleans up state associated with a particular connection
|
||||
pub async fn cleanup_connection(&mut self, conn_id: usize) {
|
||||
debug!("<Conn @ {:?}> Cleaning up state", conn_id);
|
||||
if let Some(ids) = self.client_processes.remove(&conn_id) {
|
||||
for id in ids {
|
||||
if let Some(mut process) = self.processes.remove(&id) {
|
||||
if !process.persist {
|
||||
trace!(
|
||||
"<Conn @ {:?}> Requesting proc {} be killed",
|
||||
conn_id,
|
||||
process.id
|
||||
);
|
||||
let pid = process.id;
|
||||
if let Err(x) = process.killer.kill().await {
|
||||
error!(
|
||||
"Conn {} failed to send process {} kill signal: {}",
|
||||
id, pid, x
|
||||
);
|
||||
}
|
||||
} else {
|
||||
trace!(
|
||||
"<Conn @ {:?}> Proc {} is persistent and will not be killed",
|
||||
conn_id,
|
||||
process.id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,8 +0,0 @@
|
||||
mod distant;
|
||||
mod port;
|
||||
mod relay;
|
||||
mod utils;
|
||||
|
||||
pub use self::distant::{DistantServer, DistantServerOptions};
|
||||
pub use port::PortRange;
|
||||
pub use relay::RelayServer;
|
@ -1,382 +0,0 @@
|
||||
use crate::{
|
||||
client::{Session, SessionChannel},
|
||||
data::{Request, RequestData, ResponseData},
|
||||
net::{Codec, DataStream, Transport},
|
||||
server::utils::{ConnTracker, ShutdownTask},
|
||||
};
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use log::*;
|
||||
use std::{collections::HashMap, marker::Unpin, sync::Arc};
|
||||
use tokio::{
|
||||
io,
|
||||
sync::{oneshot, Mutex},
|
||||
task::{JoinError, JoinHandle},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
/// Represents a server that relays requests & responses between connections and the
|
||||
/// actual server
|
||||
pub struct RelayServer {
|
||||
accept_task: JoinHandle<()>,
|
||||
conns: Arc<Mutex<HashMap<usize, Conn>>>,
|
||||
}
|
||||
|
||||
impl RelayServer {
|
||||
pub fn initialize<T, U, S>(
|
||||
session: Session,
|
||||
mut stream: S,
|
||||
shutdown_after: Option<Duration>,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
T: DataStream + Send + 'static,
|
||||
U: Codec + Send + 'static,
|
||||
S: Stream<Item = Transport<T, U>> + Send + Unpin + 'static,
|
||||
{
|
||||
let conns: Arc<Mutex<HashMap<usize, Conn>>> = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after);
|
||||
let conns_2 = Arc::clone(&conns);
|
||||
let accept_task = tokio::spawn(async move {
|
||||
let inner = async move {
|
||||
loop {
|
||||
let channel = session.clone_channel();
|
||||
match stream.next().await {
|
||||
Some(transport) => {
|
||||
let result = Conn::initialize(
|
||||
transport,
|
||||
channel,
|
||||
tracker.as_ref().map(Arc::clone),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(conn) => {
|
||||
conns_2.lock().await.insert(conn.id(), conn);
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Failed to initialize connection: {}", x);
|
||||
}
|
||||
};
|
||||
}
|
||||
None => {
|
||||
info!("Listener shutting down");
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
match shutdown {
|
||||
Some(shutdown) => tokio::select! {
|
||||
_ = inner => {}
|
||||
_ = shutdown => {
|
||||
warn!("Reached shutdown timeout, so terminating");
|
||||
}
|
||||
},
|
||||
None => inner.await,
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self { accept_task, conns })
|
||||
}
|
||||
|
||||
/// Waits for the server to terminate
|
||||
pub async fn wait(self) -> Result<(), JoinError> {
|
||||
self.accept_task.await
|
||||
}
|
||||
|
||||
/// Aborts the server by aborting the internal tasks and current connections
|
||||
pub async fn abort(&self) {
|
||||
self.accept_task.abort();
|
||||
self.conns
|
||||
.lock()
|
||||
.await
|
||||
.values()
|
||||
.for_each(|conn| conn.abort());
|
||||
}
|
||||
}
|
||||
|
||||
struct Conn {
|
||||
id: usize,
|
||||
conn_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Conn {
|
||||
pub async fn initialize<T, U>(
|
||||
transport: Transport<T, U>,
|
||||
channel: SessionChannel,
|
||||
ct: Option<Arc<Mutex<ConnTracker>>>,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
T: DataStream + 'static,
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
// Create a unique id to associate with the connection since its address
|
||||
// is not guaranteed to have an identifiable string
|
||||
let id: usize = rand::random();
|
||||
|
||||
// Mark that we have a new connection
|
||||
if let Some(ct) = ct.as_ref() {
|
||||
ct.lock().await.increment();
|
||||
}
|
||||
|
||||
let conn_task = spawn_conn_handler(id, transport, channel, ct).await;
|
||||
|
||||
Ok(Self { id, conn_task })
|
||||
}
|
||||
|
||||
/// Id associated with the connection
|
||||
pub fn id(&self) -> usize {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Aborts the connection from the server side
|
||||
pub fn abort(&self) {
|
||||
self.conn_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_conn_handler<T, U>(
|
||||
conn_id: usize,
|
||||
transport: Transport<T, U>,
|
||||
mut channel: SessionChannel,
|
||||
ct: Option<Arc<Mutex<ConnTracker>>>,
|
||||
) -> JoinHandle<()>
|
||||
where
|
||||
T: DataStream,
|
||||
U: Codec + Send + 'static,
|
||||
{
|
||||
let (mut t_reader, t_writer) = transport.into_split();
|
||||
let processes = Arc::new(Mutex::new(Vec::new()));
|
||||
let t_writer = Arc::new(Mutex::new(t_writer));
|
||||
|
||||
let (done_tx, done_rx) = oneshot::channel();
|
||||
let mut channel_2 = channel.clone();
|
||||
let processes_2 = Arc::clone(&processes);
|
||||
let task = tokio::spawn(async move {
|
||||
loop {
|
||||
if channel_2.is_closed() {
|
||||
break;
|
||||
}
|
||||
|
||||
// For each request, forward it through the session and monitor all responses
|
||||
match t_reader.receive::<Request>().await {
|
||||
Ok(Some(req)) => match channel_2.mail(req).await {
|
||||
Ok(mut mailbox) => {
|
||||
let processes = Arc::clone(&processes_2);
|
||||
let t_writer = Arc::clone(&t_writer);
|
||||
tokio::spawn(async move {
|
||||
while let Some(res) = mailbox.next().await {
|
||||
// Keep track of processes that are started so we can kill them
|
||||
// when we're done
|
||||
{
|
||||
let mut p_lock = processes.lock().await;
|
||||
for data in res.payload.iter() {
|
||||
if let ResponseData::ProcSpawned { id } = *data {
|
||||
p_lock.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(x) = t_writer.lock().await.send(res).await {
|
||||
error!(
|
||||
"<Conn @ {}> Failed to send response back: {}",
|
||||
conn_id, x
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(x) => error!(
|
||||
"<Conn @ {}> Failed to pass along request received on unix socket: {:?}",
|
||||
conn_id, x
|
||||
),
|
||||
},
|
||||
Ok(None) => break,
|
||||
Err(x) => {
|
||||
error!(
|
||||
"<Conn @ {}> Failed to receive request from unix stream: {:?}",
|
||||
conn_id, x
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = done_tx.send(());
|
||||
});
|
||||
|
||||
// Perform cleanup if done by sending a request to kill each running process
|
||||
tokio::spawn(async move {
|
||||
let _ = done_rx.await;
|
||||
|
||||
let p_lock = processes.lock().await;
|
||||
if !p_lock.is_empty() {
|
||||
trace!(
|
||||
"Cleaning conn {} :: killing {} process",
|
||||
conn_id,
|
||||
p_lock.len()
|
||||
);
|
||||
if let Err(x) = channel
|
||||
.fire(Request::new(
|
||||
"relay",
|
||||
p_lock
|
||||
.iter()
|
||||
.map(|id| RequestData::ProcKill { id: *id })
|
||||
.collect(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
error!("<Conn @ {}> Failed to send kill signals: {}", conn_id, x);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ct) = ct.as_ref() {
|
||||
ct.lock().await.decrement();
|
||||
}
|
||||
debug!("<Conn @ {}> Disconnected", conn_id);
|
||||
});
|
||||
|
||||
task
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
data::Response,
|
||||
net::{InmemoryStream, PlainCodec},
|
||||
};
|
||||
use std::{pin::Pin, time::Duration};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
fn make_session() -> (Transport<InmemoryStream, PlainCodec>, Session) {
|
||||
let (t1, t2) = Transport::make_pair();
|
||||
(t1, Session::initialize(t2).unwrap())
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn make_transport_stream() -> (
|
||||
mpsc::Sender<Transport<InmemoryStream, PlainCodec>>,
|
||||
Pin<Box<dyn Stream<Item = Transport<InmemoryStream, PlainCodec>> + Send>>,
|
||||
) {
|
||||
let (tx, rx) = mpsc::channel::<Transport<InmemoryStream, PlainCodec>>(1);
|
||||
let stream = futures::stream::unfold(rx, |mut rx| async move {
|
||||
rx.recv().await.map(move |transport| (transport, rx))
|
||||
});
|
||||
(tx, Box::pin(stream))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_should_return_ok_when_all_inner_tasks_complete() {
|
||||
let (transport, session) = make_session();
|
||||
let (tx, stream) = make_transport_stream();
|
||||
let server = RelayServer::initialize(session, stream, None).unwrap();
|
||||
|
||||
// Conclude all server tasks by closing out the listener & session
|
||||
drop(transport);
|
||||
drop(tx);
|
||||
|
||||
let result = server.wait().await;
|
||||
assert!(result.is_ok(), "Unexpected result: {:?}", result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_should_return_error_when_server_aborted() {
|
||||
let (_transport, session) = make_session();
|
||||
let (_tx, stream) = make_transport_stream();
|
||||
let server = RelayServer::initialize(session, stream, None).unwrap();
|
||||
server.abort().await;
|
||||
|
||||
match server.wait().await {
|
||||
Err(x) if x.is_cancelled() => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn server_should_forward_requests_using_session() {
|
||||
let (mut transport, session) = make_session();
|
||||
let (tx, stream) = make_transport_stream();
|
||||
let _server = RelayServer::initialize(session, stream, None).unwrap();
|
||||
|
||||
// Send over a "connection"
|
||||
let (mut t1, t2) = Transport::make_pair();
|
||||
tx.send(t2).await.unwrap();
|
||||
|
||||
// Send a request
|
||||
let req = Request::new("test-tenant", vec![RequestData::SystemInfo {}]);
|
||||
t1.send(req.clone()).await.unwrap();
|
||||
|
||||
// Verify the request is forwarded out via session
|
||||
let outbound_req = transport.receive().await.unwrap().unwrap();
|
||||
assert_eq!(req, outbound_req);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn server_should_send_back_response_with_tenant_matching_connection() {
|
||||
let (mut transport, session) = make_session();
|
||||
let (tx, stream) = make_transport_stream();
|
||||
let _server = RelayServer::initialize(session, stream, None).unwrap();
|
||||
|
||||
// Send over a "connection"
|
||||
let (mut t1, t2) = Transport::make_pair();
|
||||
tx.send(t2).await.unwrap();
|
||||
|
||||
// Send over a second "connection"
|
||||
let (mut t2, t3) = Transport::make_pair();
|
||||
tx.send(t3).await.unwrap();
|
||||
|
||||
// Send a request to mark the tenant of the first connection
|
||||
t1.send(Request::new(
|
||||
"test-tenant-1",
|
||||
vec![RequestData::SystemInfo {}],
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send a request to mark the tenant of the second connection
|
||||
t2.send(Request::new(
|
||||
"test-tenant-2",
|
||||
vec![RequestData::SystemInfo {}],
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Clear out the transport channel (outbound of session)
|
||||
// NOTE: Because our test stream uses a buffer size of 1, we have to clear out the
|
||||
// outbound data from the earlier requests before we can send back a response
|
||||
let req_1 = transport.receive::<Request>().await.unwrap().unwrap();
|
||||
let req_2 = transport.receive::<Request>().await.unwrap().unwrap();
|
||||
let origin_id = if req_1.tenant == "test-tenant-2" {
|
||||
req_1.id
|
||||
} else {
|
||||
req_2.id
|
||||
};
|
||||
|
||||
// Send a response back to a singular connection based on the tenant
|
||||
let res = Response::new("test-tenant-2", origin_id, vec![ResponseData::Ok]);
|
||||
transport.send(res.clone()).await.unwrap();
|
||||
|
||||
// Verify that response is only received by a singular connection
|
||||
let inbound_res = t2.receive().await.unwrap().unwrap();
|
||||
assert_eq!(res, inbound_res);
|
||||
|
||||
let no_inbound = tokio::select! {
|
||||
_ = t1.receive::<Response>() => {false}
|
||||
_ = tokio::time::sleep(Duration::from_millis(50)) => {true}
|
||||
};
|
||||
assert!(no_inbound, "Unexpectedly got response for wrong connection");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn server_should_shutdown_if_no_connections_after_shutdown_duration() {
|
||||
let (_transport, session) = make_session();
|
||||
let (_tx, stream) = make_transport_stream();
|
||||
let server =
|
||||
RelayServer::initialize(session, stream, Some(Duration::from_millis(50))).unwrap();
|
||||
|
||||
let result = server.wait().await;
|
||||
assert!(result.is_ok(), "Unexpected result: {:?}", result);
|
||||
}
|
||||
}
|
@ -1,289 +0,0 @@
|
||||
use log::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
sync::Mutex,
|
||||
task::{JoinError, JoinHandle},
|
||||
time::{self, Instant},
|
||||
};
|
||||
|
||||
/// Task to keep track of a possible server shutdown based on connections
|
||||
pub struct ShutdownTask {
|
||||
task: JoinHandle<()>,
|
||||
tracker: Arc<Mutex<ConnTracker>>,
|
||||
}
|
||||
|
||||
impl Future for ShutdownTask {
|
||||
type Output = Result<(), JoinError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.task).poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl ShutdownTask {
|
||||
/// Given an optional timeout, will either create the shutdown task or not,
|
||||
/// returning an optional shutdown task alongside an optional connection tracker
|
||||
pub fn maybe_initialize(
|
||||
duration: Option<Duration>,
|
||||
) -> (Option<ShutdownTask>, Option<Arc<Mutex<ConnTracker>>>) {
|
||||
match duration {
|
||||
Some(duration) => {
|
||||
let task = Self::initialize(duration);
|
||||
let tracker = task.tracker();
|
||||
(Some(task), Some(tracker))
|
||||
}
|
||||
None => (None, None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawns a new task that continues to monitor the time since a
|
||||
/// connection on the server existed, reporting a shutdown to all listeners
|
||||
/// once the timeout is exceeded
|
||||
pub fn initialize(duration: Duration) -> Self {
|
||||
let tracker = Arc::new(Mutex::new(ConnTracker::new()));
|
||||
|
||||
let tracker_2 = Arc::clone(&tracker);
|
||||
let task = tokio::spawn(async move {
|
||||
loop {
|
||||
// Get the time since the last connection joined/left
|
||||
let (base_time, cnt) = tracker_2.lock().await.time_and_cnt();
|
||||
|
||||
// If we have no connections left, we want to wait
|
||||
// until the remaining period has passed and then
|
||||
// verify that we still have no connections
|
||||
if cnt == 0 {
|
||||
// Get the time we should wait based on when the last connection
|
||||
// was dropped; this closes the gap in the case where we start
|
||||
// sometime later than exactly duration since the last check
|
||||
let next_time = base_time + duration;
|
||||
let wait_duration = next_time
|
||||
.checked_duration_since(Instant::now())
|
||||
.unwrap_or_default()
|
||||
+ Duration::from_millis(1);
|
||||
|
||||
// Wait until we've reached our desired duration since the
|
||||
// last connection was dropped
|
||||
time::sleep(wait_duration).await;
|
||||
|
||||
// If we do have a connection at this point, don't exit
|
||||
if !tracker_2.lock().await.has_reached_timeout(duration) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, we now should exit, which we do by reporting
|
||||
debug!(
|
||||
"Shutdown time of {}s has been reached!",
|
||||
duration.as_secs_f32()
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// Otherwise, we just wait the full duration as worst case
|
||||
// we'll have waited just about the time desired if right
|
||||
// after waiting starts the last connection is closed
|
||||
time::sleep(duration).await;
|
||||
}
|
||||
});
|
||||
|
||||
Self { task, tracker }
|
||||
}
|
||||
|
||||
/// Produces a new copy of the connection tracker associated with the shutdown manager
|
||||
pub fn tracker(&self) -> Arc<Mutex<ConnTracker>> {
|
||||
Arc::clone(&self.tracker)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnTracker {
|
||||
time: Instant,
|
||||
cnt: usize,
|
||||
}
|
||||
|
||||
impl ConnTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
time: Instant::now(),
|
||||
cnt: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment(&mut self) {
|
||||
self.time = Instant::now();
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
pub fn decrement(&mut self) {
|
||||
if self.cnt > 0 {
|
||||
self.time = Instant::now();
|
||||
self.cnt -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn time_and_cnt(&self) -> (Instant, usize) {
|
||||
(self.time, self.cnt)
|
||||
}
|
||||
|
||||
fn has_reached_timeout(&self, duration: Duration) -> bool {
|
||||
self.cnt == 0 && self.time.elapsed() >= duration
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::thread;
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_task_should_not_resolve_if_has_connection_regardless_of_time() {
|
||||
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
|
||||
task.tracker().lock().await.increment();
|
||||
assert!(
|
||||
futures::poll!(&mut task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
|
||||
time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
assert!(
|
||||
futures::poll!(task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration() {
|
||||
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
|
||||
assert!(
|
||||
futures::poll!(&mut task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
|
||||
tokio::select! {
|
||||
_ = task => {}
|
||||
_ = time::sleep(Duration::from_secs(1)) => {
|
||||
panic!("Shutdown task unexpectedly pending");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration_after_connection_removed(
|
||||
) {
|
||||
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
|
||||
task.tracker().lock().await.increment();
|
||||
assert!(
|
||||
futures::poll!(&mut task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
|
||||
time::sleep(Duration::from_millis(50)).await;
|
||||
assert!(
|
||||
futures::poll!(&mut task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
|
||||
task.tracker().lock().await.decrement();
|
||||
|
||||
tokio::select! {
|
||||
_ = task => {}
|
||||
_ = time::sleep(Duration::from_secs(1)) => {
|
||||
panic!("Shutdown task unexpectedly pending");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_task_should_not_resolve_before_minimum_duration() {
|
||||
let mut task = ShutdownTask::initialize(Duration::from_millis(50));
|
||||
assert!(
|
||||
futures::poll!(&mut task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
|
||||
time::sleep(Duration::from_millis(5)).await;
|
||||
|
||||
assert!(
|
||||
futures::poll!(task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conn_tracker_should_update_time_when_incremented() {
|
||||
let mut tracker = ConnTracker::new();
|
||||
let (old_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 0);
|
||||
|
||||
// Wait to ensure that the new time will be different
|
||||
thread::sleep(Duration::from_millis(1));
|
||||
|
||||
tracker.increment();
|
||||
let (new_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 1);
|
||||
assert!(new_time > old_time);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conn_tracker_should_update_time_when_decremented() {
|
||||
let mut tracker = ConnTracker::new();
|
||||
tracker.increment();
|
||||
|
||||
let (old_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 1);
|
||||
|
||||
// Wait to ensure that the new time will be different
|
||||
thread::sleep(Duration::from_millis(1));
|
||||
|
||||
tracker.decrement();
|
||||
let (new_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 0);
|
||||
assert!(new_time > old_time);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conn_tracker_should_not_update_time_when_decremented_if_at_zero_already() {
|
||||
let mut tracker = ConnTracker::new();
|
||||
let (old_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 0);
|
||||
|
||||
// Wait to ensure that the new time would be different if updated
|
||||
thread::sleep(Duration::from_millis(1));
|
||||
|
||||
tracker.decrement();
|
||||
let (new_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 0);
|
||||
assert!(new_time == old_time);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conn_tracker_should_report_timeout_reached_when_time_has_elapsed_and_no_connections() {
|
||||
let tracker = ConnTracker::new();
|
||||
let (_, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 0);
|
||||
|
||||
// Wait to ensure that the new time would be different if updated
|
||||
thread::sleep(Duration::from_millis(1));
|
||||
|
||||
assert!(tracker.has_reached_timeout(Duration::from_millis(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conn_tracker_should_not_report_timeout_reached_when_time_has_elapsed_but_has_connections() {
|
||||
let mut tracker = ConnTracker::new();
|
||||
tracker.increment();
|
||||
|
||||
let (_, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 1);
|
||||
|
||||
// Wait to ensure that the new time would be different if updated
|
||||
thread::sleep(Duration::from_millis(1));
|
||||
|
||||
assert!(!tracker.has_reached_timeout(Duration::from_millis(1)));
|
||||
}
|
||||
}
|
@ -0,0 +1,96 @@
|
||||
use distant_core::{
|
||||
net::{FramedTransport, InmemoryTransport, IntoSplit, OneshotListener, PlainCodec},
|
||||
BoxedDistantReader, BoxedDistantWriter, Destination, DistantApiServer, DistantChannelExt,
|
||||
DistantManager, DistantManagerClient, DistantManagerClientConfig, DistantManagerConfig, Extra,
|
||||
};
|
||||
use std::io;
|
||||
|
||||
/// Creates a client transport and server listener for our tests
|
||||
/// that are connected together
|
||||
async fn setup() -> (
|
||||
FramedTransport<InmemoryTransport, PlainCodec>,
|
||||
OneshotListener<FramedTransport<InmemoryTransport, PlainCodec>>,
|
||||
) {
|
||||
let (t1, t2) = InmemoryTransport::pair(100);
|
||||
|
||||
let listener = OneshotListener::from_value(FramedTransport::new(t2, PlainCodec));
|
||||
let transport = FramedTransport::new(t1, PlainCodec);
|
||||
(transport, listener)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_be_able_to_establish_a_single_connection_and_communicate() {
|
||||
let (transport, listener) = setup().await;
|
||||
|
||||
let config = DistantManagerConfig::default();
|
||||
let manager_ref = DistantManager::start(config, listener).expect("Failed to start manager");
|
||||
|
||||
// NOTE: To pass in a raw function, we HAVE to specify the types of the parameters manually,
|
||||
// otherwise we get a compilation error about lifetime mismatches
|
||||
manager_ref
|
||||
.register_connect_handler("scheme", |_: &_, _: &_, _: &mut _| async {
|
||||
use distant_core::net::ServerExt;
|
||||
let (t1, t2) = FramedTransport::pair(100);
|
||||
|
||||
// Spawn a server on one end
|
||||
let _ = DistantApiServer::local()
|
||||
.unwrap()
|
||||
.start(OneshotListener::from_value(t2.into_split()))?;
|
||||
|
||||
// Create a reader/writer pair on the other end
|
||||
let (writer, reader) = t1.into_split();
|
||||
let writer: BoxedDistantWriter = Box::new(writer);
|
||||
let reader: BoxedDistantReader = Box::new(reader);
|
||||
Ok((writer, reader))
|
||||
})
|
||||
.await
|
||||
.expect("Failed to register handler");
|
||||
|
||||
let config = DistantManagerClientConfig::with_empty_prompts();
|
||||
let mut client =
|
||||
DistantManagerClient::new(config, transport).expect("Failed to connect to manager");
|
||||
|
||||
// Test establishing a connection to some remote server
|
||||
let id = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Extra>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to connect to a remote server");
|
||||
|
||||
// Test retrieving list of connections
|
||||
let list = client
|
||||
.list()
|
||||
.await
|
||||
.expect("Failed to get list of connections");
|
||||
assert_eq!(list.len(), 1);
|
||||
assert_eq!(list.get(&id).unwrap().to_string(), "scheme://host/");
|
||||
|
||||
// Test retrieving information
|
||||
let info = client
|
||||
.info(id)
|
||||
.await
|
||||
.expect("Failed to get info about connection");
|
||||
assert_eq!(info.id, id);
|
||||
assert_eq!(info.destination.to_string(), "scheme://host/");
|
||||
assert_eq!(info.extra, "key=value".parse::<Extra>().unwrap());
|
||||
|
||||
// Create a new channel and request some data
|
||||
let mut channel = client
|
||||
.open_channel(id)
|
||||
.await
|
||||
.expect("Failed to open channel");
|
||||
let _ = channel
|
||||
.system_info()
|
||||
.await
|
||||
.expect("Failed to get system information");
|
||||
|
||||
// Test killing a connection
|
||||
client.kill(id).await.expect("Failed to kill connection");
|
||||
|
||||
// Test getting an error to ensure that serialization of that data works,
|
||||
// which we do by trying to access a connection that no longer exists
|
||||
let err = client.info(id).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::NotConnected);
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
[package]
|
||||
name = "distant-net"
|
||||
description = "Network library for distant, providing implementations to support client/server architecture"
|
||||
categories = ["network-programming"]
|
||||
keywords = ["api", "async"]
|
||||
version = "0.17.0"
|
||||
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
|
||||
edition = "2021"
|
||||
homepage = "https://github.com/chipsenkbeil/distant"
|
||||
repository = "https://github.com/chipsenkbeil/distant"
|
||||
readme = "README.md"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.56"
|
||||
bytes = "1.1.0"
|
||||
chacha20poly1305 = "=0.10.0-pre"
|
||||
derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] }
|
||||
futures = "0.3.21"
|
||||
hex = "0.4.3"
|
||||
hkdf = "0.12.3"
|
||||
log = "0.4.17"
|
||||
paste = "1.0.7"
|
||||
p256 = { version = "0.11.1", features = ["ecdh", "pem"] }
|
||||
rand = { version = "0.8.4", features = ["getrandom"] }
|
||||
rmp-serde = "1.1.0"
|
||||
sha2 = "0.10.2"
|
||||
serde = { version = "1.0.126", features = ["derive"] }
|
||||
serde_bytes = "0.11.6"
|
||||
tokio = { version = "1.12.0", features = ["full"] }
|
||||
tokio-util = { version = "0.6.7", features = ["codec"] }
|
||||
|
||||
# Optional dependencies based on features
|
||||
schemars = { version = "0.8.10", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
@ -0,0 +1,49 @@
|
||||
# distant net
|
||||
|
||||
[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.61.0][distant_rustc_img]][distant_rustc_lnk]
|
||||
|
||||
[distant_crates_img]: https://img.shields.io/crates/v/distant-net.svg
|
||||
[distant_crates_lnk]: https://crates.io/crates/distant-net
|
||||
[distant_doc_img]: https://docs.rs/distant-net/badge.svg
|
||||
[distant_doc_lnk]: https://docs.rs/distant-net
|
||||
[distant_rustc_img]: https://img.shields.io/badge/distant_net-rustc_1.61+-lightgray.svg
|
||||
[distant_rustc_lnk]: https://blog.rust-lang.org/2022/05/19/Rust-1.61.0.html
|
||||
|
||||
Library that powers the [`distant`](https://github.com/chipsenkbeil/distant)
|
||||
binary.
|
||||
|
||||
🚧 **(Alpha stage software) This library is in rapid development and may break or change frequently!** 🚧
|
||||
|
||||
## Details
|
||||
|
||||
The `distant-net` library supplies the foundational networking functionality
|
||||
for the distant interfaces and distant cli.
|
||||
|
||||
## Installation
|
||||
|
||||
You can import the dependency by adding the following to your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
distant-net = "0.17"
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
Currently, the library supports the following features:
|
||||
|
||||
- `schemars`: derives the `schemars::JsonSchema` interface on `Request`
|
||||
and `Response` data types
|
||||
|
||||
By default, no features are enabled on the library.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under either of
|
||||
|
||||
Apache License, Version 2.0, (LICENSE-APACHE or
|
||||
[apache-license][apache-license]) MIT license (LICENSE-MIT or
|
||||
[mit-license][mit-license]) at your option.
|
||||
|
||||
[apache-license]: http://www.apache.org/licenses/LICENSE-2.0
|
||||
[mit-license]: http://opensource.org/licenses/MIT
|
@ -0,0 +1,29 @@
|
||||
use std::any::Any;
|
||||
|
||||
/// Trait used for casting support into the [`Any`] trait object
|
||||
pub trait AsAny: Any {
|
||||
/// Converts reference to [`Any`]
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
|
||||
/// Converts mutable reference to [`Any`]
|
||||
fn as_mut_any(&mut self) -> &mut dyn Any;
|
||||
|
||||
/// Consumes and produces `Box<dyn Any>`
|
||||
fn into_any(self: Box<Self>) -> Box<dyn Any>;
|
||||
}
|
||||
|
||||
/// Blanket implementation that enables any `'static` reference to convert
|
||||
/// to the [`Any`] type
|
||||
impl<T: 'static> AsAny for T {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_mut_any(&mut self) -> &mut dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_any(self: Box<Self>) -> Box<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
@ -0,0 +1,122 @@
|
||||
use derive_more::Display;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
mod client;
|
||||
pub use client::*;
|
||||
|
||||
mod handshake;
|
||||
pub use handshake::*;
|
||||
|
||||
mod server;
|
||||
pub use server::*;
|
||||
|
||||
/// Represents authentication messages that can be sent over the wire
|
||||
///
|
||||
/// NOTE: Must use serde's content attribute with the tag attribute. Just the tag attribute will
|
||||
/// cause deserialization to fail
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type", content = "data")]
|
||||
pub enum Auth {
|
||||
/// Represents a request to perform an authentication handshake,
|
||||
/// providing the public key and salt from one side in order to
|
||||
/// derive the shared key
|
||||
#[serde(rename = "auth_handshake")]
|
||||
Handshake {
|
||||
/// Bytes of the public key
|
||||
#[serde(with = "serde_bytes")]
|
||||
public_key: PublicKeyBytes,
|
||||
|
||||
/// Randomly generated salt
|
||||
#[serde(with = "serde_bytes")]
|
||||
salt: Salt,
|
||||
},
|
||||
|
||||
/// Represents the bytes of an encrypted message
|
||||
///
|
||||
/// Underneath, will be one of either [`AuthRequest`] or [`AuthResponse`]
|
||||
#[serde(rename = "auth_msg")]
|
||||
Msg {
|
||||
#[serde(with = "serde_bytes")]
|
||||
encrypted_payload: Vec<u8>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Represents authentication messages that act as initiators such as providing
|
||||
/// a challenge, verifying information, presenting information, or highlighting an error
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthRequest {
|
||||
/// Represents a challenge comprising a series of questions to be presented
|
||||
Challenge {
|
||||
questions: Vec<AuthQuestion>,
|
||||
extra: HashMap<String, String>,
|
||||
},
|
||||
|
||||
/// Represents an ask to verify some information
|
||||
Verify { kind: AuthVerifyKind, text: String },
|
||||
|
||||
/// Represents some information to be presented
|
||||
Info { text: String },
|
||||
|
||||
/// Represents some error that occurred
|
||||
Error { kind: AuthErrorKind, text: String },
|
||||
}
|
||||
|
||||
/// Represents authentication messages that are responses to auth requests such
|
||||
/// as answers to challenges or verifying information
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthResponse {
|
||||
/// Represents the answers to a previously-asked challenge
|
||||
Challenge { answers: Vec<String> },
|
||||
|
||||
/// Represents the answer to a previously-asked verify
|
||||
Verify { valid: bool },
|
||||
}
|
||||
|
||||
/// Represents the type of verification being requested
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[non_exhaustive]
|
||||
pub enum AuthVerifyKind {
|
||||
/// An ask to verify the host such as with SSH
|
||||
#[display(fmt = "host")]
|
||||
Host,
|
||||
}
|
||||
|
||||
/// Represents a single question in a challenge
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct AuthQuestion {
|
||||
/// The text of the question
|
||||
pub text: String,
|
||||
|
||||
/// Any extra information specific to a particular auth domain
|
||||
/// such as including a username and instructions for SSH authentication
|
||||
pub extra: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl AuthQuestion {
|
||||
/// Creates a new question without any extra data
|
||||
pub fn new(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
text: text.into(),
|
||||
extra: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the type of error encountered during authentication
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AuthErrorKind {
|
||||
/// When the answer(s) to a challenge do not pass authentication
|
||||
FailedChallenge,
|
||||
|
||||
/// When verification during authentication fails
|
||||
/// (e.g. a host is not allowed or blocked)
|
||||
FailedVerification,
|
||||
|
||||
/// When the error is unknown
|
||||
Unknown,
|
||||
}
|
@ -0,0 +1,817 @@
|
||||
use crate::{
|
||||
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Client,
|
||||
Codec, Handshake, XChaCha20Poly1305Codec,
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io};
|
||||
|
||||
pub struct AuthClient {
|
||||
inner: Client<Auth, Auth>,
|
||||
codec: Option<XChaCha20Poly1305Codec>,
|
||||
jit_handshake: bool,
|
||||
}
|
||||
|
||||
impl From<Client<Auth, Auth>> for AuthClient {
|
||||
fn from(client: Client<Auth, Auth>) -> Self {
|
||||
Self {
|
||||
inner: client,
|
||||
codec: None,
|
||||
jit_handshake: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthClient {
|
||||
/// Sends a request to the server to establish an encrypted connection
|
||||
pub async fn handshake(&mut self) -> io::Result<()> {
|
||||
let handshake = Handshake::default();
|
||||
|
||||
let response = self
|
||||
.inner
|
||||
.send(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let key = handshake.handshake(public_key, salt)?;
|
||||
self.codec.replace(XChaCha20Poly1305Codec::new(&key));
|
||||
Ok(())
|
||||
}
|
||||
Auth::Msg { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected encrypted message during handshake",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a handshake only if jit is enabled and no handshake has succeeded yet
|
||||
async fn jit_handshake(&mut self) -> io::Result<()> {
|
||||
if self.will_jit_handshake() && !self.is_ready() {
|
||||
self.handshake().await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if client has successfully performed a handshake
|
||||
/// and is ready to communicate with the server
|
||||
pub fn is_ready(&self) -> bool {
|
||||
self.codec.is_some()
|
||||
}
|
||||
|
||||
/// Returns true if this client will perform a handshake just-in-time (JIT) prior to making a
|
||||
/// request in the scenario where the client has not already performed a handshake
|
||||
#[inline]
|
||||
pub fn will_jit_handshake(&self) -> bool {
|
||||
self.jit_handshake
|
||||
}
|
||||
|
||||
/// Sets the jit flag on this client with `true` indicating that this client will perform a
|
||||
/// handshake just-in-time (JIT) prior to making a request in the scenario where the client has
|
||||
/// not already performed a handshake
|
||||
#[inline]
|
||||
pub fn set_jit_handshake(&mut self, flag: bool) {
|
||||
self.jit_handshake = flag;
|
||||
}
|
||||
|
||||
/// Provides a challenge to the server and returns the answers to the questions
|
||||
/// asked by the client
|
||||
pub async fn challenge(
|
||||
&mut self,
|
||||
questions: Vec<AuthQuestion>,
|
||||
extra: HashMap<String, String>,
|
||||
) -> io::Result<Vec<String>> {
|
||||
trace!(
|
||||
"AuthClient::challenge(questions = {:?}, extra = {:?})",
|
||||
questions,
|
||||
extra
|
||||
);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Challenge { questions, extra };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match self.decrypt_and_deserialize(&encrypted_payload)? {
|
||||
AuthResponse::Challenge { answers } => Ok(answers),
|
||||
AuthResponse::Verify { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected verify response during challenge",
|
||||
)),
|
||||
}
|
||||
}
|
||||
Auth::Handshake { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected handshake during challenge",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides a verification request to the server and returns whether or not
|
||||
/// the server approved
|
||||
pub async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool> {
|
||||
trace!("AuthClient::verify(kind = {:?}, text = {:?})", kind, text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Verify { kind, text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match self.decrypt_and_deserialize(&encrypted_payload)? {
|
||||
AuthResponse::Verify { valid } => Ok(valid),
|
||||
AuthResponse::Challenge { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected challenge response during verify",
|
||||
)),
|
||||
}
|
||||
}
|
||||
Auth::Handshake { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected handshake during verify",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides information to the server to use as it pleases with no response expected
|
||||
pub async fn info(&mut self, text: String) -> io::Result<()> {
|
||||
trace!("AuthClient::info(text = {:?})", text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Info { text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
self.inner.fire(Auth::Msg { encrypted_payload }).await
|
||||
}
|
||||
|
||||
/// Provides an error to the server to use as it pleases with no response expected
|
||||
pub async fn error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> {
|
||||
trace!("AuthClient::error(kind = {:?}, text = {:?})", kind, text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Error { kind, text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
self.inner.fire(Auth::Msg { encrypted_payload }).await
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt(&mut self, payload: &AuthRequest) -> io::Result<Vec<u8>> {
|
||||
let codec = self.codec.as_mut().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (client encrypt message)",
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize(&mut self, payload: &[u8]) -> io::Result<AuthResponse> {
|
||||
let codec = self.codec.as_mut().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (client decrypt message)",
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Client, FramedTransport, Request, Response, TypedAsyncRead, TypedAsyncWrite};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
const TIMEOUT_MILLIS: u64 = 100;
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_should_fail_if_get_unexpected_response_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.handshake().await });
|
||||
|
||||
// Get the request, but send a bad response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Handshake { .. } => server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: Vec::new(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap(),
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
}
|
||||
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Handshake succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.challenge(Vec::new(), HashMap::new()).await });
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Challenge succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_fail_if_receive_wrong_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.challenge(
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
extra: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Challenge { .. } => {
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Verify { valid: true },
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Challenge succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_return_answers_received_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.challenge(
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
extra: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Challenge { questions, extra } => {
|
||||
assert_eq!(
|
||||
questions,
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
extra: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
extra,
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
);
|
||||
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Challenge {
|
||||
answers: vec![
|
||||
"answer1".to_string(),
|
||||
"answer2".to_string(),
|
||||
],
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
let answers = task.await.unwrap().unwrap();
|
||||
assert_eq!(answers, vec!["answer1".to_string(), "answer2".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Verify succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_fail_if_receive_wrong_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a verify request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Verify { .. } => {
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Challenge {
|
||||
answers: Vec::new(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Verify succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_return_valid_bool_received_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Verify { kind, text } => {
|
||||
assert_eq!(kind, AuthVerifyKind::Host);
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Verify { valid: true },
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
let valid = task.await.unwrap().unwrap();
|
||||
assert!(valid, "Got verify response, but valid was set incorrectly");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.info("some text".to_string()).await });
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Info succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_send_the_server_a_request_but_not_wait_for_a_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client.info("some text".to_string()).await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a request
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Info { text } => {
|
||||
assert_eq!(text, "some text");
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
task.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client
|
||||
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Error succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_should_send_the_server_a_request_but_not_wait_for_a_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a request
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Error { kind, text } => {
|
||||
assert_eq!(kind, AuthErrorKind::FailedChallenge);
|
||||
assert_eq!(text, "some text");
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
task.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
async fn wait_ms(ms: u64) {
|
||||
use std::time::Duration;
|
||||
tokio::time::sleep(Duration::from_millis(ms)).await;
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt<T: Serialize>(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &T,
|
||||
) -> io::Result<Vec<u8>> {
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize<T: DeserializeOwned>(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &[u8],
|
||||
) -> io::Result<T> {
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<T>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
use p256::{ecdh::EphemeralSecret, PublicKey};
|
||||
use rand::rngs::OsRng;
|
||||
use sha2::Sha256;
|
||||
use std::{convert::TryFrom, io};
|
||||
|
||||
mod pkb;
|
||||
pub use pkb::PublicKeyBytes;
|
||||
|
||||
mod salt;
|
||||
pub use salt::Salt;
|
||||
|
||||
/// 32-byte key shared by handshake
|
||||
pub type SharedKey = [u8; 32];
|
||||
|
||||
/// Utility to perform a handshake
|
||||
pub struct Handshake {
|
||||
secret: EphemeralSecret,
|
||||
salt: Salt,
|
||||
}
|
||||
|
||||
impl Default for Handshake {
|
||||
// Create a new handshake instance with a secret and salt
|
||||
fn default() -> Self {
|
||||
let secret = EphemeralSecret::random(&mut OsRng);
|
||||
let salt = Salt::random();
|
||||
|
||||
Self { secret, salt }
|
||||
}
|
||||
}
|
||||
|
||||
impl Handshake {
|
||||
// Return encoded bytes of public key
|
||||
pub fn pk_bytes(&self) -> PublicKeyBytes {
|
||||
PublicKeyBytes::from(self.secret.public_key())
|
||||
}
|
||||
|
||||
// Return the salt contained by this handshake
|
||||
pub fn salt(&self) -> &Salt {
|
||||
&self.salt
|
||||
}
|
||||
|
||||
pub fn handshake(&self, public_key: PublicKeyBytes, salt: Salt) -> io::Result<SharedKey> {
|
||||
// Decode the public key of the client
|
||||
let decoded_public_key = PublicKey::try_from(public_key)?;
|
||||
|
||||
// Produce a salt that is consistent with what the other side will do
|
||||
let shared_salt = self.salt ^ salt;
|
||||
|
||||
// Acquire the shared secret
|
||||
let shared_secret = self.secret.diffie_hellman(&decoded_public_key);
|
||||
|
||||
// Extract entropy from the shared secret for use in producing a key
|
||||
let hkdf = shared_secret.extract::<Sha256>(Some(shared_salt.as_ref()));
|
||||
|
||||
// Derive a shared key (32 bytes)
|
||||
let mut shared_key = [0u8; 32];
|
||||
match hkdf.expand(&[], &mut shared_key) {
|
||||
Ok(_) => Ok(shared_key),
|
||||
Err(x) => Err(io::Error::new(io::ErrorKind::InvalidData, x.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
use p256::{EncodedPoint, PublicKey};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::{convert::TryFrom, io};
|
||||
|
||||
/// Represents a wrapper around [`EncodedPoint`], and exists to
|
||||
/// fix an issue with [`serde`] deserialization failing when
|
||||
/// directly serializing the [`EncodedPoint`] type
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(into = "Vec<u8>", try_from = "Vec<u8>")]
|
||||
pub struct PublicKeyBytes(EncodedPoint);
|
||||
|
||||
impl From<PublicKey> for PublicKeyBytes {
|
||||
fn from(pk: PublicKey) -> Self {
|
||||
Self(EncodedPoint::from(pk))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<PublicKeyBytes> for PublicKey {
|
||||
type Error = io::Error;
|
||||
|
||||
fn try_from(pkb: PublicKeyBytes) -> Result<Self, Self::Error> {
|
||||
PublicKey::from_sec1_bytes(pkb.0.as_ref())
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PublicKeyBytes> for Vec<u8> {
|
||||
fn from(pkb: PublicKeyBytes) -> Self {
|
||||
pkb.0.as_bytes().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Vec<u8>> for PublicKeyBytes {
|
||||
type Error = io::Error;
|
||||
|
||||
fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
|
||||
Ok(Self(EncodedPoint::from_bytes(bytes).map_err(|x| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, x.to_string())
|
||||
})?))
|
||||
}
|
||||
}
|
||||
|
||||
impl serde_bytes::Serialize for PublicKeyBytes {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_bytes(self.0.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde_bytes::Deserialize<'de> for PublicKeyBytes {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let bytes = Deserialize::deserialize(deserializer).map(serde_bytes::ByteBuf::into_vec)?;
|
||||
PublicKeyBytes::try_from(bytes).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
@ -0,0 +1,111 @@
|
||||
use rand::{rngs::OsRng, RngCore};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::{
|
||||
convert::{TryFrom, TryInto},
|
||||
fmt, io,
|
||||
ops::BitXor,
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
/// Friendly wrapper around a 32-byte array representing a salt
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(into = "Vec<u8>", try_from = "Vec<u8>")]
|
||||
pub struct Salt([u8; 32]);
|
||||
|
||||
impl Salt {
|
||||
/// Generates a salt via a uniform random
|
||||
pub fn random() -> Self {
|
||||
let mut salt = [0u8; 32];
|
||||
OsRng.fill_bytes(&mut salt);
|
||||
Self(salt)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Salt {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", hex::encode(self.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl serde_bytes::Serialize for Salt {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_bytes(self.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde_bytes::Deserialize<'de> for Salt {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let bytes = Deserialize::deserialize(deserializer).map(serde_bytes::ByteBuf::into_vec)?;
|
||||
let bytes_len = bytes.len();
|
||||
Salt::try_from(bytes)
|
||||
.map_err(|_| serde::de::Error::invalid_length(bytes_len, &"expected 32-byte length"))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Salt> for String {
|
||||
fn from(salt: Salt) -> Self {
|
||||
salt.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Salt {
|
||||
type Err = io::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let bytes = hex::decode(s).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?;
|
||||
Self::try_from(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for Salt {
|
||||
type Error = io::Error;
|
||||
|
||||
fn try_from(s: String) -> Result<Self, Self::Error> {
|
||||
s.parse()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Vec<u8>> for Salt {
|
||||
type Error = io::Error;
|
||||
|
||||
fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
|
||||
Ok(Self(bytes.try_into().map_err(|x: Vec<u8>| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Vec<u8> len of {} != 32", x.len()),
|
||||
)
|
||||
})?))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Salt> for Vec<u8> {
|
||||
fn from(salt: Salt) -> Self {
|
||||
salt.0.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for Salt {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl BitXor for Salt {
|
||||
type Output = Self;
|
||||
|
||||
fn bitxor(self, rhs: Self) -> Self::Output {
|
||||
let shared_salt = self
|
||||
.0
|
||||
.iter()
|
||||
.zip(rhs.0.iter())
|
||||
.map(|(x, y)| x ^ y)
|
||||
.collect::<Vec<u8>>();
|
||||
Self::try_from(shared_salt).unwrap()
|
||||
}
|
||||
}
|
@ -0,0 +1,653 @@
|
||||
use crate::{
|
||||
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Codec,
|
||||
Handshake, Server, ServerCtx, XChaCha20Poly1305Codec,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Type signature for a dynamic on_challenge function
|
||||
pub type AuthChallengeFn =
|
||||
dyn Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync;
|
||||
|
||||
/// Type signature for a dynamic on_verify function
|
||||
pub type AuthVerifyFn = dyn Fn(AuthVerifyKind, String) -> bool + Send + Sync;
|
||||
|
||||
/// Type signature for a dynamic on_info function
|
||||
pub type AuthInfoFn = dyn Fn(String) + Send + Sync;
|
||||
|
||||
/// Type signature for a dynamic on_error function
|
||||
pub type AuthErrorFn = dyn Fn(AuthErrorKind, String) + Send + Sync;
|
||||
|
||||
/// Represents an [`AuthServer`] where all handlers are stored on the heap
|
||||
pub type HeapAuthServer =
|
||||
AuthServer<Box<AuthChallengeFn>, Box<AuthVerifyFn>, Box<AuthInfoFn>, Box<AuthErrorFn>>;
|
||||
|
||||
/// Server that handles authentication
|
||||
pub struct AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
|
||||
where
|
||||
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
|
||||
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
|
||||
InfoFn: Fn(String) + Send + Sync,
|
||||
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
|
||||
{
|
||||
pub on_challenge: ChallengeFn,
|
||||
pub on_verify: VerifyFn,
|
||||
pub on_info: InfoFn,
|
||||
pub on_error: ErrorFn,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<ChallengeFn, VerifyFn, InfoFn, ErrorFn> Server
|
||||
for AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
|
||||
where
|
||||
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
|
||||
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
|
||||
InfoFn: Fn(String) + Send + Sync,
|
||||
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
|
||||
{
|
||||
type Request = Auth;
|
||||
type Response = Auth;
|
||||
type LocalData = RwLock<Option<XChaCha20Poly1305Codec>>;
|
||||
|
||||
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
|
||||
let reply = ctx.reply.clone();
|
||||
|
||||
match ctx.request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
trace!(
|
||||
"Received handshake request from client, request id = {}",
|
||||
ctx.request.id
|
||||
);
|
||||
let handshake = Handshake::default();
|
||||
match handshake.handshake(public_key, salt) {
|
||||
Ok(key) => {
|
||||
ctx.local_data
|
||||
.write()
|
||||
.await
|
||||
.replace(XChaCha20Poly1305Codec::new(&key));
|
||||
|
||||
trace!(
|
||||
"Sending reciprocal handshake to client, response origin id = {}",
|
||||
ctx.request.id
|
||||
);
|
||||
if let Err(x) = reply
|
||||
.send(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Auth::Msg {
|
||||
ref encrypted_payload,
|
||||
} => {
|
||||
trace!(
|
||||
"Received auth msg, encrypted payload size = {}",
|
||||
encrypted_payload.len()
|
||||
);
|
||||
|
||||
// Attempt to decrypt the message so we can understand what to do
|
||||
let request = match ctx.local_data.write().await.as_mut() {
|
||||
Some(codec) => {
|
||||
let mut payload = BytesMut::from(encrypted_payload.as_slice());
|
||||
match codec.decode(&mut payload) {
|
||||
Ok(Some(payload)) => {
|
||||
utils::deserialize_from_slice::<AuthRequest>(&payload)
|
||||
}
|
||||
Ok(None) => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
Err(x) => Err(x),
|
||||
}
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (server decrypt message)",
|
||||
)),
|
||||
};
|
||||
|
||||
let response = match request {
|
||||
Ok(request) => match request {
|
||||
AuthRequest::Challenge { questions, extra } => {
|
||||
trace!("Received challenge request");
|
||||
trace!("questions = {:?}", questions);
|
||||
trace!("extra = {:?}", extra);
|
||||
|
||||
let answers = (self.on_challenge)(questions, extra);
|
||||
AuthResponse::Challenge { answers }
|
||||
}
|
||||
AuthRequest::Verify { kind, text } => {
|
||||
trace!("Received verify request");
|
||||
trace!("kind = {:?}", kind);
|
||||
trace!("text = {:?}", text);
|
||||
|
||||
let valid = (self.on_verify)(kind, text);
|
||||
AuthResponse::Verify { valid }
|
||||
}
|
||||
AuthRequest::Info { text } => {
|
||||
trace!("Received info request");
|
||||
trace!("text = {:?}", text);
|
||||
|
||||
(self.on_info)(text);
|
||||
return;
|
||||
}
|
||||
AuthRequest::Error { kind, text } => {
|
||||
trace!("Received error request");
|
||||
trace!("kind = {:?}", kind);
|
||||
trace!("text = {:?}", text);
|
||||
|
||||
(self.on_error)(kind, text);
|
||||
return;
|
||||
}
|
||||
},
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Serialize and encrypt the message before sending it back
|
||||
let encrypted_payload = match ctx.local_data.write().await.as_mut() {
|
||||
Some(codec) => {
|
||||
let mut encrypted_payload = BytesMut::new();
|
||||
|
||||
// Convert the response into bytes for us to send back
|
||||
match utils::serialize_to_vec(&response) {
|
||||
Ok(bytes) => match codec.encode(&bytes, &mut encrypted_payload) {
|
||||
Ok(_) => Ok(encrypted_payload.freeze().to_vec()),
|
||||
Err(x) => Err(x),
|
||||
},
|
||||
Err(x) => Err(x),
|
||||
}
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (server encrypt messaage)",
|
||||
)),
|
||||
};
|
||||
|
||||
match encrypted_payload {
|
||||
Ok(encrypted_payload) => {
|
||||
if let Err(x) = reply.send(Auth::Msg { encrypted_payload }).await {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
IntoSplit, MpscListener, MpscTransport, Request, Response, ServerExt, ServerRef,
|
||||
TypedAsyncRead, TypedAsyncWrite,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
const TIMEOUT_MILLIS: u64 = 100;
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_not_reply_if_receive_encrypted_msg_without_handshake_first() {
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send an encrypted message before establishing a handshake
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: Vec::new(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_reply_to_handshake_request_with_new_public_key_and_salt() {
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Wait for a handshake response
|
||||
tokio::select! {
|
||||
x = t.read() => {
|
||||
let response = x.expect("Request failed").expect("Response missing");
|
||||
match response.payload {
|
||||
Auth::Handshake { .. } => {},
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
}
|
||||
}
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => panic!("Ran out of time waiting on response"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_not_reply_if_receive_invalid_encrypted_msg() {
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send a bad chunk of data
|
||||
let _codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: vec![1, 2, 3, 4],
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_challenge_request_and_reply() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */
|
||||
move |questions, extra| {
|
||||
tx.try_send((questions, extra)).unwrap();
|
||||
vec!["answer1".to_string(), "answer2".to_string()]
|
||||
},
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Challenge {
|
||||
questions: vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
extra: vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
extra: vec![("hello".to_string(), "world".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let (questions, extra) = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(
|
||||
questions,
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
extra: vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
}
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
extra,
|
||||
vec![("hello".to_string(), "world".to_string())]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
// Wait for a response and verify that it matches what we expect
|
||||
tokio::select! {
|
||||
x = t.read() => {
|
||||
let response = x.expect("Request failed").expect("Response missing");
|
||||
match response.payload {
|
||||
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthResponse::Challenge { answers } =>
|
||||
assert_eq!(
|
||||
answers,
|
||||
vec!["answer1".to_string(), "answer2".to_string()]
|
||||
),
|
||||
_ => panic!("Got wrong response for verify"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_verify_request_and_reply() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */
|
||||
move |kind, text| {
|
||||
tx.try_send((kind, text)).unwrap();
|
||||
true
|
||||
},
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Verify {
|
||||
kind: AuthVerifyKind::Host,
|
||||
text: "some text".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(kind, AuthVerifyKind::Host);
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
// Wait for a response and verify that it matches what we expect
|
||||
tokio::select! {
|
||||
x = t.read() => {
|
||||
let response = x.expect("Request failed").expect("Response missing");
|
||||
match response.payload {
|
||||
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthResponse::Verify { valid } =>
|
||||
assert!(valid, "Got verify, but valid was wrong"),
|
||||
_ => panic!("Got wrong response for verify"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_info_request() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */
|
||||
move |text| {
|
||||
tx.try_send(text).unwrap();
|
||||
},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Info {
|
||||
text: "some text".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let text = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_error_request() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */
|
||||
move |kind, text| {
|
||||
tx.try_send((kind, text)).unwrap();
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Error {
|
||||
kind: AuthErrorKind::FailedChallenge,
|
||||
text: "some text".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(kind, AuthErrorKind::FailedChallenge);
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_ms(ms: u64) {
|
||||
use std::time::Duration;
|
||||
tokio::time::sleep(Duration::from_millis(ms)).await;
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &AuthRequest,
|
||||
) -> io::Result<Vec<u8>> {
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &[u8],
|
||||
) -> io::Result<AuthResponse> {
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_auth_server<ChallengeFn, VerifyFn, InfoFn, ErrorFn>(
|
||||
on_challenge: ChallengeFn,
|
||||
on_verify: VerifyFn,
|
||||
on_info: InfoFn,
|
||||
on_error: ErrorFn,
|
||||
) -> io::Result<(
|
||||
MpscTransport<Request<Auth>, Response<Auth>>,
|
||||
Box<dyn ServerRef>,
|
||||
)>
|
||||
where
|
||||
ChallengeFn:
|
||||
Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync + 'static,
|
||||
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync + 'static,
|
||||
InfoFn: Fn(String) + Send + Sync + 'static,
|
||||
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync + 'static,
|
||||
{
|
||||
let server = AuthServer {
|
||||
on_challenge,
|
||||
on_verify,
|
||||
on_info,
|
||||
on_error,
|
||||
};
|
||||
|
||||
// Create a test listener where we will forward a connection
|
||||
let (tx, listener) = MpscListener::channel(100);
|
||||
|
||||
// Make bounded transport pair and send off one of them to act as our connection
|
||||
let (transport, connection) = MpscTransport::<Request<Auth>, Response<Auth>>::pair(100);
|
||||
tx.send(connection.into_split())
|
||||
.await
|
||||
.expect("Failed to feed listener a connection");
|
||||
|
||||
let server = server.start(listener)?;
|
||||
Ok((transport, server))
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue