From 08768609f0357c9e2fc8ff0a08684ea278f26a62 Mon Sep 17 00:00:00 2001 From: Marco Nicola Date: Mon, 17 Oct 2022 20:10:26 +0200 Subject: [PATCH] Return a custom error from BART text2text if sequence is too long instead of panicking --- pkg/tasks/text2text/bart/text2text.go | 5 +++++ pkg/tasks/text2text/text2text.go | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/pkg/tasks/text2text/bart/text2text.go b/pkg/tasks/text2text/bart/text2text.go index 2e89431..5178e3d 100644 --- a/pkg/tasks/text2text/bart/text2text.go +++ b/pkg/tasks/text2text/bart/text2text.go @@ -135,6 +135,11 @@ func (m *Text2Text) Generate(ctx context.Context, text string, opts *text2text.O if err != nil { return text2text.Response{}, err } + + if l, max := len(tokenized), m.Model.Bart.Config.MaxLength; l > max { + return text2text.Response{}, fmt.Errorf("%w: %d > %d", text2text.ErrInputSequenceTooLong, l, max) + } + sequences, scores := m.process(ctx, tokenized, *opts) result := text2text.Response{ Texts: make([]string, len(sequences)), diff --git a/pkg/tasks/text2text/text2text.go b/pkg/tasks/text2text/text2text.go index ffe12bf..467ec52 100644 --- a/pkg/tasks/text2text/text2text.go +++ b/pkg/tasks/text2text/text2text.go @@ -6,6 +6,7 @@ package text2text import ( "context" + "errors" "fmt" "strings" @@ -76,6 +77,10 @@ type Response struct { Scores []float64 } +// ErrInputSequenceTooLong means that pre-processing the input text +// produced a sequence that exceeds the maximum allowed length. +var ErrInputSequenceTooLong = errors.New("sequence too long") + // DefaultOptions returns the default options for generating text. func DefaultOptions() *Options { return &Options{