rusqlite/
collation.rs

1//! Add, remove, or modify a collation
2use std::cmp::Ordering;
3use std::ffi::{c_char, c_int, c_void, CStr};
4use std::panic::catch_unwind;
5use std::ptr;
6use std::slice;
7
8use crate::ffi;
9use crate::util::free_boxed_value;
10use crate::{Connection, InnerConnection, Name, Result};
11
12impl Connection {
13    /// Add or modify a collation.
14    #[inline]
15    pub fn create_collation<C, N: Name>(&self, collation_name: N, x_compare: C) -> Result<()>
16    where
17        C: Fn(&str, &str) -> Ordering + Send + 'static,
18    {
19        self.db
20            .borrow_mut()
21            .create_collation(collation_name, x_compare)
22    }
23
24    /// Collation needed callback
25    #[inline]
26    pub fn collation_needed(&self, x_coll_needed: fn(&Self, &str) -> Result<()>) -> Result<()> {
27        self.db.borrow_mut().collation_needed(x_coll_needed)
28    }
29
30    /// Remove collation.
31    #[inline]
32    pub fn remove_collation<N: Name>(&self, collation_name: N) -> Result<()> {
33        self.db.borrow_mut().remove_collation(collation_name)
34    }
35}
36
37impl InnerConnection {
38    /// ```compile_fail
39    /// use rusqlite::{Connection, Result};
40    /// fn main() -> Result<()> {
41    ///     let db = Connection::open_in_memory()?;
42    ///     {
43    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
44    ///         db.create_collation("foo", |_, _| {
45    ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
46    ///             std::cmp::Ordering::Equal
47    ///         })?;
48    ///     }
49    ///     let value: String = db.query_row(
50    ///         "WITH cte(bar) AS
51    ///        (VALUES ('v1'),('v2'),('v3'),('v4'),('v5'))
52    ///         SELECT DISTINCT bar COLLATE foo FROM cte;",
53    ///         [],
54    ///         |row| row.get(0),
55    ///     )?;
56    ///     assert_eq!(value, "v1");
57    ///     Ok(())
58    /// }
59    /// ```
60    fn create_collation<C, N: Name>(&mut self, collation_name: N, x_compare: C) -> Result<()>
61    where
62        C: Fn(&str, &str) -> Ordering + Send + 'static,
63    {
64        unsafe extern "C" fn call_boxed_closure<C>(
65            arg1: *mut c_void,
66            arg2: c_int,
67            arg3: *const c_void,
68            arg4: c_int,
69            arg5: *const c_void,
70        ) -> c_int
71        where
72            C: Fn(&str, &str) -> Ordering,
73        {
74            let r = catch_unwind(|| {
75                let boxed_f: *mut C = arg1.cast::<C>();
76                assert!(!boxed_f.is_null(), "Internal error - null function pointer");
77                let s1 = {
78                    let c_slice = slice::from_raw_parts(arg3.cast::<u8>(), arg2 as usize);
79                    String::from_utf8_lossy(c_slice)
80                };
81                let s2 = {
82                    let c_slice = slice::from_raw_parts(arg5.cast::<u8>(), arg4 as usize);
83                    String::from_utf8_lossy(c_slice)
84                };
85                (*boxed_f)(s1.as_ref(), s2.as_ref())
86            });
87            let t = match r {
88                Err(_) => {
89                    return -1; // FIXME How ?
90                }
91                Ok(r) => r,
92            };
93
94            match t {
95                Ordering::Less => -1,
96                Ordering::Equal => 0,
97                Ordering::Greater => 1,
98            }
99        }
100
101        let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
102        let c_name = collation_name.as_cstr()?;
103        let flags = ffi::SQLITE_UTF8;
104        let r = unsafe {
105            ffi::sqlite3_create_collation_v2(
106                self.db(),
107                c_name.as_ptr(),
108                flags,
109                boxed_f.cast::<c_void>(),
110                Some(call_boxed_closure::<C>),
111                Some(free_boxed_value::<C>),
112            )
113        };
114        let res = self.decode_result(r);
115        // The xDestroy callback is not called if the sqlite3_create_collation_v2()
116        // function fails.
117        if res.is_err() {
118            drop(unsafe { Box::from_raw(boxed_f) });
119        }
120        res
121    }
122
123    fn collation_needed(
124        &mut self,
125        x_coll_needed: fn(&Connection, &str) -> Result<()>,
126    ) -> Result<()> {
127        use std::mem;
128        #[expect(clippy::needless_return)]
129        unsafe extern "C" fn collation_needed_callback(
130            arg1: *mut c_void,
131            arg2: *mut ffi::sqlite3,
132            e_text_rep: c_int,
133            arg3: *const c_char,
134        ) {
135            use std::str;
136
137            if e_text_rep != ffi::SQLITE_UTF8 {
138                // TODO: validate
139                return;
140            }
141
142            let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
143            let res = catch_unwind(|| {
144                let conn = Connection::from_handle(arg2).unwrap();
145                let collation_name = CStr::from_ptr(arg3)
146                    .to_str()
147                    .expect("illegal collation sequence name");
148                callback(&conn, collation_name)
149            });
150            if res.is_err() {
151                return; // FIXME How ?
152            }
153        }
154
155        let r = unsafe {
156            ffi::sqlite3_collation_needed(
157                self.db(),
158                x_coll_needed as *mut c_void,
159                Some(collation_needed_callback),
160            )
161        };
162        self.decode_result(r)
163    }
164
165    #[inline]
166    fn remove_collation<N: Name>(&mut self, collation_name: N) -> Result<()> {
167        let c_name = collation_name.as_cstr()?;
168        let r = unsafe {
169            ffi::sqlite3_create_collation_v2(
170                self.db(),
171                c_name.as_ptr(),
172                ffi::SQLITE_UTF8,
173                ptr::null_mut(),
174                None,
175                None,
176            )
177        };
178        self.decode_result(r)
179    }
180}
181
182#[cfg(test)]
183mod test {
184    use crate::{Connection, Result};
185    use fallible_streaming_iterator::FallibleStreamingIterator;
186    use std::cmp::Ordering;
187    use unicase::UniCase;
188
189    fn unicase_compare(s1: &str, s2: &str) -> Ordering {
190        UniCase::new(s1).cmp(&UniCase::new(s2))
191    }
192
193    #[test]
194    fn test_unicase() -> Result<()> {
195        let db = Connection::open_in_memory()?;
196        db.create_collation(c"unicase", unicase_compare)?;
197        collate(db)
198    }
199
200    fn collate(db: Connection) -> Result<()> {
201        db.execute_batch(
202            "CREATE TABLE foo (bar);
203             INSERT INTO foo (bar) VALUES ('Maße');
204             INSERT INTO foo (bar) VALUES ('MASSE');",
205        )?;
206        let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?;
207        let rows = stmt.query([])?;
208        assert_eq!(rows.count()?, 1);
209        Ok(())
210    }
211
212    fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
213        if "unicase" == collation_name {
214            db.create_collation(collation_name, unicase_compare)
215        } else {
216            Ok(())
217        }
218    }
219
220    #[test]
221    fn test_collation_needed() -> Result<()> {
222        let db = Connection::open_in_memory()?;
223        db.collation_needed(collation_needed)?;
224        collate(db)
225    }
226
227    #[test]
228    fn remove_collation() -> Result<()> {
229        let db = Connection::open_in_memory()?;
230        db.create_collation(c"unicase", unicase_compare)?;
231        db.remove_collation(c"unicase")
232    }
233}