rusqlite/hooks/
mod.rs

1//! Commit, Data Change and Rollback Notification Callbacks
2#![expect(non_camel_case_types)]
3
4use std::ffi::{c_char, c_int, c_void, CStr};
5use std::panic::catch_unwind;
6use std::ptr;
7
8use crate::ffi;
9
10use crate::{error::decode_result_raw, Connection, InnerConnection, Result};
11
12#[cfg(feature = "preupdate_hook")]
13pub use preupdate_hook::*;
14
15#[cfg(feature = "preupdate_hook")]
16mod preupdate_hook;
17
18/// Action Codes
19#[derive(Clone, Copy, Debug, Eq, PartialEq)]
20#[repr(i32)]
21#[non_exhaustive]
22pub enum Action {
23    /// Unsupported / unexpected action
24    UNKNOWN = -1,
25    /// DELETE command
26    SQLITE_DELETE = ffi::SQLITE_DELETE,
27    /// INSERT command
28    SQLITE_INSERT = ffi::SQLITE_INSERT,
29    /// UPDATE command
30    SQLITE_UPDATE = ffi::SQLITE_UPDATE,
31}
32
33impl From<i32> for Action {
34    #[inline]
35    fn from(code: i32) -> Self {
36        match code {
37            ffi::SQLITE_DELETE => Self::SQLITE_DELETE,
38            ffi::SQLITE_INSERT => Self::SQLITE_INSERT,
39            ffi::SQLITE_UPDATE => Self::SQLITE_UPDATE,
40            _ => Self::UNKNOWN,
41        }
42    }
43}
44
45/// The context received by an authorizer hook.
46///
47/// See <https://sqlite.org/c3ref/set_authorizer.html> for more info.
48#[derive(Clone, Copy, Debug, Eq, PartialEq)]
49pub struct AuthContext<'c> {
50    /// The action to be authorized.
51    pub action: AuthAction<'c>,
52
53    /// The database name, if applicable.
54    pub database_name: Option<&'c str>,
55
56    /// The inner-most trigger or view responsible for the access attempt.
57    /// `None` if the access attempt was made by top-level SQL code.
58    pub accessor: Option<&'c str>,
59}
60
61/// Actions and arguments found within a statement during
62/// preparation.
63///
64/// See <https://sqlite.org/c3ref/c_alter_table.html> for more info.
65#[derive(Clone, Copy, Debug, Eq, PartialEq)]
66#[non_exhaustive]
67#[allow(missing_docs)]
68pub enum AuthAction<'c> {
69    /// This variant is not normally produced by SQLite. You may encounter it
70    // if you're using a different version than what's supported by this library.
71    Unknown {
72        /// The unknown authorization action code.
73        code: i32,
74        /// The third arg to the authorizer callback.
75        arg1: Option<&'c str>,
76        /// The fourth arg to the authorizer callback.
77        arg2: Option<&'c str>,
78    },
79    CreateIndex {
80        index_name: &'c str,
81        table_name: &'c str,
82    },
83    CreateTable {
84        table_name: &'c str,
85    },
86    CreateTempIndex {
87        index_name: &'c str,
88        table_name: &'c str,
89    },
90    CreateTempTable {
91        table_name: &'c str,
92    },
93    CreateTempTrigger {
94        trigger_name: &'c str,
95        table_name: &'c str,
96    },
97    CreateTempView {
98        view_name: &'c str,
99    },
100    CreateTrigger {
101        trigger_name: &'c str,
102        table_name: &'c str,
103    },
104    CreateView {
105        view_name: &'c str,
106    },
107    Delete {
108        table_name: &'c str,
109    },
110    DropIndex {
111        index_name: &'c str,
112        table_name: &'c str,
113    },
114    DropTable {
115        table_name: &'c str,
116    },
117    DropTempIndex {
118        index_name: &'c str,
119        table_name: &'c str,
120    },
121    DropTempTable {
122        table_name: &'c str,
123    },
124    DropTempTrigger {
125        trigger_name: &'c str,
126        table_name: &'c str,
127    },
128    DropTempView {
129        view_name: &'c str,
130    },
131    DropTrigger {
132        trigger_name: &'c str,
133        table_name: &'c str,
134    },
135    DropView {
136        view_name: &'c str,
137    },
138    Insert {
139        table_name: &'c str,
140    },
141    Pragma {
142        pragma_name: &'c str,
143        /// The pragma value, if present (e.g., `PRAGMA name = value;`).
144        pragma_value: Option<&'c str>,
145    },
146    Read {
147        table_name: &'c str,
148        column_name: &'c str,
149    },
150    Select,
151    Transaction {
152        operation: TransactionOperation,
153    },
154    Update {
155        table_name: &'c str,
156        column_name: &'c str,
157    },
158    Attach {
159        filename: &'c str,
160    },
161    Detach {
162        database_name: &'c str,
163    },
164    AlterTable {
165        database_name: &'c str,
166        table_name: &'c str,
167    },
168    Reindex {
169        index_name: &'c str,
170    },
171    Analyze {
172        table_name: &'c str,
173    },
174    CreateVtable {
175        table_name: &'c str,
176        module_name: &'c str,
177    },
178    DropVtable {
179        table_name: &'c str,
180        module_name: &'c str,
181    },
182    Function {
183        function_name: &'c str,
184    },
185    Savepoint {
186        operation: TransactionOperation,
187        savepoint_name: &'c str,
188    },
189    Recursive,
190}
191
192impl<'c> AuthAction<'c> {
193    fn from_raw(code: i32, arg1: Option<&'c str>, arg2: Option<&'c str>) -> Self {
194        match (code, arg1, arg2) {
195            (ffi::SQLITE_CREATE_INDEX, Some(index_name), Some(table_name)) => Self::CreateIndex {
196                index_name,
197                table_name,
198            },
199            (ffi::SQLITE_CREATE_TABLE, Some(table_name), _) => Self::CreateTable { table_name },
200            (ffi::SQLITE_CREATE_TEMP_INDEX, Some(index_name), Some(table_name)) => {
201                Self::CreateTempIndex {
202                    index_name,
203                    table_name,
204                }
205            }
206            (ffi::SQLITE_CREATE_TEMP_TABLE, Some(table_name), _) => {
207                Self::CreateTempTable { table_name }
208            }
209            (ffi::SQLITE_CREATE_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
210                Self::CreateTempTrigger {
211                    trigger_name,
212                    table_name,
213                }
214            }
215            (ffi::SQLITE_CREATE_TEMP_VIEW, Some(view_name), _) => {
216                Self::CreateTempView { view_name }
217            }
218            (ffi::SQLITE_CREATE_TRIGGER, Some(trigger_name), Some(table_name)) => {
219                Self::CreateTrigger {
220                    trigger_name,
221                    table_name,
222                }
223            }
224            (ffi::SQLITE_CREATE_VIEW, Some(view_name), _) => Self::CreateView { view_name },
225            (ffi::SQLITE_DELETE, Some(table_name), None) => Self::Delete { table_name },
226            (ffi::SQLITE_DROP_INDEX, Some(index_name), Some(table_name)) => Self::DropIndex {
227                index_name,
228                table_name,
229            },
230            (ffi::SQLITE_DROP_TABLE, Some(table_name), _) => Self::DropTable { table_name },
231            (ffi::SQLITE_DROP_TEMP_INDEX, Some(index_name), Some(table_name)) => {
232                Self::DropTempIndex {
233                    index_name,
234                    table_name,
235                }
236            }
237            (ffi::SQLITE_DROP_TEMP_TABLE, Some(table_name), _) => {
238                Self::DropTempTable { table_name }
239            }
240            (ffi::SQLITE_DROP_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
241                Self::DropTempTrigger {
242                    trigger_name,
243                    table_name,
244                }
245            }
246            (ffi::SQLITE_DROP_TEMP_VIEW, Some(view_name), _) => Self::DropTempView { view_name },
247            (ffi::SQLITE_DROP_TRIGGER, Some(trigger_name), Some(table_name)) => Self::DropTrigger {
248                trigger_name,
249                table_name,
250            },
251            (ffi::SQLITE_DROP_VIEW, Some(view_name), _) => Self::DropView { view_name },
252            (ffi::SQLITE_INSERT, Some(table_name), _) => Self::Insert { table_name },
253            (ffi::SQLITE_PRAGMA, Some(pragma_name), pragma_value) => Self::Pragma {
254                pragma_name,
255                pragma_value,
256            },
257            (ffi::SQLITE_READ, Some(table_name), Some(column_name)) => Self::Read {
258                table_name,
259                column_name,
260            },
261            (ffi::SQLITE_SELECT, ..) => Self::Select,
262            (ffi::SQLITE_TRANSACTION, Some(operation_str), _) => Self::Transaction {
263                operation: TransactionOperation::from_str(operation_str),
264            },
265            (ffi::SQLITE_UPDATE, Some(table_name), Some(column_name)) => Self::Update {
266                table_name,
267                column_name,
268            },
269            (ffi::SQLITE_ATTACH, Some(filename), _) => Self::Attach { filename },
270            (ffi::SQLITE_DETACH, Some(database_name), _) => Self::Detach { database_name },
271            (ffi::SQLITE_ALTER_TABLE, Some(database_name), Some(table_name)) => Self::AlterTable {
272                database_name,
273                table_name,
274            },
275            (ffi::SQLITE_REINDEX, Some(index_name), _) => Self::Reindex { index_name },
276            (ffi::SQLITE_ANALYZE, Some(table_name), _) => Self::Analyze { table_name },
277            (ffi::SQLITE_CREATE_VTABLE, Some(table_name), Some(module_name)) => {
278                Self::CreateVtable {
279                    table_name,
280                    module_name,
281                }
282            }
283            (ffi::SQLITE_DROP_VTABLE, Some(table_name), Some(module_name)) => Self::DropVtable {
284                table_name,
285                module_name,
286            },
287            (ffi::SQLITE_FUNCTION, _, Some(function_name)) => Self::Function { function_name },
288            (ffi::SQLITE_SAVEPOINT, Some(operation_str), Some(savepoint_name)) => Self::Savepoint {
289                operation: TransactionOperation::from_str(operation_str),
290                savepoint_name,
291            },
292            (ffi::SQLITE_RECURSIVE, ..) => Self::Recursive,
293            (code, arg1, arg2) => Self::Unknown { code, arg1, arg2 },
294        }
295    }
296}
297
298pub(crate) type BoxedAuthorizer =
299    Box<dyn for<'c> FnMut(AuthContext<'c>) -> Authorization + Send + 'static>;
300
301/// A transaction operation.
302#[derive(Clone, Copy, Debug, Eq, PartialEq)]
303#[non_exhaustive]
304#[allow(missing_docs)]
305pub enum TransactionOperation {
306    Unknown,
307    Begin,
308    Release,
309    Rollback,
310}
311
312impl TransactionOperation {
313    fn from_str(op_str: &str) -> Self {
314        match op_str {
315            "BEGIN" => Self::Begin,
316            "RELEASE" => Self::Release,
317            "ROLLBACK" => Self::Rollback,
318            _ => Self::Unknown,
319        }
320    }
321}
322
323/// [`authorizer`](Connection::authorizer) return code
324#[derive(Clone, Copy, Debug, Eq, PartialEq)]
325#[non_exhaustive]
326pub enum Authorization {
327    /// Authorize the action.
328    Allow,
329    /// Don't allow access, but don't trigger an error either.
330    Ignore,
331    /// Trigger an error.
332    Deny,
333}
334
335impl Authorization {
336    fn into_raw(self) -> c_int {
337        match self {
338            Self::Allow => ffi::SQLITE_OK,
339            Self::Ignore => ffi::SQLITE_IGNORE,
340            Self::Deny => ffi::SQLITE_DENY,
341        }
342    }
343}
344
345impl Connection {
346    /// Register a callback function to be invoked whenever
347    /// a transaction is committed.
348    ///
349    /// The callback returns `true` to rollback.
350    #[inline]
351    pub fn commit_hook<F>(&self, hook: Option<F>)
352    where
353        F: FnMut() -> bool + Send + 'static,
354    {
355        self.db.borrow_mut().commit_hook(hook);
356    }
357
358    /// Register a callback function to be invoked whenever
359    /// a transaction is committed.
360    #[inline]
361    pub fn rollback_hook<F>(&self, hook: Option<F>)
362    where
363        F: FnMut() + Send + 'static,
364    {
365        self.db.borrow_mut().rollback_hook(hook);
366    }
367
368    /// Register a callback function to be invoked whenever
369    /// a row is updated, inserted or deleted in a rowid table.
370    ///
371    /// The callback parameters are:
372    ///
373    /// - the type of database update (`SQLITE_INSERT`, `SQLITE_UPDATE` or
374    ///   `SQLITE_DELETE`),
375    /// - the name of the database ("main", "temp", ...),
376    /// - the name of the table that is updated,
377    /// - the ROWID of the row that is updated.
378    #[inline]
379    pub fn update_hook<F>(&self, hook: Option<F>)
380    where
381        F: FnMut(Action, &str, &str, i64) + Send + 'static,
382    {
383        self.db.borrow_mut().update_hook(hook);
384    }
385
386    /// Register a callback that is invoked each time data is committed to a database in wal mode.
387    ///
388    /// A single database handle may have at most a single write-ahead log callback registered at one time.
389    /// Calling `wal_hook` replaces any previously registered write-ahead log callback.
390    /// Note that the `sqlite3_wal_autocheckpoint()` interface and the `wal_autocheckpoint` pragma
391    /// both invoke `sqlite3_wal_hook()` and will overwrite any prior `sqlite3_wal_hook()` settings.
392    pub fn wal_hook(&self, hook: Option<fn(&Wal, c_int) -> Result<()>>) {
393        unsafe extern "C" fn wal_hook_callback(
394            client_data: *mut c_void,
395            db: *mut ffi::sqlite3,
396            db_name: *const c_char,
397            pages: c_int,
398        ) -> c_int {
399            let hook_fn: fn(&Wal, c_int) -> Result<()> = std::mem::transmute(client_data);
400            let wal = Wal { db, db_name };
401            catch_unwind(|| match hook_fn(&wal, pages) {
402                Ok(_) => ffi::SQLITE_OK,
403                Err(e) => e
404                    .sqlite_error()
405                    .map_or(ffi::SQLITE_ERROR, |x| x.extended_code),
406            })
407            .unwrap_or_default()
408        }
409        let c = self.db.borrow_mut();
410        match hook {
411            Some(f) => unsafe {
412                ffi::sqlite3_wal_hook(c.db(), Some(wal_hook_callback), f as *mut c_void)
413            },
414            None => unsafe { ffi::sqlite3_wal_hook(c.db(), None, ptr::null_mut()) },
415        };
416    }
417
418    /// Register a query progress callback.
419    ///
420    /// The parameter `num_ops` is the approximate number of virtual machine
421    /// instructions that are evaluated between successive invocations of the
422    /// `handler`. If `num_ops` is less than one then the progress handler
423    /// is disabled.
424    ///
425    /// If the progress callback returns `true`, the operation is interrupted.
426    pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>)
427    where
428        F: FnMut() -> bool + Send + 'static,
429    {
430        self.db.borrow_mut().progress_handler(num_ops, handler);
431    }
432
433    /// Register an authorizer callback that's invoked
434    /// as a statement is being prepared.
435    #[inline]
436    pub fn authorizer<'c, F>(&self, hook: Option<F>)
437    where
438        F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
439    {
440        self.db.borrow_mut().authorizer(hook);
441    }
442}
443
444/// Checkpoint mode
445#[derive(Clone, Copy)]
446#[repr(i32)]
447#[non_exhaustive]
448pub enum CheckpointMode {
449    /// Do as much as possible w/o blocking
450    PASSIVE = ffi::SQLITE_CHECKPOINT_PASSIVE,
451    /// Wait for writers, then checkpoint
452    FULL = ffi::SQLITE_CHECKPOINT_FULL,
453    /// Like FULL but wait for readers
454    RESTART = ffi::SQLITE_CHECKPOINT_RESTART,
455    /// Like RESTART but also truncate WAL
456    TRUNCATE = ffi::SQLITE_CHECKPOINT_TRUNCATE,
457}
458
459/// Write-Ahead Log
460pub struct Wal {
461    db: *mut ffi::sqlite3,
462    db_name: *const c_char,
463}
464
465impl Wal {
466    /// Checkpoint a database
467    pub fn checkpoint(&self) -> Result<()> {
468        unsafe { decode_result_raw(self.db, ffi::sqlite3_wal_checkpoint(self.db, self.db_name)) }
469    }
470    /// Checkpoint a database
471    pub fn checkpoint_v2(&self, mode: CheckpointMode) -> Result<(c_int, c_int)> {
472        let mut n_log = 0;
473        let mut n_ckpt = 0;
474        unsafe {
475            decode_result_raw(
476                self.db,
477                ffi::sqlite3_wal_checkpoint_v2(
478                    self.db,
479                    self.db_name,
480                    mode as c_int,
481                    &mut n_log,
482                    &mut n_ckpt,
483                ),
484            )?
485        };
486        Ok((n_log, n_ckpt))
487    }
488
489    /// Name of the database that was written to
490    pub fn name(&self) -> &CStr {
491        unsafe { CStr::from_ptr(self.db_name) }
492    }
493}
494
495impl InnerConnection {
496    #[inline]
497    pub fn remove_hooks(&mut self) {
498        self.update_hook(None::<fn(Action, &str, &str, i64)>);
499        self.commit_hook(None::<fn() -> bool>);
500        self.rollback_hook(None::<fn()>);
501        self.progress_handler(0, None::<fn() -> bool>);
502        self.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
503    }
504
505    /// ```compile_fail
506    /// use rusqlite::{Connection, Result};
507    /// fn main() -> Result<()> {
508    ///     let db = Connection::open_in_memory()?;
509    ///     {
510    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
511    ///         db.commit_hook(Some(|| {
512    ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
513    ///             true
514    ///         }));
515    ///     }
516    ///     assert!(db
517    ///         .execute_batch(
518    ///             "BEGIN;
519    ///         CREATE TABLE foo (t TEXT);
520    ///         COMMIT;",
521    ///         )
522    ///         .is_err());
523    ///     Ok(())
524    /// }
525    /// ```
526    fn commit_hook<F>(&mut self, hook: Option<F>)
527    where
528        F: FnMut() -> bool + Send + 'static,
529    {
530        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
531        where
532            F: FnMut() -> bool,
533        {
534            let r = catch_unwind(|| {
535                let boxed_hook: *mut F = p_arg.cast::<F>();
536                (*boxed_hook)()
537            });
538            c_int::from(r.unwrap_or_default())
539        }
540
541        match hook {
542            Some(hook) => {
543                let boxed_hook = Box::new(hook);
544                unsafe {
545                    ffi::sqlite3_commit_hook(
546                        self.db(),
547                        Some(call_boxed_closure::<F>),
548                        &*boxed_hook as *const F as *mut _,
549                    )
550                };
551                self.commit_hook = Some(boxed_hook);
552            }
553            _ => {
554                unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) };
555                self.commit_hook = None;
556            }
557        }
558    }
559
560    /// ```compile_fail
561    /// use rusqlite::{Connection, Result};
562    /// fn main() -> Result<()> {
563    ///     let db = Connection::open_in_memory()?;
564    ///     {
565    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
566    ///         db.rollback_hook(Some(|| {
567    ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
568    ///         }));
569    ///     }
570    ///     assert!(db
571    ///         .execute_batch(
572    ///             "BEGIN;
573    ///         CREATE TABLE foo (t TEXT);
574    ///         ROLLBACK;",
575    ///         )
576    ///         .is_err());
577    ///     Ok(())
578    /// }
579    /// ```
580    fn rollback_hook<F>(&mut self, hook: Option<F>)
581    where
582        F: FnMut() + Send + 'static,
583    {
584        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void)
585        where
586            F: FnMut(),
587        {
588            drop(catch_unwind(|| {
589                let boxed_hook: *mut F = p_arg.cast::<F>();
590                (*boxed_hook)();
591            }));
592        }
593
594        match hook {
595            Some(hook) => {
596                let boxed_hook = Box::new(hook);
597                unsafe {
598                    ffi::sqlite3_rollback_hook(
599                        self.db(),
600                        Some(call_boxed_closure::<F>),
601                        &*boxed_hook as *const F as *mut _,
602                    )
603                };
604                self.rollback_hook = Some(boxed_hook);
605            }
606            _ => {
607                unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) };
608                self.rollback_hook = None;
609            }
610        }
611    }
612
613    /// ```compile_fail
614    /// use rusqlite::{Connection, Result};
615    /// fn main() -> Result<()> {
616    ///     let db = Connection::open_in_memory()?;
617    ///     {
618    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
619    ///         db.update_hook(Some(|_, _: &str, _: &str, _| {
620    ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
621    ///         }));
622    ///     }
623    ///     db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;")
624    /// }
625    /// ```
626    fn update_hook<F>(&mut self, hook: Option<F>)
627    where
628        F: FnMut(Action, &str, &str, i64) + Send + 'static,
629    {
630        unsafe extern "C" fn call_boxed_closure<F>(
631            p_arg: *mut c_void,
632            action_code: c_int,
633            p_db_name: *const c_char,
634            p_table_name: *const c_char,
635            row_id: i64,
636        ) where
637            F: FnMut(Action, &str, &str, i64),
638        {
639            let action = Action::from(action_code);
640            drop(catch_unwind(|| {
641                let boxed_hook: *mut F = p_arg.cast::<F>();
642                (*boxed_hook)(
643                    action,
644                    expect_utf8(p_db_name, "database name"),
645                    expect_utf8(p_table_name, "table name"),
646                    row_id,
647                );
648            }));
649        }
650
651        match hook {
652            Some(hook) => {
653                let boxed_hook = Box::new(hook);
654                unsafe {
655                    ffi::sqlite3_update_hook(
656                        self.db(),
657                        Some(call_boxed_closure::<F>),
658                        &*boxed_hook as *const F as *mut _,
659                    )
660                };
661                self.update_hook = Some(boxed_hook);
662            }
663            _ => {
664                unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) };
665                self.update_hook = None;
666            }
667        }
668    }
669
670    /// ```compile_fail
671    /// use rusqlite::{Connection, Result};
672    /// fn main() -> Result<()> {
673    ///     let db = Connection::open_in_memory()?;
674    ///     {
675    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
676    ///         db.progress_handler(
677    ///             1,
678    ///             Some(|| {
679    ///                 called.store(true, std::sync::atomic::Ordering::Relaxed);
680    ///                 true
681    ///             }),
682    ///         );
683    ///     }
684    ///     assert!(db
685    ///         .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
686    ///         .is_err());
687    ///     Ok(())
688    /// }
689    /// ```
690    fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>)
691    where
692        F: FnMut() -> bool + Send + 'static,
693    {
694        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
695        where
696            F: FnMut() -> bool,
697        {
698            let r = catch_unwind(|| {
699                let boxed_handler: *mut F = p_arg.cast::<F>();
700                (*boxed_handler)()
701            });
702            c_int::from(r.unwrap_or_default())
703        }
704
705        if let Some(handler) = handler {
706            let boxed_handler = Box::new(handler);
707            unsafe {
708                ffi::sqlite3_progress_handler(
709                    self.db(),
710                    num_ops,
711                    Some(call_boxed_closure::<F>),
712                    &*boxed_handler as *const F as *mut _,
713                );
714            }
715            self.progress_handler = Some(boxed_handler);
716        } else {
717            unsafe { ffi::sqlite3_progress_handler(self.db(), num_ops, None, ptr::null_mut()) }
718            self.progress_handler = None;
719        };
720    }
721
722    /// ```compile_fail
723    /// use rusqlite::{Connection, Result};
724    /// fn main() -> Result<()> {
725    ///     let db = Connection::open_in_memory()?;
726    ///     {
727    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
728    ///         db.authorizer(Some(|_: rusqlite::hooks::AuthContext<'_>| {
729    ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
730    ///             rusqlite::hooks::Authorization::Deny
731    ///         }));
732    ///     }
733    ///     assert!(db
734    ///         .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
735    ///         .is_err());
736    ///     Ok(())
737    /// }
738    /// ```
739    fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>)
740    where
741        F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
742    {
743        unsafe extern "C" fn call_boxed_closure<'c, F>(
744            p_arg: *mut c_void,
745            action_code: c_int,
746            param1: *const c_char,
747            param2: *const c_char,
748            db_name: *const c_char,
749            trigger_or_view_name: *const c_char,
750        ) -> c_int
751        where
752            F: FnMut(AuthContext<'c>) -> Authorization + Send + 'static,
753        {
754            catch_unwind(|| {
755                let action = AuthAction::from_raw(
756                    action_code,
757                    expect_optional_utf8(param1, "authorizer param 1"),
758                    expect_optional_utf8(param2, "authorizer param 2"),
759                );
760                let auth_ctx = AuthContext {
761                    action,
762                    database_name: expect_optional_utf8(db_name, "database name"),
763                    accessor: expect_optional_utf8(
764                        trigger_or_view_name,
765                        "accessor (inner-most trigger or view)",
766                    ),
767                };
768                let boxed_hook: *mut F = p_arg.cast::<F>();
769                (*boxed_hook)(auth_ctx)
770            })
771            .map_or_else(|_| ffi::SQLITE_ERROR, Authorization::into_raw)
772        }
773
774        let callback_fn = authorizer
775            .as_ref()
776            .map(|_| call_boxed_closure::<'c, F> as unsafe extern "C" fn(_, _, _, _, _, _) -> _);
777        let boxed_authorizer = authorizer.map(Box::new);
778
779        match unsafe {
780            ffi::sqlite3_set_authorizer(
781                self.db(),
782                callback_fn,
783                boxed_authorizer
784                    .as_ref()
785                    .map_or_else(ptr::null_mut, |f| &**f as *const F as *mut _),
786            )
787        } {
788            ffi::SQLITE_OK => {
789                self.authorizer = boxed_authorizer.map(|ba| ba as _);
790            }
791            err_code => {
792                // The only error that `sqlite3_set_authorizer` returns is `SQLITE_MISUSE`
793                // when compiled with `ENABLE_API_ARMOR` and the db pointer is invalid.
794                // This library does not allow constructing a null db ptr, so if this branch
795                // is hit, something very bad has happened. Panicking instead of returning
796                // `Result` keeps this hook's API consistent with the others.
797                panic!("unexpectedly failed to set_authorizer: {}", unsafe {
798                    crate::error::error_from_handle(self.db(), err_code)
799                });
800            }
801        }
802    }
803}
804
805unsafe fn expect_utf8<'a>(p_str: *const c_char, description: &'static str) -> &'a str {
806    expect_optional_utf8(p_str, description)
807        .unwrap_or_else(|| panic!("received empty {description}"))
808}
809
810unsafe fn expect_optional_utf8<'a>(
811    p_str: *const c_char,
812    description: &'static str,
813) -> Option<&'a str> {
814    if p_str.is_null() {
815        return None;
816    }
817    CStr::from_ptr(p_str)
818        .to_str()
819        .unwrap_or_else(|_| panic!("received non-utf8 string as {description}"))
820        .into()
821}
822
823#[cfg(test)]
824mod test {
825    use super::Action;
826    use crate::{Connection, Result, MAIN_DB};
827    use std::sync::atomic::{AtomicBool, Ordering};
828
829    #[test]
830    fn test_commit_hook() -> Result<()> {
831        let db = Connection::open_in_memory()?;
832
833        static CALLED: AtomicBool = AtomicBool::new(false);
834        db.commit_hook(Some(|| {
835            CALLED.store(true, Ordering::Relaxed);
836            false
837        }));
838        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
839        assert!(CALLED.load(Ordering::Relaxed));
840        Ok(())
841    }
842
843    #[test]
844    fn test_fn_commit_hook() -> Result<()> {
845        let db = Connection::open_in_memory()?;
846
847        fn hook() -> bool {
848            true
849        }
850
851        db.commit_hook(Some(hook));
852        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
853            .unwrap_err();
854        Ok(())
855    }
856
857    #[test]
858    fn test_rollback_hook() -> Result<()> {
859        let db = Connection::open_in_memory()?;
860
861        static CALLED: AtomicBool = AtomicBool::new(false);
862        db.rollback_hook(Some(|| {
863            CALLED.store(true, Ordering::Relaxed);
864        }));
865        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")?;
866        assert!(CALLED.load(Ordering::Relaxed));
867        Ok(())
868    }
869
870    #[test]
871    fn test_update_hook() -> Result<()> {
872        let db = Connection::open_in_memory()?;
873
874        static CALLED: AtomicBool = AtomicBool::new(false);
875        db.update_hook(Some(|action, db: &str, tbl: &str, row_id| {
876            assert_eq!(Action::SQLITE_INSERT, action);
877            assert_eq!("main", db);
878            assert_eq!("foo", tbl);
879            assert_eq!(1, row_id);
880            CALLED.store(true, Ordering::Relaxed);
881        }));
882        db.execute_batch("CREATE TABLE foo (t TEXT)")?;
883        db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
884        assert!(CALLED.load(Ordering::Relaxed));
885        Ok(())
886    }
887
888    #[test]
889    fn test_progress_handler() -> Result<()> {
890        let db = Connection::open_in_memory()?;
891
892        static CALLED: AtomicBool = AtomicBool::new(false);
893        db.progress_handler(
894            1,
895            Some(|| {
896                CALLED.store(true, Ordering::Relaxed);
897                false
898            }),
899        );
900        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
901        assert!(CALLED.load(Ordering::Relaxed));
902        Ok(())
903    }
904
905    #[test]
906    fn test_progress_handler_interrupt() -> Result<()> {
907        let db = Connection::open_in_memory()?;
908
909        fn handler() -> bool {
910            true
911        }
912
913        db.progress_handler(1, Some(handler));
914        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
915            .unwrap_err();
916        Ok(())
917    }
918
919    #[test]
920    fn test_authorizer() -> Result<()> {
921        use super::{AuthAction, AuthContext, Authorization};
922
923        let db = Connection::open_in_memory()?;
924        db.execute_batch("CREATE TABLE foo (public TEXT, private TEXT)")?;
925
926        let authorizer = move |ctx: AuthContext<'_>| match ctx.action {
927            AuthAction::Read {
928                column_name: "private",
929                ..
930            } => Authorization::Ignore,
931            AuthAction::DropTable { .. } => Authorization::Deny,
932            AuthAction::Pragma { .. } => panic!("shouldn't be called"),
933            _ => Authorization::Allow,
934        };
935
936        db.authorizer(Some(authorizer));
937        db.execute_batch(
938            "BEGIN TRANSACTION; INSERT INTO foo VALUES ('pub txt', 'priv txt'); COMMIT;",
939        )?;
940        db.query_row_and_then("SELECT * FROM foo", [], |row| -> Result<()> {
941            assert_eq!(row.get::<_, String>("public")?, "pub txt");
942            assert!(row.get::<_, Option<String>>("private")?.is_none());
943            Ok(())
944        })?;
945        db.execute_batch("DROP TABLE foo").unwrap_err();
946
947        db.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
948        db.execute_batch("PRAGMA user_version=1")?; // Disallowed by first authorizer, but it's now removed.
949
950        Ok(())
951    }
952
953    #[test]
954    fn wal_hook() -> Result<()> {
955        let temp_dir = tempfile::tempdir().unwrap();
956        let path = temp_dir.path().join("wal-hook.db3");
957
958        let db = Connection::open(&path)?;
959        let journal_mode: String =
960            db.pragma_update_and_check(None, "journal_mode", "wal", |row| row.get(0))?;
961        assert_eq!(journal_mode, "wal");
962
963        static CALLED: AtomicBool = AtomicBool::new(false);
964        db.wal_hook(Some(|wal, pages| {
965            assert_eq!(wal.name(), MAIN_DB);
966            assert!(pages > 0);
967            CALLED.swap(true, Ordering::Relaxed);
968            wal.checkpoint()
969        }));
970        db.execute_batch("CREATE TABLE x(c);")?;
971        assert!(CALLED.load(Ordering::Relaxed));
972
973        db.wal_hook(Some(|wal, pages| {
974            assert!(pages > 0);
975            let (log, ckpt) = wal.checkpoint_v2(super::CheckpointMode::TRUNCATE)?;
976            assert_eq!(log, 0);
977            assert_eq!(ckpt, 0);
978            Ok(())
979        }));
980        db.execute_batch("CREATE TABLE y(c);")?;
981
982        db.wal_hook(None);
983        Ok(())
984    }
985}