Skip to content

Commit

Permalink
Merge pull request #80 from timescale/adn/error-handling
Browse files Browse the repository at this point in the history
Return error instead of panicking
  • Loading branch information
alejandrodnm authored Jul 3, 2024
2 parents 1a1d2ae + 2cc32f0 commit 4aa8b5c
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 55 deletions.
3 changes: 2 additions & 1 deletion cmd/timescaledb-parallel-copy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package main

import (
"context"
"flag"
"fmt"
"io"
Expand Down Expand Up @@ -132,7 +133,7 @@ func main() {
reader = os.Stdin
}

result, err := copier.Copy(reader)
result, err := copier.Copy(context.Background(), reader)
if err != nil {
log.Fatal("failed to copy CSV:", err)
}
Expand Down
6 changes: 5 additions & 1 deletion internal/batch/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package batch

import (
"bufio"
"context"
"fmt"
"io"
"net"
Expand All @@ -27,7 +28,7 @@ type Options struct {
// Scan expects the input to be in Postgres CSV format. Since this format allows
// rows to be split over multiple lines, the caller may provide opts.Quote and
// opts.Escape as the QUOTE and ESCAPE characters used for the CSV input.
func Scan(r io.Reader, out chan<- net.Buffers, opts Options) error {
func Scan(ctx context.Context, r io.Reader, out chan<- net.Buffers, opts Options) error {
var rowsRead int64
reader := bufio.NewReader(r)

Expand Down Expand Up @@ -110,6 +111,9 @@ func Scan(r io.Reader, out chan<- net.Buffers, opts Options) error {
}

if bufferedRows >= opts.Size { // dispatch to COPY worker & reset
if ctx.Err() != nil {
return nil
}
out <- bufs
bufs = make(net.Buffers, 0, opts.Size)
bufferedRows = 0
Expand Down
3 changes: 2 additions & 1 deletion internal/batch/scan_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package batch

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -227,7 +228,7 @@ func TestCSVRowState(t *testing.T) {
allLines := strings.Join(c.input, "")
copyCmd := fmt.Sprintf(`COPY csv FROM STDIN WITH %s`, copyOpts)

num, err := db.CopyFromLines(d, strings.NewReader(allLines), copyCmd)
num, err := db.CopyFromLines(context.Background(), d, strings.NewReader(allLines), copyCmd)

if c.expectMore {
// If our test case claimed to be unterminated, then the DB
Expand Down
7 changes: 4 additions & 3 deletions internal/batch/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package batch_test

import (
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -262,7 +263,7 @@ d"
Escape: byte(c.escape),
}

err := batch.Scan(reader, rowChan, opts)
err := batch.Scan(context.Background(), reader, rowChan, opts)
if err != nil {
t.Fatalf("Scan() returned error: %v", err)
}
Expand Down Expand Up @@ -307,7 +308,7 @@ d"
Skip: c.skip,
}

err := batch.Scan(reader, rowChan, opts)
err := batch.Scan(context.Background(), reader, rowChan, opts)
if !errors.Is(err, expected) {
t.Errorf("Scan() returned unexpected error: %v", err)
t.Logf("want: %v", expected)
Expand Down Expand Up @@ -416,7 +417,7 @@ func BenchmarkScan(b *testing.B) {
for i := 0; i < b.N; i++ {
reader.Reset(data) // rewind to the beginning

err := batch.Scan(reader, rowChan, opts)
err := batch.Scan(context.Background(), reader, rowChan, opts)
if err != nil {
b.Errorf("Scan() returned unexpected error: %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ func Connect(connStr string, overrides ...Overrideable) (*sqlx.DB, error) {
// CopyFromLines bulk-loads data using the given copyCmd. lines must provide a
// set of complete lines of CSV data, including the end-of-line delimiters.
// Returns the number of rows inserted.
func CopyFromLines(db *sqlx.DB, lines io.Reader, copyCmd string) (int64, error) {
conn, err := db.Conn(context.Background())
func CopyFromLines(ctx context.Context, db *sqlx.DB, lines io.Reader, copyCmd string) (int64, error) {
conn, err := db.Conn(ctx)
if err != nil {
return 0, fmt.Errorf("acquiring DB connection for COPY: %w", err)
}
Expand All @@ -171,7 +171,7 @@ func CopyFromLines(db *sqlx.DB, lines io.Reader, copyCmd string) (int64, error)
// the pgx.Conn, and the pgconn.PgConn.
pg := driverConn.(*stdlib.Conn).Conn().PgConn()

result, err := pg.CopyFrom(context.Background(), lines, copyCmd)
result, err := pg.CopyFrom(ctx, lines, copyCmd)
if err != nil {
return err
}
Expand Down
7 changes: 4 additions & 3 deletions internal/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package db_test

import (
"bytes"
"context"
"errors"
"os"
"reflect"
Expand Down Expand Up @@ -153,7 +154,7 @@ func TestCopyFromLines(t *testing.T) {

// Load the rows into it.
allLines := strings.Join(append(c.lines, ""), "\n")
num, err := db.CopyFromLines(d, strings.NewReader(allLines), c.copyCmd)
num, err := db.CopyFromLines(context.Background(), d, strings.NewReader(allLines), c.copyCmd)
if err != nil {
t.Errorf("CopyFromLines() returned error: %v", err)
}
Expand Down Expand Up @@ -198,7 +199,7 @@ func TestCopyFromLines(t *testing.T) {
lines := bytes.Repeat([]byte{'\n'}, 10000)
badCopy := `COPY BUT NOT REALLY`

num, err := db.CopyFromLines(d, bytes.NewReader(lines), badCopy)
num, err := db.CopyFromLines(context.Background(), d, bytes.NewReader(lines), badCopy)
if num != 0 {
t.Errorf("CopyFromLines() reported %d new rows, want 0", num)
}
Expand Down Expand Up @@ -249,7 +250,7 @@ func TestCopyFromLines(t *testing.T) {
}

lineData := strings.Join(append(lines, ""), "\n")
_, err := db.CopyFromLines(d, strings.NewReader(lineData), cmd)
_, err := db.CopyFromLines(context.Background(), d, strings.NewReader(lineData), cmd)
if err != nil {
t.Fatalf("CopyFromLines() returned error: %v", err)
}
Expand Down
123 changes: 81 additions & 42 deletions pkg/csvcopy/csvcopy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package csvcopy

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -152,19 +153,31 @@ func (c *Copier) Truncate() (err error) {
return err
}

func (c *Copier) Copy(reader io.Reader) (Result, error) {
func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) {
var wg sync.WaitGroup
batchChan := make(chan net.Buffers, c.workers*2)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
errCh := make(chan error, c.workers+1)

// Generate COPY workers
for i := 0; i < c.workers; i++ {
wg.Add(1)
go c.processBatches(&wg, batchChan)
go func() {
defer wg.Done()
err := c.processBatches(ctx, batchChan)
if err != nil {
errCh <- err
cancel()
}
}()

}

// Reporting thread
if c.reportingPeriod > (0 * time.Second) {
go c.report()
go c.report(ctx)
}

opts := batch.Options{
Expand All @@ -183,12 +196,21 @@ func (c *Copier) Copy(reader io.Reader) (Result, error) {
}

start := time.Now()
if err := batch.Scan(reader, batchChan, opts); err != nil {
return Result{}, fmt.Errorf("failed reading input: %w", err)
}

close(batchChan)
wg.Add(1)
go func() {
defer wg.Done()
if err := batch.Scan(ctx, reader, batchChan, opts); err != nil {
errCh <- fmt.Errorf("failed reading input: %w", err)
cancel()
}
close(batchChan)
}()
wg.Wait()
close(errCh)
// We are only interested on the first error message since all other errors
// must probably are related to the context being canceled.
err := <-errCh

end := time.Now()
took := end.Sub(start)

Expand All @@ -199,15 +221,15 @@ func (c *Copier) Copy(reader io.Reader) (Result, error) {
RowsRead: rowsRead,
Duration: took,
RowRate: rowRate,
}, nil
}, err
}

// processBatches reads batches from channel c and copies them to the target
// server while tracking stats on the write.
func (c *Copier) processBatches(wg *sync.WaitGroup, ch chan net.Buffers) {
func (c *Copier) processBatches(ctx context.Context, ch chan net.Buffers) (err error) {
dbx, err := db.Connect(c.dbURL, c.overrides...)
if err != nil {
panic(err)
return err
}
defer dbx.Close()

Expand All @@ -233,46 +255,63 @@ func (c *Copier) processBatches(wg *sync.WaitGroup, ch chan net.Buffers) {
copyCmd = fmt.Sprintf("COPY %s FROM STDIN WITH DELIMITER %s %s %s", c.getFullTableName(), delimStr, quotes, c.copyOptions)
}

for batch := range ch {
start := time.Now()
rows, err := db.CopyFromLines(dbx, &batch, copyCmd)
if err != nil {
panic(err)
for {
if ctx.Err() != nil {
return nil
}
atomic.AddInt64(&c.rowCount, rows)

if c.logBatches {
took := time.Since(start)
fmt.Printf("[BATCH] took %v, batch size %d, row rate %f/sec\n", took, c.batchSize, float64(c.batchSize)/float64(took.Seconds()))
select {
case <-ctx.Done():
return nil
case batch, ok := <-ch:
if !ok {
return
}
start := time.Now()
rows, err := db.CopyFromLines(ctx, dbx, &batch, copyCmd)
if err != nil {
return err
}
atomic.AddInt64(&c.rowCount, rows)

if c.logBatches {
took := time.Since(start)
fmt.Printf("[BATCH] took %v, batch size %d, row rate %f/sec\n", took, c.batchSize, float64(c.batchSize)/float64(took.Seconds()))
}
}
}
wg.Done()
}

// report periodically prints the write rate in number of rows per second
func (c *Copier) report() {
func (c *Copier) report(ctx context.Context) {
start := time.Now()
prevTime := start
prevRowCount := int64(0)

for now := range time.NewTicker(c.reportingPeriod).C {
rCount := atomic.LoadInt64(&c.rowCount)

took := now.Sub(prevTime)
rowrate := float64(rCount-prevRowCount) / float64(took.Seconds())
overallRowrate := float64(rCount) / float64(now.Sub(start).Seconds())
totalTook := now.Sub(start)

c.logger.Infof(
"at %v, row rate %0.2f/sec (period), row rate %0.2f/sec (overall), %E total rows",
totalTook-(totalTook%time.Second),
rowrate,
overallRowrate,
float64(rCount),
)

prevRowCount = rCount
prevTime = now
ticker := time.NewTicker(c.reportingPeriod)
defer ticker.Stop()

for {
select {
case now := <-ticker.C:
rCount := atomic.LoadInt64(&c.rowCount)

took := now.Sub(prevTime)
rowrate := float64(rCount-prevRowCount) / float64(took.Seconds())
overallRowrate := float64(rCount) / float64(now.Sub(start).Seconds())
totalTook := now.Sub(start)

c.logger.Infof(
"at %v, row rate %0.2f/sec (period), row rate %0.2f/sec (overall), %E total rows",
totalTook-(totalTook%time.Second),
rowrate,
overallRowrate,
float64(rCount),
)

prevRowCount = rCount
prevTime = now
case <-ctx.Done():
return
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/csvcopy/csvcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestWriteDataToCSV(t *testing.T) {

reader, err := os.Open(tmpfile.Name())
require.NoError(t, err)
r, err := copier.Copy(reader)
r, err := copier.Copy(context.Background(), reader)
require.NoError(t, err)
require.NotNil(t, r)

Expand Down

0 comments on commit 4aa8b5c

Please sign in to comment.