Skip to content

Commit

Permalink
FMWK-276 Fix parameterized queries in PreparedStatement
Browse files Browse the repository at this point in the history
  • Loading branch information
reugn committed Dec 6, 2023
1 parent 9f1c08c commit d17c235
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 107 deletions.
27 changes: 19 additions & 8 deletions src/main/java/com/aerospike/jdbc/AerospikeDatabaseMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.aerospike.jdbc.sql.ListRecordSet;
import com.aerospike.jdbc.sql.SimpleWrapper;
import com.aerospike.jdbc.util.AerospikeUtils;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;

import java.io.IOException;
import java.io.StringReader;
Expand All @@ -22,6 +24,7 @@
import java.sql.Statement;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -49,14 +52,15 @@ public class AerospikeDatabaseMetadata implements DatabaseMetaData, SimpleWrappe
private static final String NEW_LINE = System.lineSeparator();

private final String url;
private final AerospikeConnection connection;
private final Connection connection;
private final String dbBuild;
private final String dbEdition;
private final List<String> catalogs;
private final Map<String, Collection<String>> tables;
private final Map<String, Collection<AerospikeSecondaryIndex>> catalogIndexes;
private final Map<String, AerospikeSecondaryIndex> secondaryIndexes;
private final AerospikeSchemaBuilder schemaBuilder;
private final Cache<String, ResultSetMetaData> resultSetMetaDataCache;

