diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8d17f4d6b..1ca030d26 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,9 +3,11 @@ name: CI on: pull_request: branches: + - neon - master push: branches: + - neon - master env: @@ -55,7 +57,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.62.0 + version: 1.65.0 - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - uses: actions/cache@v1 diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 0315ac805..051a12000 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -64,6 +64,7 @@ port = 5433 ssl = on ssl_cert_file = 'server.crt' ssl_key_file = 'server.key' +wal_level = logical EOCONF cat > "$PGDATA/pg_hba.conf" <<-EOCONF @@ -82,6 +83,7 @@ host all ssl_user ::0/0 reject # IPv4 local connections: host all postgres 0.0.0.0/0 trust +host replication postgres 0.0.0.0/0 trust # IPv6 local connections: host all postgres ::0/0 trust # Unix socket connections: diff --git a/postgres-derive-test/src/lib.rs b/postgres-derive-test/src/lib.rs index d1478ac4c..f0534f32c 100644 --- a/postgres-derive-test/src/lib.rs +++ b/postgres-derive-test/src/lib.rs @@ -14,7 +14,7 @@ where T: PartialEq + FromSqlOwned + ToSql + Sync, S: fmt::Display, { - for &(ref val, ref repr) in checks.iter() { + for (val, repr) in checks.iter() { let stmt = conn .prepare(&format!("SELECT {}::{}", *repr, sql_type)) .unwrap(); @@ -38,7 +38,7 @@ pub fn test_type_asymmetric<T, F, S, C>( S: fmt::Display, C: Fn(&T, &F) -> bool, { - for &(ref val, ref repr) in checks.iter() { + for (val, repr) in checks.iter() { let stmt = conn .prepare(&format!("SELECT {}::{}", *repr, sql_type)) .unwrap(); diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index 2a72cc60c..2fc5ecc28 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -12,8 +12,9 @@ readme = "../README.md" base64 = "0.20" byteorder = "1.0" bytes = "1.0" -fallible-iterator = "0.2" +fallible-iterator = "0.3" hmac = "0.12" +lazy_static = "1.4" md-5 = "0.10" memchr = "2.0" rand = "0.8" diff --git a/postgres-protocol/src/authentication/sasl.rs b/postgres-protocol/src/authentication/sasl.rs index ea2f55cad..41d0e41b0 100644 --- a/postgres-protocol/src/authentication/sasl.rs +++ b/postgres-protocol/src/authentication/sasl.rs @@ -96,14 +96,32 @@ impl ChannelBinding { } } +/// A pair of keys for the SCRAM-SHA-256 mechanism. +/// See <https://datatracker.ietf.org/doc/html/rfc5802#section-3> for details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ScramKeys<const N: usize> { + /// Used by server to authenticate client. + pub client_key: [u8; N], + /// Used by client to verify server's signature. + pub server_key: [u8; N], +} + +/// Password or keys which were derived from it. +enum Credentials<const N: usize> { + /// A regular password as a vector of bytes. + Password(Vec<u8>), + /// A precomputed pair of keys. + Keys(Box<ScramKeys<N>>), +} + enum State { Update { nonce: String, - password: Vec<u8>, + password: Credentials<32>, channel_binding: ChannelBinding, }, Finish { - salted_password: [u8; 32], + server_key: [u8; 32], auth_message: String, }, Done, @@ -129,30 +147,43 @@ pub struct ScramSha256 { state: State, } +fn nonce() -> String { + // rand 0.5's ThreadRng is cryptographically secure + let mut rng = rand::thread_rng(); + (0..NONCE_LENGTH) + .map(|_| { + let mut v = rng.gen_range(0x21u8..0x7e); + if v == 0x2c { + v = 0x7e + } + v as char + }) + .collect() +} + impl ScramSha256 { /// Constructs a new instance which will use the provided password for authentication. pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 { - // rand 0.5's ThreadRng is cryptographically secure - let mut rng = rand::thread_rng(); - let nonce = (0..NONCE_LENGTH) - .map(|_| { - let mut v = rng.gen_range(0x21u8..0x7e); - if v == 0x2c { - v = 0x7e - } - v as char - }) - .collect::<String>(); + let password = Credentials::Password(normalize(password)); + ScramSha256::new_inner(password, channel_binding, nonce()) + } - ScramSha256::new_inner(password, channel_binding, nonce) + /// Constructs a new instance which will use the provided key pair for authentication. + pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 { + let password = Credentials::Keys(keys.into()); + ScramSha256::new_inner(password, channel_binding, nonce()) } - fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 { + fn new_inner( + password: Credentials<32>, + channel_binding: ChannelBinding, + nonce: String, + ) -> ScramSha256 { ScramSha256 { message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce), state: State::Update { nonce, - password: normalize(password), + password, channel_binding, }, } @@ -189,20 +220,32 @@ impl ScramSha256 { return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce")); } - let salt = match base64::decode(parsed.salt) { - Ok(salt) => salt, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), - }; + let (client_key, server_key) = match password { + Credentials::Password(password) => { + let salt = match base64::decode(parsed.salt) { + Ok(salt) => salt, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; - let salted_password = hi(&password, &salt, parsed.iteration_count); + let salted_password = hi(&password, &salt, parsed.iteration_count); - let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes"); - hmac.update(b"Client Key"); - let client_key = hmac.finalize().into_bytes(); + let make_key = |name| { + let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(name); + + let mut key = [0u8; 32]; + key.copy_from_slice(hmac.finalize().into_bytes().as_slice()); + key + }; + + (make_key(b"Client Key"), make_key(b"Server Key")) + } + Credentials::Keys(keys) => (keys.client_key, keys.server_key), + }; let mut hash = Sha256::default(); - hash.update(client_key.as_slice()); + hash.update(client_key); let stored_key = hash.finalize_fixed(); let mut cbind_input = vec![]; @@ -225,10 +268,10 @@ impl ScramSha256 { *proof ^= signature; } - write!(&mut self.message, ",p={}", base64::encode(&*client_proof)).unwrap(); + write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap(); self.state = State::Finish { - salted_password, + server_key, auth_message, }; Ok(()) @@ -239,11 +282,11 @@ impl ScramSha256 { /// This should be called when the backend sends an `AuthenticationSASLFinal` message. /// Authentication has only succeeded if this method returns `Ok(())`. pub fn finish(&mut self, message: &[u8]) -> io::Result<()> { - let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) { + let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) { State::Finish { - salted_password, + server_key, auth_message, - } => (salted_password, auth_message), + } => (server_key, auth_message), _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), }; @@ -267,11 +310,6 @@ impl ScramSha256 { Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), }; - let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes"); - hmac.update(b"Server Key"); - let server_key = hmac.finalize().into_bytes(); - let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key) .expect("HMAC is able to accept all key sizes"); hmac.update(auth_message.as_bytes()); @@ -351,7 +389,7 @@ impl<'a> Parser<'a> { } fn posit_number(&mut self) -> io::Result<u32> { - let n = self.take_while(|c| matches!(c, '0'..='9'))?; + let n = self.take_while(|c| c.is_ascii_digit())?; n.parse() .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) } @@ -458,7 +496,7 @@ mod test { let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw="; let mut scram = ScramSha256::new_inner( - password.as_bytes(), + Credentials::Password(normalize(password.as_bytes())), ChannelBinding::unsupported(), nonce.to_string(), ); diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index 8b6ff508d..1f7aa7923 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -14,7 +14,9 @@ use byteorder::{BigEndian, ByteOrder}; use bytes::{BufMut, BytesMut}; +use lazy_static::lazy_static; use std::io; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; pub mod authentication; pub mod escape; @@ -28,6 +30,11 @@ pub type Oid = u32; /// A Postgres Log Sequence Number (LSN). pub type Lsn = u64; +lazy_static! { + /// Postgres epoch is 2000-01-01T00:00:00Z + pub static ref PG_EPOCH: SystemTime = UNIX_EPOCH + Duration::from_secs(946_684_800); +} + /// An enum indicating if a value is `NULL` or not. pub enum IsNull { /// The value is `NULL`. diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 3f5374d64..b6883cc3c 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -8,9 +8,11 @@ use std::cmp; use std::io::{self, Read}; use std::ops::Range; use std::str; +use std::time::{Duration, SystemTime}; -use crate::Oid; +use crate::{Lsn, Oid, PG_EPOCH}; +// top-level message tags pub const PARSE_COMPLETE_TAG: u8 = b'1'; pub const BIND_COMPLETE_TAG: u8 = b'2'; pub const CLOSE_COMPLETE_TAG: u8 = b'3'; @@ -22,6 +24,7 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -33,6 +36,33 @@ pub const PARAMETER_DESCRIPTION_TAG: u8 = b't'; pub const ROW_DESCRIPTION_TAG: u8 = b'T'; pub const READY_FOR_QUERY_TAG: u8 = b'Z'; +// replication message tags +pub const XLOG_DATA_TAG: u8 = b'w'; +pub const PRIMARY_KEEPALIVE_TAG: u8 = b'k'; + +// logical replication message tags +const BEGIN_TAG: u8 = b'B'; +const COMMIT_TAG: u8 = b'C'; +const ORIGIN_TAG: u8 = b'O'; +const RELATION_TAG: u8 = b'R'; +const TYPE_TAG: u8 = b'Y'; +const INSERT_TAG: u8 = b'I'; +const UPDATE_TAG: u8 = b'U'; +const DELETE_TAG: u8 = b'D'; +const TRUNCATE_TAG: u8 = b'T'; +const TUPLE_NEW_TAG: u8 = b'N'; +const TUPLE_KEY_TAG: u8 = b'K'; +const TUPLE_OLD_TAG: u8 = b'O'; +const TUPLE_DATA_NULL_TAG: u8 = b'n'; +const TUPLE_DATA_TOAST_TAG: u8 = b'u'; +const TUPLE_DATA_TEXT_TAG: u8 = b't'; + +// replica identity tags +const REPLICA_IDENTITY_DEFAULT_TAG: u8 = b'd'; +const REPLICA_IDENTITY_NOTHING_TAG: u8 = b'n'; +const REPLICA_IDENTITY_FULL_TAG: u8 = b'f'; +const REPLICA_IDENTITY_INDEX_TAG: u8 = b'i'; + #[derive(Debug, Copy, Clone)] pub struct Header { tag: u8, @@ -93,6 +123,7 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), + CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -190,6 +221,16 @@ impl Message { storage, }) } + COPY_BOTH_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::<BigEndian>()?; + let storage = buf.read_all(); + Message::CopyBothResponse(CopyBothResponseBody { + format, + len, + storage, + }) + } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::<BigEndian>()?; @@ -278,6 +319,69 @@ impl Message { } } +/// An enum representing Postgres backend replication messages. +#[non_exhaustive] +#[derive(Debug)] +pub enum ReplicationMessage<D> { + XLogData(XLogDataBody<D>), + PrimaryKeepAlive(PrimaryKeepAliveBody), +} + +impl ReplicationMessage<Bytes> { + #[inline] + pub fn parse(buf: &Bytes) -> io::Result<Self> { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let replication_message = match tag { + XLOG_DATA_TAG => { + let wal_start = buf.read_u64::<BigEndian>()?; + let wal_end = buf.read_u64::<BigEndian>()?; + let ts = buf.read_i64::<BigEndian>()?; + let timestamp = if ts > 0 { + *PG_EPOCH + Duration::from_micros(ts as u64) + } else { + *PG_EPOCH - Duration::from_micros(-ts as u64) + }; + let data = buf.read_all(); + ReplicationMessage::XLogData(XLogDataBody { + wal_start, + wal_end, + timestamp, + data, + }) + } + PRIMARY_KEEPALIVE_TAG => { + let wal_end = buf.read_u64::<BigEndian>()?; + let ts = buf.read_i64::<BigEndian>()?; + let timestamp = if ts > 0 { + *PG_EPOCH + Duration::from_micros(ts as u64) + } else { + *PG_EPOCH - Duration::from_micros(-ts as u64) + }; + let reply = buf.read_u8()?; + ReplicationMessage::PrimaryKeepAlive(PrimaryKeepAliveBody { + wal_end, + timestamp, + reply, + }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(replication_message) + } +} + struct Buffer { bytes: Bytes, idx: usize, @@ -524,6 +628,27 @@ impl CopyOutResponseBody { } } +pub struct CopyBothResponseBody { + storage: Bytes, + len: u16, + format: u8, +} + +impl CopyBothResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + #[derive(Debug)] pub struct DataRowBody { storage: Bytes, @@ -582,7 +707,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> { )); } let base = self.len - self.buf.len(); - self.buf = &self.buf[len as usize..]; + self.buf = &self.buf[len..]; Ok(Some(Some(base..base + len))) } } @@ -777,6 +902,655 @@ impl RowDescriptionBody { } } +#[derive(Debug)] +pub struct XLogDataBody<D> { + wal_start: u64, + wal_end: u64, + timestamp: SystemTime, + data: D, +} + +impl<D> XLogDataBody<D> { + #[inline] + pub fn wal_start(&self) -> u64 { + self.wal_start + } + + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> SystemTime { + self.timestamp + } + + #[inline] + pub fn data(&self) -> &D { + &self.data + } + + #[inline] + pub fn into_data(self) -> D { + self.data + } + + pub fn map_data<F, D2, E>(self, f: F) -> Result<XLogDataBody<D2>, E> + where + F: Fn(D) -> Result<D2, E>, + { + let data = f(self.data)?; + Ok(XLogDataBody { + wal_start: self.wal_start, + wal_end: self.wal_end, + timestamp: self.timestamp, + data, + }) + } +} + +#[derive(Debug)] +pub struct PrimaryKeepAliveBody { + wal_end: u64, + timestamp: SystemTime, + reply: u8, +} + +impl PrimaryKeepAliveBody { + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> SystemTime { + self.timestamp + } + + #[inline] + pub fn reply(&self) -> u8 { + self.reply + } +} + +#[non_exhaustive] +/// A message of the logical replication stream +#[derive(Debug)] +pub enum LogicalReplicationMessage { + /// A BEGIN statement + Begin(BeginBody), + /// A BEGIN statement + Commit(CommitBody), + /// An Origin replication message + /// Note that there can be multiple Origin messages inside a single transaction. + Origin(OriginBody), + /// A Relation replication message + Relation(RelationBody), + /// A Type replication message + Type(TypeBody), + /// An INSERT statement + Insert(InsertBody), + /// An UPDATE statement + Update(UpdateBody), + /// A DELETE statement + Delete(DeleteBody), + /// A TRUNCATE statement + Truncate(TruncateBody), +} + +impl LogicalReplicationMessage { + pub fn parse(buf: &Bytes) -> io::Result<Self> { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let logical_replication_message = match tag { + BEGIN_TAG => Self::Begin(BeginBody { + final_lsn: buf.read_u64::<BigEndian>()?, + timestamp: buf.read_i64::<BigEndian>()?, + xid: buf.read_u32::<BigEndian>()?, + }), + COMMIT_TAG => Self::Commit(CommitBody { + flags: buf.read_i8()?, + commit_lsn: buf.read_u64::<BigEndian>()?, + end_lsn: buf.read_u64::<BigEndian>()?, + timestamp: buf.read_i64::<BigEndian>()?, + }), + ORIGIN_TAG => Self::Origin(OriginBody { + commit_lsn: buf.read_u64::<BigEndian>()?, + name: buf.read_cstr()?, + }), + RELATION_TAG => { + let rel_id = buf.read_u32::<BigEndian>()?; + let namespace = buf.read_cstr()?; + let name = buf.read_cstr()?; + let replica_identity = match buf.read_u8()? { + REPLICA_IDENTITY_DEFAULT_TAG => ReplicaIdentity::Default, + REPLICA_IDENTITY_NOTHING_TAG => ReplicaIdentity::Nothing, + REPLICA_IDENTITY_FULL_TAG => ReplicaIdentity::Full, + REPLICA_IDENTITY_INDEX_TAG => ReplicaIdentity::Index, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replica identity tag `{}`", tag), + )); + } + }; + let column_len = buf.read_i16::<BigEndian>()?; + + let mut columns = Vec::with_capacity(column_len as usize); + for _ in 0..column_len { + columns.push(Column::parse(&mut buf)?); + } + + Self::Relation(RelationBody { + rel_id, + namespace, + name, + replica_identity, + columns, + }) + } + TYPE_TAG => Self::Type(TypeBody { + id: buf.read_u32::<BigEndian>()?, + namespace: buf.read_cstr()?, + name: buf.read_cstr()?, + }), + INSERT_TAG => { + let rel_id = buf.read_u32::<BigEndian>()?; + let tag = buf.read_u8()?; + + let tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + }; + + Self::Insert(InsertBody { rel_id, tuple }) + } + UPDATE_TAG => { + let rel_id = buf.read_u32::<BigEndian>()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + let new_tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + TUPLE_OLD_TAG | TUPLE_KEY_TAG => { + if tag == TUPLE_OLD_TAG { + old_tuple = Some(Tuple::parse(&mut buf)?); + } else { + key_tuple = Some(Tuple::parse(&mut buf)?); + } + + match buf.read_u8()? { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + } + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + }; + + Self::Update(UpdateBody { + rel_id, + key_tuple, + old_tuple, + new_tuple, + }) + } + DELETE_TAG => { + let rel_id = buf.read_u32::<BigEndian>()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + match tag { + TUPLE_OLD_TAG => old_tuple = Some(Tuple::parse(&mut buf)?), + TUPLE_KEY_TAG => key_tuple = Some(Tuple::parse(&mut buf)?), + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + } + + Self::Delete(DeleteBody { + rel_id, + key_tuple, + old_tuple, + }) + } + TRUNCATE_TAG => { + let relation_len = buf.read_i32::<BigEndian>()?; + let options = buf.read_i8()?; + + let mut rel_ids = Vec::with_capacity(relation_len as usize); + for _ in 0..relation_len { + rel_ids.push(buf.read_u32::<BigEndian>()?); + } + + Self::Truncate(TruncateBody { options, rel_ids }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(logical_replication_message) + } +} + +/// A row as it appears in the replication stream +#[derive(Debug)] +pub struct Tuple(Vec<TupleData>); + +impl Tuple { + #[inline] + /// The tuple data of this tuple + pub fn tuple_data(&self) -> &[TupleData] { + &self.0 + } +} + +impl Tuple { + fn parse(buf: &mut Buffer) -> io::Result<Self> { + let col_len = buf.read_i16::<BigEndian>()?; + let mut tuple = Vec::with_capacity(col_len as usize); + for _ in 0..col_len { + tuple.push(TupleData::parse(buf)?); + } + + Ok(Tuple(tuple)) + } +} + +/// A column as it appears in the replication stream +#[derive(Debug)] +pub struct Column { + flags: i8, + name: Bytes, + type_id: i32, + type_modifier: i32, +} + +impl Column { + #[inline] + /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as + /// part of the key. + pub fn flags(&self) -> i8 { + self.flags + } + + #[inline] + /// Name of the column. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// ID of the column's data type. + pub fn type_id(&self) -> i32 { + self.type_id + } + + #[inline] + /// Type modifier of the column (`atttypmod`). + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } +} + +impl Column { + fn parse(buf: &mut Buffer) -> io::Result<Self> { + Ok(Self { + flags: buf.read_i8()?, + name: buf.read_cstr()?, + type_id: buf.read_i32::<BigEndian>()?, + type_modifier: buf.read_i32::<BigEndian>()?, + }) + } +} + +/// The data of an individual column as it appears in the replication stream +#[derive(Debug)] +pub enum TupleData { + /// Represents a NULL value + Null, + /// Represents an unchanged TOASTed value (the actual value is not sent). + UnchangedToast, + /// Column data as text formatted value. + Text(Bytes), +} + +impl TupleData { + fn parse(buf: &mut Buffer) -> io::Result<Self> { + let type_tag = buf.read_u8()?; + + let tuple = match type_tag { + TUPLE_DATA_NULL_TAG => TupleData::Null, + TUPLE_DATA_TOAST_TAG => TupleData::UnchangedToast, + TUPLE_DATA_TEXT_TAG => { + let len = buf.read_i32::<BigEndian>()?; + let mut data = vec![0; len as usize]; + buf.read_exact(&mut data)?; + TupleData::Text(data.into()) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(tuple) + } +} + +/// A BEGIN statement +#[derive(Debug)] +pub struct BeginBody { + final_lsn: u64, + timestamp: i64, + xid: u32, +} + +impl BeginBody { + #[inline] + /// Gets the final lsn of the transaction + pub fn final_lsn(&self) -> Lsn { + self.final_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Xid of the transaction. + pub fn xid(&self) -> u32 { + self.xid + } +} + +/// A COMMIT statement +#[derive(Debug)] +pub struct CommitBody { + flags: i8, + commit_lsn: u64, + end_lsn: u64, + timestamp: i64, +} + +impl CommitBody { + #[inline] + /// The LSN of the commit. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// The end LSN of the transaction. + pub fn end_lsn(&self) -> Lsn { + self.end_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Flags; currently unused (will be 0). + pub fn flags(&self) -> i8 { + self.flags + } +} + +/// An Origin replication message +/// +/// Note that there can be multiple Origin messages inside a single transaction. +#[derive(Debug)] +pub struct OriginBody { + commit_lsn: u64, + name: Bytes, +} + +impl OriginBody { + #[inline] + /// The LSN of the commit on the origin server. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// Name of the origin. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// Describes the REPLICA IDENTITY setting of a table +#[derive(Debug)] +pub enum ReplicaIdentity { + /// default selection for replica identity (primary key or nothing) + Default, + /// no replica identity is logged for this relation + Nothing, + /// all columns are logged as replica identity + Full, + /// An explicitly chosen candidate key's columns are used as replica identity. + /// Note this will still be set if the index has been dropped; in that case it + /// has the same meaning as 'd'. + Index, +} + +/// A Relation replication message +#[derive(Debug)] +pub struct RelationBody { + rel_id: u32, + namespace: Bytes, + name: Bytes, + replica_identity: ReplicaIdentity, + columns: Vec<Column>, +} + +impl RelationBody { + #[inline] + /// ID of the relation. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Relation name. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// Replica identity setting for the relation + pub fn replica_identity(&self) -> &ReplicaIdentity { + &self.replica_identity + } + + #[inline] + /// The column definitions of this relation + pub fn columns(&self) -> &[Column] { + &self.columns + } +} + +/// A Type replication message +#[derive(Debug)] +pub struct TypeBody { + id: u32, + namespace: Bytes, + name: Bytes, +} + +impl TypeBody { + #[inline] + /// ID of the data type. + pub fn id(&self) -> Oid { + self.id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Name of the data type. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// An INSERT statement +#[derive(Debug)] +pub struct InsertBody { + rel_id: u32, + tuple: Tuple, +} + +impl InsertBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// The inserted tuple + pub fn tuple(&self) -> &Tuple { + &self.tuple + } +} + +/// An UPDATE statement +#[derive(Debug)] +pub struct UpdateBody { + rel_id: u32, + old_tuple: Option<Tuple>, + key_tuple: Option<Tuple>, + new_tuple: Tuple, +} + +impl UpdateBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is optional and is only present if the update changed data in any of the + /// column(s) that are part of the REPLICA IDENTITY index. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is optional and is only present if table in which the update happened has + /// REPLICA IDENTITY set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } + + #[inline] + /// The new tuple + pub fn new_tuple(&self) -> &Tuple { + &self.new_tuple + } +} + +/// A DELETE statement +#[derive(Debug)] +pub struct DeleteBody { + rel_id: u32, + old_tuple: Option<Tuple>, + key_tuple: Option<Tuple>, +} + +impl DeleteBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is present if the table in which the delete has happened uses an index as + /// REPLICA IDENTITY. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is present if the table in which the delete has happened has REPLICA IDENTITY + /// set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } +} + +/// A TRUNCATE statement +#[derive(Debug)] +pub struct TruncateBody { + options: i8, + rel_ids: Vec<u32>, +} + +impl TruncateBody { + #[inline] + /// The IDs of the relations corresponding to the ID in the relation messages + pub fn rel_ids(&self) -> &[u32] { + &self.rel_ids + } + + #[inline] + /// Option bits for TRUNCATE: 1 for CASCADE, 2 for RESTART IDENTITY + pub fn options(&self) -> i8 { + self.options + } +} + pub struct Fields<'a> { buf: &'a [u8], remaining: u16, diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 70f1ed54a..73e006ba0 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -29,7 +29,7 @@ with-time-0_3 = ["time-03"] [dependencies] bytes = "1.0" -fallible-iterator = "0.2" +fallible-iterator = "0.3" postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } postgres-derive = { version = "0.4.2", optional = true, path = "../postgres-derive" } diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..f4caa892f 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -395,6 +395,22 @@ impl WrongType { } } +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + /// A trait for types that can be created from a Postgres value. /// /// # Types @@ -846,7 +862,7 @@ pub trait ToSql: fmt::Debug { /// Supported Postgres message format types /// /// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum Format { /// Text format (UTF-8) Text, diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index bd7c297f3..c4665361e 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -37,7 +37,7 @@ with-time-0_3 = ["tokio-postgres/with-time-0_3"] [dependencies] bytes = "1.0" -fallible-iterator = "0.2" +fallible-iterator = "0.3" futures-util = { version = "0.3", features = ["sink"] } tokio-postgres = { version = "0.7.7", path = "../tokio-postgres" } diff --git a/postgres/src/config.rs b/postgres/src/config.rs index b541ec846..44e4bec3a 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -12,7 +12,9 @@ use std::sync::Arc; use std::time::Duration; use tokio::runtime; #[doc(inline)] -pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs}; +pub use tokio_postgres::config::{ + AuthKeys, ChannelBinding, Host, ScramKeys, SslMode, TargetSessionAttrs, +}; use tokio_postgres::error::DbError; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::{Error, Socket}; @@ -149,6 +151,20 @@ impl Config { self.config.get_password() } + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.config.auth_keys(keys); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option<AuthKeys> { + self.config.get_auth_keys() + } + /// Sets the name of the database to connect to. /// /// Defaults to the user. diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 68737f738..72c420023 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -45,7 +45,7 @@ with-time-0_3 = ["postgres-types/with-time-0_3"] async-trait = "0.1" bytes = "1.0" byteorder = "1.0" -fallible-iterator = "0.2" +fallible-iterator = "0.3" futures-channel = { version = "0.3", features = ["sink"] } futures-util = { version = "0.3", features = ["sink"] } log = "0.4" diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index ad5aa2866..37cdd6827 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -3,11 +3,14 @@ use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::Host; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; +use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; +use crate::prepare::get_type; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; +use crate::statement::Column; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; @@ -15,10 +18,11 @@ use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, - Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, + copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, + CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, + TransactionBuilder, }; -use bytes::{Buf, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; @@ -372,6 +376,87 @@ impl Client { query::query(&self.inner, statement, params).await } + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, Error> + where + S: AsRef<str>, + I: IntoIterator<Item = S>, + I::IntoIter: ExactSizeIterator, + { + let params = params.into_iter(); + let params_len = params.len(); + + let buf = self.inner.with_buf(|buf| { + // Parse, anonymous portal + frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // empty string selects the unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Describe portal to typecast results + frontend::describe(b'P', "", buf).map_err(Error::encode)?; + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + let mut responses = self + .inner + .send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + // now read the responses + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + // construct statement object + + let parameters = vec![Type::UNKNOWN; params_len]; + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(&self.inner, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_); + columns.push(column); + } + } + + let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); + + Ok(RowStream::new(statement, responses)) + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list @@ -439,6 +524,14 @@ impl Client { copy_in::copy_in(self.inner(), statement).await } + /// Executes a `COPY FROM STDIN` query, returning a sink used to write the copy data. + pub async fn copy_in_simple<U>(&self, query: &str) -> Result<CopyInSink<U>, Error> + where + U: Buf + 'static + Send, + { + copy_in::copy_in_simple(self.inner(), query).await + } + /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data. /// /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. @@ -454,6 +547,20 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } + /// Executes a `COPY TO STDOUT` query, returning a stream of the resulting data. + pub async fn copy_out_simple(&self, query: &str) -> Result<CopyOutStream, Error> { + copy_out::copy_out_simple(self.inner(), query).await + } + + /// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy + /// data. + pub async fn copy_both_simple<T>(&self, query: &str) -> Result<CopyBothDuplex<T>, Error> + where + T: Buf + 'static + Send, + { + copy_both::copy_both_simple(self.inner(), query).await + } + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that diff --git a/tokio-postgres/src/codec.rs b/tokio-postgres/src/codec.rs index 9d078044b..23c371542 100644 --- a/tokio-postgres/src/codec.rs +++ b/tokio-postgres/src/codec.rs @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages { } } -pub struct PostgresCodec; +pub struct PostgresCodec { + pub max_message_size: Option<usize>, +} impl Encoder<FrontendMessage> for PostgresCodec { type Error = io::Error; @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec { break; } + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 5b364ec06..fdb5e6359 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -23,6 +23,8 @@ use std::time::Duration; use std::{error, fmt, iter, mem}; use tokio::io::{AsyncRead, AsyncWrite}; +pub use postgres_protocol::authentication::sasl::ScramKeys; + /// Properties required of a session. #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[non_exhaustive] @@ -57,6 +59,16 @@ pub enum ChannelBinding { Require, } +/// Replication mode configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -69,6 +81,13 @@ pub enum Host { Unix(PathBuf), } +/// Precomputed keys which may override password during auth. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AuthKeys { + /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`. + ScramSha256(ScramKeys<32>), +} + /// Connection configuration. /// /// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: @@ -153,6 +172,7 @@ pub enum Host { pub struct Config { pub(crate) user: Option<String>, pub(crate) password: Option<Vec<u8>>, + pub(crate) auth_keys: Option<Box<AuthKeys>>, pub(crate) dbname: Option<String>, pub(crate) options: Option<String>, pub(crate) application_name: Option<String>, @@ -164,6 +184,8 @@ pub struct Config { pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, + pub(crate) replication_mode: Option<ReplicationMode>, + pub(crate) max_backend_message_size: Option<usize>, } impl Default for Config { @@ -183,6 +205,7 @@ impl Config { Config { user: None, password: None, + auth_keys: None, dbname: None, options: None, application_name: None, @@ -194,6 +217,8 @@ impl Config { keepalive_config, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, + replication_mode: None, + max_backend_message_size: None, } } @@ -226,6 +251,20 @@ impl Config { self.password.as_deref() } + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.auth_keys = Some(Box::new(keys)); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option<AuthKeys> { + self.auth_keys.as_deref().copied() + } + /// Sets the name of the database to connect to. /// /// Defaults to the user. @@ -424,6 +463,28 @@ impl Config { self.channel_binding } + /// Set replication mode. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option<ReplicationMode> { + self.replication_mode + } + + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option<usize> { + self.max_backend_message_size + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -527,6 +588,25 @@ impl Config { }; self.channel_binding(channel_binding); } + "replication" => { + let mode = match value { + "off" => None, + "true" => Some(ReplicationMode::Physical), + "database" => Some(ReplicationMode::Logical), + _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), + }; + if let Some(mode) = mode { + self.replication_mode(mode); + } + } + "max_backend_message_size" => { + let limit = value.parse::<usize>().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -601,6 +681,7 @@ impl fmt::Debug for Config { .field("keepalives_retries", &self.keepalive_config.retries) .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) .finish() } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index d97636221..0beead11f 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Config}; +use crate::config::{self, AuthKeys, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -90,7 +90,12 @@ where let stream = connect_tls(stream, config.ssl_mode, tls).await?; let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), buf: BackendMessages::empty(), delayed: VecDeque::new(), }; @@ -124,6 +129,12 @@ where if let Some(application_name) = &config.application_name { params.push(("application_name", &**application_name)); } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } let mut buf = BytesMut::new(); frontend::startup_message(params, &mut buf).map_err(Error::encode)?; @@ -228,11 +239,6 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { - let password = config - .password - .as_ref() - .ok_or_else(|| Error::config("password missing".into()))?; - let mut has_scram = false; let mut has_scram_plus = false; let mut mechanisms = body.mechanisms(); @@ -270,7 +276,13 @@ where can_skip_channel_binding(config)?; } - let mut scram = ScramSha256::new(password, channel_binding); + let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() { + ScramSha256::new_with_keys(keys, channel_binding) + } else if let Some(password) = config.get_password() { + ScramSha256::new(password, channel_binding) + } else { + return Err(Error::config("password or auth keys missing".into())); + }; let mut buf = BytesMut::new(); frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 30be4e834..8b8c7b6fa 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -1,4 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::copy_both::CopyBothReceiver; use crate::copy_in::CopyInReceiver; use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; @@ -20,6 +21,7 @@ use tokio_util::codec::Framed; pub enum RequestMessages { Single(FrontendMessage), CopyIn(CopyInReceiver), + CopyBoth(CopyBothReceiver), } pub struct Request { @@ -47,8 +49,10 @@ enum State { /// occurred, or because its associated `Client` has dropped and all outstanding work has completed. #[must_use = "futures do nothing unless polled"] pub struct Connection<S, T> { - stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>, - parameters: HashMap<String, String>, + /// HACK: we need this in the Neon Proxy. + pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>, + /// HACK: we need this in the Neon Proxy to forward params. + pub parameters: HashMap<String, String>, receiver: mpsc::UnboundedReceiver<Request>, pending_request: Option<RequestMessages>, pending_responses: VecDeque<BackendMessage>, @@ -258,6 +262,24 @@ where .map_err(Error::io)?; self.pending_request = Some(RequestMessages::CopyIn(receiver)); } + RequestMessages::CopyBoth(mut receiver) => { + let message = match receiver.poll_next_unpin(cx) { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => { + trace!("poll_write: finished copy_both request"); + continue; + } + Poll::Pending => { + trace!("poll_write: waiting on copy_both stream"); + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + return Ok(true); + } + }; + Pin::new(&mut self.stream) + .start_send(message) + .map_err(Error::io)?; + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + } } } } diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs new file mode 100644 index 000000000..79a7be34a --- /dev/null +++ b/tokio-postgres/src/copy_both.rs @@ -0,0 +1,248 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{simple_query, Error}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures_channel::mpsc; +use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::CopyData; +use std::marker::{PhantomData, PhantomPinned}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub(crate) enum CopyBothMessage { + Message(FrontendMessage), + Done, +} + +pub struct CopyBothReceiver { + receiver: mpsc::Receiver<CopyBothMessage>, + done: bool, +} + +impl CopyBothReceiver { + pub(crate) fn new(receiver: mpsc::Receiver<CopyBothMessage>) -> CopyBothReceiver { + CopyBothReceiver { + receiver, + done: false, + } + } +} + +impl Stream for CopyBothReceiver { + type Item = FrontendMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> { + if self.done { + return Poll::Ready(None); + } + + match ready!(self.receiver.poll_next_unpin(cx)) { + Some(CopyBothMessage::Message(message)) => Poll::Ready(Some(message)), + Some(CopyBothMessage::Done) => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + frontend::sync(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + None => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_fail("", &mut buf).unwrap(); + frontend::sync(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + } + } +} + +enum SinkState { + Active, + Closing, + Reading, +} + +pin_project! { + /// A sink for `COPY ... FROM STDIN` query data. + /// + /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is + /// not, the copy will be aborted. + pub struct CopyBothDuplex<T> { + #[pin] + sender: mpsc::Sender<CopyBothMessage>, + responses: Responses, + buf: BytesMut, + state: SinkState, + #[pin] + _p: PhantomPinned, + _p2: PhantomData<T>, + } +} + +impl<T> CopyBothDuplex<T> +where + T: Buf + 'static + Send, +{ + pub(crate) fn new(sender: mpsc::Sender<CopyBothMessage>, responses: Responses) -> Self { + Self { + sender, + responses, + buf: BytesMut::new(), + state: SinkState::Active, + _p: PhantomPinned, + _p2: PhantomData, + } + } + + /// A poll-based version of `finish`. + pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64, Error>> { + loop { + match self.state { + SinkState::Active => { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + this.sender + .start_send(CopyBothMessage::Done) + .map_err(|_| Error::closed())?; + *this.state = SinkState::Closing; + } + SinkState::Closing => { + let this = self.as_mut().project(); + ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?; + *this.state = SinkState::Reading; + } + SinkState::Reading => { + let this = self.as_mut().project(); + match ready!(this.responses.poll_next(cx))? { + Message::CommandComplete(body) => { + let rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + return Poll::Ready(Ok(rows)); + } + _ => return Poll::Ready(Err(Error::unexpected_message())), + } + } + } + } + } + + /// Completes the copy, returning the number of rows inserted. + /// + /// The `Sink::close` method is equivalent to `finish`, except that it does not return the + /// number of rows. + pub async fn finish(mut self: Pin<&mut Self>) -> Result<u64, Error> { + future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await + } +} + +impl<T> Stream for CopyBothDuplex<T> { + type Item = Result<Bytes, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let this = self.project(); + + match ready!(this.responses.poll_next(cx)?) { + Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), + Message::CopyDone => Poll::Ready(None), + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } +} + +impl<T> Sink<T> for CopyBothDuplex<T> +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { + self.project() + .sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + let this = self.project(); + + let data: Box<dyn Buf + Send> = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .as_mut() + .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed())?; + } + + this.sender.poll_flush(cx).map_err(|_| Error::closed()) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { + self.poll_finish(cx).map_ok(|_| ()) + } +} + +pub async fn copy_both_simple<T>( + client: &InnerClient, + query: &str, +) -> Result<CopyBothDuplex<T>, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy both query {}", query); + + let buf = simple_query::encode(client, query)?; + + let (mut sender, receiver) = mpsc::channel(1); + let receiver = CopyBothReceiver::new(receiver); + let mut responses = client.send(RequestMessages::CopyBoth(receiver))?; + + sender + .send(CopyBothMessage::Message(FrontendMessage::Raw(buf))) + .await + .map_err(|_| Error::closed())?; + + match responses.next().await? { + Message::CopyBothResponse(_) => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(CopyBothDuplex::new(sender, responses)) +} diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index de1da933b..cd3c4f8db 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -1,8 +1,8 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, slice_iter, Error, Statement}; -use bytes::{Buf, BufMut, BytesMut}; +use crate::{query, simple_query, slice_iter, Error, Statement}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures_channel::mpsc; use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; use log::debug; @@ -194,14 +194,10 @@ where } } -pub async fn copy_in<T>(client: &InnerClient, statement: Statement) -> Result<CopyInSink<T>, Error> +async fn start<T>(client: &InnerClient, buf: Bytes, simple: bool) -> Result<CopyInSink<T>, Error> where T: Buf + 'static + Send, { - debug!("executing copy in statement {}", statement.name()); - - let buf = query::encode(client, &statement, slice_iter(&[]))?; - let (mut sender, receiver) = mpsc::channel(1); let receiver = CopyInReceiver::new(receiver); let mut responses = client.send(RequestMessages::CopyIn(receiver))?; @@ -211,9 +207,11 @@ where .await .map_err(|_| Error::closed())?; - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + if !simple { + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } } match responses.next().await? { @@ -230,3 +228,23 @@ where _p2: PhantomData, }) } + +pub async fn copy_in<T>(client: &InnerClient, statement: Statement) -> Result<CopyInSink<T>, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in statement {}", statement.name()); + + let buf = query::encode(client, &statement, slice_iter(&[]))?; + start(client, buf, false).await +} + +pub async fn copy_in_simple<T>(client: &InnerClient, query: &str) -> Result<CopyInSink<T>, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in query {}", query); + + let buf = simple_query::encode(client, query)?; + start(client, buf, true).await +} diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 1e6949252..981f9365e 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, slice_iter, Error, Statement}; +use crate::{query, simple_query, slice_iter, Error, Statement}; use bytes::Bytes; use futures_util::{ready, Stream}; use log::debug; @@ -11,23 +11,36 @@ use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; +pub async fn copy_out_simple(client: &InnerClient, query: &str) -> Result<CopyOutStream, Error> { + debug!("executing copy out query {}", query); + + let buf = simple_query::encode(client, query)?; + let responses = start(client, buf, true).await?; + Ok(CopyOutStream { + responses, + _p: PhantomPinned, + }) +} + pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result<CopyOutStream, Error> { debug!("executing copy out statement {}", statement.name()); let buf = query::encode(client, &statement, slice_iter(&[]))?; - let responses = start(client, buf).await?; + let responses = start(client, buf, false).await?; Ok(CopyOutStream { responses, _p: PhantomPinned, }) } -async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> { +async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result<Responses, Error> { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + if !simple { + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } } match responses.next().await? { diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a9ecba4f1..17bb28409 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -159,15 +159,17 @@ mod connect_raw; mod connect_socket; mod connect_tls; mod connection; +mod copy_both; mod copy_in; mod copy_out; pub mod error; mod generic_client; mod keepalive; -mod maybe_tls_stream; +pub mod maybe_tls_stream; mod portal; mod prepare; mod query; +pub mod replication; pub mod row; mod simple_query; #[cfg(feature = "runtime")] diff --git a/tokio-postgres/src/maybe_tls_stream.rs b/tokio-postgres/src/maybe_tls_stream.rs index 73b0c4721..9a7e24899 100644 --- a/tokio-postgres/src/maybe_tls_stream.rs +++ b/tokio-postgres/src/maybe_tls_stream.rs @@ -1,11 +1,17 @@ +//! MaybeTlsStream. +//! +//! Represents a stream that may or may not be encrypted with TLS. use crate::tls::{ChannelBinding, TlsStream}; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +/// A stream that may or may not be encrypted with TLS. pub enum MaybeTlsStream<S, T> { + /// An unencrypted stream. Raw(S), + /// An encrypted stream. Tls(T), } diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..ba8d5a43e 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -126,7 +126,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> { +pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 71db8769a..a486b4f88 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -52,6 +52,7 @@ where Ok(RowStream { statement, responses, + command_tag: None, _p: PhantomPinned, }) } @@ -72,6 +73,7 @@ pub async fn query_portal( Ok(RowStream { statement: portal.statement().clone(), responses, + command_tag: None, _p: PhantomPinned, }) } @@ -202,11 +204,24 @@ pin_project! { pub struct RowStream { statement: Statement, responses: Responses, + command_tag: Option<String>, #[pin] _p: PhantomPinned, } } +impl RowStream { + /// Creates a new `RowStream`. + pub fn new(statement: Statement, responses: Responses) -> Self { + RowStream { + statement, + responses, + command_tag: None, + _p: PhantomPinned, + } + } +} + impl Stream for RowStream { type Item = Result<Row, Error>; @@ -217,12 +232,24 @@ impl Stream for RowStream { Message::DataRow(body) => { return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) } - Message::EmptyQueryResponse - | Message::CommandComplete(_) - | Message::PortalSuspended => {} + Message::EmptyQueryResponse | Message::PortalSuspended => {} + Message::CommandComplete(body) => { + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } + } Message::ReadyForQuery(_) => return Poll::Ready(None), _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } } } } + +impl RowStream { + /// Returns the command tag of this query. + /// + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option<String> { + self.command_tag.clone() + } +} diff --git a/tokio-postgres/src/replication.rs b/tokio-postgres/src/replication.rs new file mode 100644 index 000000000..e7c845958 --- /dev/null +++ b/tokio-postgres/src/replication.rs @@ -0,0 +1,201 @@ +//! Utilities for working with the PostgreSQL replication copy both format. + +use crate::copy_both::CopyBothDuplex; +use crate::Error; +use bytes::{BufMut, Bytes, BytesMut}; +use futures_util::{ready, SinkExt, Stream}; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::{LogicalReplicationMessage, ReplicationMessage}; +use postgres_protocol::PG_EPOCH; +use postgres_types::PgLsn; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::SystemTime; + +const STANDBY_STATUS_UPDATE_TAG: u8 = b'r'; +const HOT_STANDBY_FEEDBACK_TAG: u8 = b'h'; +const ZENITH_STATUS_UPDATE_TAG_BYTE: u8 = b'z'; + +pin_project! { + /// A type which deserializes the postgres replication protocol. This type can be used with + /// both physical and logical replication to get access to the byte content of each replication + /// message. + /// + /// The replication *must* be explicitly completed via the `finish` method. + pub struct ReplicationStream { + #[pin] + stream: CopyBothDuplex<Bytes>, + } +} + +impl ReplicationStream { + /// Creates a new ReplicationStream that will wrap the underlying CopyBoth stream + pub fn new(stream: CopyBothDuplex<Bytes>) -> Self { + Self { stream } + } + + /// Send zenith status update to server. + pub async fn zenith_status_update( + self: Pin<&mut Self>, + len: u64, + data: &[u8], + ) -> Result<(), Error> { + let mut this = self.project(); + + let mut buf = BytesMut::new(); + buf.put_u8(ZENITH_STATUS_UPDATE_TAG_BYTE); + buf.put_u64(len); + buf.put_slice(data); + + this.stream.send(buf.freeze()).await + } + + /// Send standby update to server. + pub async fn standby_status_update( + self: Pin<&mut Self>, + write_lsn: PgLsn, + flush_lsn: PgLsn, + apply_lsn: PgLsn, + timestamp: SystemTime, + reply: u8, + ) -> Result<(), Error> { + let mut this = self.project(); + + let timestamp = match timestamp.duration_since(*PG_EPOCH) { + Ok(d) => d.as_micros() as i64, + Err(e) => -(e.duration().as_micros() as i64), + }; + + let mut buf = BytesMut::new(); + buf.put_u8(STANDBY_STATUS_UPDATE_TAG); + buf.put_u64(write_lsn.into()); + buf.put_u64(flush_lsn.into()); + buf.put_u64(apply_lsn.into()); + buf.put_i64(timestamp); + buf.put_u8(reply); + + this.stream.send(buf.freeze()).await + } + + /// Send hot standby feedback message to server. + pub async fn hot_standby_feedback( + self: Pin<&mut Self>, + timestamp: SystemTime, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + ) -> Result<(), Error> { + let mut this = self.project(); + + let timestamp = match timestamp.duration_since(*PG_EPOCH) { + Ok(d) => d.as_micros() as i64, + Err(e) => -(e.duration().as_micros() as i64), + }; + + let mut buf = BytesMut::new(); + buf.put_u8(HOT_STANDBY_FEEDBACK_TAG); + buf.put_i64(timestamp); + buf.put_u32(global_xmin); + buf.put_u32(global_xmin_epoch); + buf.put_u32(catalog_xmin); + buf.put_u32(catalog_xmin_epoch); + + this.stream.send(buf.freeze()).await + } +} + +impl Stream for ReplicationStream { + type Item = Result<ReplicationMessage<Bytes>, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let this = self.project(); + + match ready!(this.stream.poll_next(cx)) { + Some(Ok(buf)) => { + Poll::Ready(Some(ReplicationMessage::parse(&buf).map_err(Error::parse))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} + +pin_project! { + /// A type which deserializes the postgres logical replication protocol. This type gives access + /// to a high level representation of the changes in transaction commit order. + /// + /// The replication *must* be explicitly completed via the `finish` method. + pub struct LogicalReplicationStream { + #[pin] + stream: ReplicationStream, + } +} + +impl LogicalReplicationStream { + /// Creates a new LogicalReplicationStream that will wrap the underlying CopyBoth stream + pub fn new(stream: CopyBothDuplex<Bytes>) -> Self { + Self { + stream: ReplicationStream::new(stream), + } + } + + /// Send standby update to server. + pub async fn standby_status_update( + self: Pin<&mut Self>, + write_lsn: PgLsn, + flush_lsn: PgLsn, + apply_lsn: PgLsn, + timestamp: SystemTime, + reply: u8, + ) -> Result<(), Error> { + let this = self.project(); + this.stream + .standby_status_update(write_lsn, flush_lsn, apply_lsn, timestamp, reply) + .await + } + + /// Send hot standby feedback message to server. + pub async fn hot_standby_feedback( + self: Pin<&mut Self>, + timestamp: SystemTime, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + ) -> Result<(), Error> { + let this = self.project(); + this.stream + .hot_standby_feedback( + timestamp, + global_xmin, + global_xmin_epoch, + catalog_xmin, + catalog_xmin_epoch, + ) + .await + } +} + +impl Stream for LogicalReplicationStream { + type Item = Result<ReplicationMessage<LogicalReplicationMessage>, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let this = self.project(); + + match ready!(this.stream.poll_next(cx)) { + Some(Ok(ReplicationMessage::XLogData(body))) => { + let body = body + .map_data(|buf| LogicalReplicationMessage::parse(&buf)) + .map_err(Error::parse)?; + Poll::Ready(Some(Ok(ReplicationMessage::XLogData(body)))) + } + Some(Ok(ReplicationMessage::PrimaryKeepAlive(body))) => { + Poll::Ready(Some(Ok(ReplicationMessage::PrimaryKeepAlive(body)))) + } + Some(Ok(_)) => Poll::Ready(Some(Err(Error::unexpected_message()))), + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..ce4efed7e 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType}; use crate::{Error, Statement}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::DataRowBody; +use postgres_types::{Format, WrongFormat}; use std::fmt; use std::ops::Range; use std::str; @@ -187,6 +188,27 @@ impl Row { let range = self.ranges[idx].to_owned()?; Some(&self.body.buffer()[range]) } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result<Option<&str>, Error> { + if self.statement.output_format() == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } } impl AsName for SimpleColumn { diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 7c266e409..70f48a7d8 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -62,7 +62,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro } } -fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> { +pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> { client.with_buf(|buf| { frontend::query(query, buf).map_err(Error::encode)?; Ok(buf.split().freeze()) diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..b7ab11866 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -3,6 +3,7 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; use postgres_protocol::message::frontend; +use postgres_types::Format; use std::{ fmt, sync::{Arc, Weak}, @@ -13,6 +14,7 @@ struct StatementInner { name: String, params: Vec<Type>, columns: Vec<Column>, + output_format: Format, } impl Drop for StatementInner { @@ -46,6 +48,22 @@ impl Statement { name, params, columns, + output_format: Format::Binary, + })) + } + + pub(crate) fn new_text( + inner: &Arc<InnerClient>, + name: String, + params: Vec<Type>, + columns: Vec<Column>, + ) -> Statement { + Statement(Arc::new(StatementInner { + client: Arc::downgrade(inner), + name, + params, + columns, + output_format: Format::Text, })) } @@ -62,6 +80,11 @@ impl Statement { pub fn columns(&self) -> &[Column] { &self.0.columns } + + /// Returns output format for the statement. + pub fn output_format(&self) -> Format { + self.0.output_format + } } /// Information about a column of a query. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..551f6ec5c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -22,6 +22,8 @@ use tokio_postgres::{ mod binary_copy; mod parse; #[cfg(feature = "runtime")] +mod replication; +#[cfg(feature = "runtime")] mod runtime; mod types; @@ -249,6 +251,78 @@ async fn custom_array() { } } +#[tokio::test] +async fn query_raw_txt() { + let client = connect("user=postgres").await; + + let rows: Vec<tokio_postgres::Row> = client + .query_raw_txt("SELECT 55 * $1", ["42"]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + let res: i32 = rows[0].as_text(0).unwrap().unwrap().parse::<i32>().unwrap(); + assert_eq!(res, 55 * 42); + + let rows: Vec<tokio_postgres::Row> = client + .query_raw_txt("SELECT $1", ["42"]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "42"); + assert!(rows[0].body_len() > 0); +} + +#[tokio::test] +async fn limit_max_backend_message_size() { + let client = connect("user=postgres max_backend_message_size=10000").await; + let small: Vec<tokio_postgres::Row> = client + .query_raw_txt("SELECT REPEAT('a', 20)", []) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(small.len(), 1); + assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); + + let large: Result<Vec<tokio_postgres::Row>, Error> = client + .query_raw_txt("SELECT REPEAT('a', 2000000)", []) + .await + .unwrap() + .try_collect() + .await; + + assert!(large.is_err()); +} + +#[tokio::test] +async fn command_tag() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("select unnest('{1,2,3}'::int[]);", []) + .await + .unwrap(); + + pin_mut!(row_stream); + + let mut rows: Vec<tokio_postgres::Row> = Vec::new(); + while let Some(row) = row_stream.next().await { + rows.push(row.unwrap()); + } + + assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await; diff --git a/tokio-postgres/tests/test/replication.rs b/tokio-postgres/tests/test/replication.rs new file mode 100644 index 000000000..c176a4104 --- /dev/null +++ b/tokio-postgres/tests/test/replication.rs @@ -0,0 +1,146 @@ +use futures_util::StreamExt; +use std::time::SystemTime; + +use postgres_protocol::message::backend::LogicalReplicationMessage::{Begin, Commit, Insert}; +use postgres_protocol::message::backend::ReplicationMessage::*; +use postgres_protocol::message::backend::TupleData; +use postgres_types::PgLsn; +use tokio_postgres::replication::LogicalReplicationStream; +use tokio_postgres::NoTls; +use tokio_postgres::SimpleQueryMessage::Row; + +#[tokio::test] +async fn test_replication() { + // form SQL connection + let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database"; + let (client, connection) = tokio_postgres::connect(conninfo, NoTls).await.unwrap(); + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + client + .simple_query("DROP TABLE IF EXISTS test_logical_replication") + .await + .unwrap(); + client + .simple_query("CREATE TABLE test_logical_replication(i int)") + .await + .unwrap(); + let res = client + .simple_query("SELECT 'test_logical_replication'::regclass::oid") + .await + .unwrap(); + let rel_id: u32 = if let Row(row) = &res[0] { + row.get("oid").unwrap().parse().unwrap() + } else { + panic!("unexpeced query message"); + }; + + client + .simple_query("DROP PUBLICATION IF EXISTS test_pub") + .await + .unwrap(); + client + .simple_query("CREATE PUBLICATION test_pub FOR ALL TABLES") + .await + .unwrap(); + + let slot = "test_logical_slot"; + + let query = format!( + r#"CREATE_REPLICATION_SLOT {:?} TEMPORARY LOGICAL "pgoutput""#, + slot + ); + let slot_query = client.simple_query(&query).await.unwrap(); + let lsn = if let Row(row) = &slot_query[0] { + row.get("consistent_point").unwrap() + } else { + panic!("unexpeced query message"); + }; + + // issue a query that will appear in the slot's stream since it happened after its creation + client + .simple_query("INSERT INTO test_logical_replication VALUES (42)") + .await + .unwrap(); + + let options = r#"("proto_version" '1', "publication_names" 'test_pub')"#; + let query = format!( + r#"START_REPLICATION SLOT {:?} LOGICAL {} {}"#, + slot, lsn, options + ); + let copy_stream = client + .copy_both_simple::<bytes::Bytes>(&query) + .await + .unwrap(); + + let stream = LogicalReplicationStream::new(copy_stream); + tokio::pin!(stream); + + // verify that we can observe the transaction in the replication stream + let begin = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Begin(begin) = body.into_data() { + break begin; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + let insert = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Insert(insert) = body.into_data() { + break insert; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + let commit = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Commit(commit) = body.into_data() { + break commit; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + assert_eq!(begin.final_lsn(), commit.commit_lsn()); + assert_eq!(insert.rel_id(), rel_id); + + let tuple_data = insert.tuple().tuple_data(); + assert_eq!(tuple_data.len(), 1); + assert!(matches!(tuple_data[0], TupleData::Text(_))); + if let TupleData::Text(data) = &tuple_data[0] { + assert_eq!(data, &b"42"[..]); + } + + // Send a standby status update and require a keep alive response + let lsn: PgLsn = lsn.parse().unwrap(); + stream + .as_mut() + .standby_status_update(lsn, lsn, lsn, SystemTime::now(), 1) + .await + .unwrap(); + loop { + match stream.next().await { + Some(Ok(PrimaryKeepAlive(_))) => break, + Some(Ok(_)) => (), + Some(Err(e)) => panic!("unexpected replication stream error: {}", e), + None => panic!("unexpected replication stream end"), + } + } +}