3535import java .util .concurrent .ExecutionException ;
3636import java .util .concurrent .TimeUnit ;
3737import java .util .concurrent .atomic .AtomicInteger ;
38+ import java .util .function .BiConsumer ;
3839import java .util .function .BiPredicate ;
40+ import java .util .function .Function ;
3941import java .util .stream .Collectors ;
4042import java .util .stream .Stream ;
4143import java .util .stream .StreamSupport ;
5052
5153import com .datastax .driver .core .utils .UUIDs ;
5254import org .apache .cassandra .Util ;
55+ import org .apache .cassandra .cql3 .CQL3Type ;
5356import org .apache .cassandra .cql3 .QueryProcessor ;
5457import org .apache .cassandra .cql3 .UntypedResultSet ;
5558import org .apache .cassandra .cql3 .constraints .ConstraintViolationException ;
6164import org .apache .cassandra .db .compression .CompressionDictionary ;
6265import org .apache .cassandra .db .compression .CompressionDictionary .DictId ;
6366import org .apache .cassandra .db .compression .ZstdCompressionDictionary ;
67+ import org .apache .cassandra .db .marshal .AbstractType ;
6468import org .apache .cassandra .db .marshal .FloatType ;
69+ import org .apache .cassandra .db .marshal .SimpleDateType ;
70+ import org .apache .cassandra .db .marshal .TimeType ;
6571import org .apache .cassandra .db .marshal .UTF8Type ;
6672import org .apache .cassandra .dht .ByteOrderedPartitioner ;
6773import org .apache .cassandra .dht .Murmur3Partitioner ;
@@ -1621,9 +1627,36 @@ public void testSkipBuildingIndexesWithSAI() throws Exception
16211627 @ Test
16221628 public void testWritingVectorData () throws Exception
16231629 {
1630+ testWritingVectorData (CQL3Type .Native .FLOAT , FloatType .instance , (i ) -> (float ) i , (i , vector ) -> {
1631+ assertThat (vector ).allMatch (val -> val instanceof Float );
1632+ assertThat (vector ).allMatch (val -> (float ) val == (float ) i );
1633+ });
1634+
1635+ perTestSetup ();
1636+
1637+ testWritingVectorData (CQL3Type .Native .DATE , SimpleDateType .instance , LocalDate ::fromDaysSinceEpoch , (i , vector ) -> {
1638+ assertThat (vector ).allMatch (val -> val instanceof Integer );
1639+ assertThat (vector ).allMatch (val -> {
1640+ int days = (int ) val - Integer .MIN_VALUE ; // signed to unsigned conversion
1641+ return days == i ;
1642+ });
1643+ });
1644+
1645+ perTestSetup ();
1646+
1647+ testWritingVectorData (CQL3Type .Native .TIME , TimeType .instance , (i ) -> (long ) i , (i , vector ) -> {
1648+ assertThat (vector ).allMatch (val -> val instanceof Long );
1649+ assertThat (vector ).allMatch (val -> (long ) val == (long ) i );
1650+ });
1651+ }
1652+
1653+ private void testWritingVectorData (CQL3Type .Native cqlType , AbstractType <?> subType , Function <Integer , ?> valueFactory ,
1654+ BiConsumer <Integer , List <?>> checkFunction ) throws Exception
1655+ {
1656+ final int dimensions = 5 ;
16241657 final String schema = "CREATE TABLE " + qualifiedTable + " ("
16251658 + " k int,"
1626- + " v1 VECTOR<FLOAT, 5 >,"
1659+ + " v1 VECTOR<" + cqlType . name () + ", " + dimensions + " >,"
16271660 + " PRIMARY KEY (k)"
16281661 + ")" ;
16291662
@@ -1635,7 +1668,12 @@ public void testWritingVectorData() throws Exception
16351668
16361669 for (int i = 0 ; i < 100 ; i ++)
16371670 {
1638- writer .addRow (i , List .of ( (float )i , (float )i , (float )i , (float )i , (float )i ));
1671+ List <Object > vector = new ArrayList <>(dimensions );
1672+ for (int j = 0 ; j < dimensions ; j ++)
1673+ {
1674+ vector .add (valueFactory .apply (i ));
1675+ }
1676+ writer .addRow (i , vector );
16391677 }
16401678
16411679 writer .close ();
@@ -1650,10 +1688,9 @@ public void testWritingVectorData() throws Exception
16501688 for (UntypedResultSet .Row row : resultSet )
16511689 {
16521690 assertEquals (cnt , row .getInt ("k" ));
1653- List <Float > vector = row .getVector ("v1" , FloatType .instance , 5 );
1654- assertThat (vector ).hasSize (5 );
1655- final float floatCount = (float )cnt ;
1656- assertThat (vector ).allMatch (val -> val == floatCount );
1691+ List <?> vector = row .getVector ("v1" , subType , dimensions );
1692+ assertThat (vector ).hasSize (dimensions );
1693+ checkFunction .accept (cnt , vector );
16571694 cnt ++;
16581695 }
16591696 }
0 commit comments