Skip to content

Commit 1b8164e

Browse files
authored
Merge pull request #171 from lawrinn/master
Issue#401 Adding required definitions for MariaDB metadata skipping
2 parents a4827e3 + 29ce486 commit 1b8164e

File tree

2 files changed

+135
-5
lines changed

2 files changed

+135
-5
lines changed

src/constants.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,35 @@ my_bitflags! {
360360
}
361361
}
362362

363+
my_bitflags! {
364+
MariadbCapabilities,
365+
#[error("Unknown flags in the raw value of MariadbCapabilities (raw={0:b})")]
366+
UnknownMariadbCapabilityFlags,
367+
u32,
368+
369+
/// Mariadb client capability flags
370+
#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
371+
pub struct MariadbCapabilities: u32 {
372+
/// Permits feedback during long-running operations
373+
const MARIADB_CLIENT_PROGRESS = 0x0000_0001;
374+
375+
/// Former COM_MULTI, don't use
376+
const MARIADB_CLIENT_RESERVED_1 = 0x0000_0002;
377+
378+
/// Support of parameter arrays in COM_STMT_EXECUTE, since 10.2.0
379+
const MARIADB_CLIENT_STMT_BULK_OPERATIONS = 0x0000_0004;
380+
381+
/// Support of extended data type/format information, since 10.5.0
382+
const MARIADB_CLIENT_EXTENDED_METADATA = 0x0000_0008;
383+
384+
/// Do not resend metadata for prepared statements, since 10.6.0
385+
const MARIADB_CLIENT_CACHE_METADATA = 0x0000_0010;
386+
387+
/// Permits sending unit result-set for BULK commands
388+
const MARIADB_CLIENT_BULK_UNIT_RESULTS = 0x0000_0020;
389+
}
390+
}
391+
363392
my_bitflags! {
364393
CursorType,
365394
#[error("Unknown flags in the raw value of CursorType (raw={0:b})")]

src/packets/mod.rs

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ use crate::scramble::create_response_for_ed25519;
2222
use crate::{
2323
constants::{
2424
CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN,
25-
SessionStateType, StatusFlags, StmtExecuteParamFlags, StmtExecuteParamsFlags,
25+
MariadbCapabilities, SessionStateType, StatusFlags, StmtExecuteParamFlags,
26+
StmtExecuteParamsFlags,
2627
},
2728
io::{BufMutExt, ParseBuf},
2829
misc::{
@@ -1594,7 +1595,9 @@ pub struct HandshakePacket<'a> {
15941595
// upper 16 bytes
15951596
capabilities_2: Const<CapabilityFlags, LeU32UpperHalf>,
15961597
auth_plugin_data_len: RawInt<u8>,
1597-
__reserved: Skip<10>,
1598+
__reserved: Skip<6>,
1599+
// MariaDB uses last 4 reserved bytes to pass its extended capabilities.
1600+
mariadb_ext_capabilities: Const<MariadbCapabilities, LeU32>,
15981601
scramble_2: Option<RawBytes<'a, BareBytes<{ (u8::MAX as usize) - 8 }>>>,
15991602
auth_plugin_name: Option<RawBytes<'a, NullBytes>>,
16001603
}
@@ -1618,6 +1621,9 @@ impl<'de> MyDeserialize<'de> for HandshakePacket<'de> {
16181621
let capabilities_2: RawConst<LeU32UpperHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
16191622
let auth_plugin_data_len: RawInt<u8> = sbuf.parse_unchecked(())?;
16201623
let __reserved = sbuf.parse_unchecked(())?;
1624+
// If the server is MariaDB, it will pass its extended capabilities
1625+
// in the last 4 reserved bytes.
1626+
let mariadb_capabiities: RawConst<LeU32, MariadbCapabilities> = sbuf.parse_unchecked(())?;
16211627
let mut scramble_2 = None;
16221628
if capabilities_1.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
16231629
let len = max(13, auth_plugin_data_len.0 as i8 - 8) as usize;
@@ -1644,6 +1650,9 @@ impl<'de> MyDeserialize<'de> for HandshakePacket<'de> {
16441650
capabilities_2: Const::new(CapabilityFlags::from_bits_truncate(capabilities_2.0)),
16451651
auth_plugin_data_len,
16461652
__reserved,
1653+
mariadb_ext_capabilities: Const::new(MariadbCapabilities::from_bits_truncate(
1654+
mariadb_capabiities.0,
1655+
)),
16471656
scramble_2,
16481657
auth_plugin_name,
16491658
})
@@ -1676,7 +1685,8 @@ impl MySerialize for HandshakePacket<'_> {
16761685
buf.put_u8(0);
16771686
}
16781687

