|
| 1 | +use std::fs; |
| 2 | +use std::net::SocketAddr; |
| 3 | +use std::path::PathBuf; |
| 4 | +use std::sync::Arc; |
| 5 | + |
| 6 | +use tokio::net::TcpStream; |
| 7 | +use tokio::prelude::*; |
| 8 | + |
| 9 | +use anyhow::{Context, Result}; |
| 10 | +use futures::future::try_join; |
| 11 | +use futures::{StreamExt, TryFutureExt}; |
| 12 | +use structopt::{self, StructOpt}; |
| 13 | +use tracing::{error, info}; |
| 14 | + |
| 15 | +#[derive(StructOpt, Debug)] |
| 16 | +#[structopt(name = "qtun-server")] |
| 17 | +struct Opt { |
| 18 | + /// TLS private key in PEM format |
| 19 | + #[structopt( |
| 20 | + parse(from_os_str), |
| 21 | + short = "k", |
| 22 | + long = "key", |
| 23 | + requires = "cert", |
| 24 | + default_value = "key.der" |
| 25 | + )] |
| 26 | + key: PathBuf, |
| 27 | + /// TLS certificate in PEM format |
| 28 | + #[structopt( |
| 29 | + parse(from_os_str), |
| 30 | + short = "c", |
| 31 | + long = "cert", |
| 32 | + requires = "key", |
| 33 | + default_value = "cert.der" |
| 34 | + )] |
| 35 | + cert: PathBuf, |
| 36 | + /// Enable stateless retries |
| 37 | + #[structopt(long = "stateless-retry")] |
| 38 | + stateless_retry: bool, |
| 39 | + /// Address to listen on |
| 40 | + #[structopt(long = "local", default_value = "0.0.0.0:4433")] |
| 41 | + local: SocketAddr, |
| 42 | + /// Address to listen on |
| 43 | + #[structopt(long = "remote", default_value = "127.0.0.1:8138")] |
| 44 | + remote: SocketAddr, |
| 45 | +} |
| 46 | + |
| 47 | +#[tokio::main] |
| 48 | +async fn main() -> Result<()> { |
| 49 | + let options = Opt::from_args(); |
| 50 | + |
| 51 | + let mut transport_config = quinn::TransportConfig::default(); |
| 52 | + transport_config.stream_window_uni(0); |
| 53 | + let mut server_config = quinn::ServerConfig::default(); |
| 54 | + server_config.transport = Arc::new(transport_config); |
| 55 | + let mut server_config = quinn::ServerConfigBuilder::new(server_config); |
| 56 | + |
| 57 | + if options.stateless_retry { |
| 58 | + server_config.use_stateless_retry(true); |
| 59 | + } |
| 60 | + |
| 61 | + // load certificates |
| 62 | + let key_path = &options.key; |
| 63 | + let cert_path = &options.cert; |
| 64 | + let key = fs::read(key_path).context("failed to read private key")?; |
| 65 | + let key = if key_path.extension().map_or(false, |x| x == "der") { |
| 66 | + quinn::PrivateKey::from_der(&key)? |
| 67 | + } else { |
| 68 | + quinn::PrivateKey::from_pem(&key)? |
| 69 | + }; |
| 70 | + let cert_chain = fs::read(cert_path).context("failed to read certificate chain")?; |
| 71 | + let cert_chain = if cert_path.extension().map_or(false, |x| x == "der") { |
| 72 | + quinn::CertificateChain::from_certs(quinn::Certificate::from_der(&cert_chain)) |
| 73 | + } else { |
| 74 | + quinn::CertificateChain::from_pem(&cert_chain)? |
| 75 | + }; |
| 76 | + server_config.certificate(cert_chain, key)?; |
| 77 | + |
| 78 | + let mut endpoint = quinn::Endpoint::builder(); |
| 79 | + endpoint.listen(server_config.build()); |
| 80 | + |
| 81 | + let remote = Arc::<SocketAddr>::from(options.remote); |
| 82 | + |
| 83 | + let mut incoming = { |
| 84 | + let (endpoint, incoming) = endpoint.bind(&options.local)?; |
| 85 | + info!("listening on {}", endpoint.local_addr()?); |
| 86 | + incoming |
| 87 | + }; |
| 88 | + |
| 89 | + while let Some(conn) = incoming.next().await { |
| 90 | + info!("connection incoming"); |
| 91 | + tokio::spawn( |
| 92 | + handle_connection(remote.clone(), conn).unwrap_or_else(move |e| { |
| 93 | + error!("connection failed: {reason}", reason = e.to_string()) |
| 94 | + }), |
| 95 | + ); |
| 96 | + } |
| 97 | + |
| 98 | + Ok(()) |
| 99 | +} |
| 100 | + |
| 101 | +async fn handle_connection(remote: Arc<SocketAddr>, conn: quinn::Connecting) -> Result<()> { |
| 102 | + let quinn::NewConnection { |
| 103 | + connection: _, |
| 104 | + mut bi_streams, |
| 105 | + .. |
| 106 | + } = conn.await?; |
| 107 | + |
| 108 | + async { |
| 109 | + info!("established"); |
| 110 | + |
| 111 | + // Each stream initiated by the client constitutes a new request. |
| 112 | + while let Some(stream) = bi_streams.next().await { |
| 113 | + let stream = match stream { |
| 114 | + Err(quinn::ConnectionError::ApplicationClosed { .. }) => { |
| 115 | + info!("connection closed"); |
| 116 | + return Ok(()); |
| 117 | + } |
| 118 | + Err(e) => { |
| 119 | + return Err(e); |
| 120 | + } |
| 121 | + Ok(s) => s, |
| 122 | + }; |
| 123 | + tokio::spawn( |
| 124 | + transfer(remote.clone(), stream) |
| 125 | + .unwrap_or_else(move |e| error!("failed: {reason}", reason = e.to_string())), |
| 126 | + ); |
| 127 | + } |
| 128 | + Ok(()) |
| 129 | + } |
| 130 | + .await?; |
| 131 | + |
| 132 | + Ok(()) |
| 133 | +} |
| 134 | + |
| 135 | +async fn transfer( |
| 136 | + remote: Arc<SocketAddr>, |
| 137 | + inbound: (quinn::SendStream, quinn::RecvStream), |
| 138 | +) -> Result<()> { |
| 139 | + let mut outbound = TcpStream::connect(remote.as_ref()).await?; |
| 140 | + |
| 141 | + let (mut wi, mut ri) = inbound; |
| 142 | + let (mut ro, mut wo) = outbound.split(); |
| 143 | + |
| 144 | + let client_to_server = io::copy(&mut ri, &mut wo); |
| 145 | + let server_to_client = io::copy(&mut ro, &mut wi); |
| 146 | + |
| 147 | + try_join(client_to_server, server_to_client).await?; |
| 148 | + |
| 149 | + Ok(()) |
| 150 | +} |
0 commit comments