diff --git a/cmd/timescaledb-parallel-copy/main.go b/cmd/timescaledb-parallel-copy/main.go index 07959e5..becc021 100644 --- a/cmd/timescaledb-parallel-copy/main.go +++ b/cmd/timescaledb-parallel-copy/main.go @@ -2,6 +2,7 @@ package main import ( + "context" "flag" "fmt" "io" @@ -132,7 +133,7 @@ func main() { reader = os.Stdin } - result, err := copier.Copy(reader) + result, err := copier.Copy(context.Background(), reader) if err != nil { log.Fatal("failed to copy CSV:", err) } diff --git a/internal/batch/scan.go b/internal/batch/scan.go index 8e15a01..76f0453 100644 --- a/internal/batch/scan.go +++ b/internal/batch/scan.go @@ -2,6 +2,7 @@ package batch import ( "bufio" + "context" "fmt" "io" "net" @@ -27,7 +28,7 @@ type Options struct { // Scan expects the input to be in Postgres CSV format. Since this format allows // rows to be split over multiple lines, the caller may provide opts.Quote and // opts.Escape as the QUOTE and ESCAPE characters used for the CSV input. -func Scan(r io.Reader, out chan<- net.Buffers, opts Options) error { +func Scan(ctx context.Context, r io.Reader, out chan<- net.Buffers, opts Options) error { var rowsRead int64 reader := bufio.NewReader(r) @@ -110,6 +111,9 @@ func Scan(r io.Reader, out chan<- net.Buffers, opts Options) error { } if bufferedRows >= opts.Size { // dispatch to COPY worker & reset + if ctx.Err() != nil { + return nil + } out <- bufs bufs = make(net.Buffers, 0, opts.Size) bufferedRows = 0 diff --git a/internal/batch/scan_internal_test.go b/internal/batch/scan_internal_test.go index 1fce4c9..2d3adf0 100644 --- a/internal/batch/scan_internal_test.go +++ b/internal/batch/scan_internal_test.go @@ -1,6 +1,7 @@ package batch import ( + "context" "errors" "fmt" "os" @@ -227,7 +228,7 @@ func TestCSVRowState(t *testing.T) { allLines := strings.Join(c.input, "") copyCmd := fmt.Sprintf(`COPY csv FROM STDIN WITH %s`, copyOpts) - num, err := db.CopyFromLines(d, strings.NewReader(allLines), copyCmd) + num, err := db.CopyFromLines(context.Background(), d, strings.NewReader(allLines), copyCmd) if c.expectMore { // If our test case claimed to be unterminated, then the DB diff --git a/internal/batch/scan_test.go b/internal/batch/scan_test.go index cc624b6..cef7341 100644 --- a/internal/batch/scan_test.go +++ b/internal/batch/scan_test.go @@ -2,6 +2,7 @@ package batch_test import ( "bytes" + "context" "errors" "fmt" "io" @@ -262,7 +263,7 @@ d" Escape: byte(c.escape), } - err := batch.Scan(reader, rowChan, opts) + err := batch.Scan(context.Background(), reader, rowChan, opts) if err != nil { t.Fatalf("Scan() returned error: %v", err) } @@ -307,7 +308,7 @@ d" Skip: c.skip, } - err := batch.Scan(reader, rowChan, opts) + err := batch.Scan(context.Background(), reader, rowChan, opts) if !errors.Is(err, expected) { t.Errorf("Scan() returned unexpected error: %v", err) t.Logf("want: %v", expected) @@ -416,7 +417,7 @@ func BenchmarkScan(b *testing.B) { for i := 0; i < b.N; i++ { reader.Reset(data) // rewind to the beginning - err := batch.Scan(reader, rowChan, opts) + err := batch.Scan(context.Background(), reader, rowChan, opts) if err != nil { b.Errorf("Scan() returned unexpected error: %v", err) } diff --git a/internal/db/db.go b/internal/db/db.go index aa2c505..2604169 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -156,8 +156,8 @@ func Connect(connStr string, overrides ...Overrideable) (*sqlx.DB, error) { // CopyFromLines bulk-loads data using the given copyCmd. lines must provide a // set of complete lines of CSV data, including the end-of-line delimiters. // Returns the number of rows inserted. -func CopyFromLines(db *sqlx.DB, lines io.Reader, copyCmd string) (int64, error) { - conn, err := db.Conn(context.Background()) +func CopyFromLines(ctx context.Context, db *sqlx.DB, lines io.Reader, copyCmd string) (int64, error) { + conn, err := db.Conn(ctx) if err != nil { return 0, fmt.Errorf("acquiring DB connection for COPY: %w", err) } @@ -171,7 +171,7 @@ func CopyFromLines(db *sqlx.DB, lines io.Reader, copyCmd string) (int64, error) // the pgx.Conn, and the pgconn.PgConn. pg := driverConn.(*stdlib.Conn).Conn().PgConn() - result, err := pg.CopyFrom(context.Background(), lines, copyCmd) + result, err := pg.CopyFrom(ctx, lines, copyCmd) if err != nil { return err } diff --git a/internal/db/db_test.go b/internal/db/db_test.go index e6058f3..b4932ea 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -2,6 +2,7 @@ package db_test import ( "bytes" + "context" "errors" "os" "reflect" @@ -153,7 +154,7 @@ func TestCopyFromLines(t *testing.T) { // Load the rows into it. allLines := strings.Join(append(c.lines, ""), "\n") - num, err := db.CopyFromLines(d, strings.NewReader(allLines), c.copyCmd) + num, err := db.CopyFromLines(context.Background(), d, strings.NewReader(allLines), c.copyCmd) if err != nil { t.Errorf("CopyFromLines() returned error: %v", err) } @@ -198,7 +199,7 @@ func TestCopyFromLines(t *testing.T) { lines := bytes.Repeat([]byte{'\n'}, 10000) badCopy := `COPY BUT NOT REALLY` - num, err := db.CopyFromLines(d, bytes.NewReader(lines), badCopy) + num, err := db.CopyFromLines(context.Background(), d, bytes.NewReader(lines), badCopy) if num != 0 { t.Errorf("CopyFromLines() reported %d new rows, want 0", num) } @@ -249,7 +250,7 @@ func TestCopyFromLines(t *testing.T) { } lineData := strings.Join(append(lines, ""), "\n") - _, err := db.CopyFromLines(d, strings.NewReader(lineData), cmd) + _, err := db.CopyFromLines(context.Background(), d, strings.NewReader(lineData), cmd) if err != nil { t.Fatalf("CopyFromLines() returned error: %v", err) } diff --git a/pkg/csvcopy/csvcopy.go b/pkg/csvcopy/csvcopy.go index c91d55e..98a53ff 100644 --- a/pkg/csvcopy/csvcopy.go +++ b/pkg/csvcopy/csvcopy.go @@ -1,6 +1,7 @@ package csvcopy import ( + "context" "errors" "fmt" "io" @@ -152,19 +153,31 @@ func (c *Copier) Truncate() (err error) { return err } -func (c *Copier) Copy(reader io.Reader) (Result, error) { +func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { var wg sync.WaitGroup batchChan := make(chan net.Buffers, c.workers*2) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + errCh := make(chan error, c.workers+1) + // Generate COPY workers for i := 0; i < c.workers; i++ { wg.Add(1) - go c.processBatches(&wg, batchChan) + go func() { + defer wg.Done() + err := c.processBatches(ctx, batchChan) + if err != nil { + errCh <- err + cancel() + } + }() + } // Reporting thread if c.reportingPeriod > (0 * time.Second) { - go c.report() + go c.report(ctx) } opts := batch.Options{ @@ -183,12 +196,21 @@ func (c *Copier) Copy(reader io.Reader) (Result, error) { } start := time.Now() - if err := batch.Scan(reader, batchChan, opts); err != nil { - return Result{}, fmt.Errorf("failed reading input: %w", err) - } - - close(batchChan) + wg.Add(1) + go func() { + defer wg.Done() + if err := batch.Scan(ctx, reader, batchChan, opts); err != nil { + errCh <- fmt.Errorf("failed reading input: %w", err) + cancel() + } + close(batchChan) + }() wg.Wait() + close(errCh) + // We are only interested on the first error message since all other errors + // must probably are related to the context being canceled. + err := <-errCh + end := time.Now() took := end.Sub(start) @@ -199,15 +221,15 @@ func (c *Copier) Copy(reader io.Reader) (Result, error) { RowsRead: rowsRead, Duration: took, RowRate: rowRate, - }, nil + }, err } // processBatches reads batches from channel c and copies them to the target // server while tracking stats on the write. -func (c *Copier) processBatches(wg *sync.WaitGroup, ch chan net.Buffers) { +func (c *Copier) processBatches(ctx context.Context, ch chan net.Buffers) (err error) { dbx, err := db.Connect(c.dbURL, c.overrides...) if err != nil { - panic(err) + return err } defer dbx.Close() @@ -233,46 +255,63 @@ func (c *Copier) processBatches(wg *sync.WaitGroup, ch chan net.Buffers) { copyCmd = fmt.Sprintf("COPY %s FROM STDIN WITH DELIMITER %s %s %s", c.getFullTableName(), delimStr, quotes, c.copyOptions) } - for batch := range ch { - start := time.Now() - rows, err := db.CopyFromLines(dbx, &batch, copyCmd) - if err != nil { - panic(err) + for { + if ctx.Err() != nil { + return nil } - atomic.AddInt64(&c.rowCount, rows) - - if c.logBatches { - took := time.Since(start) - fmt.Printf("[BATCH] took %v, batch size %d, row rate %f/sec\n", took, c.batchSize, float64(c.batchSize)/float64(took.Seconds())) + select { + case <-ctx.Done(): + return nil + case batch, ok := <-ch: + if !ok { + return + } + start := time.Now() + rows, err := db.CopyFromLines(ctx, dbx, &batch, copyCmd) + if err != nil { + return err + } + atomic.AddInt64(&c.rowCount, rows) + + if c.logBatches { + took := time.Since(start) + fmt.Printf("[BATCH] took %v, batch size %d, row rate %f/sec\n", took, c.batchSize, float64(c.batchSize)/float64(took.Seconds())) + } } } - wg.Done() } // report periodically prints the write rate in number of rows per second -func (c *Copier) report() { +func (c *Copier) report(ctx context.Context) { start := time.Now() prevTime := start prevRowCount := int64(0) - - for now := range time.NewTicker(c.reportingPeriod).C { - rCount := atomic.LoadInt64(&c.rowCount) - - took := now.Sub(prevTime) - rowrate := float64(rCount-prevRowCount) / float64(took.Seconds()) - overallRowrate := float64(rCount) / float64(now.Sub(start).Seconds()) - totalTook := now.Sub(start) - - c.logger.Infof( - "at %v, row rate %0.2f/sec (period), row rate %0.2f/sec (overall), %E total rows", - totalTook-(totalTook%time.Second), - rowrate, - overallRowrate, - float64(rCount), - ) - - prevRowCount = rCount - prevTime = now + ticker := time.NewTicker(c.reportingPeriod) + defer ticker.Stop() + + for { + select { + case now := <-ticker.C: + rCount := atomic.LoadInt64(&c.rowCount) + + took := now.Sub(prevTime) + rowrate := float64(rCount-prevRowCount) / float64(took.Seconds()) + overallRowrate := float64(rCount) / float64(now.Sub(start).Seconds()) + totalTook := now.Sub(start) + + c.logger.Infof( + "at %v, row rate %0.2f/sec (period), row rate %0.2f/sec (overall), %E total rows", + totalTook-(totalTook%time.Second), + rowrate, + overallRowrate, + float64(rCount), + ) + + prevRowCount = rCount + prevTime = now + case <-ctx.Done(): + return + } } } diff --git a/pkg/csvcopy/csvcopy_test.go b/pkg/csvcopy/csvcopy_test.go index 717eb64..300ad76 100644 --- a/pkg/csvcopy/csvcopy_test.go +++ b/pkg/csvcopy/csvcopy_test.go @@ -72,7 +72,7 @@ func TestWriteDataToCSV(t *testing.T) { reader, err := os.Open(tmpfile.Name()) require.NoError(t, err) - r, err := copier.Copy(reader) + r, err := copier.Copy(context.Background(), reader) require.NoError(t, err) require.NotNil(t, r)