@@ -21,6 +21,7 @@ import (
21
21
"errors"
22
22
"fmt"
23
23
"io"
24
+ "slices"
24
25
"sync"
25
26
"sync/atomic"
26
27
@@ -116,6 +117,7 @@ type colReaderImpl interface {
116
117
GetDefLevels () ([]int16 , error )
117
118
GetRepLevels () ([]int16 , error )
118
119
Field () * arrow.Field
120
+ SeekToRow (int64 ) error
119
121
IsOrHasRepeatedChild () bool
120
122
Retain ()
121
123
Release ()
@@ -427,6 +429,20 @@ func (fr *FileReader) getColumnReader(ctx context.Context, i int, colFactory itr
427
429
type RecordReader interface {
428
430
array.RecordReader
429
431
arrio.Reader
432
+ // SeekToRow will shift the record reader so that subsequent calls to Read
433
+ // or Next will begin from the specified row.
434
+ //
435
+ // If the record reader was constructed with a request for a subset of row
436
+ // groups, then rows are counted across the requested row groups, not the
437
+ // entire file. This prevents reading row groups that were requested to be
438
+ // skipped, and allows treating the subset of row groups as a single collection
439
+ // of rows.
440
+ //
441
+ // If the file contains Offset indexes for a given column, then it will be
442
+ // utilized to skip pages as needed to find the requested row. Otherwise page
443
+ // headers will have to still be read to find the right page to being reading
444
+ // from.
445
+ SeekToRow (int64 ) error
430
446
}
431
447
432
448
// GetRecordReader returns a record reader that reads only the requested column indexes and row groups.
@@ -537,12 +553,8 @@ func (fr *FileReader) getReader(ctx context.Context, field *SchemaField, arrowFi
537
553
}
538
554
539
555
// because we performed getReader concurrently, we need to prune out any empty readers
540
- for n := len (childReaders ) - 1 ; n >= 0 ; n -- {
541
- if childReaders [n ] == nil {
542
- childReaders = append (childReaders [:n ], childReaders [n + 1 :]... )
543
- childFields = append (childFields [:n ], childFields [n + 1 :]... )
544
- }
545
- }
556
+ childReaders = slices .DeleteFunc (childReaders ,
557
+ func (r * ColumnReader ) bool { return r == nil })
546
558
if len (childFields ) == 0 {
547
559
return nil , nil
548
560
}
@@ -615,15 +627,45 @@ type columnIterator struct {
615
627
rdr * file.Reader
616
628
schema * schema.Schema
617
629
rowGroups []int
630
+
631
+ rgIdx int
618
632
}
619
633
620
- func (c * columnIterator ) NextChunk ( ) (file.PageReader , error ) {
634
+ func (c * columnIterator ) FindChunkForRow ( rowIdx int64 ) (file.PageReader , int64 , error ) {
621
635
if len (c .rowGroups ) == 0 {
636
+ return nil , 0 , nil
637
+ }
638
+
639
+ if rowIdx < 0 || rowIdx > c .rdr .NumRows () {
640
+ return nil , 0 , fmt .Errorf ("invalid row index %d, file only has %d rows" , rowIdx , c .rdr .NumRows ())
641
+ }
642
+
643
+ idx := int64 (0 )
644
+ for i , rg := range c .rowGroups {
645
+ rgr := c .rdr .RowGroup (rg )
646
+ if idx + rgr .NumRows () > rowIdx {
647
+ c .rgIdx = i + 1
648
+ pr , err := rgr .GetColumnPageReader (c .index )
649
+ if err != nil {
650
+ return nil , 0 , err
651
+ }
652
+
653
+ return pr , rowIdx - idx , nil
654
+ }
655
+ idx += rgr .NumRows ()
656
+ }
657
+
658
+ return nil , 0 , fmt .Errorf ("%w: invalid row index %d, row group subset only has %d total rows" ,
659
+ arrow .ErrInvalid , rowIdx , idx )
660
+ }
661
+
662
+ func (c * columnIterator ) NextChunk () (file.PageReader , error ) {
663
+ if len (c .rowGroups ) == 0 || c .rgIdx >= len (c .rowGroups ) {
622
664
return nil , nil
623
665
}
624
666
625
- rgr := c .rdr .RowGroup (c .rowGroups [0 ])
626
- c .rowGroups = c . rowGroups [ 1 :]
667
+ rgr := c .rdr .RowGroup (c .rowGroups [c . rgIdx ])
668
+ c .rgIdx ++
627
669
return rgr .GetColumnPageReader (c .index )
628
670
}
629
671
@@ -643,6 +685,25 @@ type recordReader struct {
643
685
refCount int64
644
686
}
645
687
688
+ func (r * recordReader ) SeekToRow (row int64 ) error {
689
+ if r .cur != nil {
690
+ r .cur .Release ()
691
+ r .cur = nil
692
+ }
693
+
694
+ if row < 0 || row >= r .numRows {
695
+ return fmt .Errorf ("invalid row index %d, file only has %d rows" , row , r .numRows )
696
+ }
697
+
698
+ for _ , fr := range r .fieldReaders {
699
+ if err := fr .SeekToRow (row ); err != nil {
700
+ return err
701
+ }
702
+ }
703
+
704
+ return nil
705
+ }
706
+
646
707
func (r * recordReader ) Retain () {
647
708
atomic .AddInt64 (& r .refCount , 1 )
648
709
}
0 commit comments