1679-
buf.put_slice(&[0_u8; 10][..]);
1688+
self.__reserved.serialize(&mut *buf);
1689+
self.mariadb_ext_capabilities.serialize(&mut *buf);
16801690

16811691
// Assume that the packet is well formed:
16821692
// * the CLIENT_SECURE_CONNECTION is set.
@@ -1704,6 +1714,7 @@ impl<'a> HandshakePacket<'a> {
17041714
default_collation: u8,
17051715
status_flags: StatusFlags,
17061716
auth_plugin_name: Option<impl Into<Cow<'a, [u8]>>>,
1717+
mariadb_capabilities: MariadbCapabilities,
17071718
) -> Self {
17081719
// Safety:
17091720
// * capabilities are given as a valid CapabilityFlags instance
@@ -1732,6 +1743,7 @@ impl<'a> HandshakePacket<'a> {
17321743
.unwrap_or_default(),
17331744
),
17341745
__reserved: Skip,
1746+
mariadb_ext_capabilities: Const::new(mariadb_capabilities),
17351747
scramble_2,
17361748
auth_plugin_name: auth_plugin_name.map(RawBytes::new),
17371749
}
@@ -1750,6 +1762,7 @@ impl<'a> HandshakePacket<'a> {
17501762
capabilities_2: self.capabilities_2,
17511763
auth_plugin_data_len: self.auth_plugin_data_len,
17521764
__reserved: self.__reserved,
1765+
mariadb_ext_capabilities: self.mariadb_ext_capabilities,
17531766
scramble_2: self.scramble_2.map(|x| x.into_owned()),
17541767
auth_plugin_name: self.auth_plugin_name.map(RawBytes::into_owned),
17551768
}
@@ -1834,6 +1847,10 @@ impl<'a> HandshakePacket<'a> {
18341847
self.capabilities_1.0 | self.capabilities_2.0
18351848
}
18361849

1850+
/// Value of MariaDB specific server capabilities
1851+
pub fn mariadb_ext_capabilities(&self) -> MariadbCapabilities {
1852+
self.mariadb_ext_capabilities.0
1853+
}
18371854
/// Value of the default_collation field of an initial handshake packet.
18381855
pub fn default_collation(&self) -> u8 {
18391856
self.default_collation.0
@@ -2105,6 +2122,7 @@ pub struct HandshakeResponse<'a> {
21052122
db_name: Option<RawBytes<'a, NullBytes>>,
21062123
auth_plugin: Option<AuthPlugin<'a>>,
21072124
connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
2125+
mariadb_ext_capabilities: Const<MariadbCapabilities, LeU32>,
21082126
}
21092127

21102128
impl<'a> HandshakeResponse<'a> {
@@ -2170,13 +2188,26 @@ impl<'a> HandshakeResponse<'a> {
21702188
.collect()
21712189
}),
21722190
max_packet_size: RawInt::new(max_packet_size),
2191+
mariadb_ext_capabilities: Const::new(MariadbCapabilities::empty()),
21732192
}
21742193
}
21752194

2195+
pub fn with_mariadb_ext_capabilities(
2196+
mut self,
2197+
mariadb_ext_capabilities: MariadbCapabilities,
2198+
) -> Self {
2199+
self.mariadb_ext_capabilities = Const::new(mariadb_ext_capabilities);
2200+
self
2201+
}
2202+
21762203
pub fn capabilities(&self) -> CapabilityFlags {
21772204
self.capabilities.0
21782205
}
21792206

2207+
pub fn mariadb_ext_capabilities(&self) -> MariadbCapabilities {
2208+
self.mariadb_ext_capabilities.0
2209+
}
2210+
21802211
pub fn collation(&self) -> u8 {
21812212
self.collation.0
21822213
}
@@ -2223,7 +2254,8 @@ impl<'de> MyDeserialize<'de> for HandshakeResponse<'de> {
22232254
let client_flags: RawConst<LeU32, CapabilityFlags> = sbuf.parse_unchecked(())?;
22242255
let max_packet_size: RawInt<LeU32> = sbuf.parse_unchecked(())?;
22252256
let collation = sbuf.parse_unchecked(())?;
2226-
sbuf.parse_unchecked::<Skip<23>>(())?;
2257+
sbuf.parse_unchecked::<Skip<19>>(())?;
2258+
let mariadb_flags: RawConst<LeU32, MariadbCapabilities> = sbuf.parse_unchecked(())?;
22272259

