From 471f94c35099c5aecbf5cfbd64c81357de056ba6 Mon Sep 17 00:00:00 2001 From: totodore Date: Tue, 6 May 2025 00:23:09 +0200 Subject: [PATCH 01/31] wip --- Cargo.lock | 446 +++++++++++ crates/socketioxide-postgres/Cargo.toml | 58 ++ crates/socketioxide-postgres/README.md | 0 .../socketioxide-postgres/src/drivers/mod.rs | 15 + .../src/drivers/postgres.rs | 34 + .../socketioxide-postgres/src/drivers/sqlx.rs | 43 ++ crates/socketioxide-postgres/src/lib.rs | 694 ++++++++++++++++++ 7 files changed, 1290 insertions(+) create mode 100644 crates/socketioxide-postgres/Cargo.toml create mode 100644 crates/socketioxide-postgres/README.md create mode 100644 crates/socketioxide-postgres/src/drivers/mod.rs create mode 100644 crates/socketioxide-postgres/src/drivers/postgres.rs create mode 100644 crates/socketioxide-postgres/src/drivers/sqlx.rs create mode 100644 crates/socketioxide-postgres/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index fcbc8e71..20d29952 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,6 +54,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -107,6 +113,15 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -387,6 +402,15 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "const-random" version = "0.1.18" @@ -434,6 +458,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + [[package]] name = "crc16" version = "0.4.0" @@ -494,6 +533,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -635,11 +683,20 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +dependencies = [ + "serde", +] [[package]] name = "engineioxide" @@ -705,6 +762,50 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "etcetera" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +dependencies = [ + "cfg-if", + "home", + "windows-sys 0.48.0", +] + +[[package]] +name = "event-listener" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "float-cmp" version = "0.10.0" @@ -720,6 +821,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -814,6 +921,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -938,6 +1056,20 @@ name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] + +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown 0.15.2", +] [[package]] name = "heaptrack" @@ -952,6 +1084,12 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "hermit-abi" version = "0.5.0" @@ -964,6 +1102,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -973,6 +1120,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "http" version = "1.3.1" @@ -1293,6 +1449,12 @@ version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "litemap" version = "0.7.5" @@ -1581,6 +1743,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot" version = "0.12.3" @@ -1625,6 +1793,24 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -1637,6 +1823,35 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "postgres-protocol" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ff0abab4a9b844b93ef7b81f1efc0a366062aaef2cd702c76256b5dc075c54" +dependencies = [ + "base64 0.22.1", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand 0.9.1", + "sha2", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -1927,6 +2142,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags 2.9.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + [[package]] name = "rustls" version = "0.21.12" @@ -2158,6 +2386,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "slab" version = "0.4.9" @@ -2296,6 +2530,27 @@ dependencies = [ "socketioxide-core", ] +[[package]] +name = "socketioxide-postgres" +version = "0.1.0" +dependencies = [ + "bytes", + "futures-core", + "futures-util", + "pin-project-lite", + "rmp-serde", + "serde", + "smallvec", + "socketioxide", + "socketioxide-core", + "sqlx", + "thiserror 2.0.12", + "tokio", + "tokio-postgres", + "tracing", + "tracing-subscriber", +] + [[package]] name = "socketioxide-redis" version = "0.2.2" @@ -2318,6 +2573,122 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "sqlx" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c3a85280daca669cfd3bcb68a337882a8bc57ec882f72c5d13a430613a738e" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-postgres", +] + +[[package]] +name = "sqlx-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f743f2a3cea30a58cd479013f75550e879009e3a02f616f18ca699335aa248c3" +dependencies = [ + "base64 0.22.1", + "bytes", + "crc", + "crossbeam-queue", + "either", + "event-listener", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashbrown 0.15.2", + "hashlink", + "indexmap 2.9.0", + "log", + "memchr", + "once_cell", + "percent-encoding", + "serde", + "serde_json", + "sha2", + "smallvec", + "thiserror 2.0.12", + "tracing", + "url", +] + +[[package]] +name = "sqlx-macros" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4200e0fde19834956d4252347c12a083bdcb237d7a1a1446bffd8768417dce" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 2.0.100", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ceaa29cade31beca7129b6beeb05737f44f82dbe2a9806ecea5a7093d00b7" +dependencies = [ + "dotenvy", + "either", + "heck", + "hex", + "once_cell", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2", + "sqlx-core", + "sqlx-postgres", + "syn 2.0.100", + "tempfile", + "url", +] + +[[package]] +name = "sqlx-postgres" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0bedbe1bbb5e2615ef347a5e9d8cd7680fb63e77d9dafc0f29be15e53f1ebe6" +dependencies = [ + "atoi", + "base64 0.22.1", + "bitflags 2.9.0", + "byteorder", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand 0.8.5", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 2.0.12", + "tracing", + "whoami", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2407,6 +2778,19 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tempfile" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +dependencies = [ + "fastrand", + "getrandom 0.3.2", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -2561,6 +2945,32 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "tokio-postgres" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c95d533c83082bb6490e0189acaa0bbeef9084e60471b696ca6988cd0541fb0" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "rand 0.9.1", + "socket2", + "tokio", + "tokio-util", + "whoami", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -2855,6 +3265,12 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -2913,12 +3329,33 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "web-sys" +version = "0.3.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "whoami" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6994d13118ab492c3c80c1f81928718159254c53c472bf9ce36f8dae4add02a7" +dependencies = [ + "redox_syscall", + "wasite", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" @@ -3018,6 +3455,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml new file mode 100644 index 00000000..df86ade0 --- /dev/null +++ b/crates/socketioxide-postgres/Cargo.toml @@ -0,0 +1,58 @@ +[package] +name = "socketioxide-postgres" +description = "PostgreSQL adapter for socketioxide" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +repository.workspace = true +homepage.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[features] +sqlx = ["dep:sqlx"] +postgres = ["dep:tokio-postgres"] +default = ["postgres"] + +[dependencies] +socketioxide-core = { version = "0.17", path = "../socketioxide-core", features = [ + "remote-adapter", +] } +futures-core.workspace = true +futures-util.workspace = true +pin-project-lite.workspace = true +serde.workspace = true +smallvec = { workspace = true, features = ["serde"] } +tokio = { workspace = true, features = ["time", "rt", "sync"] } +rmp-serde.workspace = true +tracing.workspace = true +thiserror.workspace = true + +# PostgreSQL implementations +tokio-postgres = { version = "0.7", default-features = false, optional = true, features = [ + "runtime", +] } +sqlx = { version = "0.8", default-features = false, optional = true, features = [ + "postgres", +] } + +[dev-dependencies] +tokio = { workspace = true, features = [ + "macros", + "parking_lot", + "rt-multi-thread", +] } +socketioxide = { path = "../socketioxide", features = [ + "tracing", + "__test_harness", +] } +tracing-subscriber.workspace = true +bytes.workspace = true + +# docs.rs-specific configuration +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/socketioxide-postgres/README.md b/crates/socketioxide-postgres/README.md new file mode 100644 index 00000000..e69de29b diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs new file mode 100644 index 00000000..976f9821 --- /dev/null +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -0,0 +1,15 @@ +mod postgres; +mod sqlx; + +pub type ChanItem = (String, String); + +/// The driver trait can be used to support different LISTEN/NOTIFY backends. +/// It must share handlers/connection between its clones. +pub trait Driver: Clone + Send + Sync + 'static { + type Error: std::error::Error + Send + 'static; + + fn init(&self, table: &str, channels: &[&str]) + -> impl Future>; + fn notify(&self, channel: &str, message: &str) + -> impl Future>; +} diff --git a/crates/socketioxide-postgres/src/drivers/postgres.rs b/crates/socketioxide-postgres/src/drivers/postgres.rs new file mode 100644 index 00000000..8e5b6408 --- /dev/null +++ b/crates/socketioxide-postgres/src/drivers/postgres.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use tokio_postgres::{Client, Connection}; + +use crate::PostgresAdapterConfig; + +use super::Driver; + +#[derive(Debug, Clone)] +pub struct PostgresDriver { + client: Arc, +} + +impl PostgresDriver { + pub fn new(client: Client, connection: Connection) -> Self { + PostgresDriver { + client: Arc::new(client), + } + } +} + +impl Driver for PostgresDriver { + type Error = tokio_postgres::Error; + async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { + self.client + .execute("CREATE TABLE $1 IF NOT EXISTS", &[&table]) + .await?; + Ok(()) + } + + async fn notify(&self, channel: &str, msg: &str) -> Result<(), Self::Error> { + todo!() + } +} diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs new file mode 100644 index 00000000..fa63e56b --- /dev/null +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -0,0 +1,43 @@ +use std::{collections::HashMap, sync::Arc}; + +use sqlx::{PgPool, postgres::PgListener}; +use tokio::sync::mpsc; + +use super::{ChanItem, Driver}; + +#[derive(Debug, Clone)] +pub struct SqlxDriver { + client: PgPool, +} +impl SqlxDriver { + pub fn new(client: PgPool) -> Self { + Self { client } + } + + async fn spawn_listener(&self, mut listener: PgListener, tx: mpsc::Sender) { + while let Ok(notif) = listener + .recv() + .await + .inspect_err(|e| tracing::warn!(?e, "sqlx listener error")) + {} + } +} + +impl Driver for SqlxDriver { + type Error = sqlx::Error; + + async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { + sqlx::query("CREATE TABLE $1 IF NOT EXISTS") + .bind(&table) + .execute(&self.client) + .await?; + let mut listener = PgListener::connect_with(&self.client).await?; + listener.listen_all(channels.iter().copied()).await?; + + Ok(()) + } + + async fn notify(&self, channel: &str, msg: &str) -> Result<(), Self::Error> { + todo!() + } +} diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs new file mode 100644 index 00000000..a3441158 --- /dev/null +++ b/crates/socketioxide-postgres/src/lib.rs @@ -0,0 +1,694 @@ +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![warn( + clippy::all, + clippy::todo, + clippy::empty_enum, + clippy::mem_forget, + clippy::unused_self, + clippy::filter_map_next, + clippy::needless_continue, + clippy::needless_borrow, + clippy::match_wildcard_for_single_variants, + clippy::if_let_mutex, + clippy::await_holding_lock, + clippy::match_on_vec_items, + clippy::imprecise_flops, + clippy::suboptimal_flops, + clippy::lossy_float_literal, + clippy::rest_pat_in_fully_bound_structs, + clippy::fn_params_excessive_bools, + clippy::exit, + clippy::inefficient_to_string, + clippy::linkedlist, + clippy::macro_use_imports, + clippy::option_option, + clippy::verbose_file_reads, + clippy::unnested_or_patterns, + rust_2018_idioms, + future_incompatible, + nonstandard_style, + missing_docs +)] +//! + +use drivers::Driver; + +use futures_core::Stream; +use serde::{Serialize, de::DeserializeOwned}; +use socketioxide_core::{ + Sid, Uid, + adapter::{ + BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room, + RoomParam, SocketEmitter, Spawnable, + errors::AdapterError, + remote_packet::{ + RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, + ResponseTypeId, + }, + }, + packet::Packet, +}; +use std::{ + borrow::Cow, + collections::HashMap, + fmt, future, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + time::{Duration, Instant}, +}; +use tokio::sync::mpsc; + +mod drivers; + +/// The configuration of the [`MongoDbAdapter`]. +#[derive(Debug, Clone)] +pub struct PostgresAdapterConfig { + /// The heartbeat timeout duration. If a remote node does not respond within this duration, + /// it will be considered disconnected. Default is 60 seconds. + pub hb_timeout: Duration, + /// The heartbeat interval duration. The current node will broadcast a heartbeat to the + /// remote nodes at this interval. Default is 10 seconds. + pub hb_interval: Duration, + /// The request timeout. When expecting a response from remote nodes, if they do not respond within + /// this duration, the request will be considered failed. Default is 5 seconds. + pub request_timeout: Duration, + /// The channel size used to receive ack responses. Default is 255. + /// + /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster + /// than you poll them with the returned stream, you might want to increase this value. + pub ack_response_buffer: usize, + /// The table name used to store socket.io attachments. Default is "socket_io_attachments". + pub table_name: Cow<'static, str>, + /// The prefix used for the channels. Default is "socket.io". + pub prefix: Cow<'static, str>, + /// The treshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: + /// + pub payload_treshold: usize, + /// The duration between cleanup queries on the + pub cleanup_intervals: Duration, +} + +/// Represent any error that might happen when using this adapter. +#[derive(thiserror::Error)] +pub enum Error { + /// Mongo driver error + #[error("driver error: {0}")] + Driver(D::Error), + /// Packet encoding error + #[error("packet encoding error: {0}")] + Encode(#[from] rmp_serde::encode::Error), + /// Packet decoding error + #[error("packet decoding error: {0}")] + Decode(#[from] rmp_serde::decode::Error), +} + +impl Error { + fn from_driver(err: R::Error) -> Self { + Self::Driver(err) + } +} +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Driver(err) => write!(f, "Driver error: {:?}", err), + Self::Decode(err) => write!(f, "Decode error: {:?}", err), + Self::Encode(err) => write!(f, "Encode error: {:?}", err), + } + } +} + +impl From> for AdapterError { + fn from(err: Error) -> Self { + AdapterError::from(Box::new(err) as Box) + } +} + +pub(crate) type ResponseHandlers = HashMap>; + +/// The postgres adapter implementation. +/// It is generic over the [`Driver`] used to communicate with the postgres server. +/// And over the [`SocketEmitter`] used to communicate with the local server. This allows to +/// avoid cyclic dependencies between the adapter, `socketioxide-core` and `socketioxide` crates. +pub struct CustomPostgresAdapter { + /// The driver used by the adapter. This is used to communicate with the postgres server. + /// All the postgres adapter instances share the same driver. + driver: D, + /// The configuration of the adapter. + config: PostgresAdapterConfig, + /// A unique identifier for the adapter to identify itself in the postgres server. + uid: Uid, + /// The local adapter, used to manage local rooms and socket stores. + local: CoreLocalAdapter, + /// A map of nodes liveness, with the last time remote nodes were seen alive. + nodes_liveness: Mutex>, + /// A map of response handlers used to await for responses from the remote servers. + responses: Arc>, +} + +impl DefinedAdapter for CustomPostgresAdapter {} +impl CoreAdapter for CustomPostgresAdapter { + type Error = Error; + type State = PostgresAdapterCtr; + type AckStream = AckStream; + type InitRes = InitRes; + + fn new(state: &Self::State, local: CoreLocalAdapter) -> Self { + let uid = local.server_id(); + Self { + local, + uid, + driver: state.driver.clone(), + config: state.config.clone(), + nodes_liveness: Mutex::new(Vec::new()), + responses: Arc::new(Mutex::new(HashMap::new())), + } + } + + fn init(self: Arc, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes { + let fut = async move { + let stream = self.driver.watch(self.uid, self.local.path()).await?; + tokio::spawn(self.clone().handle_ev_stream(stream)); + tokio::spawn(self.clone().heartbeat_job()); + + // Send initial heartbeat when starting. + self.emit_init_heartbeat().await.map_err(|e| match e { + Error::Driver(e) => e, + Error::Encode(_) | Error::Decode(_) => unreachable!(), + })?; + + on_success(); + Ok(()) + }; + InitRes(Box::pin(fut)) + } + + async fn close(&self) -> Result<(), Self::Error> { + Ok(()) + } + + /// Get the number of servers by iterating over the node liveness heartbeats. + async fn server_count(&self) -> Result { + let treshold = std::time::Instant::now() - self.config.hb_timeout; + let mut nodes_liveness = self.nodes_liveness.lock().unwrap(); + nodes_liveness.retain(|(_, v)| v > &treshold); + Ok((nodes_liveness.len() + 1) as u16) + } + + /// Broadcast a packet to all the servers to send them through their sockets. + async fn broadcast( + &self, + packet: Packet, + opts: BroadcastOptions, + ) -> Result<(), BroadcastError> { + if !opts.is_local(self.uid) { + let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts); + self.send_req(req, None).await.map_err(AdapterError::from)?; + } + + self.local.broadcast(packet, opts)?; + Ok(()) + } + + /// Broadcast a packet to all the servers to send them through their sockets. + /// + /// Returns a Stream that is a combination of the local ack stream and a remote ack stream. + /// Here is a specific protocol in order to know how many message the server expect to close + /// the stream at the right time: + /// * Get the number `n` of remote servers. + /// * Send the broadcast request. + /// * Expect `n` `BroadcastAckCount` response in the stream to know the number `m` of expected ack responses. + /// * Expect `sum(m)` broadcast counts sent by the servers. + /// + /// Example with 3 remote servers (n = 3): + /// ```text + /// +---+ +---+ +---+ + /// | A | | B | | C | + /// +---+ +---+ +---+ + /// | | | + /// |---BroadcastWithAck--->| | + /// |---BroadcastWithAck--------------------------->| + /// | | | + /// |<-BroadcastAckCount(2)-| (n = 2; m = 2) | + /// |<-BroadcastAckCount(2)-------(n = 2; m = 4)----| + /// | | | + /// |<----------------Ack---------------------------| + /// |<----------------Ack---| | + /// | | | + /// |<----------------Ack---------------------------| + /// |<----------------Ack---| | + async fn broadcast_with_ack( + &self, + packet: Packet, + opts: BroadcastOptions, + timeout: Option, + ) -> Result { + if opts.is_local(self.uid) { + tracing::debug!(?opts, "broadcast with ack is local"); + let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); + let stream = AckStream::new_local(local); + return Ok(stream); + } + let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts); + let req_id = req.id; + + let remote_serv_cnt = self.server_count().await?.saturating_sub(1); + tracing::trace!(?remote_serv_cnt, "expecting acks from remote servers"); + + let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize); + self.responses.lock().unwrap().insert(req_id, tx); + self.send_req(req, None).await?; + let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); + + Ok(AckStream::new( + local, + rx, + self.config.request_timeout, + remote_serv_cnt, + req_id, + self.responses.clone(), + )) + } + + async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> { + if !opts.is_local(self.uid) { + let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts); + self.send_req(req, None).await.map_err(AdapterError::from)?; + } + self.local + .disconnect_socket(opts) + .map_err(BroadcastError::Socket)?; + + Ok(()) + } + + async fn rooms(&self, opts: BroadcastOptions) -> Result, Self::Error> { + if opts.is_local(self.uid) { + return Ok(self.local.rooms(opts).into_iter().collect()); + } + let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts); + let req_id = req.id; + + // First get the remote stream because mongodb might send + // the responses before subscription is done. + let stream = self + .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id) + .await?; + self.send_req(req, opts.server_id).await?; + let local = self.local.rooms(opts); + let rooms = stream + .filter_map(|item| std::future::ready(item.into_rooms())) + .fold(local, |mut acc, item| async move { + acc.extend(item); + acc + }) + .await; + Ok(Vec::from_iter(rooms)) + } + + async fn add_sockets( + &self, + opts: BroadcastOptions, + rooms: impl RoomParam, + ) -> Result<(), Self::Error> { + let rooms: Vec = rooms.into_room_iter().collect(); + if !opts.is_local(self.uid) { + let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts); + self.send_req(req, opts.server_id).await?; + } + self.local.add_sockets(opts, rooms); + Ok(()) + } + + async fn del_sockets( + &self, + opts: BroadcastOptions, + rooms: impl RoomParam, + ) -> Result<(), Self::Error> { + let rooms: Vec = rooms.into_room_iter().collect(); + if !opts.is_local(self.uid) { + let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts); + self.send_req(req, opts.server_id).await?; + } + self.local.del_sockets(opts, rooms); + Ok(()) + } + + async fn fetch_sockets( + &self, + opts: BroadcastOptions, + ) -> Result, Self::Error> { + if opts.is_local(self.uid) { + return Ok(self.local.fetch_sockets(opts)); + } + let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts); + // First get the remote stream because mongodb might send + // the responses before subscription is done. + let remote = self + .get_res::(req.id, ResponseTypeId::FetchSockets, opts.server_id) + .await?; + + self.send_req(req, opts.server_id).await?; + let local = self.local.fetch_sockets(opts); + let sockets = remote + .filter_map(|item| future::ready(item.into_fetch_sockets())) + .fold(local, |mut acc, item| async move { + acc.extend(item); + acc + }) + .await; + Ok(sockets) + } + + fn get_local(&self) -> &CoreLocalAdapter { + &self.local + } +} + +impl CustomPostgresAdapter { + async fn heartbeat_job(self: Arc) -> Result<(), Error> { + let mut interval = tokio::time::interval(self.config.hb_interval); + interval.tick().await; // first tick yields immediately + loop { + interval.tick().await; + self.emit_heartbeat(None).await?; + } + } + + async fn handle_ev_stream( + self: Arc, + mut stream: impl Stream> + Unpin, + ) { + while let Some(item) = stream.next().await { + match item { + Ok(Item { + header: ItemHeader::Req { target, .. }, + data, + .. + }) if target.is_none_or(|id| id == self.uid) => { + tracing::debug!(?target, "request header"); + if let Err(e) = self.recv_req(data).await { + tracing::warn!("error receiving request from driver: {e}"); + } + } + Ok(Item { + header: ItemHeader::Req { target, .. }, + .. + }) => { + tracing::debug!( + ?target, + "receiving request which is not for us, skipping..." + ); + } + Ok( + item @ Item { + header: ItemHeader::Res { request, .. }, + .. + }, + ) => { + tracing::trace!(?request, "received response"); + let handlers = self.responses.lock().unwrap(); + if let Some(tx) = handlers.get(&request) { + if let Err(e) = tx.try_send(item) { + tracing::warn!("error sending response to handler: {e}"); + } + } else { + tracing::warn!(?request, ?handlers, "could not find req handler"); + } + } + Err(e) => { + tracing::warn!("error receiving event from driver: {e}"); + } + } + } + } + + async fn recv_req(self: &Arc, req: Vec) -> Result<(), Error> { + let req = rmp_serde::from_slice::(&req)?; + tracing::trace!(?req, "incoming request"); + match (req.r#type, req.opts) { + (RequestTypeIn::Broadcast(p), Some(opts)) => self.recv_broadcast(opts, p), + (RequestTypeIn::BroadcastWithAck(p), Some(opts)) => self + .clone() + .recv_broadcast_with_ack(req.node_id, req.id, p, opts), + (RequestTypeIn::DisconnectSockets, Some(opts)) => self.recv_disconnect_sockets(opts), + (RequestTypeIn::AllRooms, Some(opts)) => self.recv_rooms(req.node_id, req.id, opts), + (RequestTypeIn::AddSockets(rooms), Some(opts)) => self.recv_add_sockets(opts, rooms), + (RequestTypeIn::DelSockets(rooms), Some(opts)) => self.recv_del_sockets(opts, rooms), + (RequestTypeIn::FetchSockets, Some(opts)) => { + self.recv_fetch_sockets(req.node_id, req.id, opts) + } + req_type @ (RequestTypeIn::Heartbeat | RequestTypeIn::InitHeartbeat, _) => { + self.recv_heartbeat(req_type.0, req.node_id) + } + _ => (), + } + Ok(()) + } + + fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) { + tracing::trace!(?opts, "incoming broadcast"); + if let Err(e) = self.local.broadcast(packet, opts) { + let ns = self.local.path(); + tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e); + } + } + + fn recv_disconnect_sockets(&self, opts: BroadcastOptions) { + if let Err(e) = self.local.disconnect_socket(opts) { + let ns = self.local.path(); + tracing::warn!( + ?self.uid, + ?ns, + "remote request disconnect sockets handler: {:?}", + e + ); + } + } + + fn recv_broadcast_with_ack( + self: Arc, + origin: Uid, + req_id: Sid, + packet: Packet, + opts: BroadcastOptions, + ) { + let (stream, count) = self.local.broadcast_with_ack(packet, opts, None); + tokio::spawn(async move { + let on_err = |err| { + let ns = self.local.path(); + tracing::warn!( + ?self.uid, + ?ns, + "remote request broadcast with ack handler errors: {:?}", + err + ); + }; + // First send the count of expected acks to the server that sent the request. + // This is used to keep track of the number of expected acks. + let res = Response { + r#type: ResponseType::<()>::BroadcastAckCount(count), + node_id: self.uid, + }; + if let Err(err) = self.send_res(req_id, origin, res).await { + on_err(err); + return; + } + + // Then send the acks as they are received. + futures_util::pin_mut!(stream); + while let Some(ack) = stream.next().await { + let res = Response { + r#type: ResponseType::BroadcastAck(ack), + node_id: self.uid, + }; + if let Err(err) = self.send_res(req_id, origin, res).await { + on_err(err); + return; + } + } + }); + } + + fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) { + let rooms = self.local.rooms(opts); + let res = Response { + r#type: ResponseType::<()>::AllRooms(rooms), + node_id: self.uid, + }; + let fut = self.send_res(req_id, origin, res); + let ns = self.local.path().clone(); + let uid = self.uid; + tokio::spawn(async move { + if let Err(err) = fut.await { + tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err); + } + }); + } + + fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec) { + self.local.add_sockets(opts, rooms); + } + + fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec) { + self.local.del_sockets(opts, rooms); + } + fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) { + let sockets = self.local.fetch_sockets(opts); + let res = Response { + node_id: self.uid, + r#type: ResponseType::FetchSockets(sockets), + }; + let fut = self.send_res(req_id, origin, res); + let ns = self.local.path().clone(); + let uid = self.uid; + tokio::spawn(async move { + if let Err(err) = fut.await { + tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err); + } + }); + } + + /// Receive a heartbeat from a remote node. + /// It might be a FirstHeartbeat packet, in which case we are re-emitting a heartbeat to the remote node. + fn recv_heartbeat(self: &Arc, req_type: RequestTypeIn, origin: Uid) { + tracing::debug!(?req_type, "{:?} received", req_type); + let mut node_liveness = self.nodes_liveness.lock().unwrap(); + // Even with a FirstHeartbeat packet we first consume the node liveness to + // ensure that the node is not already in the list. + for (id, liveness) in node_liveness.iter_mut() { + if *id == origin { + *liveness = Instant::now(); + return; + } + } + + node_liveness.push((origin, Instant::now())); + + if matches!(req_type, RequestTypeIn::InitHeartbeat) { + tracing::debug!( + ?origin, + "initial heartbeat detected, saying hello to the new node" + ); + + let this = self.clone(); + tokio::spawn(async move { + if let Err(err) = this.emit_heartbeat(Some(origin)).await { + tracing::warn!( + "could not re-emit heartbeat after new node detection: {:?}", + err + ); + } + }); + } + } + + /// Send a request to a specific target node or broadcast it to all nodes if no target is specified. + async fn send_req(&self, req: RequestOut<'_>, target: Option) -> Result<(), Error> { + tracing::trace!(?req, "sending request"); + let head = ItemHeader::Req { target }; + let req = self.new_packet(head, &req)?; + self.driver.emit(&req).await.map_err(Error::from_driver)?; + Ok(()) + } + + /// Send a response to the node that sent the request. + fn send_res( + &self, + req_id: Sid, + req_origin: Uid, + res: Response, + ) -> impl Future>> + Send + 'static { + tracing::trace!(?res, "sending response for {req_id} req to {req_origin}"); + let driver = self.driver.clone(); + let head = ItemHeader::Res { + request: req_id, + target: req_origin, + }; + let res = self.new_packet(head, &res); + + async move { + driver.emit(&res?).await.map_err(Error::from_driver)?; + Ok(()) + } + } + + /// Await for all the responses from the remote servers. + /// If the target node is specified, only await for the response from that node. + async fn get_res( + &self, + req_id: Sid, + response_type: ResponseTypeId, + target: Option, + ) -> Result>, Error> { + // Check for specific target node + let remote_serv_cnt = if target.is_none() { + self.server_count().await?.saturating_sub(1) as usize + } else { + 1 + }; + let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1)); + self.responses.lock().unwrap().insert(req_id, tx); + let stream = ChanStream::new(rx) + .filter_map(|Item { header, data, .. }| { + let data = match rmp_serde::from_slice::>(&data) { + Ok(data) => Some(data), + Err(e) => { + tracing::warn!(header = ?header, "error decoding response: {e}"); + None + } + }; + future::ready(data) + }) + .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type)) + .take(remote_serv_cnt) + .take_until(tokio::time::sleep(self.config.request_timeout)); + let stream = DropStream::new(stream, self.responses.clone(), req_id); + Ok(stream) + } + + /// Emit a heartbeat to the specified target node or broadcast to all nodes. + async fn emit_heartbeat(&self, target: Option) -> Result<(), Error> { + // Send heartbeat when starting. + self.send_req( + RequestOut::new_empty(self.uid, RequestTypeOut::Heartbeat), + target, + ) + .await + } + + /// Emit an initial heartbeat to all nodes. + async fn emit_init_heartbeat(&self) -> Result<(), Error> { + // Send initial heartbeat when starting. + self.send_req( + RequestOut::new_empty(self.uid, RequestTypeOut::InitHeartbeat), + None, + ) + .await + } + fn new_packet(&self, head: ItemHeader, data: &impl Serialize) -> Result> { + let ns = &self.local.path(); + let uid = self.uid; + } +} + +/// The result of the init future. +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct InitRes(futures_core::future::BoxFuture<'static, Result<(), D::Error>>); + +impl Future for InitRes { + type Output = Result<(), D::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0.as_mut().poll(cx) + } +} +impl Spawnable for InitRes { + fn spawn(self) { + tokio::spawn(async move { + if let Err(e) = self.0.await { + tracing::error!("error initializing adapter: {e}"); + } + }); + } +} From da03dc8b4013d3136bb2f15efdbdd7e6c6cf24f7 Mon Sep 17 00:00:00 2001 From: totodore Date: Wed, 7 May 2025 16:35:04 +0200 Subject: [PATCH 02/31] feat(adapter/postgres): `Driver` wip --- .../socketioxide-postgres/src/drivers/mod.rs | 9 ++ .../src/drivers/postgres.rs | 3 +- .../socketioxide-postgres/src/drivers/sqlx.rs | 87 ++++++++++++++++--- 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 976f9821..10c9333a 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,3 +1,5 @@ +use serde::de::DeserializeOwned; + mod postgres; mod sqlx; @@ -7,9 +9,16 @@ pub type ChanItem = (String, String); /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { type Error: std::error::Error + Send + 'static; + type NotifStream: futures_core::Stream + Send + 'static; fn init(&self, table: &str, channels: &[&str]) -> impl Future>; + + fn listen( + &self, + channel: &str, + ) -> impl Future, Self::Error>>; + fn notify(&self, channel: &str, message: &str) -> impl Future>; } diff --git a/crates/socketioxide-postgres/src/drivers/postgres.rs b/crates/socketioxide-postgres/src/drivers/postgres.rs index 8e5b6408..86dda23f 100644 --- a/crates/socketioxide-postgres/src/drivers/postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/postgres.rs @@ -2,8 +2,6 @@ use std::sync::Arc; use tokio_postgres::{Client, Connection}; -use crate::PostgresAdapterConfig; - use super::Driver; #[derive(Debug, Clone)] @@ -21,6 +19,7 @@ impl PostgresDriver { impl Driver for PostgresDriver { type Error = tokio_postgres::Error; + async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { self.client .execute("CREATE TABLE $1 IF NOT EXISTS", &[&table]) diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index fa63e56b..439f0ed3 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -1,31 +1,66 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + marker::PhantomData, + sync::{Arc, RwLock}, +}; -use sqlx::{PgPool, postgres::PgListener}; +use futures_core::Stream; +use serde::de::DeserializeOwned; +use sqlx::{ + PgPool, + postgres::{PgListener, PgNotification}, +}; use tokio::sync::mpsc; -use super::{ChanItem, Driver}; +use super::Driver; +type HandlerMap = HashMap>; #[derive(Debug, Clone)] pub struct SqlxDriver { client: PgPool, + handlers: Arc>, } impl SqlxDriver { pub fn new(client: PgPool) -> Self { - Self { client } + Self { + client, + handlers: Arc::new(RwLock::new(HashMap::new())), + } } +} - async fn spawn_listener(&self, mut listener: PgListener, tx: mpsc::Sender) { - while let Ok(notif) = listener - .recv() - .await - .inspect_err(|e| tracing::warn!(?e, "sqlx listener error")) - {} +pin_project_lite::pin_project! { + pub struct NotifStream { + #[pin] + rx: tokio::sync::mpsc::Receiver, + _phantom: std::marker::PhantomData T> + } +} +impl Stream for NotifStream { + type Item = T; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.rx.poll_recv(cx) { + std::task::Poll::Ready(_) => todo!(), + std::task::Poll::Pending => todo!(), + } + } +} +impl NotifStream { + pub fn new(rx: mpsc::Receiver) -> Self { + NotifStream { + rx, + _phantom: PhantomData::default(), + } } } impl Driver for SqlxDriver { type Error = sqlx::Error; - + type NotifStream = NotifStream; async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { sqlx::query("CREATE TABLE $1 IF NOT EXISTS") .bind(&table) @@ -33,11 +68,39 @@ impl Driver for SqlxDriver { .await?; let mut listener = PgListener::connect_with(&self.client).await?; listener.listen_all(channels.iter().copied()).await?; + tokio::spawn(spawn_listener(self.handlers.clone(), listener)); Ok(()) } + async fn listen( + &self, + channel: &str, + ) -> Result, Self::Error> { + let (tx, rx) = mpsc::channel(255); + self.handlers.write().unwrap().insert(channel.into(), tx); + Ok(NotifStream::new(rx)) + } async fn notify(&self, channel: &str, msg: &str) -> Result<(), Self::Error> { - todo!() + sqlx::query("NOTIFY $1 $2") + .bind(channel) + .bind(msg) + .execute(&self.client) + .await?; + Ok(()) + } +} + +async fn spawn_listener(handlers: Arc>, mut listener: PgListener) { + while let Ok(notif) = listener + .recv() + .await + .inspect_err(|e| tracing::warn!(?e, "sqlx listener error")) + { + if let Some(tx) = handlers.read().unwrap().get(notif.channel()) { + tx.try_send(notif); + } else { + tracing::warn!("handler not found for channel {}", notif.channel()); + } } } From 092e4474d14bbd9a50e5a64c086a57ae9db26da9 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 22 Jun 2025 20:44:38 +0200 Subject: [PATCH 03/31] wip --- Cargo.lock | 1 + crates/socketioxide-postgres/Cargo.toml | 1 + .../socketioxide-postgres/src/drivers/mod.rs | 17 +- .../socketioxide-postgres/src/drivers/sqlx.rs | 41 ++- crates/socketioxide-postgres/src/lib.rs | 109 ++----- crates/socketioxide-postgres/src/stream.rs | 265 ++++++++++++++++++ 6 files changed, 333 insertions(+), 101 deletions(-) create mode 100644 crates/socketioxide-postgres/src/stream.rs diff --git a/Cargo.lock b/Cargo.lock index 20d29952..013253bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2540,6 +2540,7 @@ dependencies = [ "pin-project-lite", "rmp-serde", "serde", + "serde_json", "smallvec", "socketioxide", "socketioxide-core", diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index df86ade0..537fab4b 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -25,6 +25,7 @@ futures-core.workspace = true futures-util.workspace = true pin-project-lite.workspace = true serde.workspace = true +serde_json.workspace = true smallvec = { workspace = true, features = ["serde"] } tokio = { workspace = true, features = ["time", "rt", "sync"] } rmp-serde.workspace = true diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 10c9333a..d038e341 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,15 +1,19 @@ -use serde::de::DeserializeOwned; +use futures_core::Stream; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; mod postgres; mod sqlx; pub type ChanItem = (String, String); +#[derive(Deserialize)] +pub struct Item {} + /// The driver trait can be used to support different LISTEN/NOTIFY backends. /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { type Error: std::error::Error + Send + 'static; - type NotifStream: futures_core::Stream + Send + 'static; + type NotifStream: Stream + Send + 'static; fn init(&self, table: &str, channels: &[&str]) -> impl Future>; @@ -17,8 +21,11 @@ pub trait Driver: Clone + Send + Sync + 'static { fn listen( &self, channel: &str, - ) -> impl Future, Self::Error>>; + ) -> impl Future, Self::Error>> + Send; - fn notify(&self, channel: &str, message: &str) - -> impl Future>; + fn notify( + &self, + channel: &str, + message: &T, + ) -> impl Future> + Send; } diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index 439f0ed3..c798fa0b 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -5,26 +5,30 @@ use std::{ }; use futures_core::Stream; -use serde::de::DeserializeOwned; +use serde::{Serialize, de::DeserializeOwned}; use sqlx::{ PgPool, postgres::{PgListener, PgNotification}, }; use tokio::sync::mpsc; +use crate::PostgresAdapterConfig; + use super::Driver; -type HandlerMap = HashMap>; +type HandlerMap = HashMap>; #[derive(Debug, Clone)] pub struct SqlxDriver { client: PgPool, handlers: Arc>, + config: PostgresAdapterConfig, } impl SqlxDriver { - pub fn new(client: PgPool) -> Self { + pub fn new(client: PgPool, config: PostgresAdapterConfig) -> Self { Self { client, handlers: Arc::new(RwLock::new(HashMap::new())), + config, } } } @@ -32,7 +36,7 @@ impl SqlxDriver { pin_project_lite::pin_project! { pub struct NotifStream { #[pin] - rx: tokio::sync::mpsc::Receiver, + rx: mpsc::UnboundedReceiver, _phantom: std::marker::PhantomData T> } } @@ -50,7 +54,7 @@ impl Stream for NotifStream { } } impl NotifStream { - pub fn new(rx: mpsc::Receiver) -> Self { + pub fn new(rx: mpsc::UnboundedReceiver) -> Self { NotifStream { rx, _phantom: PhantomData::default(), @@ -76,18 +80,27 @@ impl Driver for SqlxDriver { &self, channel: &str, ) -> Result, Self::Error> { - let (tx, rx) = mpsc::channel(255); + let (tx, rx) = mpsc::unbounded_channel(); self.handlers.write().unwrap().insert(channel.into(), tx); Ok(NotifStream::new(rx)) } - async fn notify(&self, channel: &str, msg: &str) -> Result<(), Self::Error> { - sqlx::query("NOTIFY $1 $2") - .bind(channel) - .bind(msg) - .execute(&self.client) - .await?; - Ok(()) + fn notify( + &self, + channel: &str, + req: &T, + ) -> impl Future> + Send { + let client = self.client.clone(); + //TODO: handle error + let msg = serde_json::to_string(req).unwrap(); + async move { + sqlx::query("NOTIFY $1 $2") + .bind(channel) + .bind(msg) + .execute(&client) + .await?; + Ok(()) + } } } @@ -98,7 +111,7 @@ async fn spawn_listener(handlers: Arc>, mut listener: PgListe .inspect_err(|e| tracing::warn!(?e, "sqlx listener error")) { if let Some(tx) = handlers.read().unwrap().get(notif.channel()) { - tx.try_send(notif); + tx.send(notif); } else { tracing::warn!("handler not found for channel {}", notif.channel()); } diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index a3441158..91196003 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -32,15 +32,15 @@ //! use drivers::Driver; - use futures_core::Stream; -use serde::{Serialize, de::DeserializeOwned}; +use futures_util::StreamExt; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use socketioxide_core::{ Sid, Uid, adapter::{ BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room, RoomParam, SocketEmitter, Spawnable, - errors::AdapterError, + errors::{AdapterError, BroadcastError}, remote_packet::{ RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, ResponseTypeId, @@ -50,7 +50,6 @@ use socketioxide_core::{ }; use std::{ borrow::Cow, - collections::HashMap, fmt, future, pin::Pin, sync::{Arc, Mutex}, @@ -60,6 +59,7 @@ use std::{ use tokio::sync::mpsc; mod drivers; +mod stream; /// The configuration of the [`MongoDbAdapter`]. #[derive(Debug, Clone)] @@ -103,11 +103,6 @@ pub enum Error { Decode(#[from] rmp_serde::decode::Error), } -impl Error { - fn from_driver(err: R::Error) -> Self { - Self::Driver(err) - } -} impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -124,7 +119,9 @@ impl From> for AdapterError { } } -pub(crate) type ResponseHandlers = HashMap>; +/// An event we should answer to +#[derive(Debug, Deserialize)] +struct Event {} /// The postgres adapter implementation. /// It is generic over the [`Driver`] used to communicate with the postgres server. @@ -142,8 +139,6 @@ pub struct CustomPostgresAdapter { local: CoreLocalAdapter, /// A map of nodes liveness, with the last time remote nodes were seen alive. nodes_liveness: Mutex>, - /// A map of response handlers used to await for responses from the remote servers. - responses: Arc>, } impl DefinedAdapter for CustomPostgresAdapter {} @@ -161,13 +156,12 @@ impl CoreAdapter for CustomPostgresAdapter driver: state.driver.clone(), config: state.config.clone(), nodes_liveness: Mutex::new(Vec::new()), - responses: Arc::new(Mutex::new(HashMap::new())), } } fn init(self: Arc, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes { let fut = async move { - let stream = self.driver.watch(self.uid, self.local.path()).await?; + let stream = self.driver.listen("event").await?; tokio::spawn(self.clone().handle_ev_stream(stream)); tokio::spawn(self.clone().heartbeat_job()); @@ -254,9 +248,9 @@ impl CoreAdapter for CustomPostgresAdapter let remote_serv_cnt = self.server_count().await?.saturating_sub(1); tracing::trace!(?remote_serv_cnt, "expecting acks from remote servers"); + let res = self.driver.listen("").await?; let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize); - self.responses.lock().unwrap().insert(req_id, tx); self.send_req(req, None).await?; let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); @@ -375,56 +369,14 @@ impl CustomPostgresAdapter { } } - async fn handle_ev_stream( - self: Arc, - mut stream: impl Stream> + Unpin, - ) { - while let Some(item) = stream.next().await { - match item { - Ok(Item { - header: ItemHeader::Req { target, .. }, - data, - .. - }) if target.is_none_or(|id| id == self.uid) => { - tracing::debug!(?target, "request header"); - if let Err(e) = self.recv_req(data).await { - tracing::warn!("error receiving request from driver: {e}"); - } - } - Ok(Item { - header: ItemHeader::Req { target, .. }, - .. - }) => { - tracing::debug!( - ?target, - "receiving request which is not for us, skipping..." - ); - } - Ok( - item @ Item { - header: ItemHeader::Res { request, .. }, - .. - }, - ) => { - tracing::trace!(?request, "received response"); - let handlers = self.responses.lock().unwrap(); - if let Some(tx) = handlers.get(&request) { - if let Err(e) = tx.try_send(item) { - tracing::warn!("error sending response to handler: {e}"); - } - } else { - tracing::warn!(?request, ?handlers, "could not find req handler"); - } - } - Err(e) => { - tracing::warn!("error receiving event from driver: {e}"); - } - } + async fn handle_ev_stream(self: Arc, stream: impl Stream) { + futures_util::pin_mut!(stream); + while let Some(req) = stream.next().await { + self.recv_req(req); } } - async fn recv_req(self: &Arc, req: Vec) -> Result<(), Error> { - let req = rmp_serde::from_slice::(&req)?; + fn recv_req(self: &Arc, req: RequestIn) { tracing::trace!(?req, "incoming request"); match (req.r#type, req.opts) { (RequestTypeIn::Broadcast(p), Some(opts)) => self.recv_broadcast(opts, p), @@ -443,7 +395,6 @@ impl CustomPostgresAdapter { } _ => (), } - Ok(()) } fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) { @@ -586,31 +537,29 @@ impl CustomPostgresAdapter { /// Send a request to a specific target node or broadcast it to all nodes if no target is specified. async fn send_req(&self, req: RequestOut<'_>, target: Option) -> Result<(), Error> { tracing::trace!(?req, "sending request"); - let head = ItemHeader::Req { target }; - let req = self.new_packet(head, &req)?; - self.driver.emit(&req).await.map_err(Error::from_driver)?; + // let head = ItemHeader::Req { target }; + // let req = self.new_packet(head, &req)?; + self.driver + .notify("yolo", &req) + .await + .map_err(Error::Driver)?; Ok(()) } /// Send a response to the node that sent the request. - fn send_res( + async fn send_res( &self, req_id: Sid, req_origin: Uid, res: Response, - ) -> impl Future>> + Send + 'static { + ) -> Result<(), Error> { tracing::trace!(?res, "sending response for {req_id} req to {req_origin}"); - let driver = self.driver.clone(); - let head = ItemHeader::Res { - request: req_id, - target: req_origin, - }; - let res = self.new_packet(head, &res); - async move { - driver.emit(&res?).await.map_err(Error::from_driver)?; - Ok(()) - } + self.driver + .notify("response", &res) + .await + .map_err(Error::Driver)?; + Ok(()) } /// Await for all the responses from the remote servers. @@ -666,10 +615,6 @@ impl CustomPostgresAdapter { ) .await } - fn new_packet(&self, head: ItemHeader, data: &impl Serialize) -> Result> { - let ns = &self.local.path(); - let uid = self.uid; - } } /// The result of the init future. diff --git a/crates/socketioxide-postgres/src/stream.rs b/crates/socketioxide-postgres/src/stream.rs new file mode 100644 index 00000000..e27960e5 --- /dev/null +++ b/crates/socketioxide-postgres/src/stream.rs @@ -0,0 +1,265 @@ +use std::{ + fmt, + pin::Pin, + sync::{Arc, Mutex}, + task::{self, Poll}, + time::Duration, +}; + +use futures_core::{FusedStream, Stream}; +use futures_util::{StreamExt, stream::TakeUntil}; +use pin_project_lite::pin_project; +use serde::de::DeserializeOwned; +use socketioxide_core::{ + Sid, + adapter::AckStreamItem, + adapter::remote_packet::{Response, ResponseType}, +}; +use tokio::{sync::mpsc, time}; + +pin_project! { + /// A stream of acknowledgement messages received from the local and remote servers. + /// It merges the local ack stream with the remote ack stream from all the servers. + // The server_cnt is the number of servers that are expected to send a AckCount message. + // It is decremented each time a AckCount message is received. + // + // The ack_cnt is the number of acks that are expected to be received. It is the sum of all the the ack counts. + // And it is decremented each time an ack is received. + // + // Therefore an exhausted stream correspond to `ack_cnt == 0` and `server_cnt == 0`. + pub struct AckStream { + #[pin] + local: S, + #[pin] + remote: DropStream>, + ack_cnt: u32, + total_ack_cnt: usize, + serv_cnt: u16, + } +} + +impl AckStream { + pub fn new( + local: S, + rx: mpsc::Receiver, + timeout: Duration, + serv_cnt: u16, + req_id: Sid, + ) -> Self { + let remote = ChanStream::new(rx).take_until(time::sleep(timeout)); + let remote = DropStream::new(remote, handlers, req_id); + Self { + local, + ack_cnt: 0, + total_ack_cnt: 0, + serv_cnt, + } + } + pub fn new_local(local: S) -> Self { + let handlers = Arc::new(Mutex::new(ResponseHandlers::new())); + let rx = mpsc::channel(1).1; + let remote = ChanStream::new(rx).take_until(time::sleep(Duration::ZERO)); + let remote = DropStream::new(remote, handlers, Sid::ZERO); + Self { + local, + remote, + ack_cnt: 0, + total_ack_cnt: 0, + serv_cnt: 0, + } + } +} +impl AckStream +where + Err: DeserializeOwned + fmt::Debug, + S: Stream> + FusedStream, +{ + /// Poll the remote stream. First the count of acks is received, then the acks are received. + /// We expect `serv_cnt` of `BroadcastAckCount` messages to be received, then we expect + /// `ack_cnt` of `BroadcastAck` messages. + fn poll_remote( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll>> { + // remote stream is not fused, so we need to check if it is terminated + if FusedStream::is_terminated(&self) { + return Poll::Ready(None); + } + let mut projection = self.project(); + loop { + match projection.remote.as_mut().poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Ready(Some(Item { header, data, .. })) => { + let res = rmp_serde::from_slice::>(&data); + match res { + Ok(Response { + node_id: uid, + r#type: ResponseType::BroadcastAckCount(count), + }) if *projection.serv_cnt > 0 => { + tracing::trace!(?uid, ?header, "receiving broadcast ack count {count}"); + *projection.ack_cnt += count; + *projection.total_ack_cnt += count as usize; + *projection.serv_cnt -= 1; + } + Ok(Response { + node_id: uid, + r#type: ResponseType::BroadcastAck((sid, res)), + }) if *projection.ack_cnt > 0 => { + tracing::trace!( + ?uid, + ?header, + "receiving broadcast ack {sid} {:?}", + res + ); + *projection.ack_cnt -= 1; + return Poll::Ready(Some((sid, res))); + } + Ok(Response { node_id: uid, .. }) => { + tracing::warn!(?uid, ?header, "unexpected response type"); + } + Err(e) => { + tracing::warn!("error decoding ack response: {e}"); + } + } + } + } + } + } +} +impl Stream for AckStream +where + E: DeserializeOwned + fmt::Debug, + S: Stream> + FusedStream, +{ + type Item = AckStreamItem; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match self.as_mut().project().local.poll_next(cx) { + Poll::Pending => match self.poll_remote(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(item)) => Poll::Ready(Some(item)), + Poll::Ready(None) => Poll::Pending, + }, + Poll::Ready(Some(item)) => Poll::Ready(Some(item)), + Poll::Ready(None) => self.poll_remote(cx), + } + } + + fn size_hint(&self) -> (usize, Option) { + let (lower, upper) = self.local.size_hint(); + (lower, upper.map(|upper| upper + self.total_ack_cnt)) + } +} + +impl FusedStream for AckStream +where + Err: DeserializeOwned + fmt::Debug, + S: Stream> + FusedStream, +{ + /// The stream is terminated if: + /// * The local stream is terminated. + /// * All the servers have sent the expected ack count. + /// * We have received all the expected acks. + fn is_terminated(&self) -> bool { + // remote stream is terminated if the timeout is reached + let remote_term = (self.ack_cnt == 0 && self.serv_cnt == 0) || self.remote.is_terminated(); + self.local.is_terminated() && remote_term + } +} +impl fmt::Debug for AckStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AckStream") + .field("ack_cnt", &self.ack_cnt) + .field("total_ack_cnt", &self.total_ack_cnt) + .field("serv_cnt", &self.serv_cnt) + .finish() + } +} + +pin_project! { + /// A stream of messages received from a channel. + pub struct ChanStream { + #[pin] + rx: mpsc::Receiver + } +} +impl ChanStream { + pub fn new(rx: mpsc::Receiver) -> Self { + Self { rx } + } +} +impl Stream for ChanStream { + type Item = Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().rx.poll_recv(cx) + } +} +pin_project! { + /// A stream that unsubscribes from its source channel when dropped. + pub struct DropStream { + #[pin] + stream: S, + req_id: Sid, + handlers: Arc> + } + impl PinnedDrop for DropStream { + fn drop(this: Pin<&mut Self>) { + let stream = this.project(); + let chan = stream.req_id; + tracing::debug!(?chan, "dropping stream"); + stream.handlers.lock().unwrap().remove(chan); + } + } +} +impl DropStream { + pub fn new(stream: S, handlers: Arc>, req_id: Sid) -> Self { + Self { + stream, + handlers, + req_id, + } + } +} +impl Stream for DropStream { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().stream.poll_next(cx) + } +} +impl FusedStream for DropStream { + fn is_terminated(&self) -> bool { + self.stream.is_terminated() + } +} + +#[cfg(test)] +mod tests { + use futures_core::FusedStream; + use futures_util::StreamExt; + use socketioxide_core::{Sid, Value}; + + use super::AckStream; + + #[tokio::test] + async fn local_ack_stream_should_have_a_closed_remote() { + let sid = Sid::new(); + let local = futures_util::stream::once(async move { + (sid, Ok::<_, ()>(Value::Str("local".into(), None))) + }); + let stream = AckStream::new_local(local); + futures_util::pin_mut!(stream); + assert_eq!(stream.ack_cnt, 0); + assert_eq!(stream.total_ack_cnt, 0); + assert_eq!(stream.serv_cnt, 0); + assert!(!stream.local.is_terminated()); + assert!(!stream.is_terminated()); + let data = stream.next().await; + assert!( + matches!(data, Some((id, Ok(Value::Str(msg, None)))) if id == sid && msg == "local") + ); + assert_eq!(stream.next().await, None); + assert!(stream.is_terminated()); + } +} From 818a3066a1bfea961d47fecc2f3d9f85b9ec9aca Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 22 Feb 2026 22:00:53 +0100 Subject: [PATCH 04/31] wip --- .../socketioxide-postgres/src/drivers/mod.rs | 38 +++++- .../src/drivers/postgres.rs | 15 ++- .../socketioxide-postgres/src/drivers/sqlx.rs | 57 ++++----- crates/socketioxide-postgres/src/lib.rs | 66 ++++++----- crates/socketioxide-postgres/src/stream.rs | 108 ++++-------------- 5 files changed, 133 insertions(+), 151 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index d038e341..a0e92ed9 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,7 +1,8 @@ use futures_core::Stream; use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use tokio::sync::mpsc; -mod postgres; +// mod postgres; mod sqlx; pub type ChanItem = (String, String); @@ -13,7 +14,8 @@ pub struct Item {} /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { type Error: std::error::Error + Send + 'static; - type NotifStream: Stream + Send + 'static; + type NotifStream: Stream + Send + 'static; + type Notification: Notification; fn init(&self, table: &str, channels: &[&str]) -> impl Future>; @@ -21,7 +23,7 @@ pub trait Driver: Clone + Send + Sync + 'static { fn listen( &self, channel: &str, - ) -> impl Future, Self::Error>> + Send; + ) -> impl Future> + Send; fn notify( &self, @@ -29,3 +31,33 @@ pub trait Driver: Clone + Send + Sync + 'static { message: &T, ) -> impl Future> + Send; } + +pub trait Notification: Send + 'static { + fn channel(&self) -> &str; + fn payload(&self) -> &str; +} + +pin_project_lite::pin_project! { + pub struct NotifStream { + #[pin] + rx: mpsc::UnboundedReceiver, + } +} +impl Stream for NotifStream { + type Item = T; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.rx.poll_recv(cx) { + std::task::Poll::Ready(notif) => std::task::Poll::Ready(notif), + std::task::Poll::Pending => std::task::Poll::Pending, + } + } +} +impl NotifStream { + pub fn new(rx: mpsc::UnboundedReceiver) -> Self { + NotifStream { rx } + } +} diff --git a/crates/socketioxide-postgres/src/drivers/postgres.rs b/crates/socketioxide-postgres/src/drivers/postgres.rs index 86dda23f..0d6d408a 100644 --- a/crates/socketioxide-postgres/src/drivers/postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/postgres.rs @@ -19,15 +19,28 @@ impl PostgresDriver { impl Driver for PostgresDriver { type Error = tokio_postgres::Error; + type NotifStream; async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { self.client .execute("CREATE TABLE $1 IF NOT EXISTS", &[&table]) .await?; + Ok(()) } - async fn notify(&self, channel: &str, msg: &str) -> Result<(), Self::Error> { + fn listen( + &self, + channel: &str, + ) -> impl Future, Self::Error>> + Send { + todo!() + } + + fn notify( + &self, + channel: &str, + message: &T, + ) -> impl Future> + Send { todo!() } } diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index c798fa0b..9a0c6d86 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -1,10 +1,8 @@ use std::{ collections::HashMap, - marker::PhantomData, sync::{Arc, RwLock}, }; -use futures_core::Stream; use serde::{Serialize, de::DeserializeOwned}; use sqlx::{ PgPool, @@ -12,9 +10,10 @@ use sqlx::{ }; use tokio::sync::mpsc; -use crate::PostgresAdapterConfig; +use crate::{PostgresAdapterConfig, drivers::NotifStream}; use super::Driver; + type HandlerMap = HashMap>; #[derive(Debug, Clone)] @@ -23,6 +22,7 @@ pub struct SqlxDriver { handlers: Arc>, config: PostgresAdapterConfig, } + impl SqlxDriver { pub fn new(client: PgPool, config: PostgresAdapterConfig) -> Self { Self { @@ -33,43 +33,17 @@ impl SqlxDriver { } } -pin_project_lite::pin_project! { - pub struct NotifStream { - #[pin] - rx: mpsc::UnboundedReceiver, - _phantom: std::marker::PhantomData T> - } -} -impl Stream for NotifStream { - type Item = T; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.rx.poll_recv(cx) { - std::task::Poll::Ready(_) => todo!(), - std::task::Poll::Pending => todo!(), - } - } -} -impl NotifStream { - pub fn new(rx: mpsc::UnboundedReceiver) -> Self { - NotifStream { - rx, - _phantom: PhantomData::default(), - } - } -} - impl Driver for SqlxDriver { type Error = sqlx::Error; - type NotifStream = NotifStream; + type NotifStream = NotifStream; + type Notification = PgNotification; + async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { sqlx::query("CREATE TABLE $1 IF NOT EXISTS") .bind(&table) .execute(&self.client) .await?; + let mut listener = PgListener::connect_with(&self.client).await?; listener.listen_all(channels.iter().copied()).await?; tokio::spawn(spawn_listener(self.handlers.clone(), listener)); @@ -79,9 +53,12 @@ impl Driver for SqlxDriver { async fn listen( &self, channel: &str, - ) -> Result, Self::Error> { + ) -> Result { let (tx, rx) = mpsc::unbounded_channel(); - self.handlers.write().unwrap().insert(channel.into(), tx); + self.handlers + .write() + .unwrap() + .insert(channel.to_string(), tx); Ok(NotifStream::new(rx)) } @@ -117,3 +94,13 @@ async fn spawn_listener(handlers: Arc>, mut listener: PgListe } } } + +impl super::Notification for PgNotification { + fn channel(&self) -> &str { + PgNotification::channel(self) + } + + fn payload(&self) -> &str { + PgNotification::payload(self) + } +} diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 91196003..92e18675 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -2,7 +2,7 @@ #![warn( clippy::all, clippy::todo, - clippy::empty_enum, + clippy::empty_enums, clippy::mem_forget, clippy::unused_self, clippy::filter_map_next, @@ -11,7 +11,7 @@ clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, clippy::await_holding_lock, - clippy::match_on_vec_items, + clippy::indexing_slicing, clippy::imprecise_flops, clippy::suboptimal_flops, clippy::lossy_float_literal, @@ -50,13 +50,15 @@ use socketioxide_core::{ }; use std::{ borrow::Cow, + collections::HashMap, fmt, future, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll}, time::{Duration, Instant}, }; -use tokio::sync::mpsc; + +use crate::{drivers::Notification, stream::AckStream}; mod drivers; mod stream; @@ -123,11 +125,16 @@ impl From> for AdapterError { #[derive(Debug, Deserialize)] struct Event {} +pub struct PostgresAdapterCtr { + driver: D, + config: PostgresAdapterConfig, +} + /// The postgres adapter implementation. /// It is generic over the [`Driver`] used to communicate with the postgres server. /// And over the [`SocketEmitter`] used to communicate with the local server. This allows to /// avoid cyclic dependencies between the adapter, `socketioxide-core` and `socketioxide` crates. -pub struct CustomPostgresAdapter { +pub struct CustomPostgresAdapter { /// The driver used by the adapter. This is used to communicate with the postgres server. /// All the postgres adapter instances share the same driver. driver: D, @@ -139,13 +146,15 @@ pub struct CustomPostgresAdapter { local: CoreLocalAdapter, /// A map of nodes liveness, with the last time remote nodes were seen alive. nodes_liveness: Mutex>, + /// A map of response handlers used to await for responses from the remote servers. + responses: Arc>>, } -impl DefinedAdapter for CustomPostgresAdapter {} +impl DefinedAdapter for CustomPostgresAdapter {} impl CoreAdapter for CustomPostgresAdapter { type Error = Error; type State = PostgresAdapterCtr; - type AckStream = AckStream; + type AckStream = AckStream; type InitRes = InitRes; fn new(state: &Self::State, local: CoreLocalAdapter) -> Self { @@ -156,6 +165,7 @@ impl CoreAdapter for CustomPostgresAdapter driver: state.driver.clone(), config: state.config.clone(), nodes_liveness: Mutex::new(Vec::new()), + responses: Arc::new(Mutex::new(HashMap::new())), } } @@ -248,19 +258,16 @@ impl CoreAdapter for CustomPostgresAdapter let remote_serv_cnt = self.server_count().await?.saturating_sub(1); tracing::trace!(?remote_serv_cnt, "expecting acks from remote servers"); - let res = self.driver.listen("").await?; + let remote = self.driver.listen("").await?; - let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize); self.send_req(req, None).await?; let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); Ok(AckStream::new( local, - rx, + remote, self.config.request_timeout, remote_serv_cnt, - req_id, - self.responses.clone(), )) } @@ -283,7 +290,7 @@ impl CoreAdapter for CustomPostgresAdapter let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts); let req_id = req.id; - // First get the remote stream because mongodb might send + // First get the remote stream because postgres might send // the responses before subscription is done. let stream = self .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id) @@ -547,19 +554,22 @@ impl CustomPostgresAdapter { } /// Send a response to the node that sent the request. - async fn send_res( + fn send_res( &self, req_id: Sid, req_origin: Uid, res: Response, - ) -> Result<(), Error> { + ) -> impl Future>> + 'static { tracing::trace!(?res, "sending response for {req_id} req to {req_origin}"); - - self.driver - .notify("response", &res) - .await - .map_err(Error::Driver)?; - Ok(()) + let driver = self.driver.clone(); + //TODO: is this the right way? + async move { + driver + .notify("response", &res) + .await + .map_err(Error::Driver)?; + Ok(()) + } } /// Await for all the responses from the remote servers. @@ -576,14 +586,16 @@ impl CustomPostgresAdapter { } else { 1 }; - let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1)); - self.responses.lock().unwrap().insert(req_id, tx); - let stream = ChanStream::new(rx) - .filter_map(|Item { header, data, .. }| { - let data = match rmp_serde::from_slice::>(&data) { + + let stream = self.driver.listen("test").await.unwrap(); + self.responses.lock().unwrap().insert(req_id, stream); + + let stream = stream + .filter_map(|notif| { + let data = match serde_json::from_str::>(notif.payload()) { Ok(data) => Some(data), Err(e) => { - tracing::warn!(header = ?header, "error decoding response: {e}"); + tracing::warn!(channel = %notif.channel(), "error decoding response: {e}"); None } }; @@ -592,13 +604,13 @@ impl CustomPostgresAdapter { .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type)) .take(remote_serv_cnt) .take_until(tokio::time::sleep(self.config.request_timeout)); + let stream = DropStream::new(stream, self.responses.clone(), req_id); Ok(stream) } /// Emit a heartbeat to the specified target node or broadcast to all nodes. async fn emit_heartbeat(&self, target: Option) -> Result<(), Error> { - // Send heartbeat when starting. self.send_req( RequestOut::new_empty(self.uid, RequestTypeOut::Heartbeat), target, diff --git a/crates/socketioxide-postgres/src/stream.rs b/crates/socketioxide-postgres/src/stream.rs index e27960e5..5e7a945f 100644 --- a/crates/socketioxide-postgres/src/stream.rs +++ b/crates/socketioxide-postgres/src/stream.rs @@ -1,7 +1,6 @@ use std::{ fmt, pin::Pin, - sync::{Arc, Mutex}, task::{self, Poll}, time::Duration, }; @@ -11,12 +10,13 @@ use futures_util::{StreamExt, stream::TakeUntil}; use pin_project_lite::pin_project; use serde::de::DeserializeOwned; use socketioxide_core::{ - Sid, adapter::AckStreamItem, adapter::remote_packet::{Response, ResponseType}, }; use tokio::{sync::mpsc, time}; +use crate::drivers::{NotifStream, Notification}; + pin_project! { /// A stream of acknowledgement messages received from the local and remote servers. /// It merges the local ack stream with the remote ack stream from all the servers. @@ -27,39 +27,32 @@ pin_project! { // And it is decremented each time an ack is received. // // Therefore an exhausted stream correspond to `ack_cnt == 0` and `server_cnt == 0`. - pub struct AckStream { + pub struct AckStream { #[pin] local: S, #[pin] - remote: DropStream>, + remote: TakeUntil, time::Sleep>, ack_cnt: u32, total_ack_cnt: usize, serv_cnt: u16, } } -impl AckStream { - pub fn new( - local: S, - rx: mpsc::Receiver, - timeout: Duration, - serv_cnt: u16, - req_id: Sid, - ) -> Self { - let remote = ChanStream::new(rx).take_until(time::sleep(timeout)); - let remote = DropStream::new(remote, handlers, req_id); +impl AckStream { + pub fn new(local: S, remote: NotifStream, timeout: Duration, serv_cnt: u16) -> Self { + let remote = remote.take_until(time::sleep(timeout)); Self { local, ack_cnt: 0, total_ack_cnt: 0, serv_cnt, + remote, } } + pub fn new_local(local: S) -> Self { - let handlers = Arc::new(Mutex::new(ResponseHandlers::new())); - let rx = mpsc::channel(1).1; - let remote = ChanStream::new(rx).take_until(time::sleep(Duration::ZERO)); - let remote = DropStream::new(remote, handlers, Sid::ZERO); + let rx = mpsc::unbounded_channel().1; + let remote = NotifStream::new(rx).take_until(time::sleep(Duration::ZERO)); Self { local, remote, @@ -69,7 +62,7 @@ impl AckStream { } } } -impl AckStream +impl AckStream where Err: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, @@ -90,14 +83,15 @@ where match projection.remote.as_mut().poll_next(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(Item { header, data, .. })) => { - let res = rmp_serde::from_slice::>(&data); + Poll::Ready(Some(notif)) => { + let channel = notif.channel(); + let res = serde_json::from_str::>(notif.payload()); match res { Ok(Response { node_id: uid, r#type: ResponseType::BroadcastAckCount(count), }) if *projection.serv_cnt > 0 => { - tracing::trace!(?uid, ?header, "receiving broadcast ack count {count}"); + tracing::trace!(?uid, channel, "receiving broadcast ack count {count}"); *projection.ack_cnt += count; *projection.total_ack_cnt += count as usize; *projection.serv_cnt -= 1; @@ -108,7 +102,7 @@ where }) if *projection.ack_cnt > 0 => { tracing::trace!( ?uid, - ?header, + channel, "receiving broadcast ack {sid} {:?}", res ); @@ -116,7 +110,7 @@ where return Poll::Ready(Some((sid, res))); } Ok(Response { node_id: uid, .. }) => { - tracing::warn!(?uid, ?header, "unexpected response type"); + tracing::warn!(?uid, channel, "unexpected response type"); } Err(e) => { tracing::warn!("error decoding ack response: {e}"); @@ -127,10 +121,11 @@ where } } } -impl Stream for AckStream +impl Stream for AckStream where E: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, + T: Notification, { type Item = AckStreamItem; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { @@ -151,10 +146,11 @@ where } } -impl FusedStream for AckStream +impl FusedStream for AckStream where Err: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, + T: Notification, { /// The stream is terminated if: /// * The local stream is terminated. @@ -166,7 +162,7 @@ where self.local.is_terminated() && remote_term } } -impl fmt::Debug for AckStream { +impl fmt::Debug for AckStream { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("AckStream") .field("ack_cnt", &self.ack_cnt) @@ -176,64 +172,6 @@ impl fmt::Debug for AckStream { } } -pin_project! { - /// A stream of messages received from a channel. - pub struct ChanStream { - #[pin] - rx: mpsc::Receiver - } -} -impl ChanStream { - pub fn new(rx: mpsc::Receiver) -> Self { - Self { rx } - } -} -impl Stream for ChanStream { - type Item = Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - self.project().rx.poll_recv(cx) - } -} -pin_project! { - /// A stream that unsubscribes from its source channel when dropped. - pub struct DropStream { - #[pin] - stream: S, - req_id: Sid, - handlers: Arc> - } - impl PinnedDrop for DropStream { - fn drop(this: Pin<&mut Self>) { - let stream = this.project(); - let chan = stream.req_id; - tracing::debug!(?chan, "dropping stream"); - stream.handlers.lock().unwrap().remove(chan); - } - } -} -impl DropStream { - pub fn new(stream: S, handlers: Arc>, req_id: Sid) -> Self { - Self { - stream, - handlers, - req_id, - } - } -} -impl Stream for DropStream { - type Item = S::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - self.project().stream.poll_next(cx) - } -} -impl FusedStream for DropStream { - fn is_terminated(&self) -> bool { - self.stream.is_terminated() - } -} - #[cfg(test)] mod tests { use futures_core::FusedStream; From 001740372a61fdba6fe273e0103aa02aa7677263 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 14:33:02 +0200 Subject: [PATCH 05/31] feat(adapter/postgre): wip --- .../workflows/adapter-ci/docker-compose.yml | 14 + Cargo.lock | 4 + crates/socketioxide-postgres/Cargo.toml | 3 +- .../socketioxide-postgres/src/drivers/mod.rs | 47 +-- .../socketioxide-postgres/src/drivers/sqlx.rs | 84 +++--- crates/socketioxide-postgres/src/lib.rs | 271 +++++++++++++----- crates/socketioxide-postgres/src/stream.rs | 123 ++++++-- e2e/adapter/Cargo.toml | 13 + e2e/adapter/main.rs | 40 ++- e2e/adapter/src/bins/sqlx.rs | 64 +++++ e2e/adapter/src/bins/sqlx_msgpack.rs | 65 +++++ 11 files changed, 518 insertions(+), 210 deletions(-) create mode 100644 e2e/adapter/src/bins/sqlx.rs create mode 100644 e2e/adapter/src/bins/sqlx_msgpack.rs diff --git a/.github/workflows/adapter-ci/docker-compose.yml b/.github/workflows/adapter-ci/docker-compose.yml index 3965afd3..7c0403d2 100644 --- a/.github/workflows/adapter-ci/docker-compose.yml +++ b/.github/workflows/adapter-ci/docker-compose.yml @@ -140,3 +140,17 @@ services: '; wait " + + postgres: + image: postgres:18-alpine + ports: + - 5432:5432 + environment: + POSTGRES_DB: socketio + POSTGRES_PASSWORD: socketio + POSTGRES_USER: socketio + healthcheck: + test: "pg_isready -U socketio" + interval: 2s + timeout: 5s + retries: 5 diff --git a/Cargo.lock b/Cargo.lock index 7cbc7442..597e0aa6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,7 @@ dependencies = [ "hyper-util", "socketioxide", "socketioxide-mongodb", + "socketioxide-postgres", "socketioxide-redis", "tokio", "tracing", @@ -2665,6 +2666,8 @@ dependencies = [ "sha2", "smallvec", "thiserror 2.0.17", + "tokio", + "tokio-stream", "tracing", "url", ] @@ -2701,6 +2704,7 @@ dependencies = [ "sqlx-core", "sqlx-postgres", "syn", + "tokio", "url", ] diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index 537fab4b..4029a391 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -25,7 +25,7 @@ futures-core.workspace = true futures-util.workspace = true pin-project-lite.workspace = true serde.workspace = true -serde_json.workspace = true +serde_json = { workspace = true, features = ["raw_value"] } smallvec = { workspace = true, features = ["serde"] } tokio = { workspace = true, features = ["time", "rt", "sync"] } rmp-serde.workspace = true @@ -38,6 +38,7 @@ tokio-postgres = { version = "0.7", default-features = false, optional = true, f ] } sqlx = { version = "0.8", default-features = false, optional = true, features = [ "postgres", + "runtime-tokio", ] } [dev-dependencies] diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index a0e92ed9..94de7f17 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,29 +1,21 @@ use futures_core::Stream; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use tokio::sync::mpsc; +use serde::Serialize; -// mod postgres; -mod sqlx; - -pub type ChanItem = (String, String); - -#[derive(Deserialize)] -pub struct Item {} +pub mod sqlx; /// The driver trait can be used to support different LISTEN/NOTIFY backends. /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { type Error: std::error::Error + Send + 'static; - type NotifStream: Stream + Send + 'static; type Notification: Notification; + type NotificationStream: Stream + Send; - fn init(&self, table: &str, channels: &[&str]) - -> impl Future>; + fn init(&self, table: &str) -> impl Future> + Send; - fn listen( + fn listen( &self, - channel: &str, - ) -> impl Future> + Send; + channels: &[&str], + ) -> impl Future> + Send; fn notify( &self, @@ -36,28 +28,3 @@ pub trait Notification: Send + 'static { fn channel(&self) -> &str; fn payload(&self) -> &str; } - -pin_project_lite::pin_project! { - pub struct NotifStream { - #[pin] - rx: mpsc::UnboundedReceiver, - } -} -impl Stream for NotifStream { - type Item = T; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.rx.poll_recv(cx) { - std::task::Poll::Ready(notif) => std::task::Poll::Ready(notif), - std::task::Poll::Pending => std::task::Poll::Pending, - } - } -} -impl NotifStream { - pub fn new(rx: mpsc::UnboundedReceiver) -> Self { - NotifStream { rx } - } -} diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index 9a0c6d86..a47b5169 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -1,65 +1,59 @@ -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, -}; - -use serde::{Serialize, de::DeserializeOwned}; +use futures_core::stream::BoxStream; +use futures_util::StreamExt; +use serde::Serialize; use sqlx::{ PgPool, postgres::{PgListener, PgNotification}, }; -use tokio::sync::mpsc; - -use crate::{PostgresAdapterConfig, drivers::NotifStream}; use super::Driver; -type HandlerMap = HashMap>; +pub use sqlx as sqlx_client; #[derive(Debug, Clone)] pub struct SqlxDriver { client: PgPool, - handlers: Arc>, - config: PostgresAdapterConfig, } impl SqlxDriver { - pub fn new(client: PgPool, config: PostgresAdapterConfig) -> Self { - Self { - client, - handlers: Arc::new(RwLock::new(HashMap::new())), - config, - } + /// Create a new SqlxDriver instance. + pub fn new(client: PgPool) -> Self { + Self { client } } } impl Driver for SqlxDriver { type Error = sqlx::Error; - type NotifStream = NotifStream; type Notification = PgNotification; + type NotificationStream = BoxStream<'static, Self::Notification>; - async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { - sqlx::query("CREATE TABLE $1 IF NOT EXISTS") - .bind(&table) - .execute(&self.client) - .await?; + async fn init(&self, table: &str) -> Result<(), Self::Error> { + sqlx::query(&format!( + r#"CREATE TABLE IF NOT EXISTS "{table}" ( + id BIGSERIAL UNIQUE, + created_at TIMESTAMPTZ DEFAULT NOW(), + payload BYTEA + )"#, + )) + .execute(&self.client) + .await?; + Ok(()) + } + + async fn listen(&self, channels: &[&str]) -> Result { let mut listener = PgListener::connect_with(&self.client).await?; listener.listen_all(channels.iter().copied()).await?; - tokio::spawn(spawn_listener(self.handlers.clone(), listener)); - Ok(()) - } - async fn listen( - &self, - channel: &str, - ) -> Result { - let (tx, rx) = mpsc::unbounded_channel(); - self.handlers - .write() - .unwrap() - .insert(channel.to_string(), tx); - Ok(NotifStream::new(rx)) + let stream = listener.into_stream(); + let stream = stream.filter_map(async |res| { + res.inspect_err(|err| { + tracing::warn!("failed to pull sqlx notification from stream: {err}") + }) + .ok() + }); + + Ok(Box::pin(stream)) } fn notify( @@ -71,7 +65,7 @@ impl Driver for SqlxDriver { //TODO: handle error let msg = serde_json::to_string(req).unwrap(); async move { - sqlx::query("NOTIFY $1 $2") + sqlx::query("SELECT pg_notify($1, $2)") .bind(channel) .bind(msg) .execute(&client) @@ -81,20 +75,6 @@ impl Driver for SqlxDriver { } } -async fn spawn_listener(handlers: Arc>, mut listener: PgListener) { - while let Ok(notif) = listener - .recv() - .await - .inspect_err(|e| tracing::warn!(?e, "sqlx listener error")) - { - if let Some(tx) = handlers.read().unwrap().get(notif.channel()) { - tx.send(notif); - } else { - tracing::warn!("handler not found for channel {}", notif.channel()); - } - } -} - impl super::Notification for PgNotification { fn channel(&self) -> &str { PgNotification::channel(self) diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 92e18675..f4ee9a23 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -29,12 +29,13 @@ nonstandard_style, missing_docs )] -//! +//! test use drivers::Driver; use futures_core::Stream; -use futures_util::StreamExt; +use futures_util::{StreamExt, pin_mut}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde_json::value::RawValue; use socketioxide_core::{ Sid, Uid, adapter::{ @@ -57,10 +58,14 @@ use std::{ task::{Context, Poll}, time::{Duration, Instant}, }; +use tokio::sync::mpsc; -use crate::{drivers::Notification, stream::AckStream}; +use crate::{ + drivers::Notification, + stream::{AckStream, ChanStream}, +}; -mod drivers; +pub mod drivers; mod stream; /// The configuration of the [`MongoDbAdapter`]. @@ -81,36 +86,49 @@ pub struct PostgresAdapterConfig { /// than you poll them with the returned stream, you might want to increase this value. pub ack_response_buffer: usize, /// The table name used to store socket.io attachments. Default is "socket_io_attachments". + /// + /// > The table name must be a sanitized string. Do not use special characters or spaces. pub table_name: Cow<'static, str>, /// The prefix used for the channels. Default is "socket.io". pub prefix: Cow<'static, str>, /// The treshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: - /// + /// . By default it is 8KB (8000 bytes). pub payload_treshold: usize, - /// The duration between cleanup queries on the + /// The duration between cleanup queries on the attachment table. pub cleanup_intervals: Duration, } +impl Default for PostgresAdapterConfig { + fn default() -> Self { + Self { + hb_timeout: Duration::from_secs(60), + hb_interval: Duration::from_secs(10), + request_timeout: Duration::from_secs(5), + ack_response_buffer: 255, + table_name: "socket_io_attachments".into(), + prefix: "socket.io".into(), + payload_treshold: 8_000, + cleanup_intervals: Duration::from_secs(60), + } + } +} + /// Represent any error that might happen when using this adapter. #[derive(thiserror::Error)] pub enum Error { /// Mongo driver error #[error("driver error: {0}")] Driver(D::Error), - /// Packet encoding error - #[error("packet encoding error: {0}")] - Encode(#[from] rmp_serde::encode::Error), - /// Packet decoding error + /// Packet encoding/decoding error #[error("packet decoding error: {0}")] - Decode(#[from] rmp_serde::decode::Error), + Serde(#[from] serde_json::Error), } impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Driver(err) => write!(f, "Driver error: {:?}", err), - Self::Decode(err) => write!(f, "Decode error: {:?}", err), - Self::Encode(err) => write!(f, "Encode error: {:?}", err), + Self::Serde(err) => write!(f, "Encode/Decode error: {:?}", err), } } } @@ -121,15 +139,24 @@ impl From> for AdapterError { } } -/// An event we should answer to -#[derive(Debug, Deserialize)] -struct Event {} - +/// Constructor for the PostgresAdapterCtr struct. pub struct PostgresAdapterCtr { driver: D, config: PostgresAdapterConfig, } +impl PostgresAdapterCtr { + /// Create a new adapter constructor with a custom postgres driver and a config. + /// + /// You can implement your own driver by implementing the [`Driver`] trait with any postgres client. + /// Check the [`drivers`] module for more information. + pub fn new_with_driver(driver: D, config: PostgresAdapterConfig) -> Self { + Self { driver, config } + } +} + +type ResponseHandlers = HashMap>>; + /// The postgres adapter implementation. /// It is generic over the [`Driver`] used to communicate with the postgres server. /// And over the [`SocketEmitter`] used to communicate with the local server. This allows to @@ -140,28 +167,24 @@ pub struct CustomPostgresAdapter { driver: D, /// The configuration of the adapter. config: PostgresAdapterConfig, - /// A unique identifier for the adapter to identify itself in the postgres server. - uid: Uid, /// The local adapter, used to manage local rooms and socket stores. local: CoreLocalAdapter, /// A map of nodes liveness, with the last time remote nodes were seen alive. nodes_liveness: Mutex>, /// A map of response handlers used to await for responses from the remote servers. - responses: Arc>>, + responses: Arc>, } impl DefinedAdapter for CustomPostgresAdapter {} impl CoreAdapter for CustomPostgresAdapter { type Error = Error; type State = PostgresAdapterCtr; - type AckStream = AckStream; + type AckStream = AckStream; type InitRes = InitRes; fn new(state: &Self::State, local: CoreLocalAdapter) -> Self { - let uid = local.server_id(); Self { local, - uid, driver: state.driver.clone(), config: state.config.clone(), nodes_liveness: Mutex::new(Vec::new()), @@ -171,14 +194,26 @@ impl CoreAdapter for CustomPostgresAdapter fn init(self: Arc, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes { let fut = async move { - let stream = self.driver.listen("event").await?; + self.driver.init(&self.config.table_name).await?; + + let global_chan = self.get_global_chan(); + let node_chan = self.get_node_chan(self.local.server_id()); + let response_chan = self.get_response_chan(self.local.server_id()); + + let channels = [ + global_chan.as_str(), + node_chan.as_str(), + response_chan.as_str(), + ]; + + let stream = self.driver.listen(&channels).await?; tokio::spawn(self.clone().handle_ev_stream(stream)); tokio::spawn(self.clone().heartbeat_job()); // Send initial heartbeat when starting. self.emit_init_heartbeat().await.map_err(|e| match e { Error::Driver(e) => e, - Error::Encode(_) | Error::Decode(_) => unreachable!(), + Error::Serde(_) => unreachable!(), })?; on_success(); @@ -205,8 +240,9 @@ impl CoreAdapter for CustomPostgresAdapter packet: Packet, opts: BroadcastOptions, ) -> Result<(), BroadcastError> { - if !opts.is_local(self.uid) { - let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts); + let node_id = self.local.server_id(); + if !opts.is_local(node_id) { + let req = RequestOut::new(node_id, RequestTypeOut::Broadcast(&packet), &opts); self.send_req(req, None).await.map_err(AdapterError::from)?; } @@ -247,33 +283,45 @@ impl CoreAdapter for CustomPostgresAdapter opts: BroadcastOptions, timeout: Option, ) -> Result { - if opts.is_local(self.uid) { + if opts.is_local(self.local.server_id()) { tracing::debug!(?opts, "broadcast with ack is local"); let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); let stream = AckStream::new_local(local); return Ok(stream); } - let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts); + let req = RequestOut::new( + self.local.server_id(), + RequestTypeOut::BroadcastWithAck(&packet), + &opts, + ); let req_id = req.id; let remote_serv_cnt = self.server_count().await?.saturating_sub(1); tracing::trace!(?remote_serv_cnt, "expecting acks from remote servers"); - let remote = self.driver.listen("").await?; + + let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize); + self.responses.lock().unwrap().insert(req_id, tx); self.send_req(req, None).await?; let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); Ok(AckStream::new( local, - remote, + rx, self.config.request_timeout, remote_serv_cnt, + req_id, + self.responses.clone(), )) } async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> { - if !opts.is_local(self.uid) { - let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts); + if !opts.is_local(self.local.server_id()) { + let req = RequestOut::new( + self.local.server_id(), + RequestTypeOut::DisconnectSockets, + &opts, + ); self.send_req(req, None).await.map_err(AdapterError::from)?; } self.local @@ -284,10 +332,10 @@ impl CoreAdapter for CustomPostgresAdapter } async fn rooms(&self, opts: BroadcastOptions) -> Result, Self::Error> { - if opts.is_local(self.uid) { + if opts.is_local(self.local.server_id()) { return Ok(self.local.rooms(opts).into_iter().collect()); } - let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts); + let req = RequestOut::new(self.local.server_id(), RequestTypeOut::AllRooms, &opts); let req_id = req.id; // First get the remote stream because postgres might send @@ -313,8 +361,12 @@ impl CoreAdapter for CustomPostgresAdapter rooms: impl RoomParam, ) -> Result<(), Self::Error> { let rooms: Vec = rooms.into_room_iter().collect(); - if !opts.is_local(self.uid) { - let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts); + if !opts.is_local(self.local.server_id()) { + let req = RequestOut::new( + self.local.server_id(), + RequestTypeOut::AddSockets(&rooms), + &opts, + ); self.send_req(req, opts.server_id).await?; } self.local.add_sockets(opts, rooms); @@ -327,8 +379,12 @@ impl CoreAdapter for CustomPostgresAdapter rooms: impl RoomParam, ) -> Result<(), Self::Error> { let rooms: Vec = rooms.into_room_iter().collect(); - if !opts.is_local(self.uid) { - let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts); + if !opts.is_local(self.local.server_id()) { + let req = RequestOut::new( + self.local.server_id(), + RequestTypeOut::DelSockets(&rooms), + &opts, + ); self.send_req(req, opts.server_id).await?; } self.local.del_sockets(opts, rooms); @@ -339,11 +395,11 @@ impl CoreAdapter for CustomPostgresAdapter &self, opts: BroadcastOptions, ) -> Result, Self::Error> { - if opts.is_local(self.uid) { + if opts.is_local(self.local.server_id()) { return Ok(self.local.fetch_sockets(opts)); } - let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts); - // First get the remote stream because mongodb might send + let req = RequestOut::new(self.local.server_id(), RequestTypeOut::FetchSockets, &opts); + // First get the remote stream because postgres might send // the responses before subscription is done. let remote = self .get_res::(req.id, ResponseTypeId::FetchSockets, opts.server_id) @@ -376,10 +432,46 @@ impl CustomPostgresAdapter { } } - async fn handle_ev_stream(self: Arc, stream: impl Stream) { - futures_util::pin_mut!(stream); - while let Some(req) = stream.next().await { - self.recv_req(req); + async fn handle_ev_stream(self: Arc, stream: impl Stream) { + pin_mut!(stream); + while let Some(notif) = stream.next().await { + let chan = notif.channel(); + let resp_chan = self.get_response_chan(self.local.server_id()); + tracing::info!(chan, resp_chan, notif = notif.payload(), ""); + if chan == resp_chan { + match serde_json::from_str(notif.payload()) { + Ok(ResponsePacket { + req_id, + node_id, + payload, + }) if node_id != self.local.server_id() => { + let handlers = self.responses.lock().unwrap(); + if let Some(handler) = handlers.get(&req_id) { + if let Err(e) = handler.try_send(payload) { + tracing::warn!(channel = resp_chan, req_id = %req_id, "error sending response: {e}"); + } + } else { + tracing::warn!(channel = resp_chan, req_id = %req_id, "response handler not found"); + } + } + Ok(_) => { + tracing::trace!("skipping loopback packets"); + } + Err(e) => { + tracing::warn!(channel = %notif.channel(), "error handling response: {e}") + } + }; + } else { + match serde_json::from_str::(notif.payload()) { + Ok(req) if req.node_id != self.local.server_id() => self.recv_req(req), + Ok(_) => { + tracing::trace!("skipping loopback packets") + } + Err(e) => { + tracing::warn!(channel = %notif.channel(), "error decoding request: {e}") + } + }; + } } } @@ -408,7 +500,7 @@ impl CustomPostgresAdapter { tracing::trace!(?opts, "incoming broadcast"); if let Err(e) = self.local.broadcast(packet, opts) { let ns = self.local.path(); - tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e); + tracing::warn!(node_id = %self.local.server_id(), ?ns, "remote request broadcast handler: {:?}", e); } } @@ -416,8 +508,8 @@ impl CustomPostgresAdapter { if let Err(e) = self.local.disconnect_socket(opts) { let ns = self.local.path(); tracing::warn!( - ?self.uid, - ?ns, + node_id = %self.local.server_id(), + %ns, "remote request disconnect sockets handler: {:?}", e ); @@ -436,8 +528,8 @@ impl CustomPostgresAdapter { let on_err = |err| { let ns = self.local.path(); tracing::warn!( - ?self.uid, - ?ns, + node_id = %self.local.server_id(), + %ns, "remote request broadcast with ack handler errors: {:?}", err ); @@ -446,7 +538,7 @@ impl CustomPostgresAdapter { // This is used to keep track of the number of expected acks. let res = Response { r#type: ResponseType::<()>::BroadcastAckCount(count), - node_id: self.uid, + node_id: self.local.server_id(), }; if let Err(err) = self.send_res(req_id, origin, res).await { on_err(err); @@ -458,7 +550,7 @@ impl CustomPostgresAdapter { while let Some(ack) = stream.next().await { let res = Response { r#type: ResponseType::BroadcastAck(ack), - node_id: self.uid, + node_id: self.local.server_id(), }; if let Err(err) = self.send_res(req_id, origin, res).await { on_err(err); @@ -472,11 +564,11 @@ impl CustomPostgresAdapter { let rooms = self.local.rooms(opts); let res = Response { r#type: ResponseType::<()>::AllRooms(rooms), - node_id: self.uid, + node_id: self.local.server_id(), }; let fut = self.send_res(req_id, origin, res); let ns = self.local.path().clone(); - let uid = self.uid; + let uid = self.local.server_id(); tokio::spawn(async move { if let Err(err) = fut.await { tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err); @@ -494,12 +586,12 @@ impl CustomPostgresAdapter { fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) { let sockets = self.local.fetch_sockets(opts); let res = Response { - node_id: self.uid, + node_id: self.local.server_id(), r#type: ResponseType::FetchSockets(sockets), }; let fut = self.send_res(req_id, origin, res); let ns = self.local.path().clone(); - let uid = self.uid; + let uid = self.local.server_id(); tokio::spawn(async move { if let Err(err) = fut.await { tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err); @@ -544,10 +636,12 @@ impl CustomPostgresAdapter { /// Send a request to a specific target node or broadcast it to all nodes if no target is specified. async fn send_req(&self, req: RequestOut<'_>, target: Option) -> Result<(), Error> { tracing::trace!(?req, "sending request"); - // let head = ItemHeader::Req { target }; - // let req = self.new_packet(head, &req)?; + let chan = match target { + Some(target) => self.get_node_chan(target), + None => self.get_global_chan(), + }; self.driver - .notify("yolo", &req) + .notify(&chan, &req) .await .map_err(Error::Driver)?; Ok(()) @@ -558,16 +652,23 @@ impl CustomPostgresAdapter { &self, req_id: Sid, req_origin: Uid, - res: Response, + payload: Response, ) -> impl Future>> + 'static { - tracing::trace!(?res, "sending response for {req_id} req to {req_origin}"); + tracing::trace!( + ?payload, + "sending response for {req_id} req to {req_origin}" + ); let driver = self.driver.clone(); + let chan = self.get_response_chan(req_origin); + let payload = RawValue::from_string(serde_json::to_string(&payload).unwrap()).unwrap(); + let res = ResponsePacket { + req_id, + node_id: self.local.server_id(), + payload, + }; //TODO: is this the right way? async move { - driver - .notify("response", &res) - .await - .map_err(Error::Driver)?; + driver.notify(&chan, &res).await.map_err(Error::Driver)?; Ok(()) } } @@ -587,15 +688,16 @@ impl CustomPostgresAdapter { 1 }; - let stream = self.driver.listen("test").await.unwrap(); - self.responses.lock().unwrap().insert(req_id, stream); + let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1)); + self.responses.lock().unwrap().insert(req_id, tx); + let stream = ChanStream::new(rx); let stream = stream - .filter_map(|notif| { - let data = match serde_json::from_str::>(notif.payload()) { + .filter_map(|payload| { + let data = match serde_json::from_str::>(payload.get()) { Ok(data) => Some(data), Err(e) => { - tracing::warn!(channel = %notif.channel(), "error decoding response: {e}"); + tracing::warn!("error decoding response: {e}"); None } }; @@ -605,14 +707,13 @@ impl CustomPostgresAdapter { .take(remote_serv_cnt) .take_until(tokio::time::sleep(self.config.request_timeout)); - let stream = DropStream::new(stream, self.responses.clone(), req_id); Ok(stream) } /// Emit a heartbeat to the specified target node or broadcast to all nodes. async fn emit_heartbeat(&self, target: Option) -> Result<(), Error> { self.send_req( - RequestOut::new_empty(self.uid, RequestTypeOut::Heartbeat), + RequestOut::new_empty(self.local.server_id(), RequestTypeOut::Heartbeat), target, ) .await @@ -622,11 +723,26 @@ impl CustomPostgresAdapter { async fn emit_init_heartbeat(&self) -> Result<(), Error> { // Send initial heartbeat when starting. self.send_req( - RequestOut::new_empty(self.uid, RequestTypeOut::InitHeartbeat), + RequestOut::new_empty(self.local.server_id(), RequestTypeOut::InitHeartbeat), None, ) .await } + + fn get_global_chan(&self) -> String { + format!("{}#{}", self.config.prefix, self.local.path()) + } + fn get_node_chan(&self, uid: Uid) -> String { + format!("{}#{}", self.get_global_chan(), uid) + } + fn get_response_chan(&self, uid: Uid) -> String { + format!( + "{}-response#{}#{}", + &self.config.prefix, + self.local.path(), + uid + ) + } } /// The result of the init future. @@ -649,3 +765,10 @@ impl Spawnable for InitRes { }); } } + +#[derive(Deserialize, Serialize)] +struct ResponsePacket { + req_id: Sid, + node_id: Uid, + payload: Box, +} diff --git a/crates/socketioxide-postgres/src/stream.rs b/crates/socketioxide-postgres/src/stream.rs index 5e7a945f..3fec6df7 100644 --- a/crates/socketioxide-postgres/src/stream.rs +++ b/crates/socketioxide-postgres/src/stream.rs @@ -1,6 +1,7 @@ use std::{ fmt, pin::Pin, + sync::{Arc, Mutex}, task::{self, Poll}, time::Duration, }; @@ -9,13 +10,17 @@ use futures_core::{FusedStream, Stream}; use futures_util::{StreamExt, stream::TakeUntil}; use pin_project_lite::pin_project; use serde::de::DeserializeOwned; +use serde_json::value::RawValue; use socketioxide_core::{ - adapter::AckStreamItem, - adapter::remote_packet::{Response, ResponseType}, + Sid, + adapter::{ + AckStreamItem, + remote_packet::{Response, ResponseType}, + }, }; use tokio::{sync::mpsc, time}; -use crate::drivers::{NotifStream, Notification}; +use crate::{ResponseHandlers, drivers::Notification}; pin_project! { /// A stream of acknowledgement messages received from the local and remote servers. @@ -27,20 +32,28 @@ pin_project! { // And it is decremented each time an ack is received. // // Therefore an exhausted stream correspond to `ack_cnt == 0` and `server_cnt == 0`. - pub struct AckStream { + pub struct AckStream { #[pin] local: S, #[pin] - remote: TakeUntil, time::Sleep>, + remote: DropStream>, time::Sleep>>, ack_cnt: u32, total_ack_cnt: usize, serv_cnt: u16, } } -impl AckStream { - pub fn new(local: S, remote: NotifStream, timeout: Duration, serv_cnt: u16) -> Self { - let remote = remote.take_until(time::sleep(timeout)); +impl AckStream { + pub fn new( + local: S, + remote: mpsc::Receiver>, + timeout: Duration, + serv_cnt: u16, + req_sid: Sid, + handlers: Arc>, + ) -> Self { + let remote = ChanStream::new(remote).take_until(time::sleep(timeout)); + let remote = DropStream::new(remote, handlers, req_sid); Self { local, ack_cnt: 0, @@ -51,8 +64,10 @@ impl AckStream { } pub fn new_local(local: S) -> Self { - let rx = mpsc::unbounded_channel().1; - let remote = NotifStream::new(rx).take_until(time::sleep(Duration::ZERO)); + let handlers = Arc::new(Mutex::new(ResponseHandlers::new())); + let rx = mpsc::channel(1).1; + let remote = ChanStream::new(rx).take_until(time::sleep(Duration::ZERO)); + let remote = DropStream::new(remote, handlers, Sid::ZERO); Self { local, remote, @@ -62,13 +77,13 @@ impl AckStream { } } } -impl AckStream +impl AckStream where Err: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, { - /// Poll the remote stream. First the count of acks is received, then the acks are received. - /// We expect `serv_cnt` of `BroadcastAckCount` messages to be received, then we expect + /// Poll the remote stream. First the count of acks is receivedhen the acks are received. + /// We expect `serv_cnt` of `BroadcastAckCount` messages to be receivedhen we expect /// `ack_cnt` of `BroadcastAck` messages. fn poll_remote( self: Pin<&mut Self>, @@ -84,14 +99,13 @@ where Poll::Pending => return Poll::Pending, Poll::Ready(None) => return Poll::Ready(None), Poll::Ready(Some(notif)) => { - let channel = notif.channel(); - let res = serde_json::from_str::>(notif.payload()); + let res = serde_json::from_str::>(notif.get()); match res { Ok(Response { node_id: uid, r#type: ResponseType::BroadcastAckCount(count), }) if *projection.serv_cnt > 0 => { - tracing::trace!(?uid, channel, "receiving broadcast ack count {count}"); + tracing::trace!(?uid, "receiving broadcast ack count {count}"); *projection.ack_cnt += count; *projection.total_ack_cnt += count as usize; *projection.serv_cnt -= 1; @@ -100,17 +114,12 @@ where node_id: uid, r#type: ResponseType::BroadcastAck((sid, res)), }) if *projection.ack_cnt > 0 => { - tracing::trace!( - ?uid, - channel, - "receiving broadcast ack {sid} {:?}", - res - ); + tracing::trace!(?uid, "receiving broadcast ack {sid} {:?}", res); *projection.ack_cnt -= 1; return Poll::Ready(Some((sid, res))); } Ok(Response { node_id: uid, .. }) => { - tracing::warn!(?uid, channel, "unexpected response type"); + tracing::warn!(?uid, "unexpected response type"); } Err(e) => { tracing::warn!("error decoding ack response: {e}"); @@ -121,11 +130,10 @@ where } } } -impl Stream for AckStream +impl Stream for AckStream where E: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, - T: Notification, { type Item = AckStreamItem; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { @@ -146,11 +154,10 @@ where } } -impl FusedStream for AckStream +impl FusedStream for AckStream where Err: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, - T: Notification, { /// The stream is terminated if: /// * The local stream is terminated. @@ -162,7 +169,7 @@ where self.local.is_terminated() && remote_term } } -impl fmt::Debug for AckStream { +impl fmt::Debug for AckStream { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("AckStream") .field("ack_cnt", &self.ack_cnt) @@ -171,6 +178,64 @@ impl fmt::Debug for AckStream { .finish() } } +pin_project! { + /// A stream of messages received from a channel. + pub struct ChanStream { + #[pin] + rx: mpsc::Receiver + } +} +impl ChanStream { + pub fn new(rx: mpsc::Receiver) -> Self { + Self { rx } + } +} +impl Stream for ChanStream { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().rx.poll_recv(cx) + } +} + +pin_project! { + /// A stream that unsubscribes from its source channel when dropped. + pub struct DropStream { + #[pin] + stream: S, + req_id: Sid, + handlers: Arc> + } + impl PinnedDrop for DropStream { + fn drop(this: Pin<&mut Self>) { + let stream = this.project(); + let chan = stream.req_id; + tracing::debug!(?chan, "dropping stream"); + stream.handlers.lock().unwrap().remove(chan); + } + } +} +impl DropStream { + pub fn new(stream: S, handlers: Arc>, req_id: Sid) -> Self { + Self { + stream, + handlers, + req_id, + } + } +} +impl Stream for DropStream { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().stream.poll_next(cx) + } +} +impl FusedStream for DropStream { + fn is_terminated(&self) -> bool { + self.stream.is_terminated() + } +} #[cfg(test)] mod tests { @@ -186,7 +251,7 @@ mod tests { let local = futures_util::stream::once(async move { (sid, Ok::<_, ()>(Value::Str("local".into(), None))) }); - let stream = AckStream::new_local(local); + let stream = AckStream::<_>::new_local(local); futures_util::pin_mut!(stream); assert_eq!(stream.ack_cnt, 0); assert_eq!(stream.total_ack_cnt, 0); diff --git a/e2e/adapter/Cargo.toml b/e2e/adapter/Cargo.toml index 1e94a487..cc7ea709 100644 --- a/e2e/adapter/Cargo.toml +++ b/e2e/adapter/Cargo.toml @@ -22,6 +22,11 @@ socketioxide-redis = { path = "../../crates/socketioxide-redis", features = [ "fred", ] } socketioxide-mongodb = { path = "../../crates/socketioxide-mongodb" } +socketioxide-postgres = { path = "../../crates/socketioxide-postgres", features = [ + "sqlx", + "postgres", +] } + hyper-util = { workspace = true, features = ["tokio"] } hyper = { workspace = true, features = ["server", "http1"] } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } @@ -80,3 +85,11 @@ path = "src/bins/mongodb_ttl.rs" [[bin]] name = "mongodb-ttl-e2e-msgpack" path = "src/bins/mongodb_ttl_msgpack.rs" + +[[bin]] +name = "sqlx-e2e" +path = "src/bins/sqlx.rs" + +[[bin]] +name = "sqlx-e2e-msgpack" +path = "src/bins/sqlx_msgpack.rs" diff --git a/e2e/adapter/main.rs b/e2e/adapter/main.rs index f174b849..cd2f0bf9 100644 --- a/e2e/adapter/main.rs +++ b/e2e/adapter/main.rs @@ -3,7 +3,7 @@ use std::fs; use std::process::{Child, Command}; use std::time::Duration; -const BINS: [&str; 12] = [ +const BINS: &[&str] = &[ "fred-e2e", "fred-e2e-msgpack", "redis-e2e", @@ -16,27 +16,34 @@ const BINS: [&str; 12] = [ "mongodb-ttl-e2e-msgpack", "mongodb-capped-e2e", "mongodb-capped-e2e-msgpack", + "sqlx-e2e", + "sqlx-e2e-msgpack", ]; const EXEC_SUFFIX: &str = if cfg!(windows) { ".exe" } else { "" }; const LOG_DIR: &str = "e2e/adapter/logs"; -fn main() { - let filter = args().skip(1).next().unwrap_or("".to_string()); - println!("filter: {}", filter); +fn main() -> Result<(), Box> { + let bin_filter = args().nth(1).unwrap_or("".to_string()); + println!("binary target filter: {}", bin_filter); - if fs::exists(LOG_DIR).unwrap() { - fs::remove_dir_all(LOG_DIR).unwrap(); + let test_filter = args().nth(2); + println!("test filter: {}", test_filter.as_deref().unwrap_or("*")); + + if fs::exists(LOG_DIR)? { + fs::remove_dir_all(LOG_DIR)?; } - fs::create_dir_all(LOG_DIR).unwrap(); + fs::create_dir_all(LOG_DIR)?; // run everything - for target in BINS.into_iter().filter(|name| name.contains(&filter)) { - run(target); + for target in BINS.iter().filter(|name| name.contains(&bin_filter)) { + run(target, test_filter.as_deref()); } println!("All tests passed!"); + + Ok(()) } -fn run(target: &'static str) { +fn run(target: &'static str, test_filter: Option<&str>) { let parser = if target.ends_with("msgpack") { "msgpack" } else { @@ -50,10 +57,15 @@ fn run(target: &'static str) { std::thread::sleep(Duration::from_millis(200)); - let child = Command::new("node") - .arg("--experimental-strip-types") - .arg("--test-reporter=spec") - .arg("--test") + let mut cmd = Command::new("node"); + + cmd.arg("--test-reporter=spec").arg("--test"); + + if let Some(filter) = test_filter { + cmd.arg(format!("--test-name-pattern=\"{filter}\"")); + } + + let child = cmd .arg("e2e/adapter/client.ts") .env("PORTS", "3000,3001,3002") .env("PARSER", parser) diff --git a/e2e/adapter/src/bins/sqlx.rs b/e2e/adapter/src/bins/sqlx.rs new file mode 100644 index 00000000..fb260072 --- /dev/null +++ b/e2e/adapter/src/bins/sqlx.rs @@ -0,0 +1,64 @@ +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use socketioxide::SocketIo; + +use socketioxide_postgres::{ + CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, + drivers::sqlx::{SqlxDriver, sqlx_client::PgPool}, +}; +use tokio::net::TcpListener; +use tracing::{Level, info}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + let variant = std::env::args().next().unwrap(); + let variant = variant.split("/").last().unwrap(); + + let config = PostgresAdapterConfig { + prefix: format!("socket.io-{variant}").into(), + ..Default::default() + }; + + let pg_pool = PgPool::connect("postgres://socketio:socketio@localhost:5432/socketio").await?; + let adapter = PostgresAdapterCtr::new_with_driver(SqlxDriver::new(pg_pool), config); + let (svc, io) = SocketIo::builder() + .with_adapter::>(adapter) + .build_svc(); + + io.ns("/", adapter_e2e::handler).await?; + + info!("Starting server with v5 protocol"); + let port: u16 = std::env::var("PORT") + .expect("a PORT env var should be set") + .parse()?; + + let listener = TcpListener::bind(("127.0.0.1", port)).await?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} diff --git a/e2e/adapter/src/bins/sqlx_msgpack.rs b/e2e/adapter/src/bins/sqlx_msgpack.rs new file mode 100644 index 00000000..d7f420f3 --- /dev/null +++ b/e2e/adapter/src/bins/sqlx_msgpack.rs @@ -0,0 +1,65 @@ +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use socketioxide::{ParserConfig, SocketIo}; + +use socketioxide_postgres::{ + CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, + drivers::sqlx::{SqlxDriver, sqlx_client::PgPool}, +}; +use tokio::net::TcpListener; +use tracing::{Level, info}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + let variant = std::env::args().next().unwrap(); + let variant = variant.split("/").last().unwrap(); + + let config = PostgresAdapterConfig { + prefix: format!("socket.io-{variant}").into(), + ..Default::default() + }; + + let pg_pool = PgPool::connect("postgres://socketio:socketio@localhost:5432/socketio").await?; + let adapter = PostgresAdapterCtr::new_with_driver(SqlxDriver::new(pg_pool), config); + let (svc, io) = SocketIo::builder() + .with_parser(ParserConfig::msgpack()) + .with_adapter::>(adapter) + .build_svc(); + + io.ns("/", adapter_e2e::handler).await?; + + info!("Starting server with v5 protocol"); + let port: u16 = std::env::var("PORT") + .expect("a PORT env var should be set") + .parse()?; + + let listener = TcpListener::bind(("127.0.0.1", port)).await?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} From 7acdde2ff040ee0400422d95981353a5e1dd2208 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 14:38:38 +0200 Subject: [PATCH 06/31] feat(adapter/postgre): wip --- e2e/adapter/main.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/e2e/adapter/main.rs b/e2e/adapter/main.rs index cd2f0bf9..f18af801 100644 --- a/e2e/adapter/main.rs +++ b/e2e/adapter/main.rs @@ -57,15 +57,9 @@ fn run(target: &'static str, test_filter: Option<&str>) { std::thread::sleep(Duration::from_millis(200)); - let mut cmd = Command::new("node"); - - cmd.arg("--test-reporter=spec").arg("--test"); - - if let Some(filter) = test_filter { - cmd.arg(format!("--test-name-pattern=\"{filter}\"")); - } - - let child = cmd + let child = Command::new("node") + .arg("--test-reporter=spec") + .arg("--test") .arg("e2e/adapter/client.ts") .env("PORTS", "3000,3001,3002") .env("PARSER", parser) From db72426427ae3fee59763ab6bd77333481c9f4e2 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 14:39:45 +0200 Subject: [PATCH 07/31] feat(adapter/postgre): wip --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index af7063a2..5485323b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace.package] edition = "2024" -rust-version = "1.86.0" +rust-version = "1.88.0" authors = ["Théodore Prévot <"] repository = "https://github.com/totodore/socketioxide" homepage = "https://github.com/totodore/socketioxide" From 3453a8680d0613f664fb4edaed1c2439d2a9f6ba Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 14:48:45 +0200 Subject: [PATCH 08/31] feat(adapter/postgre): wip --- crates/socketioxide-postgres/src/drivers/mod.rs | 1 + crates/socketioxide-postgres/src/drivers/sqlx.rs | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 94de7f17..57fa558c 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,6 +1,7 @@ use futures_core::Stream; use serde::Serialize; +#[cfg(feature = "sqlx")] pub mod sqlx; /// The driver trait can be used to support different LISTEN/NOTIFY backends. diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index a47b5169..f8977284 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -63,11 +63,11 @@ impl Driver for SqlxDriver { ) -> impl Future> + Send { let client = self.client.clone(); //TODO: handle error - let msg = serde_json::to_string(req).unwrap(); + let msg = serde_json::to_string(req).map_err(|err| sqlx::Error::Decode(Box::new(err))); async move { sqlx::query("SELECT pg_notify($1, $2)") .bind(channel) - .bind(msg) + .bind(msg?) .execute(&client) .await?; Ok(()) From aa5430526c38b6013b63a940bbba392e0b0ab6c0 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 15:23:29 +0200 Subject: [PATCH 09/31] feat(adapter/postgre): wip --- Cargo.lock | 1 - crates/socketioxide-postgres/Cargo.toml | 1 - .../socketioxide-postgres/src/drivers/mod.rs | 8 ++- .../src/drivers/postgres.rs | 59 ++++++++++++------- .../socketioxide-postgres/src/drivers/sqlx.rs | 24 +++----- crates/socketioxide-postgres/src/lib.rs | 9 ++- 6 files changed, 57 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1dae8524..845b0369 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2614,7 +2614,6 @@ dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "rmp-serde", "serde", "serde_json", "smallvec", diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index 4029a391..3aa2604e 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -28,7 +28,6 @@ serde.workspace = true serde_json = { workspace = true, features = ["raw_value"] } smallvec = { workspace = true, features = ["serde"] } tokio = { workspace = true, features = ["time", "rt", "sync"] } -rmp-serde.workspace = true tracing.workspace = true thiserror.workspace = true diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 57fa558c..6b64b59a 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,9 +1,11 @@ use futures_core::Stream; -use serde::Serialize; #[cfg(feature = "sqlx")] pub mod sqlx; +// #[cfg(feature = "postgres")] +// pub mod postgres; + /// The driver trait can be used to support different LISTEN/NOTIFY backends. /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { @@ -18,10 +20,10 @@ pub trait Driver: Clone + Send + Sync + 'static { channels: &[&str], ) -> impl Future> + Send; - fn notify( + fn notify( &self, channel: &str, - message: &T, + message: &str, ) -> impl Future> + Send; } diff --git a/crates/socketioxide-postgres/src/drivers/postgres.rs b/crates/socketioxide-postgres/src/drivers/postgres.rs index 0d6d408a..fea516a5 100644 --- a/crates/socketioxide-postgres/src/drivers/postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/postgres.rs @@ -1,46 +1,63 @@ -use std::sync::Arc; +use std::{pin::Pin, sync::Arc}; -use tokio_postgres::{Client, Connection}; +use futures_core::{Stream, stream::BoxStream}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_postgres::{AsyncMessage, Client, Config, Connection, Socket}; use super::Driver; #[derive(Debug, Clone)] pub struct PostgresDriver { client: Arc, + config: Config, } impl PostgresDriver { - pub fn new(client: Client, connection: Connection) -> Self { - PostgresDriver { - client: Arc::new(client), - } + pub fn new(config: Config) -> Self + where + T: AsyncRead + AsyncWrite + Unpin, + { + PostgresDriver { config } } } impl Driver for PostgresDriver { type Error = tokio_postgres::Error; - type NotifStream; + type Notification = tokio_postgres::Notification; + type NotificationStream = BoxStream<'static, Self::Notification>; - async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { + async fn init( + &self, + table: &str, + channels: &[&str], + ) -> Result { + let st = &format!( + r#"CREATE TABLE IF NOT EXISTS "{table}" ( + id BIGSERIAL UNIQUE, + created_at TIMESTAMPTZ DEFAULT NOW(), + payload BYTEA + )"# + ); + + self.client.execute(st, &[]).await?; + + Ok(()) + } + + async fn notify(&self, channel: &str, message: &str) -> Result<(), Self::Error> { self.client - .execute("CREATE TABLE $1 IF NOT EXISTS", &[&table]) + .execute("SELECT pg_notify($1, $2)", &[&channel, &message]) .await?; - Ok(()) } +} - fn listen( - &self, - channel: &str, - ) -> impl Future, Self::Error>> + Send { - todo!() +impl super::Notification for tokio_postgres::Notification { + fn channel(&self) -> &str { + tokio_postgres::Notification::channel(self) } - fn notify( - &self, - channel: &str, - message: &T, - ) -> impl Future> + Send { - todo!() + fn payload(&self) -> &str { + tokio_postgres::Notification::payload(self) } } diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index f8977284..d23a05f8 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -1,6 +1,5 @@ use futures_core::stream::BoxStream; use futures_util::StreamExt; -use serde::Serialize; use sqlx::{ PgPool, postgres::{PgListener, PgNotification}, @@ -56,22 +55,13 @@ impl Driver for SqlxDriver { Ok(Box::pin(stream)) } - fn notify( - &self, - channel: &str, - req: &T, - ) -> impl Future> + Send { - let client = self.client.clone(); - //TODO: handle error - let msg = serde_json::to_string(req).map_err(|err| sqlx::Error::Decode(Box::new(err))); - async move { - sqlx::query("SELECT pg_notify($1, $2)") - .bind(channel) - .bind(msg?) - .execute(&client) - .await?; - Ok(()) - } + async fn notify(&self, channel: &str, message: &str) -> Result<(), Self::Error> { + sqlx::query("SELECT pg_notify($1, $2)") + .bind(channel) + .bind(message) + .execute(&self.client) + .await?; + Ok(()) } } diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index f4ee9a23..6a142d90 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -640,8 +640,9 @@ impl CustomPostgresAdapter { Some(target) => self.get_node_chan(target), None => self.get_global_chan(), }; + let payload = serde_json::to_string(&req)?; self.driver - .notify(&chan, &req) + .notify(&chan, &payload) .await .map_err(Error::Driver)?; Ok(()) @@ -666,9 +667,13 @@ impl CustomPostgresAdapter { node_id: self.local.server_id(), payload, }; + let message = serde_json::to_string(&res); //TODO: is this the right way? async move { - driver.notify(&chan, &res).await.map_err(Error::Driver)?; + driver + .notify(&chan, &message?) + .await + .map_err(Error::Driver)?; Ok(()) } } From 874799247b99f200bdbd8ac792ca275ef235dd45 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 18:28:52 +0200 Subject: [PATCH 10/31] feat(adapter/postgre): wip --- crates/socketioxide-postgres/README.md | 137 ++++++++++++++ .../socketioxide-postgres/src/drivers/mod.rs | 14 ++ .../socketioxide-postgres/src/drivers/sqlx.rs | 3 + crates/socketioxide-postgres/src/lib.rs | 172 ++++++++++++++++-- e2e/adapter/src/bins/sqlx.rs | 12 +- e2e/adapter/src/bins/sqlx_msgpack.rs | 13 +- 6 files changed, 317 insertions(+), 34 deletions(-) diff --git a/crates/socketioxide-postgres/README.md b/crates/socketioxide-postgres/README.md index e69de29b..dd9781db 100644 --- a/crates/socketioxide-postgres/README.md +++ b/crates/socketioxide-postgres/README.md @@ -0,0 +1,137 @@ +# [`Socketioxide-Postgres`](https://github.com/totodore/socketioxide) 🚀🦀 + +A [***`socket.io`***](https://socket.io) adapter for [***`Socketioxide`***](https://github.com/totodore/socketioxide), using [PostgreSQL LISTEN/NOTIFY](https://www.postgresql.org/docs/current/sql-notify.html) for event broadcasting. This adapter enables **horizontal scaling** of your Socketioxide servers across distributed deployments by leveraging PostgreSQL as a message bus. + +[![Crates.io](https://img.shields.io/crates/v/socketioxide-postgres.svg)](https://crates.io/crates/socketioxide-postgres) +[![Documentation](https://docs.rs/socketioxide-postgres/badge.svg)](https://docs.rs/socketioxide-postgres) +[![CI](https://github.com/Totodore/socketioxide/actions/workflows/github-ci.yml/badge.svg)](https://github.com/Totodore/socketioxide/actions/workflows/github-ci.yml) + + + +## Features + +- **PostgreSQL LISTEN/NOTIFY-based adapter** +- **Support for any PostgreSQL client** via the [`Driver`] abstraction +- Built-in driver for the [sqlx](https://docs.rs/sqlx) crate: [`SqlxDriver`](https://docs.rs/socketioxide-postgres/latest/socketioxide_postgres/drivers/sqlx/struct.SqlxDriver.html) +- **Heartbeat-based liveness detection** for tracking active server nodes +- Fully compatible with the asynchronous Rust ecosystem +- Implement your own custom driver by implementing the `Driver` trait + +> [!WARNING] +> This adapter is **not compatible** with [`@socket.io/postgres-adapter`](https://github.com/socketio/socket.io-postgres-adapter). +> These projects use entirely different protocols and cannot interoperate. +> **Do not mix Socket.IO JavaScript servers with Socketioxide Rust servers**. + + + +## Example: Using the PostgreSQL Adapter with Axum + +```rust +use serde::{Deserialize, Serialize}; +use socketioxide::{ + adapter::Adapter, + extract::{Data, Extension, SocketRef}, + SocketIo, +}; +use socketioxide_postgres::{ + drivers::sqlx::sqlx_client::{self as sqlx, PgPool}, + SqlxAdapter, PostgresAdapterCtr, PostgresAdapterConfig, +}; +use tower::ServiceBuilder; +use tower_http::{cors::CorsLayer, services::ServeDir}; +use tracing::info; +use tracing_subscriber::FmtSubscriber; + +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(transparent)] +struct Username(String); + +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(rename_all = "camelCase", untagged)] +enum Res { + Message { + username: Username, + message: String, + }, + Username { + username: Username, + }, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::new(); + + tracing::subscriber::set_global_default(subscriber)?; + + info!("Starting server"); + + let pool = PgPool::connect("postgres://user:password@localhost/socketio").await?; + let adapter = PostgresAdapterCtr::new_with_sqlx(pool); + + let (layer, io) = SocketIo::builder() + .with_adapter::>(adapter) + .build_layer(); + io.ns("/", on_connect).await?; + + let app = axum::Router::new() + .fallback_service(ServeDir::new("dist")) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) // Enable CORS policy + .layer(layer), + ); + + let port = std::env::var("PORT") + .map(|s| s.parse().unwrap()) + .unwrap_or(3000); + let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) + .await + .unwrap(); + axum::serve(listener, app).await.unwrap(); + + Ok(()) +} + +async fn on_connect(socket: SocketRef) { + socket.on("new message", on_msg); + socket.on("typing", on_typing); + socket.on("stop typing", on_stop_typing); +} +async fn on_msg( + s: SocketRef, + Data(msg): Data, + Extension(username): Extension, +) { + let msg = &Res::Message { + username, + message: msg, + }; + s.broadcast().emit("new message", msg).await.ok(); +} +async fn on_typing(s: SocketRef, Extension(username): Extension) { + s.broadcast() + .emit("typing", &Res::Username { username }) + .await + .ok(); +} +async fn on_stop_typing(s: SocketRef, Extension(username): Extension) { + s.broadcast() + .emit("stop typing", &Res::Username { username }) + .await + .ok(); +} + +``` + + + +## Contributions and Feedback / Questions + +Contributions are very welcome! Feel free to open an issue or a PR. If you're unsure where to start, check the [issues](https://github.com/totodore/socketioxide/issues). + +For feedback or questions, join the discussion on the [discussions](https://github.com/totodore/socketioxide/discussions) page. + +## License 🔐 + +This project is licensed under the [MIT license](./LICENSE). diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 6b64b59a..10e1f5cd 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,5 +1,9 @@ +//! Drivers are an abstraction over the PostgreSQL LISTEN/NOTIFY backend used by the adapter. +//! You can use the provided implementation or implement your own. + use futures_core::Stream; +/// A driver implementation for the [`sqlx`](https://docs.rs/sqlx) PostgreSQL backend. #[cfg(feature = "sqlx")] pub mod sqlx; @@ -9,17 +13,24 @@ pub mod sqlx; /// The driver trait can be used to support different LISTEN/NOTIFY backends. /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { + /// The error type returned by the driver. type Error: std::error::Error + Send + 'static; + /// The notification type yielded by the notification stream. type Notification: Notification; + /// The stream of notifications returned by [`Driver::listen`]. type NotificationStream: Stream + Send; + /// Initialize the driver. This is called once when the adapter is created. + /// It should create the necessary tables or schema if needed. fn init(&self, table: &str) -> impl Future> + Send; + /// Subscribe to the given NOTIFY channels and return a stream of notifications. fn listen( &self, channels: &[&str], ) -> impl Future> + Send; + /// Send a NOTIFY message on the given channel with the given payload. fn notify( &self, channel: &str, @@ -27,7 +38,10 @@ pub trait Driver: Clone + Send + Sync + 'static { ) -> impl Future> + Send; } +/// A trait representing a PostgreSQL NOTIFY notification. pub trait Notification: Send + 'static { + /// The channel name on which the notification was received. fn channel(&self) -> &str; + /// The payload of the notification. fn payload(&self) -> &str; } diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index d23a05f8..3241dfac 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -9,6 +9,9 @@ use super::Driver; pub use sqlx as sqlx_client; +/// A [`Driver`] implementation using the [`sqlx`] PostgreSQL client. +/// +/// It uses [`PgListener`] for LISTEN/NOTIFY and [`PgPool`] for queries. #[derive(Debug, Clone)] pub struct SqlxDriver { client: PgPool, diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 6a142d90..ec09d9e7 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -29,7 +29,55 @@ nonstandard_style, missing_docs )] -//! test +//! # A PostgreSQL adapter implementation for the socketioxide crate. +//! The adapter is used to communicate with other nodes of the same application. +//! This allows to broadcast messages to sockets connected on other servers, +//! to get the list of rooms, to add or remove sockets from rooms, etc. +//! +//! To achieve this, the adapter uses [LISTEN/NOTIFY](https://www.postgresql.org/docs/current/sql-notify.html) +//! through PostgreSQL to communicate with other servers. +//! +//! The [`Driver`] abstraction allows the use of any PostgreSQL client. +//! One implementation is provided: +//! * [`SqlxDriver`](crate::drivers::sqlx::SqlxDriver) for the [`sqlx`] crate. +//! +//! You can also implement your own driver by implementing the [`Driver`] trait. +//! +//!
+//! Socketioxide-postgres is not compatible with @socketio/postgres-adapter. +//! They use completely different protocols and cannot be used together. +//! Do not mix socket.io JS servers with socketioxide rust servers. +//!
+//! +//! ## How does it work? +//! +//! The [`PostgresAdapterCtr`] is a constructor for the [`SqlxAdapter`] which is an implementation of +//! the [`Adapter`](https://docs.rs/socketioxide/latest/socketioxide/adapter/trait.Adapter.html) trait. +//! +//! Then, for each namespace, an adapter is created and it takes a corresponding [`CoreLocalAdapter`]. +//! The [`CoreLocalAdapter`] allows to manage the local rooms and local sockets. The default `LocalAdapter` +//! is simply a wrapper around this [`CoreLocalAdapter`]. +//! +//! Once it is created the adapter is initialized with the [`CustomPostgresAdapter::init`] method. +//! It will subscribe to three PostgreSQL NOTIFY channels and emit heartbeats. +//! All messages are encoded with JSON. +//! +//! There are 7 types of requests: +//! * Broadcast a packet to all the matching sockets. +//! * Broadcast a packet to all the matching sockets and wait for a stream of acks. +//! * Disconnect matching sockets. +//! * Get all the rooms. +//! * Add matching sockets to rooms. +//! * Remove matching sockets from rooms. +//! * Fetch all the remote sockets matching the options. +//! * Heartbeat +//! * Initial heartbeat. When receiving an initial heartbeat all other servers reply a heartbeat immediately. +//! +//! For ack streams, the adapter will first send a `BroadcastAckCount` response to the server that sent the request, +//! and then send the acks as they are received (more details in [`CustomPostgresAdapter::broadcast_with_ack`] fn). +//! +//! On the other side, each time an action has to be performed on the local server, the adapter will +//! first broadcast a request to all the servers and then perform the action locally. use drivers::Driver; use futures_core::Stream; @@ -68,7 +116,7 @@ use crate::{ pub mod drivers; mod stream; -/// The configuration of the [`MongoDbAdapter`]. +/// The configuration of the [`CustomPostgresAdapter`]. #[derive(Debug, Clone)] pub struct PostgresAdapterConfig { /// The heartbeat timeout duration. If a remote node does not respond within this duration, @@ -91,11 +139,75 @@ pub struct PostgresAdapterConfig { pub table_name: Cow<'static, str>, /// The prefix used for the channels. Default is "socket.io". pub prefix: Cow<'static, str>, - /// The treshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: + /// The threshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: /// . By default it is 8KB (8000 bytes). - pub payload_treshold: usize, + pub payload_threshold: usize, /// The duration between cleanup queries on the attachment table. - pub cleanup_intervals: Duration, + pub cleanup_interval: Duration, +} + +impl PostgresAdapterConfig { + /// Create a new [`PostgresAdapterConfig`] with default values. + pub fn new() -> Self { + Self::default() + } + + /// The heartbeat timeout duration. If a remote node does not respond within this duration, + /// it will be considered disconnected. Default is 60 seconds. + pub fn with_hb_timeout(mut self, hb_timeout: Duration) -> Self { + self.hb_timeout = hb_timeout; + self + } + + /// The heartbeat interval duration. The current node will broadcast a heartbeat to the + /// remote nodes at this interval. Default is 10 seconds. + pub fn with_hb_interval(mut self, hb_interval: Duration) -> Self { + self.hb_interval = hb_interval; + self + } + + /// The request timeout. When expecting a response from remote nodes, if they do not respond within + /// this duration, the request will be considered failed. Default is 5 seconds. + pub fn with_request_timeout(mut self, request_timeout: Duration) -> Self { + self.request_timeout = request_timeout; + self + } + + /// The channel size used to receive ack responses. Default is 255. + /// + /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster + /// than you poll them with the returned stream, you might want to increase this value. + pub fn with_ack_response_buffer(mut self, ack_response_buffer: usize) -> Self { + self.ack_response_buffer = ack_response_buffer; + self + } + + /// The table name used to store socket.io attachments. Default is "socket_io_attachments". + /// + /// > The table name must be a sanitized string. Do not use special characters or spaces. + pub fn with_table_name(mut self, table_name: impl Into>) -> Self { + self.table_name = table_name.into(); + self + } + + /// The prefix used for the channels. Default is "socket.io". + pub fn with_prefix(mut self, prefix: impl Into>) -> Self { + self.prefix = prefix.into(); + self + } + + /// The threshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: + /// . By default it is 8KB (8000 bytes). + pub fn with_payload_threshold(mut self, payload_threshold: usize) -> Self { + self.payload_threshold = payload_threshold; + self + } + + /// The duration between cleanup queries on the attachment table. Default is 60 seconds. + pub fn with_cleanup_interval(mut self, cleanup_interval: Duration) -> Self { + self.cleanup_interval = cleanup_interval; + self + } } impl Default for PostgresAdapterConfig { @@ -107,8 +219,8 @@ impl Default for PostgresAdapterConfig { ack_response_buffer: 255, table_name: "socket_io_attachments".into(), prefix: "socket.io".into(), - payload_treshold: 8_000, - cleanup_intervals: Duration::from_secs(60), + payload_threshold: 8_000, + cleanup_interval: Duration::from_secs(60), } } } @@ -116,7 +228,7 @@ impl Default for PostgresAdapterConfig { /// Represent any error that might happen when using this adapter. #[derive(thiserror::Error)] pub enum Error { - /// Mongo driver error + /// Postgres driver error #[error("driver error: {0}")] Driver(D::Error), /// Packet encoding/decoding error @@ -139,12 +251,33 @@ impl From> for AdapterError { } } -/// Constructor for the PostgresAdapterCtr struct. +/// The adapter constructor. For each namespace you define, a new adapter instance is created +/// from this constructor. +#[derive(Debug, Clone)] pub struct PostgresAdapterCtr { driver: D, config: PostgresAdapterConfig, } +#[cfg(feature = "sqlx")] +impl PostgresAdapterCtr { + /// Create a new adapter constructor with the [`sqlx`](drivers::sqlx) driver + /// and a default config. + pub fn new_with_sqlx(pool: drivers::sqlx::sqlx_client::PgPool) -> Self { + Self::new_with_sqlx_config(pool, PostgresAdapterConfig::default()) + } + + /// Create a new adapter constructor with the [`sqlx`](drivers::sqlx) driver + /// and a custom config. + pub fn new_with_sqlx_config( + pool: drivers::sqlx::sqlx_client::PgPool, + config: PostgresAdapterConfig, + ) -> Self { + let driver = drivers::sqlx::SqlxDriver::new(pool); + Self { driver, config } + } +} + impl PostgresAdapterCtr { /// Create a new adapter constructor with a custom postgres driver and a config. /// @@ -155,6 +288,10 @@ impl PostgresAdapterCtr { } } +/// The postgres adapter with the [`sqlx`](drivers::sqlx) driver. +#[cfg(feature = "sqlx")] +pub type SqlxAdapter = CustomPostgresAdapter; + type ResponseHandlers = HashMap>>; /// The postgres adapter implementation. @@ -661,14 +798,15 @@ impl CustomPostgresAdapter { ); let driver = self.driver.clone(); let chan = self.get_response_chan(req_origin); - let payload = RawValue::from_string(serde_json::to_string(&payload).unwrap()).unwrap(); - let res = ResponsePacket { - req_id, - node_id: self.local.server_id(), - payload, - }; - let message = serde_json::to_string(&res); - //TODO: is this the right way? + let message = serde_json::to_string(&payload) + .and_then(RawValue::from_string) + .map(|payload| ResponsePacket { + req_id, + node_id: self.local.server_id(), + payload, + }) + .and_then(|res| serde_json::to_string(&res)); + async move { driver .notify(&chan, &message?) diff --git a/e2e/adapter/src/bins/sqlx.rs b/e2e/adapter/src/bins/sqlx.rs index fb260072..37c457ba 100644 --- a/e2e/adapter/src/bins/sqlx.rs +++ b/e2e/adapter/src/bins/sqlx.rs @@ -3,8 +3,7 @@ use hyper_util::rt::TokioIo; use socketioxide::SocketIo; use socketioxide_postgres::{ - CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, - drivers::sqlx::{SqlxDriver, sqlx_client::PgPool}, + PostgresAdapterConfig, PostgresAdapterCtr, SqlxAdapter, drivers::sqlx::sqlx_client::PgPool, }; use tokio::net::TcpListener; use tracing::{Level, info}; @@ -20,15 +19,12 @@ async fn main() -> Result<(), Box> { let variant = std::env::args().next().unwrap(); let variant = variant.split("/").last().unwrap(); - let config = PostgresAdapterConfig { - prefix: format!("socket.io-{variant}").into(), - ..Default::default() - }; + let config = PostgresAdapterConfig::new().with_prefix(format!("socket.io-{variant}")); let pg_pool = PgPool::connect("postgres://socketio:socketio@localhost:5432/socketio").await?; - let adapter = PostgresAdapterCtr::new_with_driver(SqlxDriver::new(pg_pool), config); + let adapter = PostgresAdapterCtr::new_with_sqlx_config(pg_pool, config); let (svc, io) = SocketIo::builder() - .with_adapter::>(adapter) + .with_adapter::>(adapter) .build_svc(); io.ns("/", adapter_e2e::handler).await?; diff --git a/e2e/adapter/src/bins/sqlx_msgpack.rs b/e2e/adapter/src/bins/sqlx_msgpack.rs index d7f420f3..c5a9482d 100644 --- a/e2e/adapter/src/bins/sqlx_msgpack.rs +++ b/e2e/adapter/src/bins/sqlx_msgpack.rs @@ -3,8 +3,7 @@ use hyper_util::rt::TokioIo; use socketioxide::{ParserConfig, SocketIo}; use socketioxide_postgres::{ - CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, - drivers::sqlx::{SqlxDriver, sqlx_client::PgPool}, + PostgresAdapterConfig, PostgresAdapterCtr, SqlxAdapter, drivers::sqlx::sqlx_client::PgPool, }; use tokio::net::TcpListener; use tracing::{Level, info}; @@ -19,17 +18,13 @@ async fn main() -> Result<(), Box> { tracing::subscriber::set_global_default(subscriber)?; let variant = std::env::args().next().unwrap(); let variant = variant.split("/").last().unwrap(); - - let config = PostgresAdapterConfig { - prefix: format!("socket.io-{variant}").into(), - ..Default::default() - }; + let config = PostgresAdapterConfig::new().with_prefix(format!("socket.io-{variant}")); let pg_pool = PgPool::connect("postgres://socketio:socketio@localhost:5432/socketio").await?; - let adapter = PostgresAdapterCtr::new_with_driver(SqlxDriver::new(pg_pool), config); + let adapter = PostgresAdapterCtr::new_with_sqlx_config(pg_pool, config); let (svc, io) = SocketIo::builder() .with_parser(ParserConfig::msgpack()) - .with_adapter::>(adapter) + .with_adapter::>(adapter) .build_svc(); io.ns("/", adapter_e2e::handler).await?; From 0a36aded7378ffee50b9ace2a819904392811a17 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 18:32:38 +0200 Subject: [PATCH 11/31] feat(adapter/postgre): wip --- e2e/adapter/main.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/e2e/adapter/main.rs b/e2e/adapter/main.rs index f18af801..35125745 100644 --- a/e2e/adapter/main.rs +++ b/e2e/adapter/main.rs @@ -26,9 +26,6 @@ fn main() -> Result<(), Box> { let bin_filter = args().nth(1).unwrap_or("".to_string()); println!("binary target filter: {}", bin_filter); - let test_filter = args().nth(2); - println!("test filter: {}", test_filter.as_deref().unwrap_or("*")); - if fs::exists(LOG_DIR)? { fs::remove_dir_all(LOG_DIR)?; } @@ -36,14 +33,14 @@ fn main() -> Result<(), Box> { // run everything for target in BINS.iter().filter(|name| name.contains(&bin_filter)) { - run(target, test_filter.as_deref()); + run(target); } println!("All tests passed!"); Ok(()) } -fn run(target: &'static str, test_filter: Option<&str>) { +fn run(target: &'static str) { let parser = if target.ends_with("msgpack") { "msgpack" } else { From d6fbb14deff22439da0aabb5cb50fd0203ee29c7 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 18:43:11 +0200 Subject: [PATCH 12/31] feat(adapter/postgre): add tests --- crates/socketioxide-postgres/Cargo.toml | 1 + .../socketioxide-postgres/tests/broadcast.rs | 149 +++++++++++ crates/socketioxide-postgres/tests/fixture.rs | 247 ++++++++++++++++++ crates/socketioxide-postgres/tests/local.rs | 32 +++ crates/socketioxide-postgres/tests/rooms.rs | 119 +++++++++ crates/socketioxide-postgres/tests/sockets.rs | 170 ++++++++++++ 6 files changed, 718 insertions(+) create mode 100644 crates/socketioxide-postgres/tests/broadcast.rs create mode 100644 crates/socketioxide-postgres/tests/fixture.rs create mode 100644 crates/socketioxide-postgres/tests/local.rs create mode 100644 crates/socketioxide-postgres/tests/rooms.rs create mode 100644 crates/socketioxide-postgres/tests/sockets.rs diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index 3aa2604e..21724c6f 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -52,6 +52,7 @@ socketioxide = { path = "../socketioxide", features = [ ] } tracing-subscriber.workspace = true bytes.workspace = true +futures-util.workspace = true # docs.rs-specific configuration [package.metadata.docs.rs] diff --git a/crates/socketioxide-postgres/tests/broadcast.rs b/crates/socketioxide-postgres/tests/broadcast.rs new file mode 100644 index 00000000..c7ba71a9 --- /dev/null +++ b/crates/socketioxide-postgres/tests/broadcast.rs @@ -0,0 +1,149 @@ +use std::time::Duration; + +use socketioxide::{adapter::Adapter, extract::SocketRef}; +mod fixture; + +#[tokio::test] +pub async fn broadcast() { + async fn handler(socket: SocketRef
) { + // delay to ensure all socket/servers are connected + tokio::time::sleep(Duration::from_millis(1)).await; + socket.broadcast().emit("test", &2).await.unwrap(); + } + + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", handler).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + assert_eq!(timeout_rcv!(&mut rx1), r#"42["test",2]"#); + assert_eq!(timeout_rcv!(&mut rx2), r#"42["test",2]"#); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} + +#[tokio::test] +pub async fn broadcast_rooms() { + let [io1, io2, io3] = fixture::spawn_servers(); + let handler = |room: &'static str, to: &'static str| { + move |socket: SocketRef<_>| async move { + // delay to ensure all socket/servers are connected + socket.join(room); + tokio::time::sleep(Duration::from_millis(5)).await; + socket.to(to).emit("test", room).await.unwrap(); + } + }; + + io1.ns("/", handler("room1", "room2")).await.unwrap(); + io2.ns("/", handler("room2", "room3")).await.unwrap(); + io3.ns("/", handler("room3", "room1")).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2), (_tx3, mut rx3)) = tokio::join!( + io1.new_dummy_sock("/", ()), + io2.new_dummy_sock("/", ()), + io3.new_dummy_sock("/", ()) + ); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + timeout_rcv!(&mut rx3); // Connect "/" packet + + // socket 1 is receiving a packet from io3 + assert_eq!(timeout_rcv!(&mut rx1), r#"42["test","room3"]"#); + // socket 2 is receiving a packet from io2 + assert_eq!(timeout_rcv!(&mut rx2), r#"42["test","room1"]"#); + // socket 3 is receiving a packet from io1 + assert_eq!(timeout_rcv!(&mut rx3), r#"42["test","room2"]"#); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); + timeout_rcv_err!(&mut rx3); +} + +#[tokio::test] +pub async fn broadcast_with_ack() { + use futures_util::stream::StreamExt; + + async fn handler(socket: SocketRef) { + // delay to ensure all socket/servers are connected + tokio::time::sleep(Duration::from_millis(1)).await; + socket + .broadcast() + .emit_with_ack::<_, String>("test", "bar") + .await + .unwrap() + .for_each(|(_, res)| { + socket.emit("ack_res", &res).unwrap(); + async move {} + }) + .await; + } + + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let ((_tx1, mut rx1), (tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + assert_eq!(timeout_rcv!(&mut rx2), r#"421["test","bar"]"#); + let packet_res = r#"431["foo"]"#.to_string().try_into().unwrap(); + tx2.try_send(packet_res).unwrap(); + assert_eq!(timeout_rcv!(&mut rx1), r#"42["ack_res",{"Ok":"foo"}]"#); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} + +#[tokio::test] +pub async fn broadcast_with_ack_timeout() { + use futures_util::StreamExt; + const TIMEOUT: Duration = Duration::from_millis(50); + + async fn handler(socket: SocketRef) { + socket + .broadcast() + .emit_with_ack::<_, String>("test", "bar") + .await + .unwrap() + .for_each(|(_, res)| { + socket.emit("ack_res", &res).unwrap(); + async move {} + }) + .await; + socket.emit("ack_res", "timeout").unwrap(); + } + + let [io1, io2] = fixture::spawn_buggy_servers(TIMEOUT); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let now = std::time::Instant::now(); + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + assert_eq!(timeout_rcv!(&mut rx2), r#"421["test","bar"]"#); // emit with ack message + // We do not answer + assert_eq!( + timeout_rcv!(&mut rx1, TIMEOUT.as_millis() as u64 + 100), + r#"42["ack_res","timeout"]"# + ); + assert!(now.elapsed() >= TIMEOUT); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs new file mode 100644 index 00000000..070d9b97 --- /dev/null +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -0,0 +1,247 @@ +#![allow(dead_code)] + +use futures_core::Stream; +use socketioxide_core::Uid; +use socketioxide_postgres::{ + CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, + drivers::{Driver, Notification}, +}; +use std::{ + convert::Infallible, + pin::Pin, + str::FromStr, + sync::{Arc, RwLock}, + task, + time::Duration, +}; +use tokio::sync::mpsc; + +use socketioxide::{SocketIo, SocketIoConfig, adapter::Emitter}; + +/// Spawns a number of servers with a stub driver for testing. +/// Every server will be connected to every other server. +pub fn spawn_servers() -> [SocketIo>; N] +{ + let sync_buff = Arc::new(RwLock::new(Vec::with_capacity(N))); + spawn_inner(sync_buff, PostgresAdapterConfig::default()) +} + +pub fn spawn_buggy_servers( + timeout: Duration, +) -> [SocketIo>; N] { + let sync_buff = Arc::new(RwLock::new(Vec::with_capacity(N))); + let config = PostgresAdapterConfig::default().with_request_timeout(timeout); + let res = spawn_inner(sync_buff.clone(), config); + + // Reinject a false heartbeat request to simulate a bad number of servers. + // This will trigger timeouts when expecting responses from all servers. + // The heartbeat type is 20 (RequestTypeOut::Heartbeat) in the wire format. + let uid: Uid = Uid::from_str("PHHq01ObWy7Godqx").unwrap(); + let heartbeat_json = serde_json::json!({ + "node_id": uid.to_string(), + "id": "ZG9K1r7xSLBiJYWD", + "type": 20, + "opts": null, + }); + let payload = serde_json::to_string(&heartbeat_json).unwrap(); + + for (_, tx) in sync_buff.read().unwrap().iter() { + // Send the heartbeat to the global channel of the "/" namespace + tx.try_send(StubNotification { + channel: "socket.io#/".to_string(), + payload: payload.clone(), + }) + .unwrap(); + } + + res +} + +fn spawn_inner( + sync_buff: Arc>, + config: PostgresAdapterConfig, +) -> [SocketIo>; N] { + [0; N].map(|_| { + let server_id = Uid::new(); + let (driver, mut rx, tx) = StubDriver::new(server_id); + + // pipe messages to all other servers + sync_buff.write().unwrap().push((server_id, tx)); + let sync_buff = sync_buff.clone(); + tokio::spawn(async move { + while let Some(notif) = rx.recv().await { + tracing::debug!("received notify on channel {:?}", notif.channel); + for (sid, tx) in sync_buff.read().unwrap().iter() { + if *sid != server_id { + tracing::debug!("forwarding notify to server {:?}", sid); + tx.try_send(notif.clone()).unwrap(); + } + } + } + }); + + let adapter = PostgresAdapterCtr::new_with_driver(driver, config.clone()); + let mut config = SocketIoConfig::default(); + config.server_id = server_id; + let (_svc, io) = SocketIo::builder() + .with_config(config) + .with_adapter::>(adapter) + .build_svc(); + io + }) +} + +type NotifyHandlers = Vec<(Uid, mpsc::Sender)>; + +#[derive(Debug, Clone)] +pub struct StubNotification { + channel: String, + payload: String, +} + +impl Notification for StubNotification { + fn channel(&self) -> &str { + &self.channel + } + + fn payload(&self) -> &str { + &self.payload + } +} + +#[derive(Debug, Clone)] +pub struct StubDriver { + server_id: Uid, + /// Sender to emit outgoing NOTIFY messages (to be broadcast to other servers). + tx: mpsc::Sender, + /// Handlers for incoming notifications per listened channel. + handlers: Arc)>>>, +} + +impl StubDriver { + pub fn new( + server_id: Uid, + ) -> (Self, mpsc::Receiver, mpsc::Sender) { + let (tx, rx) = mpsc::channel(255); // outgoing notifies + let (tx1, rx1) = mpsc::channel(255); // incoming notifies + let handlers: Arc)>>> = + Arc::new(RwLock::new(Vec::new())); + + tokio::spawn(pipe_handlers(rx1, handlers.clone())); + + let driver = Self { + server_id, + tx, + handlers, + }; + (driver, rx, tx1) + } +} + +/// Pipe incoming notifications to the matching channel handlers. +async fn pipe_handlers( + mut rx: mpsc::Receiver, + handlers: Arc)>>>, +) { + while let Some(notif) = rx.recv().await { + let handlers = handlers.read().unwrap(); + for (chan, handler) in &*handlers { + if *chan == notif.channel { + handler.try_send(notif.clone()).unwrap(); + } + } + } +} + +pin_project_lite::pin_project! { + pub struct NotificationStream { + #[pin] + rx: mpsc::Receiver, + } +} + +impl Stream for NotificationStream { + type Item = StubNotification; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.project().rx.poll_recv(cx) + } +} + +impl Driver for StubDriver { + type Error = Infallible; + type Notification = StubNotification; + type NotificationStream = NotificationStream; + + async fn init(&self, _table: &str) -> Result<(), Self::Error> { + Ok(()) + } + + async fn listen(&self, channels: &[&str]) -> Result { + let (tx, rx) = mpsc::channel(255); + let mut handlers = self.handlers.write().unwrap(); + for chan in channels { + handlers.push((chan.to_string(), tx.clone())); + } + Ok(NotificationStream { rx }) + } + + async fn notify(&self, channel: &str, message: &str) -> Result<(), Self::Error> { + // Also deliver to local handlers (self-delivery, like real PG NOTIFY). + { + let handlers = self.handlers.read().unwrap(); + for (chan, handler) in &*handlers { + if *chan == channel { + handler + .try_send(StubNotification { + channel: channel.to_string(), + payload: message.to_string(), + }) + .unwrap(); + } + } + } + // Send to the broadcast pipe for delivery to other servers. + self.tx + .try_send(StubNotification { + channel: channel.to_string(), + payload: message.to_string(), + }) + .unwrap(); + Ok(()) + } +} + +#[macro_export] +macro_rules! timeout_rcv_err { + ($srx:expr) => { + tokio::time::timeout(std::time::Duration::from_millis(10), $srx.recv()) + .await + .unwrap_err(); + }; +} + +#[macro_export] +macro_rules! timeout_rcv { + ($srx:expr) => { + TryInto::::try_into( + tokio::time::timeout(std::time::Duration::from_millis(10), $srx.recv()) + .await + .unwrap() + .unwrap(), + ) + .unwrap() + }; + ($srx:expr, $t:expr) => { + TryInto::::try_into( + tokio::time::timeout(std::time::Duration::from_millis($t), $srx.recv()) + .await + .unwrap() + .unwrap(), + ) + .unwrap() + }; +} diff --git a/crates/socketioxide-postgres/tests/local.rs b/crates/socketioxide-postgres/tests/local.rs new file mode 100644 index 00000000..49972933 --- /dev/null +++ b/crates/socketioxide-postgres/tests/local.rs @@ -0,0 +1,32 @@ +//! Check that each adapter function with a broadcast options that is [`Local`] returns an immediate future +mod fixture; + +macro_rules! assert_now { + ($fut:expr) => { + #[allow(unused_must_use)] + futures_util::FutureExt::now_or_never($fut) + .expect("Returned future should be sync") + .unwrap() + }; +} + +#[tokio::test] +async fn test_local_fns() { + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + assert_now!(io1.local().emit("test", "test")); + assert_now!(io1.local().emit_with_ack::<_, ()>("test", "test")); + assert_now!(io1.local().join("test")); + assert_now!(io1.local().leave("test")); + assert_now!(io1.local().disconnect()); + assert_now!(io1.local().fetch_sockets()); +} diff --git a/crates/socketioxide-postgres/tests/rooms.rs b/crates/socketioxide-postgres/tests/rooms.rs new file mode 100644 index 00000000..343d400f --- /dev/null +++ b/crates/socketioxide-postgres/tests/rooms.rs @@ -0,0 +1,119 @@ +use std::time::Duration; + +use socketioxide::extract::SocketRef; + +mod fixture; + +#[tokio::test] +pub async fn all_rooms() { + let [io1, io2, io3] = fixture::spawn_servers(); + let handler = + |rooms: &'static [&'static str]| async move |socket: SocketRef<_>| socket.join(rooms); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room2", "room3"])).await.unwrap(); + io3.ns("/", handler(&["room3", "room1"])).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2), (_tx3, mut rx3)) = tokio::join!( + io1.new_dummy_sock("/", ()), + io2.new_dummy_sock("/", ()), + io3.new_dummy_sock("/", ()) + ); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + timeout_rcv!(&mut rx3); // Connect "/" packet + + const ROOMS: [&str; 3] = ["room1", "room2", "room3"]; + for io in [io1, io2, io3] { + let mut rooms = io.rooms().await.unwrap(); + rooms.sort(); + assert_eq!(rooms, ROOMS); + } + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); + timeout_rcv_err!(&mut rx3); +} + +#[tokio::test] +pub async fn all_rooms_timeout() { + const TIMEOUT: Duration = Duration::from_millis(50); + let [io1, io2, io3] = fixture::spawn_buggy_servers(TIMEOUT); + let handler = + |rooms: &'static [&'static str]| async move |socket: SocketRef<_>| socket.join(rooms); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room2", "room3"])).await.unwrap(); + io3.ns("/", handler(&["room3", "room1"])).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2), (_tx3, mut rx3)) = tokio::join!( + io1.new_dummy_sock("/", ()), + io2.new_dummy_sock("/", ()), + io3.new_dummy_sock("/", ()) + ); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + timeout_rcv!(&mut rx3); // Connect "/" packet + + const ROOMS: [&str; 3] = ["room1", "room2", "room3"]; + for io in [io1, io3, io2] { + let now = std::time::Instant::now(); + let mut rooms = io.rooms().await.unwrap(); + dbg!(&rooms); + assert!(dbg!(now.elapsed()) >= TIMEOUT); // timeout time + rooms.sort(); + assert_eq!(rooms, ROOMS); + } + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); + timeout_rcv_err!(&mut rx3); +} +#[tokio::test] +pub async fn add_sockets() { + let handler = |room: &'static str| async move |socket: SocketRef<_>| socket.join(room); + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler("room1")).await.unwrap(); + io2.ns("/", handler("room3")).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + io1.broadcast().join("room2").await.unwrap(); + let mut rooms = io1.rooms().await.unwrap(); + rooms.sort(); + assert_eq!(rooms, ["room1", "room2", "room3"]); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} + +#[tokio::test] +pub async fn del_sockets() { + let handler = + |rooms: &'static [&'static str]| async move |socket: SocketRef<_>| socket.join(rooms); + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room3", "room2"])).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + io1.broadcast().leave("room2").await.unwrap(); + + let mut rooms = io1.rooms().await.unwrap(); + rooms.sort(); + assert_eq!(rooms, ["room1", "room3"]); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} diff --git a/crates/socketioxide-postgres/tests/sockets.rs b/crates/socketioxide-postgres/tests/sockets.rs new file mode 100644 index 00000000..947151ff --- /dev/null +++ b/crates/socketioxide-postgres/tests/sockets.rs @@ -0,0 +1,170 @@ +use std::{str::FromStr, time::Duration}; + +use socketioxide::{ + SocketIo, adapter::Adapter, extract::SocketRef, operators::BroadcastOperators, + socket::RemoteSocket, +}; +use socketioxide_core::{Sid, Str, adapter::RemoteSocketData}; +use tokio::time::Instant; + +mod fixture; +fn extract_sid(data: &str) -> Sid { + let data = data + .split("\"sid\":\"") + .nth(1) + .and_then(|s| s.split('"').next()) + .unwrap(); + Sid::from_str(data).unwrap() +} +async fn fetch_sockets_data(op: BroadcastOperators) -> Vec { + let mut sockets = op + .fetch_sockets() + .await + .unwrap() + .into_iter() + .map(RemoteSocket::into_data) + .collect::>(); + sockets.sort_by(|a, b| a.id.cmp(&b.id)); + sockets +} +fn create_expected_sockets( + ids: [Sid; N], + ios: [&SocketIo; N], +) -> [RemoteSocketData; N] { + let mut i = 0; + let mut sockets = ios.map(|io| { + let id = ids[i]; + i += 1; + RemoteSocketData { + id, + server_id: io.config().server_id, + ns: Str::from("/"), + } + }); + sockets.sort_by(|a, b| a.id.cmp(&b.id)); + sockets +} + +#[tokio::test] +pub async fn fetch_sockets() { + let [io1, io2, io3] = fixture::spawn_servers::<3>(); + + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + io3.ns("/", async || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + let (_, mut rx3) = io3.new_dummy_sock("/", ()).await; + + let id1 = extract_sid(&timeout_rcv!(&mut rx1)); + let id2 = extract_sid(&timeout_rcv!(&mut rx2)); + let id3 = extract_sid(&timeout_rcv!(&mut rx3)); + + let mut expected_sockets = create_expected_sockets([id1, id2, id3], [&io1, &io2, &io3]); + expected_sockets.sort_by(|a, b| a.id.cmp(&b.id)); + + let sockets = fetch_sockets_data(io1.broadcast()).await; + assert_eq!(sockets, expected_sockets); + + let sockets = fetch_sockets_data(io2.broadcast()).await; + assert_eq!(sockets, expected_sockets); + + let sockets = fetch_sockets_data(io3.broadcast()).await; + assert_eq!(sockets, expected_sockets); +} + +#[tokio::test] +pub async fn fetch_sockets_with_rooms() { + let [io1, io2, io3] = fixture::spawn_servers::<3>(); + let handler = + |rooms: &'static [&'static str]| async move |socket: SocketRef<_>| socket.join(rooms); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room2", "room3"])).await.unwrap(); + io3.ns("/", handler(&["room3", "room1"])).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + let (_, mut rx3) = io3.new_dummy_sock("/", ()).await; + + let id1 = extract_sid(&timeout_rcv!(&mut rx1)); + let id2 = extract_sid(&timeout_rcv!(&mut rx2)); + let id3 = extract_sid(&timeout_rcv!(&mut rx3)); + + let sockets = fetch_sockets_data(io1.to("room1")).await; + assert_eq!(sockets, create_expected_sockets([id1, id3], [&io1, &io3])); + + let sockets = fetch_sockets_data(io1.to("room2")).await; + assert_eq!(sockets, create_expected_sockets([id1, id2], [&io1, &io2])); + + let sockets = fetch_sockets_data(io1.to("room3")).await; + assert_eq!(sockets, create_expected_sockets([id2, id3], [&io2, &io3])); +} + +#[tokio::test] +pub async fn fetch_sockets_timeout() { + const TIMEOUT: Duration = Duration::from_millis(50); + let [io1, io2] = fixture::spawn_buggy_servers(TIMEOUT); + + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + let now = Instant::now(); + io1.fetch_sockets().await.unwrap(); + assert!(now.elapsed() >= TIMEOUT); +} + +#[tokio::test] +pub async fn remote_socket_emit() { + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + let sockets = io1.fetch_sockets().await.unwrap(); + for socket in sockets { + socket.emit("test", "hello").await.unwrap(); + } + + assert_eq!(timeout_rcv!(&mut rx1), r#"42["test","hello"]"#); + assert_eq!(timeout_rcv!(&mut rx2), r#"42["test","hello"]"#); +} + +#[tokio::test] +pub async fn remote_socket_emit_with_ack() { + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + let sockets = io1.fetch_sockets().await.unwrap(); + for socket in sockets { + #[allow(unused_must_use)] + socket + .emit_with_ack::<_, ()>("test", "hello") + .await + .unwrap(); + } + + assert_eq!(timeout_rcv!(&mut rx1), r#"421["test","hello"]"#); + assert_eq!(timeout_rcv!(&mut rx2), r#"421["test","hello"]"#); +} From 68bea2cbb8dea1840c1215b225860bd81add43a9 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 12 Apr 2026 15:22:10 +0200 Subject: [PATCH 13/31] fix: fmt --- crates/socketioxide-postgres/tests/fixture.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs index 070d9b97..69fb8415 100644 --- a/crates/socketioxide-postgres/tests/fixture.rs +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -121,7 +121,11 @@ pub struct StubDriver { impl StubDriver { pub fn new( server_id: Uid, - ) -> (Self, mpsc::Receiver, mpsc::Sender) { + ) -> ( + Self, + mpsc::Receiver, + mpsc::Sender, + ) { let (tx, rx) = mpsc::channel(255); // outgoing notifies let (tx1, rx1) = mpsc::channel(255); // incoming notifies let handlers: Arc)>>> = From 128102c42d5b7b0d19fc9c743e3aedc1789b9823 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 12 Apr 2026 15:26:03 +0200 Subject: [PATCH 14/31] fix: tests --- crates/socketioxide-postgres/tests/broadcast.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/crates/socketioxide-postgres/tests/broadcast.rs b/crates/socketioxide-postgres/tests/broadcast.rs index c7ba71a9..f1486cc1 100644 --- a/crates/socketioxide-postgres/tests/broadcast.rs +++ b/crates/socketioxide-postgres/tests/broadcast.rs @@ -108,11 +108,14 @@ pub async fn broadcast_with_ack() { #[tokio::test] pub async fn broadcast_with_ack_timeout() { use futures_util::StreamExt; - const TIMEOUT: Duration = Duration::from_millis(50); + const REQ_TIMEOUT: Duration = Duration::from_millis(50); + const ACK_TIMEOUT: Duration = Duration::from_millis(50); + const TIMEOUT: Duration = Duration::from_millis(100); async fn handler(socket: SocketRef) { socket .broadcast() + .timeout(ACK_TIMEOUT) .emit_with_ack::<_, String>("test", "bar") .await .unwrap() @@ -124,7 +127,7 @@ pub async fn broadcast_with_ack_timeout() { socket.emit("ack_res", "timeout").unwrap(); } - let [io1, io2] = fixture::spawn_buggy_servers(TIMEOUT); + let [io1, io2] = fixture::spawn_buggy_servers(REQ_TIMEOUT); io1.ns("/", handler).await.unwrap(); io2.ns("/", async || ()).await.unwrap(); From e4001f2ca8991cc2224eb69508528a97b1bba604 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 12 Apr 2026 16:21:41 +0200 Subject: [PATCH 15/31] feat: tokio postgres driver --- crates/socketioxide-postgres/Cargo.toml | 4 +- .../socketioxide-postgres/src/drivers/mod.rs | 6 +- .../src/drivers/postgres.rs | 63 ---------- .../src/drivers/tokio_postgres.rs | 119 ++++++++++++++++++ crates/socketioxide-postgres/src/lib.rs | 36 ++++++ crates/socketioxide-postgres/tests/fixture.rs | 12 +- e2e/adapter/Cargo.toml | 10 +- e2e/adapter/src/bins/fred_cluster.rs | 10 +- e2e/adapter/src/bins/fred_cluster_msgpack.rs | 8 +- e2e/adapter/src/bins/tokio_postgres.rs | 65 ++++++++++ .../src/bins/tokio_postgres_msgpack.rs | 65 ++++++++++ 11 files changed, 317 insertions(+), 81 deletions(-) delete mode 100644 crates/socketioxide-postgres/src/drivers/postgres.rs create mode 100644 crates/socketioxide-postgres/src/drivers/tokio_postgres.rs create mode 100644 e2e/adapter/src/bins/tokio_postgres.rs create mode 100644 e2e/adapter/src/bins/tokio_postgres_msgpack.rs diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index 42d7ff35..4961a06b 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -14,8 +14,8 @@ readme = "README.md" [features] sqlx = ["dep:sqlx"] -postgres = ["dep:tokio-postgres"] -default = ["postgres"] +tokio-postgres = ["dep:tokio-postgres"] +default = ["sqlx"] [dependencies] socketioxide-core = { version = "0.18", path = "../socketioxide-core", features = [ diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 10e1f5cd..77c75cd8 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -7,8 +7,10 @@ use futures_core::Stream; #[cfg(feature = "sqlx")] pub mod sqlx; -// #[cfg(feature = "postgres")] -// pub mod postgres; +/// A driver implementation for the [`tokio-postgres`](https://docs.rs/tokio-postgres) +/// PostgreSQL backend. +#[cfg(feature = "tokio-postgres")] +pub mod tokio_postgres; /// The driver trait can be used to support different LISTEN/NOTIFY backends. /// It must share handlers/connection between its clones. diff --git a/crates/socketioxide-postgres/src/drivers/postgres.rs b/crates/socketioxide-postgres/src/drivers/postgres.rs deleted file mode 100644 index fea516a5..00000000 --- a/crates/socketioxide-postgres/src/drivers/postgres.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::{pin::Pin, sync::Arc}; - -use futures_core::{Stream, stream::BoxStream}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_postgres::{AsyncMessage, Client, Config, Connection, Socket}; - -use super::Driver; - -#[derive(Debug, Clone)] -pub struct PostgresDriver { - client: Arc, - config: Config, -} - -impl PostgresDriver { - pub fn new(config: Config) -> Self - where - T: AsyncRead + AsyncWrite + Unpin, - { - PostgresDriver { config } - } -} - -impl Driver for PostgresDriver { - type Error = tokio_postgres::Error; - type Notification = tokio_postgres::Notification; - type NotificationStream = BoxStream<'static, Self::Notification>; - - async fn init( - &self, - table: &str, - channels: &[&str], - ) -> Result { - let st = &format!( - r#"CREATE TABLE IF NOT EXISTS "{table}" ( - id BIGSERIAL UNIQUE, - created_at TIMESTAMPTZ DEFAULT NOW(), - payload BYTEA - )"# - ); - - self.client.execute(st, &[]).await?; - - Ok(()) - } - - async fn notify(&self, channel: &str, message: &str) -> Result<(), Self::Error> { - self.client - .execute("SELECT pg_notify($1, $2)", &[&channel, &message]) - .await?; - Ok(()) - } -} - -impl super::Notification for tokio_postgres::Notification { - fn channel(&self) -> &str { - tokio_postgres::Notification::channel(self) - } - - fn payload(&self) -> &str { - tokio_postgres::Notification::payload(self) - } -} diff --git a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs new file mode 100644 index 00000000..7b1f6eb6 --- /dev/null +++ b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs @@ -0,0 +1,119 @@ +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use futures_util::{StreamExt, sink, stream}; +use tokio::sync::mpsc; +use tokio_postgres::{AsyncMessage, Client, Config, Socket, tls::MakeTlsConnect}; + +use crate::stream::ChanStream; + +use super::Driver; + +pub use tokio_postgres as tokio_postgres_client; + +type Demux = HashMap>; + +const LISTENER_QUEUE_SIZE: usize = 255; + +/// A [`Driver`] implementation using the [`tokio_postgres`] PostgreSQL client. +/// +/// It drives the client connection to extract notifications from the PostgreSQL server. +#[derive(Debug, Clone)] +pub struct TokioPostgresDriver { + client: Arc, + demux: Arc>, +} + +async fn demux_notif( + demux: Arc>, + msg: AsyncMessage, +) -> Result>, tokio_postgres::Error> { + let AsyncMessage::Notification(notif) = msg else { + return Ok(demux); + }; + + if let Some(tx) = demux.read().unwrap().get(notif.channel()) { + if let Err(e) = tx.try_send(notif) { + tracing::warn!("failed to send notification: {}", e); + } + } else { + tracing::debug!("no listener for channel {}", notif.channel()); + } + + Ok(demux) +} + +impl TokioPostgresDriver { + /// Connects to the PostgreSQL server using the provided configuration and TLS settings + /// with [`Config::connect`]. + /// + /// The resulting connection is driven inside the driver to be + /// able to receive notifications and dispatch them to the appropriate listeners. + pub async fn new(config: Config, tls: T) -> Result + where + T: MakeTlsConnect + Send + Sync + 'static, + >::Stream: Send, + { + let (client, mut conn) = config.connect(tls).await?; + + let demux = Arc::new(RwLock::new(HashMap::new())); + let stream = stream::poll_fn(move |cx| conn.poll_message(cx)); + tokio::spawn(stream.forward(sink::unfold(demux.clone(), demux_notif))); + + let driver = TokioPostgresDriver { + client: Arc::new(client), + demux, + }; + + Ok(driver) + } +} + +impl Driver for TokioPostgresDriver { + type Error = tokio_postgres::Error; + type Notification = tokio_postgres::Notification; + type NotificationStream = ChanStream; + + async fn init(&self, table: &str) -> Result<(), Self::Error> { + let st = &format!( + r#"CREATE TABLE IF NOT EXISTS "{table}" ( + id BIGSERIAL UNIQUE, + created_at TIMESTAMPTZ DEFAULT NOW(), + payload BYTEA + )"# + ); + + self.client.execute(st, &[]).await?; + + Ok(()) + } + + async fn listen(&self, channels: &[&str]) -> Result { + let (tx, rx) = mpsc::channel(LISTENER_QUEUE_SIZE); + let mut demux = self.demux.write().unwrap(); + for channel in channels { + demux.insert(channel.to_string(), tx.clone()); + } + + Ok(ChanStream::new(rx)) + } + + async fn notify(&self, channel: &str, message: &str) -> Result<(), Self::Error> { + self.client + .execute("SELECT pg_notify($1, $2)", &[&channel, &message]) + .await?; + Ok(()) + } +} + +impl super::Notification for tokio_postgres::Notification { + fn channel(&self) -> &str { + tokio_postgres::Notification::channel(self) + } + + fn payload(&self) -> &str { + tokio_postgres::Notification::payload(self) + } +} diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 7e52a618..4d457495 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -250,6 +250,37 @@ impl PostgresAdapterCtr { } } +#[cfg(feature = "tokio-postgres")] +impl PostgresAdapterCtr { + /// Create a new adapter constructor with the [`tokio-postgres`](drivers::tokio_postgres) driver + /// and a default config. + pub async fn new_with_tokio_postgres( + pg_config: tokio_postgres::Config, + tls: T, + ) -> Result + where + T: tokio_postgres::tls::MakeTlsConnect + Send + Sync + 'static, + >::Stream: Send, + { + Self::new_with_tokio_postgres_config(pg_config, tls, PostgresAdapterConfig::default()).await + } + + /// Create a new adapter constructor with the [`sqlx`](drivers::sqlx) driver + /// and a custom config. + pub async fn new_with_tokio_postgres_config( + pg_config: tokio_postgres::Config, + tls: T, + config: PostgresAdapterConfig, + ) -> Result + where + T: tokio_postgres::tls::MakeTlsConnect + Send + Sync + 'static, + >::Stream: Send, + { + let driver = drivers::tokio_postgres::TokioPostgresDriver::new(pg_config, tls).await?; + Ok(Self { driver, config }) + } +} + impl PostgresAdapterCtr { /// Create a new adapter constructor with a custom postgres driver and a config. /// @@ -264,6 +295,11 @@ impl PostgresAdapterCtr { #[cfg(feature = "sqlx")] pub type SqlxAdapter = CustomPostgresAdapter; +/// The postgres adapter with the [`tokio_postgres`](drivers::tokio_postgres) driver. +#[cfg(feature = "tokio-postgres")] +pub type TokioPostgresAdapter = + CustomPostgresAdapter; + type ResponseHandlers = HashMap>>; /// The postgres adapter implementation. diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs index 69fb8415..b63410a1 100644 --- a/crates/socketioxide-postgres/tests/fixture.rs +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -109,13 +109,15 @@ impl Notification for StubNotification { } } +type Handlers = Vec<(String, mpsc::Sender)>; + #[derive(Debug, Clone)] pub struct StubDriver { server_id: Uid, /// Sender to emit outgoing NOTIFY messages (to be broadcast to other servers). tx: mpsc::Sender, /// Handlers for incoming notifications per listened channel. - handlers: Arc)>>>, + handlers: Arc>, } impl StubDriver { @@ -128,8 +130,7 @@ impl StubDriver { ) { let (tx, rx) = mpsc::channel(255); // outgoing notifies let (tx1, rx1) = mpsc::channel(255); // incoming notifies - let handlers: Arc)>>> = - Arc::new(RwLock::new(Vec::new())); + let handlers: Arc> = Arc::new(RwLock::new(Vec::new())); tokio::spawn(pipe_handlers(rx1, handlers.clone())); @@ -143,10 +144,7 @@ impl StubDriver { } /// Pipe incoming notifications to the matching channel handlers. -async fn pipe_handlers( - mut rx: mpsc::Receiver, - handlers: Arc)>>>, -) { +async fn pipe_handlers(mut rx: mpsc::Receiver, handlers: Arc>) { while let Some(notif) = rx.recv().await { let handlers = handlers.read().unwrap(); for (chan, handler) in &*handlers { diff --git a/e2e/adapter/Cargo.toml b/e2e/adapter/Cargo.toml index cc7ea709..f455d08c 100644 --- a/e2e/adapter/Cargo.toml +++ b/e2e/adapter/Cargo.toml @@ -24,7 +24,7 @@ socketioxide-redis = { path = "../../crates/socketioxide-redis", features = [ socketioxide-mongodb = { path = "../../crates/socketioxide-mongodb" } socketioxide-postgres = { path = "../../crates/socketioxide-postgres", features = [ "sqlx", - "postgres", + "tokio-postgres", ] } hyper-util = { workspace = true, features = ["tokio"] } @@ -93,3 +93,11 @@ path = "src/bins/sqlx.rs" [[bin]] name = "sqlx-e2e-msgpack" path = "src/bins/sqlx_msgpack.rs" + +[[bin]] +name = "tokio-postgres-e2e" +path = "src/bins/tokio_postgres.rs" + +[[bin]] +name = "tokio-postgres-e2e-msgpack" +path = "src/bins/tokio_postgres_msgpack.rs" diff --git a/e2e/adapter/src/bins/fred_cluster.rs b/e2e/adapter/src/bins/fred_cluster.rs index c8c38b0e..d347e884 100644 --- a/e2e/adapter/src/bins/fred_cluster.rs +++ b/e2e/adapter/src/bins/fred_cluster.rs @@ -23,9 +23,13 @@ async fn main() -> Result<(), Box> { ("127.0.0.1", 7004), ("127.0.0.1", 7005), ]); - let mut config = fred::prelude::Config::default(); - config.server = server_config; - config.version = RespVersion::RESP3; + + let config = fred::prelude::Config { + server: server_config, + version: RespVersion::RESP3, + ..Default::default() + }; + let client = fred::prelude::Builder::from_config(config).build_subscriber_client()?; let variant = std::env::args().next().unwrap(); let variant = variant.split("/").last().unwrap(); diff --git a/e2e/adapter/src/bins/fred_cluster_msgpack.rs b/e2e/adapter/src/bins/fred_cluster_msgpack.rs index 4f58f152..1e213e3d 100644 --- a/e2e/adapter/src/bins/fred_cluster_msgpack.rs +++ b/e2e/adapter/src/bins/fred_cluster_msgpack.rs @@ -23,9 +23,11 @@ async fn main() -> Result<(), Box> { ("127.0.0.1", 7004), ("127.0.0.1", 7005), ]); - let mut config = fred::prelude::Config::default(); - config.server = server_config; - config.version = RespVersion::RESP3; + let config = fred::prelude::Config { + server: server_config, + version: RespVersion::RESP3, + ..Default::default() + }; let client = fred::prelude::Builder::from_config(config).build_subscriber_client()?; let variant = std::env::args().next().unwrap(); let variant = variant.split("/").last().unwrap(); diff --git a/e2e/adapter/src/bins/tokio_postgres.rs b/e2e/adapter/src/bins/tokio_postgres.rs new file mode 100644 index 00000000..f45e96cc --- /dev/null +++ b/e2e/adapter/src/bins/tokio_postgres.rs @@ -0,0 +1,65 @@ +use std::str::FromStr; + +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use socketioxide::SocketIo; + +use socketioxide_postgres::{ + PostgresAdapterConfig, PostgresAdapterCtr, TokioPostgresAdapter, + drivers::tokio_postgres::tokio_postgres_client::{Config, NoTls}, +}; +use tokio::net::TcpListener; +use tracing::{Level, info}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + let variant = std::env::args().next().unwrap(); + let variant = variant.split("/").last().unwrap(); + + let config = PostgresAdapterConfig::new().with_prefix(format!("socket.io-{variant}")); + + let pg_config = Config::from_str("postgres://socketio:socketio@localhost:5432/socketio")?; + + let adapter = + PostgresAdapterCtr::new_with_tokio_postgres_config(pg_config, NoTls, config).await?; + let (svc, io) = SocketIo::builder() + .with_adapter::>(adapter) + .build_svc(); + + io.ns("/", adapter_e2e::handler).await?; + + info!("Starting server with v5 protocol"); + let port: u16 = std::env::var("PORT") + .expect("a PORT env var should be set") + .parse()?; + + let listener = TcpListener::bind(("127.0.0.1", port)).await?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} diff --git a/e2e/adapter/src/bins/tokio_postgres_msgpack.rs b/e2e/adapter/src/bins/tokio_postgres_msgpack.rs new file mode 100644 index 00000000..d72dfc59 --- /dev/null +++ b/e2e/adapter/src/bins/tokio_postgres_msgpack.rs @@ -0,0 +1,65 @@ +use std::str::FromStr; + +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use socketioxide::{ParserConfig, SocketIo}; + +use socketioxide_postgres::{ + PostgresAdapterConfig, PostgresAdapterCtr, TokioPostgresAdapter, + drivers::tokio_postgres::tokio_postgres_client::{Config, NoTls}, +}; +use tokio::net::TcpListener; +use tracing::{Level, info}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + let variant = std::env::args().next().unwrap(); + let variant = variant.split("/").last().unwrap(); + let config = PostgresAdapterConfig::new().with_prefix(format!("socket.io-{variant}")); + + let pg_config = Config::from_str("postgres://socketio:socketio@localhost:5432/socketio")?; + + let adapter = + PostgresAdapterCtr::new_with_tokio_postgres_config(pg_config, NoTls, config).await?; + let (svc, io) = SocketIo::builder() + .with_parser(ParserConfig::msgpack()) + .with_adapter::>(adapter) + .build_svc(); + + io.ns("/", adapter_e2e::handler).await?; + + info!("Starting server with v5 protocol"); + let port: u16 = std::env::var("PORT") + .expect("a PORT env var should be set") + .parse()?; + + let listener = TcpListener::bind(("127.0.0.1", port)).await?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} From 117f1036973f9896a63c39433e864645cfa79563 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 12 Apr 2026 16:26:46 +0200 Subject: [PATCH 16/31] fix: ci check featuree flag --- .github/workflows/github-ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/github-ci.yml b/.github/workflows/github-ci.yml index 7d0e29ac..ef17012e 100644 --- a/.github/workflows/github-ci.yml +++ b/.github/workflows/github-ci.yml @@ -115,7 +115,8 @@ jobs: -p socketioxide \ -p engineioxide \ -p socketioxide-redis \ - -p socketioxide-mongodb + -p socketioxide-mongodb \ + -p socketioxide-postgres examples: runs-on: ubuntu-latest From ef327b84cc5bb0c0a249ea1db4eafb46e1d727d0 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 12 Apr 2026 16:33:43 +0200 Subject: [PATCH 17/31] fix: ci check featuree flag --- e2e/adapter/main.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/e2e/adapter/main.rs b/e2e/adapter/main.rs index 35125745..0c1590f4 100644 --- a/e2e/adapter/main.rs +++ b/e2e/adapter/main.rs @@ -18,6 +18,8 @@ const BINS: &[&str] = &[ "mongodb-capped-e2e-msgpack", "sqlx-e2e", "sqlx-e2e-msgpack", + "tokio-postgres-e2e", + "tokio-postgres-e2e-msgpack", ]; const EXEC_SUFFIX: &str = if cfg!(windows) { ".exe" } else { "" }; const LOG_DIR: &str = "e2e/adapter/logs"; From 2f0c33c5404880f85ec366ece8382959e642a56f Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 18 Apr 2026 13:10:16 +0200 Subject: [PATCH 18/31] fix(adapter/postgres): minor fixes --- .../src/drivers/tokio_postgres.rs | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs index 7b1f6eb6..608f61d4 100644 --- a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs @@ -1,7 +1,4 @@ -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, -}; +use std::sync::{Arc, RwLock}; use futures_util::{StreamExt, sink, stream}; use tokio::sync::mpsc; @@ -13,7 +10,7 @@ use super::Driver; pub use tokio_postgres as tokio_postgres_client; -type Demux = HashMap>; +type Listeners = Vec<(String, mpsc::Sender)>; const LISTENER_QUEUE_SIZE: usize = 255; @@ -23,18 +20,23 @@ const LISTENER_QUEUE_SIZE: usize = 255; #[derive(Debug, Clone)] pub struct TokioPostgresDriver { client: Arc, - demux: Arc>, + listeners: Arc>, } -async fn demux_notif( - demux: Arc>, +async fn dispatch_notifs( + listeners: Arc>, msg: AsyncMessage, -) -> Result>, tokio_postgres::Error> { +) -> Result>, tokio_postgres::Error> { let AsyncMessage::Notification(notif) = msg else { - return Ok(demux); + return Ok(listeners); }; - if let Some(tx) = demux.read().unwrap().get(notif.channel()) { + if let Some((_, tx)) = listeners + .read() + .unwrap() + .iter() + .find(|(chan, _)| chan == notif.channel()) + { if let Err(e) = tx.try_send(notif) { tracing::warn!("failed to send notification: {}", e); } @@ -42,7 +44,7 @@ async fn demux_notif( tracing::debug!("no listener for channel {}", notif.channel()); } - Ok(demux) + Ok(listeners) } impl TokioPostgresDriver { @@ -58,13 +60,13 @@ impl TokioPostgresDriver { { let (client, mut conn) = config.connect(tls).await?; - let demux = Arc::new(RwLock::new(HashMap::new())); + let listeners = Arc::new(RwLock::new(Vec::new())); let stream = stream::poll_fn(move |cx| conn.poll_message(cx)); - tokio::spawn(stream.forward(sink::unfold(demux.clone(), demux_notif))); + tokio::spawn(stream.forward(sink::unfold(listeners.clone(), dispatch_notifs))); let driver = TokioPostgresDriver { client: Arc::new(client), - demux, + listeners, }; Ok(driver) @@ -92,9 +94,9 @@ impl Driver for TokioPostgresDriver { async fn listen(&self, channels: &[&str]) -> Result { let (tx, rx) = mpsc::channel(LISTENER_QUEUE_SIZE); - let mut demux = self.demux.write().unwrap(); + let mut listeners = self.listeners.write().unwrap(); for channel in channels { - demux.insert(channel.to_string(), tx.clone()); + listeners.push((channel.to_string(), tx.clone())); } Ok(ChanStream::new(rx)) From a78ac923957c256103ec73c0868ba4d6feb41725 Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 18 Apr 2026 14:15:02 +0200 Subject: [PATCH 19/31] feat(adapter/postgre): fix tokio postgre driver --- Cargo.lock | 1 + crates/socketioxide-postgres/Cargo.toml | 2 ++ .../socketioxide-postgres/src/drivers/mod.rs | 3 ++ .../socketioxide-postgres/src/drivers/sqlx.rs | 5 ++++ .../src/drivers/tokio_postgres.rs | 29 ++++++++++++++++--- crates/socketioxide-postgres/src/lib.rs | 23 +++++++++++---- crates/socketioxide-postgres/tests/fixture.rs | 4 +++ 7 files changed, 57 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8518d71f..94fd2834 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2709,6 +2709,7 @@ dependencies = [ "tokio-postgres", "tracing", "tracing-subscriber", + "xxhash-rust", ] [[package]] diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index 4961a06b..ced9a368 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -40,6 +40,8 @@ sqlx = { version = "0.8", default-features = false, optional = true, features = "runtime-tokio", ] } +xxhash-rust = { version = "0.8", features = ["xxh3"] } + [dev-dependencies] tokio = { workspace = true, features = [ "macros", diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 77c75cd8..eea95575 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -38,6 +38,9 @@ pub trait Driver: Clone + Send + Sync + 'static { channel: &str, message: &str, ) -> impl Future> + Send; + + /// UNLISTEN from every channel. + fn close(&self) -> impl Future> + Send; } /// A trait representing a PostgreSQL NOTIFY notification. diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index 3241dfac..f8fac832 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -66,6 +66,11 @@ impl Driver for SqlxDriver { .await?; Ok(()) } + + async fn close(&self) -> Result<(), Self::Error> { + // PgListener will automatically unlisten channel when being dropped + Ok(()) + } } impl super::Notification for PgNotification { diff --git a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs index 608f61d4..4f142ede 100644 --- a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs @@ -31,6 +31,11 @@ async fn dispatch_notifs( return Ok(listeners); }; + tracing::trace!( + chan = notif.channel(), + "dispatching postgre notif to listeners" + ); + if let Some((_, tx)) = listeners .read() .unwrap() @@ -61,7 +66,7 @@ impl TokioPostgresDriver { let (client, mut conn) = config.connect(tls).await?; let listeners = Arc::new(RwLock::new(Vec::new())); - let stream = stream::poll_fn(move |cx| conn.poll_message(cx)); + let stream = stream::poll_fn(move |cx| dbg!(conn.poll_message(cx))); tokio::spawn(stream.forward(sink::unfold(listeners.clone(), dispatch_notifs))); let driver = TokioPostgresDriver { @@ -94,11 +99,21 @@ impl Driver for TokioPostgresDriver { async fn listen(&self, channels: &[&str]) -> Result { let (tx, rx) = mpsc::channel(LISTENER_QUEUE_SIZE); - let mut listeners = self.listeners.write().unwrap(); - for channel in channels { - listeners.push((channel.to_string(), tx.clone())); + + { + let mut listeners = self.listeners.write().unwrap(); + for channel in channels { + listeners.push((channel.to_string(), tx.clone())); + } } + let query: String = channels + .iter() + .map(|c| format!(r#"LISTEN "{c}"; "#)) + .collect(); + + self.client.batch_execute(&query).await?; + Ok(ChanStream::new(rx)) } @@ -106,6 +121,12 @@ impl Driver for TokioPostgresDriver { self.client .execute("SELECT pg_notify($1, $2)", &[&channel, &message]) .await?; + + Ok(()) + } + + async fn close(&self) -> Result<(), Self::Error> { + self.client.execute("UNLISTEN *", &[]).await?; Ok(()) } } diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 4d457495..0d3e0c13 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -368,7 +368,7 @@ impl CoreAdapter for CustomPostgresAdapter } async fn close(&self) -> Result<(), Self::Error> { - Ok(()) + self.driver.close().await.map_err(Error::Driver) } /// Get the number of servers by iterating over the node liveness heartbeats. @@ -886,22 +886,33 @@ impl CustomPostgresAdapter { .await } + // == All channels are hashed to avoid thresspassing the 63 bytes limit on postgres channel == + // We cannot constraint the length of the channel name because it is generated dynamically. + fn get_global_chan(&self) -> String { - format!("{}#{}", self.config.prefix, self.local.path()) + let chan = format!("{}#{}", self.config.prefix, self.local.path()); + hash_chan(&chan) } fn get_node_chan(&self, uid: Uid) -> String { - format!("{}#{}", self.get_global_chan(), uid) + let chan = format!("{}{}{}", self.config.prefix, self.local.path(), uid); + hash_chan(&chan) } fn get_response_chan(&self, uid: Uid) -> String { - format!( - "{}-response#{}#{}", + let chan = format!( + "response-{}{}{}", &self.config.prefix, self.local.path(), uid - ) + ); + hash_chan(&chan) } } +fn hash_chan(chan: &str) -> String { + let hash = xxhash_rust::xxh3::xxh3_64(chan.as_bytes()); + format!("ch_{:x}", hash) +} + /// The result of the init future. #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct InitRes(futures_core::future::BoxFuture<'static, Result<(), D::Error>>); diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs index b63410a1..4974c177 100644 --- a/crates/socketioxide-postgres/tests/fixture.rs +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -215,6 +215,10 @@ impl Driver for StubDriver { .unwrap(); Ok(()) } + + async fn close(&self) -> Result<(), Self::Error> { + Ok(()) + } } #[macro_export] From bdc8bf0555b3fd2453927db33a8c06e9a3cbe28d Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 18 Apr 2026 14:40:34 +0200 Subject: [PATCH 20/31] docs(example): add postgres adapter examples --- examples/Cargo.lock | 585 +++++++++++++++++- examples/postgres-whiteboard/Cargo.toml | 27 + examples/postgres-whiteboard/Readme.md | 6 + examples/postgres-whiteboard/dist/index.html | 23 + examples/postgres-whiteboard/dist/main.js | 124 ++++ examples/postgres-whiteboard/dist/style.css | 44 ++ examples/postgres-whiteboard/src/sqlx.rs | 60 ++ .../postgres-whiteboard/src/tokio_postgres.rs | 66 ++ 8 files changed, 907 insertions(+), 28 deletions(-) create mode 100644 examples/postgres-whiteboard/Cargo.toml create mode 100644 examples/postgres-whiteboard/Readme.md create mode 100644 examples/postgres-whiteboard/dist/index.html create mode 100644 examples/postgres-whiteboard/dist/main.js create mode 100644 examples/postgres-whiteboard/dist/style.css create mode 100644 examples/postgres-whiteboard/src/sqlx.rs create mode 100644 examples/postgres-whiteboard/src/tokio_postgres.rs diff --git a/examples/Cargo.lock b/examples/Cargo.lock index a2ee5dd9..6c05d952 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -14,7 +14,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" dependencies = [ - "crypto-common", + "crypto-common 0.1.7", "generic-array", ] @@ -80,6 +80,12 @@ dependencies = [ "alloc-no-stdlib", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -194,6 +200,15 @@ dependencies = [ "syn", ] +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -390,6 +405,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-buffer" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdd35008169921d80bc60d3d0ab416eecb028c4cd653352907921d95084790be" +dependencies = [ + "hybrid-array", +] + [[package]] name = "brotli" version = "8.0.2" @@ -574,7 +598,7 @@ version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" dependencies = [ - "crypto-common", + "crypto-common 0.1.7", "inout", ] @@ -587,6 +611,12 @@ dependencies = [ "cc", ] +[[package]] +name = "cmov" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f88a43d011fc4a6876cb7344703e297c71dda42494fee094d5f7c76bf13f746" + [[package]] name = "combine" version = "4.6.7" @@ -610,6 +640,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "const-oid" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6ef517f0926dd24a1582492c791b6a4818a4d94e789a334894aa15b0d12f55c" + [[package]] name = "const-random" version = "0.1.18" @@ -656,10 +692,10 @@ checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" dependencies = [ "aes-gcm", "base64", - "hmac", + "hmac 0.12.1", "percent-encoding", "rand 0.8.5", - "sha2", + "sha2 0.10.9", "subtle", "time", "version_check", @@ -705,6 +741,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eb8a2a1cd12ab0d987a5d5e825195d372001a4094a0376319d5a0ad71c1ba0d" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + [[package]] name = "crc16" version = "0.4.0" @@ -720,6 +771,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -743,6 +803,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "crypto-common" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77727bb15fa921304124b128af125e7e3b968275d1b108b379190264f4423710" +dependencies = [ + "hybrid-array", +] + [[package]] name = "ctr" version = "0.9.2" @@ -752,6 +821,15 @@ dependencies = [ "cipher", ] +[[package]] +name = "ctutils" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5515a3834141de9eafb9717ad39eea8247b5674e6066c404e8c4b365d2a29e" +dependencies = [ + "cmov", +] + [[package]] name = "darling" version = "0.20.11" @@ -943,11 +1021,23 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "block-buffer", - "crypto-common", + "block-buffer 0.10.4", + "crypto-common 0.1.7", "subtle", ] +[[package]] +name = "digest" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4850db49bf08e663084f7fb5c87d202ef91a3907271aff24a94eb97ff039153c" +dependencies = [ + "block-buffer 0.12.0", + "const-oid", + "crypto-common 0.2.1", + "ctutils", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -959,6 +1049,12 @@ dependencies = [ "syn", ] +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "dunce" version = "1.0.5" @@ -970,6 +1066,9 @@ name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +dependencies = [ + "serde", +] [[package]] name = "encoding_rs" @@ -1063,6 +1162,17 @@ dependencies = [ "xxhash-rust", ] +[[package]] +name = "etcetera" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +dependencies = [ + "cfg-if", + "home", + "windows-sys 0.48.0", +] + [[package]] name = "event-listener" version = "5.4.1" @@ -1084,6 +1194,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + [[package]] name = "fastrand" version = "2.3.0" @@ -1254,6 +1370,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.32" @@ -1332,7 +1459,7 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "wasm-bindgen", ] @@ -1451,6 +1578,8 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ + "allocator-api2", + "equivalent", "foldhash", ] @@ -1460,6 +1589,15 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown 0.15.5", +] + [[package]] name = "headers" version = "0.4.1" @@ -1496,13 +1634,40 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac 0.12.1", +] + [[package]] name = "hmac" version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" dependencies = [ - "digest", + "digest 0.10.7", +] + +[[package]] +name = "hmac" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6303bc9732ae41b04cb554b844a762b4115a61bfaa81e3e83050991eeb56863f" +dependencies = [ + "digest 0.11.2", +] + +[[package]] +name = "home" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" +dependencies = [ + "windows-sys 0.61.2", ] [[package]] @@ -1556,6 +1721,15 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "hybrid-array" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3944cf8cf766b40e2a1a333ee5e9b563f854d5fa49d6a8ca2764e97c6eddb214" +dependencies = [ + "typenum", +] + [[package]] name = "hyper" version = "1.9.0" @@ -1898,6 +2072,18 @@ version = "0.2.184" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af" +[[package]] +name = "libredox" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e02f3bb43d335493c96bf3fd3a321600bf6bd07ed34bc64118e9293bdffea46c" +dependencies = [ + "bitflags", + "libc", + "plain", + "redox_syscall 0.7.4", +] + [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -2033,7 +2219,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" dependencies = [ "cfg-if", - "digest", + "digest 0.10.7", +] + +[[package]] +name = "md-5" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69b6441f590336821bb897fb28fc622898ccceb1d6cea3fde5ea86b090c4de98" +dependencies = [ + "cfg-if", + "digest 0.11.2", ] [[package]] @@ -2091,7 +2287,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" dependencies = [ "libc", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "windows-sys 0.61.2", ] @@ -2128,9 +2324,9 @@ dependencies = [ "futures-io", "futures-util", "hex", - "hmac", + "hmac 0.12.1", "macro_magic", - "md-5", + "md-5 0.10.6", "mongocrypt", "mongodb-internal-macros", "pbkdf2", @@ -2143,7 +2339,7 @@ dependencies = [ "serde_bytes", "serde_with", "sha1", - "sha2", + "sha2 0.10.9", "socket2 0.6.3", "stringprep", "strsim", @@ -2293,6 +2489,24 @@ dependencies = [ "autocfg", ] +[[package]] +name = "objc2-core-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" +dependencies = [ + "bitflags", +] + +[[package]] +name = "objc2-system-configuration" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7216bd11cbda54ccabcab84d523dc93b858ec75ecfb3a7d89513fa22464da396" +dependencies = [ + "objc2-core-foundation", +] + [[package]] name = "oid-registry" version = "0.8.1" @@ -2344,7 +2558,7 @@ checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.18", "smallvec", "windows-link", ] @@ -2364,7 +2578,7 @@ version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" dependencies = [ - "digest", + "digest 0.10.7", ] [[package]] @@ -2383,6 +2597,25 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" +[[package]] +name = "phf" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1562dc717473dbaa4c1f85a36410e03c047b2e7df7f45ee938fbef64ae7fadf" +dependencies = [ + "phf_shared", + "serde", +] + +[[package]] +name = "phf_shared" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e57fef6bc5981e38c2ce2d63bfa546861309f875b8a75f092d1d54ae2d64f266" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.11" @@ -2415,6 +2648,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "polyval" version = "0.6.2" @@ -2427,6 +2666,51 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "postgres-protocol" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56201207dac53e2f38e848e31b4b91616a6bb6e0c7205b77718994a7f49e70fc" +dependencies = [ + "base64", + "byteorder", + "bytes", + "fallible-iterator", + "hmac 0.13.0", + "md-5 0.11.0", + "memchr", + "rand 0.10.0", + "sha2 0.11.0", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8dc729a129e682e8d24170cd30ae1aa01b336b096cbb56df6d534ffec133d186" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol", +] + +[[package]] +name = "postgres-whiteboard" +version = "0.1.0" +dependencies = [ + "axum", + "rmpv", + "serde", + "socketioxide", + "socketioxide-postgres", + "tokio", + "tower", + "tower-http", + "tracing", + "tracing-subscriber", +] + [[package]] name = "potential_utf" version = "0.1.5" @@ -2756,6 +3040,15 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_syscall" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f450ad9c3b1da563fb6948a8e0fb0fb9269711c9c73d9ea1de5058c79c8d643a" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.12.3" @@ -3392,7 +3685,7 @@ checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures 0.2.17", - "digest", + "digest 0.10.7", ] [[package]] @@ -3403,7 +3696,18 @@ checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", "cpufeatures 0.2.17", - "digest", + "digest 0.10.7", +] + +[[package]] +name = "sha2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "digest 0.11.2", ] [[package]] @@ -3443,6 +3747,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" +[[package]] +name = "siphasher" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" + [[package]] name = "slab" version = "0.4.12" @@ -3497,7 +3807,7 @@ dependencies = [ [[package]] name = "socketioxide" -version = "0.18.2" +version = "0.18.3" dependencies = [ "bytes", "engineioxide", @@ -3522,7 +3832,7 @@ dependencies = [ [[package]] name = "socketioxide-core" -version = "0.17.0" +version = "0.18.0" dependencies = [ "arbitrary", "bytes", @@ -3535,7 +3845,7 @@ dependencies = [ [[package]] name = "socketioxide-mongodb" -version = "0.1.0" +version = "0.1.3" dependencies = [ "bson 3.1.0", "futures-core", @@ -3553,7 +3863,7 @@ dependencies = [ [[package]] name = "socketioxide-parser-common" -version = "0.17.0" +version = "0.17.1" dependencies = [ "bytes", "itoa", @@ -3564,7 +3874,7 @@ dependencies = [ [[package]] name = "socketioxide-parser-msgpack" -version = "0.17.0" +version = "0.17.1" dependencies = [ "bytes", "rmp", @@ -3573,9 +3883,28 @@ dependencies = [ "socketioxide-core", ] +[[package]] +name = "socketioxide-postgres" +version = "0.1.0" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "serde", + "serde_json", + "smallvec", + "socketioxide-core", + "sqlx", + "thiserror 2.0.18", + "tokio", + "tokio-postgres", + "tracing", + "xxhash-rust", +] + [[package]] name = "socketioxide-redis" -version = "0.4.0" +version = "0.4.1" dependencies = [ "bytes", "fred", @@ -3599,6 +3928,124 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "sqlx" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fefb893899429669dcdd979aff487bd78f4064e5e7907e4269081e0ef7d97dc" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-postgres", +] + +[[package]] +name = "sqlx-core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee6798b1838b6a0f69c007c133b8df5866302197e404e8b6ee8ed3e3a5e68dc6" +dependencies = [ + "base64", + "bytes", + "crc", + "crossbeam-queue", + "either", + "event-listener", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashbrown 0.15.5", + "hashlink", + "indexmap", + "log", + "memchr", + "once_cell", + "percent-encoding", + "serde", + "serde_json", + "sha2 0.10.9", + "smallvec", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tracing", + "url", +] + +[[package]] +name = "sqlx-macros" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2d452988ccaacfbf5e0bdbc348fb91d7c8af5bee192173ac3636b5fb6e6715d" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19a9c1841124ac5a61741f96e1d9e2ec77424bf323962dd894bdb93f37d5219b" +dependencies = [ + "dotenvy", + "either", + "heck", + "hex", + "once_cell", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2 0.10.9", + "sqlx-core", + "sqlx-postgres", + "syn", + "tokio", + "url", +] + +[[package]] +name = "sqlx-postgres" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46" +dependencies = [ + "atoi", + "base64", + "bitflags", + "byteorder", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-util", + "hex", + "hkdf", + "hmac 0.12.1", + "home", + "itoa", + "log", + "md-5 0.10.6", + "memchr", + "once_cell", + "rand 0.8.5", + "serde", + "serde_json", + "sha2 0.10.9", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 2.0.18", + "tracing", + "whoami 1.6.1", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -3815,9 +4262,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.50.0" +version = "1.52.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" dependencies = [ "bytes", "libc", @@ -3832,15 +4279,41 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.6.1" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" dependencies = [ "proc-macro2", "quote", "syn", ] +[[package]] +name = "tokio-postgres" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dd8df5ef180f6364759a6f00f7aadda4fbbac86cdee37480826a6ff9f3574ce" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "rand 0.10.0", + "socket2 0.6.3", + "tokio", + "tokio-util", + "whoami 2.1.1", +] + [[package]] name = "tokio-rustls" version = "0.26.4" @@ -4155,7 +4628,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" dependencies = [ - "crypto-common", + "crypto-common 0.1.7", "subtle", ] @@ -4351,6 +4824,15 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasi" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + [[package]] name = "wasip2" version = "1.0.2+wasi-0.2.9" @@ -4369,6 +4851,21 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + +[[package]] +name = "wasite" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fe902b4a6b8028a753d5424909b764ccf79b7a209eac9bf97e59cda9f71a42" +dependencies = [ + "wasi 0.14.7+wasi-0.2.4", +] + [[package]] name = "wasm-bindgen" version = "0.2.117" @@ -4539,6 +5036,29 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "whoami" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" +dependencies = [ + "libredox", + "wasite 0.1.0", +] + +[[package]] +name = "whoami" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6a5b12f9df4f978d2cfdb1bd3bac52433f44393342d7ee9c25f5a1c14c0f45d" +dependencies = [ + "libc", + "libredox", + "objc2-system-configuration", + "wasite 1.0.2", + "web-sys", +] + [[package]] name = "winapi-util" version = "0.1.11" @@ -4625,6 +5145,15 @@ dependencies = [ "windows-targets 0.42.2", ] +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/examples/postgres-whiteboard/Cargo.toml b/examples/postgres-whiteboard/Cargo.toml new file mode 100644 index 00000000..bec473b2 --- /dev/null +++ b/examples/postgres-whiteboard/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "postgres-whiteboard" +version = "0.1.0" +edition = "2021" + +[dependencies] +socketioxide-postgres = { path = "../../crates/socketioxide-postgres", features = [ + "sqlx", + "tokio-postgres", +] } +socketioxide = { workspace = true, features = ["tracing", "msgpack"] } +axum.workspace = true +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tower-http.workspace = true +tower.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true +serde.workspace = true +rmpv.workspace = true + +[[bin]] +name = "sqlx" +path = "src/sqlx.rs" + +[[bin]] +name = "tokio_postgres" +path = "src/tokio_postgres.rs" diff --git a/examples/postgres-whiteboard/Readme.md b/examples/postgres-whiteboard/Readme.md new file mode 100644 index 00000000..a472ebc5 --- /dev/null +++ b/examples/postgres-whiteboard/Readme.md @@ -0,0 +1,6 @@ +# Same example than whiteboard but with a postgres adapter + +You can spawn as much as server as you want with different ports (env PORT) and then join with clients +connected on these different ports. + +The parser is set to msgpack in the example, but you can use any socket.io parser you want. diff --git a/examples/postgres-whiteboard/dist/index.html b/examples/postgres-whiteboard/dist/index.html new file mode 100644 index 00000000..e8a0f2aa --- /dev/null +++ b/examples/postgres-whiteboard/dist/index.html @@ -0,0 +1,23 @@ + + + + + Socket.IO whiteboard + + + + + + +
+
+
+
+
+
+
+ + + + + diff --git a/examples/postgres-whiteboard/dist/main.js b/examples/postgres-whiteboard/dist/main.js new file mode 100644 index 00000000..e41c37d4 --- /dev/null +++ b/examples/postgres-whiteboard/dist/main.js @@ -0,0 +1,124 @@ +"use strict"; + +(function () { + const params = new URLSearchParams(window.location.search); + var socket = io(); + var canvas = document.getElementsByClassName("whiteboard")[0]; + var colors = document.getElementsByClassName("color"); + var context = canvas.getContext("2d"); + + var current = { + color: "black", + }; + var drawing = false; + + canvas.addEventListener("mousedown", onMouseDown, false); + canvas.addEventListener("mouseup", onMouseUp, false); + canvas.addEventListener("mouseout", onMouseUp, false); + canvas.addEventListener("mousemove", throttle(onMouseMove, 10), false); + + //Touch support for mobile devices + canvas.addEventListener("touchstart", onMouseDown, false); + canvas.addEventListener("touchend", onMouseUp, false); + canvas.addEventListener("touchcancel", onMouseUp, false); + canvas.addEventListener("touchmove", throttle(onMouseMove, 10), false); + + for (var i = 0; i < colors.length; i++) { + colors[i].addEventListener("click", onColorUpdate, false); + } + + socket.on("drawing", onDrawingEvent); + + window.addEventListener("resize", onResize, false); + onResize(); + + function drawLine(x0, y0, x1, y1, color, emit) { + context.beginPath(); + context.moveTo(x0, y0); + context.lineTo(x1, y1); + context.strokeStyle = color; + context.lineWidth = 2; + context.stroke(); + context.closePath(); + + if (!emit) { + return; + } + var w = canvas.width; + var h = canvas.height; + + socket.emit("drawing", { + x0: x0 / w, + y0: y0 / h, + x1: x1 / w, + y1: y1 / h, + color: color, + }); + } + + function onMouseDown(e) { + drawing = true; + current.x = e.clientX || e.touches[0].clientX; + current.y = e.clientY || e.touches[0].clientY; + } + + function onMouseUp(e) { + if (!drawing) { + return; + } + drawing = false; + drawLine( + current.x, + current.y, + e.clientX || e.touches[0].clientX, + e.clientY || e.touches[0].clientY, + current.color, + true, + ); + } + + function onMouseMove(e) { + if (!drawing) { + return; + } + drawLine( + current.x, + current.y, + e.clientX || e.touches[0].clientX, + e.clientY || e.touches[0].clientY, + current.color, + true, + ); + current.x = e.clientX || e.touches[0].clientX; + current.y = e.clientY || e.touches[0].clientY; + } + + function onColorUpdate(e) { + current.color = e.target.className.split(" ")[1]; + } + + // limit the number of events per second + function throttle(callback, delay) { + var previousCall = new Date().getTime(); + return function () { + var time = new Date().getTime(); + + if (time - previousCall >= delay) { + previousCall = time; + callback.apply(null, arguments); + } + }; + } + + function onDrawingEvent(data) { + var w = canvas.width; + var h = canvas.height; + drawLine(data.x0 * w, data.y0 * h, data.x1 * w, data.y1 * h, data.color); + } + + // make the canvas fill its parent + function onResize() { + canvas.width = window.innerWidth; + canvas.height = window.innerHeight; + } +})(); diff --git a/examples/postgres-whiteboard/dist/style.css b/examples/postgres-whiteboard/dist/style.css new file mode 100644 index 00000000..437a29cf --- /dev/null +++ b/examples/postgres-whiteboard/dist/style.css @@ -0,0 +1,44 @@ + +/** + * Fix user-agent + */ + +* { + box-sizing: border-box; +} + +html, body { + height: 100%; + margin: 0; + padding: 0; +} + +/** + * Canvas + */ + +.whiteboard { + height: 100%; + width: 100%; + position: absolute; + left: 0; + right: 0; + bottom: 0; + top: 0; +} + +.colors { + position: fixed; +} + +.color { + display: inline-block; + height: 48px; + width: 48px; +} + +.color.black { background-color: black; } +.color.red { background-color: red; } +.color.green { background-color: green; } +.color.blue { background-color: blue; } +.color.yellow { background-color: yellow; } diff --git a/examples/postgres-whiteboard/src/sqlx.rs b/examples/postgres-whiteboard/src/sqlx.rs new file mode 100644 index 00000000..81ac166e --- /dev/null +++ b/examples/postgres-whiteboard/src/sqlx.rs @@ -0,0 +1,60 @@ +//! A simple whiteboard example using Redis as the adapter. +//! It uses the fred crate to connect to a Redis server. +use rmpv::Value; +use socketioxide::{ + adapter::Adapter, + extract::{Data, SocketRef}, + ParserConfig, SocketIo, +}; +use socketioxide_postgres::{drivers::sqlx::sqlx_client as sqlx, PostgresAdapterCtr, SqlxAdapter}; +use std::str::FromStr; +use tower::ServiceBuilder; +use tower_http::{cors::CorsLayer, services::ServeDir}; +use tracing::info; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with(EnvFilter::from_default_env()) + .init(); + + info!("connecting to postgres"); + let client = sqlx::Pool::connect("postgres://socketio:socketio@localhost/socketio").await?; + let adapter = PostgresAdapterCtr::new_with_sqlx(client); + info!("starting server"); + + let (layer, io) = SocketIo::builder() + .with_parser(ParserConfig::msgpack()) + .with_adapter::>(adapter) + .build_layer(); + + // It is heavily recommended to use generic fns instead of closures for handlers. + // This allows to be generic over the adapter you want to use. + async fn on_drawing(s: SocketRef
, Data(data): Data) { + s.broadcast().emit("drawing", &data).await.ok(); + } + async fn on_connect(s: SocketRef) { + s.on("drawing", on_drawing); + } + io.ns("/", on_connect).await?; + + let app = axum::Router::new() + .fallback_service(ServeDir::new("dist")) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) // Enable CORS policy + .layer(layer), + ); + + let port: u16 = std::env::var("PORT") + .map(|s| u16::from_str(&s).unwrap()) + .unwrap_or(3000); + let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) + .await + .unwrap(); + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/examples/postgres-whiteboard/src/tokio_postgres.rs b/examples/postgres-whiteboard/src/tokio_postgres.rs new file mode 100644 index 00000000..5f9222de --- /dev/null +++ b/examples/postgres-whiteboard/src/tokio_postgres.rs @@ -0,0 +1,66 @@ +//! A simple whiteboard example using Redis as the adapter. +//! It uses the redis crate to connect to a Redis server. +use std::str::FromStr; + +use rmpv::Value; +use socketioxide::{ + adapter::Adapter, + extract::{Data, SocketRef}, + ParserConfig, SocketIo, +}; +use socketioxide_postgres::{ + drivers::tokio_postgres::tokio_postgres_client as tokio_postgres, PostgresAdapterCtr, + TokioPostgresAdapter, +}; +use tower::ServiceBuilder; +use tower_http::{cors::CorsLayer, services::ServeDir}; +use tracing::info; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +use tokio_postgres::{Config, NoTls}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with(EnvFilter::from_default_env()) + .init(); + + info!("connecting to redis"); + let config: Config = "postgres://socketio:socketio@localhost/socketio".parse()?; + let adapter = PostgresAdapterCtr::new_with_tokio_postgres(config, NoTls).await?; + info!("starting server"); + + let (layer, io) = SocketIo::builder() + .with_parser(ParserConfig::msgpack()) + .with_adapter::>(adapter) + .build_layer(); + + // It is heavily recommended to use generic fns instead of closures for handlers. + // This allows to be generic over the adapter you want to use. + async fn on_drawing(s: SocketRef, Data(data): Data) { + s.broadcast().emit("drawing", &data).await.ok(); + } + async fn on_connect(s: SocketRef) { + s.on("drawing", on_drawing); + } + io.ns("/", on_connect).await?; + + let app = axum::Router::new() + .fallback_service(ServeDir::new("dist")) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) // Enable CORS policy + .layer(layer), + ); + + let port: u16 = std::env::var("PORT") + .map(|s| u16::from_str(&s).unwrap()) + .unwrap_or(3000); + let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) + .await + .unwrap(); + axum::serve(listener, app).await?; + + Ok(()) +} From 16bfddce1ba40022ff2f203f10ad4686b73ff817 Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 18 Apr 2026 14:46:34 +0200 Subject: [PATCH 21/31] fix: remove dbg logging --- crates/socketioxide-postgres/src/drivers/tokio_postgres.rs | 2 +- crates/socketioxide-postgres/src/lib.rs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs index 4f142ede..159b2f52 100644 --- a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs @@ -66,7 +66,7 @@ impl TokioPostgresDriver { let (client, mut conn) = config.connect(tls).await?; let listeners = Arc::new(RwLock::new(Vec::new())); - let stream = stream::poll_fn(move |cx| dbg!(conn.poll_message(cx))); + let stream = stream::poll_fn(move |cx| conn.poll_message(cx)); tokio::spawn(stream.forward(sink::unfold(listeners.clone(), dispatch_notifs))); let driver = TokioPostgresDriver { diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 0d3e0c13..04ac18f0 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -588,7 +588,6 @@ impl CustomPostgresAdapter { while let Some(notif) = stream.next().await { let chan = notif.channel(); let resp_chan = self.get_response_chan(self.local.server_id()); - tracing::info!(chan, resp_chan, notif = notif.payload(), ""); if chan == resp_chan { match serde_json::from_str(notif.payload()) { Ok(ResponsePacket { From a4167061e0cf190a6665cf5b2b9c9c1a88896476 Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 18 Apr 2026 17:21:55 +0200 Subject: [PATCH 22/31] fix(adapter/postgres): minor fixes --- crates/socketioxide-postgres/src/drivers/tokio_postgres.rs | 2 +- crates/socketioxide-postgres/tests/fixture.rs | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs index 159b2f52..a86d21a5 100644 --- a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs @@ -109,7 +109,7 @@ impl Driver for TokioPostgresDriver { let query: String = channels .iter() - .map(|c| format!(r#"LISTEN "{c}"; "#)) + .map(|chan| format!(r#"LISTEN "{chan}";"#)) .collect(); self.client.batch_execute(&query).await?; diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs index 4974c177..fb17f2f1 100644 --- a/crates/socketioxide-postgres/tests/fixture.rs +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -46,9 +46,11 @@ pub fn spawn_buggy_servers( let payload = serde_json::to_string(&heartbeat_json).unwrap(); for (_, tx) in sync_buff.read().unwrap().iter() { + let hash = xxhash_rust::xxh3::xxh3_64("socket.io#/".as_bytes()); + let channel = format!("ch_{:x}", hash); // Send the heartbeat to the global channel of the "/" namespace tx.try_send(StubNotification { - channel: "socket.io#/".to_string(), + channel, payload: payload.clone(), }) .unwrap(); From 5aac87a5466cf9da8c0bee44f0157d361bf29b69 Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 18 Apr 2026 17:28:39 +0200 Subject: [PATCH 23/31] docs(example): add postgres adapter examples --- examples/Cargo.lock | 1 + examples/Cargo.toml | 4 + examples/chat/Cargo.toml | 9 + examples/chat/README.md | 4 + examples/chat/src/postgres/sqlx.rs | 182 ++++++++++++++++ examples/chat/src/postgres/tokio_postgres.rs | 194 ++++++++++++++++++ examples/postgres-whiteboard/Cargo.toml | 5 +- examples/postgres-whiteboard/src/sqlx.rs | 4 +- .../postgres-whiteboard/src/tokio_postgres.rs | 6 +- 9 files changed, 400 insertions(+), 9 deletions(-) create mode 100644 examples/chat/src/postgres/sqlx.rs create mode 100644 examples/chat/src/postgres/tokio_postgres.rs diff --git a/examples/Cargo.lock b/examples/Cargo.lock index 6c05d952..69c876b9 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -3797,6 +3797,7 @@ dependencies = [ "serde", "socketioxide", "socketioxide-mongodb", + "socketioxide-postgres", "socketioxide-redis", "tokio", "tower", diff --git a/examples/Cargo.toml b/examples/Cargo.toml index b39a5831..de6a0596 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -22,4 +22,8 @@ socketioxide-redis = { path = "../crates/socketioxide-redis", features = [ "fred", ] } socketioxide-mongodb = { path = "../crates/socketioxide-mongodb" } +socketioxide-postgres = { path = "../crates/socketioxide-postgres", features = [ + "sqlx", + "tokio-postgres", +] } serde_json = "1" diff --git a/examples/chat/Cargo.toml b/examples/chat/Cargo.toml index ce595c78..8b380c15 100644 --- a/examples/chat/Cargo.toml +++ b/examples/chat/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" socketioxide = { workspace = true, features = ["extensions", "state"] } socketioxide-redis.workspace = true socketioxide-mongodb.workspace = true +socketioxide-postgres.workspace = true axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tower-http.workspace = true @@ -39,3 +40,11 @@ path = "src/mongodb/mongodb_capped.rs" [[bin]] name = "mongodb-adapter-ttl" path = "src/mongodb/mongodb_ttl.rs" + +[[bin]] +name = "postgres-adapter-sqlx" +path = "src/postgres/sqlx.rs" + +[[bin]] +name = "postgres-adapter-tokio-postgres" +path = "src/postgres/tokio_postgres.rs" diff --git a/examples/chat/README.md b/examples/chat/README.md index 548f2ef2..aff7167a 100644 --- a/examples/chat/README.md +++ b/examples/chat/README.md @@ -9,3 +9,7 @@ This example also include all the available adapters examples: * [Standalone redis](./src/redis/redis.rs) * [Clustered redis](./src/redis/redis_cluster.rs) * [Redis with the fred crate](./src/redis/redis_fred.rs) +* [MongoDB with a capped collection](./src/mongodb/mongodb_capped.rs) +* [MongoDB with a TTL index](./src/mongodb/mongodb_ttl.rs) +* [Postgres with the sqlx crate](./src/postgres/sqlx.rs) +* [Postgres with the tokio-postgres crate](./src/postgres/tokio_postgres.rs) diff --git a/examples/chat/src/postgres/sqlx.rs b/examples/chat/src/postgres/sqlx.rs new file mode 100644 index 00000000..a8a153a1 --- /dev/null +++ b/examples/chat/src/postgres/sqlx.rs @@ -0,0 +1,182 @@ +use serde::{Deserialize, Serialize}; +use socketioxide::{ + adapter::Adapter, + extract::{Data, Extension, SocketRef, State}, + SocketIo, +}; +use socketioxide_postgres::{drivers::sqlx::sqlx_client as sqlx, PostgresAdapterCtr, SqlxAdapter}; +use sqlx::PgPool; +use tower::ServiceBuilder; +use tower_http::{cors::CorsLayer, services::ServeDir}; +use tracing::info; +use tracing_subscriber::FmtSubscriber; + +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(transparent)] +struct Username(String); + +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(rename_all = "camelCase", untagged)] +enum Res { + Login { + #[serde(rename = "numUsers")] + num_users: usize, + }, + UserEvent { + #[serde(rename = "numUsers")] + num_users: usize, + username: Username, + }, + Message { + username: Username, + message: String, + }, + Username { + username: Username, + }, +} +#[derive(Clone)] +struct RemoteUserCnt(PgPool); +impl RemoteUserCnt { + fn new(pool: PgPool) -> Self { + Self(pool) + } + + async fn init(&self) -> Result<(), sqlx::Error> { + sqlx::query( + r#"CREATE TABLE IF NOT EXISTS socket_io_chat_users ( + id INT PRIMARY KEY, + count BIGINT NOT NULL DEFAULT 0 + )"#, + ) + .execute(&self.0) + .await?; + + sqlx::query( + "INSERT INTO socket_io_chat_users (id, count) VALUES (1, 0) ON CONFLICT DO NOTHING", + ) + .execute(&self.0) + .await?; + + Ok(()) + } + + async fn add_user(&self) -> Result { + let count: i64 = sqlx::query_scalar( + "UPDATE socket_io_chat_users SET count = count + 1 WHERE id = 1 RETURNING count", + ) + .fetch_one(&self.0) + .await?; + Ok(count as usize) + } + + async fn remove_user(&self) -> Result { + let count: i64 = sqlx::query_scalar( + "UPDATE socket_io_chat_users SET count = count - 1 WHERE id = 1 RETURNING count", + ) + .fetch_one(&self.0) + .await?; + Ok(count as usize) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::new(); + + tracing::subscriber::set_global_default(subscriber)?; + + info!("Starting server"); + + let pool = sqlx::Pool::connect("postgres://socketio:socketio@localhost/socketio").await?; + let user_cnt = RemoteUserCnt::new(pool.clone()); + user_cnt.init().await?; + let adapter = PostgresAdapterCtr::new_with_sqlx(pool); + + let (layer, io) = SocketIo::builder() + .with_state(user_cnt) + .with_adapter::>(adapter) + .build_layer(); + + io.ns("/", on_connect).await?; + + let app = axum::Router::new() + .fallback_service(ServeDir::new("dist")) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) // Enable CORS policy + .layer(layer), + ); + + let port = std::env::var("PORT") + .map(|s| s.parse().unwrap()) + .unwrap_or(3000); + + let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) + .await + .unwrap(); + axum::serve(listener, app).await.unwrap(); + + Ok(()) +} + +async fn on_connect(socket: SocketRef) { + socket.on("new message", on_msg); + socket.on("add user", on_add_user); + socket.on("typing", on_typing); + socket.on("stop typing", on_stop_typing); + socket.on_disconnect(on_disconnect); +} +async fn on_msg( + s: SocketRef, + Data(msg): Data, + Extension(username): Extension, +) { + let msg = &Res::Message { + username, + message: msg, + }; + s.broadcast().emit("new message", msg).await.ok(); +} +async fn on_add_user( + s: SocketRef, + Data(username): Data, + user_cnt: State, +) { + if s.extensions.get::().is_some() { + return; + } + let num_users = user_cnt.add_user().await.unwrap_or(0); + s.extensions.insert(Username(username.clone())); + s.emit("login", &Res::Login { num_users }).ok(); + + let res = &Res::UserEvent { + num_users, + username: Username(username), + }; + s.broadcast().emit("user joined", res).await.ok(); +} +async fn on_typing(s: SocketRef, Extension(username): Extension) { + s.broadcast() + .emit("typing", &Res::Username { username }) + .await + .ok(); +} +async fn on_stop_typing(s: SocketRef, Extension(username): Extension) { + s.broadcast() + .emit("stop typing", &Res::Username { username }) + .await + .ok(); +} +async fn on_disconnect( + s: SocketRef, + user_cnt: State, + Extension(username): Extension, +) { + let num_users = user_cnt.remove_user().await.unwrap_or(0); + let res = &Res::UserEvent { + num_users, + username, + }; + s.broadcast().emit("user left", res).await.ok(); +} diff --git a/examples/chat/src/postgres/tokio_postgres.rs b/examples/chat/src/postgres/tokio_postgres.rs new file mode 100644 index 00000000..15811c9d --- /dev/null +++ b/examples/chat/src/postgres/tokio_postgres.rs @@ -0,0 +1,194 @@ +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use socketioxide::{ + adapter::Adapter, + extract::{Data, Extension, SocketRef, State}, + SocketIo, +}; +use socketioxide_postgres::{ + drivers::tokio_postgres::tokio_postgres_client as tokio_postgres, PostgresAdapterCtr, + TokioPostgresAdapter, +}; +use tokio_postgres::{Client, Config, NoTls}; +use tower::ServiceBuilder; +use tower_http::{cors::CorsLayer, services::ServeDir}; +use tracing::info; +use tracing_subscriber::FmtSubscriber; + +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(transparent)] +struct Username(String); + +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(rename_all = "camelCase", untagged)] +enum Res { + Login { + #[serde(rename = "numUsers")] + num_users: usize, + }, + UserEvent { + #[serde(rename = "numUsers")] + num_users: usize, + username: Username, + }, + Message { + username: Username, + message: String, + }, + Username { + username: Username, + }, +} +#[derive(Clone)] +struct RemoteUserCnt(Arc); +impl RemoteUserCnt { + fn new(client: Arc) -> Self { + Self(client) + } + async fn init(&self) -> Result<(), tokio_postgres::Error> { + self.0 + .batch_execute( + r#"CREATE TABLE IF NOT EXISTS socket_io_chat_users ( + id INT PRIMARY KEY, + count BIGINT NOT NULL DEFAULT 0 + ); + INSERT INTO socket_io_chat_users (id, count) VALUES (1, 0) ON CONFLICT DO NOTHING;"#, + ) + .await?; + + Ok(()) + } + + async fn add_user(&self) -> Result { + let row = self + .0 + .query_one( + "UPDATE socket_io_chat_users SET count = count + 1 WHERE id = 1 RETURNING count", + &[], + ) + .await?; + + let count: i64 = row.get(0); + Ok(count as usize) + } + async fn remove_user(&self) -> Result { + let row = self + .0 + .query_one( + "UPDATE socket_io_chat_users SET count = count - 1 WHERE id = 1 RETURNING count", + &[], + ) + .await?; + + let count: i64 = row.get(0); + Ok(count as usize) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::new(); + + tracing::subscriber::set_global_default(subscriber)?; + + info!("Starting server"); + + let config: Config = "postgres://socketio:socketio@localhost/socketio".parse()?; + let (client, conn) = config.connect(NoTls).await?; + tokio::spawn(async move { + if let Err(e) = conn.await { + tracing::error!("postgres connection error: {e}"); + } + }); + let user_cnt = RemoteUserCnt::new(Arc::new(client)); + user_cnt.init().await?; + + let adapter = PostgresAdapterCtr::new_with_tokio_postgres(config, NoTls).await?; + + let (layer, io) = SocketIo::builder() + .with_state(user_cnt) + .with_adapter::>(adapter) + .build_layer(); + + io.ns("/", on_connect).await?; + + let app = axum::Router::new() + .fallback_service(ServeDir::new("dist")) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) // Enable CORS policy + .layer(layer), + ); + + let port = std::env::var("PORT") + .map(|s| s.parse().unwrap()) + .unwrap_or(3000); + let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) + .await + .unwrap(); + axum::serve(listener, app).await.unwrap(); + + Ok(()) +} + +async fn on_connect(socket: SocketRef) { + socket.on("new message", on_msg); + socket.on("add user", on_add_user); + socket.on("typing", on_typing); + socket.on("stop typing", on_stop_typing); + socket.on_disconnect(on_disconnect); +} +async fn on_msg( + s: SocketRef, + Data(msg): Data, + Extension(username): Extension, +) { + let msg = &Res::Message { + username, + message: msg, + }; + s.broadcast().emit("new message", msg).await.ok(); +} +async fn on_add_user( + s: SocketRef, + Data(username): Data, + user_cnt: State, +) { + if s.extensions.get::().is_some() { + return; + } + let num_users = user_cnt.add_user().await.unwrap_or(0); + s.extensions.insert(Username(username.clone())); + s.emit("login", &Res::Login { num_users }).ok(); + + let res = &Res::UserEvent { + num_users, + username: Username(username), + }; + s.broadcast().emit("user joined", res).await.ok(); +} +async fn on_typing(s: SocketRef, Extension(username): Extension) { + s.broadcast() + .emit("typing", &Res::Username { username }) + .await + .ok(); +} +async fn on_stop_typing(s: SocketRef, Extension(username): Extension) { + s.broadcast() + .emit("stop typing", &Res::Username { username }) + .await + .ok(); +} +async fn on_disconnect( + s: SocketRef, + user_cnt: State, + Extension(username): Extension, +) { + let num_users = user_cnt.remove_user().await.unwrap_or(0); + let res = &Res::UserEvent { + num_users, + username, + }; + s.broadcast().emit("user left", res).await.ok(); +} diff --git a/examples/postgres-whiteboard/Cargo.toml b/examples/postgres-whiteboard/Cargo.toml index bec473b2..83fddbc3 100644 --- a/examples/postgres-whiteboard/Cargo.toml +++ b/examples/postgres-whiteboard/Cargo.toml @@ -4,10 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -socketioxide-postgres = { path = "../../crates/socketioxide-postgres", features = [ - "sqlx", - "tokio-postgres", -] } +socketioxide-postgres.workspace = true socketioxide = { workspace = true, features = ["tracing", "msgpack"] } axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/examples/postgres-whiteboard/src/sqlx.rs b/examples/postgres-whiteboard/src/sqlx.rs index 81ac166e..86100c77 100644 --- a/examples/postgres-whiteboard/src/sqlx.rs +++ b/examples/postgres-whiteboard/src/sqlx.rs @@ -1,5 +1,5 @@ -//! A simple whiteboard example using Redis as the adapter. -//! It uses the fred crate to connect to a Redis server. +//! A simple whiteboard example using Postgres as the adapter. +//! It uses the sqlx crate to connect to a Postgres server. use rmpv::Value; use socketioxide::{ adapter::Adapter, diff --git a/examples/postgres-whiteboard/src/tokio_postgres.rs b/examples/postgres-whiteboard/src/tokio_postgres.rs index 5f9222de..3d8d21fc 100644 --- a/examples/postgres-whiteboard/src/tokio_postgres.rs +++ b/examples/postgres-whiteboard/src/tokio_postgres.rs @@ -1,5 +1,5 @@ -//! A simple whiteboard example using Redis as the adapter. -//! It uses the redis crate to connect to a Redis server. +//! A simple whiteboard example using postgres as the adapter. +//! It uses the tokio_postgres crate to connect to a Postgtes server. use std::str::FromStr; use rmpv::Value; @@ -26,7 +26,7 @@ async fn main() -> Result<(), Box> { .with(EnvFilter::from_default_env()) .init(); - info!("connecting to redis"); + info!("connecting to postgres"); let config: Config = "postgres://socketio:socketio@localhost/socketio".parse()?; let adapter = PostgresAdapterCtr::new_with_tokio_postgres(config, NoTls).await?; info!("starting server"); From 3afb059753a43fda4547af4007b246a7fd5ff671 Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 18 Apr 2026 18:58:55 +0200 Subject: [PATCH 24/31] wip --- .../src/adapter/remote_packet.rs | 19 ++++++ .../socketioxide-postgres/src/drivers/mod.rs | 13 ++++ .../socketioxide-postgres/src/drivers/sqlx.rs | 23 +++++++ .../src/drivers/tokio_postgres.rs | 11 ++++ crates/socketioxide-postgres/src/lib.rs | 55 ++++++++++++---- crates/socketioxide-postgres/src/stream.rs | 62 ++++++++++++++++--- crates/socketioxide-postgres/tests/fixture.rs | 20 +++++- 7 files changed, 182 insertions(+), 21 deletions(-) diff --git a/crates/socketioxide-core/src/adapter/remote_packet.rs b/crates/socketioxide-core/src/adapter/remote_packet.rs index b936ea35..d1af0b48 100644 --- a/crates/socketioxide-core/src/adapter/remote_packet.rs +++ b/crates/socketioxide-core/src/adapter/remote_packet.rs @@ -104,6 +104,17 @@ impl<'a> RequestOut<'a> { opts: None, } } + + /// The request is binary if it is a [`RequestTypeOut::Broadcast`] or [`RequestTypeOut::BroadcastWithAck`] + /// with a binary payload. + pub fn is_binary(&self) -> bool { + match self.r#type { + RequestTypeOut::Broadcast(p) | RequestTypeOut::BroadcastWithAck(p) => { + p.inner.is_binary() + } + _ => false, + } + } } /// Custom implementation to serialize enum variant as u8. @@ -281,6 +292,14 @@ impl Response { _ => None, } } + + /// The response is binary if it is a [`ResponseType::BroadcastAck`] with a binary payload. + pub fn is_binary(&self) -> bool { + matches!( + self.r#type, + ResponseType::BroadcastAck((_, Ok(Value::Bytes(_) | Value::Str(_, Some(_))))) + ) + } } #[cfg(test)] diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index eea95575..adb1cff8 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -39,6 +39,19 @@ pub trait Driver: Clone + Send + Sync + 'static { message: &str, ) -> impl Future> + Send; + /// Push an attachment when deferring a NOTIFY message to the attachment table. + fn push_attachment( + &self, + table: &str, + attachment: &[u8], + ) -> impl Future> + Send; + + fn get_attachment( + &self, + table: &str, + id: i32, + ) -> impl Future, Self::Error>> + Send; + /// UNLISTEN from every channel. fn close(&self) -> impl Future> + Send; } diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index f8fac832..5fc1e522 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -67,6 +67,29 @@ impl Driver for SqlxDriver { Ok(()) } + async fn push_attachment(&self, table: &str, attachment: &[u8]) -> Result { + let query = + format!("INSERT INTO {table} (id, created_at, payload) VALUES (?, ?, $2) RETURNING id"); + + let id: i32 = sqlx::query_scalar(&query) + .bind(attachment) + .fetch_one(&self.client) + .await?; + + Ok(id) + } + + async fn get_attachment(&self, table: &str, id: i32) -> Result, Self::Error> { + let query = format!("SELECT payload FROM {table} WHERE id = $1"); + + let attachment: Vec = sqlx::query_scalar(&query) + .bind(id) + .fetch_one(&self.client) + .await?; + + Ok(attachment) + } + async fn close(&self) -> Result<(), Self::Error> { // PgListener will automatically unlisten channel when being dropped Ok(()) diff --git a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs index a86d21a5..93e78c2c 100644 --- a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs @@ -125,6 +125,17 @@ impl Driver for TokioPostgresDriver { Ok(()) } + async fn push_attachment(&self, table: &str, attachment: &[u8]) -> Result { + let query = + format!("INSERT INTO {table} (id, created_at, payload) VALUES (?, ?, $2) RETURNING id"); + self.client.query_one_scalar(&query, &[&attachment]).await + } + + async fn get_attachment(&self, table: &str, id: i32) -> Result, Self::Error> { + let query = format!("SELECT payload FROM {table} WHERE id = $1"); + self.client.query_one_scalar(&query, &[&id]).await + } + async fn close(&self) -> Result<(), Self::Error> { self.client.execute("UNLISTEN *", &[]).await?; Ok(()) diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 04ac18f0..23bdad12 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -300,7 +300,7 @@ pub type SqlxAdapter = CustomPostgresAdapter; pub type TokioPostgresAdapter = CustomPostgresAdapter; -type ResponseHandlers = HashMap>>; +type ResponseHandlers = HashMap>; /// The postgres adapter implementation. /// It is generic over the [`Driver`] used to communicate with the postgres server. @@ -790,7 +790,18 @@ impl CustomPostgresAdapter { Some(target) => self.get_node_chan(target), None => self.get_global_chan(), }; - let payload = serde_json::to_string(&req)?; + + let mut payload = serde_json::to_string(&req)?; + if payload.len() > self.config.payload_threshold { + let attachment_id = self + .driver + .push_attachment(&self.config.table_name, payload.as_bytes()) + .await + .map_err(Error::Driver)?; + + payload = attachment_id.to_string(); + } + self.driver .notify(&chan, &payload) .await @@ -816,7 +827,7 @@ impl CustomPostgresAdapter { .map(|payload| ResponsePacket { req_id, node_id: self.local.server_id(), - payload, + payload: ResponsePayload::Data(payload), //TODO: defer to attachments }) .and_then(|res| serde_json::to_string(&res)); @@ -849,15 +860,22 @@ impl CustomPostgresAdapter { let stream = ChanStream::new(rx); let stream = stream - .filter_map(|payload| { - let data = match serde_json::from_str::>(payload.get()) { - Ok(data) => Some(data), - Err(e) => { - tracing::warn!("error decoding response: {e}"); - None - } - }; - future::ready(data) + .filter_map(async move |payload| match payload { + //TODO: response chan stream + ResponsePayload::Data(data) => serde_json::from_str::>(data.get()) + .inspect_err(|err| tracing::warn!("error decoding response: {err}")) + .ok(), + ResponsePayload::Attachment(id) => self + .driver + .get_attachment(&self.config.table_name, id) + .await + .inspect_err(|err| tracing::warn!("error fetching attachment: {err}")) + .ok() + .and_then(|data| { + serde_json::from_slice::>(&data) + .inspect_err(|err| tracing::warn!("error decoding response: {err}")) + .ok() + }), }) .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type)) .take(remote_serv_cnt) @@ -933,9 +951,20 @@ impl Spawnable for InitRes { } } +#[derive(Debug, Serialize, Deserialize)] +enum RequestPacket { + Request(T), + RequestWithAttachment(i32), +} + #[derive(Deserialize, Serialize)] struct ResponsePacket { req_id: Sid, node_id: Uid, - payload: Box, + payload: ResponsePayload, +} +#[derive(Deserialize, Serialize)] +enum ResponsePayload { + Data(Box), + Attachment(i32), } diff --git a/crates/socketioxide-postgres/src/stream.rs b/crates/socketioxide-postgres/src/stream.rs index 3fec6df7..73a0797c 100644 --- a/crates/socketioxide-postgres/src/stream.rs +++ b/crates/socketioxide-postgres/src/stream.rs @@ -1,12 +1,8 @@ use std::{ - fmt, - pin::Pin, - sync::{Arc, Mutex}, - task::{self, Poll}, - time::Duration, + borrow::Cow, fmt, pin::Pin, sync::{Arc, Mutex}, task::{self, Poll}, time::Duration }; -use futures_core::{FusedStream, Stream}; +use futures_core::{FusedStream, Stream, future::BoxFuture, ready}; use futures_util::{StreamExt, stream::TakeUntil}; use pin_project_lite::pin_project; use serde::de::DeserializeOwned; @@ -20,7 +16,10 @@ use socketioxide_core::{ }; use tokio::{sync::mpsc, time}; -use crate::{ResponseHandlers, drivers::Notification}; +use crate::{ + ResponseHandlers, ResponsePayload, + drivers::{Driver, Notification}, +}; pin_project! { /// A stream of acknowledgement messages received from the local and remote servers. @@ -237,6 +236,55 @@ impl FusedStream for DropStream { } } +pin_project! { + struct RemoteAckStream { + driver: D, + table: Cow<'static, str>, + #[pin] + inner: ChanStream, + #[pin] + state: RemoteAckStreamState> + 'static>>, + } +} + +pin_project! { + #[project = RemoteAckStreamStateProj] + enum RemoteAckStreamState { + Pending{ #[pin] fut: Box, Error>> + 'static> }, + Done, + } +} + +impl Stream for RemoteAckStream { + type Item = Box; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + let proj = self.project(); + match proj.state.project() { + RemoteAckStreamStateProj::Pending { fut } => match ready!(fut.poll(cx)) { + Ok(value) => { + proj.state.set(RemoteAckStreamState::Done); + cx.waker().wake_by_ref(); + return Poll::Ready(Some(value)) + }, + Err(err) => Poll::Ready(Some(Box::new(Value::String(err.to_string())))), + }, + RemoteAckStreamStateProj::Done => (), + }; + + match ready!(self.project().inner.poll_next(cx)) { + Some(ResponsePayload::Data(data)) => Poll::Ready(Some(data)), + Some(ResponsePayload::Attachment(id)) => self.driver.get_attachment(&self.table, id) + None => Poll::Ready(None), + } + } +} +impl FusedStream for RemoteAckStream { + fn is_terminated(&self) -> bool { + self.inner.is_terminated() + } +} + #[cfg(test)] mod tests { use futures_core::FusedStream; diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs index fb17f2f1..3595f89f 100644 --- a/crates/socketioxide-postgres/tests/fixture.rs +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -7,10 +7,11 @@ use socketioxide_postgres::{ drivers::{Driver, Notification}, }; use std::{ + collections::HashMap, convert::Infallible, pin::Pin, str::FromStr, - sync::{Arc, RwLock}, + sync::{Arc, RwLock, atomic::AtomicI32}, task, time::Duration, }; @@ -120,6 +121,8 @@ pub struct StubDriver { tx: mpsc::Sender, /// Handlers for incoming notifications per listened channel. handlers: Arc>, + attachments: Arc>>>, + attachment_idx: Arc, } impl StubDriver { @@ -140,6 +143,8 @@ impl StubDriver { server_id, tx, handlers, + attachments: Arc::new(RwLock::new(HashMap::new())), + attachment_idx: Arc::new(AtomicI32::new(0)), }; (driver, rx, tx1) } @@ -218,6 +223,19 @@ impl Driver for StubDriver { Ok(()) } + async fn push_attachment(&self, _table: &str, attachment: &[u8]) -> Result { + let id = self + .attachment_idx + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + + self.attachments + .write() + .unwrap() + .insert(id, attachment.to_vec()); + + Ok(id) + } + async fn close(&self) -> Result<(), Self::Error> { Ok(()) } From 9617bc989d8f5aaac0ea23b82d1ff678b9125b61 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 19 Apr 2026 11:20:16 +0200 Subject: [PATCH 25/31] wip --- .../socketioxide-postgres/src/drivers/mod.rs | 1 + .../socketioxide-postgres/src/drivers/sqlx.rs | 3 +- .../src/drivers/tokio_postgres.rs | 3 +- crates/socketioxide-postgres/src/lib.rs | 228 +++++++++++++----- crates/socketioxide-postgres/src/stream.rs | 79 ++---- crates/socketioxide-postgres/tests/fixture.rs | 15 +- 6 files changed, 203 insertions(+), 126 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index adb1cff8..4ec4f378 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -46,6 +46,7 @@ pub trait Driver: Clone + Send + Sync + 'static { attachment: &[u8], ) -> impl Future> + Send; + /// Get an attachment from the attachment table. fn get_attachment( &self, table: &str, diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index 5fc1e522..191f5b55 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -68,8 +68,7 @@ impl Driver for SqlxDriver { } async fn push_attachment(&self, table: &str, attachment: &[u8]) -> Result { - let query = - format!("INSERT INTO {table} (id, created_at, payload) VALUES (?, ?, $2) RETURNING id"); + let query = format!("INSERT INTO {table} (payload) VALUES ($1) RETURNING id"); let id: i32 = sqlx::query_scalar(&query) .bind(attachment) diff --git a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs index 93e78c2c..29398344 100644 --- a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs @@ -126,8 +126,7 @@ impl Driver for TokioPostgresDriver { } async fn push_attachment(&self, table: &str, attachment: &[u8]) -> Result { - let query = - format!("INSERT INTO {table} (id, created_at, payload) VALUES (?, ?, $2) RETURNING id"); + let query = format!("INSERT INTO {table} (payload) VALUES ($1) RETURNING id"); self.client.query_one_scalar(&query, &[&attachment]).await } diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 23bdad12..abc8a604 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -52,7 +52,7 @@ //! first broadcast a request to all the servers and then perform the action locally. use drivers::Driver; -use futures_core::Stream; +use futures_core::{Stream, stream::BoxStream}; use futures_util::{StreamExt, pin_mut}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::value::RawValue; @@ -206,14 +206,17 @@ pub enum Error { /// Packet encoding/decoding error #[error("packet decoding error: {0}")] Serde(#[from] serde_json::Error), + /// Response handler not found + #[error("response handler not found/closed for request: {req_id}")] + ResponseHandlerNotFound { + /// The request this response is for + req_id: Sid, + }, } impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Driver(err) => write!(f, "Driver error: {:?}", err), - Self::Serde(err) => write!(f, "Encode/Decode error: {:?}", err), - } + fmt::Display::fmt(self, f) } } @@ -358,7 +361,7 @@ impl CoreAdapter for CustomPostgresAdapter // Send initial heartbeat when starting. self.emit_init_heartbeat().await.map_err(|e| match e { Error::Driver(e) => e, - Error::Serde(_) => unreachable!(), + _ => unreachable!(), })?; on_success(); @@ -456,9 +459,17 @@ impl CoreAdapter for CustomPostgresAdapter .request_timeout .saturating_add(timeout.unwrap_or(self.local.ack_timeout())); + let table_name = self.config.table_name.clone(); + let driver = self.driver.clone(); + let remote: BoxStream<'static, Box> = ChanStream::new(rx) + .filter_map(move |payload| { + resolve_resp_payload(payload, driver.clone(), table_name.clone()) + }) + .boxed(); + Ok(AckStream::new( local, - rx, + remote, timeout, remote_serv_cnt, req_id, @@ -588,43 +599,73 @@ impl CustomPostgresAdapter { while let Some(notif) = stream.next().await { let chan = notif.channel(); let resp_chan = self.get_response_chan(self.local.server_id()); - if chan == resp_chan { - match serde_json::from_str(notif.payload()) { - Ok(ResponsePacket { - req_id, - node_id, - payload, - }) if node_id != self.local.server_id() => { - let handlers = self.responses.lock().unwrap(); - if let Some(handler) = handlers.get(&req_id) { - if let Err(e) = handler.try_send(payload) { - tracing::warn!(channel = resp_chan, req_id = %req_id, "error sending response: {e}"); - } - } else { - tracing::warn!(channel = resp_chan, req_id = %req_id, "response handler not found"); - } - } - Ok(_) => { - tracing::trace!("skipping loopback packets"); - } - Err(e) => { - tracing::warn!(channel = %notif.channel(), "error handling response: {e}") - } - }; + + let result = if chan == resp_chan { + self.handle_res_notif(notif).await } else { - match serde_json::from_str::(notif.payload()) { - Ok(req) if req.node_id != self.local.server_id() => self.recv_req(req), - Ok(_) => { - tracing::trace!("skipping loopback packets") - } - Err(e) => { - tracing::warn!(channel = %notif.channel(), "error decoding request: {e}") - } - }; + self.handle_req_notif(notif) + }; + + if let Err(err) = result { + tracing::warn!(%err, "Error handling notification, skipping it"); } } } + /// Deserialize a response notification and trigger the corresponding handler. + /// If the request handler queue is full, it will wait, slowing down the notification event pipeline. + async fn handle_res_notif(&self, notif: D::Notification) -> Result<(), Error> { + match serde_json::from_str(notif.payload())? { + p if p.is_loopback(self.local.server_id()) => { + tracing::trace!("skipping loopback packets") + } + ResponsePacket { + req_id, payload, .. + } => { + let tx = self + .responses + .lock() + .unwrap() + .get(&req_id) + .ok_or(Error::ResponseHandlerNotFound { req_id })? + .clone(); + + tx.send(payload) + .await + .map_err(|_| Error::ResponseHandlerNotFound { req_id })?; + } + }; + + Ok(()) + } + + /// Deserialize a request notification and propagate it. In case of attachment, + /// the resolution is done asynchronously. + fn handle_req_notif(self: &Arc, notif: D::Notification) -> Result<(), Error> { + match serde_json::from_str::>(notif.payload())? { + p if p.is_loopback(self.local.server_id()) => { + tracing::trace!("skipping loopback packets"); + } + RequestPacket::Request { payload, .. } => { + let request = serde_json::from_str::(payload.get())?; + self.recv_req(request) + } + RequestPacket::RequestWithAttachment { id, .. } => { + let this = self.clone(); + tokio::spawn(async move { + resolve_attachment(&this.driver, &this.config.table_name, id) + .await + .map(|req| this.recv_req(req)) + .inspect_err( + |err| tracing::warn!(%err, "failed to handle request with attachment"), + ) + .ok(); + }); + } + }; + Ok(()) + } + fn recv_req(self: &Arc, req: RequestIn) { tracing::trace!(?req, "incoming request"); match (req.r#type, req.opts) { @@ -784,6 +825,10 @@ impl CustomPostgresAdapter { } /// Send a request to a specific target node or broadcast it to all nodes if no target is specified. + /// + /// The request body is serialized as JSON and wrapped in a [`RequestPacket`]. When the + /// serialized body exceeds [`PostgresAdapterConfig::payload_threshold`], it is stored in the + /// attachment table and only the row id travels over NOTIFY. async fn send_req(&self, req: RequestOut<'_>, target: Option) -> Result<(), Error> { tracing::trace!(?req, "sending request"); let chan = match target { @@ -791,16 +836,21 @@ impl CustomPostgresAdapter { None => self.get_global_chan(), }; - let mut payload = serde_json::to_string(&req)?; - if payload.len() > self.config.payload_threshold { - let attachment_id = self + let node_id = self.local.server_id(); + let body = serde_json::to_string(&req)?; + + let payload = if body.len() > self.config.payload_threshold { + let id = self .driver - .push_attachment(&self.config.table_name, payload.as_bytes()) + .push_attachment(&self.config.table_name, body.as_bytes()) .await .map_err(Error::Driver)?; - payload = attachment_id.to_string(); - } + serde_json::to_string(&RequestPacket::<()>::RequestWithAttachment { node_id, id })? + } else { + let payload = RawValue::from_string(body)?; + serde_json::to_string(&RequestPacket::Request { node_id, payload })? + }; self.driver .notify(&chan, &payload) @@ -810,7 +860,11 @@ impl CustomPostgresAdapter { } /// Send a response to the node that sent the request. - fn send_res( + /// + /// When the serialized response exceeds [`PostgresAdapterConfig::payload_threshold`], it is + /// stored in the attachment table and only the row id travels over NOTIFY as a + /// [`ResponsePayload::Attachment`]. + fn send_res( &self, req_id: Sid, req_origin: Uid, @@ -822,18 +876,30 @@ impl CustomPostgresAdapter { ); let driver = self.driver.clone(); let chan = self.get_response_chan(req_origin); - let message = serde_json::to_string(&payload) - .and_then(RawValue::from_string) - .map(|payload| ResponsePacket { - req_id, - node_id: self.local.server_id(), - payload: ResponsePayload::Data(payload), //TODO: defer to attachments - }) - .and_then(|res| serde_json::to_string(&res)); + let table = self.config.table_name.clone(); + let threshold = self.config.payload_threshold; + let node_id = self.local.server_id(); async move { + let body = serde_json::to_string(&payload)?; + let payload = if body.len() > threshold { + let id = driver + .push_attachment(&table, body.as_bytes()) + .await + .map_err(Error::Driver)?; + ResponsePayload::Attachment(id) + } else { + ResponsePayload::Data(RawValue::from_string(body)?) + }; + + let message = serde_json::to_string(&ResponsePacket { + req_id, + node_id, + payload, + })?; + driver - .notify(&chan, &message?) + .notify(&chan, &message) .await .map_err(Error::Driver)?; Ok(()) @@ -951,10 +1017,49 @@ impl Spawnable for InitRes { } } +async fn resolve_resp_payload( + payload: ResponsePayload, + driver: D, + table: Cow<'static, str>, +) -> Option> { + match payload { + ResponsePayload::Data(data) => Some(data), + ResponsePayload::Attachment(id) => resolve_attachment(&driver, &table, id) + .await + .inspect_err(|err| tracing::warn!(%err, id, "failed to resolve payload attachment")) + .ok(), + } +} + +async fn resolve_attachment( + driver: &D, + table_name: &str, + id: i32, +) -> Result> { + let bytes = driver + .get_attachment(table_name, id) + .await + .map_err(Error::Driver)?; + Ok(serde_json::from_slice(&bytes)?) +} + +/// Wire-level wrapper for request NOTIFY payloads. +/// +/// A request may either be inline (serialized request JSON) or deferred to the +/// attachment table. The `node_id` in the deferred variant lets the receiver +/// filter out loopback notifications before hitting the database. #[derive(Debug, Serialize, Deserialize)] enum RequestPacket { - Request(T), - RequestWithAttachment(i32), + Request { node_id: Uid, payload: T }, + RequestWithAttachment { node_id: Uid, id: i32 }, +} +impl RequestPacket { + fn is_loopback(&self, node_id: Uid) -> bool { + match self { + RequestPacket::Request { node_id: id, .. } => *id == node_id, + RequestPacket::RequestWithAttachment { node_id: id, .. } => *id == node_id, + } + } } #[derive(Deserialize, Serialize)] @@ -963,8 +1068,13 @@ struct ResponsePacket { node_id: Uid, payload: ResponsePayload, } -#[derive(Deserialize, Serialize)] -enum ResponsePayload { +impl ResponsePacket { + fn is_loopback(&self, node_id: Uid) -> bool { + self.node_id == node_id + } +} +#[derive(Debug, Deserialize, Serialize)] +pub(crate) enum ResponsePayload { Data(Box), Attachment(i32), } diff --git a/crates/socketioxide-postgres/src/stream.rs b/crates/socketioxide-postgres/src/stream.rs index 73a0797c..4dedaced 100644 --- a/crates/socketioxide-postgres/src/stream.rs +++ b/crates/socketioxide-postgres/src/stream.rs @@ -1,8 +1,12 @@ use std::{ - borrow::Cow, fmt, pin::Pin, sync::{Arc, Mutex}, task::{self, Poll}, time::Duration + fmt, + pin::Pin, + sync::{Arc, Mutex}, + task::{self, Poll}, + time::Duration, }; -use futures_core::{FusedStream, Stream, future::BoxFuture, ready}; +use futures_core::{FusedStream, Stream, stream::BoxStream}; use futures_util::{StreamExt, stream::TakeUntil}; use pin_project_lite::pin_project; use serde::de::DeserializeOwned; @@ -16,10 +20,7 @@ use socketioxide_core::{ }; use tokio::{sync::mpsc, time}; -use crate::{ - ResponseHandlers, ResponsePayload, - drivers::{Driver, Notification}, -}; +use crate::{ResponseHandlers, drivers::Notification}; pin_project! { /// A stream of acknowledgement messages received from the local and remote servers. @@ -35,7 +36,7 @@ pin_project! { #[pin] local: S, #[pin] - remote: DropStream>, time::Sleep>>, + remote: DropStream>, time::Sleep>>, ack_cnt: u32, total_ack_cnt: usize, serv_cnt: u16, @@ -43,15 +44,18 @@ pin_project! { } impl AckStream { - pub fn new( + /// Build an ack stream backed by a remote `ResponsePayload` channel. Attachment payloads are + /// resolved lazily, driven by polling — not by a background task — so dropping the stream + /// cancels any still-pending fetch. + pub(crate) fn new( local: S, - remote: mpsc::Receiver>, + remote: BoxStream<'static, Box>, timeout: Duration, serv_cnt: u16, req_sid: Sid, handlers: Arc>, ) -> Self { - let remote = ChanStream::new(remote).take_until(time::sleep(timeout)); + let remote = remote.take_until(time::sleep(timeout)); let remote = DropStream::new(remote, handlers, req_sid); Self { local, @@ -62,10 +66,10 @@ impl AckStream { } } - pub fn new_local(local: S) -> Self { + pub(crate) fn new_local(local: S) -> Self { let handlers = Arc::new(Mutex::new(ResponseHandlers::new())); - let rx = mpsc::channel(1).1; - let remote = ChanStream::new(rx).take_until(time::sleep(Duration::ZERO)); + let empty: BoxStream<'static, Box> = futures_util::stream::empty().boxed(); + let remote = empty.take_until(time::sleep(Duration::ZERO)); let remote = DropStream::new(remote, handlers, Sid::ZERO); Self { local, @@ -236,55 +240,6 @@ impl FusedStream for DropStream { } } -pin_project! { - struct RemoteAckStream { - driver: D, - table: Cow<'static, str>, - #[pin] - inner: ChanStream, - #[pin] - state: RemoteAckStreamState> + 'static>>, - } -} - -pin_project! { - #[project = RemoteAckStreamStateProj] - enum RemoteAckStreamState { - Pending{ #[pin] fut: Box, Error>> + 'static> }, - Done, - } -} - -impl Stream for RemoteAckStream { - type Item = Box; - - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - let proj = self.project(); - match proj.state.project() { - RemoteAckStreamStateProj::Pending { fut } => match ready!(fut.poll(cx)) { - Ok(value) => { - proj.state.set(RemoteAckStreamState::Done); - cx.waker().wake_by_ref(); - return Poll::Ready(Some(value)) - }, - Err(err) => Poll::Ready(Some(Box::new(Value::String(err.to_string())))), - }, - RemoteAckStreamStateProj::Done => (), - }; - - match ready!(self.project().inner.poll_next(cx)) { - Some(ResponsePayload::Data(data)) => Poll::Ready(Some(data)), - Some(ResponsePayload::Attachment(id)) => self.driver.get_attachment(&self.table, id) - None => Poll::Ready(None), - } - } -} -impl FusedStream for RemoteAckStream { - fn is_terminated(&self) -> bool { - self.inner.is_terminated() - } -} - #[cfg(test)] mod tests { use futures_core::FusedStream; diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs index 3595f89f..1d383b38 100644 --- a/crates/socketioxide-postgres/tests/fixture.rs +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -44,7 +44,10 @@ pub fn spawn_buggy_servers( "type": 20, "opts": null, }); - let payload = serde_json::to_string(&heartbeat_json).unwrap(); + let payload = serde_json::to_string(&serde_json::json!({ + "Request": heartbeat_json, + })) + .unwrap(); for (_, tx) in sync_buff.read().unwrap().iter() { let hash = xxhash_rust::xxh3::xxh3_64("socket.io#/".as_bytes()); @@ -236,6 +239,16 @@ impl Driver for StubDriver { Ok(id) } + async fn get_attachment(&self, _table: &str, id: i32) -> Result, Self::Error> { + Ok(self + .attachments + .read() + .unwrap() + .get(&id) + .cloned() + .unwrap_or_default()) + } + async fn close(&self) -> Result<(), Self::Error> { Ok(()) } From cf0a996087d6e194fe3e09146107f1d3dd55bce2 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 19 Apr 2026 11:37:01 +0200 Subject: [PATCH 26/31] wip --- crates/socketioxide-postgres/src/lib.rs | 43 +++++++++---------- .../socketioxide-postgres/tests/broadcast.rs | 4 ++ crates/socketioxide-postgres/tests/fixture.rs | 32 ++++++++------ 3 files changed, 45 insertions(+), 34 deletions(-) diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index abc8a604..fb59f5ca 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -94,26 +94,34 @@ pub struct PostgresAdapterConfig { /// The heartbeat timeout duration. If a remote node does not respond within this duration, /// it will be considered disconnected. Default is 60 seconds. pub hb_timeout: Duration, + /// The heartbeat interval duration. The current node will broadcast a heartbeat to the /// remote nodes at this interval. Default is 10 seconds. pub hb_interval: Duration, + /// The request timeout. When expecting a response from remote nodes, if they do not respond within /// this duration, the request will be considered failed. Default is 5 seconds. pub request_timeout: Duration, + /// The channel size used to receive ack responses. Default is 255. /// /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster /// than you poll them with the returned stream, you might want to increase this value. pub ack_response_buffer: usize, + /// The table name used to store socket.io attachments. Default is "socket_io_attachments". /// /// > The table name must be a sanitized string. Do not use special characters or spaces. pub table_name: Cow<'static, str>, + /// The prefix used for the channels. Default is "socket.io". pub prefix: Cow<'static, str>, - /// The threshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: + + /// The threshold from which the payload size in bytes is considered large and should be passed through the + /// attachment table. It should match the configured value on your PostgreSQL instance: /// . By default it is 8KB (8000 bytes). pub payload_threshold: usize, + /// The duration between cleanup queries on the attachment table. pub cleanup_interval: Duration, } @@ -206,8 +214,8 @@ pub enum Error { /// Packet encoding/decoding error #[error("packet decoding error: {0}")] Serde(#[from] serde_json::Error), - /// Response handler not found - #[error("response handler not found/closed for request: {req_id}")] + /// Response handler not found/full/closed for request + #[error("response handler not found/full/closed for request: {req_id}")] ResponseHandlerNotFound { /// The request this response is for req_id: Sid, @@ -598,10 +606,8 @@ impl CustomPostgresAdapter { pin_mut!(stream); while let Some(notif) = stream.next().await { let chan = notif.channel(); - let resp_chan = self.get_response_chan(self.local.server_id()); - - let result = if chan == resp_chan { - self.handle_res_notif(notif).await + let result = if chan == self.get_response_chan(self.local.server_id()) { + self.handle_res_notif(notif) } else { self.handle_req_notif(notif) }; @@ -614,7 +620,7 @@ impl CustomPostgresAdapter { /// Deserialize a response notification and trigger the corresponding handler. /// If the request handler queue is full, it will wait, slowing down the notification event pipeline. - async fn handle_res_notif(&self, notif: D::Notification) -> Result<(), Error> { + fn handle_res_notif(&self, notif: D::Notification) -> Result<(), Error> { match serde_json::from_str(notif.payload())? { p if p.is_loopback(self.local.server_id()) => { tracing::trace!("skipping loopback packets") @@ -630,8 +636,7 @@ impl CustomPostgresAdapter { .ok_or(Error::ResponseHandlerNotFound { req_id })? .clone(); - tx.send(payload) - .await + tx.try_send(payload) .map_err(|_| Error::ResponseHandlerNotFound { req_id })?; } }; @@ -927,21 +932,15 @@ impl CustomPostgresAdapter { let stream = stream .filter_map(async move |payload| match payload { - //TODO: response chan stream ResponsePayload::Data(data) => serde_json::from_str::>(data.get()) .inspect_err(|err| tracing::warn!("error decoding response: {err}")) .ok(), - ResponsePayload::Attachment(id) => self - .driver - .get_attachment(&self.config.table_name, id) - .await - .inspect_err(|err| tracing::warn!("error fetching attachment: {err}")) - .ok() - .and_then(|data| { - serde_json::from_slice::>(&data) - .inspect_err(|err| tracing::warn!("error decoding response: {err}")) - .ok() - }), + ResponsePayload::Attachment(id) => { + resolve_attachment(&self.driver, &self.config.table_name, id) + .await + .inspect_err(|err| tracing::warn!("error fetching attachment: {err}")) + .ok() + } }) .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type)) .take(remote_serv_cnt) diff --git a/crates/socketioxide-postgres/tests/broadcast.rs b/crates/socketioxide-postgres/tests/broadcast.rs index f1486cc1..da70d746 100644 --- a/crates/socketioxide-postgres/tests/broadcast.rs +++ b/crates/socketioxide-postgres/tests/broadcast.rs @@ -1,6 +1,7 @@ use std::time::Duration; use socketioxide::{adapter::Adapter, extract::SocketRef}; +use tracing_subscriber::EnvFilter; mod fixture; #[tokio::test] @@ -107,6 +108,9 @@ pub async fn broadcast_with_ack() { #[tokio::test] pub async fn broadcast_with_ack_timeout() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); use futures_util::StreamExt; const REQ_TIMEOUT: Duration = Duration::from_millis(50); const ACK_TIMEOUT: Duration = Duration::from_millis(50); diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs index 1d383b38..6c0ebd17 100644 --- a/crates/socketioxide-postgres/tests/fixture.rs +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -1,7 +1,10 @@ #![allow(dead_code)] use futures_core::Stream; -use socketioxide_core::Uid; +use socketioxide_core::{ + Uid, + adapter::remote_packet::{RequestOut, RequestTypeOut}, +}; use socketioxide_postgres::{ CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, drivers::{Driver, Notification}, @@ -27,6 +30,20 @@ pub fn spawn_servers() -> [SocketIo, "payload": }}`. +/// +/// `node_id` is the emitter id used by the receiver's loopback filter — pass an id distinct +/// from every real server spawned in the test, otherwise the packet will be dropped as a +/// loopback. +pub fn wrap_request(node_id: Uid, req: &RequestOut<'_>) -> String { + let payload = serde_json::to_value(req).unwrap(); + serde_json::to_string(&serde_json::json!({ + "Request": { "node_id": node_id, "payload": payload }, + })) + .unwrap() +} + pub fn spawn_buggy_servers( timeout: Duration, ) -> [SocketIo>; N] { @@ -36,18 +53,9 @@ pub fn spawn_buggy_servers( // Reinject a false heartbeat request to simulate a bad number of servers. // This will trigger timeouts when expecting responses from all servers. - // The heartbeat type is 20 (RequestTypeOut::Heartbeat) in the wire format. let uid: Uid = Uid::from_str("PHHq01ObWy7Godqx").unwrap(); - let heartbeat_json = serde_json::json!({ - "node_id": uid.to_string(), - "id": "ZG9K1r7xSLBiJYWD", - "type": 20, - "opts": null, - }); - let payload = serde_json::to_string(&serde_json::json!({ - "Request": heartbeat_json, - })) - .unwrap(); + let heartbeat = RequestOut::new_empty(uid, RequestTypeOut::Heartbeat); + let payload = wrap_request(uid, &heartbeat); for (_, tx) in sync_buff.read().unwrap().iter() { let hash = xxhash_rust::xxh3::xxh3_64("socket.io#/".as_bytes()); From b1b2f6f1e059a9cfadbe998349a3b0cd10f3efab Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 19 Apr 2026 12:12:02 +0200 Subject: [PATCH 27/31] wip --- Cargo.lock | 1 + crates/socketioxide-postgres/Cargo.toml | 1 + .../socketioxide-postgres/src/drivers/mod.rs | 4 +- .../socketioxide-postgres/src/drivers/sqlx.rs | 10 ++-- .../src/drivers/tokio_postgres.rs | 8 +-- crates/socketioxide-postgres/src/lib.rs | 52 +++++++++++++++---- crates/socketioxide-postgres/tests/fixture.rs | 12 ++--- 7 files changed, 60 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 94fd2834..dba881d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2698,6 +2698,7 @@ dependencies = [ "futures-core", "futures-util", "pin-project-lite", + "rmp-serde", "serde", "serde_json", "smallvec", diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index ced9a368..2c600496 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -26,6 +26,7 @@ futures-util.workspace = true pin-project-lite.workspace = true serde.workspace = true serde_json = { workspace = true, features = ["raw_value"] } +rmp-serde.workspace = true smallvec = { workspace = true, features = ["serde"] } tokio = { workspace = true, features = ["time", "rt", "sync"] } tracing.workspace = true diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 4ec4f378..09f63ead 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -44,13 +44,13 @@ pub trait Driver: Clone + Send + Sync + 'static { &self, table: &str, attachment: &[u8], - ) -> impl Future> + Send; + ) -> impl Future> + Send; /// Get an attachment from the attachment table. fn get_attachment( &self, table: &str, - id: i32, + id: i64, ) -> impl Future, Self::Error>> + Send; /// UNLISTEN from every channel. diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index 191f5b55..e5680079 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -67,10 +67,10 @@ impl Driver for SqlxDriver { Ok(()) } - async fn push_attachment(&self, table: &str, attachment: &[u8]) -> Result { - let query = format!("INSERT INTO {table} (payload) VALUES ($1) RETURNING id"); + async fn push_attachment(&self, table: &str, attachment: &[u8]) -> Result { + let query = format!("INSERT INTO \"{table}\" (payload) VALUES ($1) RETURNING id"); - let id: i32 = sqlx::query_scalar(&query) + let id: i64 = sqlx::query_scalar(&query) .bind(attachment) .fetch_one(&self.client) .await?; @@ -78,8 +78,8 @@ impl Driver for SqlxDriver { Ok(id) } - async fn get_attachment(&self, table: &str, id: i32) -> Result, Self::Error> { - let query = format!("SELECT payload FROM {table} WHERE id = $1"); + async fn get_attachment(&self, table: &str, id: i64) -> Result, Self::Error> { + let query = format!("SELECT payload FROM \"{table}\" WHERE id = $1"); let attachment: Vec = sqlx::query_scalar(&query) .bind(id) diff --git a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs index 29398344..a89308ac 100644 --- a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs @@ -125,13 +125,13 @@ impl Driver for TokioPostgresDriver { Ok(()) } - async fn push_attachment(&self, table: &str, attachment: &[u8]) -> Result { - let query = format!("INSERT INTO {table} (payload) VALUES ($1) RETURNING id"); + async fn push_attachment(&self, table: &str, attachment: &[u8]) -> Result { + let query = format!("INSERT INTO \"{table}\" (payload) VALUES ($1) RETURNING id"); self.client.query_one_scalar(&query, &[&attachment]).await } - async fn get_attachment(&self, table: &str, id: i32) -> Result, Self::Error> { - let query = format!("SELECT payload FROM {table} WHERE id = $1"); + async fn get_attachment(&self, table: &str, id: i64) -> Result, Self::Error> { + let query = format!("SELECT payload FROM \"{table}\" WHERE id = $1"); self.client.query_one_scalar(&query, &[&id]).await } diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index fb59f5ca..3e0aee22 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -74,11 +74,11 @@ use std::{ collections::HashMap, fmt, future, pin::Pin, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, OnceLock}, task::{Context, Poll}, time::{Duration, Instant}, }; -use tokio::sync::mpsc; +use tokio::{sync::mpsc, task::AbortHandle}; use crate::{ drivers::Notification, @@ -329,6 +329,9 @@ pub struct CustomPostgresAdapter { nodes_liveness: Mutex>, /// A map of response handlers used to await for responses from the remote servers. responses: Arc>, + /// A task that listens for events from the remote servers. + ev_stream_task: OnceLock, + hb_task: OnceLock, } impl DefinedAdapter for CustomPostgresAdapter {} @@ -345,6 +348,8 @@ impl CoreAdapter for CustomPostgresAdapter config: state.config.clone(), nodes_liveness: Mutex::new(Vec::new()), responses: Arc::new(Mutex::new(HashMap::new())), + ev_stream_task: OnceLock::new(), + hb_task: OnceLock::new(), } } @@ -363,8 +368,16 @@ impl CoreAdapter for CustomPostgresAdapter ]; let stream = self.driver.listen(&channels).await?; - tokio::spawn(self.clone().handle_ev_stream(stream)); - tokio::spawn(self.clone().heartbeat_job()); + let ev_stream_task = tokio::spawn(self.clone().handle_ev_stream(stream)).abort_handle(); + assert!( + self.ev_stream_task.set(ev_stream_task).is_ok(), + "Adapter::init should be called only once" + ); + let hb_task = tokio::spawn(self.clone().heartbeat_job()).abort_handle(); + assert!( + self.hb_task.set(hb_task).is_ok(), + "Adapter::init should be called only once" + ); // Send initial heartbeat when starting. self.emit_init_heartbeat().await.map_err(|e| match e { @@ -379,6 +392,13 @@ impl CoreAdapter for CustomPostgresAdapter } async fn close(&self) -> Result<(), Self::Error> { + if let Some(hb_task) = self.hb_task.get() { + hb_task.abort(); + } + if let Some(ev_stream_task) = self.ev_stream_task.get() { + ev_stream_task.abort(); + } + self.driver.close().await.map_err(Error::Driver) } @@ -469,10 +489,15 @@ impl CoreAdapter for CustomPostgresAdapter let table_name = self.config.table_name.clone(); let driver = self.driver.clone(); + + // Resolve attachment payloads concurrently while preserving the order the notifications + // arrived in the channel. `buffered` keeps wire order, which is required so that each + // server's `BroadcastAckCount` is observed before its individual acks. + let concurrency = std::cmp::max(self.config.ack_response_buffer, 1); let remote: BoxStream<'static, Box> = ChanStream::new(rx) - .filter_map(move |payload| { - resolve_resp_payload(payload, driver.clone(), table_name.clone()) - }) + .map(move |payload| resolve_resp_payload(payload, driver.clone(), table_name.clone())) + .buffered(concurrency) + .filter_map(future::ready) .boxed(); Ok(AckStream::new( @@ -930,8 +955,11 @@ impl CustomPostgresAdapter { self.responses.lock().unwrap().insert(req_id, tx); let stream = ChanStream::new(rx); + // Overlap attachment fetches across servers while preserving arrival order so that + // `take(remote_serv_cnt)` still closes the stream after exactly one response per server. + let concurrency = std::cmp::max(remote_serv_cnt, 1); let stream = stream - .filter_map(async move |payload| match payload { + .map(async move |payload| match payload { ResponsePayload::Data(data) => serde_json::from_str::>(data.get()) .inspect_err(|err| tracing::warn!("error decoding response: {err}")) .ok(), @@ -942,6 +970,8 @@ impl CustomPostgresAdapter { .ok() } }) + .buffered(concurrency) + .filter_map(future::ready) .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type)) .take(remote_serv_cnt) .take_until(tokio::time::sleep(self.config.request_timeout)); @@ -1033,7 +1063,7 @@ async fn resolve_resp_payload( async fn resolve_attachment( driver: &D, table_name: &str, - id: i32, + id: i64, ) -> Result> { let bytes = driver .get_attachment(table_name, id) @@ -1050,7 +1080,7 @@ async fn resolve_attachment( #[derive(Debug, Serialize, Deserialize)] enum RequestPacket { Request { node_id: Uid, payload: T }, - RequestWithAttachment { node_id: Uid, id: i32 }, + RequestWithAttachment { node_id: Uid, id: i64 }, } impl RequestPacket { fn is_loopback(&self, node_id: Uid) -> bool { @@ -1075,5 +1105,5 @@ impl ResponsePacket { #[derive(Debug, Deserialize, Serialize)] pub(crate) enum ResponsePayload { Data(Box), - Attachment(i32), + Attachment(i64), } diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs index 6c0ebd17..669eba74 100644 --- a/crates/socketioxide-postgres/tests/fixture.rs +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -14,7 +14,7 @@ use std::{ convert::Infallible, pin::Pin, str::FromStr, - sync::{Arc, RwLock, atomic::AtomicI32}, + sync::{Arc, RwLock, atomic::AtomicI64}, task, time::Duration, }; @@ -132,8 +132,8 @@ pub struct StubDriver { tx: mpsc::Sender, /// Handlers for incoming notifications per listened channel. handlers: Arc>, - attachments: Arc>>>, - attachment_idx: Arc, + attachments: Arc>>>, + attachment_idx: Arc, } impl StubDriver { @@ -155,7 +155,7 @@ impl StubDriver { tx, handlers, attachments: Arc::new(RwLock::new(HashMap::new())), - attachment_idx: Arc::new(AtomicI32::new(0)), + attachment_idx: Arc::new(AtomicI64::new(0)), }; (driver, rx, tx1) } @@ -234,7 +234,7 @@ impl Driver for StubDriver { Ok(()) } - async fn push_attachment(&self, _table: &str, attachment: &[u8]) -> Result { + async fn push_attachment(&self, _table: &str, attachment: &[u8]) -> Result { let id = self .attachment_idx .fetch_add(1, std::sync::atomic::Ordering::SeqCst); @@ -247,7 +247,7 @@ impl Driver for StubDriver { Ok(id) } - async fn get_attachment(&self, _table: &str, id: i32) -> Result, Self::Error> { + async fn get_attachment(&self, _table: &str, id: i64) -> Result, Self::Error> { Ok(self .attachments .read() From ca0da8f5f1a89028945a99b0e6a96a8559f2fc2e Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 19 Apr 2026 12:26:28 +0200 Subject: [PATCH 28/31] wip --- crates/socketioxide-postgres/src/lib.rs | 117 ++++++++++++++++-------- 1 file changed, 80 insertions(+), 37 deletions(-) diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 3e0aee22..82efd07b 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -124,6 +124,18 @@ pub struct PostgresAdapterConfig { /// The duration between cleanup queries on the attachment table. pub cleanup_interval: Duration, + + /// The maximum number of concurrent attachment fetches in-flight on the notification + /// event pipeline. Default is 64. + /// + /// Incoming NOTIFY messages are processed through `.map().buffered(n).for_each()` so that + /// [`recv_req`](CustomPostgresAdapter) is always called in wire order, while attachment DB + /// round-trips overlap up to this bound. Raising it improves throughput under bursts of + /// large payloads at the cost of more in-flight memory; lowering it tightens back-pressure + /// on the LISTEN/NOTIFY pipeline. Because `buffered` preserves input order, a single slow + /// attachment stalls every subsequent request until it resolves — so keep this comfortably + /// above the typical burst size. + pub ev_buffer_size: usize, } impl PostgresAdapterConfig { @@ -188,6 +200,13 @@ impl PostgresAdapterConfig { self.cleanup_interval = cleanup_interval; self } + + /// The maximum number of concurrent attachment fetches in-flight on the notification + /// event pipeline. Default is 64. See [`PostgresAdapterConfig::ev_buffer_size`] for tradeoffs. + pub fn with_ev_buffer_size(mut self, ev_buffer_size: usize) -> Self { + self.ev_buffer_size = ev_buffer_size; + self + } } impl Default for PostgresAdapterConfig { @@ -201,6 +220,7 @@ impl Default for PostgresAdapterConfig { prefix: "socket.io".into(), payload_threshold: 8_000, cleanup_interval: Duration::from_secs(60), + ev_buffer_size: 64, } } } @@ -627,24 +647,51 @@ impl CustomPostgresAdapter { } } + /// Drive the notification stream. + /// + /// Because `buffered` preserves input order, [`recv_req`](Self::recv_req) is always called + /// in the order notifications were received on the NOTIFY channel — which is the same order + /// the producing node issued them. This matters for causal sequences from a single producer, + /// e.g. `AddSockets(room)` followed by `Broadcast(room)`: without ordering, a deferred + /// attachment fetch for the first request could be overtaken by an inline second request, + /// making the broadcast miss its target. + /// + /// The tradeoff of global ordering is head-of-line blocking: a single slow attachment + /// fetch delays every subsequent request on this namespace, including requests from + /// unrelated producer nodes. + /// + /// Response notifications (`handle_res_notif`) are handled synchronously inside the `map` + /// closure and produce `None` so they do not participate in the buffered pipeline. async fn handle_ev_stream(self: Arc, stream: impl Stream) { + let concurrency = std::cmp::max(self.config.ev_buffer_size, 1); + let response_chan: Arc = Arc::from(self.get_response_chan(self.local.server_id())); pin_mut!(stream); - while let Some(notif) = stream.next().await { - let chan = notif.channel(); - let result = if chan == self.get_response_chan(self.local.server_id()) { - self.handle_res_notif(notif) - } else { - self.handle_req_notif(notif) - }; - - if let Err(err) = result { - tracing::warn!(%err, "Error handling notification, skipping it"); - } - } + stream + .map(|notif| { + let this = self.clone(); + let response_chan = response_chan.clone(); + async move { + let result = if notif.channel() == &*response_chan { + this.handle_res_notif(notif).map(|_| None) + } else { + this.resolve_req_notif(notif).await + }; + + result + .inspect_err(|err| tracing::warn!(%err, "error handling notification")) + .ok() + .flatten() + } + }) + .buffered(concurrency) + .filter_map(future::ready) + .for_each(async |req| self.recv_req(req)) + .await; } - /// Deserialize a response notification and trigger the corresponding handler. - /// If the request handler queue is full, it will wait, slowing down the notification event pipeline. + /// Deserialize a response notification and forward it to the waiting ack-stream handler. + /// Synchronous: no DB round-trip, attachment resolution happens on the consumer side + /// (see `get_res`/`broadcast_with_ack`). fn handle_res_notif(&self, notif: D::Notification) -> Result<(), Error> { match serde_json::from_str(notif.payload())? { p if p.is_loopback(self.local.server_id()) => { @@ -669,31 +716,27 @@ impl CustomPostgresAdapter { Ok(()) } - /// Deserialize a request notification and propagate it. In case of attachment, - /// the resolution is done asynchronously. - fn handle_req_notif(self: &Arc, notif: D::Notification) -> Result<(), Error> { - match serde_json::from_str::>(notif.payload())? { - p if p.is_loopback(self.local.server_id()) => { - tracing::trace!("skipping loopback packets"); - } + /// Parse a request notification into a [`RequestIn`]. Inline requests are decoded + /// directly; attachment-deferred requests are fetched from the attachment table. Returns + /// `None` for loopback packets or any recoverable error (decode, DB) — errors are logged + /// and the packet is skipped so a single bad notification cannot derail the pipeline. + async fn resolve_req_notif( + self: &Arc, + notif: D::Notification, + ) -> Result, Error> { + let packet = serde_json::from_str::>(notif.payload())?; + if packet.is_loopback(self.local.server_id()) { + tracing::trace!("skipping loopback packets"); + return Ok(None); + } + match packet { RequestPacket::Request { payload, .. } => { - let request = serde_json::from_str::(payload.get())?; - self.recv_req(request) - } - RequestPacket::RequestWithAttachment { id, .. } => { - let this = self.clone(); - tokio::spawn(async move { - resolve_attachment(&this.driver, &this.config.table_name, id) - .await - .map(|req| this.recv_req(req)) - .inspect_err( - |err| tracing::warn!(%err, "failed to handle request with attachment"), - ) - .ok(); - }); + Ok(Some(serde_json::from_str::(payload.get())?)) } - }; - Ok(()) + RequestPacket::RequestWithAttachment { id, .. } => Ok(Some( + resolve_attachment(&self.driver, &self.config.table_name, id).await?, + )), + } } fn recv_req(self: &Arc, req: RequestIn) { From adb0985b6871cd8fc6e40db90902292309c7aafd Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 19 Apr 2026 12:56:36 +0200 Subject: [PATCH 29/31] wip --- crates/socketioxide-postgres/src/lib.rs | 3 + .../tests/attachments.rs | 296 ++++++++++++++++++ crates/socketioxide-postgres/tests/fixture.rs | 168 +++++++--- 3 files changed, 430 insertions(+), 37 deletions(-) create mode 100644 crates/socketioxide-postgres/tests/attachments.rs diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 82efd07b..126e954d 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -919,6 +919,8 @@ impl CustomPostgresAdapter { .await .map_err(Error::Driver)?; + tracing::debug!("pushed attachment {id} for req {}", req.id); + serde_json::to_string(&RequestPacket::<()>::RequestWithAttachment { node_id, id })? } else { let payload = RawValue::from_string(body)?; @@ -1112,6 +1114,7 @@ async fn resolve_attachment( .get_attachment(table_name, id) .await .map_err(Error::Driver)?; + tracing::debug!("resolving attachment {id}"); Ok(serde_json::from_slice(&bytes)?) } diff --git a/crates/socketioxide-postgres/tests/attachments.rs b/crates/socketioxide-postgres/tests/attachments.rs new file mode 100644 index 00000000..125e6e54 --- /dev/null +++ b/crates/socketioxide-postgres/tests/attachments.rs @@ -0,0 +1,296 @@ +//! Tests exercising the attachment path: small payloads stay inline, payloads above +//! `payload_threshold` are pushed to the attachment table and fetched by the receiver. +//! +//! The stub driver in `fixture.rs` tracks push/fetch activity per server, so every test +//! asserts both a functional outcome (the message arrived) and a pipeline outcome (an +//! attachment row was actually used). + +use std::time::Duration; + +use socketioxide::{adapter::Adapter, extract::SocketRef}; +use socketioxide_postgres::PostgresAdapterConfig; + +mod fixture; + +const LOW_THRESHOLD: usize = 128; +const HIGH_THRESHOLD: usize = 10_000_000; + +fn filler(len: usize) -> String { + "x".repeat(len) +} + +fn low_threshold_config() -> PostgresAdapterConfig { + PostgresAdapterConfig::default().with_payload_threshold(LOW_THRESHOLD) +} + +/// With a very high threshold, no payload should ever land in the attachment table. +/// Guards against a regression that routes everything through `push_attachment`. +#[tokio::test] +async fn broadcast_below_threshold_is_inline() { + async fn handler(socket: SocketRef) { + tokio::time::sleep(Duration::from_millis(20)).await; + let msg = "a".repeat(512); + socket.broadcast().emit("test", &msg).await.unwrap(); + } + + let config = PostgresAdapterConfig::default().with_payload_threshold(HIGH_THRESHOLD); + let ([io1, io2], [h1, h2]) = fixture::spawn_servers_with_handles::<2>(config); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", handler).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + // Both servers broadcast on connect; each remote socket should receive exactly one message. + let expected = format!(r#"42["test","{}"]"#, "a".repeat(512)); + assert_eq!(timeout_rcv!(&mut rx1, 500), expected); + assert_eq!(timeout_rcv!(&mut rx2, 500), expected); + + assert_eq!(h1.push_count(), 0, "server 1 must not push attachments"); + assert_eq!(h2.push_count(), 0, "server 2 must not push attachments"); + assert_eq!(h1.fetch_count(), 0, "server 1 must not fetch attachments"); + assert_eq!(h2.fetch_count(), 0, "server 2 must not fetch attachments"); +} + +/// Golden path: a large broadcast from server 1 is stored in the attachment table, server 2 +/// fetches it and delivers it to its socket. Payload must round-trip byte-for-byte. +#[tokio::test] +async fn broadcast_above_threshold_uses_attachment() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_test_writer() + .try_init(); + + async fn handler(socket: SocketRef) { + tokio::time::sleep(Duration::from_millis(20)).await; + let msg = "y".repeat(4096); + socket.broadcast().emit("test", &msg).await.unwrap(); + } + + let ([io1, io2], [h1, h2]) = fixture::spawn_servers_with_handles::<2>(low_threshold_config()); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", handler).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + let expected = format!(r#"42["test","{}"]"#, "y".repeat(4096)); + assert_eq!(timeout_rcv!(&mut rx1, 500), expected); + assert_eq!(timeout_rcv!(&mut rx2, 500), expected); + + assert!(h1.push_count() >= 1, "server 1 should push an attachment"); + assert!(h2.push_count() >= 1, "server 2 should push an attachment"); + assert!( + h1.fetch_count() >= 1, + "server 1 should fetch server 2's attachment" + ); + assert!( + h2.fetch_count() >= 1, + "server 2 should fetch server 1's attachment" + ); +} + +/// A large NOTIFY the sender self-delivers (PG delivers NOTIFY to the emitter too) must be +/// filtered by `is_loopback` BEFORE hitting the attachment table — otherwise each server +/// does a pointless DB round-trip for every large message it itself produced. +#[tokio::test] +async fn request_with_attachment_loopback_is_filtered() { + async fn handler(socket: SocketRef) { + tokio::time::sleep(Duration::from_millis(5)).await; + let msg = "z".repeat(2048); + socket.broadcast().emit("test", &msg).await.unwrap(); + } + + let ([io1], [h1]) = fixture::spawn_servers_with_handles::<1>(low_threshold_config()); + io1.ns("/", handler).await.unwrap(); + + let (_tx1, mut rx1) = io1.new_dummy_sock("/", ()).await; + timeout_rcv!(&mut rx1); // Connect "/" packet + + // Give the handle_ev_stream loop time to consume the self-delivered NOTIFY. + tokio::time::sleep(Duration::from_millis(50)).await; + + assert!(h1.push_count() >= 1, "sender must push the attachment"); + assert_eq!( + h1.fetch_count(), + 0, + "sender must not fetch its own attachment (loopback filter)" + ); +} + +/// Large ack responses must round-trip through the attachment table too. This exercises +/// `send_res` → `ResponsePayload::Attachment` → `resolve_resp_payload` on the ack stream. +#[tokio::test] +async fn ack_stream_with_large_responses() { + use futures_util::StreamExt; + + async fn emitter(socket: SocketRef) { + tokio::time::sleep(Duration::from_millis(5)).await; + socket + .broadcast() + .emit_with_ack::<_, String>("test", "ping") + .await + .unwrap() + .for_each(|(_, res)| { + socket.emit("ack_res", &res.unwrap()).unwrap(); + async move {} + }) + .await; + } + + let ([io1, io2], [h1, h2]) = fixture::spawn_servers_with_handles::<2>(low_threshold_config()); + + io1.ns("/", emitter).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let ((_tx1, mut rx1), (tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + // Server 2's socket sees the emit-with-ack request and answers with a large payload. + timeout_rcv!(&mut rx2, 100); + let big_ack = filler(4096); + let packet_res = format!(r#"431["{big_ack}"]"#).try_into().unwrap(); + tx2.try_send(packet_res).unwrap(); + + // Server 1's socket receives the forwarded ack payload. + let received = timeout_rcv!(&mut rx1, 500); + assert!( + received.contains(&big_ack), + "ack payload was truncated or lost: {received}" + ); + + assert!( + h2.push_count() >= 1, + "server 2 should push an ack attachment" + ); + assert!( + h1.fetch_count() >= 1, + "server 1 should fetch the ack attachment" + ); +} + +/// Cross-server `fetch_sockets` response payloads become large when many sockets with many +/// rooms live on the remote. Exercises the read path through `get_res`. +#[tokio::test] +async fn fetch_sockets_response_uses_attachment() { + let ([io1, io2], [_h1, h2]) = fixture::spawn_servers_with_handles::<2>(low_threshold_config()); + + let handler = |rooms: Vec| async move |socket: SocketRef<_>| socket.join(rooms); + + let many_rooms: Vec = (0..200).map(|i| format!("room-{i}")).collect(); + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", handler(many_rooms)).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + timeout_rcv!(&mut rx1); + timeout_rcv!(&mut rx2); + + let sockets = io1.fetch_sockets().await.unwrap(); + assert_eq!(sockets.len(), 2); + + assert!( + h2.push_count() >= 1, + "server 2 should push its FetchSockets response via the attachment table" + ); +} + +/// `rooms()` response path: many rooms reported by the remote exceed the threshold. +#[tokio::test] +async fn rooms_response_uses_attachment() { + let ([io1, io2], [_h1, h2]) = fixture::spawn_servers_with_handles::<2>(low_threshold_config()); + + let handler = |rooms: Vec| async move |socket: SocketRef<_>| socket.join(rooms); + + let many_rooms: Vec = (0..200).map(|i| format!("room-{i}")).collect(); + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", handler(many_rooms.clone())).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + timeout_rcv!(&mut rx1); + timeout_rcv!(&mut rx2); + + let rooms = io1.rooms().await.unwrap(); + assert!(rooms.len() >= many_rooms.len()); + + assert!( + h2.push_count() >= 1, + "server 2 should push its AllRooms response via the attachment table" + ); +} + +/// A corrupted attachment (decode fails) must not derail the pipeline: the poisoned packet +/// is dropped with a warning, and the next well-formed request is still processed. +#[tokio::test] +async fn attachment_resolution_failure_drops_packet() { + async fn on_test(socket: SocketRef) { + let _ = socket; + } + + let ([io1, io2], [h1, h2]) = fixture::spawn_servers_with_handles::<2>(low_threshold_config()); + io1.ns("/", on_test).await.unwrap(); + io2.ns("/", on_test).await.unwrap(); + + let (_tx1, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_tx2, mut rx2) = io2.new_dummy_sock("/", ()).await; + timeout_rcv!(&mut rx1); + timeout_rcv!(&mut rx2); + + // Arm server 2 to fail the next get_attachment call (the upcoming large broadcast). + h2.fail_once(); + let msg1 = "a".repeat(4096); + io1.emit("test", &msg1).await.unwrap(); + + // Server 2's socket should not receive the dropped packet within a reasonable window. + timeout_rcv_err!(&mut rx2); + + // Next broadcast must still get through — the pipeline is not poisoned. + let msg2 = "b".repeat(4096); + io1.emit("test", &msg2).await.unwrap(); + let expected = format!(r#"42["test","{}"]"#, "b".repeat(4096)); + assert_eq!(timeout_rcv!(&mut rx2, 500), expected); + + assert!(h1.push_count() >= 2, "two large broadcasts were emitted"); + assert!( + h2.fetch_count() >= 2, + "server 2 should have attempted to fetch both" + ); +} + +/// Many large requests fired back-to-back must all reach the target. Catches regressions in +/// the buffered pipeline (e.g. stream closing early, head-of-line deadlocks). +#[tokio::test] +async fn concurrent_large_requests_reach_target() { + const N: usize = 20; + + let ([io1, io2], _) = fixture::spawn_servers_with_handles::<2>(low_threshold_config()); + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let (_tx1, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_tx2, mut rx2) = io2.new_dummy_sock("/", ()).await; + timeout_rcv!(&mut rx1); + timeout_rcv!(&mut rx2); + + for i in 0..N { + let msg = format!("{i}:{}", "q".repeat(4096)); + io1.emit("test", &msg).await.unwrap(); + } + + for _ in 0..N { + let _ = timeout_rcv!(&mut rx2, 500); + } + timeout_rcv_err!(&mut rx2); +} diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs index 669eba74..11448117 100644 --- a/crates/socketioxide-postgres/tests/fixture.rs +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -11,10 +11,13 @@ use socketioxide_postgres::{ }; use std::{ collections::HashMap, - convert::Infallible, + collections::HashSet, pin::Pin, str::FromStr, - sync::{Arc, RwLock, atomic::AtomicI64}, + sync::{ + Arc, RwLock, + atomic::{AtomicBool, AtomicI64}, + }, task, time::Duration, }; @@ -27,7 +30,29 @@ use socketioxide::{SocketIo, SocketIoConfig, adapter::Emitter}; pub fn spawn_servers() -> [SocketIo>; N] { let sync_buff = Arc::new(RwLock::new(Vec::with_capacity(N))); - spawn_inner(sync_buff, PostgresAdapterConfig::default()) + let (ios, _handles) = spawn_inner(sync_buff, PostgresAdapterConfig::default()); + ios +} + +/// Spawns `N` servers with a custom [`PostgresAdapterConfig`]. +pub fn spawn_servers_with_config( + config: PostgresAdapterConfig, +) -> [SocketIo>; N] { + let sync_buff = Arc::new(RwLock::new(Vec::with_capacity(N))); + let (ios, _handles) = spawn_inner(sync_buff, config); + ios +} + +/// Spawns `N` servers with a custom config AND returns handles to each stub driver so tests +/// can assert on attachment-store state (how many rows pushed, which ids were fetched, etc.). +pub fn spawn_servers_with_handles( + config: PostgresAdapterConfig, +) -> ( + [SocketIo>; N], + [StubDriver; N], +) { + let sync_buff = Arc::new(RwLock::new(Vec::with_capacity(N))); + spawn_inner(sync_buff, config) } /// Serialize a [`RequestOut`] in the same wire envelope the adapter emits for inline requests: @@ -49,7 +74,7 @@ pub fn spawn_buggy_servers( ) -> [SocketIo>; N] { let sync_buff = Arc::new(RwLock::new(Vec::with_capacity(N))); let config = PostgresAdapterConfig::default().with_request_timeout(timeout); - let res = spawn_inner(sync_buff.clone(), config); + let (res, _handles) = spawn_inner(sync_buff.clone(), config); // Reinject a false heartbeat request to simulate a bad number of servers. // This will trigger timeouts when expecting responses from all servers. @@ -74,35 +99,44 @@ pub fn spawn_buggy_servers( fn spawn_inner( sync_buff: Arc>, config: PostgresAdapterConfig, -) -> [SocketIo>; N] { - [0; N].map(|_| { - let server_id = Uid::new(); - let (driver, mut rx, tx) = StubDriver::new(server_id); - - // pipe messages to all other servers - sync_buff.write().unwrap().push((server_id, tx)); - let sync_buff = sync_buff.clone(); - tokio::spawn(async move { - while let Some(notif) = rx.recv().await { - tracing::debug!("received notify on channel {:?}", notif.channel); - for (sid, tx) in sync_buff.read().unwrap().iter() { - if *sid != server_id { - tracing::debug!("forwarding notify to server {:?}", sid); - tx.try_send(notif.clone()).unwrap(); +) -> ( + [SocketIo>; N], + [StubDriver; N], +) { + let attachments = Arc::new(RwLock::new(RemoteTable::default())); + let (ios, handles) = [0; N] + .map(|_| { + let server_id = Uid::new(); + let (driver, mut rx, tx) = StubDriver::new(server_id, attachments.clone()); + + // pipe messages to all other servers + sync_buff.write().unwrap().push((server_id, tx)); + let sync_buff = sync_buff.clone(); + tokio::spawn(async move { + while let Some(notif) = rx.recv().await { + tracing::debug!("received notify on channel {:?}", notif.channel); + for (sid, tx) in sync_buff.read().unwrap().iter() { + if *sid != server_id { + tracing::debug!("forwarding notify to server {:?}", sid); + tx.try_send(notif.clone()).unwrap(); + } } } - } - }); - - let adapter = PostgresAdapterCtr::new_with_driver(driver, config.clone()); - let mut config = SocketIoConfig::default(); - config.server_id = server_id; - let (_svc, io) = SocketIo::builder() - .with_config(config) - .with_adapter::>(adapter) - .build_svc(); - io - }) + }); + + let adapter = PostgresAdapterCtr::new_with_driver(driver.clone(), config.clone()); + let mut config = SocketIoConfig::default(); + config.server_id = server_id; + let (_svc, io) = SocketIo::builder() + .with_config(config) + .with_adapter::>(adapter) + .build_svc(); + (io, driver) + }) + .into_iter() + .collect::<(Vec<_>, Vec<_>)>(); + + (ios.try_into().unwrap(), handles.try_into().unwrap()) } type NotifyHandlers = Vec<(Uid, mpsc::Sender)>; @@ -125,6 +159,12 @@ impl Notification for StubNotification { type Handlers = Vec<(String, mpsc::Sender)>; +#[derive(Debug, Default)] +pub struct RemoteTable { + table: HashMap>, + idx: AtomicI64, +} + #[derive(Debug, Clone)] pub struct StubDriver { server_id: Uid, @@ -132,13 +172,26 @@ pub struct StubDriver { tx: mpsc::Sender, /// Handlers for incoming notifications per listened channel. handlers: Arc>, - attachments: Arc>>>, - attachment_idx: Arc, + attachments: Arc>, + /// Ids passed to `get_attachment`, in call order. Used by tests to assert on fetch + /// activity (e.g. that a loopback large NOTIFY does not trigger a fetch). + fetched_ids: Arc>>, + /// Ids that, when fetched, return garbage bytes — triggers decode failure in the adapter. + corrupt_ids: Arc>>, + /// One-shot flag: if true, the next `get_attachment` returns a driver error. + fail_next_get: Arc, +} + +#[derive(Debug, thiserror::Error)] +pub enum StubError { + #[error("injected stub driver failure")] + Injected, } impl StubDriver { pub fn new( server_id: Uid, + attachments: Arc>, ) -> ( Self, mpsc::Receiver, @@ -154,11 +207,37 @@ impl StubDriver { server_id, tx, handlers, - attachments: Arc::new(RwLock::new(HashMap::new())), - attachment_idx: Arc::new(AtomicI64::new(0)), + attachments, + fetched_ids: Arc::new(RwLock::new(Vec::new())), + corrupt_ids: Arc::new(RwLock::new(HashSet::new())), + fail_next_get: Arc::new(AtomicBool::new(false)), }; (driver, rx, tx1) } + + /// Number of attachment rows this server has written (i.e. sent via NOTIFY). + pub fn push_count(&self) -> usize { + self.attachments.read().unwrap().table.len() + } + /// Number of `get_attachment` calls this server has served. + pub fn fetch_count(&self) -> usize { + self.fetched_ids.read().unwrap().len() + } + /// All ids `get_attachment` was called with, in call order. + pub fn fetched(&self) -> Vec { + self.fetched_ids.read().unwrap().clone() + } + /// Mark an id as corrupt: next `get_attachment(id)` returns a garbage payload so the + /// decoder fails. Exercises the silent-drop path. + pub fn corrupt(&self, id: i64) { + self.corrupt_ids.write().unwrap().insert(id); + } + /// One-shot: the next `get_attachment` call returns a driver error (triggers the + /// attachment-resolution-failure path). + pub fn fail_once(&self) { + self.fail_next_get + .store(true, std::sync::atomic::Ordering::SeqCst); + } } /// Pipe incoming notifications to the matching channel handlers. @@ -192,7 +271,7 @@ impl Stream for NotificationStream { } impl Driver for StubDriver { - type Error = Infallible; + type Error = StubError; type Notification = StubNotification; type NotificationStream = NotificationStream; @@ -236,22 +315,37 @@ impl Driver for StubDriver { async fn push_attachment(&self, _table: &str, attachment: &[u8]) -> Result { let id = self - .attachment_idx + .attachments + .read() + .unwrap() + .idx .fetch_add(1, std::sync::atomic::Ordering::SeqCst); self.attachments .write() .unwrap() + .table .insert(id, attachment.to_vec()); Ok(id) } async fn get_attachment(&self, _table: &str, id: i64) -> Result, Self::Error> { + self.fetched_ids.write().unwrap().push(id); + if self + .fail_next_get + .swap(false, std::sync::atomic::Ordering::SeqCst) + { + return Err(StubError::Injected); + } + if self.corrupt_ids.read().unwrap().contains(&id) { + return Ok(b"not a valid request".to_vec()); + } Ok(self .attachments .read() .unwrap() + .table .get(&id) .cloned() .unwrap_or_default()) From 82beea63c62c9b33db2df562e07170096ed5c481 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 19 Apr 2026 13:31:38 +0200 Subject: [PATCH 30/31] wip --- crates/socketioxide-postgres/src/lib.rs | 102 +++++++++++++++++------- 1 file changed, 73 insertions(+), 29 deletions(-) diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 126e954d..153805eb 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -231,9 +231,17 @@ pub enum Error { /// Postgres driver error #[error("driver error: {0}")] Driver(D::Error), - /// Packet encoding/decoding error - #[error("packet decoding error: {0}")] - Serde(#[from] serde_json::Error), + /// Json Packet encoding/decoding error + #[error("json packet encoding/decoding error: {0}")] + Json(#[from] serde_json::Error), + + /// Binary packet encoding error + #[error("binary packet encoding error: {0}")] + MessagePackEncode(#[from] rmp_serde::encode::Error), + /// Binary packet decoding error + #[error("binary packet decoding error: {0}")] + MessagePackDecode(#[from] rmp_serde::decode::Error), + /// Response handler not found/full/closed for request #[error("response handler not found/full/closed for request: {req_id}")] ResponseHandlerNotFound { @@ -733,8 +741,8 @@ impl CustomPostgresAdapter { RequestPacket::Request { payload, .. } => { Ok(Some(serde_json::from_str::(payload.get())?)) } - RequestPacket::RequestWithAttachment { id, .. } => Ok(Some( - resolve_attachment(&self.driver, &self.config.table_name, id).await?, + RequestPacket::RequestWithAttachment { id, is_binary, .. } => Ok(Some( + resolve_attachment(&self.driver, &self.config.table_name, id, is_binary).await?, )), } } @@ -910,19 +918,34 @@ impl CustomPostgresAdapter { }; let node_id = self.local.server_id(); - let body = serde_json::to_string(&req)?; + let is_binary = req.is_binary(); + let body = if is_binary { + rmp_serde::to_vec(&req)? + } else { + serde_json::to_vec(&req)? + }; - let payload = if body.len() > self.config.payload_threshold { + let payload = if body.len() > self.config.payload_threshold || is_binary { let id = self .driver - .push_attachment(&self.config.table_name, body.as_bytes()) + .push_attachment(&self.config.table_name, &body) .await .map_err(Error::Driver)?; tracing::debug!("pushed attachment {id} for req {}", req.id); - serde_json::to_string(&RequestPacket::<()>::RequestWithAttachment { node_id, id })? + serde_json::to_string(&RequestPacket::<()>::RequestWithAttachment { + node_id, + is_binary, + id, + })? } else { + assert!( + !is_binary, + "binary packets should be stored in attachment table and serialized in msgpack" + ); + + let body = unsafe { String::from_utf8_unchecked(body) }; let payload = RawValue::from_string(body)?; serde_json::to_string(&RequestPacket::Request { node_id, payload })? }; @@ -943,27 +966,34 @@ impl CustomPostgresAdapter { &self, req_id: Sid, req_origin: Uid, - payload: Response, + res: Response, ) -> impl Future>> + 'static { - tracing::trace!( - ?payload, - "sending response for {req_id} req to {req_origin}" - ); + tracing::trace!(?res, "sending response for {req_id} req to {req_origin}"); let driver = self.driver.clone(); let chan = self.get_response_chan(req_origin); let table = self.config.table_name.clone(); let threshold = self.config.payload_threshold; let node_id = self.local.server_id(); - + let is_binary = res.is_binary(); async move { - let body = serde_json::to_string(&payload)?; - let payload = if body.len() > threshold { + let body = if is_binary { + rmp_serde::to_vec(&res)? + } else { + serde_json::to_vec(&res)? + }; + + let payload = if body.len() > threshold || is_binary { let id = driver - .push_attachment(&table, body.as_bytes()) + .push_attachment(&table, &body) .await .map_err(Error::Driver)?; - ResponsePayload::Attachment(id) + ResponsePayload::Attachment { id, is_binary } } else { + assert!( + !is_binary, + "binary packets should be stored in attachment table and serialized in msgpack" + ); + let body = unsafe { String::from_utf8_unchecked(body) }; ResponsePayload::Data(RawValue::from_string(body)?) }; @@ -1008,8 +1038,8 @@ impl CustomPostgresAdapter { ResponsePayload::Data(data) => serde_json::from_str::>(data.get()) .inspect_err(|err| tracing::warn!("error decoding response: {err}")) .ok(), - ResponsePayload::Attachment(id) => { - resolve_attachment(&self.driver, &self.config.table_name, id) + ResponsePayload::Attachment { id, is_binary } => { + resolve_attachment(&self.driver, &self.config.table_name, id, is_binary) .await .inspect_err(|err| tracing::warn!("error fetching attachment: {err}")) .ok() @@ -1098,10 +1128,12 @@ async fn resolve_resp_payload( ) -> Option> { match payload { ResponsePayload::Data(data) => Some(data), - ResponsePayload::Attachment(id) => resolve_attachment(&driver, &table, id) - .await - .inspect_err(|err| tracing::warn!(%err, id, "failed to resolve payload attachment")) - .ok(), + ResponsePayload::Attachment { id, is_binary } => { + resolve_attachment(&driver, &table, id, is_binary) + .await + .inspect_err(|err| tracing::warn!(%err, id, "failed to resolve payload attachment")) + .ok() + } } } @@ -1109,13 +1141,18 @@ async fn resolve_attachment( driver: &D, table_name: &str, id: i64, + is_binary: bool, ) -> Result> { let bytes = driver .get_attachment(table_name, id) .await .map_err(Error::Driver)?; tracing::debug!("resolving attachment {id}"); - Ok(serde_json::from_slice(&bytes)?) + if is_binary { + Ok(rmp_serde::from_slice(&bytes)?) + } else { + Ok(serde_json::from_slice(&bytes)?) + } } /// Wire-level wrapper for request NOTIFY payloads. @@ -1125,8 +1162,15 @@ async fn resolve_attachment( /// filter out loopback notifications before hitting the database. #[derive(Debug, Serialize, Deserialize)] enum RequestPacket { - Request { node_id: Uid, payload: T }, - RequestWithAttachment { node_id: Uid, id: i64 }, + Request { + node_id: Uid, + payload: T, + }, + RequestWithAttachment { + node_id: Uid, + is_binary: bool, + id: i64, + }, } impl RequestPacket { fn is_loopback(&self, node_id: Uid) -> bool { @@ -1151,5 +1195,5 @@ impl ResponsePacket { #[derive(Debug, Deserialize, Serialize)] pub(crate) enum ResponsePayload { Data(Box), - Attachment(i64), + Attachment { id: i64, is_binary: bool }, } From 98b9218bae216215ca74a1a7005f8a9cd0185dd8 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 19 Apr 2026 13:46:49 +0200 Subject: [PATCH 31/31] wip --- .../socketioxide-postgres/src/drivers/mod.rs | 9 +++++ .../socketioxide-postgres/src/drivers/sqlx.rs | 22 +++++++++++ .../src/drivers/tokio_postgres.rs | 21 +++++++++- crates/socketioxide-postgres/src/lib.rs | 25 +++++++++++- crates/socketioxide-postgres/tests/fixture.rs | 39 ++++++++++++++++--- 5 files changed, 109 insertions(+), 7 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 09f63ead..8533fffc 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,6 +1,8 @@ //! Drivers are an abstraction over the PostgreSQL LISTEN/NOTIFY backend used by the adapter. //! You can use the provided implementation or implement your own. +use std::time::Duration; + use futures_core::Stream; /// A driver implementation for the [`sqlx`](https://docs.rs/sqlx) PostgreSQL backend. @@ -53,6 +55,13 @@ pub trait Driver: Clone + Send + Sync + 'static { id: i64, ) -> impl Future, Self::Error>> + Send; + /// Cleanup attachments older than a certain timestamp. + fn cleanup_attachments( + &self, + table: &str, + interval: Duration, + ) -> impl Future> + Send; + /// UNLISTEN from every channel. fn close(&self) -> impl Future> + Send; } diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index e5680079..9833aea7 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use futures_core::stream::BoxStream; use futures_util::StreamExt; use sqlx::{ @@ -89,6 +91,26 @@ impl Driver for SqlxDriver { Ok(attachment) } + async fn cleanup_attachments( + &self, + table: &str, + interval: Duration, + ) -> Result<(), Self::Error> { + let query = format!( + "DELETE FROM \"{table}\" WHERE created_at < now() - interval '{} milliseconds'", + interval.as_millis() + ); + + let affected = sqlx::query(&query) + .execute(&self.client) + .await? + .rows_affected(); + + tracing::debug!(affected, "pruned attachments"); + + Ok(()) + } + async fn close(&self) -> Result<(), Self::Error> { // PgListener will automatically unlisten channel when being dropped Ok(()) diff --git a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs index a89308ac..92f86578 100644 --- a/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/tokio_postgres.rs @@ -1,4 +1,7 @@ -use std::sync::{Arc, RwLock}; +use std::{ + sync::{Arc, RwLock}, + time::Duration, +}; use futures_util::{StreamExt, sink, stream}; use tokio::sync::mpsc; @@ -135,6 +138,22 @@ impl Driver for TokioPostgresDriver { self.client.query_one_scalar(&query, &[&id]).await } + async fn cleanup_attachments( + &self, + table: &str, + interval: Duration, + ) -> Result<(), Self::Error> { + let query = format!( + "DELETE FROM \"{table}\" WHERE created_at < now() - interval '{} milliseconds'", + interval.as_millis() + ); + + let affected = self.client.execute(&query, &[]).await?; + tracing::debug!(affected, "pruned attachments"); + + Ok(()) + } + async fn close(&self) -> Result<(), Self::Error> { self.client.execute("UNLISTEN *", &[]).await?; Ok(()) diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 153805eb..a3abfc0d 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -357,9 +357,10 @@ pub struct CustomPostgresAdapter { nodes_liveness: Mutex>, /// A map of response handlers used to await for responses from the remote servers. responses: Arc>, - /// A task that listens for events from the remote servers. + ev_stream_task: OnceLock, hb_task: OnceLock, + cleanup_task: OnceLock, } impl DefinedAdapter for CustomPostgresAdapter {} @@ -378,6 +379,7 @@ impl CoreAdapter for CustomPostgresAdapter responses: Arc::new(Mutex::new(HashMap::new())), ev_stream_task: OnceLock::new(), hb_task: OnceLock::new(), + cleanup_task: OnceLock::new(), } } @@ -406,6 +408,11 @@ impl CoreAdapter for CustomPostgresAdapter self.hb_task.set(hb_task).is_ok(), "Adapter::init should be called only once" ); + let cleanup_task = tokio::spawn(self.clone().cleanup_job()).abort_handle(); + assert!( + self.cleanup_task.set(cleanup_task).is_ok(), + "Adapter::init should be called only once" + ); // Send initial heartbeat when starting. self.emit_init_heartbeat().await.map_err(|e| match e { @@ -655,6 +662,22 @@ impl CustomPostgresAdapter { } } + async fn cleanup_job(self: Arc) { + let mut interval = tokio::time::interval(self.config.cleanup_interval); + interval.tick().await; // first tick yields immediately + + loop { + interval.tick().await; + if let Err(err) = self + .driver + .cleanup_attachments(&self.config.table_name, self.config.cleanup_interval) + .await + { + tracing::warn!(%err, "failed to cleanup attachment table"); + } + } + } + /// Drive the notification stream. /// /// Because `buffered` preserves input order, [`recv_req`](Self::recv_req) is always called diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs index 11448117..f5135a1c 100644 --- a/crates/socketioxide-postgres/tests/fixture.rs +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -10,8 +10,7 @@ use socketioxide_postgres::{ drivers::{Driver, Notification}, }; use std::{ - collections::HashMap, - collections::HashSet, + collections::{HashMap, HashSet}, pin::Pin, str::FromStr, sync::{ @@ -19,7 +18,7 @@ use std::{ atomic::{AtomicBool, AtomicI64}, }, task, - time::Duration, + time::{Duration, Instant}, }; use tokio::sync::mpsc; @@ -161,9 +160,24 @@ type Handlers = Vec<(String, mpsc::Sender)>; #[derive(Debug, Default)] pub struct RemoteTable { - table: HashMap>, + table: HashMap, idx: AtomicI64, } +#[derive(Debug, Clone)] +pub struct Row { + id: i64, + data: Vec, + created_at: Instant, +} +impl Row { + fn new(id: i64, data: Vec) -> Self { + Self { + id, + data, + created_at: Instant::now(), + } + } +} #[derive(Debug, Clone)] pub struct StubDriver { @@ -325,7 +339,7 @@ impl Driver for StubDriver { .write() .unwrap() .table - .insert(id, attachment.to_vec()); + .insert(id, Row::new(id, attachment.to_vec())); Ok(id) } @@ -348,9 +362,24 @@ impl Driver for StubDriver { .table .get(&id) .cloned() + .map(|v| v.data) .unwrap_or_default()) } + async fn cleanup_attachments( + &self, + _table: &str, + interval: Duration, + ) -> Result<(), Self::Error> { + self.attachments + .write() + .unwrap() + .table + .retain(|_, v| v.created_at.elapsed() < interval); + + Ok(()) + } + async fn close(&self) -> Result<(), Self::Error> { Ok(()) }