diff --git a/cmd/timescaledb-parallel-copy/main.go b/cmd/timescaledb-parallel-copy/main.go index 40da550..1088352 100644 --- a/cmd/timescaledb-parallel-copy/main.go +++ b/cmd/timescaledb-parallel-copy/main.go @@ -139,7 +139,7 @@ func main() { result, err := copier.Copy(context.Background(), reader) if err != nil { - log.Fatal("failed to copy CSV:", err) + log.Fatal("failed to copy CSV: ", err) } res := fmt.Sprintf("COPY %d", result.RowsRead) diff --git a/internal/batch/scan.go b/internal/batch/scan.go index 76f0453..7936d31 100644 --- a/internal/batch/scan.go +++ b/internal/batch/scan.go @@ -18,6 +18,25 @@ type Options struct { Escape byte // the ESCAPE character; defaults to QUOTE } +// Batch represents an operation to copy data into the DB +type Batch struct { + Data net.Buffers + Location Location +} + +// Location positions a batch within the original data +type Location struct { + StartRow int64 + Length int +} + +func NewLocation(rowsRead int64, bufferedRows int, skip int) Location { + return Location{ + StartRow: rowsRead - int64(bufferedRows) + int64(skip), + Length: bufferedRows, + } +} + // Scan reads all lines from an io.Reader, partitions them into net.Buffers with // opts.Size rows each, and writes each batch to the out channel. If opts.Skip // is greater than zero, that number of lines will be discarded from the @@ -28,7 +47,7 @@ type Options struct { // Scan expects the input to be in Postgres CSV format. Since this format allows // rows to be split over multiple lines, the caller may provide opts.Quote and // opts.Escape as the QUOTE and ESCAPE characters used for the CSV input. -func Scan(ctx context.Context, r io.Reader, out chan<- net.Buffers, opts Options) error { +func Scan(ctx context.Context, r io.Reader, out chan<- Batch, opts Options) error { var rowsRead int64 reader := bufio.NewReader(r) @@ -111,10 +130,14 @@ func Scan(ctx context.Context, r io.Reader, out chan<- net.Buffers, opts Options } if bufferedRows >= opts.Size { // dispatch to COPY worker & reset - if ctx.Err() != nil { - return nil + select { + case out <- Batch{ + Data: bufs, + Location: NewLocation(rowsRead, bufferedRows, opts.Skip), + }: + case <-ctx.Done(): + return ctx.Err() } - out <- bufs bufs = make(net.Buffers, 0, opts.Size) bufferedRows = 0 } @@ -130,7 +153,14 @@ func Scan(ctx context.Context, r io.Reader, out chan<- net.Buffers, opts Options // Finished reading input, make sure last batch goes out. if len(bufs) > 0 { - out <- bufs + select { + case out <- Batch{ + Data: bufs, + Location: NewLocation(rowsRead, bufferedRows, opts.Skip), + }: + case <-ctx.Done(): + return ctx.Err() + } } return nil diff --git a/internal/batch/scan_test.go b/internal/batch/scan_test.go index cef7341..f7b8b01 100644 --- a/internal/batch/scan_test.go +++ b/internal/batch/scan_test.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "net" "reflect" "strings" "testing" @@ -239,7 +238,7 @@ d" for _, c := range cases { t.Run(c.name, func(t *testing.T) { - rowChan := make(chan net.Buffers) + rowChan := make(chan batch.Batch) resultChan := make(chan []string) // Collector for the scanned row batches. @@ -247,7 +246,7 @@ d" var actual []string for buf := range rowChan { - actual = append(actual, string(bytes.Join(buf, nil))) + actual = append(actual, string(bytes.Join(buf.Data, nil))) } resultChan <- actual @@ -302,7 +301,7 @@ d" should be discarded `), expected) - rowChan := make(chan net.Buffers, 1) + rowChan := make(chan batch.Batch, 1) opts := batch.Options{ Size: 50, Skip: c.skip, @@ -411,7 +410,7 @@ func BenchmarkScan(b *testing.B) { b.Run(name, func(b *testing.B) { // Make sure our output channel won't block. This relies on each // call to Scan() producing exactly one batch. - rowChan := make(chan net.Buffers, b.N) + rowChan := make(chan batch.Batch, b.N) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/pkg/csvcopy/csvcopy.go b/pkg/csvcopy/csvcopy.go index 0e25a5f..391ba80 100644 --- a/pkg/csvcopy/csvcopy.go +++ b/pkg/csvcopy/csvcopy.go @@ -5,12 +5,14 @@ import ( "errors" "fmt" "io" - "net" + "regexp" + "strconv" "strings" "sync" "sync/atomic" "time" + "github.com/jackc/pgconn" _ "github.com/jackc/pgx/v4/stdlib" "github.com/timescale/timescaledb-parallel-copy/internal/batch" "github.com/timescale/timescaledb-parallel-copy/internal/db" @@ -161,7 +163,7 @@ func (c *Copier) Truncate() (err error) { func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { var wg sync.WaitGroup - batchChan := make(chan net.Buffers, c.workers*2) + batchChan := make(chan batch.Batch, c.workers*2) ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -230,9 +232,50 @@ func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { }, err } +type ErrAtRow struct { + Err error + Row int64 +} + +func ErrAtRowFromPGError(pgerr *pgconn.PgError, offset int64) *ErrAtRow { + // Example of Where field + // "COPY metrics, line 1, column value: \"hello\"" + match := regexp.MustCompile(`line (\d+)`).FindStringSubmatch(pgerr.Where) + if len(match) != 2 { + return &ErrAtRow{ + Err: pgerr, + Row: -1, + } + } + + line, err := strconv.Atoi(match[1]) + if err != nil { + return &ErrAtRow{ + Err: pgerr, + Row: -1, + } + } + + return &ErrAtRow{ + Err: pgerr, + Row: offset + int64(line), + } +} + +func (e *ErrAtRow) Error() string { + if e.Err != nil { + return fmt.Sprintf("at row %d, error %s", e.Row, e.Err.Error()) + } + return fmt.Sprintf("error at row %d", e.Row) +} + +func (e *ErrAtRow) Unwrap() error { + return e.Err +} + // processBatches reads batches from channel c and copies them to the target // server while tracking stats on the write. -func (c *Copier) processBatches(ctx context.Context, ch chan net.Buffers) (err error) { +func (c *Copier) processBatches(ctx context.Context, ch chan batch.Batch) (err error) { dbx, err := db.Connect(c.dbURL, c.overrides...) if err != nil { return err @@ -273,15 +316,18 @@ func (c *Copier) processBatches(ctx context.Context, ch chan net.Buffers) (err e return } start := time.Now() - rows, err := db.CopyFromLines(ctx, dbx, &batch, copyCmd) + rows, err := db.CopyFromLines(ctx, dbx, &batch.Data, copyCmd) if err != nil { - return err + if pgerr, ok := err.(*pgconn.PgError); ok { + return ErrAtRowFromPGError(pgerr, batch.Location.StartRow) + } + return fmt.Errorf("[BATCH] starting at row %d: %w", batch.Location.StartRow, err) } atomic.AddInt64(&c.rowCount, rows) if c.logBatches { took := time.Since(start) - fmt.Printf("[BATCH] took %v, batch size %d, row rate %f/sec\n", took, c.batchSize, float64(c.batchSize)/float64(took.Seconds())) + fmt.Printf("[BATCH] starting at row %d, took %v, batch size %d, row rate %f/sec\n", batch.Location.StartRow, took, batch.Location.Length, float64(batch.Location.Length)/float64(took.Seconds())) } } } diff --git a/pkg/csvcopy/csvcopy_test.go b/pkg/csvcopy/csvcopy_test.go index 300ad76..4836385 100644 --- a/pkg/csvcopy/csvcopy_test.go +++ b/pkg/csvcopy/csvcopy_test.go @@ -97,3 +97,69 @@ func TestWriteDataToCSV(t *testing.T) { require.NoError(t, err) assert.Equal(t, []interface{}{int32(24), "qased", 2.4}, results) } + +func TestErrorAtRow(t *testing.T) { + ctx := context.Background() + + pgContainer, err := postgres.RunContainer(ctx, + testcontainers.WithImage("postgres:15.3-alpine"), + postgres.WithDatabase("test-db"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second)), + ) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := pgContainer.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate pgContainer: %s", err) + } + }) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + _, err = conn.Exec(ctx, "create table public.metrics (device_id int, label text, value float8)") + require.NoError(t, err) + + // Create a temporary CSV file + tmpfile, err := os.CreateTemp("", "example") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + // Write data to the CSV file + writer := csv.NewWriter(tmpfile) + + data := [][]string{ + {"42", "xasev", "4.2"}, + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + {"24", "qased", "hello"}, + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + } + + for _, record := range data { + if err := writer.Write(record); err != nil { + t.Fatalf("Error writing record to CSV: %v", err) + } + } + + writer.Flush() + + copier, err := NewCopier(connStr, "test-db", "public", "metrics", "CSV", ",", "", "", "device_id,label,value", false, 1, 1, 0, 2, true, 0, false) + require.NoError(t, err) + reader, err := os.Open(tmpfile.Name()) + require.NoError(t, err) + _, err = copier.Copy(context.Background(), reader) + assert.Error(t, err) + assert.IsType(t, err, &ErrAtRow{}) + assert.EqualValues(t, 4, err.(*ErrAtRow).Row) +}