22282260
let user = buf.parse(())?;
22292261
let scramble_buf =
@@ -2260,6 +2292,9 @@ impl<'de> MyDeserialize<'de> for HandshakeResponse<'de> {
22602292
db_name,
22612293
auth_plugin,
22622294
connect_attributes,
2295+
mariadb_ext_capabilities: Const::new(MariadbCapabilities::from_bits_truncate(
2296+
mariadb_flags.0,
2297+
)),
22632298
})
22642299
}
22652300
}
@@ -2269,7 +2304,8 @@ impl MySerialize for HandshakeResponse<'_> {
22692304
self.capabilities.serialize(&mut *buf);
22702305
self.max_packet_size.serialize(&mut *buf);
22712306
self.collation.serialize(&mut *buf);
2272-
buf.put_slice(&[0; 23]);
2307+
buf.put_slice(&[0; 19]);
2308+
self.mariadb_ext_capabilities.serialize(&mut *buf);
22732309
self.user.serialize(&mut *buf);
22742310
self.scramble_buf.serialize(&mut *buf);
22752311

@@ -4085,6 +4121,71 @@ mod test {
40854121
assert_eq!(expected, actual);
40864122
}
40874123

4124+
#[test]
4125+
fn should_parse_handshake_packet_with_mariadb_ext_capabilities() {
4126+
const HSP: &[u8] = b"\x0a5.5.5-11.4.7-MariaDB-log\x00\x0b\x00\
4127+
\x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
4128+
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x2a\x34\x64\
4129+
\x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";
4130+
4131+
let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
4132+
assert_eq!(hsp.protocol_version(), 0x0a);
4133+
assert_eq!(hsp.server_version_str(), "5.5.5-11.4.7-MariaDB-log");
4134+
assert_eq!(hsp.server_version_parsed(), Some((5, 5, 5)));
4135+
assert_eq!(hsp.maria_db_server_version_parsed(), Some((11, 4, 7)));
4136+
assert_eq!(hsp.connection_id(), 0x0b);
4137+
assert_eq!(hsp.scramble_1_ref(), b"dvH@I-CJ");
4138+
assert_eq!(
4139+
hsp.capabilities(),
4140+
CapabilityFlags::from_bits_truncate(0xf7ff)
4141+
);
4142+
assert_eq!(hsp.default_collation(), 0x08);
4143+
assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
4144+
assert_eq!(hsp.scramble_2_ref(), Some(&b"*4d|cZwk4^]:\x00"[..]));
4145+
assert_eq!(hsp.auth_plugin_name_ref(), None);
4146+
assert_eq!(
4147+
hsp.mariadb_ext_capabilities(),
4148+
MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA
4149+
);
4150+
let mut output = Vec::new();
4151+
hsp.serialize(&mut output);
4152+
assert_eq!(&output, HSP);
4153+
}
4154+
4155+
#[test]
4156+
fn should_build_handshake_response_with_mariadb_capabilities() {
4157+
let flags_without_db_name = CapabilityFlags::from_bits_truncate(0x81aea205);
4158+
let response = HandshakeResponse::new(
4159+
Some(&[][..]),
4160+
(5u16, 5, 5),
4161+
Some(&b"root"[..]),
4162+
None::<&'static [u8]>,
4163+
Some(AuthPlugin::MysqlNativePassword),
4164+
flags_without_db_name,
4165+
None,
4166+
1_u32.to_be(),
4167+
)
4168+
.with_mariadb_ext_capabilities(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA);
4169+
let mut actual = Vec::new();
4170+
response.serialize(&mut actual);
4171+
4172+
let expected: Vec<u8> = [
4173+
0x05, 0xa2, 0xae, 0x81, // client capabilities
4174+
0x00, 0x00, 0x00, 0x01, // max packet
4175+
0x2d, // charset
4176+
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4177+
0x00, 0x00, 0x00, 0x00, 0x00, // reserved
4178+
0x10, 0x00, 0x00, 0x00, // mariadb capabilities
4179+
0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
4180+
0x00, // blank scramble
4181+
0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4182+
0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
4183+
]
4184+
.to_vec();
4185+
4186+
assert_eq!(expected, actual);
4187+
}
4188+
40884189
#[test]
40894190
fn parse_str_to_sid() {
40904191
let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:23";

0 commit comments

Comments
 (0)