Skip to content

Commit 94b9bb6

Browse files
GH-46574: [C++][FlightRPC] ODBC Driver Connectivity support (#47971)
### Rationale for this change Add ODBC driver connection support for the driver, so the driver can connect to servers. ### What changes are included in this PR? Implementation of unicode APIs: - SQLDriverConnect - SQLDisconnect - SQLConnect - DSN Window - Test refactoring ### Are these changes tested? Tested locally on MSVC Windows. ### Are there any user-facing changes? N/A * GitHub Issue: #46574 Lead-authored-by: Alina (Xi) Li <[email protected]> Co-authored-by: justing-bq <[email protected]> Co-authored-by: justing-bq <[email protected]> Signed-off-by: David Li <[email protected]>
1 parent d1e6d6b commit 94b9bb6

23 files changed

+877
-371
lines changed

.github/workflows/cpp.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ jobs:
310310
ARROW_DATASET: ON
311311
ARROW_FLIGHT: ON
312312
ARROW_FLIGHT_SQL: ON
313-
ARROW_FLIGHT_SQL_ODBC: ON
313+
ARROW_FLIGHT_SQL_ODBC: OFF
314314
ARROW_GANDIVA: ON
315315
ARROW_GCS: ON
316316
ARROW_HDFS: OFF
@@ -389,10 +389,6 @@ jobs:
389389
PIPX_BASE_PYTHON: ${{ steps.python-install.outputs.python-path }}
390390
run: |
391391
ci/scripts/install_gcs_testbench.sh default
392-
- name: Register Flight SQL ODBC Driver
393-
shell: cmd
394-
run: |
395-
call "cpp\src\arrow\flight\sql\odbc\tests\install_odbc.cmd" ${{ github.workspace }}\build\cpp\%ARROW_BUILD_TYPE%\libarrow_flight_sql_odbc.dll
396392
- name: Test
397393
shell: msys2 {0}
398394
run: |

cpp/src/arrow/flight/sql/odbc/odbc.def

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
LIBRARY arrow_flight_sql_odbc
1919
EXPORTS
20-
; GH-46574 TODO enable DSN window
21-
; ConfigDSNW
20+
ConfigDSNW
2221
SQLAllocConnect
2322
SQLAllocEnv
2423
SQLAllocHandle

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 147 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
#include "arrow/flight/sql/odbc/odbc_impl/spi/connection.h"
3232
#include "arrow/util/logging.h"
3333

34+
#if defined _WIN32
35+
// For displaying DSN Window
36+
# include "arrow/flight/sql/odbc/odbc_impl/system_dsn.h"
37+
#endif // defined(_WIN32)
38+
3439
namespace arrow::flight::sql::odbc {
3540
SQLRETURN SQLAllocHandle(SQLSMALLINT type, SQLHANDLE parent, SQLHANDLE* result) {
3641
ARROW_LOG(DEBUG) << "SQLAllocHandle called with type: " << type
@@ -718,8 +723,30 @@ SQLRETURN SQLSetConnectAttr(SQLHDBC conn, SQLINTEGER attr, SQLPOINTER value_ptr,
718723
ARROW_LOG(DEBUG) << "SQLSetConnectAttrW called with conn: " << conn
719724
<< ", attr: " << attr << ", value_ptr: " << value_ptr
720725
<< ", value_len: " << value_len;
721-
// GH-47708 TODO: Implement SQLSetConnectAttr
722-
return SQL_INVALID_HANDLE;
726+
// GH-47708 TODO: Add tests for SQLSetConnectAttr
727+
using ODBC::ODBCConnection;
728+
729+
return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() {
730+
const bool is_unicode = true;
731+
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(conn);
732+
connection->SetConnectAttr(attr, value_ptr, value_len, is_unicode);
733+
return SQL_SUCCESS;
734+
});
735+
}
736+
737+
// Load properties from the given DSN. The properties loaded do _not_ overwrite existing
738+
// entries in the properties.
739+
void LoadPropertiesFromDSN(const std::string& dsn,
740+
Connection::ConnPropertyMap& properties) {
741+
arrow::flight::sql::odbc::config::Configuration config;
742+
config.LoadDsn(dsn);
743+
Connection::ConnPropertyMap dsn_properties = config.GetProperties();
744+
for (auto& [key, value] : dsn_properties) {
745+
auto prop_iter = properties.find(key);
746+
if (prop_iter == properties.end()) {
747+
properties.emplace(std::make_pair(std::move(key), std::move(value)));
748+
}
749+
}
723750
}
724751

725752
SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle,
@@ -740,13 +767,73 @@ SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle,
740767
<< out_connection_string_buffer_len << ", out_connection_string_len: "
741768
<< static_cast<const void*>(out_connection_string_len)
742769
<< ", driver_completion: " << driver_completion;
770+
743771
// GH-46449 TODO: Implement FILEDSN and SAVEFILE keywords according to the spec
744772

745773
// GH-46560 TODO: Copy connection string properly in SQLDriverConnect according to the
746774
// spec
747775

748-
// GH-46574 TODO: Implement SQLDriverConnect
749-
return SQL_INVALID_HANDLE;
776+
using ODBC::ODBCConnection;
777+
778+
return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() {
779+
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(conn);
780+
std::string connection_string =
781+
ODBC::SqlWcharToString(in_connection_string, in_connection_string_len);
782+
Connection::ConnPropertyMap properties;
783+
std::string dsn_value = "";
784+
std::optional<std::string> dsn = ODBCConnection::GetDsnIfExists(connection_string);
785+
if (dsn.has_value()) {
786+
dsn_value = dsn.value();
787+
LoadPropertiesFromDSN(dsn_value, properties);
788+
}
789+
ODBCConnection::GetPropertiesFromConnString(connection_string, properties);
790+
791+
std::vector<std::string_view> missing_properties;
792+
793+
// GH-46448 TODO: Implement SQL_DRIVER_COMPLETE_REQUIRED in SQLDriverConnect according
794+
// to the spec
795+
#if defined _WIN32
796+
// Load the DSN window according to driver_completion
797+
if (driver_completion == SQL_DRIVER_PROMPT) {
798+
// Load DSN window before first attempt to connect
799+
arrow::flight::sql::odbc::config::Configuration config;
800+
if (!DisplayConnectionWindow(window_handle, config, properties)) {
801+
return static_cast<SQLRETURN>(SQL_NO_DATA);
802+
}
803+
connection->Connect(dsn_value, properties, missing_properties);
804+
} else if (driver_completion == SQL_DRIVER_COMPLETE ||
805+
driver_completion == SQL_DRIVER_COMPLETE_REQUIRED) {
806+
try {
807+
connection->Connect(dsn_value, properties, missing_properties);
808+
} catch (const DriverException&) {
809+
// If first connection fails due to missing attributes, load
810+
// the DSN window and try to connect again
811+
if (!missing_properties.empty()) {
812+
arrow::flight::sql::odbc::config::Configuration config;
813+
missing_properties.clear();
814+
815+
if (!DisplayConnectionWindow(window_handle, config, properties)) {
816+
return static_cast<SQLRETURN>(SQL_NO_DATA);
817+
}
818+
connection->Connect(dsn_value, properties, missing_properties);
819+
} else {
820+
throw;
821+
}
822+
}
823+
} else {
824+
// Default case: attempt connection without showing DSN window
825+
connection->Connect(dsn_value, properties, missing_properties);
826+
}
827+
#else
828+
// Attempt connection without loading DSN window on macOS/Linux
829+
connection->Connect(dsn, properties, missing_properties);
830+
#endif
831+
// Copy connection string to out_connection_string after connection attempt
832+
return ODBC::GetStringAttribute(true, connection_string, false, out_connection_string,
833+
out_connection_string_buffer_len,
834+
out_connection_string_len,
835+
connection->GetDiagnostics());
836+
});
750837
}
751838

