1use std::cmp::Ordering;
3use std::ffi::{c_char, c_int, c_void, CStr};
4use std::panic::catch_unwind;
5use std::ptr;
6use std::slice;
7
8use crate::ffi;
9use crate::util::free_boxed_value;
10use crate::{Connection, InnerConnection, Name, Result};
11
12impl Connection {
13 #[inline]
15 pub fn create_collation<C, N: Name>(&self, collation_name: N, x_compare: C) -> Result<()>
16 where
17 C: Fn(&str, &str) -> Ordering + Send + 'static,
18 {
19 self.db
20 .borrow_mut()
21 .create_collation(collation_name, x_compare)
22 }
23
24 #[inline]
26 pub fn collation_needed(&self, x_coll_needed: fn(&Self, &str) -> Result<()>) -> Result<()> {
27 self.db.borrow_mut().collation_needed(x_coll_needed)
28 }
29
30 #[inline]
32 pub fn remove_collation<N: Name>(&self, collation_name: N) -> Result<()> {
33 self.db.borrow_mut().remove_collation(collation_name)
34 }
35}
36
37impl InnerConnection {
38 fn create_collation<C, N: Name>(&mut self, collation_name: N, x_compare: C) -> Result<()>
61 where
62 C: Fn(&str, &str) -> Ordering + Send + 'static,
63 {
64 unsafe extern "C" fn call_boxed_closure<C>(
65 arg1: *mut c_void,
66 arg2: c_int,
67 arg3: *const c_void,
68 arg4: c_int,
69 arg5: *const c_void,
70 ) -> c_int
71 where
72 C: Fn(&str, &str) -> Ordering,
73 {
74 let r = catch_unwind(|| {
75 let boxed_f: *mut C = arg1.cast::<C>();
76 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
77 let s1 = {
78 let c_slice = slice::from_raw_parts(arg3.cast::<u8>(), arg2 as usize);
79 String::from_utf8_lossy(c_slice)
80 };
81 let s2 = {
82 let c_slice = slice::from_raw_parts(arg5.cast::<u8>(), arg4 as usize);
83 String::from_utf8_lossy(c_slice)
84 };
85 (*boxed_f)(s1.as_ref(), s2.as_ref())
86 });
87 let t = match r {
88 Err(_) => {
89 return -1; }
91 Ok(r) => r,
92 };
93
94 match t {
95 Ordering::Less => -1,
96 Ordering::Equal => 0,
97 Ordering::Greater => 1,
98 }
99 }
100
101 let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
102 let c_name = collation_name.as_cstr()?;
103 let flags = ffi::SQLITE_UTF8;
104 let r = unsafe {
105 ffi::sqlite3_create_collation_v2(
106 self.db(),
107 c_name.as_ptr(),
108 flags,
109 boxed_f.cast::<c_void>(),
110 Some(call_boxed_closure::<C>),
111 Some(free_boxed_value::<C>),
112 )
113 };
114 let res = self.decode_result(r);
115 if res.is_err() {
118 drop(unsafe { Box::from_raw(boxed_f) });
119 }
120 res
121 }
122
123 fn collation_needed(
124 &mut self,
125 x_coll_needed: fn(&Connection, &str) -> Result<()>,
126 ) -> Result<()> {
127 use std::mem;
128 #[expect(clippy::needless_return)]
129 unsafe extern "C" fn collation_needed_callback(
130 arg1: *mut c_void,
131 arg2: *mut ffi::sqlite3,
132 e_text_rep: c_int,
133 arg3: *const c_char,
134 ) {
135 use std::str;
136
137 if e_text_rep != ffi::SQLITE_UTF8 {
138 return;
140 }
141
142 let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
143 let res = catch_unwind(|| {
144 let conn = Connection::from_handle(arg2).unwrap();
145 let collation_name = CStr::from_ptr(arg3)
146 .to_str()
147 .expect("illegal collation sequence name");
148 callback(&conn, collation_name)
149 });
150 if res.is_err() {
151 return; }
153 }
154
155 let r = unsafe {
156 ffi::sqlite3_collation_needed(
157 self.db(),
158 x_coll_needed as *mut c_void,
159 Some(collation_needed_callback),
160 )
161 };
162 self.decode_result(r)
163 }
164
165 #[inline]
166 fn remove_collation<N: Name>(&mut self, collation_name: N) -> Result<()> {
167 let c_name = collation_name.as_cstr()?;
168 let r = unsafe {
169 ffi::sqlite3_create_collation_v2(
170 self.db(),
171 c_name.as_ptr(),
172 ffi::SQLITE_UTF8,
173 ptr::null_mut(),
174 None,
175 None,
176 )
177 };
178 self.decode_result(r)
179 }
180}
181
182#[cfg(test)]
183mod test {
184 use crate::{Connection, Result};
185 use fallible_streaming_iterator::FallibleStreamingIterator;
186 use std::cmp::Ordering;
187 use unicase::UniCase;
188
189 fn unicase_compare(s1: &str, s2: &str) -> Ordering {
190 UniCase::new(s1).cmp(&UniCase::new(s2))
191 }
192
193 #[test]
194 fn test_unicase() -> Result<()> {
195 let db = Connection::open_in_memory()?;
196 db.create_collation(c"unicase", unicase_compare)?;
197 collate(db)
198 }
199
200 fn collate(db: Connection) -> Result<()> {
201 db.execute_batch(
202 "CREATE TABLE foo (bar);
203 INSERT INTO foo (bar) VALUES ('Maße');
204 INSERT INTO foo (bar) VALUES ('MASSE');",
205 )?;
206 let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?;
207 let rows = stmt.query([])?;
208 assert_eq!(rows.count()?, 1);
209 Ok(())
210 }
211
212 fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
213 if "unicase" == collation_name {
214 db.create_collation(collation_name, unicase_compare)
215 } else {
216 Ok(())
217 }
218 }
219
220 #[test]
221 fn test_collation_needed() -> Result<()> {
222 let db = Connection::open_in_memory()?;
223 db.collation_needed(collation_needed)?;
224 collate(db)
225 }
226
227 #[test]
228 fn remove_collation() -> Result<()> {
229 let db = Connection::open_in_memory()?;
230 db.create_collation(c"unicase", unicase_compare)?;
231 db.remove_collation(c"unicase")
232 }
233}