Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Support for prepared statements #474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 16, 2023
Prev Previous commit
Next Next commit
parse
  • Loading branch information
levkk committed Jun 15, 2023
commit 488df6c459b7aa82c10e6cd8c1011f574b421736
45 changes: 38 additions & 7 deletions 45 src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use crate::constants::*;
use crate::messages::*;
use crate::plugins::PluginOutput;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter, PreparedStatement, PreparedStatementName};
use crate::query_router::{
Command, ParseResult, PreparedStatement, PreparedStatementName, QueryRouter,
};
use crate::server::Server;
use crate::stats::{ClientStats, ServerStats};
use crate::tls::Tls;
Expand Down Expand Up @@ -813,7 +815,8 @@ where
'Q' => {
if query_router.query_parser_enabled() {
if let Ok(parse_result) = QueryRouter::parse(&message) {
let plugin_result = query_router.execute_plugins(&parse_result.ast).await;
let plugin_result =
query_router.execute_plugins(&parse_result.ast).await;

match plugin_result {
Ok(PluginOutput::Deny(error)) => {
Expand All @@ -839,10 +842,14 @@ where

if query_router.query_parser_enabled() {
if let Ok(parse_result) = QueryRouter::parse(&message) {
if let Ok(output) = query_router.execute_plugins(&parse_result.ast).await {
if let Ok(output) =
query_router.execute_plugins(&parse_result.ast).await
{
plugin_output = Some(output);
}

self.handle_prepared_statement(&parse_result);

let _ = query_router.infer(&parse_result.ast);
}
}
Expand Down Expand Up @@ -1120,7 +1127,8 @@ where
'Q' => {
if query_router.query_parser_enabled() {
if let Ok(parse_result) = QueryRouter::parse(&message) {
let plugin_result = query_router.execute_plugins(&parse_result.ast).await;
let plugin_result =
query_router.execute_plugins(&parse_result.ast).await;

match plugin_result {
Ok(PluginOutput::Deny(error)) => {
Expand Down Expand Up @@ -1178,21 +1186,33 @@ where
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => {
let parse: Parse = (&message).try_into()?;
debug!("Parse: {:?}", parse);
let back: BytesMut = parse.try_into()?;

if query_router.query_parser_enabled() {
if let Ok(parse_result) = QueryRouter::parse(&message) {
if let Ok(output) = query_router.execute_plugins(&parse_result.ast).await {
if let Ok(output) =
query_router.execute_plugins(&parse_result.ast).await
{
plugin_output = Some(output);
}

self.handle_prepared_statement(&parse_result);
}
}

self.buffer.put(&message[..]);
self.buffer.put(&back[..]);
}

// Bind
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
'B' => {
self.buffer.put(&message[..]);
let bind: Bind = (&message).try_into()?;
debug!("Bind: {:?}", bind);
let back: BytesMut = bind.try_into()?;

self.buffer.put(&back[..]);
}

// Describe
Expand Down Expand Up @@ -1368,6 +1388,17 @@ where
}
}

fn handle_prepared_statement(&mut self, parse_result: &ParseResult) {
if !self.prepared_statements.contains_key(&parse_result.name) {
debug!(
"Adding prepared statement `{}` to cache",
parse_result.name.0
);
self.prepared_statements
.insert(parse_result.name.clone(), parse_result.statement.clone());
}
}

/// Release the server from the client: it can't cancel its queries anymore.
pub fn release(&self) {
let mut guard = self.client_server_map.lock();
Expand Down
7 changes: 7 additions & 0 deletions 7 src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub enum Error {
AuthPassthroughError(String),
UnsupportedStatement,
QueryRouterParserError(String),
QueryRouterError(String),
}

#[derive(Clone, PartialEq, Debug)]
Expand Down Expand Up @@ -121,3 +122,9 @@ impl std::fmt::Display for Error {
}
}
}

impl From<std::ffi::NulError> for Error {
fn from(err: std::ffi::NulError) -> Self {
Error::QueryRouterError(err.to_string())
}
}
196 changes: 196 additions & 0 deletions 196 src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use tokio::net::TcpStream;
use crate::config::get_config;
use crate::errors::Error;
use std::collections::HashMap;
use std::ffi::CString;
use std::io::{BufRead, Cursor};
use std::mem;
use std::time::Duration;
Expand Down Expand Up @@ -689,3 +690,198 @@ impl BytesMutReader for Cursor<&BytesMut> {
}
}
}

/// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Parse {
code: char,
len: i32,
name: String,
query: String,
num_params: i16,
param_types: Vec<i32>,
}

impl TryFrom<&BytesMut> for Parse {
type Error = Error;

fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let name = cursor.read_string()?;
let query = cursor.read_string()?;
let num_params = cursor.get_i16();
let mut param_types = Vec::new();

for _ in 0..num_params {
param_types.push(cursor.get_i32());
}

Ok(Parse {
code,
len,
name,
query,
num_params,
param_types,
})
}
}

impl TryFrom<Parse> for BytesMut {
type Error = Error;

fn try_from(parse: Parse) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();

let name_binding = CString::new(parse.name)?;
let name = name_binding.as_bytes_with_nul();

let query_binding = CString::new(parse.query)?;
let query = query_binding.as_bytes_with_nul();

// Recompute length of the message.
let len = 4 // self
+ name.len()
+ query.len()
+ 2
+ 4 * parse.num_params as usize;

bytes.put_u8(parse.code as u8);
bytes.put_i32(len as i32);
bytes.put_slice(name);
bytes.put_slice(query);
bytes.put_i16(parse.num_params);
for param in parse.param_types {
bytes.put_i32(param);
}

Ok(bytes)
}
}

impl Parse {
pub fn rename(&mut self, name: &str) {
self.name = name.to_string();
}
}

/// Bind (B) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Bind {
code: char,
len: i64,
portal: String,
prepared_statement: String,
num_param_format_codes: i16,
param_format_codes: Vec<i16>,
num_param_values: i16,
param_values: Vec<(i32, BytesMut)>,
num_result_column_format_codes: i16,
result_columns_format_codes: Vec<i16>,
}

impl TryFrom<&BytesMut> for Bind {
type Error = Error;

fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
let num_param_format_codes = cursor.get_i16();
let mut param_format_codes = Vec::new();

for _ in 0..num_param_format_codes {
param_format_codes.push(cursor.get_i16());
}

let num_param_values = cursor.get_i16();
let mut param_values = Vec::new();

for _ in 0..num_param_values {
let param_len = cursor.get_i32();
let mut param = BytesMut::with_capacity(param_len as usize);
param.resize(param_len as usize, b'0');
cursor.copy_to_slice(&mut param);
param_values.push((param_len, param));
}

let num_result_column_format_codes = cursor.get_i16();
let mut result_columns_format_codes = Vec::new();

for _ in 0..num_result_column_format_codes {
result_columns_format_codes.push(cursor.get_i16());
}

Ok(Bind {
code,
len: len as i64,
portal,
prepared_statement,
num_param_format_codes,
param_format_codes,
num_param_values,
param_values,
num_result_column_format_codes,
result_columns_format_codes,
})
}
}

impl TryFrom<Bind> for BytesMut {
type Error = Error;

fn try_from(bind: Bind) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();

let portal_binding = CString::new(bind.portal)?;
let portal = portal_binding.as_bytes_with_nul();

let prepared_statement_binding = CString::new(bind.prepared_statement)?;
let prepared_statement = prepared_statement_binding.as_bytes_with_nul();

let mut len = 4 // self
+ portal.len()
+ prepared_statement.len()
+ 2 // num_param_format_codes
+ 2 * bind.num_param_format_codes as usize // num_param_format_codes
+ 2; // num_param_values

for (param_len, _) in &bind.param_values {
len += 4 + *param_len as usize;
}
len += 2; // num_result_column_format_codes
len += 2 * bind.num_result_column_format_codes as usize;

bytes.put_u8(bind.code as u8);
bytes.put_i32(len as i32);
bytes.put_slice(portal);
bytes.put_slice(prepared_statement);
bytes.put_i16(bind.num_param_format_codes);
for param_format_code in bind.param_format_codes {
bytes.put_i16(param_format_code);
}
bytes.put_i16(bind.num_param_values);
for (param_len, param) in bind.param_values {
bytes.put_i32(param_len);
bytes.put_slice(&param);
}
bytes.put_i16(bind.num_result_column_format_codes);
for result_column_format_code in bind.result_columns_format_codes {
bytes.put_i16(result_column_format_code);
}

Ok(bytes)
}
}

impl Bind {
pub fn reassign(&mut self, prepared_statement: &str) {
self.prepared_statement = prepared_statement.to_string();
}
}
2 changes: 2 additions & 0 deletions 2 src/query_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ enum ParameterFormat {
Specified(Vec<ParameterFormat>),
}

#[derive(Clone, PartialEq, Hash, Eq)]
pub struct PreparedStatementName(pub String);
#[derive(Clone)]
pub struct PreparedStatement(pub String);

pub struct ParseResult {
Expand Down
2 changes: 1 addition & 1 deletion 2 src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, TlsConnector};

use crate::config::{get_config, Address, User};
use crate::query_router::{PreparedStatement, PreparedStatementName};
use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier};
use crate::messages::*;
use crate::mirrors::MirroringManager;
use crate::pool::ClientServerMap;
use crate::query_router::{PreparedStatement, PreparedStatementName};
use crate::scram::ScramSha256;
use crate::stats::ServerStats;
use std::io::Write;
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.