diff --git a/pgcat.toml b/pgcat.toml index 41d0210a..3e8801b6 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -63,6 +63,9 @@ tcp_keepalives_interval = 5 # Handle prepared statements. prepared_statements = true +# Prepared statements server cache size. +prepared_statements_cache_size = 500 + # Path to TLS Certificate file to use for TLS connections # tls_certificate = ".circleci/server.cert" # Path to TLS private key file to use for TLS connections diff --git a/src/admin.rs b/src/admin.rs index bbca956f..03b984ae 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -701,6 +701,7 @@ where ("age_seconds", DataType::Numeric), ("prepare_cache_hit", DataType::Numeric), ("prepare_cache_miss", DataType::Numeric), + ("prepare_cache_size", DataType::Numeric), ]; let new_map = get_server_stats(); @@ -732,6 +733,10 @@ where .prepared_miss_count .load(Ordering::Relaxed) .to_string(), + server + .prepared_cache_size + .load(Ordering::Relaxed) + .to_string(), ]; res.put(data_row(&row)); diff --git a/src/client.rs b/src/client.rs index 608d838d..6c0d06fc 100644 --- a/src/client.rs +++ b/src/client.rs @@ -906,6 +906,19 @@ where return Ok(()); } + // Close (F) + 'C' => { + if prepared_statements_enabled { + let close: Close = (&message).try_into()?; + + if close.is_prepared_statement() && !close.anonymous() { + self.prepared_statements.remove(&close.name); + write_all_flush(&mut self.write, &close_complete()).await?; + continue; + } + } + } + _ => (), } @@ -1130,7 +1143,17 @@ where } else { // The statement is not prepared on the server, so we need to prepare it. if server.should_prepare(&statement.name) { - server.prepare(statement).await?; + match server.prepare(statement).await { + Ok(_) => (), + Err(err) => { + pool.ban( + &address, + BanReason::MessageSendFailed, + Some(&self.stats), + ); + return Err(err); + } + } } } @@ -1251,6 +1274,10 @@ where self.stats.disconnect(); self.release(); + if prepared_statements_enabled { + server.maintain_cache().await?; + } + return Ok(()); } @@ -1300,6 +1327,21 @@ where // Close the prepared statement. 'C' => { + if prepared_statements_enabled { + let close: Close = (&message).try_into()?; + + if close.is_prepared_statement() && !close.anonymous() { + match self.prepared_statements.get(&close.name) { + Some(parse) => { + server.will_close(&parse.generated_name); + } + + // A prepared statement slipped through? Not impossible, since we don't support PREPARE yet. + None => (), + }; + } + } + self.buffer.put(&message[..]); } @@ -1433,7 +1475,13 @@ where // The server is no longer bound to us, we can't cancel it's queries anymore. debug!("Releasing server back into the pool"); + server.checkin_cleanup().await?; + + if prepared_statements_enabled { + server.maintain_cache().await?; + } + server.stats().idle(); self.connected_to_server = false; diff --git a/src/config.rs b/src/config.rs index 66c20758..a2314fc1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -323,6 +323,9 @@ pub struct General { #[serde(default)] pub prepared_statements: bool, + + #[serde(default = "General::default_prepared_statements_cache_size")] + pub prepared_statements_cache_size: usize, } impl General { @@ -400,6 +403,10 @@ impl General { pub fn default_server_round_robin() -> bool { true } + + pub fn default_prepared_statements_cache_size() -> usize { + 500 + } } impl Default for General { @@ -438,6 +445,7 @@ impl Default for General { server_round_robin: false, validate_config: true, prepared_statements: false, + prepared_statements_cache_size: 500, } } } @@ -1020,6 +1028,12 @@ impl Config { self.general.verify_server_certificate ); info!("Prepared statements: {}", self.general.prepared_statements); + if self.general.prepared_statements { + info!( + "Prepared statements server cache size: {}", + self.general.prepared_statements_cache_size + ); + } info!( "Plugins: {}", match self.plugins { @@ -1239,13 +1253,15 @@ pub fn get_config() -> Config { } pub fn get_idle_client_in_transaction_timeout() -> u64 { - (*(*CONFIG.load())) - .general - .idle_client_in_transaction_timeout + CONFIG.load().general.idle_client_in_transaction_timeout } pub fn get_prepared_statements() -> bool { - (*(*CONFIG.load())).general.prepared_statements + CONFIG.load().general.prepared_statements +} + +pub fn get_prepared_statements_cache_size() -> usize { + CONFIG.load().general.prepared_statements_cache_size } /// Parse the configuration file located at the path. diff --git a/src/messages.rs b/src/messages.rs index 552497f1..196abe83 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,7 +1,7 @@ /// Helper functions to send one-off protocol messages /// and handle TcpStream (TCP socket). use bytes::{Buf, BufMut, BytesMut}; -use log::error; +use log::{debug, error}; use md5::{Digest, Md5}; use socket2::{SockRef, TcpKeepalive}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -976,6 +976,84 @@ impl Describe { } } +/// Close (F) message. +/// See: +#[derive(Clone, Debug)] +pub struct Close { + code: char, + #[allow(dead_code)] + len: i32, + close_type: char, + pub name: String, +} + +impl TryFrom<&BytesMut> for Close { + type Error = Error; + + fn try_from(bytes: &BytesMut) -> Result { + let mut cursor = Cursor::new(bytes); + let code = cursor.get_u8() as char; + let len = cursor.get_i32(); + let close_type = cursor.get_u8() as char; + let name = cursor.read_string()?; + + Ok(Close { + code, + len, + close_type, + name, + }) + } +} + +impl TryFrom for BytesMut { + type Error = Error; + + fn try_from(close: Close) -> Result { + debug!("Close: {:?}", close); + + let mut bytes = BytesMut::new(); + let name_binding = CString::new(close.name)?; + let name = name_binding.as_bytes_with_nul(); + let len = 4 + 1 + name.len(); + + bytes.put_u8(close.code as u8); + bytes.put_i32(len as i32); + bytes.put_u8(close.close_type as u8); + bytes.put_slice(name); + + Ok(bytes) + } +} + +impl Close { + pub fn new(name: &str) -> Close { + let name = name.to_string(); + + Close { + code: 'C', + len: 4 + 1 + name.len() as i32 + 1, // will be recalculated + close_type: 'S', + name, + } + } + + pub fn is_prepared_statement(&self) -> bool { + self.close_type == 'S' + } + + pub fn anonymous(&self) -> bool { + self.name.is_empty() + } +} + +pub fn close_complete() -> BytesMut { + let mut bytes = BytesMut::new(); + bytes.put_u8(b'3'); + bytes.put_i32(4); + bytes +} + pub fn prepared_statement_name() -> String { format!( "P_{}", diff --git a/src/server.rs b/src/server.rs index ab29db09..fa68b678 100644 --- a/src/server.rs +++ b/src/server.rs @@ -15,7 +15,7 @@ use tokio::net::TcpStream; use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore}; use tokio_rustls::{client::TlsStream, TlsConnector}; -use crate::config::{get_config, Address, User}; +use crate::config::{get_config, get_prepared_statements_cache_size, Address, User}; use crate::constants::*; use crate::dns_cache::{AddrSet, CACHED_RESOLVER}; use crate::errors::{Error, ServerIdentifier}; @@ -914,12 +914,16 @@ impl Server { Ok(bytes) } + /// Add the prepared statement to being tracked by this server. + /// The client is processing data that will create a prepared statement on this server. pub fn will_prepare(&mut self, name: &str) { debug!("Will prepare `{}`", name); self.prepared_statements.insert(name.to_string()); + self.stats.prepared_cache_add(); } + /// Check if we should prepare a statement on the server. pub fn should_prepare(&self, name: &str) -> bool { let should_prepare = !self.prepared_statements.contains(name); @@ -934,6 +938,7 @@ impl Server { should_prepare } + /// Create a prepared statement on the server. pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> { debug!("Preparing `{}`", parse.name); @@ -942,15 +947,82 @@ impl Server { self.send(&flush()).await?; // Read and discard ParseComplete (B) - let _ = read_message(&mut self.stream).await?; + match read_message(&mut self.stream).await { + Ok(_) => (), + Err(err) => { + self.bad = true; + return Err(err); + } + } self.prepared_statements.insert(parse.name.to_string()); + self.stats.prepared_cache_add(); debug!("Prepared `{}`", parse.name); Ok(()) } + /// Maintain adequate cache size on the server. + pub async fn maintain_cache(&mut self) -> Result<(), Error> { + debug!("Cache maintenance run"); + + let max_cache_size = get_prepared_statements_cache_size(); + let mut names = Vec::new(); + + while self.prepared_statements.len() >= max_cache_size { + // The prepared statmeents are alphanumerically sorted by the BTree. + // FIFO. + if let Some(name) = self.prepared_statements.pop_last() { + names.push(name); + } + } + + self.deallocate(names).await?; + + Ok(()) + } + + /// Remove the prepared statement from being tracked by this server. + /// The client is processing data that will cause the server to close the prepared statement. + pub fn will_close(&mut self, name: &str) { + debug!("Will close `{}`", name); + + self.prepared_statements.remove(name); + } + + /// Close a prepared statement on the server. + pub async fn deallocate(&mut self, names: Vec) -> Result<(), Error> { + for name in &names { + debug!("Deallocating prepared statement `{}`", name); + + let close = Close::new(name); + let bytes: BytesMut = close.try_into()?; + + self.send(&bytes).await?; + } + + self.send(&flush()).await?; + + // Read and discard CloseComplete (3) + for name in &names { + match read_message(&mut self.stream).await { + Ok(_) => { + self.prepared_statements.remove(name); + self.stats.prepared_cache_remove(); + debug!("Closed `{}`", name); + } + + Err(err) => { + self.bad = true; + return Err(err); + } + }; + } + + Ok(()) + } + /// If the server is still inside a transaction. /// If the client disconnects while the server is in a transaction, we will clean it up. pub fn in_transaction(&self) -> bool { diff --git a/src/stats/server.rs b/src/stats/server.rs index 6fb2dc97..443c0b6a 100644 --- a/src/stats/server.rs +++ b/src/stats/server.rs @@ -49,6 +49,7 @@ pub struct ServerStats { pub error_count: Arc, pub prepared_hit_count: Arc, pub prepared_miss_count: Arc, + pub prepared_cache_size: Arc, } impl Default for ServerStats { @@ -67,6 +68,7 @@ impl Default for ServerStats { reporter: get_reporter(), prepared_hit_count: Arc::new(AtomicU64::new(0)), prepared_miss_count: Arc::new(AtomicU64::new(0)), + prepared_cache_size: Arc::new(AtomicU64::new(0)), } } } @@ -213,4 +215,12 @@ impl ServerStats { pub fn prepared_cache_miss(&self) { self.prepared_miss_count.fetch_add(1, Ordering::Relaxed); } + + pub fn prepared_cache_add(&self) { + self.prepared_cache_size.fetch_add(1, Ordering::Relaxed); + } + + pub fn prepared_cache_remove(&self) { + self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed); + } } diff --git a/tests/ruby/prepared_spec.rb b/tests/ruby/prepared_spec.rb new file mode 100644 index 00000000..58a30006 --- /dev/null +++ b/tests/ruby/prepared_spec.rb @@ -0,0 +1,29 @@ +require_relative 'spec_helper' + +describe 'Prepared statements' do + let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) } + + context 'enabled' do + it 'will work over the same connection' do + conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user')) + + 10.times do |i| + statement_name = "statement_#{i}" + conn.prepare(statement_name, 'SELECT $1::int') + conn.exec_prepared(statement_name, [1]) + conn.describe_prepared(statement_name) + end + end + + it 'will work with new connections' do + 10.times do + conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user')) + + statement_name = 'statement1' + conn.prepare('statement1', 'SELECT $1::int') + conn.exec_prepared('statement1', [1]) + conn.describe_prepared('statement1') + end + end + end +end