Files
quic_ecs_dt/simulator/src/client.rs
2026-05-12 11:21:40 -04:00

190 lines
7.1 KiB
Rust

use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use anyhow::{anyhow, Context};
use quinn::{ClientConfig, Connection, Endpoint};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{DigitallySignedStruct, SignatureScheme};
use substrate::transport::QuicMessage;
/// QUIC client for driving the substrate from tests, smoke runs, and
/// (eventually) the full Bevy-driven sensor generator.
///
/// `connect` trusts the server's PEM cert by **exact byte match** — using a
/// custom `ServerCertVerifier` that compares the leaf against the cert at
/// `cert_path`. This sidesteps rustls' `CaUsedAsEndEntity` rejection of our
/// self-signed cert (which acts as both trust anchor and leaf) without
/// disabling signature verification or weakening the handshake.
pub struct SimulatorClient {
pub endpoint: Endpoint,
pub conn: Connection,
}
impl SimulatorClient {
pub async fn connect(
server_addr: SocketAddr,
server_name: &str,
cert_path: impl AsRef<Path>,
) -> anyhow::Result<Self> {
let cert_path = cert_path.as_ref();
let cert_pem = std::fs::read(cert_path)
.with_context(|| format!("read trust cert at {}", cert_path.display()))?;
let parsed: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_pem.as_slice())
.collect::<Result<_, _>>()
.with_context(|| format!("parse PEM certs at {}", cert_path.display()))?;
let expected = parsed
.into_iter()
.next()
.ok_or_else(|| anyhow!("no certificates found in {}", cert_path.display()))?;
// Reuse the process-wide rustls provider that `install_crypto_provider`
// (or substrate's main) already installed. Failing to find one here
// means nobody installed a default — caller error.
let provider = rustls::crypto::CryptoProvider::get_default()
.ok_or_else(|| anyhow!("no rustls default crypto provider installed"))?
.clone();
let verifier = Arc::new(TrustExactCert {
expected,
provider: provider.clone(),
});
let rustls_cfg = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.context("rustls client builder")?
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_no_client_auth();
let quic_cfg = quinn::crypto::rustls::QuicClientConfig::try_from(rustls_cfg)
.context("wrap rustls config for QUIC")?;
let client_cfg = ClientConfig::new(Arc::new(quic_cfg));
let bind: SocketAddr = if server_addr.is_ipv6() {
"[::]:0".parse().unwrap()
} else {
"0.0.0.0:0".parse().unwrap()
};
let mut endpoint = Endpoint::client(bind).context("Endpoint::client bind")?;
endpoint.set_default_client_config(client_cfg);
let connecting = endpoint
.connect(server_addr, server_name)
.with_context(|| format!("client connect to {server_addr} as {server_name}"))?;
let conn = connecting.await.context("client TLS handshake")?;
tracing::info!(remote = %conn.remote_address(), "simulator client connected");
Ok(Self { endpoint, conn })
}
/// T1 — send one `QuicMessage` over a QUIC datagram (38 B fixed).
pub fn send_datagram(&self, msg: &QuicMessage) -> anyhow::Result<()> {
let bytes = bytes::Bytes::copy_from_slice(&msg.to_bytes());
self.conn.send_datagram(bytes).context("send_datagram")?;
Ok(())
}
/// T2 — open a unidirectional stream, write each message as 38 B back-to-back,
/// then `finish()` the stream. The substrate sees one or many events per
/// stream, ordered within the stream.
pub async fn send_uni_stream(&self, msgs: &[QuicMessage]) -> anyhow::Result<()> {
let mut send = self.conn.open_uni().await.context("open_uni")?;
for msg in msgs {
send.write_all(&msg.to_bytes())
.await
.context("write QuicMessage to uni stream")?;
}
send.finish().context("finish uni stream")?;
Ok(())
}
/// T3 — open a bidirectional stream, write the command (38 B), finish the
/// send half, then read the substrate's ack (38 B). Errors if the
/// substrate resets the stream (e.g. no handler installed yet) or if the
/// connection drops mid-exchange.
pub async fn request(&self, command: &QuicMessage) -> anyhow::Result<QuicMessage> {
let (mut send, mut recv) = self.conn.open_bi().await.context("open_bi")?;
send.write_all(&command.to_bytes())
.await
.context("write T3 command")?;
send.finish().context("finish T3 send half")?;
let mut buf = [0u8; QuicMessage::WIRE_SIZE];
recv.read_exact(&mut buf)
.await
.context("read T3 ack")?;
let ack = QuicMessage::decode(&buf).context("decode T3 ack")?;
Ok(ack)
}
/// Close the connection gracefully. Use before dropping in tests so the
/// peer's `conn.closed()` resolves cleanly instead of via timeout.
pub async fn close(&self) {
self.conn.close(0u32.into(), b"client done");
self.endpoint.wait_idle().await;
}
}
/// `ServerCertVerifier` that accepts exactly one specific cert by byte
/// equality. Signature verification still runs through the default provider —
/// only the chain-validity check is replaced.
#[derive(Debug)]
struct TrustExactCert {
expected: CertificateDer<'static>,
provider: Arc<rustls::crypto::CryptoProvider>,
}
impl ServerCertVerifier for TrustExactCert {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
if end_entity.as_ref() == self.expected.as_ref() {
Ok(ServerCertVerified::assertion())
} else {
Err(rustls::Error::General(
"server cert does not match trusted dev cert".into(),
))
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.provider.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.provider.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.provider.signature_verification_algorithms.supported_schemes()
}
}