Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion packages/loro-websocket/src/server/simple-server.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { WebSocketServer, WebSocket } from "ws";
import { randomBytes } from "node:crypto";
import type { RawData } from "ws";
import type { IncomingMessage } from "http";
// no direct CRDT imports here; handled by CrdtDoc implementations
import {
encode,
Expand Down Expand Up @@ -47,6 +48,13 @@ export interface SimpleServerConfig {
crdtType: CrdtType,
auth: Uint8Array
) => Promise<Permission | null>;
/**
* Optional handshake auth: called during WS HTTP upgrade.
* Return true to accept, false to reject.
*/
handshakeAuth?: (
req: IncomingMessage
) => boolean | Promise<boolean>;
}

interface RoomDocument {
Expand Down Expand Up @@ -86,12 +94,28 @@ export class SimpleServer {

start(): Promise<void> {
return new Promise(resolve => {
const options: { port: number; host?: string } = {
const options: { port: number; host?: string; verifyClient?: any } = {
port: this.config.port,
};
if (this.config.host) {
options.host = this.config.host;
}
if (this.config.handshakeAuth) {
options.verifyClient = (
info: { origin: string; secure: boolean; req: IncomingMessage },
cb: (res: boolean, code?: number, message?: string) => void
) => {
Promise.resolve(this.config.handshakeAuth!(info.req))
.then(allowed => {
if (allowed) cb(true);
else cb(false, 401, "Unauthorized");
})
.catch(err => {
console.error("Handshake auth error", err);
cb(false, 500, "Internal Server Error");
});
};
}
this.wss = new WebSocketServer(options);

this.wss.on("connection", ws => {
Expand Down
70 changes: 70 additions & 0 deletions packages/loro-websocket/tests/handshake-auth.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { describe, it, expect, beforeAll, afterAll } from "vitest";
import { WebSocket } from "ws";
import getPort from "get-port";
import { SimpleServer } from "../src/server/simple-server";

// Make WebSocket available globally for the client
Object.defineProperty(globalThis, "WebSocket", {
value: WebSocket,
configurable: true,
writable: true,
});

describe("Handshake Auth", () => {
let server: SimpleServer;
let port: number;

beforeAll(async () => {
port = await getPort();
server = new SimpleServer({
port,
handshakeAuth: req => {
const cookie = req.headers.cookie;
return cookie === "session=valid";
},
});
await server.start();
});

afterAll(async () => {
await server.stop();
}, 10000);

it("should accept connection with valid cookie", async () => {
const ws = new WebSocket(`ws://localhost:${port}`, {
headers: {
Cookie: "session=valid",
},
});

await new Promise<void>((resolve, reject) => {
ws.onopen = () => resolve();
ws.onerror = err => reject(err);
});
ws.close();
});

it("should reject connection with invalid cookie", async () => {
const ws = new WebSocket(`ws://localhost:${port}`, {
headers: {
Cookie: "session=invalid",
},
});

await new Promise<void>((resolve, reject) => {
ws.onopen = () => reject(new Error("Should have failed"));
ws.onerror = err => {
resolve();
};
});
});

it("should reject connection with missing cookie", async () => {
const ws = new WebSocket(`ws://localhost:${port}`);

await new Promise<void>((resolve, reject) => {
ws.onopen = () => reject(new Error("Should have failed"));
ws.onerror = () => resolve();
});
});
});
63 changes: 63 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/loro-websocket-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ tokio-tungstenite = "0.27"
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
loro = "1"
tracing = "0.1"
cookie = "0.18.1"

[dev-dependencies]
loro-websocket-client = { version = "0.1.0", path = "../loro-websocket-client" }
Expand Down
79 changes: 74 additions & 5 deletions rust/loro-websocket-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,40 @@ type LoadFuture<DocCtx> =
type SaveFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'static>>;
type LoadFn<DocCtx> = Arc<dyn Fn(LoadDocArgs) -> LoadFuture<DocCtx> + Send + Sync>;
type SaveFn<DocCtx> = Arc<dyn Fn(SaveDocArgs<DocCtx>) -> SaveFuture + Send + Sync>;

/// Arguments provided to `authenticate`.
pub struct AuthArgs {
pub room: String,
pub crdt: CrdtType,
pub auth: Vec<u8>,
pub conn_id: u64,
}

type AuthFuture =
Pin<Box<dyn Future<Output = Result<Option<Permission>, String>> + Send + 'static>>;
type AuthFn = Arc<dyn Fn(String, CrdtType, Vec<u8>) -> AuthFuture + Send + Sync>;
type AuthFn = Arc<dyn Fn(AuthArgs) -> AuthFuture + Send + Sync>;

/// Arguments provided to `handshake_auth`.
pub struct HandshakeAuthArgs<'a> {
pub workspace: &'a str,
pub token: Option<&'a str>,
pub request: &'a tungstenite::handshake::server::Request,
pub conn_id: u64,
}

type HandshakeAuthFn = dyn Fn(&str, Option<&str>) -> bool + Send + Sync;
type HandshakeAuthFn = dyn Fn(HandshakeAuthArgs) -> bool + Send + Sync;

/// Arguments provided to `on_close_connection`.
pub struct CloseConnectionArgs {
pub workspace: String,
pub conn_id: u64,
pub rooms: Vec<(CrdtType, String)>,
}

type CloseConnectionFuture =
Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'static>>;
type CloseConnectionFn =
Arc<dyn Fn(CloseConnectionArgs) -> CloseConnectionFuture + Send + Sync>;

#[derive(Clone)]
pub struct ServerConfig<DocCtx = ()> {
Expand All @@ -122,9 +151,14 @@ pub struct ServerConfig<DocCtx = ()> {
/// Parameters:
/// - `workspace_id`: extracted from request path `/{workspace}` (empty if missing)
/// - `token`: `token` query parameter if present
/// - `request`: the full HTTP request (headers, uri, etc)
/// - `conn_id`: the connection id
///
/// Return true to accept, false to reject with 401.
pub handshake_auth: Option<Arc<HandshakeAuthFn>>,
/// Optional hook invoked after a connection fully closes.
/// Receives the workspace id, connection id, and rooms the client had joined.
pub on_close_connection: Option<CloseConnectionFn>,
}

// CRDT document abstraction to reduce match-based branching
Expand Down Expand Up @@ -440,6 +474,7 @@ impl<DocCtx> Default for ServerConfig<DocCtx> {
default_permission: Permission::Write,
authenticate: None,
handshake_auth: None,
on_close_connection: None,
}
}
}
Expand Down Expand Up @@ -884,12 +919,18 @@ async fn handle_conn<DocCtx>(
where
DocCtx: Clone + Send + Sync + 'static,
{

// Generate a connection id
let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed);

// Capture config outside of non-async closure
let handshake_auth = registry.config.handshake_auth.clone();
let close_connection = registry.config.on_close_connection.clone();
let workspace_holder: Arc<std::sync::Mutex<Option<String>>> =
Arc::new(std::sync::Mutex::new(None));
let workspace_holder_c = workspace_holder.clone();


let ws = accept_hdr_async(
stream,
move |req: &tungstenite::handshake::server::Request,
Expand Down Expand Up @@ -925,7 +966,12 @@ where
None
});

let allowed = (check)(workspace_id, token);
let allowed = (check)(HandshakeAuthArgs {
workspace: workspace_id,
token,
request: req,
conn_id,
});
if !allowed {
warn!(workspace=%workspace_id, token=?token, "handshake auth denied");
// Build a 401 Unauthorized response
Expand Down Expand Up @@ -971,7 +1017,6 @@ where
}
});

let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
let mut joined_rooms: HashSet<RoomKey> = HashSet::new();

while let Some(msg) = stream.next().await {
Expand Down Expand Up @@ -1001,7 +1046,14 @@ where
let mut permission = h.config.default_permission;
if let Some(auth_fn) = &h.config.authenticate {
let room_str = room.room.clone();
match (auth_fn)(room_str, room.crdt, auth.clone()).await {
match (auth_fn)(AuthArgs {
room: room_str,
crdt: room.crdt,
auth: auth.clone(),
conn_id,
})
.await
{
Ok(Some(p)) => {
permission = p;
}
Expand Down Expand Up @@ -1387,6 +1439,11 @@ where
}
}

let rooms_for_hook: Vec<(CrdtType, String)> = joined_rooms
.into_iter()
.map(|RoomKey { crdt, room }| (crdt, room))
.collect();

// cleanup
{
let mut h = hub.lock().await;
Expand All @@ -1395,6 +1452,18 @@ where
// drop tx to stop writer
drop(tx);
let _ = sink_task.await;

if let Some(hook) = close_connection {
let args = CloseConnectionArgs {
workspace: workspace_id.clone(),
conn_id,
rooms: rooms_for_hook,
};
if let Err(e) = (hook)(args).await {
warn!(conn_id, %e, "on_close_connection hook failed");
}
}

debug!(conn_id, "connection closed and cleaned up");
Ok(())
}
Loading
Loading