Skip to content

Commit f36b6e0

Browse files
authored
Merge pull request #1203 from Leowbattle/safetensors
safetensors: Add decoder
2 parents 740031b + 7f21d1b commit f36b6e0

File tree

9 files changed

+267
-0
lines changed

9 files changed

+267
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ prores_frame,
141141
protobuf_widevine,
142142
pssh_playready,
143143
[rtmp](doc/formats.md#rtmp),
144+
safetensors,
144145
sll2_packet,
145146
sll_packet,
146147
[tap](doc/formats.md#tap),

doc/formats.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
|`protobuf_widevine` |Widevine&nbsp;protobuf |<sub>`protobuf`</sub>|
114114
|`pssh_playready` |PlayReady&nbsp;PSSH |<sub></sub>|
115115
|[`rtmp`](#rtmp) |Real-Time&nbsp;Messaging&nbsp;Protocol |<sub>`amf0` `mpeg_asc`</sub>|
116+
|`safetensors` |SafeTensors |<sub>`json`</sub>|
116117
|`sll2_packet` |Linux&nbsp;cooked&nbsp;capture&nbsp;encapsulation&nbsp;v2 |<sub>`inet_packet`</sub>|
117118
|`sll_packet` |Linux&nbsp;cooked&nbsp;capture&nbsp;encapsulation |<sub>`inet_packet`</sub>|
118119
|[`tap`](#tap) |TAP&nbsp;tape&nbsp;format&nbsp;for&nbsp;ZX&nbsp;Spectrum&nbsp;computers |<sub></sub>|

format/all/all.fqtest

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ protobuf Protobuf
158158
protobuf_widevine Widevine protobuf
159159
pssh_playready PlayReady PSSH
160160
rtmp Real-Time Messaging Protocol
161+
safetensors SafeTensors
161162
sll2_packet Linux cooked capture encapsulation v2
162163
sll_packet Linux cooked capture encapsulation
163164
tap TAP tape format for ZX Spectrum computers

format/all/all.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ import (
5454
_ "github.com/wader/fq/format/protobuf"
5555
_ "github.com/wader/fq/format/riff"
5656
_ "github.com/wader/fq/format/rtmp"
57+
_ "github.com/wader/fq/format/safetensors"
5758
_ "github.com/wader/fq/format/tap"
5859
_ "github.com/wader/fq/format/tar"
5960
_ "github.com/wader/fq/format/text"

format/format.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ var (
165165
ProtobufWidevine = &decode.Group{Name: "protobuf_widevine"}
166166
PSSH_Playready = &decode.Group{Name: "pssh_playready"}
167167
RTMP = &decode.Group{Name: "rtmp"}
168+
SAFETENSORS = &decode.Group{Name: "safetensors"}
168169
SLL_Packet = &decode.Group{Name: "sll_packet"}
169170
SLL2_Packet = &decode.Group{Name: "sll2_packet"}
170171
TAP = &decode.Group{Name: "tap"}

format/safetensors/safetensors.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package safetensors
2+
3+
// https://huggingface.co/docs/safetensors/en/index
4+
5+
import (
6+
"fmt"
7+
"math"
8+
"sort"
9+
10+
"github.com/wader/fq/format"
11+
"github.com/wader/fq/internal/mapstruct"
12+
"github.com/wader/fq/pkg/decode"
13+
"github.com/wader/fq/pkg/interp"
14+
"github.com/wader/fq/pkg/scalar"
15+
)
16+
17+
var jsonFormat decode.Group
18+
19+
type TensorInfo struct {
20+
Dtype string `mapstruct:"dtype"`
21+
Shape []int `mapstruct:"shape"`
22+
DataOffsets []int `mapstruct:"data_offsets"`
23+
}
24+
25+
type SafeTensorsHeader struct {
26+
Tensors map[string]TensorInfo `mapstruct:",remain"`
27+
Metadata map[string]any `mapstruct:"__metadata__"`
28+
}
29+
30+
func init() {
31+
interp.RegisterFormat(
32+
format.SAFETENSORS,
33+
&decode.Format{
34+
Description: "SafeTensors",
35+
DecodeFn: decodeSafeTensors,
36+
Dependencies: []decode.Dependency{
37+
{Groups: []*decode.Group{format.JSON}, Out: &jsonFormat},
38+
},
39+
})
40+
}
41+
42+
func parseHeader(dv *decode.Value) (*SafeTensorsHeader, error) {
43+
actualVal, ok := dv.V.(*scalar.Any)
44+
if !ok {
45+
return nil, fmt.Errorf("expected scalar.Any, got %T", dv.V)
46+
}
47+
48+
headerMap, ok := actualVal.Actual.(map[string]any)
49+
if !ok {
50+
return nil, fmt.Errorf("expected map[string]any, got %T", actualVal.Actual)
51+
}
52+
53+
var header SafeTensorsHeader
54+
if err := mapstruct.ToStruct(headerMap, &header); err != nil {
55+
return nil, fmt.Errorf("failed to parse header: %w", err)
56+
}
57+
58+
return &header, nil
59+
}
60+
61+
// https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
62+
// https://en.wikipedia.org/wiki/Single-precision_floating-point_format
63+
// float32: 1 sign bit, 8 exponent bits, 23 fraction bits
64+
// bfloat16: 1 sign bit, 8 exponent bits, 7 fraction bits
65+
// To convert bfloat16 to float32, we can shift the bits to the left by 16.
66+
func bfloat16_bits_to_float(bits uint16) float32 {
67+
return math.Float32frombits(uint32(bits) << 16)
68+
}
69+
70+
var dataDecoders = map[string]func(d *decode.D){
71+
"F64": func(d *decode.D) { d.FieldF64("x") },
72+
"F32": func(d *decode.D) { d.FieldF32("x") },
73+
"F16": func(d *decode.D) { d.FieldF16("x") },
74+
"BF16": func(d *decode.D) {
75+
d.FieldFltFn("x", func(d *decode.D) float64 {
76+
return float64(bfloat16_bits_to_float(uint16(d.U16())))
77+
})
78+
},
79+
"I64": func(d *decode.D) { d.FieldS64("x") },
80+
"I32": func(d *decode.D) { d.FieldS32("x") },
81+
"I16": func(d *decode.D) { d.FieldS16("x") },
82+
"I8": func(d *decode.D) { d.FieldS8("x") },
83+
"U8": func(d *decode.D) { d.FieldU8("x") },
84+
"BOOL": func(d *decode.D) { d.FieldBool("x") },
85+
}
86+
87+
func decodeSafeTensors(d *decode.D) any {
88+
d.Endian = decode.LittleEndian
89+
90+
headerSize := d.FieldU64("header size")
91+
92+
var dv *decode.Value
93+
94+
d.LimitedFn(8*int64(headerSize), func(d *decode.D) {
95+
dv, _ = d.FieldFormat("header", &jsonFormat, nil)
96+
})
97+
98+
d.FieldStruct("tensors", func(d *decode.D) {
99+
header, err := parseHeader(dv)
100+
if err != nil {
101+
d.Fatalf("failed to parse header: %v", err)
102+
return
103+
}
104+
105+
// Get tensor names and sort them for deterministic output
106+
tensorNames := make([]string, 0, len(header.Tensors))
107+
for tensorName := range header.Tensors {
108+
tensorNames = append(tensorNames, tensorName)
109+
}
110+
sort.Strings(tensorNames)
111+
112+
for _, tensorName := range tensorNames {
113+
tensorInfo := header.Tensors[tensorName]
114+
115+
decoder, exists := dataDecoders[tensorInfo.Dtype]
116+
if !exists {
117+
d.Fatalf("unsupported dtype: %s", tensorInfo.Dtype)
118+
continue
119+
}
120+
121+
if len(tensorInfo.DataOffsets) < 2 {
122+
d.Fatalf("invalid data_offsets for tensor %s: %v", tensorName, tensorInfo.DataOffsets)
123+
continue
124+
}
125+
126+
begin := tensorInfo.DataOffsets[0]
127+
128+
d.FieldStruct(tensorName, func(d *decode.D) {
129+
d.FieldArray("shape", func(d *decode.D) {
130+
for _, s := range tensorInfo.Shape {
131+
d.FieldValueSint("dim", int64(s))
132+
}
133+
})
134+
135+
if len(tensorInfo.Shape) == 0 {
136+
return
137+
}
138+
139+
d.SeekAbs(8*(8+int64(headerSize)+int64(begin)), func(d *decode.D) {
140+
var reshape func(d *decode.D, i int)
141+
reshape = func(d *decode.D, i int) {
142+
d.FieldArray("data", func(d *decode.D) {
143+
if i == len(tensorInfo.Shape)-1 {
144+
for range tensorInfo.Shape[i] {
145+
decoder(d)
146+
}
147+
} else {
148+
for range tensorInfo.Shape[i] {
149+
reshape(d, i+1)
150+
}
151+
}
152+
})
153+
}
154+
reshape(d, 0)
155+
})
156+
157+
})
158+
}
159+
})
160+
161+
return nil
162+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
from safetensors.torch import save_file
3+
4+
tensors = {
5+
"weight1": torch.reshape(torch.arange(12, dtype=torch.float32), (12,)),
6+
"weight2": torch.reshape(torch.arange(12, dtype=torch.int64), (3, 4)),
7+
"weight3": torch.reshape(torch.arange(12, dtype=torch.float16), (2, 2, 3)),
8+
"weight4": torch.reshape(torch.arange(12, dtype=torch.bfloat16), (4, 3)),
9+
}
10+
11+
save_file(tensors, "format/safetensors/testdata/test.safetensors")
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
$ fq -d safetensors dv test.safetensors
2+
|00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f|0123456789abcdef|.{}: test.safetensors (safetensors) 0x0-0x1d0 (464)
3+
0x000|08 01 00 00 00 00 00 00 |........ | header size: 264 0x0-0x8 (8)
4+
|00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f|0123456789abcdef|
5+
0x000| 7b 22 77 65 69 67 68 74| {"weight| header: {} (json) 0x8-0x110 (264)
6+
0x010|32 22 3a 7b 22 64 74 79 70 65 22 3a 22 49 36 34|2":{"dtype":"I64|
7+
* |until 0x10f.7 (264) | |
8+
| | | tensors{}: 0x110-0x1d0 (192)
9+
| | | weight1{}: 0x110-0x1a0 (144)
10+
| | | shape[0:1]: 0x110-0x110 (0)
11+
| | | [0]: 12 dim
12+
| | | data[0:12]: 0x170-0x1a0 (48)
13+
0x170|00 00 00 00 |.... | [0]: 0 x 0x170-0x174 (4)
14+
0x170| 00 00 80 3f | ...? | [1]: 1 x 0x174-0x178 (4)
15+
0x170| 00 00 00 40 | ...@ | [2]: 2 x 0x178-0x17c (4)
16+
0x170| 00 00 40 40| ..@@| [3]: 3 x 0x17c-0x180 (4)
17+
0x180|00 00 80 40 |...@ | [4]: 4 x 0x180-0x184 (4)
18+
0x180| 00 00 a0 40 | ...@ | [5]: 5 x 0x184-0x188 (4)
19+
0x180| 00 00 c0 40 | ...@ | [6]: 6 x 0x188-0x18c (4)
20+
0x180| 00 00 e0 40| ...@| [7]: 7 x 0x18c-0x190 (4)
21+
0x190|00 00 00 41 |...A | [8]: 8 x 0x190-0x194 (4)
22+
0x190| 00 00 10 41 | ...A | [9]: 9 x 0x194-0x198 (4)
23+
0x190| 00 00 20 41 | .. A | [10]: 10 x 0x198-0x19c (4)
24+
0x190| 00 00 30 41| ..0A| [11]: 11 x 0x19c-0x1a0 (4)
25+
| | | weight2{}: 0x110-0x170 (96)
26+
| | | shape[0:2]: 0x110-0x110 (0)
27+
| | | [0]: 3 dim
28+
| | | [1]: 4 dim
29+
| | | data[0:3]: 0x110-0x170 (96)
30+
| | | [0][0:4]: data 0x110-0x130 (32)
31+
0x110|00 00 00 00 00 00 00 00 |........ | [0]: 0 x 0x110-0x118 (8)
32+
0x110| 01 00 00 00 00 00 00 00| ........| [1]: 1 x 0x118-0x120 (8)
33+
0x120|02 00 00 00 00 00 00 00 |........ | [2]: 2 x 0x120-0x128 (8)
34+
0x120| 03 00 00 00 00 00 00 00| ........| [3]: 3 x 0x128-0x130 (8)
35+
| | | [1][0:4]: data 0x130-0x150 (32)
36+
0x130|04 00 00 00 00 00 00 00 |........ | [0]: 4 x 0x130-0x138 (8)
37+
0x130| 05 00 00 00 00 00 00 00| ........| [1]: 5 x 0x138-0x140 (8)
38+
0x140|06 00 00 00 00 00 00 00 |........ | [2]: 6 x 0x140-0x148 (8)
39+
0x140| 07 00 00 00 00 00 00 00| ........| [3]: 7 x 0x148-0x150 (8)
40+
| | | [2][0:4]: data 0x150-0x170 (32)
41+
0x150|08 00 00 00 00 00 00 00 |........ | [0]: 8 x 0x150-0x158 (8)
42+
0x150| 09 00 00 00 00 00 00 00| ........| [1]: 9 x 0x158-0x160 (8)
43+
0x160|0a 00 00 00 00 00 00 00 |........ | [2]: 10 x 0x160-0x168 (8)
44+
0x160| 0b 00 00 00 00 00 00 00| ........| [3]: 11 x 0x168-0x170 (8)
45+
| | | weight3{}: 0x110-0x1d0 (192)
46+
| | | shape[0:3]: 0x110-0x110 (0)
47+
| | | [0]: 2 dim
48+
| | | [1]: 2 dim
49+
| | | [2]: 3 dim
50+
| | | data[0:2]: 0x1b8-0x1d0 (24)
51+
| | | [0][0:2]: data 0x1b8-0x1c4 (12)
52+
| | | [0][0:3]: data 0x1b8-0x1be (6)
53+
0x1b0| 00 00 | .. | [0]: 0 x 0x1b8-0x1ba (2)
54+
0x1b0| 00 3c | .< | [1]: 1 x 0x1ba-0x1bc (2)
55+
0x1b0| 00 40 | .@ | [2]: 2 x 0x1bc-0x1be (2)
56+
| | | [1][0:3]: data 0x1be-0x1c4 (6)
57+
0x1b0| 00 42| .B| [0]: 3 x 0x1be-0x1c0 (2)
58+
0x1c0|00 44 |.D | [1]: 4 x 0x1c0-0x1c2 (2)
59+
0x1c0| 00 45 | .E | [2]: 5 x 0x1c2-0x1c4 (2)
60+
| | | [1][0:2]: data 0x1c4-0x1d0 (12)
61+
| | | [0][0:3]: data 0x1c4-0x1ca (6)
62+
0x1c0| 00 46 | .F | [0]: 6 x 0x1c4-0x1c6 (2)
63+
0x1c0| 00 47 | .G | [1]: 7 x 0x1c6-0x1c8 (2)
64+
0x1c0| 00 48 | .H | [2]: 8 x 0x1c8-0x1ca (2)
65+
| | | [1][0:3]: data 0x1ca-0x1d0 (6)
66+
0x1c0| 80 48 | .H | [0]: 9 x 0x1ca-0x1cc (2)
67+
0x1c0| 00 49 | .I | [1]: 10 x 0x1cc-0x1ce (2)
68+
0x1c0| 80 49| .I| [2]: 11 x 0x1ce-0x1d0 (2)
69+
| | | weight4{}: 0x110-0x1b8 (168)
70+
| | | shape[0:2]: 0x110-0x110 (0)
71+
| | | [0]: 4 dim
72+
| | | [1]: 3 dim
73+
| | | data[0:4]: 0x1a0-0x1b8 (24)
74+
| | | [0][0:3]: data 0x1a0-0x1a6 (6)
75+
0x1a0|00 00 |.. | [0]: 0 x 0x1a0-0x1a2 (2)
76+
0x1a0| 80 3f | .? | [1]: 1 x 0x1a2-0x1a4 (2)
77+
0x1a0| 00 40 | .@ | [2]: 2 x 0x1a4-0x1a6 (2)
78+
| | | [1][0:3]: data 0x1a6-0x1ac (6)
79+
0x1a0| 40 40 | @@ | [0]: 3 x 0x1a6-0x1a8 (2)
80+
0x1a0| 80 40 | .@ | [1]: 4 x 0x1a8-0x1aa (2)
81+
0x1a0| a0 40 | .@ | [2]: 5 x 0x1aa-0x1ac (2)
82+
| | | [2][0:3]: data 0x1ac-0x1b2 (6)
83+
0x1a0| c0 40 | .@ | [0]: 6 x 0x1ac-0x1ae (2)
84+
0x1a0| e0 40| .@| [1]: 7 x 0x1ae-0x1b0 (2)
85+
0x1b0|00 41 |.A | [2]: 8 x 0x1b0-0x1b2 (2)
86+
| | | [3][0:3]: data 0x1b2-0x1b8 (6)
87+
0x1b0| 10 41 | .A | [0]: 9 x 0x1b2-0x1b4 (2)
88+
0x1b0| 20 41 | A | [1]: 10 x 0x1b4-0x1b6 (2)
89+
0x1b0| 30 41 | 0A | [2]: 11 x 0x1b6-0x1b8 (2)
464 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)