From e5006e514b81f7b906f1c05ff9b40c2871aa7f02 Mon Sep 17 00:00:00 2001 From: Juliano Viana Date: Mon, 29 Aug 2022 11:35:12 -0400 Subject: [PATCH 1/3] first attempt at support colbert model - converter is in need of a refactoring but is working for now --- cmd/converter/converter.go | 17 ++++++++++++++++ pkg/converter/bert/convert.go | 8 +++++++- pkg/converter/bert/mapper.go | 4 ++++ pkg/models/bert/colbert.go | 37 +++++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 cmd/converter/converter.go create mode 100644 pkg/models/bert/colbert.go diff --git a/cmd/converter/converter.go b/cmd/converter/converter.go new file mode 100644 index 0000000..ce9ac2b --- /dev/null +++ b/cmd/converter/converter.go @@ -0,0 +1,17 @@ +package main + +import ( + "fmt" + "os" + + "github.com/nlpodyssey/cybertron/pkg/converter" +) + +func main() { + modelDir:=os.Args[1] + fmt.Printf("Converting model from dir %s\n", modelDir) + err:=converter.Convert[float32](modelDir,true) + if err!=nil{ + panic(err) + } +} diff --git a/pkg/converter/bert/convert.go b/pkg/converter/bert/convert.go index bb44479..c54998e 100644 --- a/pkg/converter/bert/convert.go +++ b/pkg/converter/bert/convert.go @@ -101,6 +101,7 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error { bertForSequenceClassification := bert.NewModelForSequenceClassification[T](m) bertForTokenClassification := bert.NewModelForTokenClassification[T](m) bertForSequenceEncoding := bert.NewModelForSequenceEncoding(m) + colBert := bert.NewColbertModel[T](m) { source := pyParams.Pop("bert.embeddings.word_embeddings.weight") @@ -150,7 +151,7 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error { mapTokenClassifier(bertForTokenClassification.Classifier, params) } } - + mapLinear(colBert.Linear, params) mapping := make(map[string]*mappingParam) for k, v := range params { mapping[k] = &mappingParam{value: v, matched: false} @@ -222,6 +223,11 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error { if err != nil { return err } + case "HF_ColBERT": + err := nn.DumpToFile(colBert, goModelFilename) + if err != nil { + return err + } default: panic(fmt.Errorf("bert: unsupported architecture %s", config.Architectures[0])) } diff --git a/pkg/converter/bert/mapper.go b/pkg/converter/bert/mapper.go index bdb8785..c12ca62 100644 --- a/pkg/converter/bert/mapper.go +++ b/pkg/converter/bert/mapper.go @@ -66,6 +66,10 @@ func mapTokenClassifier(model *linear.Model, params paramsMap) { params["classifier.weight"] = model.W.Value() params["classifier.bias"] = model.B.Value() } +func mapLinear(model *linear.Model, params paramsMap) { + params["linear.weight"] = model.W.Value() + params["linear.bias"] = model.B.Value() +} // mapProjectionLayer maps the projection layer parameters. func mapQAClassifier(model *linear.Model, params paramsMap) { diff --git a/pkg/models/bert/colbert.go b/pkg/models/bert/colbert.go new file mode 100644 index 0000000..4adf513 --- /dev/null +++ b/pkg/models/bert/colbert.go @@ -0,0 +1,37 @@ +package bert + +import ( + "encoding/gob" + + "github.com/nlpodyssey/spago/ag" + "github.com/nlpodyssey/spago/mat/float" + "github.com/nlpodyssey/spago/nn" + "github.com/nlpodyssey/spago/nn/linear" +) + +type ColbertModel struct { + nn.Model + // Bart is the fine-tuned BERT model. + Bert *Model + // Linear is the linear layer for dimensionality reduction + Linear *linear.Model +} + +func init() { + gob.Register(&ColbertModel{}) +} + +// NewColbertModel returns a new model for information retrieval using ColBERT +func NewColbertModel[T float.DType](bert *Model) *ColbertModel { + return &ColbertModel{ + Bert: bert, + Linear: linear.New[T](bert.Config.HiddenSize, 128), + // TODO: read size dimensionality reduction layer from config + // (artifact-config.metadata , key: dim) + } +} + +// Forward returns the representation for the provided tokens +func (m *ColbertModel) Forward(tokens []string) []ag.Node { + return m.Linear.Forward(m.Bert.Encode(tokens)...) +} From fbaddb09cbe6e70487ba7d365ee414444dc12428 Mon Sep 17 00:00:00 2001 From: Juliano Viana Date: Mon, 29 Aug 2022 17:53:12 -0400 Subject: [PATCH 2/3] refactoring bert converter to make control flow more explicit --- cmd/converter/converter.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/cmd/converter/converter.go b/cmd/converter/converter.go index ce9ac2b..3af8e45 100644 --- a/cmd/converter/converter.go +++ b/cmd/converter/converter.go @@ -3,15 +3,25 @@ package main import ( "fmt" "os" + "strconv" "github.com/nlpodyssey/cybertron/pkg/converter" ) func main() { - modelDir:=os.Args[1] + if len(os.Args) != 3 { + fmt.Printf("Usage: %s \n", os.Args[0]) + os.Exit(1) + } + modelDir := os.Args[1] + overwriteIfExists, err := strconv.ParseBool(os.Args[2]) + if err != nil { + fmt.Printf("Failed to parse overwrite_if_exists: %s\n", err) + os.Exit(1) + } fmt.Printf("Converting model from dir %s\n", modelDir) - err:=converter.Convert[float32](modelDir,true) - if err!=nil{ + err = converter.Convert[float32](modelDir, overwriteIfExists) + if err != nil { panic(err) } } From 402c690d1d24e289a624ba64a696fbfc841a2818 Mon Sep 17 00:00:00 2001 From: Juliano Viana Date: Fri, 23 Sep 2022 14:42:21 -0400 Subject: [PATCH 3/3] implementing scoring task --- pkg/converter/bert/convert.go | 101 +++++++---------- pkg/tasks/scoring/colbert/.gitignore | 1 + pkg/tasks/scoring/colbert/ranking.go | 128 ++++++++++++++++++++++ pkg/tasks/scoring/colbert/ranking_test.go | 77 +++++++++++++ 4 files changed, 245 insertions(+), 62 deletions(-) create mode 100644 pkg/tasks/scoring/colbert/.gitignore create mode 100644 pkg/tasks/scoring/colbert/ranking.go create mode 100644 pkg/tasks/scoring/colbert/ranking_test.go diff --git a/pkg/converter/bert/convert.go b/pkg/converter/bert/convert.go index c54998e..d871c5c 100644 --- a/pkg/converter/bert/convert.go +++ b/pkg/converter/bert/convert.go @@ -96,22 +96,16 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error { panic(err) } - m := bert.New[T](config, repo) - bertForQuestionAnswering := bert.NewModelForQuestionAnswering[T](m) - bertForSequenceClassification := bert.NewModelForSequenceClassification[T](m) - bertForTokenClassification := bert.NewModelForTokenClassification[T](m) - bertForSequenceEncoding := bert.NewModelForSequenceEncoding(m) - colBert := bert.NewColbertModel[T](m) - + baseModel := bert.New[T](config, repo) { source := pyParams.Pop("bert.embeddings.word_embeddings.weight") - size := m.Embeddings.Tokens.Config.Size + size := baseModel.Embeddings.Tokens.Config.Size for i := 0; i < config.VocabSize; i++ { key, _ := vocab.Term(i) if len(key) == 0 { continue // skip empty key } - item, _ := m.Embeddings.Tokens.Embedding(key) + item, _ := baseModel.Embeddings.Tokens.Embedding(key) item.ReplaceValue(mat.NewVecDense[T](source[i*size : (i+1)*size])) } } @@ -120,7 +114,7 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error { { source := pyParams.Pop("bert.embeddings.position_embeddings.weight") - dest := m.Embeddings.Positions + dest := baseModel.Embeddings.Positions for i := 0; i < config.MaxPositionEmbeddings; i++ { item, _ := dest.Embedding(i) item.ReplaceValue(mat.NewVecDense[T](source[i*cols : (i+1)*cols])) @@ -129,7 +123,7 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error { { source := pyParams.Pop("bert.embeddings.token_type_embeddings.weight") - dest := m.Embeddings.TokenTypes + dest := baseModel.Embeddings.TokenTypes for i := 0; i < config.TypeVocabSize; i++ { item, _ := dest.Embedding(i) item.ReplaceValue(mat.NewVecDense[T](source[i*cols : (i+1)*cols])) @@ -137,21 +131,37 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error { } params := make(paramsMap) - mapPooler(m.Pooler, params) - mapEmbeddingsLayerNorm(m.Embeddings.Norm, params) - mapEncoderParams(m.Encoder, params) - mapQAClassifier(bertForQuestionAnswering.Classifier, params) - - { - // both architectures map `classifier` params - switch config.Architectures[0] { - case "BertForSequenceClassification": - mapSeqClassifier(bertForSequenceClassification.Classifier, params) - case "BertForTokenClassification": - mapTokenClassifier(bertForTokenClassification.Classifier, params) - } + mapPooler(baseModel.Pooler, params) + mapEmbeddingsLayerNorm(baseModel.Embeddings.Norm, params) + mapEncoderParams(baseModel.Encoder, params) + + var finalModel any + switch config.Architectures[0] { + case "BertBase": + finalModel = baseModel + case "BertForQuestionAnswering": + qaModel := bert.NewModelForQuestionAnswering[T](baseModel) + mapQAClassifier(qaModel.Classifier, params) + finalModel = qaModel + + case "BertForSequenceClassification": + scModel := bert.NewModelForSequenceClassification[T](baseModel) + mapSeqClassifier(scModel.Classifier, params) + finalModel = scModel + + case "BertForTokenClassification": + tcModel := bert.NewModelForTokenClassification[T](baseModel) + mapTokenClassifier(tcModel.Classifier, params) + finalModel = tcModel + + case "HF_ColBERT": + colbertModel := bert.NewColbertModel[T](baseModel) + mapLinear(colbertModel.Linear, params) + finalModel = colbertModel + default: + panic(fmt.Errorf("bert: unsupported architecture %s", config.Architectures[0])) } - mapLinear(colBert.Linear, params) + mapping := make(map[string]*mappingParam) for k, v := range params { mapping[k] = &mappingParam{value: v, matched: false} @@ -192,47 +202,14 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error { } fmt.Printf("Serializing model to \"%s\"... ", goModelFilename) + err = nn.DumpToFile(finalModel, goModelFilename) + if err != nil { + return err + } if config.Architectures == nil { config.Architectures = append(config.Architectures, "BertBase") } - { - switch config.Architectures[0] { - case "BertBase": - err := nn.DumpToFile(m, goModelFilename) - if err != nil { - return err - } - case "BertModel": - err := nn.DumpToFile(bertForSequenceEncoding, goModelFilename) - if err != nil { - return err - } - case "BertForQuestionAnswering": - err := nn.DumpToFile(bertForQuestionAnswering, goModelFilename) - if err != nil { - return err - } - case "BertForSequenceClassification": - err := nn.DumpToFile(bertForSequenceClassification, goModelFilename) - if err != nil { - return err - } - case "BertForTokenClassification": - err := nn.DumpToFile(bertForTokenClassification, goModelFilename) - if err != nil { - return err - } - case "HF_ColBERT": - err := nn.DumpToFile(colBert, goModelFilename) - if err != nil { - return err - } - default: - panic(fmt.Errorf("bert: unsupported architecture %s", config.Architectures[0])) - } - } - fmt.Println("Done.") return nil diff --git a/pkg/tasks/scoring/colbert/.gitignore b/pkg/tasks/scoring/colbert/.gitignore new file mode 100644 index 0000000..d383c56 --- /dev/null +++ b/pkg/tasks/scoring/colbert/.gitignore @@ -0,0 +1 @@ +testdata diff --git a/pkg/tasks/scoring/colbert/ranking.go b/pkg/tasks/scoring/colbert/ranking.go new file mode 100644 index 0000000..b9ddcb8 --- /dev/null +++ b/pkg/tasks/scoring/colbert/ranking.go @@ -0,0 +1,128 @@ +package colbert + +import ( + "fmt" + "path" + "path/filepath" + "strings" + + "github.com/nlpodyssey/cybertron/pkg/models/bert" + "github.com/nlpodyssey/cybertron/pkg/tokenizers" + "github.com/nlpodyssey/cybertron/pkg/tokenizers/wordpiecetokenizer" + "github.com/nlpodyssey/cybertron/pkg/vocabulary" + "github.com/nlpodyssey/spago/ag" + "github.com/nlpodyssey/spago/embeddings/store/diskstore" + "github.com/nlpodyssey/spago/nn" +) + +const SpecialDocumentMarker = "[unused1]" + +const SpecialQueryMarker = "[unused0]" + +type DocumentScorer struct { + Model *bert.ColbertModel + Tokenizer *wordpiecetokenizer.WordPieceTokenizer +} + +func LoadDocumentScorer(modelPath string) (*DocumentScorer, error) { + vocab, err := vocabulary.NewFromFile(filepath.Join(modelPath, "vocab.txt")) + if err != nil { + return nil, fmt.Errorf("failed to load vocabulary: %w", err) + } + + tokenizer := wordpiecetokenizer.New(vocab) + + embeddingsRepo, err := diskstore.NewRepository(filepath.Join(modelPath, "repo"), diskstore.ReadOnlyMode) + if err != nil { + return nil, fmt.Errorf("failed to load embeddings repository: %w", err) + } + + m, err := nn.LoadFromFile[*bert.ColbertModel](path.Join(modelPath, "spago_model.bin")) + if err != nil { + return nil, fmt.Errorf("failed to load colbert model: %w", err) + } + + err = m.Bert.SetEmbeddings(embeddingsRepo) + if err != nil { + return nil, fmt.Errorf("failed to set embeddings: %w", err) + } + return &DocumentScorer{ + Model: m, + Tokenizer: tokenizer, + }, nil +} + +func (r *DocumentScorer) encode(text string, specialMarker string) []ag.Node { + tokens := r.Tokenizer.Tokenize(strings.ToLower(text)) + + stringTokens := tokenizers.GetStrings(tokens) + stringTokens = append([]string{wordpiecetokenizer.DefaultClassToken, specialMarker}, stringTokens...) + stringTokens = append(stringTokens, wordpiecetokenizer.DefaultSequenceSeparator) + embeddings := normalizeEmbeddings(r.Model.Forward(stringTokens)) + return filterEmbeddings(embeddings, stringTokens) +} + +func (r *DocumentScorer) EncodeDocument(text string) []ag.Node { + return r.encode(text, SpecialDocumentMarker) +} + +func (r *DocumentScorer) EncodeQuery(text string) []ag.Node { + return r.encode(text, SpecialQueryMarker) +} + +func (r *DocumentScorer) ScoreDocument(query []ag.Node, document []ag.Node) ag.Node { + var score ag.Node + score = ag.Scalar(0.0) + for i, q := range query { + if i < 3 || i > len(query)-1 { + continue // don't take special tokens into consideration + } + score = ag.Add(score, r.maxSimilarity(q, document)) + } + return score +} + +func (r *DocumentScorer) maxSimilarity(query ag.Node, document []ag.Node) ag.Node { + var max ag.Node + max = ag.Scalar(0.0) + for i, d := range document { + if i < 3 || i > len(document)-1 { + continue // don't take special tokens into consideration + } + sim := ag.Dot(query, d) + max = ag.Max(max, sim) + } + return max +} + +func normalizeEmbeddings(embeddings []ag.Node) []ag.Node { + // Perform l2 normalization of each embedding + normalized := make([]ag.Node, len(embeddings)) + for i, e := range embeddings { + normalized[i] = ag.DivScalar(e, ag.Sqrt(ag.ReduceSum(ag.Square(e)))) + } + return normalized +} + +func isPunctuation(token string) bool { + return token == "." || token == "," || token == "!" || token == "?" || + token == ":" || token == ";" || token == "-" || token == "'" || + token == "\"" || token == "(" || token == ")" || token == "[" || + token == "]" || token == "{" || token == "}" || token == "*" || + token == "&" || token == "%" || token == "$" || token == "#" || + token == "@" || token == "=" || token == "+" || + token == "_" || token == "~" || token == "/" || token == "\\" || + token == "|" || token == "`" || token == "^" || token == ">" || + token == "<" +} + +func filterEmbeddings(embeddings []ag.Node, tokens []string) []ag.Node { + filtered := make([]ag.Node, 0, len(embeddings)) + for i, e := range embeddings { + if isPunctuation(tokens[i]) { + continue + } + filtered = append(filtered, e) + } + return filtered +} diff --git a/pkg/tasks/scoring/colbert/ranking_test.go b/pkg/tasks/scoring/colbert/ranking_test.go new file mode 100644 index 0000000..df72130 --- /dev/null +++ b/pkg/tasks/scoring/colbert/ranking_test.go @@ -0,0 +1,77 @@ +package colbert + +import ( + "os" + "sort" + "testing" + + "github.com/nlpodyssey/spago/ag" + "github.com/stretchr/testify/require" +) + +func TestDocumentScorer_ScoreDocument(t *testing.T) { + + tests := []struct { + name string + query string + documents []string + wantRanking []int + wantScores []float64 + }{ + { + name: "test1", + query: "hello world", + documents: []string{"hello world"}, + wantRanking: []int{0}, + wantScores: []float64{1.0}, + }, + { + name: "test2", + query: "In which year was the first iPhone released?", + documents: []string{"The first Nokia phone was released in 1987.", + "The iPhone 3G was released in 2008.", + "The original iPhone was first sold in 2007."}, + wantRanking: []int{2, 0, 1}, + }, + } + // Set the directory where the colbert model is stored here: + ColbertModelDir := "testdata/colbert" + // if dir does not exist, skip test + if _, err := os.Stat(ColbertModelDir); os.IsNotExist(err) { + t.Skip("Colbert model directory does not exist, skipping test") + } + + scorer, err := LoadDocumentScorer(ColbertModelDir) + require.NoError(t, err) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := scorer.EncodeQuery(tt.query) + var scores []float64 + for i, doc := range tt.documents { + document := scorer.EncodeDocument(doc) + score := scorer.ScoreDocument(query, document) + // Normalize the score by the length of the non-special tokens of query + // (this is not in the original paper btw, but it makes sense to me) + score = ag.Div(score, ag.Scalar(float64(len(query)-3))) + if tt.wantScores != nil { + require.InDelta(t, tt.wantScores[i], score.Value().Data().F64()[0], 0.01) + } + scores = append(scores, score.Value().Data().F64()[0]) + } + ranking := rank(scores) + require.Equal(t, tt.wantRanking, ranking) + }) + } +} + +func rank(scores []float64) []int { + var ranking []int + for i := range scores { + ranking = append(ranking, i) + } + sort.SliceStable(ranking, func(i, j int) bool { + return scores[ranking[i]] > scores[ranking[j]] + }) + return ranking +}