rusqlite/
functions.rs

1//! Create or redefine SQL functions.
2//!
3//! # Example
4//!
5//! Adding a `regexp` function to a connection in which compiled regular
6//! expressions are cached in a `HashMap`. For an alternative implementation
7//! that uses SQLite's [Function Auxiliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
8//! to avoid recompiling regular expressions, see the unit tests for this
9//! module.
10//!
11//! ```rust
12//! use regex::Regex;
13//! use rusqlite::functions::FunctionFlags;
14//! use rusqlite::{Connection, Error, Result};
15//! use std::sync::Arc;
16//! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
17//!
18//! fn add_regexp_function(db: &Connection) -> Result<()> {
19//!     db.create_scalar_function(
20//!         "regexp",
21//!         2,
22//!         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
23//!         move |ctx| {
24//!             assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
25//!             let regexp: Arc<Regex> = ctx.get_or_create_aux(0, |vr| -> Result<_, BoxError> {
26//!                 Ok(Regex::new(vr.as_str()?)?)
27//!             })?;
28//!             let is_match = {
29//!                 let text = ctx
30//!                     .get_raw(1)
31//!                     .as_str()
32//!                     .map_err(|e| Error::UserFunctionError(e.into()))?;
33//!
34//!                 regexp.is_match(text)
35//!             };
36//!
37//!             Ok(is_match)
38//!         },
39//!     )
40//! }
41//!
42//! fn main() -> Result<()> {
43//!     let db = Connection::open_in_memory()?;
44//!     add_regexp_function(&db)?;
45//!
46//!     let is_match: bool =
47//!         db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", [], |row| {
48//!             row.get(0)
49//!         })?;
50//!
51//!     assert!(is_match);
52//!     Ok(())
53//! }
54//! ```
55use std::any::Any;
56use std::ffi::{c_int, c_uint, c_void};
57use std::marker::PhantomData;
58use std::ops::Deref;
59use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60use std::ptr;
61use std::slice;
62use std::sync::Arc;
63
64use crate::ffi;
65use crate::ffi::sqlite3_context;
66use crate::ffi::sqlite3_value;
67
68use crate::context::set_result;
69use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
70use crate::util::free_boxed_value;
71use crate::{str_to_cstring, Connection, Error, InnerConnection, Name, Result};
72
73unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74    if let Error::SqliteFailure(ref err, ref s) = *err {
75        ffi::sqlite3_result_error_code(ctx, err.extended_code);
76        if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
77            ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
78        }
79    } else {
80        ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION);
81        if let Ok(cstr) = str_to_cstring(&err.to_string()) {
82            ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
83        }
84    }
85}
86
87/// Context is a wrapper for the SQLite function
88/// evaluation context.
89pub struct Context<'a> {
90    ctx: *mut sqlite3_context,
91    args: &'a [*mut sqlite3_value],
92}
93
94impl Context<'_> {
95    /// Returns the number of arguments to the function.
96    #[inline]
97    #[must_use]
98    pub fn len(&self) -> usize {
99        self.args.len()
100    }
101
102    /// Returns `true` when there is no argument.
103    #[inline]
104    #[must_use]
105    pub fn is_empty(&self) -> bool {
106        self.args.is_empty()
107    }
108
109    /// Returns the `idx`th argument as a `T`.
110    ///
111    /// # Failure
112    ///
113    /// Will panic if `idx` is greater than or equal to
114    /// [`self.len()`](Context::len).
115    ///
116    /// Will return Err if the underlying SQLite type cannot be converted to a
117    /// `T`.
118    pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
119        let arg = self.args[idx];
120        let value = unsafe { ValueRef::from_value(arg) };
121        FromSql::column_result(value).map_err(|err| match err {
122            FromSqlError::InvalidType => {
123                Error::InvalidFunctionParameterType(idx, value.data_type())
124            }
125            FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
126            FromSqlError::Other(err) => {
127                Error::FromSqlConversionFailure(idx, value.data_type(), err)
128            }
129            FromSqlError::InvalidBlobSize { .. } => {
130                Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
131            }
132        })
133    }
134
135    /// Returns the `idx`th argument as a `ValueRef`.
136    ///
137    /// # Failure
138    ///
139    /// Will panic if `idx` is greater than or equal to
140    /// [`self.len()`](Context::len).
141    #[inline]
142    #[must_use]
143    pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
144        let arg = self.args[idx];
145        unsafe { ValueRef::from_value(arg) }
146    }
147
148    /// Returns the `idx`th argument as a `SqlFnArg`.
149    /// To be used when the SQL function result is one of its arguments.
150    #[inline]
151    #[must_use]
152    pub fn get_arg(&self, idx: usize) -> SqlFnArg {
153        assert!(idx < self.len());
154        SqlFnArg { idx }
155    }
156
157    /// Returns the subtype of `idx`th argument.
158    ///
159    /// # Failure
160    ///
161    /// Will panic if `idx` is greater than or equal to
162    /// [`self.len()`](Context::len).
163    pub fn get_subtype(&self, idx: usize) -> c_uint {
164        let arg = self.args[idx];
165        unsafe { ffi::sqlite3_value_subtype(arg) }
166    }
167
168    /// Fetch or insert the auxiliary data associated with a particular
169    /// parameter. This is intended to be an easier-to-use way of fetching it
170    /// compared to calling [`get_aux`](Context::get_aux) and
171    /// [`set_aux`](Context::set_aux) separately.
172    ///
173    /// See `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
174    /// this feature, or the unit tests of this module for an example.
175    ///
176    /// # Failure
177    ///
178    /// Will panic if `arg` is greater than or equal to
179    /// [`self.len()`](Context::len).
180    pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
181    where
182        T: Send + Sync + 'static,
183        E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
184        F: FnOnce(ValueRef<'_>) -> Result<T, E>,
185    {
186        if let Some(v) = self.get_aux(arg)? {
187            Ok(v)
188        } else {
189            let vr = self.get_raw(arg as usize);
190            self.set_aux(
191                arg,
192                func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
193            )
194        }
195    }
196
197    /// Sets the auxiliary data associated with a particular parameter. See
198    /// `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
199    /// this feature, or the unit tests of this module for an example.
200    ///
201    /// # Failure
202    ///
203    /// Will panic if `arg` is greater than or equal to
204    /// [`self.len()`](Context::len).
205    pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
206        assert!(arg < self.len() as i32);
207        let orig: Arc<T> = Arc::new(value);
208        let inner: AuxInner = orig.clone();
209        let outer = Box::new(inner);
210        let raw: *mut AuxInner = Box::into_raw(outer);
211        unsafe {
212            ffi::sqlite3_set_auxdata(
213                self.ctx,
214                arg,
215                raw.cast(),
216                Some(free_boxed_value::<AuxInner>),
217            );
218        };
219        Ok(orig)
220    }
221
222    /// Gets the auxiliary data that was associated with a given parameter via
223    /// [`set_aux`](Context::set_aux). Returns `Ok(None)` if no data has been
224    /// associated, and Ok(Some(v)) if it has. Returns an error if the
225    /// requested type does not match.
226    ///
227    /// # Failure
228    ///
229    /// Will panic if `arg` is greater than or equal to
230    /// [`self.len()`](Context::len).
231    pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
232        assert!(arg < self.len() as i32);
233        let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
234        if p.is_null() {
235            Ok(None)
236        } else {
237            let v: AuxInner = AuxInner::clone(unsafe { &*p });
238            v.downcast::<T>()
239                .map(Some)
240                .map_err(|_| Error::GetAuxWrongType)
241        }
242    }
243
244    /// Get the db connection handle via [sqlite3_context_db_handle](https://www.sqlite.org/c3ref/context_db_handle.html)
245    ///
246    /// # Safety
247    ///
248    /// This function is marked unsafe because there is a potential for other
249    /// references to the connection to be sent across threads, [see this comment](https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213).
250    pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
251        let handle = ffi::sqlite3_context_db_handle(self.ctx);
252        Ok(ConnectionRef {
253            conn: Connection::from_handle(handle)?,
254            phantom: PhantomData,
255        })
256    }
257}
258
259/// A reference to a connection handle with a lifetime bound to something.
260pub struct ConnectionRef<'ctx> {
261    // comes from Connection::from_handle(sqlite3_context_db_handle(...))
262    // and is non-owning
263    conn: Connection,
264    phantom: PhantomData<&'ctx Context<'ctx>>,
265}
266
267impl Deref for ConnectionRef<'_> {
268    type Target = Connection;
269
270    #[inline]
271    fn deref(&self) -> &Connection {
272        &self.conn
273    }
274}
275
276type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
277
278/// Subtype of an SQL function
279pub type SubType = Option<c_uint>;
280
281/// Result of an SQL function
282pub trait SqlFnOutput {
283    /// Converts Rust value to SQLite value with an optional subtype
284    fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>;
285}
286
287impl<T: ToSql> SqlFnOutput for T {
288    #[inline]
289    fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
290        ToSql::to_sql(self).map(|o| (o, None))
291    }
292}
293
294impl<T: ToSql> SqlFnOutput for (T, SubType) {
295    fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
296        ToSql::to_sql(&self.0).map(|o| (o, self.1))
297    }
298}
299
300/// n-th arg of an SQL scalar function
301pub struct SqlFnArg {
302    idx: usize,
303}
304impl ToSql for SqlFnArg {
305    fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
306        Ok(ToSqlOutput::Arg(self.idx))
307    }
308}
309
310unsafe fn sql_result<T: SqlFnOutput>(
311    ctx: *mut sqlite3_context,
312    args: &[*mut sqlite3_value],
313    r: Result<T>,
314) {
315    let t = r.as_ref().map(SqlFnOutput::to_sql);
316
317    match t {
318        Ok(Ok((ref value, sub_type))) => {
319            set_result(ctx, args, value);
320            if let Some(sub_type) = sub_type {
321                ffi::sqlite3_result_subtype(ctx, sub_type);
322            }
323        }
324        Ok(Err(err)) => report_error(ctx, &err),
325        Err(err) => report_error(ctx, err),
326    };
327}
328
329/// Aggregate is the callback interface for user-defined
330/// aggregate function.
331///
332/// `A` is the type of the aggregation context and `T` is the type of the final
333/// result. Implementations should be stateless.
334pub trait Aggregate<A, T>
335where
336    A: RefUnwindSafe + UnwindSafe,
337    T: SqlFnOutput,
338{
339    /// Initializes the aggregation context. Will be called prior to the first
340    /// call to [`step()`](Aggregate::step) to set up the context for an
341    /// invocation of the function. (Note: `init()` will not be called if
342    /// there are no rows.)
343    fn init(&self, ctx: &mut Context<'_>) -> Result<A>;
344
345    /// "step" function called once for each row in an aggregate group. May be
346    /// called 0 times if there are no rows.
347    fn step(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
348
349    /// Computes and returns the final result. Will be called exactly once for
350    /// each invocation of the function. If [`step()`](Aggregate::step) was
351    /// called at least once, will be given `Some(A)` (the same `A` as was
352    /// created by [`init`](Aggregate::init) and given to
353    /// [`step`](Aggregate::step)); if [`step()`](Aggregate::step) was not
354    /// called (because the function is running against 0 rows), will be
355    /// given `None`.
356    ///
357    /// The passed context will have no arguments.
358    fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<T>;
359}
360
361/// `WindowAggregate` is the callback interface for
362/// user-defined aggregate window function.
363#[cfg(feature = "window")]
364pub trait WindowAggregate<A, T>: Aggregate<A, T>
365where
366    A: RefUnwindSafe + UnwindSafe,
367    T: SqlFnOutput,
368{
369    /// Returns the current value of the aggregate. Unlike xFinal, the
370    /// implementation should not delete any context.
371    fn value(&self, acc: Option<&mut A>) -> Result<T>;
372
373    /// Removes a row from the current window.
374    fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
375}
376
377bitflags::bitflags! {
378    /// Function Flags.
379    /// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html)
380    /// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details.
381    #[derive(Clone, Copy, Debug)]
382    #[repr(C)]
383    pub struct FunctionFlags: c_int {
384        /// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters.
385        const SQLITE_UTF8     = ffi::SQLITE_UTF8;
386        /// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters.
387        const SQLITE_UTF16LE  = ffi::SQLITE_UTF16LE;
388        /// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters.
389        const SQLITE_UTF16BE  = ffi::SQLITE_UTF16BE;
390        /// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters.
391        const SQLITE_UTF16    = ffi::SQLITE_UTF16;
392        /// Means that the function always gives the same output when the input parameters are the same.
393        const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; // 3.8.3
394        /// Means that the function may only be invoked from top-level SQL.
395        const SQLITE_DIRECTONLY    = 0x0000_0008_0000; // 3.30.0
396        /// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the subtypes of its arguments.
397        const SQLITE_SUBTYPE       = 0x0000_0010_0000; // 3.30.0
398        /// Means that the function is unlikely to cause problems even if misused.
399        const SQLITE_INNOCUOUS     = 0x0000_0020_0000; // 3.31.0
400        /// Indicates to SQLite that a function might call `sqlite3_result_subtype()` to cause a subtype to be associated with its result.
401        const SQLITE_RESULT_SUBTYPE     = 0x0000_0100_0000; // 3.45.0
402        /// Indicates that the function is an aggregate that internally orders the values provided to the first argument.
403        const SQLITE_SELFORDER1 = 0x0000_0200_0000; // 3.47.0
404    }
405}
406
407impl Default for FunctionFlags {
408    #[inline]
409    fn default() -> Self {
410        Self::SQLITE_UTF8
411    }
412}
413
414impl Connection {
415    /// Attach a user-defined scalar function to
416    /// this database connection.
417    ///
418    /// `fn_name` is the name the function will be accessible from SQL.
419    /// `n_arg` is the number of arguments to the function. Use `-1` for a
420    /// variable number. If the function always returns the same value
421    /// given the same input, `deterministic` should be `true`.
422    ///
423    /// The function will remain available until the connection is closed or
424    /// until it is explicitly removed via
425    /// [`remove_function`](Connection::remove_function).
426    ///
427    /// # Example
428    ///
429    /// ```rust
430    /// # use rusqlite::{Connection, Result};
431    /// # use rusqlite::functions::FunctionFlags;
432    /// fn scalar_function_example(db: Connection) -> Result<()> {
433    ///     db.create_scalar_function(
434    ///         "halve",
435    ///         1,
436    ///         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
437    ///         |ctx| {
438    ///             let value = ctx.get::<f64>(0)?;
439    ///             Ok(value / 2f64)
440    ///         },
441    ///     )?;
442    ///
443    ///     let six_halved: f64 = db.query_row("SELECT halve(6)", [], |r| r.get(0))?;
444    ///     assert_eq!(six_halved, 3f64);
445    ///     Ok(())
446    /// }
447    /// ```
448    ///
449    /// # Failure
450    ///
451    /// Will return Err if the function could not be attached to the connection.
452    #[inline]
453    pub fn create_scalar_function<F, N: Name, T>(
454        &self,
455        fn_name: N,
456        n_arg: c_int,
457        flags: FunctionFlags,
458        x_func: F,
459    ) -> Result<()>
460    where
461        F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
462        T: SqlFnOutput,
463    {
464        self.db
465            .borrow_mut()
466            .create_scalar_function(fn_name, n_arg, flags, x_func)
467    }
468
469    /// Attach a user-defined aggregate function to this
470    /// database connection.
471    ///
472    /// # Failure
473    ///
474    /// Will return Err if the function could not be attached to the connection.
475    #[inline]
476    pub fn create_aggregate_function<A, D, N: Name, T>(
477        &self,
478        fn_name: N,
479        n_arg: c_int,
480        flags: FunctionFlags,
481        aggr: D,
482    ) -> Result<()>
483    where
484        A: RefUnwindSafe + UnwindSafe,
485        D: Aggregate<A, T> + 'static,
486        T: SqlFnOutput,
487    {
488        self.db
489            .borrow_mut()
490            .create_aggregate_function(fn_name, n_arg, flags, aggr)
491    }
492
493    /// Attach a user-defined aggregate window function to
494    /// this database connection.
495    ///
496    /// See `https://sqlite.org/windowfunctions.html#udfwinfunc` for more
497    /// information.
498    #[cfg(feature = "window")]
499    #[inline]
500    pub fn create_window_function<A, N: Name, W, T>(
501        &self,
502        fn_name: N,
503        n_arg: c_int,
504        flags: FunctionFlags,
505        aggr: W,
506    ) -> Result<()>
507    where
508        A: RefUnwindSafe + UnwindSafe,
509        W: WindowAggregate<A, T> + 'static,
510        T: SqlFnOutput,
511    {
512        self.db
513            .borrow_mut()
514            .create_window_function(fn_name, n_arg, flags, aggr)
515    }
516
517    /// Removes a user-defined function from this
518    /// database connection.
519    ///
520    /// `fn_name` and `n_arg` should match the name and number of arguments
521    /// given to [`create_scalar_function`](Connection::create_scalar_function)
522    /// or [`create_aggregate_function`](Connection::create_aggregate_function).
523    ///
524    /// # Failure
525    ///
526    /// Will return Err if the function could not be removed.
527    #[inline]
528    pub fn remove_function<N: Name>(&self, fn_name: N, n_arg: c_int) -> Result<()> {
529        self.db.borrow_mut().remove_function(fn_name, n_arg)
530    }
531}
532
533impl InnerConnection {
534    /// ```compile_fail
535    /// use rusqlite::{functions::FunctionFlags, Connection, Result};
536    /// fn main() -> Result<()> {
537    ///     let db = Connection::open_in_memory()?;
538    ///     {
539    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
540    ///         db.create_scalar_function(
541    ///             "test",
542    ///             0,
543    ///             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
544    ///             |_| {
545    ///                 called.store(true, std::sync::atomic::Ordering::Relaxed);
546    ///                 Ok(true)
547    ///             },
548    ///         );
549    ///     }
550    ///     let result: Result<bool> = db.query_row("SELECT test()", [], |r| r.get(0));
551    ///     assert!(result?);
552    ///     Ok(())
553    /// }
554    /// ```
555    fn create_scalar_function<F, N: Name, T>(
556        &mut self,
557        fn_name: N,
558        n_arg: c_int,
559        flags: FunctionFlags,
560        x_func: F,
561    ) -> Result<()>
562    where
563        F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
564        T: SqlFnOutput,
565    {
566        unsafe extern "C" fn call_boxed_closure<F, T>(
567            ctx: *mut sqlite3_context,
568            argc: c_int,
569            argv: *mut *mut sqlite3_value,
570        ) where
571            F: Fn(&Context<'_>) -> Result<T>,
572            T: SqlFnOutput,
573        {
574            let args = slice::from_raw_parts(argv, argc as usize);
575            let r = catch_unwind(|| {
576                let boxed_f: *const F = ffi::sqlite3_user_data(ctx).cast::<F>();
577                assert!(!boxed_f.is_null(), "Internal error - null function pointer");
578                let ctx = Context { ctx, args };
579                (*boxed_f)(&ctx)
580            });
581            let t = match r {
582                Err(_) => {
583                    report_error(ctx, &Error::UnwindingPanic);
584                    return;
585                }
586                Ok(r) => r,
587            };
588            sql_result(ctx, args, t);
589        }
590
591        let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
592        let c_name = fn_name.as_cstr()?;
593        let r = unsafe {
594            ffi::sqlite3_create_function_v2(
595                self.db(),
596                c_name.as_ptr(),
597                n_arg,
598                flags.bits(),
599                boxed_f.cast::<c_void>(),
600                Some(call_boxed_closure::<F, T>),
601                None,
602                None,
603                Some(free_boxed_value::<F>),
604            )
605        };
606        self.decode_result(r)
607    }
608
609    fn create_aggregate_function<A, D, N: Name, T>(
610        &mut self,
611        fn_name: N,
612        n_arg: c_int,
613        flags: FunctionFlags,
614        aggr: D,
615    ) -> Result<()>
616    where
617        A: RefUnwindSafe + UnwindSafe,
618        D: Aggregate<A, T> + 'static,
619        T: SqlFnOutput,
620    {
621        let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
622        let c_name = fn_name.as_cstr()?;
623        let r = unsafe {
624            ffi::sqlite3_create_function_v2(
625                self.db(),
626                c_name.as_ptr(),
627                n_arg,
628                flags.bits(),
629                boxed_aggr.cast::<c_void>(),
630                None,
631                Some(call_boxed_step::<A, D, T>),
632                Some(call_boxed_final::<A, D, T>),
633                Some(free_boxed_value::<D>),
634            )
635        };
636        self.decode_result(r)
637    }
638
639    #[cfg(feature = "window")]
640    fn create_window_function<A, N: Name, W, T>(
641        &mut self,
642        fn_name: N,
643        n_arg: c_int,
644        flags: FunctionFlags,
645        aggr: W,
646    ) -> Result<()>
647    where
648        A: RefUnwindSafe + UnwindSafe,
649        W: WindowAggregate<A, T> + 'static,
650        T: SqlFnOutput,
651    {
652        let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
653        let c_name = fn_name.as_cstr()?;
654        let r = unsafe {
655            ffi::sqlite3_create_window_function(
656                self.db(),
657                c_name.as_ptr(),
658                n_arg,
659                flags.bits(),
660                boxed_aggr.cast::<c_void>(),
661                Some(call_boxed_step::<A, W, T>),
662                Some(call_boxed_final::<A, W, T>),
663                Some(call_boxed_value::<A, W, T>),
664                Some(call_boxed_inverse::<A, W, T>),
665                Some(free_boxed_value::<W>),
666            )
667        };
668        self.decode_result(r)
669    }
670
671    fn remove_function<N: Name>(&mut self, fn_name: N, n_arg: c_int) -> Result<()> {
672        let c_name = fn_name.as_cstr()?;
673        let r = unsafe {
674            ffi::sqlite3_create_function_v2(
675                self.db(),
676                c_name.as_ptr(),
677                n_arg,
678                ffi::SQLITE_UTF8,
679                ptr::null_mut(),
680                None,
681                None,
682                None,
683                None,
684            )
685        };
686        self.decode_result(r)
687    }
688}
689
690unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
691    let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
692    if pac.is_null() {
693        return None;
694    }
695    Some(pac)
696}
697
698unsafe extern "C" fn call_boxed_step<A, D, T>(
699    ctx: *mut sqlite3_context,
700    argc: c_int,
701    argv: *mut *mut sqlite3_value,
702) where
703    A: RefUnwindSafe + UnwindSafe,
704    D: Aggregate<A, T>,
705    T: SqlFnOutput,
706{
707    let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
708        ffi::sqlite3_result_error_nomem(ctx);
709        return;
710    };
711
712    let r = catch_unwind(|| {
713        let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
714        assert!(
715            !boxed_aggr.is_null(),
716            "Internal error - null aggregate pointer"
717        );
718        let mut ctx = Context {
719            ctx,
720            args: slice::from_raw_parts(argv, argc as usize),
721        };
722
723        #[expect(clippy::unnecessary_cast)]
724        if (*pac as *mut A).is_null() {
725            *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
726        }
727
728        (*boxed_aggr).step(&mut ctx, &mut **pac)
729    });
730    let r = match r {
731        Err(_) => {
732            report_error(ctx, &Error::UnwindingPanic);
733            return;
734        }
735        Ok(r) => r,
736    };
737    match r {
738        Ok(_) => {}
739        Err(err) => report_error(ctx, &err),
740    };
741}
742
743#[cfg(feature = "window")]
744unsafe extern "C" fn call_boxed_inverse<A, W, T>(
745    ctx: *mut sqlite3_context,
746    argc: c_int,
747    argv: *mut *mut sqlite3_value,
748) where
749    A: RefUnwindSafe + UnwindSafe,
750    W: WindowAggregate<A, T>,
751    T: SqlFnOutput,
752{
753    let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
754        ffi::sqlite3_result_error_nomem(ctx);
755        return;
756    };
757
758    let r = catch_unwind(|| {
759        let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
760        assert!(
761            !boxed_aggr.is_null(),
762            "Internal error - null aggregate pointer"
763        );
764        let mut ctx = Context {
765            ctx,
766            args: slice::from_raw_parts(argv, argc as usize),
767        };
768        (*boxed_aggr).inverse(&mut ctx, &mut **pac)
769    });
770    let r = match r {
771        Err(_) => {
772            report_error(ctx, &Error::UnwindingPanic);
773            return;
774        }
775        Ok(r) => r,
776    };
777    match r {
778        Ok(_) => {}
779        Err(err) => report_error(ctx, &err),
780    };
781}
782
783unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
784where
785    A: RefUnwindSafe + UnwindSafe,
786    D: Aggregate<A, T>,
787    T: SqlFnOutput,
788{
789    // Within the xFinal callback, it is customary to set N=0 in calls to
790    // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
791    let a: Option<A> = match aggregate_context(ctx, 0) {
792        Some(pac) =>
793        {
794            #[expect(clippy::unnecessary_cast)]
795            if (*pac as *mut A).is_null() {
796                None
797            } else {
798                let a = Box::from_raw(*pac);
799                Some(*a)
800            }
801        }
802        None => None,
803    };
804
805    let r = catch_unwind(|| {
806        let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
807        assert!(
808            !boxed_aggr.is_null(),
809            "Internal error - null aggregate pointer"
810        );
811        let mut ctx = Context { ctx, args: &mut [] };
812        (*boxed_aggr).finalize(&mut ctx, a)
813    });
814    let t = match r {
815        Err(_) => {
816            report_error(ctx, &Error::UnwindingPanic);
817            return;
818        }
819        Ok(r) => r,
820    };
821    sql_result(ctx, &[], t);
822}
823
824#[cfg(feature = "window")]
825unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
826where
827    A: RefUnwindSafe + UnwindSafe,
828    W: WindowAggregate<A, T>,
829    T: SqlFnOutput,
830{
831    // Within the xValue callback, it is customary to set N=0 in calls to
832    // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
833    let pac = aggregate_context(ctx, 0).filter(|&pac| {
834        #[expect(clippy::unnecessary_cast)]
835        !(*pac as *mut A).is_null()
836    });
837
838    let r = catch_unwind(|| {
839        let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
840        assert!(
841            !boxed_aggr.is_null(),
842            "Internal error - null aggregate pointer"
843        );
844        (*boxed_aggr).value(pac.map(|pac| &mut **pac))
845    });
846    let t = match r {
847        Err(_) => {
848            report_error(ctx, &Error::UnwindingPanic);
849            return;
850        }
851        Ok(r) => r,
852    };
853    sql_result(ctx, &[], t);
854}
855
856#[cfg(test)]
857mod test {
858    use regex::Regex;
859    use std::ffi::c_double;
860
861    #[cfg(feature = "window")]
862    use crate::functions::WindowAggregate;
863    use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType};
864    use crate::{Connection, Error, Result};
865
866    fn half(ctx: &Context<'_>) -> Result<c_double> {
867        assert!(!ctx.is_empty());
868        assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
869        assert!(unsafe {
870            ctx.get_connection()
871                .as_ref()
872                .map(::std::ops::Deref::deref)
873                .is_ok()
874        });
875        let value = ctx.get::<c_double>(0)?;
876        Ok(value / 2f64)
877    }
878
879    #[test]
880    fn test_function_half() -> Result<()> {
881        let db = Connection::open_in_memory()?;
882        db.create_scalar_function(
883            c"half",
884            1,
885            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
886            half,
887        )?;
888        let result: f64 = db.one_column("SELECT half(6)", [])?;
889
890        assert!((3f64 - result).abs() < f64::EPSILON);
891        Ok(())
892    }
893
894    #[test]
895    fn test_remove_function() -> Result<()> {
896        let db = Connection::open_in_memory()?;
897        db.create_scalar_function(
898            c"half",
899            1,
900            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
901            half,
902        )?;
903        assert!((3f64 - db.one_column::<f64, _>("SELECT half(6)", [])?).abs() < f64::EPSILON);
904
905        db.remove_function(c"half", 1)?;
906        db.one_column::<f64, _>("SELECT half(6)", []).unwrap_err();
907        Ok(())
908    }
909
910    // This implementation of a regexp scalar function uses SQLite's auxiliary data
911    // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
912    // expression multiple times within one query.
913    fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<bool> {
914        assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
915        type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
916        let regexp: std::sync::Arc<Regex> = ctx
917            .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
918                Ok(Regex::new(vr.as_str()?)?)
919            })?;
920
921        let is_match = {
922            let text = ctx
923                .get_raw(1)
924                .as_str()
925                .map_err(|e| Error::UserFunctionError(e.into()))?;
926
927            regexp.is_match(text)
928        };
929
930        Ok(is_match)
931    }
932
933    #[test]
934    fn test_function_regexp_with_auxiliary() -> Result<()> {
935        let db = Connection::open_in_memory()?;
936        db.execute_batch(
937            "BEGIN;
938             CREATE TABLE foo (x string);
939             INSERT INTO foo VALUES ('lisa');
940             INSERT INTO foo VALUES ('lXsi');
941             INSERT INTO foo VALUES ('lisX');
942             END;",
943        )?;
944        db.create_scalar_function(
945            c"regexp",
946            2,
947            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
948            regexp_with_auxiliary,
949        )?;
950
951        assert!(db.one_column::<bool, _>("SELECT regexp('l.s[aeiouy]', 'lisa')", [])?);
952
953        assert_eq!(
954            2,
955            db.one_column::<i64, _>(
956                "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
957                [],
958            )?
959        );
960        Ok(())
961    }
962
963    #[test]
964    fn test_varargs_function() -> Result<()> {
965        let db = Connection::open_in_memory()?;
966        db.create_scalar_function(
967            c"my_concat",
968            -1,
969            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
970            |ctx| {
971                let mut ret = String::new();
972
973                for idx in 0..ctx.len() {
974                    let s = ctx.get::<String>(idx)?;
975                    ret.push_str(&s);
976                }
977
978                Ok(ret)
979            },
980        )?;
981
982        for &(expected, query) in &[
983            ("", "SELECT my_concat()"),
984            ("onetwo", "SELECT my_concat('one', 'two')"),
985            ("abc", "SELECT my_concat('a', 'b', 'c')"),
986        ] {
987            assert_eq!(expected, db.one_column::<String, _>(query, [])?);
988        }
989        Ok(())
990    }
991
992    #[test]
993    fn test_get_aux_type_checking() -> Result<()> {
994        let db = Connection::open_in_memory()?;
995        db.create_scalar_function(c"example", 2, FunctionFlags::default(), |ctx| {
996            if !ctx.get::<bool>(1)? {
997                ctx.set_aux::<i64>(0, 100)?;
998            } else {
999                assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
1000                assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
1001            }
1002            Ok(true)
1003        })?;
1004
1005        let res: bool = db.query_row(
1006            "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
1007            [],
1008            |r| r.get(0),
1009        )?;
1010        // Doesn't actually matter, we'll assert in the function if there's a problem.
1011        assert!(res);
1012        Ok(())
1013    }
1014
1015    struct Sum;
1016    struct Count;
1017
1018    impl Aggregate<i64, Option<i64>> for Sum {
1019        fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1020            Ok(0)
1021        }
1022
1023        fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1024            *sum += ctx.get::<i64>(0)?;
1025            Ok(())
1026        }
1027
1028        fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
1029            Ok(sum)
1030        }
1031    }
1032
1033    impl Aggregate<i64, i64> for Count {
1034        fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1035            Ok(0)
1036        }
1037
1038        fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1039            *sum += 1;
1040            Ok(())
1041        }
1042
1043        fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
1044            Ok(sum.unwrap_or(0))
1045        }
1046    }
1047
1048    #[test]
1049    fn test_sum() -> Result<()> {
1050        let db = Connection::open_in_memory()?;
1051        db.create_aggregate_function(
1052            c"my_sum",
1053            1,
1054            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1055            Sum,
1056        )?;
1057
1058        // sum should return NULL when given no columns (contrast with count below)
1059        let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1060        assert!(db.one_column::<Option<i64>, _>(no_result, [])?.is_none());
1061
1062        let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1063        assert_eq!(4, db.one_column::<i64, _>(single_sum, [])?);
1064
1065        let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
1066                        2, 1)";
1067        let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
1068        assert_eq!((4, 2), result);
1069        Ok(())
1070    }
1071
1072    #[test]
1073    fn test_count() -> Result<()> {
1074        let db = Connection::open_in_memory()?;
1075        db.create_aggregate_function(
1076            c"my_count",
1077            -1,
1078            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1079            Count,
1080        )?;
1081
1082        // count should return 0 when given no columns (contrast with sum above)
1083        let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1084        assert_eq!(db.one_column::<i64, _>(no_result, [])?, 0);
1085
1086        let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1087        assert_eq!(2, db.one_column::<i64, _>(single_sum, [])?);
1088        Ok(())
1089    }
1090
1091    #[cfg(feature = "window")]
1092    impl WindowAggregate<i64, Option<i64>> for Sum {
1093        fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1094            *sum -= ctx.get::<i64>(0)?;
1095            Ok(())
1096        }
1097
1098        fn value(&self, sum: Option<&mut i64>) -> Result<Option<i64>> {
1099            Ok(sum.copied())
1100        }
1101    }
1102
1103    #[test]
1104    #[cfg(feature = "window")]
1105    fn test_window() -> Result<()> {
1106        use fallible_iterator::FallibleIterator;
1107
1108        let db = Connection::open_in_memory()?;
1109        db.create_window_function(
1110            c"sumint",
1111            1,
1112            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1113            Sum,
1114        )?;
1115        db.execute_batch(
1116            "CREATE TABLE t3(x, y);
1117             INSERT INTO t3 VALUES('a', 4),
1118                     ('b', 5),
1119                     ('c', 3),
1120                     ('d', 8),
1121                     ('e', 1);",
1122        )?;
1123
1124        let mut stmt = db.prepare(
1125            "SELECT x, sumint(y) OVER (
1126                   ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1127                 ) AS sum_y
1128                 FROM t3 ORDER BY x;",
1129        )?;
1130
1131        let results: Vec<(String, i64)> = stmt
1132            .query([])?
1133            .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1134            .collect()?;
1135        let expected = vec![
1136            ("a".to_owned(), 9),
1137            ("b".to_owned(), 12),
1138            ("c".to_owned(), 16),
1139            ("d".to_owned(), 12),
1140            ("e".to_owned(), 9),
1141        ];
1142        assert_eq!(expected, results);
1143        Ok(())
1144    }
1145
1146    #[test]
1147    fn test_sub_type() -> Result<()> {
1148        fn test_getsubtype(ctx: &Context<'_>) -> Result<i32> {
1149            Ok(ctx.get_subtype(0) as i32)
1150        }
1151        fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> {
1152            use std::ffi::c_uint;
1153            let value = ctx.get_arg(0);
1154            let sub_type = ctx.get::<c_uint>(1)?;
1155            Ok((value, Some(sub_type)))
1156        }
1157        let db = Connection::open_in_memory()?;
1158        db.create_scalar_function(
1159            c"test_getsubtype",
1160            1,
1161            FunctionFlags::SQLITE_UTF8,
1162            test_getsubtype,
1163        )?;
1164        db.create_scalar_function(
1165            c"test_setsubtype",
1166            2,
1167            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_RESULT_SUBTYPE,
1168            test_setsubtype,
1169        )?;
1170        let result: i32 = db.one_column("SELECT test_getsubtype('hello');", [])?;
1171        assert_eq!(0, result);
1172
1173        let result: i32 =
1174            db.one_column("SELECT test_getsubtype(test_setsubtype('hello',123));", [])?;
1175        assert_eq!(123, result);
1176
1177        Ok(())
1178    }
1179
1180    #[test]
1181    fn test_blob() -> Result<()> {
1182        fn test_len(ctx: &Context<'_>) -> Result<usize> {
1183            let blob = ctx.get_raw(0);
1184            Ok(blob.as_bytes_or_null()?.map_or(0, |b| b.len()))
1185        }
1186        let db = Connection::open_in_memory()?;
1187        db.create_scalar_function("test_len", 1, FunctionFlags::SQLITE_DETERMINISTIC, test_len)?;
1188        assert_eq!(
1189            6,
1190            db.one_column::<usize, _>("SELECT test_len(X'53514C697465');", [])?
1191        );
1192        assert_eq!(0, db.one_column::<usize, _>("SELECT test_len(X'');", [])?);
1193        assert_eq!(0, db.one_column::<usize, _>("SELECT test_len(NULL);", [])?);
1194        Ok(())
1195    }
1196}