From 1e132007c192f6e92e7d95c993f5f9d6fcaf9c73 Mon Sep 17 00:00:00 2001 From: David Coe Date: Tue, 19 Nov 2024 13:13:05 -0500 Subject: [PATCH] run pre-commit on .cs files --- csharp/src/Drivers/Apache/ApacheParameters.cs | 58 +- csharp/src/Drivers/Apache/ApacheUtility.cs | 162 +- .../Apache/Hive2/HiveServer2Connection.cs | 376 +-- .../Apache/Hive2/HiveServer2Statement.cs | 356 +-- .../Drivers/Apache/Spark/SparkConnection.cs | 2148 ++++++++--------- .../Apache/Spark/SparkConnectionTest.cs | 596 ++--- .../Drivers/Apache/Spark/StatementTests.cs | 464 ++-- 7 files changed, 2080 insertions(+), 2080 deletions(-) diff --git a/csharp/src/Drivers/Apache/ApacheParameters.cs b/csharp/src/Drivers/Apache/ApacheParameters.cs index 414f255ddf..d483bdb161 100644 --- a/csharp/src/Drivers/Apache/ApacheParameters.cs +++ b/csharp/src/Drivers/Apache/ApacheParameters.cs @@ -1,29 +1,29 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -namespace Apache.Arrow.Adbc.Drivers.Apache -{ - /// - /// Options common to all Apache drivers. - /// - public class ApacheParameters - { - public const string PollTimeMilliseconds = "adbc.apache.statement.polltime_ms"; - public const string BatchSize = "adbc.apache.statement.batch_size"; - public const string QueryTimeoutSeconds = "adbc.apache.statement.query_timeout_s"; - } - } +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace Apache.Arrow.Adbc.Drivers.Apache +{ + /// + /// Options common to all Apache drivers. + /// + public class ApacheParameters + { + public const string PollTimeMilliseconds = "adbc.apache.statement.polltime_ms"; + public const string BatchSize = "adbc.apache.statement.batch_size"; + public const string QueryTimeoutSeconds = "adbc.apache.statement.query_timeout_s"; + } + } diff --git a/csharp/src/Drivers/Apache/ApacheUtility.cs b/csharp/src/Drivers/Apache/ApacheUtility.cs index 7c02e137dc..73c42e0f98 100644 --- a/csharp/src/Drivers/Apache/ApacheUtility.cs +++ b/csharp/src/Drivers/Apache/ApacheUtility.cs @@ -1,81 +1,81 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace Apache.Arrow.Adbc.Drivers.Apache -{ - internal class ApacheUtility - { - internal const int QueryTimeoutSecondsDefault = 60; - - public enum TimeUnit - { - Seconds, - Milliseconds - } - - public static CancellationToken GetCancellationToken(int timeout, TimeUnit timeUnit) - { - TimeSpan span; - - if (timeout == -1 || timeout == int.MaxValue) - { - // the max TimeSpan for CancellationTokenSource is int.MaxValue in milliseconds (not TimeSpan.MaxValue) - // no matter what the unit is - span = TimeSpan.FromMilliseconds(int.MaxValue); - } - else - { - if (timeUnit == TimeUnit.Seconds) - { - span = TimeSpan.FromSeconds(timeout); - } - else - { - span = TimeSpan.FromMilliseconds(timeout); - } - } - - return GetCancellationToken(span); - } - - private static CancellationToken GetCancellationToken(TimeSpan timeSpan) - { - var cts = new CancellationTokenSource(timeSpan); - return cts.Token; - } - - public static bool QueryTimeoutIsValid(string key, string value, out int queryTimeoutSeconds) - { - if (!string.IsNullOrEmpty(value) && int.TryParse(value, out int queryTimeout) && (queryTimeout > 0 || queryTimeout == -1)) - { - queryTimeoutSeconds = queryTimeout; - return true; - } - else - { - throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value of -1 (infinite) or greater than zero."); - } - } - } -} +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Apache.Arrow.Adbc.Drivers.Apache +{ + internal class ApacheUtility + { + internal const int QueryTimeoutSecondsDefault = 60; + + public enum TimeUnit + { + Seconds, + Milliseconds + } + + public static CancellationToken GetCancellationToken(int timeout, TimeUnit timeUnit) + { + TimeSpan span; + + if (timeout == -1 || timeout == int.MaxValue) + { + // the max TimeSpan for CancellationTokenSource is int.MaxValue in milliseconds (not TimeSpan.MaxValue) + // no matter what the unit is + span = TimeSpan.FromMilliseconds(int.MaxValue); + } + else + { + if (timeUnit == TimeUnit.Seconds) + { + span = TimeSpan.FromSeconds(timeout); + } + else + { + span = TimeSpan.FromMilliseconds(timeout); + } + } + + return GetCancellationToken(span); + } + + private static CancellationToken GetCancellationToken(TimeSpan timeSpan) + { + var cts = new CancellationTokenSource(timeSpan); + return cts.Token; + } + + public static bool QueryTimeoutIsValid(string key, string value, out int queryTimeoutSeconds) + { + if (!string.IsNullOrEmpty(value) && int.TryParse(value, out int queryTimeout) && (queryTimeout > 0 || queryTimeout == -1)) + { + queryTimeoutSeconds = queryTimeout; + return true; + } + else + { + throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value of -1 (infinite) or greater than zero."); + } + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index a60ea5f0ec..d3ad45bec5 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -1,188 +1,188 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using Apache.Arrow.Ipc; -using Apache.Hive.Service.Rpc.Thrift; -using Thrift.Protocol; -using Thrift.Transport; - -namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 -{ - internal abstract class HiveServer2Connection : AdbcConnection - { - internal const long BatchSizeDefault = 50000; - internal const int PollTimeMillisecondsDefault = 500; - private const int ConnectTimeoutMillisecondDefault = 30000; - private TTransport? _transport; - private TCLIService.Client? _client; - private readonly Lazy _vendorVersion; - private readonly Lazy _vendorName; - - internal HiveServer2Connection(IReadOnlyDictionary properties) - { - Properties = properties; - // Note: "LazyThreadSafetyMode.PublicationOnly" is thread-safe initialization where - // the first successful thread sets the value. If an exception is thrown, initialization - // will retry until it successfully returns a value without an exception. - // https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects - _vendorVersion = new Lazy(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), LazyThreadSafetyMode.PublicationOnly); - _vendorName = new Lazy(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), LazyThreadSafetyMode.PublicationOnly); - - if (properties.TryGetValue(ApacheParameters.QueryTimeoutSeconds, out string? queryTimeoutSecondsSettingValue)) - { - if (ApacheUtility.QueryTimeoutIsValid(ApacheParameters.QueryTimeoutSeconds, queryTimeoutSecondsSettingValue, out int queryTimeoutSeconds)) - { - QueryTimeoutSeconds = queryTimeoutSeconds; - } - } - } - - internal TCLIService.Client Client - { - get { return _client ?? throw new InvalidOperationException("connection not open"); } - } - - internal string VendorVersion => _vendorVersion.Value; - - internal string VendorName => _vendorName.Value; - - protected internal int QueryTimeoutSeconds { get; private set; } = ApacheUtility.QueryTimeoutSecondsDefault; - - internal IReadOnlyDictionary Properties { get; } - - internal async Task OpenAsync() - { - try - { - TTransport transport = await CreateTransportAsync(); - TProtocol protocol = await CreateProtocolAsync(transport); - _transport = protocol.Transport; - _client = new TCLIService.Client(protocol); - TOpenSessionReq request = CreateSessionRequest(); - - CancellationToken timeoutToken = ApacheUtility.GetCancellationToken(ConnectTimeoutMilliseconds, ApacheUtility.TimeUnit.Milliseconds); - TOpenSessionResp? session = await Client.OpenSession(request, timeoutToken); - - // Explicitly check the session status - if (session == null) - { - throw new HiveServer2Exception("Unable to open session. Unknown error."); - } - else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS) - { - throw new HiveServer2Exception(session.Status.ErrorMessage) - .SetNativeError(session.Status.ErrorCode) - .SetSqlState(session.Status.SqlState); - } - - SessionHandle = session.SessionHandle; - } - catch (OperationCanceledException ex) - { - throw new TimeoutException("The operation timed out while attempting to open a session.", ex); - } - catch (Exception ex) - { - // Handle other exceptions if necessary - throw new HiveServer2Exception("An unexpected error occurred while opening the session.", ex); - } - } - - internal TSessionHandle? SessionHandle { get; private set; } - - protected internal DataTypeConversion DataTypeConversion { get; set; } = DataTypeConversion.None; - - protected internal HiveServer2TlsOption TlsOptions { get; set; } = HiveServer2TlsOption.Empty; - - protected internal int ConnectTimeoutMilliseconds { get; set; } = ConnectTimeoutMillisecondDefault; - - protected abstract Task CreateTransportAsync(); - - protected abstract Task CreateProtocolAsync(TTransport transport); - - protected abstract TOpenSessionReq CreateSessionRequest(); - - internal abstract SchemaParser SchemaParser { get; } - - internal abstract IArrowArrayStream NewReader(T statement, Schema schema) where T : HiveServer2Statement; - - public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList? tableTypes, string? columnNamePattern) - { - throw new NotImplementedException(); - } - - public override IArrowArrayStream GetTableTypes() - { - throw new NotImplementedException(); - } - - internal static async Task PollForResponseAsync(TOperationHandle operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds, CancellationToken cancellationToken = default) - { - TGetOperationStatusResp? statusResponse = null; - do - { - if (statusResponse != null) { await Task.Delay(pollTimeMilliseconds); } - TGetOperationStatusReq request = new(operationHandle); - statusResponse = await client.GetOperationStatus(request, cancellationToken); - } while (statusResponse.OperationState == TOperationState.PENDING_STATE || statusResponse.OperationState == TOperationState.RUNNING_STATE); - } - - private string GetInfoTypeStringValue(TGetInfoType infoType) - { - TGetInfoReq req = new() - { - SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), - InfoType = infoType, - }; - - TGetInfoResp getInfoResp = Client.GetInfo(req).Result; - if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage) - .SetNativeError(getInfoResp.Status.ErrorCode) - .SetSqlState(getInfoResp.Status.SqlState); - } - - return getInfoResp.InfoValue.StringValue; - } - - public override void Dispose() - { - if (_client != null) - { - TCloseSessionReq r6 = new TCloseSessionReq(SessionHandle); - _client.CloseSession(r6).Wait(); - - _transport?.Close(); - _client.Dispose(); - _transport = null; - _client = null; - } - } - - internal static async Task GetResultSetMetadataAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) - { - TGetResultSetMetadataReq request = new(operationHandle); - TGetResultSetMetadataResp response = await client.GetResultSetMetadata(request, cancellationToken); - return response; - } - } -} +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Ipc; +using Apache.Hive.Service.Rpc.Thrift; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + internal abstract class HiveServer2Connection : AdbcConnection + { + internal const long BatchSizeDefault = 50000; + internal const int PollTimeMillisecondsDefault = 500; + private const int ConnectTimeoutMillisecondDefault = 30000; + private TTransport? _transport; + private TCLIService.Client? _client; + private readonly Lazy _vendorVersion; + private readonly Lazy _vendorName; + + internal HiveServer2Connection(IReadOnlyDictionary properties) + { + Properties = properties; + // Note: "LazyThreadSafetyMode.PublicationOnly" is thread-safe initialization where + // the first successful thread sets the value. If an exception is thrown, initialization + // will retry until it successfully returns a value without an exception. + // https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects + _vendorVersion = new Lazy(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), LazyThreadSafetyMode.PublicationOnly); + _vendorName = new Lazy(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), LazyThreadSafetyMode.PublicationOnly); + + if (properties.TryGetValue(ApacheParameters.QueryTimeoutSeconds, out string? queryTimeoutSecondsSettingValue)) + { + if (ApacheUtility.QueryTimeoutIsValid(ApacheParameters.QueryTimeoutSeconds, queryTimeoutSecondsSettingValue, out int queryTimeoutSeconds)) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + } + } + } + + internal TCLIService.Client Client + { + get { return _client ?? throw new InvalidOperationException("connection not open"); } + } + + internal string VendorVersion => _vendorVersion.Value; + + internal string VendorName => _vendorName.Value; + + protected internal int QueryTimeoutSeconds { get; private set; } = ApacheUtility.QueryTimeoutSecondsDefault; + + internal IReadOnlyDictionary Properties { get; } + + internal async Task OpenAsync() + { + try + { + TTransport transport = await CreateTransportAsync(); + TProtocol protocol = await CreateProtocolAsync(transport); + _transport = protocol.Transport; + _client = new TCLIService.Client(protocol); + TOpenSessionReq request = CreateSessionRequest(); + + CancellationToken timeoutToken = ApacheUtility.GetCancellationToken(ConnectTimeoutMilliseconds, ApacheUtility.TimeUnit.Milliseconds); + TOpenSessionResp? session = await Client.OpenSession(request, timeoutToken); + + // Explicitly check the session status + if (session == null) + { + throw new HiveServer2Exception("Unable to open session. Unknown error."); + } + else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS) + { + throw new HiveServer2Exception(session.Status.ErrorMessage) + .SetNativeError(session.Status.ErrorCode) + .SetSqlState(session.Status.SqlState); + } + + SessionHandle = session.SessionHandle; + } + catch (OperationCanceledException ex) + { + throw new TimeoutException("The operation timed out while attempting to open a session.", ex); + } + catch (Exception ex) + { + // Handle other exceptions if necessary + throw new HiveServer2Exception("An unexpected error occurred while opening the session.", ex); + } + } + + internal TSessionHandle? SessionHandle { get; private set; } + + protected internal DataTypeConversion DataTypeConversion { get; set; } = DataTypeConversion.None; + + protected internal HiveServer2TlsOption TlsOptions { get; set; } = HiveServer2TlsOption.Empty; + + protected internal int ConnectTimeoutMilliseconds { get; set; } = ConnectTimeoutMillisecondDefault; + + protected abstract Task CreateTransportAsync(); + + protected abstract Task CreateProtocolAsync(TTransport transport); + + protected abstract TOpenSessionReq CreateSessionRequest(); + + internal abstract SchemaParser SchemaParser { get; } + + internal abstract IArrowArrayStream NewReader(T statement, Schema schema) where T : HiveServer2Statement; + + public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList? tableTypes, string? columnNamePattern) + { + throw new NotImplementedException(); + } + + public override IArrowArrayStream GetTableTypes() + { + throw new NotImplementedException(); + } + + internal static async Task PollForResponseAsync(TOperationHandle operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds, CancellationToken cancellationToken = default) + { + TGetOperationStatusResp? statusResponse = null; + do + { + if (statusResponse != null) { await Task.Delay(pollTimeMilliseconds); } + TGetOperationStatusReq request = new(operationHandle); + statusResponse = await client.GetOperationStatus(request, cancellationToken); + } while (statusResponse.OperationState == TOperationState.PENDING_STATE || statusResponse.OperationState == TOperationState.RUNNING_STATE); + } + + private string GetInfoTypeStringValue(TGetInfoType infoType) + { + TGetInfoReq req = new() + { + SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), + InfoType = infoType, + }; + + TGetInfoResp getInfoResp = Client.GetInfo(req).Result; + if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage) + .SetNativeError(getInfoResp.Status.ErrorCode) + .SetSqlState(getInfoResp.Status.SqlState); + } + + return getInfoResp.InfoValue.StringValue; + } + + public override void Dispose() + { + if (_client != null) + { + TCloseSessionReq r6 = new TCloseSessionReq(SessionHandle); + _client.CloseSession(r6).Wait(); + + _transport?.Close(); + _client.Dispose(); + _transport = null; + _client = null; + } + } + + internal static async Task GetResultSetMetadataAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) + { + TGetResultSetMetadataReq request = new(operationHandle); + TGetResultSetMetadataResp response = await client.GetResultSetMetadata(request, cancellationToken); + return response; + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index bbb96e0af2..4c7f4593d6 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -1,178 +1,178 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -using System; -using System.Threading; -using System.Threading.Tasks; -using Apache.Arrow.Ipc; -using Apache.Hive.Service.Rpc.Thrift; - -namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 -{ - internal abstract class HiveServer2Statement : AdbcStatement - { - protected HiveServer2Statement(HiveServer2Connection connection) - { - Connection = connection; - } - - protected virtual void SetStatementProperties(TExecuteStatementReq statement) - { - statement.QueryTimeout = QueryTimeoutSeconds; - } - - public override QueryResult ExecuteQuery() => ExecuteQueryAsync().AsTask().Result; - - public override UpdateResult ExecuteUpdate() => ExecuteUpdateAsync().Result; - - public override async ValueTask ExecuteQueryAsync() - { - try - { - CancellationToken timeoutToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); - - // this could either: - // take QueryTimeoutSeconds * 3 - // OR - // take QueryTimeoutSeconds (but this could be restricting) - - await ExecuteStatementAsync(timeoutToken); // --> get QueryTimeout + - await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, timeoutToken); // + poll, up to QueryTimeout - Schema schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, timeoutToken); // + get the result, up to QueryTimeout - - // TODO: Ensure this is set dynamically based on server capabilities - return new QueryResult(-1, Connection.NewReader(this, schema)); - } - catch (OperationCanceledException ex) - { - throw new TimeoutException("The query execution timed out.", ex); - } - } - - private async Task GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) - { - TGetResultSetMetadataResp response = await HiveServer2Connection.GetResultSetMetadataAsync(operationHandle, client, cancellationToken); - return Connection.SchemaParser.GetArrowSchema(response.Schema, Connection.DataTypeConversion); - } - - public override async Task ExecuteUpdateAsync() - { - const string NumberOfAffectedRowsColumnName = "num_affected_rows"; - // TODO: Add CTS here --> this is inside of ExecuteQueryAsync - - QueryResult queryResult = await ExecuteQueryAsync(); - if (queryResult.Stream == null) - { - throw new AdbcException("no data found"); - } - - using IArrowArrayStream stream = queryResult.Stream; - - // Check if the affected rows columns are returned in the result. - Field affectedRowsField = stream.Schema.GetFieldByName(NumberOfAffectedRowsColumnName); - if (affectedRowsField != null && affectedRowsField.DataType.TypeId != Types.ArrowTypeId.Int64) - { - throw new AdbcException($"Unexpected data type for column: '{NumberOfAffectedRowsColumnName}'", new ArgumentException(NumberOfAffectedRowsColumnName)); - } - - // The default is -1. - if (affectedRowsField == null) return new UpdateResult(-1); - - long? affectedRows = null; - while (true) - { - using RecordBatch nextBatch = await stream.ReadNextRecordBatchAsync(); - if (nextBatch == null) { break; } - Int64Array numOfModifiedArray = (Int64Array)nextBatch.Column(NumberOfAffectedRowsColumnName); - // Note: should only have one item, but iterate for completeness - for (int i = 0; i < numOfModifiedArray.Length; i++) - { - // Note: handle the case where the affected rows are zero (0). - affectedRows = (affectedRows ?? 0) + numOfModifiedArray.GetValue(i).GetValueOrDefault(0); - } - } - - // If no altered rows, i.e. DDC statements, then -1 is the default. - return new UpdateResult(affectedRows ?? -1); - } - - public override void SetOption(string key, string value) - { - switch (key) - { - case ApacheParameters.PollTimeMilliseconds: - UpdatePollTimeIfValid(key, value); - break; - case ApacheParameters.BatchSize: - UpdateBatchSizeIfValid(key, value); - break; - case ApacheParameters.QueryTimeoutSeconds: - if (ApacheUtility.QueryTimeoutIsValid(key, value, out int queryTimeoutSeconds)) - { - QueryTimeoutSeconds = queryTimeoutSeconds; - } - break; - default: - throw AdbcException.NotImplemented($"Option '{key}' is not implemented."); - } - } - - protected async Task ExecuteStatementAsync(CancellationToken cancellationToken = default) - { - TExecuteStatementReq executeRequest = new TExecuteStatementReq(Connection.SessionHandle, SqlQuery); - SetStatementProperties(executeRequest); - TExecuteStatementResp executeResponse = await Connection.Client.ExecuteStatement(executeRequest, cancellationToken); - if (executeResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new HiveServer2Exception(executeResponse.Status.ErrorMessage) - .SetSqlState(executeResponse.Status.SqlState) - .SetNativeError(executeResponse.Status.ErrorCode); - } - OperationHandle = executeResponse.OperationHandle; - } - - protected internal int PollTimeMilliseconds { get; private set; } = HiveServer2Connection.PollTimeMillisecondsDefault; - - protected internal long BatchSize { get; private set; } = HiveServer2Connection.BatchSizeDefault; - - protected internal int QueryTimeoutSeconds { get; set; } = ApacheUtility.QueryTimeoutSecondsDefault; - - public HiveServer2Connection Connection { get; private set; } - - public TOperationHandle? OperationHandle { get; private set; } - - private void UpdatePollTimeIfValid(string key, string value) => PollTimeMilliseconds = !string.IsNullOrEmpty(key) && int.TryParse(value, result: out int pollTimeMilliseconds) && pollTimeMilliseconds >= 0 - ? pollTimeMilliseconds - : throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than or equal to 0."); - - private void UpdateBatchSizeIfValid(string key, string value) => BatchSize = !string.IsNullOrEmpty(value) && long.TryParse(value, out long batchSize) && batchSize > 0 - ? batchSize - : throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than zero."); - - public override void Dispose() - { - if (OperationHandle != null) - { - TCloseOperationReq request = new TCloseOperationReq(OperationHandle); - Connection.Client.CloseOperation(request).Wait(); - OperationHandle = null; - } - - base.Dispose(); - } - } -} +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Ipc; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + internal abstract class HiveServer2Statement : AdbcStatement + { + protected HiveServer2Statement(HiveServer2Connection connection) + { + Connection = connection; + } + + protected virtual void SetStatementProperties(TExecuteStatementReq statement) + { + statement.QueryTimeout = QueryTimeoutSeconds; + } + + public override QueryResult ExecuteQuery() => ExecuteQueryAsync().AsTask().Result; + + public override UpdateResult ExecuteUpdate() => ExecuteUpdateAsync().Result; + + public override async ValueTask ExecuteQueryAsync() + { + try + { + CancellationToken timeoutToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + + // this could either: + // take QueryTimeoutSeconds * 3 + // OR + // take QueryTimeoutSeconds (but this could be restricting) + + await ExecuteStatementAsync(timeoutToken); // --> get QueryTimeout + + await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, timeoutToken); // + poll, up to QueryTimeout + Schema schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, timeoutToken); // + get the result, up to QueryTimeout + + // TODO: Ensure this is set dynamically based on server capabilities + return new QueryResult(-1, Connection.NewReader(this, schema)); + } + catch (OperationCanceledException ex) + { + throw new TimeoutException("The query execution timed out.", ex); + } + } + + private async Task GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) + { + TGetResultSetMetadataResp response = await HiveServer2Connection.GetResultSetMetadataAsync(operationHandle, client, cancellationToken); + return Connection.SchemaParser.GetArrowSchema(response.Schema, Connection.DataTypeConversion); + } + + public override async Task ExecuteUpdateAsync() + { + const string NumberOfAffectedRowsColumnName = "num_affected_rows"; + // TODO: Add CTS here --> this is inside of ExecuteQueryAsync + + QueryResult queryResult = await ExecuteQueryAsync(); + if (queryResult.Stream == null) + { + throw new AdbcException("no data found"); + } + + using IArrowArrayStream stream = queryResult.Stream; + + // Check if the affected rows columns are returned in the result. + Field affectedRowsField = stream.Schema.GetFieldByName(NumberOfAffectedRowsColumnName); + if (affectedRowsField != null && affectedRowsField.DataType.TypeId != Types.ArrowTypeId.Int64) + { + throw new AdbcException($"Unexpected data type for column: '{NumberOfAffectedRowsColumnName}'", new ArgumentException(NumberOfAffectedRowsColumnName)); + } + + // The default is -1. + if (affectedRowsField == null) return new UpdateResult(-1); + + long? affectedRows = null; + while (true) + { + using RecordBatch nextBatch = await stream.ReadNextRecordBatchAsync(); + if (nextBatch == null) { break; } + Int64Array numOfModifiedArray = (Int64Array)nextBatch.Column(NumberOfAffectedRowsColumnName); + // Note: should only have one item, but iterate for completeness + for (int i = 0; i < numOfModifiedArray.Length; i++) + { + // Note: handle the case where the affected rows are zero (0). + affectedRows = (affectedRows ?? 0) + numOfModifiedArray.GetValue(i).GetValueOrDefault(0); + } + } + + // If no altered rows, i.e. DDC statements, then -1 is the default. + return new UpdateResult(affectedRows ?? -1); + } + + public override void SetOption(string key, string value) + { + switch (key) + { + case ApacheParameters.PollTimeMilliseconds: + UpdatePollTimeIfValid(key, value); + break; + case ApacheParameters.BatchSize: + UpdateBatchSizeIfValid(key, value); + break; + case ApacheParameters.QueryTimeoutSeconds: + if (ApacheUtility.QueryTimeoutIsValid(key, value, out int queryTimeoutSeconds)) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + } + break; + default: + throw AdbcException.NotImplemented($"Option '{key}' is not implemented."); + } + } + + protected async Task ExecuteStatementAsync(CancellationToken cancellationToken = default) + { + TExecuteStatementReq executeRequest = new TExecuteStatementReq(Connection.SessionHandle, SqlQuery); + SetStatementProperties(executeRequest); + TExecuteStatementResp executeResponse = await Connection.Client.ExecuteStatement(executeRequest, cancellationToken); + if (executeResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(executeResponse.Status.ErrorMessage) + .SetSqlState(executeResponse.Status.SqlState) + .SetNativeError(executeResponse.Status.ErrorCode); + } + OperationHandle = executeResponse.OperationHandle; + } + + protected internal int PollTimeMilliseconds { get; private set; } = HiveServer2Connection.PollTimeMillisecondsDefault; + + protected internal long BatchSize { get; private set; } = HiveServer2Connection.BatchSizeDefault; + + protected internal int QueryTimeoutSeconds { get; set; } = ApacheUtility.QueryTimeoutSecondsDefault; + + public HiveServer2Connection Connection { get; private set; } + + public TOperationHandle? OperationHandle { get; private set; } + + private void UpdatePollTimeIfValid(string key, string value) => PollTimeMilliseconds = !string.IsNullOrEmpty(key) && int.TryParse(value, result: out int pollTimeMilliseconds) && pollTimeMilliseconds >= 0 + ? pollTimeMilliseconds + : throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than or equal to 0."); + + private void UpdateBatchSizeIfValid(string key, string value) => BatchSize = !string.IsNullOrEmpty(value) && long.TryParse(value, out long batchSize) && batchSize > 0 + ? batchSize + : throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than zero."); + + public override void Dispose() + { + if (OperationHandle != null) + { + TCloseOperationReq request = new TCloseOperationReq(OperationHandle); + Connection.Client.CloseOperation(request).Wait(); + OperationHandle = null; + } + + base.Dispose(); + } + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index 40a50e3e0c..e4c21c91de 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -1,1074 +1,1074 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Reflection; -using System.Text; -using System.Text.RegularExpressions; -using System.Threading; -using System.Threading.Tasks; -using Apache.Arrow.Adbc.Drivers.Apache.Hive2; -using Apache.Arrow.Adbc.Drivers.Apache.Thrift; -using Apache.Arrow.Adbc.Extensions; -using Apache.Arrow.Ipc; -using Apache.Arrow.Types; -using Apache.Hive.Service.Rpc.Thrift; - -namespace Apache.Arrow.Adbc.Drivers.Apache.Spark -{ - internal abstract class SparkConnection : HiveServer2Connection - { - internal static readonly string s_userAgent = $"{InfoDriverName.Replace(" ", "")}/{ProductVersionDefault}"; - - readonly AdbcInfoCode[] infoSupportedCodes = new[] { - AdbcInfoCode.DriverName, - AdbcInfoCode.DriverVersion, - AdbcInfoCode.DriverArrowVersion, - AdbcInfoCode.VendorName, - AdbcInfoCode.VendorSql, - AdbcInfoCode.VendorVersion, - }; - - const string ProductVersionDefault = "1.0.0"; - const string InfoDriverName = "ADBC Spark Driver"; - const string InfoDriverArrowVersion = "1.0.0"; - const bool InfoVendorSql = true; - const string ColumnDef = "COLUMN_DEF"; - const string ColumnName = "COLUMN_NAME"; - const string DataType = "DATA_TYPE"; - const string IsAutoIncrement = "IS_AUTO_INCREMENT"; - const string IsNullable = "IS_NULLABLE"; - const string OrdinalPosition = "ORDINAL_POSITION"; - const string TableCat = "TABLE_CAT"; - const string TableCatalog = "TABLE_CATALOG"; - const string TableName = "TABLE_NAME"; - const string TableSchem = "TABLE_SCHEM"; - const string TableType = "TABLE_TYPE"; - const string TypeName = "TYPE_NAME"; - const string Nullable = "NULLABLE"; - private readonly Lazy _productVersion; - - internal static TSparkGetDirectResults sparkGetDirectResults = new TSparkGetDirectResults(1000); - - internal static readonly Dictionary timestampConfig = new Dictionary - { - { "spark.thriftserver.arrowBasedRowSet.timestampAsString", "false" } - }; - - /// - /// The Spark data type definitions based on the JDBC Types constants. - /// - /// - /// This enumeration can be used to determine the Spark-specific data types that are contained in fields xdbc_data_type and xdbc_sql_data_type - /// in the column metadata . This column metadata is returned as a result of a call to - /// - /// when depth is set to . - /// - internal enum ColumnTypeId - { - // NOTE: There is a partial copy of this enumeration in test/Drivers/Apache/Spark/DriverTests.cs - // Please keep up-to-date. - // Copied from https://docs.oracle.com/en%2Fjava%2Fjavase%2F21%2Fdocs%2Fapi%2F%2F/constant-values.html#java.sql.Types.ARRAY - - /// - /// Identifies the generic SQL type ARRAY - /// - ARRAY = 2003, - /// - /// Identifies the generic SQL type BIGINT - /// - BIGINT = -5, - /// - /// Identifies the generic SQL type BINARY - /// - BINARY = -2, - /// - /// Identifies the generic SQL type BOOLEAN - /// - BOOLEAN = 16, - /// - /// Identifies the generic SQL type CHAR - /// - CHAR = 1, - /// - /// Identifies the generic SQL type DATE - /// - DATE = 91, - /// - /// Identifies the generic SQL type DECIMAL - /// - DECIMAL = 3, - /// - /// Identifies the generic SQL type DOUBLE - /// - DOUBLE = 8, - /// - /// Identifies the generic SQL type FLOAT - /// - FLOAT = 6, - /// - /// Identifies the generic SQL type INTEGER - /// - INTEGER = 4, - /// - /// Identifies the generic SQL type JAVA_OBJECT (MAP) - /// - JAVA_OBJECT = 2000, - /// - /// identifies the generic SQL type LONGNVARCHAR - /// - LONGNVARCHAR = -16, - /// - /// identifies the generic SQL type LONGVARBINARY - /// - LONGVARBINARY = -4, - /// - /// identifies the generic SQL type LONGVARCHAR - /// - LONGVARCHAR = -1, - /// - /// identifies the generic SQL type NCHAR - /// - NCHAR = -15, - /// - /// identifies the generic SQL type NULL - /// - NULL = 0, - /// - /// identifies the generic SQL type NUMERIC - /// - NUMERIC = 2, - /// - /// identifies the generic SQL type NVARCHAR - /// - NVARCHAR = -9, - /// - /// identifies the generic SQL type REAL - /// - REAL = 7, - /// - /// Identifies the generic SQL type SMALLINT - /// - SMALLINT = 5, - /// - /// Identifies the generic SQL type STRUCT - /// - STRUCT = 2002, - /// - /// Identifies the generic SQL type TIMESTAMP - /// - TIMESTAMP = 93, - /// - /// Identifies the generic SQL type TINYINT - /// - TINYINT = -6, - /// - /// Identifies the generic SQL type VARBINARY - /// - VARBINARY = -3, - /// - /// Identifies the generic SQL type VARCHAR - /// - VARCHAR = 12, - // ====================== - // Unused/unsupported - // ====================== - /// - /// Identifies the generic SQL type BIT - /// - BIT = -7, - /// - /// Identifies the generic SQL type BLOB - /// - BLOB = 2004, - /// - /// Identifies the generic SQL type CLOB - /// - CLOB = 2005, - /// - /// Identifies the generic SQL type DATALINK - /// - DATALINK = 70, - /// - /// Identifies the generic SQL type DISTINCT - /// - DISTINCT = 2001, - /// - /// identifies the generic SQL type NCLOB - /// - NCLOB = 2011, - /// - /// Indicates that the SQL type is database-specific and gets mapped to a Java object - /// - OTHER = 1111, - /// - /// Identifies the generic SQL type REF CURSOR - /// - REF_CURSOR = 2012, - /// - /// Identifies the generic SQL type REF - /// - REF = 2006, - /// - /// Identifies the generic SQL type ROWID - /// - ROWID = -8, - /// - /// Identifies the generic SQL type XML - /// - SQLXML = 2009, - /// - /// Identifies the generic SQL type TIME - /// - TIME = 92, - /// - /// Identifies the generic SQL type TIME WITH TIMEZONE - /// - TIME_WITH_TIMEZONE = 2013, - /// - /// Identifies the generic SQL type TIMESTAMP WITH TIMEZONE - /// - TIMESTAMP_WITH_TIMEZONE = 2014, - } - - internal SparkConnection(IReadOnlyDictionary properties) - : base(properties) - { - ValidateProperties(); - _productVersion = new Lazy(() => GetProductVersion(), LazyThreadSafetyMode.PublicationOnly); - } - - private void ValidateProperties() - { - ValidateAuthentication(); - ValidateConnection(); - ValidateOptions(); - } - - protected string ProductVersion => _productVersion.Value; - - public override AdbcStatement CreateStatement() - { - return new SparkStatement(this); - } - - public override IArrowArrayStream GetInfo(IReadOnlyList codes) - { - const int strValTypeID = 0; - const int boolValTypeId = 1; - - UnionType infoUnionType = new UnionType( - new Field[] - { - new Field("string_value", StringType.Default, true), - new Field("bool_value", BooleanType.Default, true), - new Field("int64_value", Int64Type.Default, true), - new Field("int32_bitmask", Int32Type.Default, true), - new Field( - "string_list", - new ListType( - new Field("item", StringType.Default, true) - ), - false - ), - new Field( - "int32_to_int32_list_map", - new ListType( - new Field("entries", new StructType( - new Field[] - { - new Field("key", Int32Type.Default, false), - new Field("value", Int32Type.Default, true), - } - ), false) - ), - true - ) - }, - new int[] { 0, 1, 2, 3, 4, 5 }, - UnionMode.Dense); - - if (codes.Count == 0) - { - codes = infoSupportedCodes; - } - - UInt32Array.Builder infoNameBuilder = new UInt32Array.Builder(); - ArrowBuffer.Builder typeBuilder = new ArrowBuffer.Builder(); - ArrowBuffer.Builder offsetBuilder = new ArrowBuffer.Builder(); - StringArray.Builder stringInfoBuilder = new StringArray.Builder(); - BooleanArray.Builder booleanInfoBuilder = new BooleanArray.Builder(); - - int nullCount = 0; - int arrayLength = codes.Count; - int offset = 0; - - foreach (AdbcInfoCode code in codes) - { - switch (code) - { - case AdbcInfoCode.DriverName: - infoNameBuilder.Append((UInt32)code); - typeBuilder.Append(strValTypeID); - offsetBuilder.Append(offset++); - stringInfoBuilder.Append(InfoDriverName); - booleanInfoBuilder.AppendNull(); - break; - case AdbcInfoCode.DriverVersion: - infoNameBuilder.Append((UInt32)code); - typeBuilder.Append(strValTypeID); - offsetBuilder.Append(offset++); - stringInfoBuilder.Append(ProductVersion); - booleanInfoBuilder.AppendNull(); - break; - case AdbcInfoCode.DriverArrowVersion: - infoNameBuilder.Append((UInt32)code); - typeBuilder.Append(strValTypeID); - offsetBuilder.Append(offset++); - stringInfoBuilder.Append(InfoDriverArrowVersion); - booleanInfoBuilder.AppendNull(); - break; - case AdbcInfoCode.VendorName: - infoNameBuilder.Append((UInt32)code); - typeBuilder.Append(strValTypeID); - offsetBuilder.Append(offset++); - string vendorName = VendorName; - stringInfoBuilder.Append(vendorName); - booleanInfoBuilder.AppendNull(); - break; - case AdbcInfoCode.VendorVersion: - infoNameBuilder.Append((UInt32)code); - typeBuilder.Append(strValTypeID); - offsetBuilder.Append(offset++); - string? vendorVersion = VendorVersion; - stringInfoBuilder.Append(vendorVersion); - booleanInfoBuilder.AppendNull(); - break; - case AdbcInfoCode.VendorSql: - infoNameBuilder.Append((UInt32)code); - typeBuilder.Append(boolValTypeId); - offsetBuilder.Append(offset++); - stringInfoBuilder.AppendNull(); - booleanInfoBuilder.Append(InfoVendorSql); - break; - default: - infoNameBuilder.Append((UInt32)code); - typeBuilder.Append(strValTypeID); - offsetBuilder.Append(offset++); - stringInfoBuilder.AppendNull(); - booleanInfoBuilder.AppendNull(); - nullCount++; - break; - } - } - - StructType entryType = new StructType( - new Field[] { - new Field("key", Int32Type.Default, false), - new Field("value", Int32Type.Default, true)}); - - StructArray entriesDataArray = new StructArray(entryType, 0, - new[] { new Int32Array.Builder().Build(), new Int32Array.Builder().Build() }, - new ArrowBuffer.BitmapBuilder().Build()); - - IArrowArray[] childrenArrays = new IArrowArray[] - { - stringInfoBuilder.Build(), - booleanInfoBuilder.Build(), - new Int64Array.Builder().Build(), - new Int32Array.Builder().Build(), - new ListArray.Builder(StringType.Default).Build(), - new List(){ entriesDataArray }.BuildListArrayForType(entryType) - }; - - DenseUnionArray infoValue = new DenseUnionArray(infoUnionType, arrayLength, childrenArrays, typeBuilder.Build(), offsetBuilder.Build(), nullCount); - - IArrowArray[] dataArrays = new IArrowArray[] - { - infoNameBuilder.Build(), - infoValue - }; - StandardSchemas.GetInfoSchema.Validate(dataArrays); - - return new SparkInfoArrowStream(StandardSchemas.GetInfoSchema, dataArrays); - - } - - public override IArrowArrayStream GetTableTypes() - { - TGetTableTypesReq req = new() - { - SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), - GetDirectResults = sparkGetDirectResults - }; - - CancellationToken timeoutToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); - - TGetTableTypesResp resp = Client.GetTableTypes(req, timeoutToken).Result; - - if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new HiveServer2Exception(resp.Status.ErrorMessage) - .SetNativeError(resp.Status.ErrorCode) - .SetSqlState(resp.Status.SqlState); - } - - TRowSet rowSet = GetRowSetAsync(resp).Result; - StringArray tableTypes = rowSet.Columns[0].StringVal.Values; - - StringArray.Builder tableTypesBuilder = new StringArray.Builder(); - tableTypesBuilder.AppendRange(tableTypes); - - IArrowArray[] dataArrays = new IArrowArray[] - { - tableTypesBuilder.Build() - }; - - return new SparkInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays); - } - - public override Schema GetTableSchema(string? catalog, string? dbSchema, string? tableName) - { - TGetColumnsReq getColumnsReq = new TGetColumnsReq(SessionHandle); - getColumnsReq.CatalogName = catalog; - getColumnsReq.SchemaName = dbSchema; - getColumnsReq.TableName = tableName; - getColumnsReq.GetDirectResults = sparkGetDirectResults; - - CancellationToken timeoutToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); - - var columnsResponse = Client.GetColumns(getColumnsReq, timeoutToken).Result; - if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(columnsResponse.Status.ErrorMessage); - } - - TRowSet rowSet = GetRowSetAsync(columnsResponse).Result; - List columns = rowSet.Columns; - int rowCount = rowSet.Columns[3].StringVal.Values.Length; - - Field[] fields = new Field[rowCount]; - for (int i = 0; i < rowCount; i++) - { - string columnName = columns[3].StringVal.Values.GetString(i); - int? columnType = columns[4].I32Val.Values.GetValue(i); - string typeName = columns[5].StringVal.Values.GetString(i); - // Note: the following two columns do not seem to be set correctly for DECIMAL types. - //int? columnSize = columns[6].I32Val.Values.GetValue(i); - //int? decimalDigits = columns[8].I32Val.Values.GetValue(i); - bool nullable = columns[10].I32Val.Values.GetValue(i) == 1; - IArrowType dataType = SparkConnection.GetArrowType(columnType!.Value, typeName); - fields[i] = new Field(columnName, dataType, nullable); - } - return new Schema(fields, null); - } - - public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList? tableTypes, string? columnNamePattern) - { - Trace.TraceError($"getting objects with depth={depth.ToString()}, catalog = {catalogPattern}, dbschema = {dbSchemaPattern}, tablename = {tableNamePattern}"); - - Dictionary>> catalogMap = new Dictionary>>(); - CancellationToken timeoutToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); - - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) - { - TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); - getCatalogsReq.GetDirectResults = sparkGetDirectResults; - - TGetCatalogsResp getCatalogsResp = Client.GetCatalogs(getCatalogsReq, timeoutToken).Result; - - if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(getCatalogsResp.Status.ErrorMessage); - } - var catalogsMetadata = GetResultSetMetadataAsync(getCatalogsResp).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(catalogsMetadata.Schema.Columns); - - string catalogRegexp = PatternToRegEx(catalogPattern); - TRowSet rowSet = GetRowSetAsync(getCatalogsResp).Result; - IReadOnlyList list = rowSet.Columns[columnMap[TableCat]].StringVal.Values; - for (int i = 0; i < list.Count; i++) - { - string col = list[i]; - string catalog = col; - - if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) - { - catalogMap.Add(catalog, new Dictionary>()); - } - } - // Handle the case where server does not support 'catalog' in the namespace. - if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern)) - { - catalogMap.Add(string.Empty, []); - } - } - - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) - { - TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); - getSchemasReq.CatalogName = catalogPattern; - getSchemasReq.SchemaName = dbSchemaPattern; - getSchemasReq.GetDirectResults = sparkGetDirectResults; - - TGetSchemasResp getSchemasResp = Client.GetSchemas(getSchemasReq, timeoutToken).Result; - if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(getSchemasResp.Status.ErrorMessage); - } - - TGetResultSetMetadataResp schemaMetadata = GetResultSetMetadataAsync(getSchemasResp).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(schemaMetadata.Schema.Columns); - TRowSet rowSet = GetRowSetAsync(getSchemasResp).Result; - - IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCatalog]].StringVal.Values; - IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; - - for (int i = 0; i < catalogList.Count; i++) - { - string catalog = catalogList[i]; - string schemaDb = schemaList[i]; - // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). - catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary()); - } - } - - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) - { - TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle); - getTablesReq.CatalogName = catalogPattern; - getTablesReq.SchemaName = dbSchemaPattern; - getTablesReq.TableName = tableNamePattern; - getTablesReq.GetDirectResults = sparkGetDirectResults; - - TGetTablesResp getTablesResp = Client.GetTables(getTablesReq, timeoutToken).Result; - if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(getTablesResp.Status.ErrorMessage); - } - - TGetResultSetMetadataResp tableMetadata = GetResultSetMetadataAsync(getTablesResp).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(tableMetadata.Schema.Columns); - TRowSet rowSet = GetRowSetAsync(getTablesResp).Result; - - IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; - IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; - IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; - IReadOnlyList tableTypeList = rowSet.Columns[columnMap[TableType]].StringVal.Values; - - for (int i = 0; i < catalogList.Count; i++) - { - string catalog = catalogList[i]; - string schemaDb = schemaList[i]; - string tableName = tableList[i]; - string tableType = tableTypeList[i]; - TableInfo tableInfo = new(tableType); - catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); - } - } - - if (depth == GetObjectsDepth.All) - { - TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle); - columnsReq.CatalogName = catalogPattern; - columnsReq.SchemaName = dbSchemaPattern; - columnsReq.TableName = tableNamePattern; - columnsReq.GetDirectResults = sparkGetDirectResults; - - if (!string.IsNullOrEmpty(columnNamePattern)) - columnsReq.ColumnName = columnNamePattern; - - var columnsResponse = Client.GetColumns(columnsReq, timeoutToken).Result; - if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(columnsResponse.Status.ErrorMessage); - } - - TGetResultSetMetadataResp columnsMetadata = GetResultSetMetadataAsync(columnsResponse).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(columnsMetadata.Schema.Columns); - TRowSet rowSet = GetRowSetAsync(columnsResponse).Result; - - IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; - IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; - IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; - IReadOnlyList columnNameList = rowSet.Columns[columnMap[ColumnName]].StringVal.Values; - ReadOnlySpan columnTypeList = rowSet.Columns[columnMap[DataType]].I32Val.Values.Values; - IReadOnlyList typeNameList = rowSet.Columns[columnMap[TypeName]].StringVal.Values; - ReadOnlySpan nullableList = rowSet.Columns[columnMap[Nullable]].I32Val.Values.Values; - IReadOnlyList columnDefaultList = rowSet.Columns[columnMap[ColumnDef]].StringVal.Values; - ReadOnlySpan ordinalPosList = rowSet.Columns[columnMap[OrdinalPosition]].I32Val.Values.Values; - IReadOnlyList isNullableList = rowSet.Columns[columnMap[IsNullable]].StringVal.Values; - IReadOnlyList isAutoIncrementList = rowSet.Columns[columnMap[IsAutoIncrement]].StringVal.Values; - - for (int i = 0; i < catalogList.Count; i++) - { - // For systems that don't support 'catalog' in the namespace - string catalog = catalogList[i] ?? string.Empty; - string schemaDb = schemaList[i]; - string tableName = tableList[i]; - string columnName = columnNameList[i]; - short colType = (short)columnTypeList[i]; - string typeName = typeNameList[i]; - short nullable = (short)nullableList[i]; - string? isAutoIncrementString = isAutoIncrementList[i]; - bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); - string isNullable = isNullableList[i] ?? "YES"; - string columnDefault = columnDefaultList[i] ?? ""; - // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed - int ordinalPos = ordinalPosList[i] + 1; - TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); - tableInfo?.ColumnName.Add(columnName); - tableInfo?.ColType.Add(colType); - tableInfo?.Nullable.Add(nullable); - tableInfo?.IsAutoIncrement.Add(isAutoIncrement); - tableInfo?.IsNullable.Add(isNullable); - tableInfo?.ColumnDefault.Add(columnDefault); - tableInfo?.OrdinalPosition.Add(ordinalPos); - SetPrecisionScaleAndTypeName(colType, typeName, tableInfo); - } - } - - StringArray.Builder catalogNameBuilder = new StringArray.Builder(); - List catalogDbSchemasValues = new List(); - - foreach (KeyValuePair>> catalogEntry in catalogMap) - { - catalogNameBuilder.Append(catalogEntry.Key); - - if (depth == GetObjectsDepth.Catalogs) - { - catalogDbSchemasValues.Add(null); - } - else - { - catalogDbSchemasValues.Add(GetDbSchemas( - depth, catalogEntry.Value)); - } - } - - Schema schema = StandardSchemas.GetObjectsSchema; - IReadOnlyList dataArrays = schema.Validate( - new List - { - catalogNameBuilder.Build(), - catalogDbSchemasValues.BuildListArrayForType(new StructType(StandardSchemas.DbSchemaSchema)), - }); - - return new SparkInfoArrowStream(schema, dataArrays); - } - - private static IReadOnlyDictionary GetColumnIndexMap(List columns) => columns - .Select(t => new { Index = t.Position - 1, t.ColumnName }) - .ToDictionary(t => t.ColumnName, t => t.Index); - - private static void SetPrecisionScaleAndTypeName(short colType, string typeName, TableInfo? tableInfo) - { - // Keep the original type name - tableInfo?.TypeName.Add(typeName); - switch (colType) - { - case (short)ColumnTypeId.DECIMAL: - case (short)ColumnTypeId.NUMERIC: - { - SqlDecimalParserResult result = SqlTypeNameParser.Parse(typeName, colType); - tableInfo?.Precision.Add(result.Precision); - tableInfo?.Scale.Add((short)result.Scale); - tableInfo?.BaseTypeName.Add(result.BaseTypeName); - break; - } - - case (short)ColumnTypeId.CHAR: - case (short)ColumnTypeId.NCHAR: - case (short)ColumnTypeId.VARCHAR: - case (short)ColumnTypeId.LONGVARCHAR: - case (short)ColumnTypeId.LONGNVARCHAR: - case (short)ColumnTypeId.NVARCHAR: - { - SqlCharVarcharParserResult result = SqlTypeNameParser.Parse(typeName, colType); - tableInfo?.Precision.Add(result.ColumnSize); - tableInfo?.Scale.Add(null); - tableInfo?.BaseTypeName.Add(result.BaseTypeName); - break; - } - - default: - { - SqlTypeNameParserResult result = SqlTypeNameParser.Parse(typeName, colType); - tableInfo?.Precision.Add(null); - tableInfo?.Scale.Add(null); - tableInfo?.BaseTypeName.Add(result.BaseTypeName); - break; - } - } - } - - private static IArrowType GetArrowType(int columnTypeId, string typeName) - { - switch (columnTypeId) - { - case (int)ColumnTypeId.BOOLEAN: - return BooleanType.Default; - case (int)ColumnTypeId.TINYINT: - return Int8Type.Default; - case (int)ColumnTypeId.SMALLINT: - return Int16Type.Default; - case (int)ColumnTypeId.INTEGER: - return Int32Type.Default; - case (int)ColumnTypeId.BIGINT: - return Int64Type.Default; - case (int)ColumnTypeId.FLOAT: - case (int)ColumnTypeId.REAL: - return FloatType.Default; - case (int)ColumnTypeId.DOUBLE: - return DoubleType.Default; - case (int)ColumnTypeId.VARCHAR: - case (int)ColumnTypeId.NVARCHAR: - case (int)ColumnTypeId.LONGVARCHAR: - case (int)ColumnTypeId.LONGNVARCHAR: - return StringType.Default; - case (int)ColumnTypeId.TIMESTAMP: - return new TimestampType(TimeUnit.Microsecond, timezone: (string?)null); - case (int)ColumnTypeId.BINARY: - case (int)ColumnTypeId.VARBINARY: - case (int)ColumnTypeId.LONGVARBINARY: - return BinaryType.Default; - case (int)ColumnTypeId.DATE: - return Date32Type.Default; - case (int)ColumnTypeId.CHAR: - case (int)ColumnTypeId.NCHAR: - return StringType.Default; - case (int)ColumnTypeId.DECIMAL: - case (int)ColumnTypeId.NUMERIC: - // Note: parsing the type name for SQL DECIMAL types as the precision and scale values - // are not returned in the Thrift call to GetColumns - return SqlTypeNameParser - .Parse(typeName, columnTypeId) - .Decimal128Type; - case (int)ColumnTypeId.NULL: - return NullType.Default; - case (int)ColumnTypeId.ARRAY: - case (int)ColumnTypeId.JAVA_OBJECT: - case (int)ColumnTypeId.STRUCT: - return StringType.Default; - default: - throw new NotImplementedException($"Column type id: {columnTypeId} is not supported."); - } - } - - private static StructArray GetDbSchemas( - GetObjectsDepth depth, - Dictionary> schemaMap) - { - StringArray.Builder dbSchemaNameBuilder = new StringArray.Builder(); - List dbSchemaTablesValues = new List(); - ArrowBuffer.BitmapBuilder nullBitmapBuffer = new ArrowBuffer.BitmapBuilder(); - int length = 0; - - - foreach (KeyValuePair> schemaEntry in schemaMap) - { - - dbSchemaNameBuilder.Append(schemaEntry.Key); - length++; - nullBitmapBuffer.Append(true); - - if (depth == GetObjectsDepth.DbSchemas) - { - dbSchemaTablesValues.Add(null); - } - else - { - dbSchemaTablesValues.Add(GetTableSchemas( - depth, schemaEntry.Value)); - } - - } - - IReadOnlyList schema = StandardSchemas.DbSchemaSchema; - IReadOnlyList dataArrays = schema.Validate( - new List - { - dbSchemaNameBuilder.Build(), - dbSchemaTablesValues.BuildListArrayForType(new StructType(StandardSchemas.TableSchema)), - }); - - return new StructArray( - new StructType(schema), - length, - dataArrays, - nullBitmapBuffer.Build()); - } - - private static StructArray GetTableSchemas( - GetObjectsDepth depth, - Dictionary tableMap) - { - StringArray.Builder tableNameBuilder = new StringArray.Builder(); - StringArray.Builder tableTypeBuilder = new StringArray.Builder(); - List tableColumnsValues = new List(); - List tableConstraintsValues = new List(); - ArrowBuffer.BitmapBuilder nullBitmapBuffer = new ArrowBuffer.BitmapBuilder(); - int length = 0; - - - foreach (KeyValuePair tableEntry in tableMap) - { - tableNameBuilder.Append(tableEntry.Key); - tableTypeBuilder.Append(tableEntry.Value.Type); - nullBitmapBuffer.Append(true); - length++; - - - tableConstraintsValues.Add(null); - - - if (depth == GetObjectsDepth.Tables) - { - tableColumnsValues.Add(null); - } - else - { - tableColumnsValues.Add(GetColumnSchema(tableEntry.Value)); - } - } - - - IReadOnlyList schema = StandardSchemas.TableSchema; - IReadOnlyList dataArrays = schema.Validate( - new List - { - tableNameBuilder.Build(), - tableTypeBuilder.Build(), - tableColumnsValues.BuildListArrayForType(new StructType(StandardSchemas.ColumnSchema)), - tableConstraintsValues.BuildListArrayForType( new StructType(StandardSchemas.ConstraintSchema)) - }); - - return new StructArray( - new StructType(schema), - length, - dataArrays, - nullBitmapBuffer.Build()); - } - - private static StructArray GetColumnSchema(TableInfo tableInfo) - { - StringArray.Builder columnNameBuilder = new StringArray.Builder(); - Int32Array.Builder ordinalPositionBuilder = new Int32Array.Builder(); - StringArray.Builder remarksBuilder = new StringArray.Builder(); - Int16Array.Builder xdbcDataTypeBuilder = new Int16Array.Builder(); - StringArray.Builder xdbcTypeNameBuilder = new StringArray.Builder(); - Int32Array.Builder xdbcColumnSizeBuilder = new Int32Array.Builder(); - Int16Array.Builder xdbcDecimalDigitsBuilder = new Int16Array.Builder(); - Int16Array.Builder xdbcNumPrecRadixBuilder = new Int16Array.Builder(); - Int16Array.Builder xdbcNullableBuilder = new Int16Array.Builder(); - StringArray.Builder xdbcColumnDefBuilder = new StringArray.Builder(); - Int16Array.Builder xdbcSqlDataTypeBuilder = new Int16Array.Builder(); - Int16Array.Builder xdbcDatetimeSubBuilder = new Int16Array.Builder(); - Int32Array.Builder xdbcCharOctetLengthBuilder = new Int32Array.Builder(); - StringArray.Builder xdbcIsNullableBuilder = new StringArray.Builder(); - StringArray.Builder xdbcScopeCatalogBuilder = new StringArray.Builder(); - StringArray.Builder xdbcScopeSchemaBuilder = new StringArray.Builder(); - StringArray.Builder xdbcScopeTableBuilder = new StringArray.Builder(); - BooleanArray.Builder xdbcIsAutoincrementBuilder = new BooleanArray.Builder(); - BooleanArray.Builder xdbcIsGeneratedcolumnBuilder = new BooleanArray.Builder(); - ArrowBuffer.BitmapBuilder nullBitmapBuffer = new ArrowBuffer.BitmapBuilder(); - int length = 0; - - - for (int i = 0; i < tableInfo.ColumnName.Count; i++) - { - columnNameBuilder.Append(tableInfo.ColumnName[i]); - ordinalPositionBuilder.Append(tableInfo.OrdinalPosition[i]); - // Use the "remarks" field to store the original type name value - remarksBuilder.Append(tableInfo.TypeName[i]); - xdbcColumnSizeBuilder.Append(tableInfo.Precision[i]); - xdbcDecimalDigitsBuilder.Append(tableInfo.Scale[i]); - xdbcDataTypeBuilder.Append(tableInfo.ColType[i]); - // Just the base type name without precision or scale clause - xdbcTypeNameBuilder.Append(tableInfo.BaseTypeName[i]); - xdbcNumPrecRadixBuilder.AppendNull(); - xdbcNullableBuilder.Append(tableInfo.Nullable[i]); - xdbcColumnDefBuilder.Append(tableInfo.ColumnDefault[i]); - xdbcSqlDataTypeBuilder.Append(tableInfo.ColType[i]); - xdbcDatetimeSubBuilder.AppendNull(); - xdbcCharOctetLengthBuilder.AppendNull(); - xdbcIsNullableBuilder.Append(tableInfo.IsNullable[i]); - xdbcScopeCatalogBuilder.AppendNull(); - xdbcScopeSchemaBuilder.AppendNull(); - xdbcScopeTableBuilder.AppendNull(); - xdbcIsAutoincrementBuilder.Append(tableInfo.IsAutoIncrement[i]); - xdbcIsGeneratedcolumnBuilder.Append(true); - nullBitmapBuffer.Append(true); - length++; - } - - IReadOnlyList schema = StandardSchemas.ColumnSchema; - IReadOnlyList dataArrays = schema.Validate( - new List - { - columnNameBuilder.Build(), - ordinalPositionBuilder.Build(), - remarksBuilder.Build(), - xdbcDataTypeBuilder.Build(), - xdbcTypeNameBuilder.Build(), - xdbcColumnSizeBuilder.Build(), - xdbcDecimalDigitsBuilder.Build(), - xdbcNumPrecRadixBuilder.Build(), - xdbcNullableBuilder.Build(), - xdbcColumnDefBuilder.Build(), - xdbcSqlDataTypeBuilder.Build(), - xdbcDatetimeSubBuilder.Build(), - xdbcCharOctetLengthBuilder.Build(), - xdbcIsNullableBuilder.Build(), - xdbcScopeCatalogBuilder.Build(), - xdbcScopeSchemaBuilder.Build(), - xdbcScopeTableBuilder.Build(), - xdbcIsAutoincrementBuilder.Build(), - xdbcIsGeneratedcolumnBuilder.Build() - }); - - return new StructArray( - new StructType(schema), - length, - dataArrays, - nullBitmapBuffer.Build()); - } - - private static string PatternToRegEx(string? pattern) - { - if (pattern == null) - return ".*"; - - StringBuilder builder = new StringBuilder("(?i)^"); - string convertedPattern = pattern.Replace("_", ".").Replace("%", ".*"); - builder.Append(convertedPattern); - builder.Append('$'); - - return builder.ToString(); - } - - - private static string GetProductVersion() - { - FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location); - return fileVersionInfo.ProductVersion ?? ProductVersionDefault; - } - - protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, string? port) - { - // Uri property takes precedent. - if (!string.IsNullOrWhiteSpace(uri)) - { - var uriValue = new Uri(uri); - if (uriValue.Scheme != Uri.UriSchemeHttp && uriValue.Scheme != Uri.UriSchemeHttps) - throw new ArgumentOutOfRangeException( - AdbcOptions.Uri, - uri, - $"Unsupported scheme '{uriValue.Scheme}'"); - return uriValue; - } - - bool isPortSet = !string.IsNullOrEmpty(port); - bool isValidPortNumber = int.TryParse(port, out int portNumber) && portNumber > 0; - bool isDefaultHttpsPort = !isPortSet || (isValidPortNumber && portNumber == 443); - string uriScheme = isDefaultHttpsPort ? Uri.UriSchemeHttps : Uri.UriSchemeHttp; - int uriPort; - if (!isPortSet) - uriPort = -1; - else if (isValidPortNumber) - uriPort = portNumber; - else - throw new ArgumentOutOfRangeException(nameof(port), portNumber, $"Port number is not in a valid range."); - - Uri baseAddress = new UriBuilder(uriScheme, hostName, uriPort, path).Uri; - return baseAddress; - } - - protected abstract void ValidateConnection(); - protected abstract void ValidateAuthentication(); - protected abstract void ValidateOptions(); - - protected abstract Task GetRowSetAsync(TGetTableTypesResp response); - protected abstract Task GetRowSetAsync(TGetColumnsResp response); - protected abstract Task GetRowSetAsync(TGetTablesResp response); - protected abstract Task GetRowSetAsync(TGetCatalogsResp getCatalogsResp); - protected abstract Task GetRowSetAsync(TGetSchemasResp getSchemasResp); - protected abstract Task GetResultSetMetadataAsync(TGetSchemasResp response); - protected abstract Task GetResultSetMetadataAsync(TGetCatalogsResp response); - protected abstract Task GetResultSetMetadataAsync(TGetColumnsResp response); - protected abstract Task GetResultSetMetadataAsync(TGetTablesResp response); - - internal abstract SparkServerType ServerType { get; } - - internal struct TableInfo(string type) - { - public string Type { get; } = type; - - public List ColumnName { get; } = new(); - - public List ColType { get; } = new(); - - public List BaseTypeName { get; } = new(); - - public List TypeName { get; } = new(); - - public List Nullable { get; } = new(); - - public List Precision { get; } = new(); - - public List Scale { get; } = new(); - - public List OrdinalPosition { get; } = new(); - - public List ColumnDefault { get; } = new(); - - public List IsNullable { get; } = new(); - - public List IsAutoIncrement { get; } = new(); - } - - internal class SparkInfoArrowStream : IArrowArrayStream - { - private Schema schema; - private RecordBatch? batch; - - public SparkInfoArrowStream(Schema schema, IReadOnlyList data) - { - this.schema = schema; - this.batch = new RecordBatch(schema, data, data[0].Length); - } - - public Schema Schema { get { return this.schema; } } - - public ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) - { - RecordBatch? batch = this.batch; - this.batch = null; - return new ValueTask(batch); - } - - public void Dispose() - { - this.batch?.Dispose(); - this.batch = null; - } - } - } -} +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Apache.Thrift; +using Apache.Arrow.Adbc.Extensions; +using Apache.Arrow.Ipc; +using Apache.Arrow.Types; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal abstract class SparkConnection : HiveServer2Connection + { + internal static readonly string s_userAgent = $"{InfoDriverName.Replace(" ", "")}/{ProductVersionDefault}"; + + readonly AdbcInfoCode[] infoSupportedCodes = new[] { + AdbcInfoCode.DriverName, + AdbcInfoCode.DriverVersion, + AdbcInfoCode.DriverArrowVersion, + AdbcInfoCode.VendorName, + AdbcInfoCode.VendorSql, + AdbcInfoCode.VendorVersion, + }; + + const string ProductVersionDefault = "1.0.0"; + const string InfoDriverName = "ADBC Spark Driver"; + const string InfoDriverArrowVersion = "1.0.0"; + const bool InfoVendorSql = true; + const string ColumnDef = "COLUMN_DEF"; + const string ColumnName = "COLUMN_NAME"; + const string DataType = "DATA_TYPE"; + const string IsAutoIncrement = "IS_AUTO_INCREMENT"; + const string IsNullable = "IS_NULLABLE"; + const string OrdinalPosition = "ORDINAL_POSITION"; + const string TableCat = "TABLE_CAT"; + const string TableCatalog = "TABLE_CATALOG"; + const string TableName = "TABLE_NAME"; + const string TableSchem = "TABLE_SCHEM"; + const string TableType = "TABLE_TYPE"; + const string TypeName = "TYPE_NAME"; + const string Nullable = "NULLABLE"; + private readonly Lazy _productVersion; + + internal static TSparkGetDirectResults sparkGetDirectResults = new TSparkGetDirectResults(1000); + + internal static readonly Dictionary timestampConfig = new Dictionary + { + { "spark.thriftserver.arrowBasedRowSet.timestampAsString", "false" } + }; + + /// + /// The Spark data type definitions based on the JDBC Types constants. + /// + /// + /// This enumeration can be used to determine the Spark-specific data types that are contained in fields xdbc_data_type and xdbc_sql_data_type + /// in the column metadata . This column metadata is returned as a result of a call to + /// + /// when depth is set to . + /// + internal enum ColumnTypeId + { + // NOTE: There is a partial copy of this enumeration in test/Drivers/Apache/Spark/DriverTests.cs + // Please keep up-to-date. + // Copied from https://docs.oracle.com/en%2Fjava%2Fjavase%2F21%2Fdocs%2Fapi%2F%2F/constant-values.html#java.sql.Types.ARRAY + + /// + /// Identifies the generic SQL type ARRAY + /// + ARRAY = 2003, + /// + /// Identifies the generic SQL type BIGINT + /// + BIGINT = -5, + /// + /// Identifies the generic SQL type BINARY + /// + BINARY = -2, + /// + /// Identifies the generic SQL type BOOLEAN + /// + BOOLEAN = 16, + /// + /// Identifies the generic SQL type CHAR + /// + CHAR = 1, + /// + /// Identifies the generic SQL type DATE + /// + DATE = 91, + /// + /// Identifies the generic SQL type DECIMAL + /// + DECIMAL = 3, + /// + /// Identifies the generic SQL type DOUBLE + /// + DOUBLE = 8, + /// + /// Identifies the generic SQL type FLOAT + /// + FLOAT = 6, + /// + /// Identifies the generic SQL type INTEGER + /// + INTEGER = 4, + /// + /// Identifies the generic SQL type JAVA_OBJECT (MAP) + /// + JAVA_OBJECT = 2000, + /// + /// identifies the generic SQL type LONGNVARCHAR + /// + LONGNVARCHAR = -16, + /// + /// identifies the generic SQL type LONGVARBINARY + /// + LONGVARBINARY = -4, + /// + /// identifies the generic SQL type LONGVARCHAR + /// + LONGVARCHAR = -1, + /// + /// identifies the generic SQL type NCHAR + /// + NCHAR = -15, + /// + /// identifies the generic SQL type NULL + /// + NULL = 0, + /// + /// identifies the generic SQL type NUMERIC + /// + NUMERIC = 2, + /// + /// identifies the generic SQL type NVARCHAR + /// + NVARCHAR = -9, + /// + /// identifies the generic SQL type REAL + /// + REAL = 7, + /// + /// Identifies the generic SQL type SMALLINT + /// + SMALLINT = 5, + /// + /// Identifies the generic SQL type STRUCT + /// + STRUCT = 2002, + /// + /// Identifies the generic SQL type TIMESTAMP + /// + TIMESTAMP = 93, + /// + /// Identifies the generic SQL type TINYINT + /// + TINYINT = -6, + /// + /// Identifies the generic SQL type VARBINARY + /// + VARBINARY = -3, + /// + /// Identifies the generic SQL type VARCHAR + /// + VARCHAR = 12, + // ====================== + // Unused/unsupported + // ====================== + /// + /// Identifies the generic SQL type BIT + /// + BIT = -7, + /// + /// Identifies the generic SQL type BLOB + /// + BLOB = 2004, + /// + /// Identifies the generic SQL type CLOB + /// + CLOB = 2005, + /// + /// Identifies the generic SQL type DATALINK + /// + DATALINK = 70, + /// + /// Identifies the generic SQL type DISTINCT + /// + DISTINCT = 2001, + /// + /// identifies the generic SQL type NCLOB + /// + NCLOB = 2011, + /// + /// Indicates that the SQL type is database-specific and gets mapped to a Java object + /// + OTHER = 1111, + /// + /// Identifies the generic SQL type REF CURSOR + /// + REF_CURSOR = 2012, + /// + /// Identifies the generic SQL type REF + /// + REF = 2006, + /// + /// Identifies the generic SQL type ROWID + /// + ROWID = -8, + /// + /// Identifies the generic SQL type XML + /// + SQLXML = 2009, + /// + /// Identifies the generic SQL type TIME + /// + TIME = 92, + /// + /// Identifies the generic SQL type TIME WITH TIMEZONE + /// + TIME_WITH_TIMEZONE = 2013, + /// + /// Identifies the generic SQL type TIMESTAMP WITH TIMEZONE + /// + TIMESTAMP_WITH_TIMEZONE = 2014, + } + + internal SparkConnection(IReadOnlyDictionary properties) + : base(properties) + { + ValidateProperties(); + _productVersion = new Lazy(() => GetProductVersion(), LazyThreadSafetyMode.PublicationOnly); + } + + private void ValidateProperties() + { + ValidateAuthentication(); + ValidateConnection(); + ValidateOptions(); + } + + protected string ProductVersion => _productVersion.Value; + + public override AdbcStatement CreateStatement() + { + return new SparkStatement(this); + } + + public override IArrowArrayStream GetInfo(IReadOnlyList codes) + { + const int strValTypeID = 0; + const int boolValTypeId = 1; + + UnionType infoUnionType = new UnionType( + new Field[] + { + new Field("string_value", StringType.Default, true), + new Field("bool_value", BooleanType.Default, true), + new Field("int64_value", Int64Type.Default, true), + new Field("int32_bitmask", Int32Type.Default, true), + new Field( + "string_list", + new ListType( + new Field("item", StringType.Default, true) + ), + false + ), + new Field( + "int32_to_int32_list_map", + new ListType( + new Field("entries", new StructType( + new Field[] + { + new Field("key", Int32Type.Default, false), + new Field("value", Int32Type.Default, true), + } + ), false) + ), + true + ) + }, + new int[] { 0, 1, 2, 3, 4, 5 }, + UnionMode.Dense); + + if (codes.Count == 0) + { + codes = infoSupportedCodes; + } + + UInt32Array.Builder infoNameBuilder = new UInt32Array.Builder(); + ArrowBuffer.Builder typeBuilder = new ArrowBuffer.Builder(); + ArrowBuffer.Builder offsetBuilder = new ArrowBuffer.Builder(); + StringArray.Builder stringInfoBuilder = new StringArray.Builder(); + BooleanArray.Builder booleanInfoBuilder = new BooleanArray.Builder(); + + int nullCount = 0; + int arrayLength = codes.Count; + int offset = 0; + + foreach (AdbcInfoCode code in codes) + { + switch (code) + { + case AdbcInfoCode.DriverName: + infoNameBuilder.Append((UInt32)code); + typeBuilder.Append(strValTypeID); + offsetBuilder.Append(offset++); + stringInfoBuilder.Append(InfoDriverName); + booleanInfoBuilder.AppendNull(); + break; + case AdbcInfoCode.DriverVersion: + infoNameBuilder.Append((UInt32)code); + typeBuilder.Append(strValTypeID); + offsetBuilder.Append(offset++); + stringInfoBuilder.Append(ProductVersion); + booleanInfoBuilder.AppendNull(); + break; + case AdbcInfoCode.DriverArrowVersion: + infoNameBuilder.Append((UInt32)code); + typeBuilder.Append(strValTypeID); + offsetBuilder.Append(offset++); + stringInfoBuilder.Append(InfoDriverArrowVersion); + booleanInfoBuilder.AppendNull(); + break; + case AdbcInfoCode.VendorName: + infoNameBuilder.Append((UInt32)code); + typeBuilder.Append(strValTypeID); + offsetBuilder.Append(offset++); + string vendorName = VendorName; + stringInfoBuilder.Append(vendorName); + booleanInfoBuilder.AppendNull(); + break; + case AdbcInfoCode.VendorVersion: + infoNameBuilder.Append((UInt32)code); + typeBuilder.Append(strValTypeID); + offsetBuilder.Append(offset++); + string? vendorVersion = VendorVersion; + stringInfoBuilder.Append(vendorVersion); + booleanInfoBuilder.AppendNull(); + break; + case AdbcInfoCode.VendorSql: + infoNameBuilder.Append((UInt32)code); + typeBuilder.Append(boolValTypeId); + offsetBuilder.Append(offset++); + stringInfoBuilder.AppendNull(); + booleanInfoBuilder.Append(InfoVendorSql); + break; + default: + infoNameBuilder.Append((UInt32)code); + typeBuilder.Append(strValTypeID); + offsetBuilder.Append(offset++); + stringInfoBuilder.AppendNull(); + booleanInfoBuilder.AppendNull(); + nullCount++; + break; + } + } + + StructType entryType = new StructType( + new Field[] { + new Field("key", Int32Type.Default, false), + new Field("value", Int32Type.Default, true)}); + + StructArray entriesDataArray = new StructArray(entryType, 0, + new[] { new Int32Array.Builder().Build(), new Int32Array.Builder().Build() }, + new ArrowBuffer.BitmapBuilder().Build()); + + IArrowArray[] childrenArrays = new IArrowArray[] + { + stringInfoBuilder.Build(), + booleanInfoBuilder.Build(), + new Int64Array.Builder().Build(), + new Int32Array.Builder().Build(), + new ListArray.Builder(StringType.Default).Build(), + new List(){ entriesDataArray }.BuildListArrayForType(entryType) + }; + + DenseUnionArray infoValue = new DenseUnionArray(infoUnionType, arrayLength, childrenArrays, typeBuilder.Build(), offsetBuilder.Build(), nullCount); + + IArrowArray[] dataArrays = new IArrowArray[] + { + infoNameBuilder.Build(), + infoValue + }; + StandardSchemas.GetInfoSchema.Validate(dataArrays); + + return new SparkInfoArrowStream(StandardSchemas.GetInfoSchema, dataArrays); + + } + + public override IArrowArrayStream GetTableTypes() + { + TGetTableTypesReq req = new() + { + SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), + GetDirectResults = sparkGetDirectResults + }; + + CancellationToken timeoutToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + + TGetTableTypesResp resp = Client.GetTableTypes(req, timeoutToken).Result; + + if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(resp.Status.ErrorMessage) + .SetNativeError(resp.Status.ErrorCode) + .SetSqlState(resp.Status.SqlState); + } + + TRowSet rowSet = GetRowSetAsync(resp).Result; + StringArray tableTypes = rowSet.Columns[0].StringVal.Values; + + StringArray.Builder tableTypesBuilder = new StringArray.Builder(); + tableTypesBuilder.AppendRange(tableTypes); + + IArrowArray[] dataArrays = new IArrowArray[] + { + tableTypesBuilder.Build() + }; + + return new SparkInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays); + } + + public override Schema GetTableSchema(string? catalog, string? dbSchema, string? tableName) + { + TGetColumnsReq getColumnsReq = new TGetColumnsReq(SessionHandle); + getColumnsReq.CatalogName = catalog; + getColumnsReq.SchemaName = dbSchema; + getColumnsReq.TableName = tableName; + getColumnsReq.GetDirectResults = sparkGetDirectResults; + + CancellationToken timeoutToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + + var columnsResponse = Client.GetColumns(getColumnsReq, timeoutToken).Result; + if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(columnsResponse.Status.ErrorMessage); + } + + TRowSet rowSet = GetRowSetAsync(columnsResponse).Result; + List columns = rowSet.Columns; + int rowCount = rowSet.Columns[3].StringVal.Values.Length; + + Field[] fields = new Field[rowCount]; + for (int i = 0; i < rowCount; i++) + { + string columnName = columns[3].StringVal.Values.GetString(i); + int? columnType = columns[4].I32Val.Values.GetValue(i); + string typeName = columns[5].StringVal.Values.GetString(i); + // Note: the following two columns do not seem to be set correctly for DECIMAL types. + //int? columnSize = columns[6].I32Val.Values.GetValue(i); + //int? decimalDigits = columns[8].I32Val.Values.GetValue(i); + bool nullable = columns[10].I32Val.Values.GetValue(i) == 1; + IArrowType dataType = SparkConnection.GetArrowType(columnType!.Value, typeName); + fields[i] = new Field(columnName, dataType, nullable); + } + return new Schema(fields, null); + } + + public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList? tableTypes, string? columnNamePattern) + { + Trace.TraceError($"getting objects with depth={depth.ToString()}, catalog = {catalogPattern}, dbschema = {dbSchemaPattern}, tablename = {tableNamePattern}"); + + Dictionary>> catalogMap = new Dictionary>>(); + CancellationToken timeoutToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) + { + TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); + getCatalogsReq.GetDirectResults = sparkGetDirectResults; + + TGetCatalogsResp getCatalogsResp = Client.GetCatalogs(getCatalogsReq, timeoutToken).Result; + + if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getCatalogsResp.Status.ErrorMessage); + } + var catalogsMetadata = GetResultSetMetadataAsync(getCatalogsResp).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(catalogsMetadata.Schema.Columns); + + string catalogRegexp = PatternToRegEx(catalogPattern); + TRowSet rowSet = GetRowSetAsync(getCatalogsResp).Result; + IReadOnlyList list = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + for (int i = 0; i < list.Count; i++) + { + string col = list[i]; + string catalog = col; + + if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) + { + catalogMap.Add(catalog, new Dictionary>()); + } + } + // Handle the case where server does not support 'catalog' in the namespace. + if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern)) + { + catalogMap.Add(string.Empty, []); + } + } + + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) + { + TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); + getSchemasReq.CatalogName = catalogPattern; + getSchemasReq.SchemaName = dbSchemaPattern; + getSchemasReq.GetDirectResults = sparkGetDirectResults; + + TGetSchemasResp getSchemasResp = Client.GetSchemas(getSchemasReq, timeoutToken).Result; + if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getSchemasResp.Status.ErrorMessage); + } + + TGetResultSetMetadataResp schemaMetadata = GetResultSetMetadataAsync(getSchemasResp).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(schemaMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getSchemasResp).Result; + + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCatalog]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). + catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary()); + } + } + + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) + { + TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle); + getTablesReq.CatalogName = catalogPattern; + getTablesReq.SchemaName = dbSchemaPattern; + getTablesReq.TableName = tableNamePattern; + getTablesReq.GetDirectResults = sparkGetDirectResults; + + TGetTablesResp getTablesResp = Client.GetTables(getTablesReq, timeoutToken).Result; + if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getTablesResp.Status.ErrorMessage); + } + + TGetResultSetMetadataResp tableMetadata = GetResultSetMetadataAsync(getTablesResp).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(tableMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getTablesResp).Result; + + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; + IReadOnlyList tableTypeList = rowSet.Columns[columnMap[TableType]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string tableType = tableTypeList[i]; + TableInfo tableInfo = new(tableType); + catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); + } + } + + if (depth == GetObjectsDepth.All) + { + TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle); + columnsReq.CatalogName = catalogPattern; + columnsReq.SchemaName = dbSchemaPattern; + columnsReq.TableName = tableNamePattern; + columnsReq.GetDirectResults = sparkGetDirectResults; + + if (!string.IsNullOrEmpty(columnNamePattern)) + columnsReq.ColumnName = columnNamePattern; + + var columnsResponse = Client.GetColumns(columnsReq, timeoutToken).Result; + if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(columnsResponse.Status.ErrorMessage); + } + + TGetResultSetMetadataResp columnsMetadata = GetResultSetMetadataAsync(columnsResponse).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(columnsMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(columnsResponse).Result; + + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; + IReadOnlyList columnNameList = rowSet.Columns[columnMap[ColumnName]].StringVal.Values; + ReadOnlySpan columnTypeList = rowSet.Columns[columnMap[DataType]].I32Val.Values.Values; + IReadOnlyList typeNameList = rowSet.Columns[columnMap[TypeName]].StringVal.Values; + ReadOnlySpan nullableList = rowSet.Columns[columnMap[Nullable]].I32Val.Values.Values; + IReadOnlyList columnDefaultList = rowSet.Columns[columnMap[ColumnDef]].StringVal.Values; + ReadOnlySpan ordinalPosList = rowSet.Columns[columnMap[OrdinalPosition]].I32Val.Values.Values; + IReadOnlyList isNullableList = rowSet.Columns[columnMap[IsNullable]].StringVal.Values; + IReadOnlyList isAutoIncrementList = rowSet.Columns[columnMap[IsAutoIncrement]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + // For systems that don't support 'catalog' in the namespace + string catalog = catalogList[i] ?? string.Empty; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string columnName = columnNameList[i]; + short colType = (short)columnTypeList[i]; + string typeName = typeNameList[i]; + short nullable = (short)nullableList[i]; + string? isAutoIncrementString = isAutoIncrementList[i]; + bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); + string isNullable = isNullableList[i] ?? "YES"; + string columnDefault = columnDefaultList[i] ?? ""; + // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed + int ordinalPos = ordinalPosList[i] + 1; + TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); + tableInfo?.ColumnName.Add(columnName); + tableInfo?.ColType.Add(colType); + tableInfo?.Nullable.Add(nullable); + tableInfo?.IsAutoIncrement.Add(isAutoIncrement); + tableInfo?.IsNullable.Add(isNullable); + tableInfo?.ColumnDefault.Add(columnDefault); + tableInfo?.OrdinalPosition.Add(ordinalPos); + SetPrecisionScaleAndTypeName(colType, typeName, tableInfo); + } + } + + StringArray.Builder catalogNameBuilder = new StringArray.Builder(); + List catalogDbSchemasValues = new List(); + + foreach (KeyValuePair>> catalogEntry in catalogMap) + { + catalogNameBuilder.Append(catalogEntry.Key); + + if (depth == GetObjectsDepth.Catalogs) + { + catalogDbSchemasValues.Add(null); + } + else + { + catalogDbSchemasValues.Add(GetDbSchemas( + depth, catalogEntry.Value)); + } + } + + Schema schema = StandardSchemas.GetObjectsSchema; + IReadOnlyList dataArrays = schema.Validate( + new List + { + catalogNameBuilder.Build(), + catalogDbSchemasValues.BuildListArrayForType(new StructType(StandardSchemas.DbSchemaSchema)), + }); + + return new SparkInfoArrowStream(schema, dataArrays); + } + + private static IReadOnlyDictionary GetColumnIndexMap(List columns) => columns + .Select(t => new { Index = t.Position - 1, t.ColumnName }) + .ToDictionary(t => t.ColumnName, t => t.Index); + + private static void SetPrecisionScaleAndTypeName(short colType, string typeName, TableInfo? tableInfo) + { + // Keep the original type name + tableInfo?.TypeName.Add(typeName); + switch (colType) + { + case (short)ColumnTypeId.DECIMAL: + case (short)ColumnTypeId.NUMERIC: + { + SqlDecimalParserResult result = SqlTypeNameParser.Parse(typeName, colType); + tableInfo?.Precision.Add(result.Precision); + tableInfo?.Scale.Add((short)result.Scale); + tableInfo?.BaseTypeName.Add(result.BaseTypeName); + break; + } + + case (short)ColumnTypeId.CHAR: + case (short)ColumnTypeId.NCHAR: + case (short)ColumnTypeId.VARCHAR: + case (short)ColumnTypeId.LONGVARCHAR: + case (short)ColumnTypeId.LONGNVARCHAR: + case (short)ColumnTypeId.NVARCHAR: + { + SqlCharVarcharParserResult result = SqlTypeNameParser.Parse(typeName, colType); + tableInfo?.Precision.Add(result.ColumnSize); + tableInfo?.Scale.Add(null); + tableInfo?.BaseTypeName.Add(result.BaseTypeName); + break; + } + + default: + { + SqlTypeNameParserResult result = SqlTypeNameParser.Parse(typeName, colType); + tableInfo?.Precision.Add(null); + tableInfo?.Scale.Add(null); + tableInfo?.BaseTypeName.Add(result.BaseTypeName); + break; + } + } + } + + private static IArrowType GetArrowType(int columnTypeId, string typeName) + { + switch (columnTypeId) + { + case (int)ColumnTypeId.BOOLEAN: + return BooleanType.Default; + case (int)ColumnTypeId.TINYINT: + return Int8Type.Default; + case (int)ColumnTypeId.SMALLINT: + return Int16Type.Default; + case (int)ColumnTypeId.INTEGER: + return Int32Type.Default; + case (int)ColumnTypeId.BIGINT: + return Int64Type.Default; + case (int)ColumnTypeId.FLOAT: + case (int)ColumnTypeId.REAL: + return FloatType.Default; + case (int)ColumnTypeId.DOUBLE: + return DoubleType.Default; + case (int)ColumnTypeId.VARCHAR: + case (int)ColumnTypeId.NVARCHAR: + case (int)ColumnTypeId.LONGVARCHAR: + case (int)ColumnTypeId.LONGNVARCHAR: + return StringType.Default; + case (int)ColumnTypeId.TIMESTAMP: + return new TimestampType(TimeUnit.Microsecond, timezone: (string?)null); + case (int)ColumnTypeId.BINARY: + case (int)ColumnTypeId.VARBINARY: + case (int)ColumnTypeId.LONGVARBINARY: + return BinaryType.Default; + case (int)ColumnTypeId.DATE: + return Date32Type.Default; + case (int)ColumnTypeId.CHAR: + case (int)ColumnTypeId.NCHAR: + return StringType.Default; + case (int)ColumnTypeId.DECIMAL: + case (int)ColumnTypeId.NUMERIC: + // Note: parsing the type name for SQL DECIMAL types as the precision and scale values + // are not returned in the Thrift call to GetColumns + return SqlTypeNameParser + .Parse(typeName, columnTypeId) + .Decimal128Type; + case (int)ColumnTypeId.NULL: + return NullType.Default; + case (int)ColumnTypeId.ARRAY: + case (int)ColumnTypeId.JAVA_OBJECT: + case (int)ColumnTypeId.STRUCT: + return StringType.Default; + default: + throw new NotImplementedException($"Column type id: {columnTypeId} is not supported."); + } + } + + private static StructArray GetDbSchemas( + GetObjectsDepth depth, + Dictionary> schemaMap) + { + StringArray.Builder dbSchemaNameBuilder = new StringArray.Builder(); + List dbSchemaTablesValues = new List(); + ArrowBuffer.BitmapBuilder nullBitmapBuffer = new ArrowBuffer.BitmapBuilder(); + int length = 0; + + + foreach (KeyValuePair> schemaEntry in schemaMap) + { + + dbSchemaNameBuilder.Append(schemaEntry.Key); + length++; + nullBitmapBuffer.Append(true); + + if (depth == GetObjectsDepth.DbSchemas) + { + dbSchemaTablesValues.Add(null); + } + else + { + dbSchemaTablesValues.Add(GetTableSchemas( + depth, schemaEntry.Value)); + } + + } + + IReadOnlyList schema = StandardSchemas.DbSchemaSchema; + IReadOnlyList dataArrays = schema.Validate( + new List + { + dbSchemaNameBuilder.Build(), + dbSchemaTablesValues.BuildListArrayForType(new StructType(StandardSchemas.TableSchema)), + }); + + return new StructArray( + new StructType(schema), + length, + dataArrays, + nullBitmapBuffer.Build()); + } + + private static StructArray GetTableSchemas( + GetObjectsDepth depth, + Dictionary tableMap) + { + StringArray.Builder tableNameBuilder = new StringArray.Builder(); + StringArray.Builder tableTypeBuilder = new StringArray.Builder(); + List tableColumnsValues = new List(); + List tableConstraintsValues = new List(); + ArrowBuffer.BitmapBuilder nullBitmapBuffer = new ArrowBuffer.BitmapBuilder(); + int length = 0; + + + foreach (KeyValuePair tableEntry in tableMap) + { + tableNameBuilder.Append(tableEntry.Key); + tableTypeBuilder.Append(tableEntry.Value.Type); + nullBitmapBuffer.Append(true); + length++; + + + tableConstraintsValues.Add(null); + + + if (depth == GetObjectsDepth.Tables) + { + tableColumnsValues.Add(null); + } + else + { + tableColumnsValues.Add(GetColumnSchema(tableEntry.Value)); + } + } + + + IReadOnlyList schema = StandardSchemas.TableSchema; + IReadOnlyList dataArrays = schema.Validate( + new List + { + tableNameBuilder.Build(), + tableTypeBuilder.Build(), + tableColumnsValues.BuildListArrayForType(new StructType(StandardSchemas.ColumnSchema)), + tableConstraintsValues.BuildListArrayForType( new StructType(StandardSchemas.ConstraintSchema)) + }); + + return new StructArray( + new StructType(schema), + length, + dataArrays, + nullBitmapBuffer.Build()); + } + + private static StructArray GetColumnSchema(TableInfo tableInfo) + { + StringArray.Builder columnNameBuilder = new StringArray.Builder(); + Int32Array.Builder ordinalPositionBuilder = new Int32Array.Builder(); + StringArray.Builder remarksBuilder = new StringArray.Builder(); + Int16Array.Builder xdbcDataTypeBuilder = new Int16Array.Builder(); + StringArray.Builder xdbcTypeNameBuilder = new StringArray.Builder(); + Int32Array.Builder xdbcColumnSizeBuilder = new Int32Array.Builder(); + Int16Array.Builder xdbcDecimalDigitsBuilder = new Int16Array.Builder(); + Int16Array.Builder xdbcNumPrecRadixBuilder = new Int16Array.Builder(); + Int16Array.Builder xdbcNullableBuilder = new Int16Array.Builder(); + StringArray.Builder xdbcColumnDefBuilder = new StringArray.Builder(); + Int16Array.Builder xdbcSqlDataTypeBuilder = new Int16Array.Builder(); + Int16Array.Builder xdbcDatetimeSubBuilder = new Int16Array.Builder(); + Int32Array.Builder xdbcCharOctetLengthBuilder = new Int32Array.Builder(); + StringArray.Builder xdbcIsNullableBuilder = new StringArray.Builder(); + StringArray.Builder xdbcScopeCatalogBuilder = new StringArray.Builder(); + StringArray.Builder xdbcScopeSchemaBuilder = new StringArray.Builder(); + StringArray.Builder xdbcScopeTableBuilder = new StringArray.Builder(); + BooleanArray.Builder xdbcIsAutoincrementBuilder = new BooleanArray.Builder(); + BooleanArray.Builder xdbcIsGeneratedcolumnBuilder = new BooleanArray.Builder(); + ArrowBuffer.BitmapBuilder nullBitmapBuffer = new ArrowBuffer.BitmapBuilder(); + int length = 0; + + + for (int i = 0; i < tableInfo.ColumnName.Count; i++) + { + columnNameBuilder.Append(tableInfo.ColumnName[i]); + ordinalPositionBuilder.Append(tableInfo.OrdinalPosition[i]); + // Use the "remarks" field to store the original type name value + remarksBuilder.Append(tableInfo.TypeName[i]); + xdbcColumnSizeBuilder.Append(tableInfo.Precision[i]); + xdbcDecimalDigitsBuilder.Append(tableInfo.Scale[i]); + xdbcDataTypeBuilder.Append(tableInfo.ColType[i]); + // Just the base type name without precision or scale clause + xdbcTypeNameBuilder.Append(tableInfo.BaseTypeName[i]); + xdbcNumPrecRadixBuilder.AppendNull(); + xdbcNullableBuilder.Append(tableInfo.Nullable[i]); + xdbcColumnDefBuilder.Append(tableInfo.ColumnDefault[i]); + xdbcSqlDataTypeBuilder.Append(tableInfo.ColType[i]); + xdbcDatetimeSubBuilder.AppendNull(); + xdbcCharOctetLengthBuilder.AppendNull(); + xdbcIsNullableBuilder.Append(tableInfo.IsNullable[i]); + xdbcScopeCatalogBuilder.AppendNull(); + xdbcScopeSchemaBuilder.AppendNull(); + xdbcScopeTableBuilder.AppendNull(); + xdbcIsAutoincrementBuilder.Append(tableInfo.IsAutoIncrement[i]); + xdbcIsGeneratedcolumnBuilder.Append(true); + nullBitmapBuffer.Append(true); + length++; + } + + IReadOnlyList schema = StandardSchemas.ColumnSchema; + IReadOnlyList dataArrays = schema.Validate( + new List + { + columnNameBuilder.Build(), + ordinalPositionBuilder.Build(), + remarksBuilder.Build(), + xdbcDataTypeBuilder.Build(), + xdbcTypeNameBuilder.Build(), + xdbcColumnSizeBuilder.Build(), + xdbcDecimalDigitsBuilder.Build(), + xdbcNumPrecRadixBuilder.Build(), + xdbcNullableBuilder.Build(), + xdbcColumnDefBuilder.Build(), + xdbcSqlDataTypeBuilder.Build(), + xdbcDatetimeSubBuilder.Build(), + xdbcCharOctetLengthBuilder.Build(), + xdbcIsNullableBuilder.Build(), + xdbcScopeCatalogBuilder.Build(), + xdbcScopeSchemaBuilder.Build(), + xdbcScopeTableBuilder.Build(), + xdbcIsAutoincrementBuilder.Build(), + xdbcIsGeneratedcolumnBuilder.Build() + }); + + return new StructArray( + new StructType(schema), + length, + dataArrays, + nullBitmapBuffer.Build()); + } + + private static string PatternToRegEx(string? pattern) + { + if (pattern == null) + return ".*"; + + StringBuilder builder = new StringBuilder("(?i)^"); + string convertedPattern = pattern.Replace("_", ".").Replace("%", ".*"); + builder.Append(convertedPattern); + builder.Append('$'); + + return builder.ToString(); + } + + + private static string GetProductVersion() + { + FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location); + return fileVersionInfo.ProductVersion ?? ProductVersionDefault; + } + + protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, string? port) + { + // Uri property takes precedent. + if (!string.IsNullOrWhiteSpace(uri)) + { + var uriValue = new Uri(uri); + if (uriValue.Scheme != Uri.UriSchemeHttp && uriValue.Scheme != Uri.UriSchemeHttps) + throw new ArgumentOutOfRangeException( + AdbcOptions.Uri, + uri, + $"Unsupported scheme '{uriValue.Scheme}'"); + return uriValue; + } + + bool isPortSet = !string.IsNullOrEmpty(port); + bool isValidPortNumber = int.TryParse(port, out int portNumber) && portNumber > 0; + bool isDefaultHttpsPort = !isPortSet || (isValidPortNumber && portNumber == 443); + string uriScheme = isDefaultHttpsPort ? Uri.UriSchemeHttps : Uri.UriSchemeHttp; + int uriPort; + if (!isPortSet) + uriPort = -1; + else if (isValidPortNumber) + uriPort = portNumber; + else + throw new ArgumentOutOfRangeException(nameof(port), portNumber, $"Port number is not in a valid range."); + + Uri baseAddress = new UriBuilder(uriScheme, hostName, uriPort, path).Uri; + return baseAddress; + } + + protected abstract void ValidateConnection(); + protected abstract void ValidateAuthentication(); + protected abstract void ValidateOptions(); + + protected abstract Task GetRowSetAsync(TGetTableTypesResp response); + protected abstract Task GetRowSetAsync(TGetColumnsResp response); + protected abstract Task GetRowSetAsync(TGetTablesResp response); + protected abstract Task GetRowSetAsync(TGetCatalogsResp getCatalogsResp); + protected abstract Task GetRowSetAsync(TGetSchemasResp getSchemasResp); + protected abstract Task GetResultSetMetadataAsync(TGetSchemasResp response); + protected abstract Task GetResultSetMetadataAsync(TGetCatalogsResp response); + protected abstract Task GetResultSetMetadataAsync(TGetColumnsResp response); + protected abstract Task GetResultSetMetadataAsync(TGetTablesResp response); + + internal abstract SparkServerType ServerType { get; } + + internal struct TableInfo(string type) + { + public string Type { get; } = type; + + public List ColumnName { get; } = new(); + + public List ColType { get; } = new(); + + public List BaseTypeName { get; } = new(); + + public List TypeName { get; } = new(); + + public List Nullable { get; } = new(); + + public List Precision { get; } = new(); + + public List Scale { get; } = new(); + + public List OrdinalPosition { get; } = new(); + + public List ColumnDefault { get; } = new(); + + public List IsNullable { get; } = new(); + + public List IsAutoIncrement { get; } = new(); + } + + internal class SparkInfoArrowStream : IArrowArrayStream + { + private Schema schema; + private RecordBatch? batch; + + public SparkInfoArrowStream(Schema schema, IReadOnlyList data) + { + this.schema = schema; + this.batch = new RecordBatch(schema, data, data[0].Length); + } + + public Schema Schema { get { return this.schema; } } + + public ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) + { + RecordBatch? batch = this.batch; + this.batch = null; + return new ValueTask(batch); + } + + public void Dispose() + { + this.batch?.Dispose(); + this.batch = null; + } + } + } +} diff --git a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs index 605693fad4..2f448b6df0 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs @@ -1,298 +1,298 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -using System; -using System.Collections.Generic; -using System.Globalization; -using System.Net; -using System.Threading.Tasks; -using Apache.Arrow.Adbc.Drivers.Apache.Spark; -using Thrift.Protocol.Entities; -using Thrift.Transport; -using Xunit; -using Xunit.Abstractions; - -namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark -{ - /// - /// Class for testing the Spark ADBC connection tests. - /// - public class SparkConnectionTest : TestBase - { - public SparkConnectionTest(ITestOutputHelper? outputHelper) : base(outputHelper, new SparkTestEnvironment.Factory()) - { - Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); - } - - /// - /// Validates database can detect invalid connection parameter combinations. - /// - [SkippableTheory] - [ClassData(typeof(InvalidConnectionParametersTestData))] - internal void CanDetectConnectionParameterErrors(ParametersWithExceptions test) - { - AdbcDriver driver = NewDriver; - AdbcDatabase database = driver.Open(test.Parameters); - Exception exeption = Assert.Throws(test.ExceptionType, () => database.Connect(test.Parameters)); - OutputHelper?.WriteLine(exeption.Message); - } - - [SkippableTheory] - [InlineData(-1, typeof(TimeoutException))] - [InlineData(10, typeof(TimeoutException))] - [InlineData(30000, null)] - [InlineData(null, null)] - [InlineData(-1, null)] - public void ConnectionTimeoutTest(int? connectTimeoutMilliseconds, Type? exceptionType) - { - SparkTestConfiguration testConfiguration = (SparkTestConfiguration)TestConfiguration.Clone(); - - if (connectTimeoutMilliseconds.HasValue) - testConfiguration.ConnectTimeoutMilliseconds = connectTimeoutMilliseconds.Value.ToString(); - - OutputHelper?.WriteLine($"ConnectTimeoutMilliseconds: {testConfiguration.ConnectTimeoutMilliseconds}. ShouldSucceed: {exceptionType == null}"); - - try - { - NewConnection(testConfiguration); - } - catch(AggregateException aex) - { - if (exceptionType != null) - { - Assert.IsType(exceptionType, aex.InnerException); - } - else - { - throw; - } - } - } - - /// - /// Tests the various metadata calls on a SparkConnection - /// - /// - [SkippableTheory] - [ClassData(typeof(MetadataTimeoutTestData))] - internal void MetadataTimeoutTest(MetadataWithExceptions metadataWithException) - { - SparkTestConfiguration testConfiguration = (SparkTestConfiguration)TestConfiguration.Clone(); - - if (metadataWithException.QueryTimeoutSeconds.HasValue) - testConfiguration.QueryTimeoutSeconds = metadataWithException.QueryTimeoutSeconds.Value.ToString(); - - OutputHelper?.WriteLine($"Action: {metadataWithException.ActionName}. QueryTimeoutSeconds: {testConfiguration.QueryTimeoutSeconds}. ShouldSucceed: {metadataWithException.ExceptionType == null}"); - - try - { - metadataWithException.MetadataAction(testConfiguration); - } - catch (AggregateException aex) - { - if (metadataWithException.ExceptionType != null) - { - if (metadataWithException.AlternateExceptionType != null && aex.InnerException?.GetType() != metadataWithException.ExceptionType) - { - Assert.IsType(metadataWithException.AlternateExceptionType, aex.InnerException); - } - else - { - Assert.IsType(metadataWithException.ExceptionType, aex.InnerException); - } - } - else - { - throw; - } - } - } - - internal class MetadataWithExceptions - { - public MetadataWithExceptions(int? queryTimeoutSeconds, string actionName, Action action, Type? exceptionType, Type? alternateExceptionType) - { - QueryTimeoutSeconds = queryTimeoutSeconds; - ActionName = actionName; - MetadataAction = action; - ExceptionType = exceptionType; - AlternateExceptionType = alternateExceptionType; - } - - /// - /// If null, uses the default timeout. - /// - public int? QueryTimeoutSeconds { get; } - - public string ActionName { get; } - - /// - /// If null, expected to succeed. - /// - public Type? ExceptionType { get; } - - /// - /// Sometimes you can expect one but may get another. - /// For example, on GetObjectsAll, sometimes a TTransportException is expected but a TaskCanceledException is received during the test. - /// - public Type? AlternateExceptionType { get; } - - /// - /// The metadata action to perform. - /// - public Action MetadataAction { get; } - } - - /// - /// Used for testing timeouts on metadata calls. - /// - internal class MetadataTimeoutTestData : TheoryData - { - public MetadataTimeoutTestData() - { - SparkConnectionTest sparkConnectionTest = new SparkConnectionTest(null); - - Action getObjectsAll = (testConfiguration) => - { - AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); - cn.GetObjects(AdbcConnection.GetObjectsDepth.All, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Table, null, null); - }; - - Action getObjectsCatalogs = (testConfiguration) => - { - AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); - cn.GetObjects(AdbcConnection.GetObjectsDepth.Catalogs, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); - }; - - Action getObjectsDbSchemas = (testConfiguration) => - { - AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); - cn.GetObjects(AdbcConnection.GetObjectsDepth.DbSchemas, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); - }; - - Action getObjectsTables = (testConfiguration) => - { - AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); - cn.GetObjects(AdbcConnection.GetObjectsDepth.Tables, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); - }; - - AddAction("getObjectsAll", getObjectsAll, new List() { null, typeof(TaskCanceledException), null, null, null } ); - AddAction("getObjectsCatalogs", getObjectsCatalogs); - AddAction("getObjectsDbSchemas", getObjectsDbSchemas); - AddAction("getObjectsTables", getObjectsTables); - - Action getTableTypes = (testConfiguration) => - { - AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); - cn.GetTableTypes(); - }; - - AddAction("getTableTypes", getTableTypes); - - Action getTableSchema = (testConfiguration) => - { - AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); - cn.GetTableSchema(testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Table); - }; - - AddAction("getTableSchema", getTableSchema); - } - - private void AddAction(string name, Action action, List? alternateExceptions = null) - { - List expectedExceptions = new List() - { - null, // QueryTimeout = -1 - typeof(TTransportException), // QueryTimeout = 1 - typeof(TimeoutException), // QueryTimeout = 10 - null, // QueryTimeout = default - null // QueryTimeout = 300 - }; - - AddAction(name, action, expectedExceptions, alternateExceptions); - } - - /// - /// Adds the action with the default timeouts. - /// - /// The action to perform. - /// The expected exceptions. - /// - /// For List the position is based on the behavior when: - /// [0] QueryTimeout = -1 - /// [1] QueryTimeout = 1 - /// [2] QueryTimeout = 10 - /// [3] QueryTimeout = default - /// [4] QueryTimeout = 300 - /// - private void AddAction(string name, Action action, List expectedExceptions, List? alternateExceptions) - { - Assert.True(expectedExceptions.Count == 5); - - Add(new(-1, name, action, expectedExceptions[0], alternateExceptions?[0])); - Add(new(1, name, action, expectedExceptions[1], alternateExceptions?[1])); - Add(new(10, name, action, expectedExceptions[2], alternateExceptions?[2])); - Add(new(null, name, action, expectedExceptions[3], alternateExceptions?[3])); - Add(new(300, name, action, expectedExceptions[4], alternateExceptions?[4])); - } - } - - internal class ParametersWithExceptions - { - public ParametersWithExceptions(Dictionary parameters, Type exceptionType) - { - Parameters = parameters; - ExceptionType = exceptionType; - } - - public IReadOnlyDictionary Parameters { get; } - public Type ExceptionType { get; } - } - - internal class InvalidConnectionParametersTestData : TheoryData - { - public InvalidConnectionParametersTestData() - { - Add(new([], typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = " " }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = "xxx" }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Standard }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = " " }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "invalid!server.com" }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "http://valid.server.com" }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com" }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"unknown_auth_type" }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Basic}" }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Token}" }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Basic}", [SparkParameters.Token] = "abcdef" }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Token}", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword" }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user" }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Password] = "myPassword" }, typeof(ArgumentException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = "-1" }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = IPEndPoint.MinPort.ToString(CultureInfo.InvariantCulture) }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = (IPEndPoint.MaxPort + 1).ToString(CultureInfo.InvariantCulture) }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "httpxxz://hostname.com" }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "http-//hostname.com" }, typeof(UriFormatException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "httpxxz://hostname.com:1234567890" }, typeof(UriFormatException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = "0" }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = ((long)int.MaxValue + 1).ToString() }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = "non-numeric" }, typeof(ArgumentOutOfRangeException))); - Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = "" }, typeof(ArgumentOutOfRangeException))); - } - } - } -} +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Net; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Thrift.Protocol.Entities; +using Thrift.Transport; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark +{ + /// + /// Class for testing the Spark ADBC connection tests. + /// + public class SparkConnectionTest : TestBase + { + public SparkConnectionTest(ITestOutputHelper? outputHelper) : base(outputHelper, new SparkTestEnvironment.Factory()) + { + Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); + } + + /// + /// Validates database can detect invalid connection parameter combinations. + /// + [SkippableTheory] + [ClassData(typeof(InvalidConnectionParametersTestData))] + internal void CanDetectConnectionParameterErrors(ParametersWithExceptions test) + { + AdbcDriver driver = NewDriver; + AdbcDatabase database = driver.Open(test.Parameters); + Exception exeption = Assert.Throws(test.ExceptionType, () => database.Connect(test.Parameters)); + OutputHelper?.WriteLine(exeption.Message); + } + + [SkippableTheory] + [InlineData(-1, typeof(TimeoutException))] + [InlineData(10, typeof(TimeoutException))] + [InlineData(30000, null)] + [InlineData(null, null)] + [InlineData(-1, null)] + public void ConnectionTimeoutTest(int? connectTimeoutMilliseconds, Type? exceptionType) + { + SparkTestConfiguration testConfiguration = (SparkTestConfiguration)TestConfiguration.Clone(); + + if (connectTimeoutMilliseconds.HasValue) + testConfiguration.ConnectTimeoutMilliseconds = connectTimeoutMilliseconds.Value.ToString(); + + OutputHelper?.WriteLine($"ConnectTimeoutMilliseconds: {testConfiguration.ConnectTimeoutMilliseconds}. ShouldSucceed: {exceptionType == null}"); + + try + { + NewConnection(testConfiguration); + } + catch(AggregateException aex) + { + if (exceptionType != null) + { + Assert.IsType(exceptionType, aex.InnerException); + } + else + { + throw; + } + } + } + + /// + /// Tests the various metadata calls on a SparkConnection + /// + /// + [SkippableTheory] + [ClassData(typeof(MetadataTimeoutTestData))] + internal void MetadataTimeoutTest(MetadataWithExceptions metadataWithException) + { + SparkTestConfiguration testConfiguration = (SparkTestConfiguration)TestConfiguration.Clone(); + + if (metadataWithException.QueryTimeoutSeconds.HasValue) + testConfiguration.QueryTimeoutSeconds = metadataWithException.QueryTimeoutSeconds.Value.ToString(); + + OutputHelper?.WriteLine($"Action: {metadataWithException.ActionName}. QueryTimeoutSeconds: {testConfiguration.QueryTimeoutSeconds}. ShouldSucceed: {metadataWithException.ExceptionType == null}"); + + try + { + metadataWithException.MetadataAction(testConfiguration); + } + catch (AggregateException aex) + { + if (metadataWithException.ExceptionType != null) + { + if (metadataWithException.AlternateExceptionType != null && aex.InnerException?.GetType() != metadataWithException.ExceptionType) + { + Assert.IsType(metadataWithException.AlternateExceptionType, aex.InnerException); + } + else + { + Assert.IsType(metadataWithException.ExceptionType, aex.InnerException); + } + } + else + { + throw; + } + } + } + + internal class MetadataWithExceptions + { + public MetadataWithExceptions(int? queryTimeoutSeconds, string actionName, Action action, Type? exceptionType, Type? alternateExceptionType) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + ActionName = actionName; + MetadataAction = action; + ExceptionType = exceptionType; + AlternateExceptionType = alternateExceptionType; + } + + /// + /// If null, uses the default timeout. + /// + public int? QueryTimeoutSeconds { get; } + + public string ActionName { get; } + + /// + /// If null, expected to succeed. + /// + public Type? ExceptionType { get; } + + /// + /// Sometimes you can expect one but may get another. + /// For example, on GetObjectsAll, sometimes a TTransportException is expected but a TaskCanceledException is received during the test. + /// + public Type? AlternateExceptionType { get; } + + /// + /// The metadata action to perform. + /// + public Action MetadataAction { get; } + } + + /// + /// Used for testing timeouts on metadata calls. + /// + internal class MetadataTimeoutTestData : TheoryData + { + public MetadataTimeoutTestData() + { + SparkConnectionTest sparkConnectionTest = new SparkConnectionTest(null); + + Action getObjectsAll = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.All, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Table, null, null); + }; + + Action getObjectsCatalogs = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.Catalogs, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); + }; + + Action getObjectsDbSchemas = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.DbSchemas, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); + }; + + Action getObjectsTables = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.Tables, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); + }; + + AddAction("getObjectsAll", getObjectsAll, new List() { null, typeof(TaskCanceledException), null, null, null } ); + AddAction("getObjectsCatalogs", getObjectsCatalogs); + AddAction("getObjectsDbSchemas", getObjectsDbSchemas); + AddAction("getObjectsTables", getObjectsTables); + + Action getTableTypes = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetTableTypes(); + }; + + AddAction("getTableTypes", getTableTypes); + + Action getTableSchema = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetTableSchema(testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Table); + }; + + AddAction("getTableSchema", getTableSchema); + } + + private void AddAction(string name, Action action, List? alternateExceptions = null) + { + List expectedExceptions = new List() + { + null, // QueryTimeout = -1 + typeof(TTransportException), // QueryTimeout = 1 + typeof(TimeoutException), // QueryTimeout = 10 + null, // QueryTimeout = default + null // QueryTimeout = 300 + }; + + AddAction(name, action, expectedExceptions, alternateExceptions); + } + + /// + /// Adds the action with the default timeouts. + /// + /// The action to perform. + /// The expected exceptions. + /// + /// For List the position is based on the behavior when: + /// [0] QueryTimeout = -1 + /// [1] QueryTimeout = 1 + /// [2] QueryTimeout = 10 + /// [3] QueryTimeout = default + /// [4] QueryTimeout = 300 + /// + private void AddAction(string name, Action action, List expectedExceptions, List? alternateExceptions) + { + Assert.True(expectedExceptions.Count == 5); + + Add(new(-1, name, action, expectedExceptions[0], alternateExceptions?[0])); + Add(new(1, name, action, expectedExceptions[1], alternateExceptions?[1])); + Add(new(10, name, action, expectedExceptions[2], alternateExceptions?[2])); + Add(new(null, name, action, expectedExceptions[3], alternateExceptions?[3])); + Add(new(300, name, action, expectedExceptions[4], alternateExceptions?[4])); + } + } + + internal class ParametersWithExceptions + { + public ParametersWithExceptions(Dictionary parameters, Type exceptionType) + { + Parameters = parameters; + ExceptionType = exceptionType; + } + + public IReadOnlyDictionary Parameters { get; } + public Type ExceptionType { get; } + } + + internal class InvalidConnectionParametersTestData : TheoryData + { + public InvalidConnectionParametersTestData() + { + Add(new([], typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = " " }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = "xxx" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Standard }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = " " }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "invalid!server.com" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "http://valid.server.com" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"unknown_auth_type" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Basic}" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Token}" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Basic}", [SparkParameters.Token] = "abcdef" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Token}", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Password] = "myPassword" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = "-1" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = IPEndPoint.MinPort.ToString(CultureInfo.InvariantCulture) }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = (IPEndPoint.MaxPort + 1).ToString(CultureInfo.InvariantCulture) }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "httpxxz://hostname.com" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "http-//hostname.com" }, typeof(UriFormatException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "httpxxz://hostname.com:1234567890" }, typeof(UriFormatException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = "0" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = ((long)int.MaxValue + 1).ToString() }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = "non-numeric" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = "" }, typeof(ArgumentOutOfRangeException))); + } + } + } +} diff --git a/csharp/test/Drivers/Apache/Spark/StatementTests.cs b/csharp/test/Drivers/Apache/Spark/StatementTests.cs index 3364e58c6a..2baf40b0d9 100644 --- a/csharp/test/Drivers/Apache/Spark/StatementTests.cs +++ b/csharp/test/Drivers/Apache/Spark/StatementTests.cs @@ -1,232 +1,232 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -using System; -using System.Collections.Generic; -using System.Threading.Tasks; -using Apache.Arrow.Adbc.Drivers.Apache.Spark; -using Apache.Arrow.Adbc.Tests.Xunit; -using Thrift.Transport; -using Xunit; -using Xunit.Abstractions; -using static Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark.SparkConnectionTest; - -namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark -{ - /// - /// Class for testing the Snowflake ADBC driver connection tests. - /// - /// - /// Tests are ordered to ensure data is created for the other - /// queries to run. - /// - [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] - public class StatementTests : TestBase - { - private static List DefaultTableTypes => new() { "TABLE", "VIEW" }; - - public StatementTests(ITestOutputHelper? outputHelper) : base(outputHelper, new SparkTestEnvironment.Factory()) - { - Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); - } - - /// - /// Validates if the SetOption handle valid/invalid data correctly for the PollTime option. - /// - [SkippableTheory] - [InlineData("-1", true)] - [InlineData("zero", true)] - [InlineData("-2147483648", true)] - [InlineData("2147483648", true)] - [InlineData("0")] - [InlineData("1")] - [InlineData("2147483647")] - public void CanSetOptionPollTime(string value, bool throws = false) - { - var testConfiguration = TestConfiguration.Clone() as SparkTestConfiguration; - testConfiguration!.PollTimeMilliseconds = value; - if (throws) - { - Assert.Throws(() => NewConnection(testConfiguration).CreateStatement()); - } - - AdbcStatement statement = NewConnection().CreateStatement(); - if (throws) - { - Assert.Throws(() => statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, value)); - } - else - { - statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, value); - } - } - - /// - /// Validates if the SetOption handle valid/invalid data correctly for the BatchSize option. - /// - [SkippableTheory] - [InlineData("-1", true)] - [InlineData("one", true)] - [InlineData("-2147483648", true)] - [InlineData("2147483648", false)] - [InlineData("9223372036854775807", false)] - [InlineData("9223372036854775808", true)] - [InlineData("0", true)] - [InlineData("1")] - [InlineData("2147483647")] - public void CanSetOptionBatchSize(string value, bool throws = false) - { - var testConfiguration = TestConfiguration.Clone() as SparkTestConfiguration; - testConfiguration!.BatchSize = value; - if (throws) - { - Assert.Throws(() => NewConnection(testConfiguration).CreateStatement()); - } - - AdbcStatement statement = NewConnection().CreateStatement(); - if (throws) - { - Assert.Throws(() => statement!.SetOption(SparkStatement.Options.BatchSize, value)); - } - else - { - statement.SetOption(SparkStatement.Options.BatchSize, value); - } - } - - /// - /// Validates if the SetOption handle valid/invalid data correctly for the QueryTimeout option. - /// - [SkippableTheory] - [InlineData("zero", true)] - [InlineData("-2147483648", true)] - [InlineData("2147483648", true)] - [InlineData("0", true)] - [InlineData("-1")] - [InlineData("1")] - [InlineData("2147483647")] - public void CanSetOptionQueryTimeout(string value, bool throws = false) - { - var testConfiguration = TestConfiguration.Clone() as SparkTestConfiguration; - testConfiguration!.QueryTimeoutSeconds = value; - if (throws) - { - Assert.Throws(() => NewConnection(testConfiguration).CreateStatement()); - } - - AdbcStatement statement = NewConnection().CreateStatement(); - if (throws) - { - Assert.Throws(() => statement.SetOption(SparkStatement.Options.QueryTimeoutSeconds, value)); - } - else - { - statement.SetOption(SparkStatement.Options.QueryTimeoutSeconds, value); - } - } - - [SkippableTheory] - [ClassData(typeof(StatementTimeoutTestData))] - internal void StatementTimeoutTest(StatementWithExceptions statementWithExceptions) - { - SparkTestConfiguration testConfiguration = (SparkTestConfiguration)TestConfiguration.Clone(); - - if (statementWithExceptions.QueryTimeoutSeconds.HasValue) - testConfiguration.QueryTimeoutSeconds = statementWithExceptions.QueryTimeoutSeconds.Value.ToString(); - - if (!string.IsNullOrEmpty(statementWithExceptions.Query)) - testConfiguration.Query = statementWithExceptions.Query!; - - OutputHelper?.WriteLine($"QueryTimeoutSeconds: {testConfiguration.QueryTimeoutSeconds}. ShouldSucceed: {statementWithExceptions.ExceptionType == null}. Query: [{testConfiguration.Query}]"); - - try - { - AdbcStatement st = NewConnection(testConfiguration).CreateStatement(); - st.SqlQuery = testConfiguration.Query; - QueryResult qr = st.ExecuteQuery(); - - OutputHelper?.WriteLine($"QueryResultRowCount: {qr.RowCount}"); - } - catch (AggregateException aex) - { - if (statementWithExceptions.ExceptionType != null) - { - Assert.IsType(statementWithExceptions.ExceptionType, aex.InnerException); - } - else - { - throw; - } - } - } - - /// - /// Validates if the driver can execute update statements. - /// - [SkippableFact, Order(1)] - public async Task CanInteractUsingSetOptions() - { - const string columnName = "INDEX"; - Statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, "100"); - Statement.SetOption(SparkStatement.Options.BatchSize, "10"); - using TemporaryTable temporaryTable = await NewTemporaryTableAsync(Statement, $"{columnName} INT"); - await ValidateInsertSelectDeleteSingleValueAsync(temporaryTable.TableName, columnName, 1); - } - } - - internal class StatementWithExceptions - { - public StatementWithExceptions(int? queryTimeoutSeconds, string? query, Type? exceptionType) - { - QueryTimeoutSeconds = queryTimeoutSeconds; - Query = query; - ExceptionType = exceptionType; - } - - /// - /// If null, uses the default timeout. - /// - public int? QueryTimeoutSeconds { get; } - - /// - /// If null, expected to succeed. - /// - public Type? ExceptionType { get; } - - /// - /// If null, uses the default TestConfiguration - /// - public string? Query { get; } - } - - internal class StatementTimeoutTestData : TheoryData - { - public StatementTimeoutTestData() - { - string longRunningQuery = "SELECT COUNT(*) AS total_count\nFROM (\n SELECT t1.id AS id1, t2.id AS id2\n FROM RANGE(1000000) t1\n CROSS JOIN RANGE(10000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0"; - - Add(new(-1, null, null)); - Add(new(null, null, null)); - Add(new(1, null, typeof(TTransportException))); - Add(new(5, null, null)); - Add(new(30, null, null)); - Add(new(5, longRunningQuery, typeof(TTransportException))); - Add(new(null, longRunningQuery, typeof(TimeoutException))); - Add(new(-1, longRunningQuery, null)); - } - } -} +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Apache.Arrow.Adbc.Tests.Xunit; +using Thrift.Transport; +using Xunit; +using Xunit.Abstractions; +using static Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark.SparkConnectionTest; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark +{ + /// + /// Class for testing the Snowflake ADBC driver connection tests. + /// + /// + /// Tests are ordered to ensure data is created for the other + /// queries to run. + /// + [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] + public class StatementTests : TestBase + { + private static List DefaultTableTypes => new() { "TABLE", "VIEW" }; + + public StatementTests(ITestOutputHelper? outputHelper) : base(outputHelper, new SparkTestEnvironment.Factory()) + { + Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); + } + + /// + /// Validates if the SetOption handle valid/invalid data correctly for the PollTime option. + /// + [SkippableTheory] + [InlineData("-1", true)] + [InlineData("zero", true)] + [InlineData("-2147483648", true)] + [InlineData("2147483648", true)] + [InlineData("0")] + [InlineData("1")] + [InlineData("2147483647")] + public void CanSetOptionPollTime(string value, bool throws = false) + { + var testConfiguration = TestConfiguration.Clone() as SparkTestConfiguration; + testConfiguration!.PollTimeMilliseconds = value; + if (throws) + { + Assert.Throws(() => NewConnection(testConfiguration).CreateStatement()); + } + + AdbcStatement statement = NewConnection().CreateStatement(); + if (throws) + { + Assert.Throws(() => statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, value)); + } + else + { + statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, value); + } + } + + /// + /// Validates if the SetOption handle valid/invalid data correctly for the BatchSize option. + /// + [SkippableTheory] + [InlineData("-1", true)] + [InlineData("one", true)] + [InlineData("-2147483648", true)] + [InlineData("2147483648", false)] + [InlineData("9223372036854775807", false)] + [InlineData("9223372036854775808", true)] + [InlineData("0", true)] + [InlineData("1")] + [InlineData("2147483647")] + public void CanSetOptionBatchSize(string value, bool throws = false) + { + var testConfiguration = TestConfiguration.Clone() as SparkTestConfiguration; + testConfiguration!.BatchSize = value; + if (throws) + { + Assert.Throws(() => NewConnection(testConfiguration).CreateStatement()); + } + + AdbcStatement statement = NewConnection().CreateStatement(); + if (throws) + { + Assert.Throws(() => statement!.SetOption(SparkStatement.Options.BatchSize, value)); + } + else + { + statement.SetOption(SparkStatement.Options.BatchSize, value); + } + } + + /// + /// Validates if the SetOption handle valid/invalid data correctly for the QueryTimeout option. + /// + [SkippableTheory] + [InlineData("zero", true)] + [InlineData("-2147483648", true)] + [InlineData("2147483648", true)] + [InlineData("0", true)] + [InlineData("-1")] + [InlineData("1")] + [InlineData("2147483647")] + public void CanSetOptionQueryTimeout(string value, bool throws = false) + { + var testConfiguration = TestConfiguration.Clone() as SparkTestConfiguration; + testConfiguration!.QueryTimeoutSeconds = value; + if (throws) + { + Assert.Throws(() => NewConnection(testConfiguration).CreateStatement()); + } + + AdbcStatement statement = NewConnection().CreateStatement(); + if (throws) + { + Assert.Throws(() => statement.SetOption(SparkStatement.Options.QueryTimeoutSeconds, value)); + } + else + { + statement.SetOption(SparkStatement.Options.QueryTimeoutSeconds, value); + } + } + + [SkippableTheory] + [ClassData(typeof(StatementTimeoutTestData))] + internal void StatementTimeoutTest(StatementWithExceptions statementWithExceptions) + { + SparkTestConfiguration testConfiguration = (SparkTestConfiguration)TestConfiguration.Clone(); + + if (statementWithExceptions.QueryTimeoutSeconds.HasValue) + testConfiguration.QueryTimeoutSeconds = statementWithExceptions.QueryTimeoutSeconds.Value.ToString(); + + if (!string.IsNullOrEmpty(statementWithExceptions.Query)) + testConfiguration.Query = statementWithExceptions.Query!; + + OutputHelper?.WriteLine($"QueryTimeoutSeconds: {testConfiguration.QueryTimeoutSeconds}. ShouldSucceed: {statementWithExceptions.ExceptionType == null}. Query: [{testConfiguration.Query}]"); + + try + { + AdbcStatement st = NewConnection(testConfiguration).CreateStatement(); + st.SqlQuery = testConfiguration.Query; + QueryResult qr = st.ExecuteQuery(); + + OutputHelper?.WriteLine($"QueryResultRowCount: {qr.RowCount}"); + } + catch (AggregateException aex) + { + if (statementWithExceptions.ExceptionType != null) + { + Assert.IsType(statementWithExceptions.ExceptionType, aex.InnerException); + } + else + { + throw; + } + } + } + + /// + /// Validates if the driver can execute update statements. + /// + [SkippableFact, Order(1)] + public async Task CanInteractUsingSetOptions() + { + const string columnName = "INDEX"; + Statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, "100"); + Statement.SetOption(SparkStatement.Options.BatchSize, "10"); + using TemporaryTable temporaryTable = await NewTemporaryTableAsync(Statement, $"{columnName} INT"); + await ValidateInsertSelectDeleteSingleValueAsync(temporaryTable.TableName, columnName, 1); + } + } + + internal class StatementWithExceptions + { + public StatementWithExceptions(int? queryTimeoutSeconds, string? query, Type? exceptionType) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + Query = query; + ExceptionType = exceptionType; + } + + /// + /// If null, uses the default timeout. + /// + public int? QueryTimeoutSeconds { get; } + + /// + /// If null, expected to succeed. + /// + public Type? ExceptionType { get; } + + /// + /// If null, uses the default TestConfiguration + /// + public string? Query { get; } + } + + internal class StatementTimeoutTestData : TheoryData + { + public StatementTimeoutTestData() + { + string longRunningQuery = "SELECT COUNT(*) AS total_count\nFROM (\n SELECT t1.id AS id1, t2.id AS id2\n FROM RANGE(1000000) t1\n CROSS JOIN RANGE(10000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0"; + + Add(new(-1, null, null)); + Add(new(null, null, null)); + Add(new(1, null, typeof(TTransportException))); + Add(new(5, null, null)); + Add(new(30, null, null)); + Add(new(5, longRunningQuery, typeof(TTransportException))); + Add(new(null, longRunningQuery, typeof(TimeoutException))); + Add(new(-1, longRunningQuery, null)); + } + } +}