Skip to content

Commit 4da856d

Browse files
committed
Make Flair conversion work with more models
found on HuggingFace hub
1 parent 64b4552 commit 4da856d

File tree

12 files changed

+150
-28
lines changed

12 files changed

+150
-28
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package builtins
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion"
7+
)
8+
9+
type Getattr struct{}
10+
11+
func (Getattr) Call(args ...any) (any, error) {
12+
if len(args) != 2 && len(args) != 3 {
13+
return nil, fmt.Errorf("builtins.getattr: want 2 or 3 args, got %d: %#v", len(args), args)
14+
}
15+
16+
object, ok := args[0].(conversion.PyAttributeGettable)
17+
if !ok {
18+
return nil, fmt.Errorf("builtins.getattr: 1st arg (object) does not satisfy PyAttributeGettable interface: %T", args[0])
19+
}
20+
21+
name, ok := args[1].(string)
22+
if !ok {
23+
return nil, fmt.Errorf("builtins.getattr: want 2nd arg (name) to be string, got %T: %#v", args[1], args[1])
24+
}
25+
26+
value, exists, err := object.PyGetAttribute(name)
27+
if err != nil {
28+
return nil, fmt.Errorf("builtins.getattr(%#v): PyGetAttribute failed: %w", args, err)
29+
}
30+
31+
if len(args) == 3 && !exists {
32+
return args[2], nil
33+
}
34+
return value, nil
35+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package builtins
2+
3+
type Int struct{}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package collections
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/nlpodyssey/gopickle/types"
7+
)
8+
9+
type DefaultDictClass struct{}
10+
11+
type DefaultDict struct {
12+
*types.Dict
13+
DefaultFactory any
14+
}
15+
16+
func (DefaultDictClass) Call(args ...any) (any, error) {
17+
if len(args) != 1 {
18+
return nil, fmt.Errorf("DefaultDictClass: want 1 argument, got %d: %#v", len(args), args)
19+
}
20+
21+
return &DefaultDict{
22+
Dict: types.NewDict(),
23+
DefaultFactory: args[0],
24+
}, nil
25+
}

pkg/converter/flair/conversion/flair/dictionary.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@ import (
88
"fmt"
99

1010
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion"
11+
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/collections"
1112
"github.com/nlpodyssey/gopickle/types"
1213
)
1314

1415
type DictionaryClass struct{}
1516

1617
type Dictionary struct {
17-
Item2Idx map[string]int
18-
Idx2Item []string
19-
MultiLabel bool
18+
Item2Idx map[string]int
19+
Idx2Item []string
20+
Item2IdxNotEncoded *collections.DefaultDict
21+
MultiLabel bool
2022
}
2123

