diff --git a/Cargo.lock b/Cargo.lock index 0c1c903e..1a2b534f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2202,11 +2202,14 @@ dependencies = [ "serde_json", "sha2 0.10.9", "shared_child", + "size-parser", "smallvec", "tokio", "tokio-rustls", + "tokio-util", "tracing", "tracing-subscriber", + "yamux", ] [[package]] @@ -3071,6 +3074,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -4573,6 +4582,12 @@ dependencies = [ "dstack-sdk-types", ] +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "nom" version = "7.1.3" @@ -7552,6 +7567,7 @@ checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", @@ -7785,6 +7801,20 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" +[[package]] +name = "uds2yamux" +version = "0.5.6" +dependencies = [ + "anyhow", + "clap", + "futures", + "tokio", + "tokio-util", + "tracing", + "tracing-subscriber", + "yamux", +] + [[package]] name = "uint" version = "0.9.5" @@ -8653,6 +8683,22 @@ dependencies = [ "hashlink", ] +[[package]] +name = "yamux" +version = "0.13.8" +source = "git+https://github.com/kvinwang/rust-yamux?branch=feat%2Fping-timeout#096e208c29ae9b2c68c4ec2d46589855d8cbc3eb" +dependencies = [ + "futures", + "futures-timer", + "log", + "nohash-hasher", + "parking_lot 0.12.5", + "pin-project", + "rand 0.9.2", + "static_assertions", + "web-time", +] + [[package]] name = "yansi" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index 218c01a7..33c05913 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,7 @@ members = [ "verifier", "no_std_check", "size-parser", + "uds2yamux", ] resolver = "2" @@ -159,7 +160,10 @@ rocket = { git = "https://github.com/rwf2/Rocket", branch = "master", features = ] } rocket-apitoken = { git = "https://github.com/kvinwang/rocket-apitoken", branch = "dev" } tokio = { version = "1.46.1" } +tokio-util = { version = "0.7", features = ["compat"] } tokio-vsock = "0.7.0" +yamux = { git = "https://github.com/kvinwang/rust-yamux", branch = "feat/ping-timeout" } +quinn = { path = "/home/kvin/src/quinn-plain/quinn", default-features = false, features = ["runtime-tokio", "null-crypto"] } sysinfo = "0.35.2" default-net = "0.22.0" url = "2.5" diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 8a30c6da..32cb81a8 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -20,6 +20,7 @@ fs-err.workspace = true clap = { workspace = true, features = ["derive", "string"] } shared_child.workspace = true tokio = { workspace = true, features = ["full"] } +tokio-util = { workspace = true, features = ["compat"] } rustls.workspace = true tokio-rustls = { workspace = true, features = ["ring"] } rinja.workspace = true @@ -37,6 +38,7 @@ bytes.workspace = true safe-write.workspace = true smallvec.workspace = true futures.workspace = true +yamux.workspace = true cmd_lib.workspace = true load_config.workspace = true dstack-kms-rpc.workspace = true @@ -46,6 +48,7 @@ http-client = { workspace = true, features = ["prpc"] } sha2.workspace = true dstack-types.workspace = true serde-duration.workspace = true +size-parser = { workspace = true, features = ["serde"] } reqwest = { workspace = true, features = ["json"] } hyper = { workspace = true, features = ["server", "http1"] } hyper-util = { version = "0.1", features = ["tokio"] } diff --git a/gateway/dstack-app/builder/entrypoint.sh b/gateway/dstack-app/builder/entrypoint.sh index 9cd46755..f2112cf1 100755 --- a/gateway/dstack-app/builder/entrypoint.sh +++ b/gateway/dstack-app/builder/entrypoint.sh @@ -150,6 +150,16 @@ write = "5s" shutdown = "5s" total = "5h" +[core.proxy.yamux] +listen_addr = "${YAMUX_LISTEN_ADDR:-0.0.0.0}" +listen_port = ${YAMUX_LISTEN_PORT:-0} +## Set to 0 to disable max connection receive window limit. +max_connection_receive_window = ${YAMUX_MAX_CONNECTION_RECEIVE_WINDOW:-1073741824} +max_num_streams = ${YAMUX_MAX_NUM_STREAMS:-512} +read_after_close = ${YAMUX_READ_AFTER_CLOSE:-true} +split_send_size = ${YAMUX_SPLIT_SEND_SIZE:-16384} +ping_timeout = "${YAMUX_PING_TIMEOUT:-15s}" + [core.recycle] enabled = true interval = "5m" diff --git a/gateway/gateway.toml b/gateway/gateway.toml index d3e5816b..fc703015 100644 --- a/gateway/gateway.toml +++ b/gateway/gateway.toml @@ -71,6 +71,16 @@ app_address_ns_compat = true workers = 32 external_port = 443 +[core.proxy.yamux] +listen_addr = "0.0.0.0" +listen_port = 4433 +# max_connection_receive_window = 0 to disable limit +max_connection_receive_window = "1G" +max_num_streams = 4096 +read_after_close = true +split_send_size = "16K" +ping_timeout = "15s" + [core.proxy.timeouts] # Timeout for establishing a connection to the target app. connect = "5s" diff --git a/gateway/src/config.rs b/gateway/src/config.rs index 9c81ca8e..b22412ee 100644 --- a/gateway/src/config.rs +++ b/gateway/src/config.rs @@ -85,6 +85,8 @@ pub struct ProxyConfig { pub workers: usize, pub app_address_ns_prefix: String, pub app_address_ns_compat: bool, + #[serde(default)] + pub yamux: YamuxConfig, } #[derive(Debug, Clone, Deserialize)] @@ -112,6 +114,35 @@ pub struct Timeouts { pub shutdown: Duration, } +#[derive(Debug, Clone, Deserialize)] +pub struct YamuxConfig { + pub listen_addr: Ipv4Addr, + pub listen_port: Option, + #[serde(with = "size_parser::human_size")] + pub max_connection_receive_window: usize, + pub max_num_streams: usize, + pub read_after_close: bool, + #[serde(with = "size_parser::human_size")] + pub split_send_size: usize, + /// Ping timeout for yamux connections. Set to 0s to disable. + #[serde(with = "serde_duration")] + pub ping_timeout: Duration, +} + +impl Default for YamuxConfig { + fn default() -> Self { + Self { + listen_addr: Ipv4Addr::new(0, 0, 0, 0), + listen_port: None, + max_connection_receive_window: 1024 * 1024 * 1024, + max_num_streams: 4096, + read_after_close: true, + split_send_size: 16 * 1024, + ping_timeout: Duration::from_secs(15), + } + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct RecycleConfig { pub enabled: bool, diff --git a/gateway/src/proxy.rs b/gateway/src/proxy.rs index 73b947cc..ef7595f4 100644 --- a/gateway/src/proxy.rs +++ b/gateway/src/proxy.rs @@ -8,20 +8,91 @@ use std::{ atomic::{AtomicU64, AtomicUsize, Ordering}, Arc, }, + time::Duration, }; +use std::io; +use std::pin::Pin; +use std::task::{Context as TaskContext, Poll}; + use anyhow::{bail, Context, Result}; use sni::extract_sni; pub(crate) use tls_terminate::create_acceptor; use tokio::{ - io::AsyncReadExt, + io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}, net::{TcpListener, TcpStream}, time::timeout, }; +use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use tracing::{debug, error, info, info_span, Instrument}; use crate::{config::ProxyConfig, main_service::Proxy, models::EnteredCounter}; +/// Abstraction over inbound connection types (TCP or yamux stream). +#[pin_project::pin_project(project = InboundStreamProj)] +pub(crate) enum InboundStream { + Tcp(#[pin] TcpStream), + Yamux(#[pin] tokio_util::compat::Compat), +} + +impl AsyncRead for InboundStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.project() { + InboundStreamProj::Tcp(s) => s.poll_read(cx, buf), + InboundStreamProj::Yamux(s) => s.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for InboundStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + InboundStreamProj::Tcp(s) => s.poll_write(cx, buf), + InboundStreamProj::Yamux(s) => s.poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + match self.project() { + InboundStreamProj::Tcp(s) => s.poll_flush(cx), + InboundStreamProj::Yamux(s) => s.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + match self.project() { + InboundStreamProj::Tcp(s) => s.poll_shutdown(cx), + InboundStreamProj::Yamux(s) => s.poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.project() { + InboundStreamProj::Tcp(s) => s.poll_write_vectored(cx, bufs), + InboundStreamProj::Yamux(s) => s.poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + InboundStream::Tcp(s) => s.is_write_vectored(), + InboundStream::Yamux(s) => s.is_write_vectored(), + } + } +} + #[derive(Debug, Clone)] pub(crate) struct AddressInfo { pub ip: Ipv4Addr, @@ -35,7 +106,7 @@ mod sni; mod tls_passthough; mod tls_terminate; -async fn take_sni(stream: &mut TcpStream) -> Result<(Option, Vec)> { +async fn take_sni(stream: &mut (impl AsyncRead + Unpin)) -> Result<(Option, Vec)> { let mut buffer = vec![0u8; 4096]; let mut data_len = 0; loop { @@ -132,7 +203,7 @@ fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result { pub static NUM_CONNECTIONS: AtomicU64 = AtomicU64::new(0); async fn handle_connection( - mut inbound: TcpStream, + mut inbound: InboundStream, state: Proxy, dotted_base_domain: &str, ) -> Result<()> { @@ -173,7 +244,7 @@ pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> { let base_domain = base_domain.strip_prefix(".").unwrap_or(base_domain); Arc::new(format!(".{base_domain}")) }; - let listener = TcpListener::bind((config.listen_addr, config.listen_port)) + let tcp_listener = TcpListener::bind((config.listen_addr, config.listen_port)) .await .with_context(|| { format!( @@ -186,17 +257,147 @@ pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> { config.listen_addr, config.listen_port ); + let yamux_cfg = &config.yamux; + let mut yamux_config = yamux::Config::default(); + let max_conn_window = if yamux_cfg.max_connection_receive_window == 0 { + None + } else { + Some(yamux_cfg.max_connection_receive_window) + }; + yamux_config.set_max_connection_receive_window(max_conn_window); + yamux_config.set_max_num_streams(yamux_cfg.max_num_streams); + yamux_config.set_read_after_close(yamux_cfg.read_after_close); + yamux_config.set_split_send_size(yamux_cfg.split_send_size); + if yamux_cfg.ping_timeout != Duration::ZERO { + yamux_config.set_ping_timeout(Some(yamux_cfg.ping_timeout)); + } + + let yamux_listener = if let Some(yamux_port) = yamux_cfg.listen_port.filter(|p| *p != 0) { + let listener = TcpListener::bind((yamux_cfg.listen_addr, yamux_port)) + .await + .with_context(|| format!("failed to bind yamux {}:{yamux_port}", yamux_cfg.listen_addr))?; + info!( + "yamux bridge listening on {}:{}", + yamux_cfg.listen_addr, yamux_port + ); + Some(listener) + } else { + None + }; + + loop { + if let Some(ref yamux_listener) = yamux_listener { + tokio::select! { + result = tcp_listener.accept() => { + match result { + Ok((stream, from)) => { + spawn_connection( + &workers_rt, + InboundStream::Tcp(stream), + from.to_string(), + proxy.clone(), + dotted_base_domain.clone(), + ); + } + Err(e) => error!("failed to accept tcp connection: {e:?}"), + } + } + result = yamux_listener.accept() => { + match result { + Ok((stream, from)) => { + let proxy = proxy.clone(); + let dotted_base_domain = dotted_base_domain.clone(); + let handle = workers_rt.handle().clone(); + let yamux_config = yamux_config.clone(); + tokio::spawn(async move { + let from = format!("yamux:{from}"); + info!(%from, "new yamux connection"); + handle_yamux_connection( + stream, yamux_config, &handle, proxy, dotted_base_domain, + ).await; + }); + } + Err(e) => error!("failed to accept yamux connection: {e:?}"), + } + } + } + } else { + match tcp_listener.accept().await { + Ok((stream, from)) => { + spawn_connection( + &workers_rt, + InboundStream::Tcp(stream), + from.to_string(), + proxy.clone(), + dotted_base_domain.clone(), + ); + } + Err(e) => error!("failed to accept tcp connection: {e:?}"), + } + }; + } +} + +/// Spawn a single inbound connection handler on the worker runtime. +fn spawn_connection( + workers_rt: &tokio::runtime::Runtime, + inbound: InboundStream, + from: String, + proxy: Proxy, + dotted_base_domain: Arc, +) { + let span = info_span!("conn", id = next_connection_id()); + let _enter = span.enter(); + let conn_entered = EnteredCounter::new(&NUM_CONNECTIONS); + + info!(%from, "new connection"); + workers_rt.spawn( + async move { + let _conn_entered = conn_entered; + let timeouts = &proxy.config.proxy.timeouts; + let result = timeout( + timeouts.total, + handle_connection(inbound, proxy, &dotted_base_domain), + ) + .await; + match result { + Ok(Ok(_)) => { + info!("connection closed"); + } + Ok(Err(e)) => { + error!("connection error: {e:?}"); + } + Err(_) => { + error!("connection kept too long, force closing"); + } + } + } + .in_current_span(), + ); +} + +/// Handle a yamux connection. +/// Each accepted yamux stream becomes an independent inbound connection. +async fn handle_yamux_connection( + tcp_stream: TcpStream, + yamux_config: yamux::Config, + workers_handle: &tokio::runtime::Handle, + proxy: Proxy, + dotted_base_domain: Arc, +) { + let mut conn = yamux::Connection::new(tcp_stream.compat(), yamux_config, yamux::Mode::Server); loop { - match listener.accept().await { - Ok((inbound, from)) => { + match std::future::poll_fn(|cx| conn.poll_next_inbound(cx)).await { + Some(Ok(stream)) => { + let inbound = InboundStream::Yamux(stream.compat()); let span = info_span!("conn", id = next_connection_id()); let _enter = span.enter(); let conn_entered = EnteredCounter::new(&NUM_CONNECTIONS); - info!(%from, "new connection"); + debug!("new yamux stream"); let proxy = proxy.clone(); let dotted_base_domain = dotted_base_domain.clone(); - workers_rt.spawn( + workers_handle.spawn( async move { let _conn_entered = conn_entered; let timeouts = &proxy.config.proxy.timeouts; @@ -220,8 +421,13 @@ pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> { .in_current_span(), ); } - Err(e) => { - error!("failed to accept connection: {e:?}"); + Some(Err(e)) => { + error!("yamux accept error: {e}"); + break; + } + None => { + info!("yamux connection closed by peer"); + break; } } } @@ -257,6 +463,7 @@ pub fn start(config: ProxyConfig, app_state: Proxy) -> Result<()> { #[cfg(test)] mod tests { use super::*; + use tokio::io::AsyncWriteExt as _; #[test] fn test_parse_destination() { @@ -320,4 +527,88 @@ mod tests { assert!(parse_destination("-8080.example.com", base_domain).is_err()); assert!(parse_destination("myapp-8080ss.example.com", base_domain).is_err()); } + + #[tokio::test] + async fn test_inbound_stream_tcp_read_write() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let mut inbound = InboundStream::Tcp(stream); + inbound.write_all(b"hello vsock").await.unwrap(); + inbound.flush().await.unwrap(); + let mut buf = vec![0u8; 64]; + let n = inbound.read(&mut buf).await.unwrap(); + String::from_utf8(buf[..n].to_vec()).unwrap() + }); + + let (server_stream, _) = listener.accept().await.unwrap(); + let mut server = server_stream; + let mut buf = vec![0u8; 64]; + let n = server.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..n], b"hello vsock"); + server.write_all(b"echo back").await.unwrap(); + server.shutdown().await.unwrap(); + + let response = client_handle.await.unwrap(); + assert_eq!(response, "echo back"); + } + + #[tokio::test] + async fn test_take_sni_with_inbound_stream() { + // Verify take_sni works with InboundStream (via impl AsyncRead + Unpin) + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + // Send a TLS ClientHello with SNI "test.example.com" + let client_handle = tokio::spawn(async move { + let mut stream = TcpStream::connect(addr).await.unwrap(); + // Minimal TLS ClientHello with SNI extension + // Record header: ContentType=22 (Handshake), Version=0x0301, Length + // Handshake: Type=1 (ClientHello) + let sni_hostname = b"test.example.com"; + let _sni_entry_len = (sni_hostname.len() + 3) as u16; + let sni_list_len = (sni_hostname.len() + 5) as u16; + let ext_data_len = (sni_hostname.len() + 7) as u16; + let extensions_len = ext_data_len + 4; // type(2) + len(2) + data + let client_hello_body_len = 2 + 32 + 1 + 2 + 1 + 2 + extensions_len; + let handshake_len = client_hello_body_len + 4; // type(1) + len(3) + let mut hello = Vec::new(); + // TLS record header + hello.push(0x16); // ContentType: Handshake + hello.extend_from_slice(&[0x03, 0x01]); // Version: TLS 1.0 + hello.extend_from_slice(&(handshake_len as u16).to_be_bytes()); + // Handshake header + hello.push(0x01); // ClientHello + hello.push(0x00); + hello.extend_from_slice(&(client_hello_body_len as u16).to_be_bytes()); + // ClientHello body + hello.extend_from_slice(&[0x03, 0x03]); // Version: TLS 1.2 + hello.extend_from_slice(&[0u8; 32]); // Random + hello.push(0x00); // Session ID length + hello.extend_from_slice(&[0x00, 0x02]); // Cipher suites length + hello.extend_from_slice(&[0x00, 0x2f]); // TLS_RSA_WITH_AES_128_CBC_SHA + hello.push(0x01); // Compression methods length + hello.push(0x00); // null compression + hello.extend_from_slice(&extensions_len.to_be_bytes()); + // SNI extension + hello.extend_from_slice(&[0x00, 0x00]); // Extension type: SNI + hello.extend_from_slice(&ext_data_len.to_be_bytes()); + hello.extend_from_slice(&sni_list_len.to_be_bytes()); + hello.push(0x00); // Host name type + hello.extend_from_slice(&(sni_hostname.len() as u16).to_be_bytes()); + hello.extend_from_slice(sni_hostname); + + stream.write_all(&hello).await.unwrap(); + stream.shutdown().await.unwrap(); + }); + + let (server_stream, _) = listener.accept().await.unwrap(); + let mut inbound = InboundStream::Tcp(server_stream); + let (sni, _buffer) = take_sni(&mut inbound).await.unwrap(); + assert_eq!(sni.as_deref(), Some("test.example.com")); + + client_handle.await.unwrap(); + } } diff --git a/gateway/src/proxy/tls_passthough.rs b/gateway/src/proxy/tls_passthough.rs index 6184c1b5..73176b48 100644 --- a/gateway/src/proxy/tls_passthough.rs +++ b/gateway/src/proxy/tls_passthough.rs @@ -12,7 +12,7 @@ use crate::{ models::{Counting, EnteredCounter}, }; -use super::{io_bridge::bridge, AddressGroup}; +use super::{io_bridge::bridge, AddressGroup, InboundStream}; #[derive(Debug)] struct AppAddress { @@ -73,7 +73,7 @@ async fn resolve_app_address(prefix: &str, sni: &str, compat: bool) -> Result, sni: &str, ) -> Result<()> { @@ -121,7 +121,7 @@ pub(crate) async fn connect_multiple_hosts( pub(crate) async fn proxy_to_app( state: Proxy, - inbound: TcpStream, + inbound: InboundStream, buffer: Vec, app_id: &str, port: u16, diff --git a/gateway/src/proxy/tls_terminate.rs b/gateway/src/proxy/tls_terminate.rs index ad19ebf4..b056bbea 100644 --- a/gateway/src/proxy/tls_terminate.rs +++ b/gateway/src/proxy/tls_terminate.rs @@ -7,6 +7,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use super::InboundStream; use anyhow::{anyhow, bail, Context as _, Result}; use fs_err as fs; use hyper::body::Incoming; @@ -19,7 +20,6 @@ use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::version::{TLS12, TLS13}; use serde::Serialize; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio::net::TcpStream; use tokio::time::timeout; use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor}; use tracing::{debug, info}; @@ -175,7 +175,7 @@ impl Proxy { pub(crate) async fn handle_this_node( &self, - inbound: TcpStream, + inbound: InboundStream, buffer: Vec, port: u16, h2: bool, @@ -231,7 +231,7 @@ impl Proxy { /// Deprecated legacy endpoint pub(crate) async fn handle_health_check( &self, - inbound: TcpStream, + inbound: InboundStream, buffer: Vec, port: u16, h2: bool, @@ -268,10 +268,10 @@ impl Proxy { async fn tls_accept( &self, - inbound: TcpStream, + inbound: InboundStream, buffer: Vec, h2: bool, - ) -> Result> { + ) -> Result>> { let stream = MergedStream { buffer, buffer_cursor: 0, @@ -300,7 +300,7 @@ impl Proxy { pub(crate) async fn proxy( &self, - inbound: TcpStream, + inbound: InboundStream, buffer: Vec, app_id: &str, port: u16, @@ -337,14 +337,14 @@ impl Proxy { } #[pin_project::pin_project] -struct MergedStream { +struct MergedStream { buffer: Vec, buffer_cursor: usize, #[pin] - inbound: TcpStream, + inbound: S, } -impl AsyncRead for MergedStream { +impl AsyncRead for MergedStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -366,7 +366,7 @@ impl AsyncRead for MergedStream { this.inbound.poll_read(cx, buf) } } -impl AsyncWrite for MergedStream { +impl AsyncWrite for MergedStream { fn poll_write( self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/gateway/tests/stress_connections.rs b/gateway/tests/stress_connections.rs new file mode 100644 index 00000000..00c1a893 --- /dev/null +++ b/gateway/tests/stress_connections.rs @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Stress test for concurrent TCP connections. +//! +//! Verifies that a TCP accept loop (like the gateway proxy's) can handle +//! thousands of concurrent connections with SNI-like payloads without dropping. +//! +//! Run: cargo test -p dstack-gateway --test stress_connections -- --nocapture + +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::task::JoinSet; + +/// Build a minimal TLS ClientHello with the given SNI hostname. +fn build_client_hello(hostname: &str) -> Vec { + let sni_hostname = hostname.as_bytes(); + let sni_list_len = (sni_hostname.len() + 5) as u16; + let ext_data_len = (sni_hostname.len() + 7) as u16; + let extensions_len = ext_data_len + 4; + let client_hello_body_len = 2 + 32 + 1 + 2 + 1 + 2 + extensions_len; + let handshake_len = client_hello_body_len + 4; + let mut hello = Vec::with_capacity(handshake_len as usize + 5); + hello.push(0x16); + hello.extend_from_slice(&[0x03, 0x01]); + hello.extend_from_slice(&(handshake_len as u16).to_be_bytes()); + hello.push(0x01); + hello.push(0x00); + hello.extend_from_slice(&(client_hello_body_len as u16).to_be_bytes()); + hello.extend_from_slice(&[0x03, 0x03]); + hello.extend_from_slice(&[0u8; 32]); + hello.push(0x00); + hello.extend_from_slice(&[0x00, 0x02]); + hello.extend_from_slice(&[0x00, 0x2f]); + hello.push(0x01); + hello.push(0x00); + hello.extend_from_slice(&extensions_len.to_be_bytes()); + hello.extend_from_slice(&[0x00, 0x00]); + hello.extend_from_slice(&ext_data_len.to_be_bytes()); + hello.extend_from_slice(&sni_list_len.to_be_bytes()); + hello.push(0x00); + hello.extend_from_slice(&(sni_hostname.len() as u16).to_be_bytes()); + hello.extend_from_slice(sni_hostname); + hello +} + +#[tokio::test] +async fn test_5000_concurrent_connections() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let total: u64 = 5000; + let accepted = Arc::new(AtomicU64::new(0)); + let accepted_clone = accepted.clone(); + + // Server: accept, read payload, echo back + let server = tokio::spawn(async move { + let mut handlers = JoinSet::new(); + loop { + let (mut stream, _) = match listener.accept().await { + Ok(v) => v, + Err(_) => break, + }; + let n = accepted_clone.fetch_add(1, Ordering::Relaxed); + handlers.spawn(async move { + let mut buf = vec![0u8; 4096]; + let n = stream.read(&mut buf).await.unwrap_or(0); + if n > 0 { + let _ = stream.write_all(&buf[..n]).await; + } + }); + if n + 1 >= total { + break; + } + } + while handlers.join_next().await.is_some() {} + }); + + // Clients + let success = Arc::new(AtomicU64::new(0)); + let failed = Arc::new(AtomicU64::new(0)); + let hello = Arc::new(build_client_hello("stress.example.com")); + let mut clients = JoinSet::new(); + + for _ in 0..total { + let s = success.clone(); + let f = failed.clone(); + let h = hello.clone(); + clients.spawn(async move { + match TcpStream::connect(addr).await { + Ok(mut stream) => { + if stream.write_all(&h).await.is_ok() { + let mut buf = vec![0u8; 4096]; + match tokio::time::timeout( + std::time::Duration::from_secs(5), + stream.read(&mut buf), + ) + .await + { + Ok(Ok(n)) if n > 0 => { + s.fetch_add(1, Ordering::Relaxed); + } + _ => { + f.fetch_add(1, Ordering::Relaxed); + } + } + } else { + f.fetch_add(1, Ordering::Relaxed); + } + } + Err(_) => { + f.fetch_add(1, Ordering::Relaxed); + } + } + }); + } + + while clients.join_next().await.is_some() {} + let _ = server.await; + + let s = success.load(Ordering::Relaxed); + let f = failed.load(Ordering::Relaxed); + let a = accepted.load(Ordering::Relaxed); + eprintln!("target={total} accepted={a} success={s} failed={f}"); + // Allow up to 5% failure due to timing + assert!( + s >= total * 95 / 100, + "success rate too low: {s}/{total} (failed={f})" + ); +} diff --git a/supervisor/src/main.rs b/supervisor/src/main.rs index 752b511c..ca9ebebb 100644 --- a/supervisor/src/main.rs +++ b/supervisor/src/main.rs @@ -14,7 +14,7 @@ use rocket::{ }; use supervisor::web_api; use tracing::error; -use tracing_subscriber::EnvFilter; +use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; pub const DEFAULT_CONFIG: &str = include_str!("../supervisor.toml"); @@ -89,11 +89,13 @@ fn main() -> Result<()> { .context("Failed to open log file")?; tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) + .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) .with_writer(file) .init(); } else { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) + .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) .init(); } #[cfg(unix)] diff --git a/supervisor/src/supervisor.rs b/supervisor/src/supervisor.rs index 378013c0..676f2fd0 100644 --- a/supervisor/src/supervisor.rs +++ b/supervisor/src/supervisor.rs @@ -55,6 +55,11 @@ impl Supervisor { if self.freezed() { bail!("Supervisor is freezed"); } + if let Some(meta) = tracing::Span::current().metadata() { + info!(current_span = meta.name(), "deploy span"); + } else { + info!("deploy span: none"); + } let id = config.id.clone(); if id.is_empty() { return Err(anyhow::anyhow!("Process ID is empty")); diff --git a/uds2yamux/Cargo.toml b/uds2yamux/Cargo.toml new file mode 100644 index 00000000..ddad689d --- /dev/null +++ b/uds2yamux/Cargo.toml @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: © 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "uds2yamux" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +tokio = { workspace = true, features = ["full"] } +clap = { workspace = true, features = ["derive"] } +anyhow.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true +yamux.workspace = true +tokio-util = { version = "0.7", features = ["compat"] } +futures = "0.3" diff --git a/uds2yamux/src/bin/stress-client.rs b/uds2yamux/src/bin/stress-client.rs new file mode 100644 index 00000000..882c3ccc --- /dev/null +++ b/uds2yamux/src/bin/stress-client.rs @@ -0,0 +1,405 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Stress test client for UDS, TCP, or QUIC connections. +//! +//! Opens N concurrent connections, sends a payload, reads echo, then holds. +//! Reports peak concurrent connections and statistics. +//! +//! Usage: +//! # Direct TCP test +//! stress-client --tcp localhost:8080 --concurrency 1000 --total 10000 +//! +//! # yamux multiplexed test (single connection, unlimited streams) +//! stress-client --yamux localhost:4433 --concurrency 100000 --total 100000 + +use anyhow::{bail, Context, Result}; +use clap::Parser; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::{TcpStream, UnixStream}; +use tokio::sync::Semaphore; + +#[derive(Parser)] +#[command(about = "Stress test client for connection throughput and concurrency limits")] +struct Args { + /// UDS path to connect to + #[arg(long)] + uds: Option, + + /// TCP address to connect to (host:port) + #[arg(long)] + tcp: Option, + + /// yamux-over-TCP address (host:port) — streams multiplexed over TCP connection(s) + #[arg(long)] + yamux: Option, + + /// Number of TCP connections in yamux pool (default 1) + #[arg(long, default_value = "1")] + yamux_conns: usize, + + /// Maximum concurrent connections/streams + #[arg(long, default_value = "1000")] + concurrency: usize, + + /// Total number of connections/streams to make + #[arg(long, default_value = "10000")] + total: u64, + + /// Payload size in bytes + #[arg(long, default_value = "128")] + payload_size: usize, + + /// Hold connection open for this many ms (0 = close immediately after echo) + #[arg(long, default_value = "0")] + hold_ms: u64, + + /// Max new connections per second (0 = unlimited) + #[arg(long, default_value = "0")] + ramp_rate: u64, + + /// Print a status line every N seconds while connections are held (0 = no periodic status) + #[arg(long, default_value = "2")] + status_interval: u64, +} + +#[derive(Clone, Debug)] +enum Target { + Uds(String), + Tcp(String), + Yamux(String), +} + +trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {} +impl AsyncStream for T {} + +async fn connect(target: &Target) -> Result> { + match target { + Target::Uds(path) => { + let stream = UnixStream::connect(path) + .await + .with_context(|| format!("failed to connect to UDS {path}"))?; + Ok(Box::new(stream)) + } + Target::Tcp(addr) => { + let stream = TcpStream::connect(addr) + .await + .with_context(|| format!("failed to connect to TCP {addr}"))?; + Ok(Box::new(stream)) + } + Target::Yamux(_) => { + bail!("use multiplexed connection, not connect()") + } + } +} + +type YamuxRequest = tokio::sync::oneshot::Sender>; + +/// Spawns a single yamux connection driver task. +fn spawn_yamux_driver( + addr: String, +) -> Result<( + tokio::sync::mpsc::Sender, + tokio::task::JoinHandle<()>, +)> { + let (tx, rx) = tokio::sync::mpsc::channel::(1024); + let handle = tokio::spawn(async move { + use tokio_util::compat::TokioAsyncReadCompatExt; + + let tcp = match TcpStream::connect(&addr).await { + Ok(t) => t, + Err(e) => { + eprintln!("yamux driver: failed to connect to {addr}: {e}"); + return; + } + }; + let cfg = yamux::Config::default(); + let mut conn = yamux::Connection::new(tcp.compat(), cfg, yamux::Mode::Client); + let mut rx = rx; + let mut pending: Vec = Vec::new(); + + loop { + use std::task::Poll; + + let result = std::future::poll_fn(|cx| { + match conn.poll_next_inbound(cx) { + Poll::Ready(Some(Ok(_))) => {} + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err::<(), _>(e)), + Poll::Ready(None) => return Poll::Ready(Err(yamux::ConnectionError::Closed)), + Poll::Pending => {} + } + + loop { + match rx.poll_recv(cx) { + Poll::Ready(Some(sender)) => pending.push(sender), + Poll::Ready(None) => { + return Poll::Ready(Err(yamux::ConnectionError::Closed)) + } + Poll::Pending => break, + } + } + + while !pending.is_empty() { + match conn.poll_new_outbound(cx) { + Poll::Ready(result) => { + if let Some(sender) = pending.pop() { + let _ = sender.send(result.map_err(Into::into)); + } + } + Poll::Pending => break, + } + } + + Poll::Pending + }) + .await; + + if let Err(e) = result { + if !matches!(e, yamux::ConnectionError::Closed) { + eprintln!("yamux connection error: {e}"); + } + break; + } + } + }); + Ok((tx, handle)) +} + +/// Pool of yamux connections with round-robin stream opening. +struct YamuxPool { + senders: Vec>, + next: std::sync::atomic::AtomicUsize, +} + +impl YamuxPool { + async fn new(addr: &str, num_conns: usize) -> Result<(Self, Vec>)> { + let mut senders = Vec::with_capacity(num_conns); + let mut handles = Vec::with_capacity(num_conns); + for _ in 0..num_conns { + let (tx, handle) = spawn_yamux_driver(addr.to_string())?; + senders.push(tx); + handles.push(handle); + } + Ok(( + Self { + senders, + next: std::sync::atomic::AtomicUsize::new(0), + }, + handles, + )) + } + + async fn open_stream(&self) -> Result> { + use tokio_util::compat::FuturesAsyncReadCompatExt; + + let idx = self.next.fetch_add(1, Ordering::Relaxed) % self.senders.len(); + let (tx, rx) = tokio::sync::oneshot::channel(); + self.senders[idx] + .send(tx) + .await + .map_err(|_| anyhow::anyhow!("yamux driver closed"))?; + let stream = rx + .await + .map_err(|_| anyhow::anyhow!("yamux driver dropped"))??; + Ok(Box::new(stream.compat())) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + let args = Args::parse(); + + let target = match (&args.uds, &args.tcp, &args.yamux) { + (Some(path), None, None) => Target::Uds(path.clone()), + (None, Some(addr), None) => Target::Tcp(addr.clone()), + (None, None, Some(addr)) => Target::Yamux(addr.clone()), + _ => bail!("specify exactly one of --uds, --tcp, or --yamux"), + }; + + let yamux_conn = if let Target::Yamux(ref addr) = target { + let n = args.yamux_conns; + let (pool, _handles) = YamuxPool::new(addr, n).await?; + eprintln!("yamux pool established to {addr} ({n} TCP connections)"); + Some(Arc::new(pool)) + } else { + None + }; + + let payload = vec![0x42u8; args.payload_size]; + let payload = Arc::new(payload); + let semaphore = Arc::new(Semaphore::new(args.concurrency)); + + let success = Arc::new(AtomicU64::new(0)); + let failed = Arc::new(AtomicU64::new(0)); + let connect_err = Arc::new(AtomicU64::new(0)); + let write_err = Arc::new(AtomicU64::new(0)); + let read_err = Arc::new(AtomicU64::new(0)); + let peak_concurrent = Arc::new(AtomicU64::new(0)); + let current_concurrent = Arc::new(AtomicU64::new(0)); + + // Periodic status reporter + if args.status_interval > 0 && args.hold_ms > 0 { + let current = current_concurrent.clone(); + let peak = peak_concurrent.clone(); + let success = success.clone(); + let failed = failed.clone(); + let interval = args.status_interval; + tokio::spawn(async move { + let start = Instant::now(); + loop { + tokio::time::sleep(Duration::from_secs(interval)).await; + let c = current.load(Ordering::Relaxed); + let p = peak.load(Ordering::Relaxed); + let s = success.load(Ordering::Relaxed); + let f = failed.load(Ordering::Relaxed); + eprintln!( + "[{:>6.1}s] active={c} peak={p} success={s} failed={f}", + start.elapsed().as_secs_f64() + ); + } + }); + } + + let mode = if yamux_conn.is_some() { + "yamux" + } else { + "direct" + }; + eprintln!( + "starting: target={target:?} mode={mode} total={} concurrency={} hold={}ms", + args.total, args.concurrency, args.hold_ms + ); + + let start = Instant::now(); + let mut handles = Vec::with_capacity(args.total as usize); + + let interval_per_conn = if args.ramp_rate > 0 { + Some(Duration::from_secs_f64(1.0 / args.ramp_rate as f64)) + } else { + None + }; + + for i in 0..args.total { + if let Some(interval) = interval_per_conn { + tokio::time::sleep(interval).await; + } + + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let target = target.clone(); + let payload = payload.clone(); + let success = success.clone(); + let failed = failed.clone(); + let connect_err = connect_err.clone(); + let write_err = write_err.clone(); + let read_err = read_err.clone(); + let peak = peak_concurrent.clone(); + let current = current_concurrent.clone(); + let hold_ms = args.hold_ms; + let total = args.total; + let yamux_conn = yamux_conn.clone(); + + handles.push(tokio::spawn(async move { + let cur = current.fetch_add(1, Ordering::Relaxed) + 1; + peak.fetch_max(cur, Ordering::Relaxed); + + let result: Result<()> = async { + let mut stream = if let Some(ref conn) = yamux_conn { + conn.open_stream().await.inspect_err(|e| { + let prev = connect_err.fetch_add(1, Ordering::Relaxed); + if prev < 3 { + eprintln!("connect error #{}: {e:#}", prev + 1); + } + })? + } else { + connect(&target).await.inspect_err(|e| { + let prev = connect_err.fetch_add(1, Ordering::Relaxed); + if prev < 3 { + eprintln!("connect error #{}: {e:#}", prev + 1); + } + })? + }; + stream.write_all(&payload).await.map_err(|e| { + write_err.fetch_add(1, Ordering::Relaxed); + anyhow::anyhow!(e) + })?; + let mut buf = vec![0u8; payload.len()]; + tokio::time::timeout(Duration::from_secs(10), stream.read_exact(&mut buf)) + .await + .map_err(|_| { + read_err.fetch_add(1, Ordering::Relaxed); + anyhow::anyhow!("read timeout") + })? + .map_err(|e| { + read_err.fetch_add(1, Ordering::Relaxed); + anyhow::anyhow!(e) + })?; + if hold_ms > 0 { + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + } + Ok(()) + } + .await; + + current.fetch_sub(1, Ordering::Relaxed); + drop(permit); + + match result { + Ok(()) => { + success.fetch_add(1, Ordering::Relaxed); + } + Err(_) => { + failed.fetch_add(1, Ordering::Relaxed); + } + } + + if (i + 1).is_multiple_of(1000) { + let s = success.load(Ordering::Relaxed); + let f = failed.load(Ordering::Relaxed); + eprintln!("[{:>6}/{:>6}] success={s} failed={f}", i + 1, total); + } + })); + } + + for h in handles { + let _ = h.await; + } + + let elapsed = start.elapsed(); + let s = success.load(Ordering::Relaxed); + let f = failed.load(Ordering::Relaxed); + let ce = connect_err.load(Ordering::Relaxed); + let we = write_err.load(Ordering::Relaxed); + let re = read_err.load(Ordering::Relaxed); + let peak = peak_concurrent.load(Ordering::Relaxed); + + eprintln!(); + eprintln!("=== Results ==="); + eprintln!("target: {target:?}"); + eprintln!("mode: {mode}"); + eprintln!("total: {}", args.total); + eprintln!("concurrency: {}", args.concurrency); + eprintln!("success: {s}"); + eprintln!("failed: {f}"); + eprintln!(" connect_err: {ce}"); + eprintln!(" write_err: {we}"); + eprintln!(" read_err: {re}"); + eprintln!("peak_conns: {peak}"); + eprintln!("elapsed: {elapsed:.2?}"); + if elapsed.as_secs_f64() > 0.0 { + eprintln!( + "throughput: {:.0} conn/s", + s as f64 / elapsed.as_secs_f64() + ); + } + + if f > 0 { + eprintln!("\nWARNING: {f} connections failed!"); + std::process::exit(1); + } + Ok(()) +} diff --git a/uds2yamux/src/bin/vsock-echo.rs b/uds2yamux/src/bin/vsock-echo.rs new file mode 100644 index 00000000..2eb45a07 --- /dev/null +++ b/uds2yamux/src/bin/vsock-echo.rs @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Echo server for testing TCP and yamux connectivity. +//! +//! Usage: +//! vsock-echo --tcp-port 8080 # TCP echo server +//! vsock-echo --yamux-port 9090 # yamux-over-TCP echo server +//! vsock-echo --tcp-port 8080 --yamux-port 9090 # both + +use anyhow::{Context, Result}; +use clap::Parser; +use std::sync::atomic::{AtomicU64, Ordering}; +use tokio::net::TcpListener; +use tracing::{error, info}; + +#[derive(Parser)] +#[command(about = "TCP/yamux echo server for testing")] +struct Args { + /// TCP port to listen on for TCP connections + #[arg(long)] + tcp_port: Option, + + /// TCP port to listen on for yamux-over-TCP connections + #[arg(long)] + yamux_port: Option, +} + +static ACTIVE: AtomicU64 = AtomicU64::new(0); +static TOTAL: AtomicU64 = AtomicU64::new(0); + +async fn run_tcp_echo(port: u16) -> Result<()> { + let listener = TcpListener::bind(format!("0.0.0.0:{port}")) + .await + .with_context(|| format!("failed to bind TCP port {port}"))?; + info!("TCP echo server listening on port {port}"); + + loop { + let (stream, addr) = listener.accept().await?; + let id = TOTAL.fetch_add(1, Ordering::Relaxed); + let active = ACTIVE.fetch_add(1, Ordering::Relaxed) + 1; + if id.is_multiple_of(100) { + info!("tcp #{id} from {addr}, active={active}"); + } + tokio::spawn(async move { + let (mut r, mut w) = tokio::io::split(stream); + let _ = tokio::io::copy(&mut r, &mut w).await; + let remaining = ACTIVE.fetch_sub(1, Ordering::Relaxed) - 1; + if id.is_multiple_of(100) { + info!("tcp #{id} closed, active={remaining}"); + } + }); + } +} + +async fn run_yamux_echo(port: u16) -> Result<()> { + use tokio_util::compat::TokioAsyncReadCompatExt; + + let listener = TcpListener::bind(format!("0.0.0.0:{port}")) + .await + .with_context(|| format!("failed to bind yamux TCP port {port}"))?; + info!("yamux echo server listening on TCP port {port}"); + + loop { + let (tcp_stream, addr) = listener.accept().await?; + info!("yamux: new TCP connection from {addr}"); + tokio::spawn(async move { + let cfg = yamux::Config::default(); + let mut conn = yamux::Connection::new(tcp_stream.compat(), cfg, yamux::Mode::Server); + + loop { + match std::future::poll_fn(|cx| conn.poll_next_inbound(cx)).await { + Some(Ok(stream)) => { + let id = TOTAL.fetch_add(1, Ordering::Relaxed); + let active = ACTIVE.fetch_add(1, Ordering::Relaxed) + 1; + if id.is_multiple_of(100) { + info!("yamux #{id}, active={active}"); + } + tokio::spawn(async move { + let (mut r, mut w) = futures::io::AsyncReadExt::split(stream); + let _ = futures::io::copy(&mut r, &mut w).await; + let remaining = ACTIVE.fetch_sub(1, Ordering::Relaxed) - 1; + if id.is_multiple_of(100) { + info!("yamux #{id} closed, active={remaining}"); + } + }); + } + None => { + info!("yamux connection from {addr} closed"); + break; + } + Some(Err(e)) => { + error!("yamux accept error from {addr}: {e}"); + break; + } + } + } + }); + } +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + let args = Args::parse(); + + if args.tcp_port.is_none() && args.yamux_port.is_none() { + anyhow::bail!("specify at least one of --tcp-port or --yamux-port"); + } + + let mut tasks = Vec::new(); + + if let Some(port) = args.tcp_port { + tasks.push(tokio::spawn(async move { + if let Err(e) = run_tcp_echo(port).await { + error!("TCP echo error: {e}"); + } + })); + } + + if let Some(port) = args.yamux_port { + tasks.push(tokio::spawn(async move { + if let Err(e) = run_yamux_echo(port).await { + error!("yamux echo error: {e}"); + } + })); + } + + for t in tasks { + let _ = t.await; + } + + Ok(()) +} diff --git a/uds2yamux/src/main.rs b/uds2yamux/src/main.rs new file mode 100644 index 00000000..3b5bf08f --- /dev/null +++ b/uds2yamux/src/main.rs @@ -0,0 +1,215 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Forward UDS connections to a remote yamux endpoint over TCP. +//! +//! Maintains a pool of TCP connections and multiplexes each incoming UDS +//! connection onto a yamux stream. + +use anyhow::{Context, Result}; +use clap::Parser; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::net::{TcpStream, UnixListener}; +use tokio::sync::{mpsc, oneshot}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; +use tracing::{error, info}; + +#[derive(Parser)] +#[command(about = "Forward UDS connections to a yamux endpoint")] +struct Args { + /// Path to the Unix domain socket to listen on + #[arg(long)] + uds: String, + + /// Yamux server address (host:port) + #[arg(long)] + yamux_addr: String, + + /// Number of TCP connections in the yamux pool + #[arg(long, default_value = "1")] + yamux_conns: usize, +} + +type YamuxRequest = oneshot::Sender>; + +async fn connect_yamux( + addr: &str, +) -> Result>> { + let stream = TcpStream::connect(addr) + .await + .with_context(|| format!("failed to connect TCP {addr}"))?; + stream + .set_nodelay(true) + .context("failed to set TCP_NODELAY")?; + + let cfg = yamux::Config::default(); + Ok(yamux::Connection::new( + stream.compat(), + cfg, + yamux::Mode::Client, + )) +} + +async fn drive_yamux(addr: String, mut requests: mpsc::Receiver) { + use futures::FutureExt; + + let mut pending: Vec = Vec::new(); + let mut backoff = std::time::Duration::from_secs(1); + let max_backoff = std::time::Duration::from_secs(60); + + loop { + let mut conn = loop { + match connect_yamux(&addr).await { + Ok(conn) => { + backoff = std::time::Duration::from_secs(1); + info!("yamux: connected to {addr}"); + break conn; + } + Err(e) => { + error!("yamux: connect error: {e}"); + let delay = tokio::time::sleep(backoff); + tokio::pin!(delay); + tokio::select! { + _ = &mut delay => { + backoff = std::cmp::min(backoff * 2, max_backoff); + } + request = requests.recv() => { + match request { + Some(reply) => pending.push(reply), + None => return, + } + } + } + } + } + }; + + loop { + let request_fut = if let Some(reply) = pending.pop() { + futures::future::ready(Some(reply)).boxed() + } else { + requests.recv().boxed() + }; + + tokio::select! { + inbound = std::future::poll_fn(|cx| conn.poll_next_inbound(cx)) => { + match inbound { + Some(Ok(stream)) => { + info!("yamux: inbound stream {stream}"); + } + Some(Err(e)) => { + error!("yamux: inbound error: {e}"); + break; + } + None => { + info!("yamux: connection closed by peer"); + break; + } + } + } + request = request_fut => { + let Some(reply) = request else { + return; + }; + let result = std::future::poll_fn(|cx| conn.poll_new_outbound(cx)) + .await + .map_err(|e| anyhow::anyhow!("failed to open yamux stream: {e}")); + let _ = reply.send(result); + } + } + } + } +} + +async fn spawn_yamux_driver(addr: &str) -> Result> { + let (request_tx, request_rx) = mpsc::channel(128); + tokio::spawn(drive_yamux(addr.to_string(), request_rx)); + Ok(request_tx) +} + +struct YamuxPool { + senders: Vec>, + next: AtomicUsize, +} + +impl YamuxPool { + async fn new(addr: &str, num_conns: usize) -> Result { + let mut senders = Vec::with_capacity(num_conns); + for _ in 0..num_conns { + senders.push(spawn_yamux_driver(addr).await?); + } + Ok(Self { + senders, + next: AtomicUsize::new(0), + }) + } + + async fn open_stream(&self) -> Result { + let idx = self.next.fetch_add(1, Ordering::Relaxed) % self.senders.len(); + let (reply_tx, reply_rx) = oneshot::channel(); + self.senders[idx] + .send(reply_tx) + .await + .map_err(|_| anyhow::anyhow!("yamux driver closed"))?; + reply_rx + .await + .map_err(|_| anyhow::anyhow!("yamux driver dropped response channel"))? + } +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + let args = Args::parse(); + if args.yamux_conns == 0 { + anyhow::bail!("yamux_conns must be >= 1"); + } + + let _ = std::fs::remove_file(&args.uds); + + let listener = UnixListener::bind(&args.uds) + .with_context(|| format!("failed to bind UDS at {}", args.uds))?; + + let pool = Arc::new( + YamuxPool::new(&args.yamux_addr, args.yamux_conns) + .await + .context("failed to connect yamux")?, + ); + + info!( + "listening on {}, forwarding to yamux {} (pool size {})", + args.uds, args.yamux_addr, args.yamux_conns + ); + + loop { + let (uds_stream, _) = listener + .accept() + .await + .context("failed to accept UDS connection")?; + + let pool = pool.clone(); + tokio::spawn(async move { + let stream = match pool.open_stream().await { + Ok(stream) => stream, + Err(e) => { + error!("failed to open yamux stream: {e}"); + return; + } + }; + + let stream = stream.compat(); + let (mut sr, mut sw) = tokio::io::split(stream); + let (mut ur, mut uw) = tokio::io::split(uds_stream); + + let r = tokio::select! { + r = tokio::io::copy(&mut ur, &mut sw) => r, + r = tokio::io::copy(&mut sr, &mut uw) => r, + }; + if let Err(e) = r { + error!("bridge error: {e}"); + } + }); + } +}