Skip to content

Commit 4aa8b5c

Browse files
authored
Merge pull request #80 from timescale/adn/error-handling
Return error instead of panicking
2 parents 1a1d2ae + 2cc32f0 commit 4aa8b5c

File tree

8 files changed

+102
-55
lines changed

8 files changed

+102
-55
lines changed

cmd/timescaledb-parallel-copy/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package main
33

44
import (
5+
"context"
56
"flag"
67
"fmt"
78
"io"
@@ -132,7 +133,7 @@ func main() {
132133
reader = os.Stdin
133134
}
134135

135-
result, err := copier.Copy(reader)
136+
result, err := copier.Copy(context.Background(), reader)
136137
if err != nil {
137138
log.Fatal("failed to copy CSV:", err)
138139
}

internal/batch/scan.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package batch
22

33
import (
44
"bufio"
5+
"context"
56
"fmt"
67
"io"
78
"net"
@@ -27,7 +28,7 @@ type Options struct {
2728
// Scan expects the input to be in Postgres CSV format. Since this format allows
2829
// rows to be split over multiple lines, the caller may provide opts.Quote and
2930
// opts.Escape as the QUOTE and ESCAPE characters used for the CSV input.
30-
func Scan(r io.Reader, out chan<- net.Buffers, opts Options) error {
31+
func Scan(ctx context.Context, r io.Reader, out chan<- net.Buffers, opts Options) error {
3132
var rowsRead int64
3233
reader := bufio.NewReader(r)
3334

@@ -110,6 +111,9 @@ func Scan(r io.Reader, out chan<- net.Buffers, opts Options) error {
110111
}
111112

112113
if bufferedRows >= opts.Size { // dispatch to COPY worker & reset
114+
if ctx.Err() != nil {
115+
return nil
116+
}
113117
out <- bufs
114118
bufs = make(net.Buffers, 0, opts.Size)
115119
bufferedRows = 0

internal/batch/scan_internal_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package batch
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"os"
@@ -227,7 +228,7 @@ func TestCSVRowState(t *testing.T) {
227228
allLines := strings.Join(c.input, "")
228229
copyCmd := fmt.Sprintf(`COPY csv FROM STDIN WITH %s`, copyOpts)
229230

230-
num, err := db.CopyFromLines(d, strings.NewReader(allLines), copyCmd)
231+
num, err := db.CopyFromLines(context.Background(), d, strings.NewReader(allLines), copyCmd)
231232

232233
if c.expectMore {
233234
// If our test case claimed to be unterminated, then the DB

internal/batch/scan_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package batch_test
22

33
import (
44
"bytes"
5+
"context"
56
"errors"
67
"fmt"
78
"io"
@@ -262,7 +263,7 @@ d"
262263
Escape: byte(c.escape),
263264
}
264265

265-
err := batch.Scan(reader, rowChan, opts)
266+
err := batch.Scan(context.Background(), reader, rowChan, opts)
266267
if err != nil {
267268
t.Fatalf("Scan() returned error: %v", err)
268269
}
@@ -307,7 +308,7 @@ d"
307308
Skip: c.skip,
308309
}
309310

310-
err := batch.Scan(reader, rowChan, opts)
311+
err := batch.Scan(context.Background(), reader, rowChan, opts)
311312
if !errors.Is(err, expected) {
312313
t.Errorf("Scan() returned unexpected error: %v", err)
313314
t.Logf("want: %v", expected)
@@ -416,7 +417,7 @@ func BenchmarkScan(b *testing.B) {
416417
for i := 0; i < b.N; i++ {
417418
reader.Reset(data) // rewind to the beginning
418419

419-
err := batch.Scan(reader, rowChan, opts)
420+
err := batch.Scan(context.Background(), reader, rowChan, opts)
420421
if err != nil {
421422
b.Errorf("Scan() returned unexpected error: %v", err)
422423
}

internal/db/db.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ func Connect(connStr string, overrides ...Overrideable) (*sqlx.DB, error) {
156156
// CopyFromLines bulk-loads data using the given copyCmd. lines must provide a
157157
// set of complete lines of CSV data, including the end-of-line delimiters.
158158
// Returns the number of rows inserted.
159-
func CopyFromLines(db *sqlx.DB, lines io.Reader, copyCmd string) (int64, error) {
160-
conn, err := db.Conn(context.Background())
159+
func CopyFromLines(ctx context.Context, db *sqlx.DB, lines io.Reader, copyCmd string) (int64, error) {
160+
conn, err := db.Conn(ctx)
161161
if err != nil {
162162
return 0, fmt.Errorf("acquiring DB connection for COPY: %w", err)
163163
}
@@ -171,7 +171,7 @@ func CopyFromLines(db *sqlx.DB, lines io.Reader, copyCmd string) (int64, error)
171171
// the pgx.Conn, and the pgconn.PgConn.
172172
pg := driverConn.(*stdlib.Conn).Conn().PgConn()
173173

174-
result, err := pg.CopyFrom(context.Background(), lines, copyCmd)
174+
result, err := pg.CopyFrom(ctx, lines, copyCmd)
175175
if err != nil {
176176
return err
177177
}

internal/db/db_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package db_test
22

33
import (
44
"bytes"
5+
"context"
56
"errors"
67
"os"
78
"reflect"
@@ -153,7 +154,7 @@ func TestCopyFromLines(t *testing.T) {
153154

154155
// Load the rows into it.
155156
allLines := strings.Join(append(c.lines, ""), "\n")
156-
num, err := db.CopyFromLines(d, strings.NewReader(allLines), c.copyCmd)
157+
num, err := db.CopyFromLines(context.Background(), d, strings.NewReader(allLines), c.copyCmd)
157158
if err != nil {
158159
t.Errorf("CopyFromLines() returned error: %v", err)
159160
}
@@ -198,7 +199,7 @@ func TestCopyFromLines(t *testing.T) {
198199
lines := bytes.Repeat([]byte{'\n'}, 10000)
199200
badCopy := `COPY BUT NOT REALLY`
200201

201-
num, err := db.CopyFromLines(d, bytes.NewReader(lines), badCopy)
202+
num, err := db.CopyFromLines(context.Background(), d, bytes.NewReader(lines), badCopy)
202203
if num != 0 {
203204
t.Errorf("CopyFromLines() reported %d new rows, want 0", num)
204205
}
@@ -249,7 +250,7 @@ func TestCopyFromLines(t *testing.T) {
249250
}
250251

251252
lineData := strings.Join(append(lines, ""), "\n")
252-
_, err := db.CopyFromLines(d, strings.NewReader(lineData), cmd)
253+
_, err := db.CopyFromLines(context.Background(), d, strings.NewReader(lineData), cmd)
253254
if err != nil {
254255
t.Fatalf("CopyFromLines() returned error: %v", err)
255256
}

pkg/csvcopy/csvcopy.go

Lines changed: 81 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package csvcopy
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"io"
@@ -152,19 +153,31 @@ func (c *Copier) Truncate() (err error) {
152153
return err
153154
}
154155

155-
func (c *Copier) Copy(reader io.Reader) (Result, error) {
156+
func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) {
156157
var wg sync.WaitGroup
157158
batchChan := make(chan net.Buffers, c.workers*2)
158159

160+
ctx, cancel := context.WithCancel(ctx)
161+
defer cancel()
162+
errCh := make(chan error, c.workers+1)
163+
159164
// Generate COPY workers
160165
for i := 0; i < c.workers; i++ {
161166
wg.Add(1)
162-
go c.processBatches(&wg, batchChan)
167+
go func() {
168+
defer wg.Done()
169+
err := c.processBatches(ctx, batchChan)
170+
if err != nil {
171+
errCh <- err
172+
cancel()
173+
}
174+
}()
175+
163176
}
164177

165178
// Reporting thread
166179
if c.reportingPeriod > (0 * time.Second) {
167-
go c.report()
180+
go c.report(ctx)
168181
}
169182

170183
opts := batch.Options{
@@ -183,12 +196,21 @@ func (c *Copier) Copy(reader io.Reader) (Result, error) {
183196
}
184197

185198
start := time.Now()
186-
if err := batch.Scan(reader, batchChan, opts); err != nil {
187-
return Result{}, fmt.Errorf("failed reading input: %w", err)
188-
}
189-
190-
close(batchChan)
199+
wg.Add(1)
200+
go func() {
201+
defer wg.Done()
202+
if err := batch.Scan(ctx, reader, batchChan, opts); err != nil {
203+
errCh <- fmt.Errorf("failed reading input: %w", err)
204+
cancel()
205+
}
206+
close(batchChan)
207+
}()
191208
wg.Wait()
209+
close(errCh)
210+
// We are only interested on the first error message since all other errors
211+
// must probably are related to the context being canceled.
212+
err := <-errCh
213+
192214
end := time.Now()
193215
took := end.Sub(start)
194216

@@ -199,15 +221,15 @@ func (c *Copier) Copy(reader io.Reader) (Result, error) {
199221
RowsRead: rowsRead,
200222
Duration: took,
201223
RowRate: rowRate,
202-
}, nil
224+
}, err
203225
}
204226

205227
// processBatches reads batches from channel c and copies them to the target
206228
// server while tracking stats on the write.
207-
func (c *Copier) processBatches(wg *sync.WaitGroup, ch chan net.Buffers) {
229+
func (c *Copier) processBatches(ctx context.Context, ch chan net.Buffers) (err error) {
208230
dbx, err := db.Connect(c.dbURL, c.overrides...)
209231
if err != nil {
210-
panic(err)
232+
return err
211233
}
212234
defer dbx.Close()
213235

@@ -233,46 +255,63 @@ func (c *Copier) processBatches(wg *sync.WaitGroup, ch chan net.Buffers) {
233255
copyCmd = fmt.Sprintf("COPY %s FROM STDIN WITH DELIMITER %s %s %s", c.getFullTableName(), delimStr, quotes, c.copyOptions)
234256
}
235257

236-
for batch := range ch {
237-
start := time.Now()
238-
rows, err := db.CopyFromLines(dbx, &batch, copyCmd)
239-
if err != nil {
240-
panic(err)
258+
for {
259+
if ctx.Err() != nil {
260+
return nil
241261
}
242-
atomic.AddInt64(&c.rowCount, rows)
243-
244-
if c.logBatches {
245-
took := time.Since(start)
246-
fmt.Printf("[BATCH] took %v, batch size %d, row rate %f/sec\n", took, c.batchSize, float64(c.batchSize)/float64(took.Seconds()))
262+
select {
263+
case <-ctx.Done():
264+
return nil
265+
case batch, ok := <-ch:
266+
if !ok {
267+
return
268+
}
269+
start := time.Now()
270+
rows, err := db.CopyFromLines(ctx, dbx, &batch, copyCmd)
271+
if err != nil {
272+
return err
273+
}
274+
atomic.AddInt64(&c.rowCount, rows)
275+
276+
if c.logBatches {
277+
took := time.Since(start)
278+
fmt.Printf("[BATCH] took %v, batch size %d, row rate %f/sec\n", took, c.batchSize, float64(c.batchSize)/float64(took.Seconds()))
279+
}
247280
}
248281
}
249-
wg.Done()
250282
}
251283

252284
// report periodically prints the write rate in number of rows per second
253-
func (c *Copier) report() {
285+
func (c *Copier) report(ctx context.Context) {
254286
start := time.Now()
255287
prevTime := start
256288
prevRowCount := int64(0)
257-
258-
for now := range time.NewTicker(c.reportingPeriod).C {
259-
rCount := atomic.LoadInt64(&c.rowCount)
260-
261-
took := now.Sub(prevTime)
262-
rowrate := float64(rCount-prevRowCount) / float64(took.Seconds())
263-
overallRowrate := float64(rCount) / float64(now.Sub(start).Seconds())
264-
totalTook := now.Sub(start)
265-
266-
c.logger.Infof(
267-
"at %v, row rate %0.2f/sec (period), row rate %0.2f/sec (overall), %E total rows",
268-
totalTook-(totalTook%time.Second),
269-
rowrate,
270-
overallRowrate,
271-
float64(rCount),
272-
)
273-
274-
prevRowCount = rCount
275-
prevTime = now
289+
ticker := time.NewTicker(c.reportingPeriod)
290+
defer ticker.Stop()
291+
292+
for {
293+
select {
294+
case now := <-ticker.C:
295+
rCount := atomic.LoadInt64(&c.rowCount)
296+
297+
took := now.Sub(prevTime)
298+
rowrate := float64(rCount-prevRowCount) / float64(took.Seconds())
299+
overallRowrate := float64(rCount) / float64(now.Sub(start).Seconds())
300+
totalTook := now.Sub(start)
301+
302+
c.logger.Infof(
303+
"at %v, row rate %0.2f/sec (period), row rate %0.2f/sec (overall), %E total rows",
304+
totalTook-(totalTook%time.Second),
305+
rowrate,
306+
overallRowrate,
307+
float64(rCount),
308+
)
309+
310+
prevRowCount = rCount
311+
prevTime = now
312+
case <-ctx.Done():
313+
return
314+
}
276315
}
277316
}
278317

pkg/csvcopy/csvcopy_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func TestWriteDataToCSV(t *testing.T) {
7272

7373
reader, err := os.Open(tmpfile.Name())
7474
require.NoError(t, err)
75-
r, err := copier.Copy(reader)
75+
r, err := copier.Copy(context.Background(), reader)
7676
require.NoError(t, err)
7777
require.NotNil(t, r)
7878

0 commit comments

Comments
 (0)