1use std::any::Any;
56use std::ffi::{c_int, c_uint, c_void};
57use std::marker::PhantomData;
58use std::ops::Deref;
59use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60use std::ptr;
61use std::slice;
62use std::sync::Arc;
63
64use crate::ffi;
65use crate::ffi::sqlite3_context;
66use crate::ffi::sqlite3_value;
67
68use crate::context::set_result;
69use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
70use crate::util::free_boxed_value;
71use crate::{str_to_cstring, Connection, Error, InnerConnection, Name, Result};
72
73unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74 if let Error::SqliteFailure(ref err, ref s) = *err {
75 ffi::sqlite3_result_error_code(ctx, err.extended_code);
76 if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
77 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
78 }
79 } else {
80 ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION);
81 if let Ok(cstr) = str_to_cstring(&err.to_string()) {
82 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
83 }
84 }
85}
86
87pub struct Context<'a> {
90 ctx: *mut sqlite3_context,
91 args: &'a [*mut sqlite3_value],
92}
93
94impl Context<'_> {
95 #[inline]
97 #[must_use]
98 pub fn len(&self) -> usize {
99 self.args.len()
100 }
101
102 #[inline]
104 #[must_use]
105 pub fn is_empty(&self) -> bool {
106 self.args.is_empty()
107 }
108
109 pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
119 let arg = self.args[idx];
120 let value = unsafe { ValueRef::from_value(arg) };
121 FromSql::column_result(value).map_err(|err| match err {
122 FromSqlError::InvalidType => {
123 Error::InvalidFunctionParameterType(idx, value.data_type())
124 }
125 FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
126 FromSqlError::Other(err) => {
127 Error::FromSqlConversionFailure(idx, value.data_type(), err)
128 }
129 FromSqlError::InvalidBlobSize { .. } => {
130 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
131 }
132 })
133 }
134
135 #[inline]
142 #[must_use]
143 pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
144 let arg = self.args[idx];
145 unsafe { ValueRef::from_value(arg) }
146 }
147
148 #[inline]
151 #[must_use]
152 pub fn get_arg(&self, idx: usize) -> SqlFnArg {
153 assert!(idx < self.len());
154 SqlFnArg { idx }
155 }
156
157 pub fn get_subtype(&self, idx: usize) -> c_uint {
164 let arg = self.args[idx];
165 unsafe { ffi::sqlite3_value_subtype(arg) }
166 }
167
168 pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
181 where
182 T: Send + Sync + 'static,
183 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
184 F: FnOnce(ValueRef<'_>) -> Result<T, E>,
185 {
186 if let Some(v) = self.get_aux(arg)? {
187 Ok(v)
188 } else {
189 let vr = self.get_raw(arg as usize);
190 self.set_aux(
191 arg,
192 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
193 )
194 }
195 }
196
197 pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
206 assert!(arg < self.len() as i32);
207 let orig: Arc<T> = Arc::new(value);
208 let inner: AuxInner = orig.clone();
209 let outer = Box::new(inner);
210 let raw: *mut AuxInner = Box::into_raw(outer);
211 unsafe {
212 ffi::sqlite3_set_auxdata(
213 self.ctx,
214 arg,
215 raw.cast(),
216 Some(free_boxed_value::<AuxInner>),
217 );
218 };
219 Ok(orig)
220 }
221
222 pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
232 assert!(arg < self.len() as i32);
233 let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
234 if p.is_null() {
235 Ok(None)
236 } else {
237 let v: AuxInner = AuxInner::clone(unsafe { &*p });
238 v.downcast::<T>()
239 .map(Some)
240 .map_err(|_| Error::GetAuxWrongType)
241 }
242 }
243
244 pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
251 let handle = ffi::sqlite3_context_db_handle(self.ctx);
252 Ok(ConnectionRef {
253 conn: Connection::from_handle(handle)?,
254 phantom: PhantomData,
255 })
256 }
257}
258
259pub struct ConnectionRef<'ctx> {
261 conn: Connection,
264 phantom: PhantomData<&'ctx Context<'ctx>>,
265}
266
267impl Deref for ConnectionRef<'_> {
268 type Target = Connection;
269
270 #[inline]
271 fn deref(&self) -> &Connection {
272 &self.conn
273 }
274}
275
276type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
277
278pub type SubType = Option<c_uint>;
280
281pub trait SqlFnOutput {
283 fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>;
285}
286
287impl<T: ToSql> SqlFnOutput for T {
288 #[inline]
289 fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
290 ToSql::to_sql(self).map(|o| (o, None))
291 }
292}
293
294impl<T: ToSql> SqlFnOutput for (T, SubType) {
295 fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
296 ToSql::to_sql(&self.0).map(|o| (o, self.1))
297 }
298}
299
300pub struct SqlFnArg {
302 idx: usize,
303}
304impl ToSql for SqlFnArg {
305 fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
306 Ok(ToSqlOutput::Arg(self.idx))
307 }
308}
309
310unsafe fn sql_result<T: SqlFnOutput>(
311 ctx: *mut sqlite3_context,
312 args: &[*mut sqlite3_value],
313 r: Result<T>,
314) {
315 let t = r.as_ref().map(SqlFnOutput::to_sql);
316
317 match t {
318 Ok(Ok((ref value, sub_type))) => {
319 set_result(ctx, args, value);
320 if let Some(sub_type) = sub_type {
321 ffi::sqlite3_result_subtype(ctx, sub_type);
322 }
323 }
324 Ok(Err(err)) => report_error(ctx, &err),
325 Err(err) => report_error(ctx, err),
326 };
327}
328
329pub trait Aggregate<A, T>
335where
336 A: RefUnwindSafe + UnwindSafe,
337 T: SqlFnOutput,
338{
339 fn init(&self, ctx: &mut Context<'_>) -> Result<A>;
344
345 fn step(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
348
349 fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<T>;
359}
360
361#[cfg(feature = "window")]
364pub trait WindowAggregate<A, T>: Aggregate<A, T>
365where
366 A: RefUnwindSafe + UnwindSafe,
367 T: SqlFnOutput,
368{
369 fn value(&self, acc: Option<&mut A>) -> Result<T>;
372
373 fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
375}
376
377bitflags::bitflags! {
378 #[derive(Clone, Copy, Debug)]
382 #[repr(C)]
383 pub struct FunctionFlags: c_int {
384 const SQLITE_UTF8 = ffi::SQLITE_UTF8;
386 const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE;
388 const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE;
390 const SQLITE_UTF16 = ffi::SQLITE_UTF16;
392 const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; const SQLITE_DIRECTONLY = 0x0000_0008_0000; const SQLITE_SUBTYPE = 0x0000_0010_0000; const SQLITE_INNOCUOUS = 0x0000_0020_0000; const SQLITE_RESULT_SUBTYPE = 0x0000_0100_0000; const SQLITE_SELFORDER1 = 0x0000_0200_0000; }
405}
406
407impl Default for FunctionFlags {
408 #[inline]
409 fn default() -> Self {
410 Self::SQLITE_UTF8
411 }
412}
413
414impl Connection {
415 #[inline]
453 pub fn create_scalar_function<F, N: Name, T>(
454 &self,
455 fn_name: N,
456 n_arg: c_int,
457 flags: FunctionFlags,
458 x_func: F,
459 ) -> Result<()>
460 where
461 F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
462 T: SqlFnOutput,
463 {
464 self.db
465 .borrow_mut()
466 .create_scalar_function(fn_name, n_arg, flags, x_func)
467 }
468
469 #[inline]
476 pub fn create_aggregate_function<A, D, N: Name, T>(
477 &self,
478 fn_name: N,
479 n_arg: c_int,
480 flags: FunctionFlags,
481 aggr: D,
482 ) -> Result<()>
483 where
484 A: RefUnwindSafe + UnwindSafe,
485 D: Aggregate<A, T> + 'static,
486 T: SqlFnOutput,
487 {
488 self.db
489 .borrow_mut()
490 .create_aggregate_function(fn_name, n_arg, flags, aggr)
491 }
492
493 #[cfg(feature = "window")]
499 #[inline]
500 pub fn create_window_function<A, N: Name, W, T>(
501 &self,
502 fn_name: N,
503 n_arg: c_int,
504 flags: FunctionFlags,
505 aggr: W,
506 ) -> Result<()>
507 where
508 A: RefUnwindSafe + UnwindSafe,
509 W: WindowAggregate<A, T> + 'static,
510 T: SqlFnOutput,
511 {
512 self.db
513 .borrow_mut()
514 .create_window_function(fn_name, n_arg, flags, aggr)
515 }
516
517 #[inline]
528 pub fn remove_function<N: Name>(&self, fn_name: N, n_arg: c_int) -> Result<()> {
529 self.db.borrow_mut().remove_function(fn_name, n_arg)
530 }
531}
532
533impl InnerConnection {
534 fn create_scalar_function<F, N: Name, T>(
556 &mut self,
557 fn_name: N,
558 n_arg: c_int,
559 flags: FunctionFlags,
560 x_func: F,
561 ) -> Result<()>
562 where
563 F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
564 T: SqlFnOutput,
565 {
566 unsafe extern "C" fn call_boxed_closure<F, T>(
567 ctx: *mut sqlite3_context,
568 argc: c_int,
569 argv: *mut *mut sqlite3_value,
570 ) where
571 F: Fn(&Context<'_>) -> Result<T>,
572 T: SqlFnOutput,
573 {
574 let args = slice::from_raw_parts(argv, argc as usize);
575 let r = catch_unwind(|| {
576 let boxed_f: *const F = ffi::sqlite3_user_data(ctx).cast::<F>();
577 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
578 let ctx = Context { ctx, args };
579 (*boxed_f)(&ctx)
580 });
581 let t = match r {
582 Err(_) => {
583 report_error(ctx, &Error::UnwindingPanic);
584 return;
585 }
586 Ok(r) => r,
587 };
588 sql_result(ctx, args, t);
589 }
590
591 let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
592 let c_name = fn_name.as_cstr()?;
593 let r = unsafe {
594 ffi::sqlite3_create_function_v2(
595 self.db(),
596 c_name.as_ptr(),
597 n_arg,
598 flags.bits(),
599 boxed_f.cast::<c_void>(),
600 Some(call_boxed_closure::<F, T>),
601 None,
602 None,
603 Some(free_boxed_value::<F>),
604 )
605 };
606 self.decode_result(r)
607 }
608
609 fn create_aggregate_function<A, D, N: Name, T>(
610 &mut self,
611 fn_name: N,
612 n_arg: c_int,
613 flags: FunctionFlags,
614 aggr: D,
615 ) -> Result<()>
616 where
617 A: RefUnwindSafe + UnwindSafe,
618 D: Aggregate<A, T> + 'static,
619 T: SqlFnOutput,
620 {
621 let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
622 let c_name = fn_name.as_cstr()?;
623 let r = unsafe {
624 ffi::sqlite3_create_function_v2(
625 self.db(),
626 c_name.as_ptr(),
627 n_arg,
628 flags.bits(),
629 boxed_aggr.cast::<c_void>(),
630 None,
631 Some(call_boxed_step::<A, D, T>),
632 Some(call_boxed_final::<A, D, T>),
633 Some(free_boxed_value::<D>),
634 )
635 };
636 self.decode_result(r)
637 }
638
639 #[cfg(feature = "window")]
640 fn create_window_function<A, N: Name, W, T>(
641 &mut self,
642 fn_name: N,
643 n_arg: c_int,
644 flags: FunctionFlags,
645 aggr: W,
646 ) -> Result<()>
647 where
648 A: RefUnwindSafe + UnwindSafe,
649 W: WindowAggregate<A, T> + 'static,
650 T: SqlFnOutput,
651 {
652 let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
653 let c_name = fn_name.as_cstr()?;
654 let r = unsafe {
655 ffi::sqlite3_create_window_function(
656 self.db(),
657 c_name.as_ptr(),
658 n_arg,
659 flags.bits(),
660 boxed_aggr.cast::<c_void>(),
661 Some(call_boxed_step::<A, W, T>),
662 Some(call_boxed_final::<A, W, T>),
663 Some(call_boxed_value::<A, W, T>),
664 Some(call_boxed_inverse::<A, W, T>),
665 Some(free_boxed_value::<W>),
666 )
667 };
668 self.decode_result(r)
669 }
670
671 fn remove_function<N: Name>(&mut self, fn_name: N, n_arg: c_int) -> Result<()> {
672 let c_name = fn_name.as_cstr()?;
673 let r = unsafe {
674 ffi::sqlite3_create_function_v2(
675 self.db(),
676 c_name.as_ptr(),
677 n_arg,
678 ffi::SQLITE_UTF8,
679 ptr::null_mut(),
680 None,
681 None,
682 None,
683 None,
684 )
685 };
686 self.decode_result(r)
687 }
688}
689
690unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
691 let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
692 if pac.is_null() {
693 return None;
694 }
695 Some(pac)
696}
697
698unsafe extern "C" fn call_boxed_step<A, D, T>(
699 ctx: *mut sqlite3_context,
700 argc: c_int,
701 argv: *mut *mut sqlite3_value,
702) where
703 A: RefUnwindSafe + UnwindSafe,
704 D: Aggregate<A, T>,
705 T: SqlFnOutput,
706{
707 let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
708 ffi::sqlite3_result_error_nomem(ctx);
709 return;
710 };
711
712 let r = catch_unwind(|| {
713 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
714 assert!(
715 !boxed_aggr.is_null(),
716 "Internal error - null aggregate pointer"
717 );
718 let mut ctx = Context {
719 ctx,
720 args: slice::from_raw_parts(argv, argc as usize),
721 };
722
723 #[expect(clippy::unnecessary_cast)]
724 if (*pac as *mut A).is_null() {
725 *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
726 }
727
728 (*boxed_aggr).step(&mut ctx, &mut **pac)
729 });
730 let r = match r {
731 Err(_) => {
732 report_error(ctx, &Error::UnwindingPanic);
733 return;
734 }
735 Ok(r) => r,
736 };
737 match r {
738 Ok(_) => {}
739 Err(err) => report_error(ctx, &err),
740 };
741}
742
743#[cfg(feature = "window")]
744unsafe extern "C" fn call_boxed_inverse<A, W, T>(
745 ctx: *mut sqlite3_context,
746 argc: c_int,
747 argv: *mut *mut sqlite3_value,
748) where
749 A: RefUnwindSafe + UnwindSafe,
750 W: WindowAggregate<A, T>,
751 T: SqlFnOutput,
752{
753 let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
754 ffi::sqlite3_result_error_nomem(ctx);
755 return;
756 };
757
758 let r = catch_unwind(|| {
759 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
760 assert!(
761 !boxed_aggr.is_null(),
762 "Internal error - null aggregate pointer"
763 );
764 let mut ctx = Context {
765 ctx,
766 args: slice::from_raw_parts(argv, argc as usize),
767 };
768 (*boxed_aggr).inverse(&mut ctx, &mut **pac)
769 });
770 let r = match r {
771 Err(_) => {
772 report_error(ctx, &Error::UnwindingPanic);
773 return;
774 }
775 Ok(r) => r,
776 };
777 match r {
778 Ok(_) => {}
779 Err(err) => report_error(ctx, &err),
780 };
781}
782
783unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
784where
785 A: RefUnwindSafe + UnwindSafe,
786 D: Aggregate<A, T>,
787 T: SqlFnOutput,
788{
789 let a: Option<A> = match aggregate_context(ctx, 0) {
792 Some(pac) =>
793 {
794 #[expect(clippy::unnecessary_cast)]
795 if (*pac as *mut A).is_null() {
796 None
797 } else {
798 let a = Box::from_raw(*pac);
799 Some(*a)
800 }
801 }
802 None => None,
803 };
804
805 let r = catch_unwind(|| {
806 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
807 assert!(
808 !boxed_aggr.is_null(),
809 "Internal error - null aggregate pointer"
810 );
811 let mut ctx = Context { ctx, args: &mut [] };
812 (*boxed_aggr).finalize(&mut ctx, a)
813 });
814 let t = match r {
815 Err(_) => {
816 report_error(ctx, &Error::UnwindingPanic);
817 return;
818 }
819 Ok(r) => r,
820 };
821 sql_result(ctx, &[], t);
822}
823
824#[cfg(feature = "window")]
825unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
826where
827 A: RefUnwindSafe + UnwindSafe,
828 W: WindowAggregate<A, T>,
829 T: SqlFnOutput,
830{
831 let pac = aggregate_context(ctx, 0).filter(|&pac| {
834 #[expect(clippy::unnecessary_cast)]
835 !(*pac as *mut A).is_null()
836 });
837
838 let r = catch_unwind(|| {
839 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
840 assert!(
841 !boxed_aggr.is_null(),
842 "Internal error - null aggregate pointer"
843 );
844 (*boxed_aggr).value(pac.map(|pac| &mut **pac))
845 });
846 let t = match r {
847 Err(_) => {
848 report_error(ctx, &Error::UnwindingPanic);
849 return;
850 }
851 Ok(r) => r,
852 };
853 sql_result(ctx, &[], t);
854}
855
856#[cfg(test)]
857mod test {
858 use regex::Regex;
859 use std::ffi::c_double;
860
861 #[cfg(feature = "window")]
862 use crate::functions::WindowAggregate;
863 use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType};
864 use crate::{Connection, Error, Result};
865
866 fn half(ctx: &Context<'_>) -> Result<c_double> {
867 assert!(!ctx.is_empty());
868 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
869 assert!(unsafe {
870 ctx.get_connection()
871 .as_ref()
872 .map(::std::ops::Deref::deref)
873 .is_ok()
874 });
875 let value = ctx.get::<c_double>(0)?;
876 Ok(value / 2f64)
877 }
878
879 #[test]
880 fn test_function_half() -> Result<()> {
881 let db = Connection::open_in_memory()?;
882 db.create_scalar_function(
883 c"half",
884 1,
885 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
886 half,
887 )?;
888 let result: f64 = db.one_column("SELECT half(6)", [])?;
889
890 assert!((3f64 - result).abs() < f64::EPSILON);
891 Ok(())
892 }
893
894 #[test]
895 fn test_remove_function() -> Result<()> {
896 let db = Connection::open_in_memory()?;
897 db.create_scalar_function(
898 c"half",
899 1,
900 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
901 half,
902 )?;
903 assert!((3f64 - db.one_column::<f64, _>("SELECT half(6)", [])?).abs() < f64::EPSILON);
904
905 db.remove_function(c"half", 1)?;
906 db.one_column::<f64, _>("SELECT half(6)", []).unwrap_err();
907 Ok(())
908 }
909
910 fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<bool> {
914 assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
915 type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
916 let regexp: std::sync::Arc<Regex> = ctx
917 .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
918 Ok(Regex::new(vr.as_str()?)?)
919 })?;
920
921 let is_match = {
922 let text = ctx
923 .get_raw(1)
924 .as_str()
925 .map_err(|e| Error::UserFunctionError(e.into()))?;
926
927 regexp.is_match(text)
928 };
929
930 Ok(is_match)
931 }
932
933 #[test]
934 fn test_function_regexp_with_auxiliary() -> Result<()> {
935 let db = Connection::open_in_memory()?;
936 db.execute_batch(
937 "BEGIN;
938 CREATE TABLE foo (x string);
939 INSERT INTO foo VALUES ('lisa');
940 INSERT INTO foo VALUES ('lXsi');
941 INSERT INTO foo VALUES ('lisX');
942 END;",
943 )?;
944 db.create_scalar_function(
945 c"regexp",
946 2,
947 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
948 regexp_with_auxiliary,
949 )?;
950
951 assert!(db.one_column::<bool, _>("SELECT regexp('l.s[aeiouy]', 'lisa')", [])?);
952
953 assert_eq!(
954 2,
955 db.one_column::<i64, _>(
956 "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
957 [],
958 )?
959 );
960 Ok(())
961 }
962
963 #[test]
964 fn test_varargs_function() -> Result<()> {
965 let db = Connection::open_in_memory()?;
966 db.create_scalar_function(
967 c"my_concat",
968 -1,
969 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
970 |ctx| {
971 let mut ret = String::new();
972
973 for idx in 0..ctx.len() {
974 let s = ctx.get::<String>(idx)?;
975 ret.push_str(&s);
976 }
977
978 Ok(ret)
979 },
980 )?;
981
982 for &(expected, query) in &[
983 ("", "SELECT my_concat()"),
984 ("onetwo", "SELECT my_concat('one', 'two')"),
985 ("abc", "SELECT my_concat('a', 'b', 'c')"),
986 ] {
987 assert_eq!(expected, db.one_column::<String, _>(query, [])?);
988 }
989 Ok(())
990 }
991
992 #[test]
993 fn test_get_aux_type_checking() -> Result<()> {
994 let db = Connection::open_in_memory()?;
995 db.create_scalar_function(c"example", 2, FunctionFlags::default(), |ctx| {
996 if !ctx.get::<bool>(1)? {
997 ctx.set_aux::<i64>(0, 100)?;
998 } else {
999 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
1000 assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
1001 }
1002 Ok(true)
1003 })?;
1004
1005 let res: bool = db.query_row(
1006 "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
1007 [],
1008 |r| r.get(0),
1009 )?;
1010 assert!(res);
1012 Ok(())
1013 }
1014
1015 struct Sum;
1016 struct Count;
1017
1018 impl Aggregate<i64, Option<i64>> for Sum {
1019 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1020 Ok(0)
1021 }
1022
1023 fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1024 *sum += ctx.get::<i64>(0)?;
1025 Ok(())
1026 }
1027
1028 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
1029 Ok(sum)
1030 }
1031 }
1032
1033 impl Aggregate<i64, i64> for Count {
1034 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1035 Ok(0)
1036 }
1037
1038 fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1039 *sum += 1;
1040 Ok(())
1041 }
1042
1043 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
1044 Ok(sum.unwrap_or(0))
1045 }
1046 }
1047
1048 #[test]
1049 fn test_sum() -> Result<()> {
1050 let db = Connection::open_in_memory()?;
1051 db.create_aggregate_function(
1052 c"my_sum",
1053 1,
1054 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1055 Sum,
1056 )?;
1057
1058 let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1060 assert!(db.one_column::<Option<i64>, _>(no_result, [])?.is_none());
1061
1062 let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1063 assert_eq!(4, db.one_column::<i64, _>(single_sum, [])?);
1064
1065 let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
1066 2, 1)";
1067 let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
1068 assert_eq!((4, 2), result);
1069 Ok(())
1070 }
1071
1072 #[test]
1073 fn test_count() -> Result<()> {
1074 let db = Connection::open_in_memory()?;
1075 db.create_aggregate_function(
1076 c"my_count",
1077 -1,
1078 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1079 Count,
1080 )?;
1081
1082 let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1084 assert_eq!(db.one_column::<i64, _>(no_result, [])?, 0);
1085
1086 let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1087 assert_eq!(2, db.one_column::<i64, _>(single_sum, [])?);
1088 Ok(())
1089 }
1090
1091 #[cfg(feature = "window")]
1092 impl WindowAggregate<i64, Option<i64>> for Sum {
1093 fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1094 *sum -= ctx.get::<i64>(0)?;
1095 Ok(())
1096 }
1097
1098 fn value(&self, sum: Option<&mut i64>) -> Result<Option<i64>> {
1099 Ok(sum.copied())
1100 }
1101 }
1102
1103 #[test]
1104 #[cfg(feature = "window")]
1105 fn test_window() -> Result<()> {
1106 use fallible_iterator::FallibleIterator;
1107
1108 let db = Connection::open_in_memory()?;
1109 db.create_window_function(
1110 c"sumint",
1111 1,
1112 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1113 Sum,
1114 )?;
1115 db.execute_batch(
1116 "CREATE TABLE t3(x, y);
1117 INSERT INTO t3 VALUES('a', 4),
1118 ('b', 5),
1119 ('c', 3),
1120 ('d', 8),
1121 ('e', 1);",
1122 )?;
1123
1124 let mut stmt = db.prepare(
1125 "SELECT x, sumint(y) OVER (
1126 ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1127 ) AS sum_y
1128 FROM t3 ORDER BY x;",
1129 )?;
1130
1131 let results: Vec<(String, i64)> = stmt
1132 .query([])?
1133 .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1134 .collect()?;
1135 let expected = vec![
1136 ("a".to_owned(), 9),
1137 ("b".to_owned(), 12),
1138 ("c".to_owned(), 16),
1139 ("d".to_owned(), 12),
1140 ("e".to_owned(), 9),
1141 ];
1142 assert_eq!(expected, results);
1143 Ok(())
1144 }
1145
1146 #[test]
1147 fn test_sub_type() -> Result<()> {
1148 fn test_getsubtype(ctx: &Context<'_>) -> Result<i32> {
1149 Ok(ctx.get_subtype(0) as i32)
1150 }
1151 fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> {
1152 use std::ffi::c_uint;
1153 let value = ctx.get_arg(0);
1154 let sub_type = ctx.get::<c_uint>(1)?;
1155 Ok((value, Some(sub_type)))
1156 }
1157 let db = Connection::open_in_memory()?;
1158 db.create_scalar_function(
1159 c"test_getsubtype",
1160 1,
1161 FunctionFlags::SQLITE_UTF8,
1162 test_getsubtype,
1163 )?;
1164 db.create_scalar_function(
1165 c"test_setsubtype",
1166 2,
1167 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_RESULT_SUBTYPE,
1168 test_setsubtype,
1169 )?;
1170 let result: i32 = db.one_column("SELECT test_getsubtype('hello');", [])?;
1171 assert_eq!(0, result);
1172
1173 let result: i32 =
1174 db.one_column("SELECT test_getsubtype(test_setsubtype('hello',123));", [])?;
1175 assert_eq!(123, result);
1176
1177 Ok(())
1178 }
1179
1180 #[test]
1181 fn test_blob() -> Result<()> {
1182 fn test_len(ctx: &Context<'_>) -> Result<usize> {
1183 let blob = ctx.get_raw(0);
1184 Ok(blob.as_bytes_or_null()?.map_or(0, |b| b.len()))
1185 }
1186 let db = Connection::open_in_memory()?;
1187 db.create_scalar_function("test_len", 1, FunctionFlags::SQLITE_DETERMINISTIC, test_len)?;
1188 assert_eq!(
1189 6,
1190 db.one_column::<usize, _>("SELECT test_len(X'53514C697465');", [])?
1191 );
1192 assert_eq!(0, db.one_column::<usize, _>("SELECT test_len(X'');", [])?);
1193 assert_eq!(0, db.one_column::<usize, _>("SELECT test_len(NULL);", [])?);
1194 Ok(())
1195 }
1196}