public AerospikeDatabaseMetadata(String url, IAerospikeClient client, AerospikeConnection connection) {
logger.info("Init AerospikeDatabaseMetadata");
Expand Down Expand Up @@ -103,6 +107,7 @@ public AerospikeDatabaseMetadata(String url, IAerospikeClient client, AerospikeC
.collect(Collectors.toMap(AerospikeSecondaryIndex::toKey, Function.identity()));

schemaBuilder = new AerospikeSchemaBuilder(client, connection.getConfiguration().getDriverPolicy());
resultSetMetaDataCache = CacheBuilder.newBuilder().build();

dbBuild = join("N/A", ", ", builds);
dbEdition = join("Aerospike", ", ", editions);
Expand Down Expand Up @@ -1304,13 +1309,19 @@ private int ordinal(ResultSetMetaData md, String columnName) {
}

private ResultSetMetaData getMetadata(String namespace, String table) {
try (Statement statement = connection.createStatement()) {
String query = format("SELECT * FROM \"%s.%s\" LIMIT %d", namespace, table,
connection.getConfiguration().getDriverPolicy().getSchemaBuilderMaxRecords());
return statement.executeQuery(query).getMetaData();
} catch (SQLException e) {
logger.severe(() -> format("Exception in getMetadata, namespace: %s, table: %s", namespace, table));
throw new IllegalArgumentException(e);
final String key = format("%s.%s", namespace, table);
try {
return resultSetMetaDataCache.get(key, () -> {
try (Statement statement = connection.createStatement()) {
String query = format("SELECT * FROM \"%s.%s\" LIMIT 1", namespace, table);
return statement.executeQuery(query).getMetaData();
} catch (SQLException e) {
logger.severe(() -> format("Exception in getMetadata, namespace: %s, table: %s", namespace, table));
throw new IllegalArgumentException(e);
}
});
} catch (ExecutionException e) {
throw new IllegalArgumentException(e.getCause());
}
}

Expand Down
79 changes: 44 additions & 35 deletions src/main/java/com/aerospike/jdbc/AerospikePreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.aerospike.client.Value;
import com.aerospike.jdbc.model.AerospikeQuery;
import com.aerospike.jdbc.model.DataColumn;
import com.aerospike.jdbc.model.QueryType;
import com.aerospike.jdbc.sql.AerospikeResultSetMetaData;
import com.aerospike.jdbc.sql.SimpleParameterMetaData;
import com.aerospike.jdbc.sql.type.ByteArrayBlob;
Expand All @@ -26,43 +27,40 @@

import static com.aerospike.jdbc.util.PreparedStatement.parseParameters;
import static java.lang.String.format;
import static java.util.Objects.isNull;

public class AerospikePreparedStatement extends AerospikeStatement implements PreparedStatement {

private static final Logger logger = Logger.getLogger(AerospikePreparedStatement.class.getName());

private final String sql;
private final AerospikeConnection connection;
private final Object[] parameterValues;
private final AerospikeQuery query;
private final String sqlStatement;
private final Object[] sqlParameters;

public AerospikePreparedStatement(IAerospikeClient client, AerospikeConnection connection, String sql) {
public AerospikePreparedStatement(IAerospikeClient client, AerospikeConnection connection, String sqlStatement) {
super(client, connection);
this.sql = sql;
this.connection = connection;
parameterValues = buildParameterValues(sql);
try {
query = parseQuery(sql);
} catch (SQLException e) {
throw new UnsupportedOperationException(e);
}
this.sqlStatement = sqlStatement;
sqlParameters = buildSqlParameters(sqlStatement);
logger.info(() -> format("statement: %s, params: %d", sqlStatement, sqlParameters.length));
}

private Object[] buildParameterValues(String sql) {
private Object[] buildSqlParameters(String sql) {
int params = parseParameters(sql, 0).getValue();
return new Object[params];
}

@Override
public ResultSet executeQuery() throws SQLException {
logger.info("AerospikePreparedStatement executeQuery");
return super.executeQuery(sql);
String preparedQueryString = prepareQueryString();
logger.info(() -> "executeQuery: " + preparedQueryString);
AerospikeQuery query = parseQuery(preparedQueryString);
runQuery(query);
return resultSet;
}

@Override
public int executeUpdate() throws SQLException {
logger.info("AerospikePreparedStatement executeUpdate");
return super.executeUpdate(sql);
executeQuery();
return updateCount;
}

@Override
Expand Down Expand Up @@ -116,7 +114,7 @@ public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException

@Override
public void setString(int parameterIndex, String x) throws SQLException {
setObject(parameterIndex, "\"" + x + "\"");
setObject(parameterIndex, format("\"%s\"", x));
}

@Override
Expand Down Expand Up @@ -149,6 +147,7 @@ public void setAsciiStream(int parameterIndex, InputStream x, int length) throws
*/
@Override
@Deprecated
@SuppressWarnings("java:S1133")
public void setUnicodeStream(int parameterIndex, InputStream x, int length) throws SQLException {
throw new SQLFeatureNotSupportedException("setUnicodeStream is deprecated");
}
Expand All @@ -160,7 +159,7 @@ public void setBinaryStream(int parameterIndex, InputStream x, int length) throw

@Override
public void clearParameters() {
Arrays.fill(parameterValues, null);
Arrays.fill(sqlParameters, null);
}

@Override
Expand All @@ -170,28 +169,36 @@ public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQ

@Override
public void setObject(int parameterIndex, Object x) throws SQLException {
if (parameterIndex <= 0 || parameterIndex > parameterValues.length) {
throw new SQLException(parameterValues.length == 0 ?
"Current SQL statement does not have parameters" :
format("Wrong parameter index. Expected from %d till %d", 1, parameterValues.length));
if (parameterIndex <= 0 || parameterIndex > sqlParameters.length) {
throw new SQLDataException(sqlParameters.length == 0
? "Current SQL statement does not have parameters"
: format("The parameter index %d is out of range, number of parameters: %d",
parameterIndex, sqlParameters.length));
}
parameterValues[parameterIndex - 1] = x;
sqlParameters[parameterIndex - 1] = x;
}

@Override
public boolean execute() throws SQLException {
String preparedQuery = prepareQuery();
logger.info(preparedQuery);
return execute(preparedQuery);
}

private String prepareQuery() {
return format(this.sql.replace("?", "%s"), parameterValues);
String preparedQueryString = prepareQueryString();
logger.info(() -> "execute: " + preparedQueryString);
AerospikeQuery query = parseQuery(preparedQueryString);
runQuery(query);
return query.getQueryType() == QueryType.SELECT;
}

private String prepareQueryString() {
String preparedQueryString = sqlStatement;
for (Object value : sqlParameters) {
String replacement = isNull(value) ? "?" : value.toString();
preparedQueryString = preparedQueryString.replaceFirst("\\?", replacement);
}
return preparedQueryString;
}

@Override
public void addBatch() throws SQLException {
addBatch(sql);
addBatch(prepareQueryString());
}

@Override
Expand Down Expand Up @@ -221,6 +228,7 @@ public void setArray(int parameterIndex, Array x) throws SQLException {

@Override
public ResultSetMetaData getMetaData() throws SQLException {
AerospikeQuery query = parseQuery(prepareQueryString());
List<DataColumn> columns = ((AerospikeDatabaseMetadata) connection.getMetaData())
.getSchemaBuilder()
.getSchema(query.getSchemaTable());
Expand Down Expand Up @@ -254,6 +262,7 @@ public void setURL(int parameterIndex, URL url) throws SQLException {

@Override
public ParameterMetaData getParameterMetaData() throws SQLException {
AerospikeQuery query = parseQuery(prepareQueryString());
List<DataColumn> columns = ((AerospikeDatabaseMetadata) connection.getMetaData())
.getSchemaBuilder()
.getSchema(query.getSchemaTable());
Expand Down Expand Up @@ -297,9 +306,9 @@ public void setClob(int parameterIndex, Reader reader, long length) throws SQLEx
@Override
public void setBlob(int parameterIndex, InputStream inputStream, long length) throws SQLException {
byte[] bytes = new byte[(int) length];
DataInputStream dis = new DataInputStream(inputStream);
DataInputStream dataInputStream = new DataInputStream(inputStream);
try {
dis.readFully(bytes);
dataInputStream.readFully(bytes);
if (inputStream.read() != -1) {
throw new SQLException(format("Source contains more bytes than required %d", length));
}
Expand Down
21 changes: 14 additions & 7 deletions src/main/java/com/aerospike/jdbc/AerospikeStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.aerospike.client.IAerospikeClient;
import com.aerospike.jdbc.model.AerospikeQuery;
import com.aerospike.jdbc.model.Pair;
import com.aerospike.jdbc.model.QueryType;
import com.aerospike.jdbc.query.QueryPerformer;
import com.aerospike.jdbc.sql.SimpleWrapper;
import com.aerospike.jdbc.util.AuxStatementParser;
Expand All @@ -29,12 +30,14 @@ public class AerospikeStatement implements Statement, SimpleWrapper {
private static final String AUTO_GENERATED_KEYS_NOT_SUPPORTED_MESSAGE = "Auto-generated keys are not supported";

protected final IAerospikeClient client;
private final Connection connection;
protected final AerospikeConnection connection;

protected String schema;
protected ResultSet resultSet;
protected int updateCount;

private int maxRows = Integer.MAX_VALUE;
private int queryTimeout;
private ResultSet resultSet;
private int updateCount;

public AerospikeStatement(IAerospikeClient client, AerospikeConnection connection) {
this.client = client;
Expand All @@ -50,12 +53,14 @@ public AerospikeStatement(IAerospikeClient client, AerospikeConnection connectio
public ResultSet executeQuery(String sql) throws SQLException {
logger.info(() -> "executeQuery: " + sql);
AerospikeQuery query = parseQuery(sql);
runQuery(query);
return resultSet;
}

protected void runQuery(AerospikeQuery query) {
Pair<ResultSet, Integer> result = QueryPerformer.executeQuery(client, this, query);
resultSet = result.getLeft();
updateCount = result.getRight();

return resultSet;
}

protected AerospikeQuery parseQuery(String sql) throws SQLException {
Expand Down Expand Up @@ -140,8 +145,10 @@ public void setCursorName(String name) throws SQLException {

@Override
public boolean execute(String sql) throws SQLException {
resultSet = executeQuery(sql);
return true;
logger.info(() -> "execute: " + sql);
AerospikeQuery query = parseQuery(sql);
runQuery(query);
return query.getQueryType() == QueryType.SELECT;
}

@Override
Expand Down
10 changes: 6 additions & 4 deletions src/main/java/com/aerospike/jdbc/model/DriverPolicy.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

public class DriverPolicy {

private static final int DEFAULT_CAPACITY = 256;
private static final int DEFAULT_TIMEOUT_MS = 1000;
private static final int DEFAULT_RECORD_SET_QUEUE_CAPACITY = 256;
private static final int DEFAULT_RECORD_SET_TIMEOUT_MS = 1000;
private static final int DEFAULT_METADATA_CACHE_TTL_SECONDS = 3600;
private static final int DEFAULT_SCHEMA_BUILDER_MAX_RECORDS = 1000;

Expand All @@ -15,8 +15,10 @@ public class DriverPolicy {
private final int schemaBuilderMaxRecords;

public DriverPolicy(Properties properties) {
recordSetQueueCapacity = parseInt(properties.getProperty("recordSetQueueCapacity"), DEFAULT_CAPACITY);
recordSetTimeoutMs = parseInt(properties.getProperty("recordSetTimeoutMs"), DEFAULT_TIMEOUT_MS);
recordSetQueueCapacity = parseInt(properties.getProperty("recordSetQueueCapacity"),
DEFAULT_RECORD_SET_QUEUE_CAPACITY);
recordSetTimeoutMs = parseInt(properties.getProperty("recordSetTimeoutMs"),
DEFAULT_RECORD_SET_TIMEOUT_MS);
metadataCacheTtlSeconds = parseInt(properties.getProperty("metadataCacheTtlSeconds"),
DEFAULT_METADATA_CACHE_TTL_SECONDS);
schemaBuilderMaxRecords = parseInt(properties.getProperty("schemaBuilderMaxRecords"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ default InputStream getAsciiStream(int columnIndex) throws SQLException {
*/
@Override
@Deprecated
@SuppressWarnings("java:S1133")
default InputStream getUnicodeStream(int columnIndex) throws SQLException {
return getUnicodeStream(getColumnLabel(columnIndex));
}
Expand Down
10 changes: 4 additions & 6 deletions src/test/java/com/aerospike/jdbc/DatabaseMetadataTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Objects;

import static com.aerospike.jdbc.util.TestUtil.closeQuietly;
import static java.lang.String.format;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
Expand All @@ -25,10 +26,8 @@ public void setUp() throws SQLException {
Objects.requireNonNull(connection, "connection is null");
PreparedStatement statement = null;
int count;
String query = String.format(
"insert into %s (bin1, int1, str1, bool1) values (11100, 1, \"bar\", true)",
tableName
);
String query = format("insert into %s (bin1, int1, str1, bool1) values (11100, 1, \"bar\", true)",
tableName);
try {
statement = connection.prepareStatement(query);
count = statement.executeUpdate();
Expand All @@ -43,7 +42,7 @@ public void tearDown() throws SQLException {
Objects.requireNonNull(connection, "connection is null");
PreparedStatement statement = null;
ResultSet resultSet = null;
String query = String.format("delete from %s", tableName);
String query = format("delete from %s", tableName);
try {
statement = connection.prepareStatement(query);
resultSet = statement.executeQuery();
Expand All @@ -57,7 +56,6 @@ public void tearDown() throws SQLException {
@Test
public void testGetTables() throws SQLException {
DatabaseMetaData databaseMetaData = connection.getMetaData();

ResultSet rs = databaseMetaData.getTables(namespace, namespace, tableName, null);

assertTrue(rs.next());
Expand Down
Loading

0 comments on commit d17c235

Please sign in to comment.