752839
SQLRETURN SQLConnect(SQLHDBC conn, SQLWCHAR* dsn_name, SQLSMALLINT dsn_name_len,
@@ -759,14 +846,48 @@ SQLRETURN SQLConnect(SQLHDBC conn, SQLWCHAR* dsn_name, SQLSMALLINT dsn_name_len,
759846
<< ", user_name_len: " << user_name_len
760847
<< ", password: " << static_cast<const void*>(password)
761848
<< ", password_len: " << password_len;
762-
// GH-46574 TODO: Implement SQLConnect
763-
return SQL_INVALID_HANDLE;
849+
850+
using ODBC::ODBCConnection;
851+
852+
using ODBC::SqlWcharToString;
853+
854+
return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() {
855+
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(conn);
856+
std::string dsn = SqlWcharToString(dsn_name, dsn_name_len);
857+
858+
Configuration config;
859+
config.LoadDsn(dsn);
860+
861+
if (user_name) {
862+
std::string uid = SqlWcharToString(user_name, user_name_len);
863+
config.Emplace(FlightSqlConnection::UID, std::move(uid));
864+
}
865+
866+
if (password) {
867+
std::string pwd = SqlWcharToString(password, password_len);
868+
config.Emplace(FlightSqlConnection::PWD, std::move(pwd));
869+
}
870+
871+
std::vector<std::string_view> missing_properties;
872+
873+
connection->Connect(dsn, config.GetProperties(), missing_properties);
874+
875+
return SQL_SUCCESS;
876+
});
764877
}
765878