2224
func (DictionaryClass) PyNew(args ...any) (any, error) {
@@ -40,6 +42,8 @@ func (d *Dictionary) PyDictSet(k, v any) (err error) {
4042
if err == nil {
4143
err = d.setIdx2Item(l)
4244
}
45+
case "item2idx_not_encoded":
46+
err = conversion.AssignAssertedType(v, &d.Item2IdxNotEncoded)
4347
case "multi_label":
4448
err = conversion.AssignAssertedType(v, &d.MultiLabel)
4549
default:

pkg/converter/flair/conversion/flair/flairembeddings.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ type FlairEmbeddings struct {
2424
CharsPerChunk int
2525
embeddingLength int
2626
PretrainedModelArchiveMap map[string]string
27+
InstanceParameters *types.Dict
28+
WithWhitespace bool // Default: true
29+
TokenizedLM bool // Default: true
2730
LM *LanguageModel
2831
}
2932

@@ -66,6 +69,12 @@ func (f *FlairEmbeddings) PyDictSet(k, v any) (err error) {
6669
if err == nil {
6770
err = conversion.AssignDictToMap(d, &f.PretrainedModelArchiveMap)
6871
}
72+
case "instance_parameters":
73+
err = conversion.AssignAssertedType(v, &f.InstanceParameters)
74+
case "with_whitespace":
75+
err = conversion.AssignAssertedType(v, &f.WithWhitespace)
76+
case "tokenized_lm":
77+
err = conversion.AssignAssertedType(v, &f.TokenizedLM)
6978
case "detach", "cache": // TODO
7079
default:
7180
err = fmt.Errorf("unexpected key with value %#v", v)

pkg/converter/flair/conversion/flair/languagemodel.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ type LanguageModelClass struct{}
1616

1717
type LanguageModel struct {
1818
torch.Module
19-
Dictionary *Dictionary
20-
IsForwardLm bool
21-
Dropout float64
22-
HiddenSize int
23-
EmbeddingSize int
24-
NLayers int
25-
NOut int
19+
Dictionary *Dictionary
20+
IsForwardLm bool
21+
Dropout float64
22+
HiddenSize int
23+
EmbeddingSize int
24+
NLayers int
25+
NOut int
26+
DocumentDelimiter string
2627

2728
Encoder *torch.SparseEmbedding
2829
Decoder *torch.Linear
@@ -69,6 +70,8 @@ func (l *LanguageModel) PyDictSet(k, v any) (err error) {
6970
if v != nil {
7071
err = fmt.Errorf("only nil is supported, got %T: %#v", v, v)
7172
}
73+
case "document_delimiter":
74+
err = conversion.AssignAssertedType(v, &l.DocumentDelimiter)
7275
default:
7376
err = fmt.Errorf("unexpected key with value %#v", v)
7477
}

pkg/converter/flair/conversion/flair/unpickling.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,24 @@ import (
88
"fmt"
99
"io"
1010

11+
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/builtins"
12+
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/collections"
1113
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/gensim"
1214
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/numpy"
1315
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/torch"
1416
"github.com/nlpodyssey/gopickle/pickle"
1517
)
1618

1719
var allClasses = map[string]any{
20+
"builtins.getattr": builtins.Getattr{},
21+
"builtins.int": builtins.Int{},
22+
"collections.defaultdict": collections.DefaultDictClass{},
1823
"flair.data.Dictionary": DictionaryClass{},
1924
"flair.embeddings.FlairEmbeddings": FlairEmbeddingsClass{},
2025
"flair.embeddings.StackedEmbeddings": StackedEmbeddingsClass{},
21-
"flair.embeddings.token.StackedEmbeddings": StackedEmbeddingsClass{},
2226
"flair.embeddings.WordEmbeddings": WordEmbeddingsClass{},
27+
"flair.embeddings.token.FlairEmbeddings": FlairEmbeddingsClass{},
28+
"flair.embeddings.token.StackedEmbeddings": StackedEmbeddingsClass{},
2329
"flair.embeddings.token.WordEmbeddings": WordEmbeddingsClass{},
2430
"flair.models.language_model.LanguageModel": LanguageModelClass{},
2531
"gensim.models.keyedvectors.Vocab": gensim.VocabClass{},
@@ -28,11 +34,11 @@ var allClasses = map[string]any{
2834
"numpy.dtype": numpy.DTypeClass{},
2935
"numpy.ndarray": numpy.NDArrayClass{},
3036
"torch._utils._rebuild_parameter": torch.RebuildParameter{},
37+
"torch.backends.cudnn.rnn.Unserializable": torch.RNNUnserializableClass{},
3138
"torch.nn.modules.dropout.Dropout": torch.DropoutClass{},
3239
"torch.nn.modules.linear.Linear": torch.LinearClass{},
3340
"torch.nn.modules.rnn.LSTM": torch.LSTMClass{},
3441
"torch.nn.modules.sparse.Embedding": torch.SparseEmbeddingClass{},
35-
"torch.backends.cudnn.rnn.Unserializable": torch.RNNUnserializableClass{},
3642
}
3743

3844
func newUnpickler(r io.Reader) pickle.Unpickler {

pkg/converter/flair/conversion/flair/wordembeddings.go

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,21 @@ import (
1010
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion"
1111
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/gensim"
1212
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/torch"
13+
"github.com/nlpodyssey/gopickle/types"
1314
"github.com/nlpodyssey/spago/mat"
1415
)
1516

1617
type WordEmbeddingsClass struct{}
1718

1819
type WordEmbeddings struct {
1920
TokenEmbeddingsModule
20-
Embeddings string
21-
Name string
22-
StaticEmbeddings bool
23-
Embedding *torch.Embedding
24-
Vocab map[string]int
25-
embeddingLength int
21+
Embeddings string
22+
Name string
23+
StaticEmbeddings bool
24+
Embedding *torch.Embedding
25+
Vocab map[string]int
26+
InstanceParameters *types.Dict
27+
embeddingLength int
2628
}
2729

2830
var _ TokenEmbeddings = &WordEmbeddings{}
@@ -48,6 +50,10 @@ func (w *WordEmbeddings) PyDictSet(k, v any) (err error) {
4850
switch k {
4951
case "embeddings":
5052
err = conversion.AssignAssertedType(v, &w.Embeddings)
53+
case "get_cached_vec":
54+
// present on older models, can be ignored
55+
case "instance_parameters":
56+
err = conversion.AssignAssertedType(v, &w.InstanceParameters)
5157
case "name":
5258
err = conversion.AssignAssertedType(v, &w.Name)
5359
case "static_embeddings":
@@ -93,6 +99,17 @@ func (w *WordEmbeddings) setPrecomputedWordEmbeddings(kv *gensim.KeyedVectors) e
9399
return nil
94100
}
95101

102+
func (w *WordEmbeddings) PyGetAttribute(name string) (value any, exists bool, err error) {
103+
switch name {
104+
case "get_cached_vec":
105+
// this ignores the get_cached_vec method when loading older versions
106+
// it is needed for compatibility reasons
107+
return nil, true, nil
108+
default:
109+
return nil, false, fmt.Errorf("WordEmbeddings: unexpected __getattribute__(%q)", name)
110+
}
111+
}
112+
96113
func (w *WordEmbeddings) LoadStateDictEntry(string, any) error {
97114
return fmt.Errorf("WordEmbeddings: loading from state dict entry not implemented")
98115
}

pkg/converter/flair/conversion/torch/module.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@ import (
1313
)
1414

1515
type Module struct {
16-
Training bool
17-
Parameters *types.OrderedDict
18-
Buffers *types.OrderedDict
19-
BackwardHooks *types.OrderedDict
20-
ForwardHooks *types.OrderedDict
21-
ForwardPreHooks *types.OrderedDict
22-
StateDictHooks *types.OrderedDict
23-
LoadStateDictPreHooks *types.OrderedDict
24-
Modules *types.OrderedDict
16+
Training bool
17+
Parameters *types.OrderedDict
18+
Buffers *types.OrderedDict
19+
BackwardHooks *types.OrderedDict
20+
ForwardHooks *types.OrderedDict
21+
ForwardPreHooks *types.OrderedDict
22+
StateDictHooks *types.OrderedDict
23+
LoadStateDictPreHooks *types.OrderedDict
24+
Modules *types.OrderedDict
25+
NonPersistentBuffersSet *types.Set
2526
}
2627

2728
func GetSubModule[T any](mod Module, name string) (v T, err error) {
@@ -70,6 +71,8 @@ func (m *Module) PyDictSet(k, v any) (err error) {
7071
err = conversion.AssignAssertedType(v, &m.ForwardPreHooks)
7172
case "_state_dict_hooks":
7273
err = conversion.AssignAssertedType(v, &m.StateDictHooks)
74+
case "_non_persistent_buffers_set":
75+
err = conversion.AssignAssertedType(v, &m.NonPersistentBuffersSet)
7376
case "_load_state_dict_pre_hooks":
7477
err = conversion.AssignAssertedType(v, &m.LoadStateDictPreHooks)
7578
case "_modules":

pkg/converter/flair/conversion/torch/rnn.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type RNNBase struct {
4242
Bidirectional bool
4343
ProjSize int
4444
FlatWeightsNames []string
45+
FlatWeights []*Parameter
4546
AllWeights [][]string
4647
Parameters map[string]*Parameter
4748
}
@@ -202,6 +203,18 @@ func (r *RNNBase) PyDictSet(k, v any) (err error) {
202203
err = conversion.AssignAssertedType(v, &r.Bidirectional)
203204
case "_all_weights":
204205
err = r.convertAndSetAllWeights(v)
206+
case "_flat_weights":
207+
var l *types.List
208+
err = conversion.AssignAssertedType(v, &l)
209+
if err == nil {
210+
err = conversion.AssignListToSlice(l, &r.FlatWeights)
211+
}
212+
case "_flat_weights_names":
213+
var l *types.List
214+
err = conversion.AssignAssertedType(v, &l)
215+
if err == nil {
216+
err = conversion.AssignListToSlice(l, &r.FlatWeightsNames)
217+
}
205218
case "_data_ptrs", "_param_buf_size":
206219
default:
207220
err = fmt.Errorf("unexpected key with value %#v", v)

0 commit comments

Comments
 (0)