1use crate::raw_statement::RawStatement;
4use crate::{Connection, PrepFlags, Result, Statement};
5use hashlink::LruCache;
6use std::cell::RefCell;
7use std::ops::{Deref, DerefMut};
8use std::sync::Arc;
9
10impl Connection {
11 #[inline]
38 pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>> {
39 self.cache.get(self, sql)
40 }
41
42 #[inline]
48 pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) {
49 self.cache.set_capacity(capacity);
50 }
51
52 #[inline]
54 pub fn flush_prepared_statement_cache(&self) {
55 self.cache.flush();
56 }
57}
58
59#[derive(Debug)]
61pub struct StatementCache(RefCell<LruCache<Arc<str>, RawStatement>>);
62
63unsafe impl Send for StatementCache {}
64
65pub struct CachedStatement<'conn> {
71 stmt: Option<Statement<'conn>>,
72 cache: &'conn StatementCache,
73}
74
75impl<'conn> Deref for CachedStatement<'conn> {
76 type Target = Statement<'conn>;
77
78 #[inline]
79 fn deref(&self) -> &Statement<'conn> {
80 self.stmt.as_ref().unwrap()
81 }
82}
83
84impl<'conn> DerefMut for CachedStatement<'conn> {
85 #[inline]
86 fn deref_mut(&mut self) -> &mut Statement<'conn> {
87 self.stmt.as_mut().unwrap()
88 }
89}
90
91impl Drop for CachedStatement<'_> {
92 #[inline]
93 fn drop(&mut self) {
94 if let Some(stmt) = self.stmt.take() {
95 self.cache.cache_stmt(unsafe { stmt.into_raw() });
96 }
97 }
98}
99
100impl CachedStatement<'_> {
101 #[inline]
102 fn new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> {
103 CachedStatement {
104 stmt: Some(stmt),
105 cache,
106 }
107 }
108
109 #[inline]
112 pub fn discard(mut self) {
113 self.stmt = None;
114 }
115}
116
117impl StatementCache {
118 #[inline]
120 pub fn with_capacity(capacity: usize) -> Self {
121 Self(RefCell::new(LruCache::new(capacity)))
122 }
123
124 #[inline]
125 fn set_capacity(&self, capacity: usize) {
126 self.0.borrow_mut().set_capacity(capacity);
127 }
128
129 fn get<'conn>(
137 &'conn self,
138 conn: &'conn Connection,
139 sql: &str,
140 ) -> Result<CachedStatement<'conn>> {
141 let trimmed = sql.trim();
142 let mut cache = self.0.borrow_mut();
143 let stmt = match cache.remove(trimmed) {
144 Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)),
145 None => conn.prepare_with_flags(trimmed, PrepFlags::SQLITE_PREPARE_PERSISTENT),
146 };
147 stmt.map(|mut stmt| {
148 stmt.stmt.set_statement_cache_key(trimmed);
149 CachedStatement::new(stmt, self)
150 })
151 }
152
153 fn cache_stmt(&self, mut stmt: RawStatement) {
155 if stmt.is_null() {
156 return;
157 }
158 let mut cache = self.0.borrow_mut();
159 stmt.clear_bindings();
160 if let Some(sql) = stmt.statement_cache_key() {
161 cache.insert(sql, stmt);
162 } else {
163 debug_assert!(
164 false,
165 "bug in statement cache code, statement returned to cache that without key"
166 );
167 }
168 }
169
170 #[inline]
171 fn flush(&self) {
172 let mut cache = self.0.borrow_mut();
173 cache.clear();
174 }
175}
176
177#[cfg(test)]
178mod test {
179 use super::StatementCache;
180 use crate::{Connection, Result};
181 use fallible_iterator::FallibleIterator;
182
183 impl StatementCache {
184 fn clear(&self) {
185 self.0.borrow_mut().clear();
186 }
187
188 fn len(&self) -> usize {
189 self.0.borrow().len()
190 }
191
192 fn capacity(&self) -> usize {
193 self.0.borrow().capacity()
194 }
195 }
196
197 #[test]
198 fn test_cache() -> Result<()> {
199 let db = Connection::open_in_memory()?;
200 let cache = &db.cache;
201 let initial_capacity = cache.capacity();
202 assert_eq!(0, cache.len());
203 assert!(initial_capacity > 0);
204
205 let sql = "PRAGMA schema_version";
206 {
207 let mut stmt = db.prepare_cached(sql)?;
208 assert_eq!(0, cache.len());
209 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
210 }
211 assert_eq!(1, cache.len());
212
213 {
214 let mut stmt = db.prepare_cached(sql)?;
215 assert_eq!(0, cache.len());
216 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
217 }
218 assert_eq!(1, cache.len());
219
220 cache.clear();
221 assert_eq!(0, cache.len());
222 assert_eq!(initial_capacity, cache.capacity());
223 Ok(())
224 }
225
226 #[test]
227 fn test_set_capacity() -> Result<()> {
228 let db = Connection::open_in_memory()?;
229 let cache = &db.cache;
230
231 let sql = "PRAGMA schema_version";
232 {
233 let mut stmt = db.prepare_cached(sql)?;
234 assert_eq!(0, cache.len());
235 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
236 }
237 assert_eq!(1, cache.len());
238
239 db.set_prepared_statement_cache_capacity(0);
240 assert_eq!(0, cache.len());
241
242 {
243 let mut stmt = db.prepare_cached(sql)?;
244 assert_eq!(0, cache.len());
245 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
246 }
247 assert_eq!(0, cache.len());
248
249 db.set_prepared_statement_cache_capacity(8);
250 {
251 let mut stmt = db.prepare_cached(sql)?;
252 assert_eq!(0, cache.len());
253 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
254 }
255 assert_eq!(1, cache.len());
256 Ok(())
257 }
258
259 #[test]
260 fn test_discard() -> Result<()> {
261 let db = Connection::open_in_memory()?;
262 let cache = &db.cache;
263
264 let sql = "PRAGMA schema_version";
265 {
266 let mut stmt = db.prepare_cached(sql)?;
267 assert_eq!(0, cache.len());
268 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
269 stmt.discard();
270 }
271 assert_eq!(0, cache.len());
272 Ok(())
273 }
274
275 #[test]
276 fn test_ddl() -> Result<()> {
277 let db = Connection::open_in_memory()?;
278 db.execute_batch(
279 r"
280 CREATE TABLE foo (x INT);
281 INSERT INTO foo VALUES (1);
282 ",
283 )?;
284
285 let sql = "SELECT * FROM foo";
286
287 {
288 let mut stmt = db.prepare_cached(sql)?;
289 assert_eq!(Ok(Some(1i32)), stmt.query([])?.map(|r| r.get(0)).next());
290 }
291
292 db.execute_batch(
293 r"
294 ALTER TABLE foo ADD COLUMN y INT;
295 UPDATE foo SET y = 2;
296 ",
297 )?;
298
299 {
300 let mut stmt = db.prepare_cached(sql)?;
301 assert_eq!(
302 Ok(Some((1i32, 2i32))),
303 stmt.query([])?.map(|r| Ok((r.get(0)?, r.get(1)?))).next()
304 );
305 }
306 Ok(())
307 }
308
309 #[test]
310 fn test_connection_close() -> Result<()> {
311 let conn = Connection::open_in_memory()?;
312 conn.prepare_cached("SELECT * FROM sqlite_master;")?;
313
314 conn.close().expect("connection not closed");
315 Ok(())
316 }
317
318 #[test]
319 fn test_cache_key() -> Result<()> {
320 let db = Connection::open_in_memory()?;
321 let cache = &db.cache;
322 assert_eq!(0, cache.len());
323
324 let sql = "PRAGMA schema_version; ";
326 {
327 let mut stmt = db.prepare_cached(sql)?;
328 assert_eq!(0, cache.len());
329 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
330 }
331 assert_eq!(1, cache.len());
332
333 {
334 let mut stmt = db.prepare_cached(sql)?;
335 assert_eq!(0, cache.len());
336 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
337 }
338 assert_eq!(1, cache.len());
339 Ok(())
340 }
341
342 #[test]
343 fn test_empty_stmt() -> Result<()> {
344 let conn = Connection::open_in_memory()?;
345 conn.prepare_cached("")?;
346 Ok(())
347 }
348}