@@ -22,7 +22,8 @@ use crate::scramble::create_response_for_ed25519;
2222use crate :: {
2323 constants:: {
2424 CapabilityFlags , ColumnFlags , ColumnType , Command , CursorType , MAX_PAYLOAD_LEN ,
25- SessionStateType , StatusFlags , StmtExecuteParamFlags , StmtExecuteParamsFlags ,
25+ MariadbCapabilities , SessionStateType , StatusFlags , StmtExecuteParamFlags ,
26+ StmtExecuteParamsFlags ,
2627 } ,
2728 io:: { BufMutExt , ParseBuf } ,
2829 misc:: {
@@ -1594,7 +1595,9 @@ pub struct HandshakePacket<'a> {
15941595 // upper 16 bytes
15951596 capabilities_2 : Const < CapabilityFlags , LeU32UpperHalf > ,
15961597 auth_plugin_data_len : RawInt < u8 > ,
1597- __reserved : Skip < 10 > ,
1598+ __reserved : Skip < 6 > ,
1599+ // MariaDB uses last 4 reserved bytes to pass its extended capabilities.
1600+ mariadb_ext_capabilities : Const < MariadbCapabilities , LeU32 > ,
15981601 scramble_2 : Option < RawBytes < ' a , BareBytes < { ( u8:: MAX as usize ) - 8 } > > > ,
15991602 auth_plugin_name : Option < RawBytes < ' a , NullBytes > > ,
16001603}
@@ -1618,6 +1621,9 @@ impl<'de> MyDeserialize<'de> for HandshakePacket<'de> {
16181621 let capabilities_2: RawConst < LeU32UpperHalf , CapabilityFlags > = sbuf. parse_unchecked ( ( ) ) ?;
16191622 let auth_plugin_data_len: RawInt < u8 > = sbuf. parse_unchecked ( ( ) ) ?;
16201623 let __reserved = sbuf. parse_unchecked ( ( ) ) ?;
1624+ // If the server is MariaDB, it will pass its extended capabilities
1625+ // in the last 4 reserved bytes.
1626+ let mariadb_capabiities: RawConst < LeU32 , MariadbCapabilities > = sbuf. parse_unchecked ( ( ) ) ?;
16211627 let mut scramble_2 = None ;
16221628 if capabilities_1. 0 & CapabilityFlags :: CLIENT_SECURE_CONNECTION . bits ( ) > 0 {
16231629 let len = max ( 13 , auth_plugin_data_len. 0 as i8 - 8 ) as usize ;
@@ -1644,6 +1650,9 @@ impl<'de> MyDeserialize<'de> for HandshakePacket<'de> {
16441650 capabilities_2 : Const :: new ( CapabilityFlags :: from_bits_truncate ( capabilities_2. 0 ) ) ,
16451651 auth_plugin_data_len,
16461652 __reserved,
1653+ mariadb_ext_capabilities : Const :: new ( MariadbCapabilities :: from_bits_truncate (
1654+ mariadb_capabiities. 0 ,
1655+ ) ) ,
16471656 scramble_2,
16481657 auth_plugin_name,
16491658 } )
@@ -1676,7 +1685,8 @@ impl MySerialize for HandshakePacket<'_> {
16761685 buf. put_u8 ( 0 ) ;
16771686 }
16781687
1679- buf. put_slice ( & [ 0_u8 ; 10 ] [ ..] ) ;
1688+ self . __reserved . serialize ( & mut * buf) ;
1689+ self . mariadb_ext_capabilities . serialize ( & mut * buf) ;
16801690
16811691 // Assume that the packet is well formed:
16821692 // * the CLIENT_SECURE_CONNECTION is set.
@@ -1704,6 +1714,7 @@ impl<'a> HandshakePacket<'a> {
17041714 default_collation : u8 ,
17051715 status_flags : StatusFlags ,
17061716 auth_plugin_name : Option < impl Into < Cow < ' a , [ u8 ] > > > ,
1717+ mariadb_capabilities : MariadbCapabilities ,
17071718 ) -> Self {
17081719 // Safety:
17091720 // * capabilities are given as a valid CapabilityFlags instance
@@ -1732,6 +1743,7 @@ impl<'a> HandshakePacket<'a> {
17321743 . unwrap_or_default ( ) ,
17331744 ) ,
17341745 __reserved : Skip ,
1746+ mariadb_ext_capabilities : Const :: new ( mariadb_capabilities) ,
17351747 scramble_2,
17361748 auth_plugin_name : auth_plugin_name. map ( RawBytes :: new) ,
17371749 }
@@ -1750,6 +1762,7 @@ impl<'a> HandshakePacket<'a> {
17501762 capabilities_2 : self . capabilities_2 ,
17511763 auth_plugin_data_len : self . auth_plugin_data_len ,
17521764 __reserved : self . __reserved ,
1765+ mariadb_ext_capabilities : self . mariadb_ext_capabilities ,
17531766 scramble_2 : self . scramble_2 . map ( |x| x. into_owned ( ) ) ,
17541767 auth_plugin_name : self . auth_plugin_name . map ( RawBytes :: into_owned) ,
17551768 }
@@ -1834,6 +1847,10 @@ impl<'a> HandshakePacket<'a> {
18341847 self . capabilities_1 . 0 | self . capabilities_2 . 0
18351848 }
18361849
1850+ /// Value of MariaDB specific server capabilities
1851+ pub fn mariadb_ext_capabilities ( & self ) -> MariadbCapabilities {
1852+ self . mariadb_ext_capabilities . 0
1853+ }
18371854 /// Value of the default_collation field of an initial handshake packet.
18381855 pub fn default_collation ( & self ) -> u8 {
18391856 self . default_collation . 0
@@ -2105,6 +2122,7 @@ pub struct HandshakeResponse<'a> {
21052122 db_name : Option < RawBytes < ' a , NullBytes > > ,
21062123 auth_plugin : Option < AuthPlugin < ' a > > ,
21072124 connect_attributes : Option < HashMap < RawBytes < ' a , LenEnc > , RawBytes < ' a , LenEnc > > > ,
2125+ mariadb_ext_capabilities : Const < MariadbCapabilities , LeU32 > ,
21082126}
21092127
21102128impl < ' a > HandshakeResponse < ' a > {
@@ -2170,13 +2188,26 @@ impl<'a> HandshakeResponse<'a> {
21702188 . collect ( )
21712189 } ) ,
21722190 max_packet_size : RawInt :: new ( max_packet_size) ,
2191+ mariadb_ext_capabilities : Const :: new ( MariadbCapabilities :: empty ( ) ) ,
21732192 }
21742193 }
21752194
2195+ pub fn with_mariadb_ext_capabilities (
2196+ mut self ,
2197+ mariadb_ext_capabilities : MariadbCapabilities ,
2198+ ) -> Self {
2199+ self . mariadb_ext_capabilities = Const :: new ( mariadb_ext_capabilities) ;
2200+ self
2201+ }
2202+
21762203 pub fn capabilities ( & self ) -> CapabilityFlags {
21772204 self . capabilities . 0
21782205 }
21792206
2207+ pub fn mariadb_ext_capabilities ( & self ) -> MariadbCapabilities {
2208+ self . mariadb_ext_capabilities . 0
2209+ }
2210+
21802211 pub fn collation ( & self ) -> u8 {
21812212 self . collation . 0
21822213 }
@@ -2223,7 +2254,8 @@ impl<'de> MyDeserialize<'de> for HandshakeResponse<'de> {
22232254 let client_flags: RawConst < LeU32 , CapabilityFlags > = sbuf. parse_unchecked ( ( ) ) ?;
22242255 let max_packet_size: RawInt < LeU32 > = sbuf. parse_unchecked ( ( ) ) ?;
22252256 let collation = sbuf. parse_unchecked ( ( ) ) ?;
2226- sbuf. parse_unchecked :: < Skip < 23 > > ( ( ) ) ?;
2257+ sbuf. parse_unchecked :: < Skip < 19 > > ( ( ) ) ?;
2258+ let mariadb_flags: RawConst < LeU32 , MariadbCapabilities > = sbuf. parse_unchecked ( ( ) ) ?;
22272259
22282260 let user = buf. parse ( ( ) ) ?;
22292261 let scramble_buf =
@@ -2260,6 +2292,9 @@ impl<'de> MyDeserialize<'de> for HandshakeResponse<'de> {
22602292 db_name,
22612293 auth_plugin,
22622294 connect_attributes,
2295+ mariadb_ext_capabilities : Const :: new ( MariadbCapabilities :: from_bits_truncate (
2296+ mariadb_flags. 0 ,
2297+ ) ) ,
22632298 } )
22642299 }
22652300}
@@ -2269,7 +2304,8 @@ impl MySerialize for HandshakeResponse<'_> {
22692304 self . capabilities . serialize ( & mut * buf) ;
22702305 self . max_packet_size . serialize ( & mut * buf) ;
22712306 self . collation . serialize ( & mut * buf) ;
2272- buf. put_slice ( & [ 0 ; 23 ] ) ;
2307+ buf. put_slice ( & [ 0 ; 19 ] ) ;
2308+ self . mariadb_ext_capabilities . serialize ( & mut * buf) ;
22732309 self . user . serialize ( & mut * buf) ;
22742310 self . scramble_buf . serialize ( & mut * buf) ;
22752311
@@ -4085,6 +4121,71 @@ mod test {
40854121 assert_eq ! ( expected, actual) ;
40864122 }
40874123
4124+ #[ test]
4125+ fn should_parse_handshake_packet_with_mariadb_ext_capabilities ( ) {
4126+ const HSP : & [ u8 ] = b"\x0a 5.5.5-11.4.7-MariaDB-log\x00 \x0b \x00 \
4127+ \x00 \x00 \x64 \x76 \x48 \x40 \x49 \x2d \x43 \x4a \x00 \xff \xf7 \x08 \x02 \x00 \
4128+ \x00 \x00 \x00 \x00 \x00 \x00 \x00 \x00 \x00 \x10 \x00 \x00 \x00 \x2a \x34 \x64 \
4129+ \x7c \x63 \x5a \x77 \x6b \x34 \x5e \x5d \x3a \x00 ";
4130+
4131+ let hsp = HandshakePacket :: deserialize ( ( ) , & mut ParseBuf ( HSP ) ) . unwrap ( ) ;
4132+ assert_eq ! ( hsp. protocol_version( ) , 0x0a ) ;
4133+ assert_eq ! ( hsp. server_version_str( ) , "5.5.5-11.4.7-MariaDB-log" ) ;
4134+ assert_eq ! ( hsp. server_version_parsed( ) , Some ( ( 5 , 5 , 5 ) ) ) ;
4135+ assert_eq ! ( hsp. maria_db_server_version_parsed( ) , Some ( ( 11 , 4 , 7 ) ) ) ;
4136+ assert_eq ! ( hsp. connection_id( ) , 0x0b ) ;
4137+ assert_eq ! ( hsp. scramble_1_ref( ) , b"dvH@I-CJ" ) ;
4138+ assert_eq ! (
4139+ hsp. capabilities( ) ,
4140+ CapabilityFlags :: from_bits_truncate( 0xf7ff )
4141+ ) ;
4142+ assert_eq ! ( hsp. default_collation( ) , 0x08 ) ;
4143+ assert_eq ! ( hsp. status_flags( ) , StatusFlags :: from_bits_truncate( 0x0002 ) ) ;
4144+ assert_eq ! ( hsp. scramble_2_ref( ) , Some ( & b"*4d|cZwk4^]:\x00 " [ ..] ) ) ;
4145+ assert_eq ! ( hsp. auth_plugin_name_ref( ) , None ) ;
4146+ assert_eq ! (
4147+ hsp. mariadb_ext_capabilities( ) ,
4148+ MariadbCapabilities :: MARIADB_CLIENT_CACHE_METADATA
4149+ ) ;
4150+ let mut output = Vec :: new ( ) ;
4151+ hsp. serialize ( & mut output) ;
4152+ assert_eq ! ( & output, HSP ) ;
4153+ }
4154+
4155+ #[ test]
4156+ fn should_build_handshake_response_with_mariadb_capabilities ( ) {
4157+ let flags_without_db_name = CapabilityFlags :: from_bits_truncate ( 0x81aea205 ) ;
4158+ let response = HandshakeResponse :: new (
4159+ Some ( & [ ] [ ..] ) ,
4160+ ( 5u16 , 5 , 5 ) ,
4161+ Some ( & b"root" [ ..] ) ,
4162+ None :: < & ' static [ u8 ] > ,
4163+ Some ( AuthPlugin :: MysqlNativePassword ) ,
4164+ flags_without_db_name,
4165+ None ,
4166+ 1_u32 . to_be ( ) ,
4167+ )
4168+ . with_mariadb_ext_capabilities ( MariadbCapabilities :: MARIADB_CLIENT_CACHE_METADATA ) ;
4169+ let mut actual = Vec :: new ( ) ;
4170+ response. serialize ( & mut actual) ;
4171+
4172+ let expected: Vec < u8 > = [
4173+ 0x05 , 0xa2 , 0xae , 0x81 , // client capabilities
4174+ 0x00 , 0x00 , 0x00 , 0x01 , // max packet
4175+ 0x2d , // charset
4176+ 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 ,
4177+ 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , // reserved
4178+ 0x10 , 0x00 , 0x00 , 0x00 , // mariadb capabilities
4179+ 0x72 , 0x6f , 0x6f , 0x74 , 0x00 , // username=root
4180+ 0x00 , // blank scramble
4181+ 0x6d , 0x79 , 0x73 , 0x71 , 0x6c , 0x5f , 0x6e , 0x61 , 0x74 , 0x69 , 0x76 , 0x65 , 0x5f , 0x70 ,
4182+ 0x61 , 0x73 , 0x73 , 0x77 , 0x6f , 0x72 , 0x64 , 0x00 , // mysql_native_password
4183+ ]
4184+ . to_vec ( ) ;
4185+
4186+ assert_eq ! ( expected, actual) ;
4187+ }
4188+
40884189 #[ test]
40894190 fn parse_str_to_sid ( ) {
40904191 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:23" ;
0 commit comments