Make A Basic Websocket Server In Rust
use axum::extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
};
use axum::response::IntoResponse;
use axum::routing::get;
use axum::Router;
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use std::sync::Arc;
use tokio::sync::broadcast;
use tower_http::services::ServeDir;
struct AppState {
tx: broadcast::Sender<String>,
}
#[tokio::main]
async fn main() {
let (tx, _rx) = broadcast::channel(100);
let app_state = Arc::new(AppState { tx });
let app = Router::new()
.route("/ws", get(websocket_handler))
.nest_service("/", ServeDir::new("assets"))
.with_state(app_state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket(socket, state))
}
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
let (mut sender, mut receiver) = stream.split();
let mut rx = state.tx.subscribe();
let tx = state.tx.clone();
let mut send_task = tokio::spawn(async move {
while let Ok(msg) = rx.recv().await {
// In any websocket error, break loop.
if sender.send(Message::Text(msg)).await.is_err() {
break;
}
}
});
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(_text))) = receiver.next().await {
let _ = tx.send(format!("ping received"));
}
});
tokio::select! {
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
};
let _ = state.tx.send(format!("ping start"));
}
-- end of line --