diff --git a/libsql/src/connection.rs b/libsql/src/connection.rs index 2bca312500..396f2e1f18 100644 --- a/libsql/src/connection.rs +++ b/libsql/src/connection.rs @@ -13,6 +13,15 @@ use crate::{Result, TransactionBehavior}; pub type AuthHook = Arc Authorization>; +pub type UpdateHook = dyn Fn(Op, &str, &str, i64) + Send + Sync; + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Op { + Insert = 0, + Delete = 1, + Update = 2, +} + #[async_trait::async_trait] pub(crate) trait Conn { async fn execute(&self, sql: &str, params: Params) -> Result; @@ -58,6 +67,10 @@ pub(crate) trait Conn { fn authorizer(&self, _hook: Option) -> Result<()> { Err(crate::Error::AuthorizerNotSupported) } + + fn add_update_hook(&self, _cb: Box) -> Result<()> { + Err(crate::Error::UpdateHookNotSupported) + } } /// A set of rows returned from `execute_batch`/`execute_transactional_batch`. It is essentially @@ -285,6 +298,10 @@ impl Connection { pub fn authorizer(&self, hook: Option) -> Result<()> { self.conn.authorizer(hook) } + + pub fn add_update_hook(&self, cb: Box) -> Result<()> { + self.conn.add_update_hook(cb) + } } impl fmt::Debug for Connection { diff --git a/libsql/src/errors.rs b/libsql/src/errors.rs index 069e5fd5cd..f612aa62f1 100644 --- a/libsql/src/errors.rs +++ b/libsql/src/errors.rs @@ -23,6 +23,8 @@ pub enum Error { LoadExtensionNotSupported, // Not in rusqlite #[error("Authorizer is only supported in local databases.")] AuthorizerNotSupported, // Not in rusqlite + #[error("Update hooks are only supported in local databases.")] + UpdateHookNotSupported, // Not in rusqlite #[error("Column not found: {0}")] ColumnNotFound(i32), // Not in rusqlite #[error("Hrana: `{0}`")] diff --git a/libsql/src/lib.rs b/libsql/src/lib.rs index a42b0a4940..13cf082eb7 100644 --- a/libsql/src/lib.rs +++ b/libsql/src/lib.rs @@ -180,7 +180,7 @@ cfg_hrana! { pub use self::{ auth::{AuthAction, AuthContext, Authorization}, - connection::{AuthHook, BatchRows, Connection}, + connection::{AuthHook, BatchRows, Connection, Op}, database::{Builder, Database}, load_extension_guard::LoadExtensionGuard, rows::{Column, Row, Rows}, diff --git a/libsql/src/local/connection.rs b/libsql/src/local/connection.rs index ba10da63e4..7012651699 100644 --- a/libsql/src/local/connection.rs +++ b/libsql/src/local/connection.rs @@ -4,7 +4,10 @@ use crate::auth::{AuthAction, AuthContext, Authorization}; use crate::connection::AuthHook; use crate::local::rows::BatchedRows; use crate::params::Params; -use crate::{connection::BatchRows, errors}; +use crate::{ + connection::{BatchRows, Op, UpdateHook}, + errors, +}; use std::time::Duration; use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction}; @@ -15,6 +18,10 @@ use libsql_sys::ffi; use parking_lot::RwLock; use std::{ffi::c_int, fmt, path::Path, sync::Arc}; +struct Container { + cb: Box, +} + /// A connection to a libSQL database. #[derive(Clone)] pub struct Connection { @@ -400,6 +407,24 @@ impl Connection { }) } + /// Installs update hook + pub fn add_update_hook(&self, cb: Box) { + let c = Box::new(Container { cb }); + let ptr: *mut Container = std::ptr::from_mut(Box::leak(c)); + + let old_data = unsafe { + ffi::sqlite3_update_hook( + self.raw, + Some(update_hook_cb), + ptr as *mut ::std::os::raw::c_void, + ) + }; + + if !old_data.is_null() { + let _ = unsafe { Box::from_raw(old_data as *mut Container) }; + } + } + pub fn enable_load_extension(&self, onoff: bool) -> Result<()> { // SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION configration verb accepts 2 additional parameters: an on/off flag and a pointer to an c_int where new state of the parameter will be written (or NULL if reporting back the setting is not needed) // See: https://sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension @@ -464,7 +489,8 @@ impl Connection { pub fn authorizer(&self, hook: Option) -> Result<()> { unsafe { - let rc = libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), None, std::ptr::null_mut()); + let rc = + libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), None, std::ptr::null_mut()); if rc != ffi::SQLITE_OK { return Err(crate::errors::Error::SqliteFailure( rc as std::ffi::c_int, @@ -484,7 +510,8 @@ impl Connection { None => (None, std::ptr::null_mut()), }; - let rc = unsafe { libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), callback, user_data) }; + let rc = + unsafe { libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), callback, user_data) }; if rc != ffi::SQLITE_OK { return Err(crate::errors::Error::SqliteFailure( rc as std::ffi::c_int, @@ -716,7 +743,7 @@ unsafe extern "C" fn authorizer_callback( pub(crate) struct WalInsertHandle<'a> { conn: &'a Connection, - in_session: RwLock + in_session: RwLock, } impl WalInsertHandle<'_> { @@ -761,6 +788,28 @@ impl fmt::Debug for Connection { } } +#[no_mangle] +extern "C" fn update_hook_cb( + data: *mut ::std::os::raw::c_void, + op: ::std::os::raw::c_int, + db_name: *const ::std::os::raw::c_char, + table_name: *const ::std::os::raw::c_char, + row_id: i64, +) { + let db = unsafe { std::ffi::CStr::from_ptr(db_name).to_string_lossy() }; + let table = unsafe { std::ffi::CStr::from_ptr(table_name).to_string_lossy() }; + + let c = unsafe { &mut *(data as *mut Container) }; + let o = match op { + 9 => Op::Delete, + 18 => Op::Insert, + 23 => Op::Update, + _ => unreachable!("Unknown operation {op}"), + }; + + (*c.cb)(o, &db, &table, row_id); +} + #[cfg(test)] mod tests { use crate::{ diff --git a/libsql/src/local/impls.rs b/libsql/src/local/impls.rs index 26b8cd0575..b86405610e 100644 --- a/libsql/src/local/impls.rs +++ b/libsql/src/local/impls.rs @@ -1,10 +1,9 @@ use std::sync::Arc; -use std::{fmt, path::Path}; use std::time::Duration; +use std::{fmt, path::Path}; -use crate::connection::BatchRows; use crate::{ - connection::{AuthHook, Conn}, + connection::{AuthHook, BatchRows, Conn, UpdateHook}, params::Params, rows::{ColumnsInner, RowInner, RowsInner}, statement::Stmt, @@ -100,6 +99,10 @@ impl Conn for LibsqlConnection { fn authorizer(&self, hook: Option) -> Result<()> { self.conn.authorizer(hook) } + + fn add_update_hook(&self, cb: Box) -> Result<()> { + Ok(self.conn.add_update_hook(cb)) + } } impl Drop for LibsqlConnection { diff --git a/libsql/tests/integration_tests.rs b/libsql/tests/integration_tests.rs index 57addab948..697ac220ef 100644 --- a/libsql/tests/integration_tests.rs +++ b/libsql/tests/integration_tests.rs @@ -4,12 +4,12 @@ use futures::{StreamExt, TryStreamExt}; use libsql::{ named_params, params, params::{IntoParams, IntoValue}, - AuthAction, Authorization, Connection, Database, Result, Value, + AuthAction, Authorization, Connection, Database, Op, Result, Value, }; use rand::distributions::Uniform; use rand::prelude::*; use std::collections::HashSet; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; async fn setup() -> Connection { let db = Database::open(":memory:").unwrap(); @@ -28,6 +28,77 @@ async fn enable_disable_extension() { conn.load_extension_disable().unwrap(); } +#[tokio::test] +async fn add_update_hook() { + let conn = setup().await; + + #[derive(PartialEq, Debug)] + struct Data { + op: Op, + db: String, + table: String, + row_id: i64, + } + + let d = Arc::new(Mutex::new(None::)); + + let d_clone = d.clone(); + conn.add_update_hook(Box::new(move |op, db, table, row_id| { + *d_clone.lock().unwrap() = Some(Data { + op, + db: db.to_string(), + table: table.to_string(), + row_id, + }); + })) + .unwrap(); + + let _ = conn + .execute("INSERT INTO users (id, name) VALUES (2, 'Alice')", ()) + .await + .unwrap(); + + assert_eq!( + *d.lock().unwrap().as_ref().unwrap(), + Data { + op: Op::Insert, + db: "main".to_string(), + table: "users".to_string(), + row_id: 1, + } + ); + + let _ = conn + .execute("UPDATE users SET name = 'Bob' WHERE id = 2", ()) + .await + .unwrap(); + + assert_eq!( + *d.lock().unwrap().as_ref().unwrap(), + Data { + op: Op::Update, + db: "main".to_string(), + table: "users".to_string(), + row_id: 1, + } + ); + + let _ = conn + .execute("DELETE FROM users WHERE id = 2", ()) + .await + .unwrap(); + + assert_eq!( + *d.lock().unwrap().as_ref().unwrap(), + Data { + op: Op::Delete, + db: "main".to_string(), + table: "users".to_string(), + row_id: 1, + } + ); +} + #[tokio::test] async fn connection_drops_before_statements() { let db = Database::open(":memory:").unwrap();