766879
SQLRETURN SQLDisconnect(SQLHDBC conn) {
767880
ARROW_LOG(DEBUG) << "SQLDisconnect called with conn: " << conn;
768-
// GH-46574 TODO: Implement SQLDisconnect
769-
return SQL_INVALID_HANDLE;
881+
882+
using ODBC::ODBCConnection;
883+
884+
return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() {
885+
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(conn);
886+
887+
connection->Disconnect();
888+
889+
return SQL_SUCCESS;
890+
});
770891
}
771892

772893
SQLRETURN SQLGetInfo(SQLHDBC conn, SQLUSMALLINT info_type, SQLPOINTER info_value_ptr,
@@ -776,8 +897,24 @@ SQLRETURN SQLGetInfo(SQLHDBC conn, SQLUSMALLINT info_type, SQLPOINTER info_value
776897
<< ", info_value_ptr: " << info_value_ptr << ", buf_len: " << buf_len
777898
<< ", string_length_ptr: "
778899
<< static_cast<const void*>(string_length_ptr);
779-
// GH-47709 TODO: Implement SQLGetInfo
780-
return SQL_INVALID_HANDLE;
900+
901+
// GH-47709 TODO: Update SQLGetInfo implementation and add tests for SQLGetInfo
902+
using ODBC::ODBCConnection;
903+
904+
return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() {
905+
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(conn);
906+
907+
// Set character type to be Unicode by default
908+
const bool is_unicode = true;
909+
910+
if (!info_value_ptr && !string_length_ptr) {
911+
return static_cast<SQLRETURN>(SQL_ERROR);
912+
}
913+
914+
connection->GetInfo(info_type, info_value_ptr, buf_len, string_length_ptr,
915+
is_unicode);
916+
return static_cast<SQLRETURN>(SQL_SUCCESS);
917+
});
781918
}
782919

783920
SQLRETURN SQLGetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER value_ptr,

cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ if(WIN32)
124124
ui/dsn_configuration_window.h
125125
ui/window.cc
126126
ui/window.h
127-
system_dsn.cc)
127+
win_system_dsn.cc
128+
system_dsn.cc
129+
system_dsn.h)
128130
endif()
129131

130132
target_link_libraries(arrow_odbc_spi_impl PUBLIC arrow_flight_sql_shared

cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717

1818
#pragma once
1919

20-
#include <arrow/flight/sql/odbc/odbc_impl/diagnostics.h>
21-
#include <arrow/flight/sql/odbc/odbc_impl/exceptions.h>
22-
#include <arrow/flight/sql/odbc/odbc_impl/platform.h>
2320
#include <sql.h>
2421
#include <sqlext.h>
2522
#include <algorithm>
2623
#include <cstring>
2724
#include <memory>
25+
#include "arrow/flight/sql/odbc/odbc_impl/diagnostics.h"
26+
#include "arrow/flight/sql/odbc/odbc_impl/encoding_utils.h"
27+
#include "arrow/flight/sql/odbc/odbc_impl/exceptions.h"
28+
#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
2829

29-
#include <arrow/flight/sql/odbc/odbc_impl/encoding_utils.h>
30-
30+
// GH-48083 TODO: replace `namespace ODBC` with `namespace arrow::flight::sql::odbc`
3131
namespace ODBC {
3232

3333
using arrow::flight::sql::odbc::Diagnostics;
@@ -48,12 +48,12 @@ inline void GetAttribute(T attribute_value, SQLPOINTER output, O output_size,
4848
}
4949

5050
template <typename O>
51-
inline SQLRETURN GetAttributeUTF8(const std::string& attribute_value, SQLPOINTER output,
51+
inline SQLRETURN GetAttributeUTF8(std::string_view attribute_value, SQLPOINTER output,
5252
O output_size, O* output_len_ptr) {
5353
if (output) {
5454
size_t output_len_before_null =
5555
std::min(static_cast<O>(attribute_value.size()), static_cast<O>(output_size - 1));
56-
memcpy(output, attribute_value.c_str(), output_len_before_null);
56+
std::memcpy(output, attribute_value.data(), output_len_before_null);
5757
reinterpret_cast<char*>(output)[output_len_before_null] = '\0';
5858
}
5959

@@ -68,7 +68,7 @@ inline SQLRETURN GetAttributeUTF8(const std::string& attribute_value, SQLPOINTER
6868
}
6969

7070
template <typename O>
71-
inline SQLRETURN GetAttributeUTF8(const std::string& attribute_value, SQLPOINTER output,
71+
inline SQLRETURN GetAttributeUTF8(std::string_view attribute_value, SQLPOINTER output,
7272
O output_size, O* output_len_ptr,
7373
Diagnostics& diagnostics) {
7474
SQLRETURN result =
@@ -80,7 +80,7 @@ inline SQLRETURN GetAttributeUTF8(const std::string& attribute_value, SQLPOINTER
8080
}
8181

8282
template <typename O>
83-
inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value,
83+
inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value,
8484
bool is_length_in_bytes, SQLPOINTER output,
8585
O output_size, O* output_len_ptr) {
8686
size_t length = ConvertToSqlWChar(
@@ -104,7 +104,7 @@ inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value,
104104
}
105105

106106
template <typename O>
107-
inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value,
107+
inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value,
108108
bool is_length_in_bytes, SQLPOINTER output,
109109
O output_size, O* output_len_ptr,
110110
Diagnostics& diagnostics) {
@@ -117,7 +117,7 @@ inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value,
117117
}
118118

119119
template <typename O>
120-
inline SQLRETURN GetStringAttribute(bool is_unicode, const std::string& attribute_value,
120+
inline SQLRETURN GetStringAttribute(bool is_unicode, std::string_view attribute_value,
121121
bool is_length_in_bytes, SQLPOINTER output,
122122
O output_size, O* output_len_ptr,
123123
Diagnostics& diagnostics) {

cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.cc

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ static const char DEFAULT_USE_CERT_STORE[] = TRUE_STR;
3535
static const char DEFAULT_DISABLE_CERT_VERIFICATION[] = FALSE_STR;
3636

3737
namespace {
38-
std::string ReadDsnString(const std::string& dsn, const std::string_view& key,
38+
std::string ReadDsnString(const std::string& dsn, std::string_view key,
3939
const std::string& dflt = "") {
4040
CONVERT_WIDE_STR(const std::wstring wdsn, dsn);
4141
CONVERT_WIDE_STR(const std::wstring wkey, key);
@@ -150,11 +150,11 @@ void Configuration::LoadDsn(const std::string& dsn) {
150150

151151
void Configuration::Clear() { this->properties_.clear(); }
152152

153-
bool Configuration::IsSet(const std::string_view& key) const {
153+
bool Configuration::IsSet(std::string_view key) const {
154154
return 0 != this->properties_.count(key);
155155
}
156156

157-
const std::string& Configuration::Get(const std::string_view& key) const {
157+
const std::string& Configuration::Get(std::string_view key) const {
158158
const auto itr = this->properties_.find(key);
159159
if (itr == this->properties_.cend()) {
160160
static const std::string empty("");
@@ -163,15 +163,22 @@ const std::string& Configuration::Get(const std::string_view& key) const {
163163
return itr->second;
164164
}
165165

166-
void Configuration::Set(const std::string_view& key, const std::wstring& wvalue) {
166+
void Configuration::Set(std::string_view key, const std::wstring& wvalue) {
167167
CONVERT_UTF8_STR(const std::string value, wvalue);
168168
Set(key, value);
169169
}
170170

171-
void Configuration::Set(const std::string_view& key, const std::string& value) {
171+
void Configuration::Set(std::string_view key, const std::string& value) {
172172
const std::string copy = boost::trim_copy(value);
173173
if (!copy.empty()) {
174-
this->properties_[key] = value;
174+
this->properties_[std::string(key)] = value;
175+
}
176+
}
177+
178+
void Configuration::Emplace(std::string_view key, std::string&& value) {
179+
const std::string copy = boost::trim_copy(value);
180+
if (!copy.empty()) {
181+
this->properties_.emplace(std::make_pair(key, std::move(value)));
175182
}
176183
}
177184

@@ -182,7 +189,7 @@ const Connection::ConnPropertyMap& Configuration::GetProperties() const {
182189
std::vector<std::string_view> Configuration::GetCustomKeys() const {
183190
Connection::ConnPropertyMap copy_props(properties_);
184191
for (auto& key : FlightSqlConnection::ALL_KEYS) {
185-
copy_props.erase(key);
192+
copy_props.erase(std::string(key));
186193
}
187194
std::vector<std::string_view> keys;
188195
boost::copy(copy_props | boost::adaptors::map_keys, std::back_inserter(keys));

0 commit comments

Comments
 (0)