Commit 2297429
Inference Gym: Add 64 bit support to all targets.
This is done via adding a `dtype` argument to the initializer which is then used to cast the inputs (if any) as well as the dtype for de novo arrays (like tf.zeros). This is a departure from the usual TensorFlow Probability style where the dtype is inferred from the inputs because:
- Inference Gym targets often have no inputs
- Inference Gym is not designed to do deferred array materialization like TFP is
We also depart from the typical JAX style of using the maximum precision available because we want to enable testing 32 and 64 bit implementations in the same process.
NOTE: At this time, ground truth remains always numpy arrays and always 64 bit.
Fixes #1993
PiperOrigin-RevId: 7362313101 parent b688014 commit 2297429
File tree
35 files changed
+597
-269
lines changed- spinoffs/inference_gym/inference_gym
- internal
- targets
- tensorflow_probability/python/math
35 files changed
+597
-269
lines changedLines changed: 13 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
| 26 | + | |
26 | 27 | | |
27 | 28 | | |
28 | 29 | | |
| |||
279 | 280 | | |
280 | 281 | | |
281 | 282 | | |
| 283 | + | |
282 | 284 | | |
283 | 285 | | |
284 | 286 | | |
| |||
295 | 297 | | |
296 | 298 | | |
297 | 299 | | |
| 300 | + | |
298 | 301 | | |
299 | 302 | | |
300 | 303 | | |
| |||
331 | 334 | | |
332 | 335 | | |
333 | 336 | | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
334 | 344 | | |
335 | 345 | | |
336 | | - | |
| 346 | + | |
337 | 347 | | |
338 | 348 | | |
339 | 349 | | |
340 | 350 | | |
341 | 351 | | |
342 | 352 | | |
343 | 353 | | |
| 354 | + | |
| 355 | + | |
344 | 356 | | |
345 | 357 | | |
346 | 358 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
| 66 | + | |
66 | 67 | | |
67 | 68 | | |
68 | 69 | | |
| |||
73 | 74 | | |
74 | 75 | | |
75 | 76 | | |
| 77 | + | |
76 | 78 | | |
77 | 79 | | |
78 | 80 | | |
| |||
146 | 148 | | |
147 | 149 | | |
148 | 150 | | |
| 151 | + | |
149 | 152 | | |
150 | 153 | | |
151 | 154 | | |
| |||
167 | 170 | | |
168 | 171 | | |
169 | 172 | | |
| 173 | + | |
170 | 174 | | |
171 | 175 | | |
172 | 176 | | |
| |||
220 | 224 | | |
221 | 225 | | |
222 | 226 | | |
| 227 | + | |
223 | 228 | | |
224 | 229 | | |
225 | 230 | | |
| |||
275 | 280 | | |
276 | 281 | | |
277 | 282 | | |
| 283 | + | |
278 | 284 | | |
279 | 285 | | |
280 | 286 | | |
| |||
319 | 325 | | |
320 | 326 | | |
321 | 327 | | |
| 328 | + | |
322 | 329 | | |
323 | 330 | | |
324 | 331 | | |
| |||
340 | 347 | | |
341 | 348 | | |
342 | 349 | | |
| 350 | + | |
343 | 351 | | |
344 | 352 | | |
345 | 353 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
| 64 | + | |
64 | 65 | | |
65 | 66 | | |
66 | 67 | | |
67 | | - | |
| 68 | + | |
68 | 69 | | |
69 | 70 | | |
70 | 71 | | |
71 | 72 | | |
72 | 73 | | |
73 | 74 | | |
| 75 | + | |
74 | 76 | | |
75 | 77 | | |
76 | 78 | | |
| |||
87 | 89 | | |
88 | 90 | | |
89 | 91 | | |
90 | | - | |
| 92 | + | |
91 | 93 | | |
92 | | - | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
93 | 97 | | |
94 | 98 | | |
95 | 99 | | |
96 | 100 | | |
97 | 101 | | |
98 | 102 | | |
99 | | - | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
100 | 106 | | |
101 | 107 | | |
102 | 108 | | |
| |||
115 | 121 | | |
116 | 122 | | |
117 | 123 | | |
| 124 | + | |
118 | 125 | | |
119 | 126 | | |
120 | 127 | | |
| |||
Lines changed: 10 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
17 | 20 | | |
18 | 21 | | |
19 | 22 | | |
| |||
22 | 25 | | |
23 | 26 | | |
24 | 27 | | |
25 | | - | |
| 28 | + | |
| 29 | + | |
26 | 30 | | |
27 | 31 | | |
28 | 32 | | |
29 | 33 | | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
30 | 37 | | |
31 | | - | |
| 38 | + | |
32 | 39 | | |
33 | 40 | | |
34 | 41 | | |
35 | 42 | | |
36 | 43 | | |
| 44 | + | |
37 | 45 | | |
38 | 46 | | |
39 | 47 | | |
| |||
Lines changed: 43 additions & 16 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
62 | | - | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
63 | 65 | | |
64 | | - | |
65 | | - | |
66 | | - | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
67 | 71 | | |
68 | 72 | | |
69 | 73 | | |
70 | | - | |
| 74 | + | |
| 75 | + | |
71 | 76 | | |
72 | 77 | | |
73 | | - | |
74 | | - | |
| 78 | + | |
| 79 | + | |
75 | 80 | | |
76 | 81 | | |
77 | 82 | | |
78 | 83 | | |
79 | 84 | | |
| 85 | + | |
80 | 86 | | |
81 | 87 | | |
82 | 88 | | |
| |||
86 | 92 | | |
87 | 93 | | |
88 | 94 | | |
89 | | - | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
90 | 101 | | |
91 | 102 | | |
92 | 103 | | |
| |||
117 | 128 | | |
118 | 129 | | |
119 | 130 | | |
| 131 | + | |
120 | 132 | | |
121 | 133 | | |
122 | 134 | | |
| |||
130 | 142 | | |
131 | 143 | | |
132 | 144 | | |
| 145 | + | |
133 | 146 | | |
134 | 147 | | |
135 | 148 | | |
136 | 149 | | |
137 | 150 | | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
138 | 157 | | |
139 | 158 | | |
140 | 159 | | |
| |||
150 | 169 | | |
151 | 170 | | |
152 | 171 | | |
153 | | - | |
| 172 | + | |
| 173 | + | |
154 | 174 | | |
155 | 175 | | |
156 | 176 | | |
| |||
164 | 184 | | |
165 | 185 | | |
166 | 186 | | |
| 187 | + | |
167 | 188 | | |
168 | 189 | | |
169 | 190 | | |
| |||
193 | 214 | | |
194 | 215 | | |
195 | 216 | | |
196 | | - | |
| 217 | + | |
197 | 218 | | |
198 | 219 | | |
199 | 220 | | |
200 | 221 | | |
201 | 222 | | |
| 223 | + | |
202 | 224 | | |
203 | 225 | | |
204 | 226 | | |
| |||
226 | 248 | | |
227 | 249 | | |
228 | 250 | | |
| 251 | + | |
229 | 252 | | |
230 | 253 | | |
231 | 254 | | |
| |||
238 | 261 | | |
239 | 262 | | |
240 | 263 | | |
| 264 | + | |
241 | 265 | | |
242 | 266 | | |
243 | 267 | | |
| |||
247 | 271 | | |
248 | 272 | | |
249 | 273 | | |
250 | | - | |
| 274 | + | |
| 275 | + | |
251 | 276 | | |
252 | 277 | | |
253 | 278 | | |
254 | 279 | | |
255 | | - | |
| 280 | + | |
| 281 | + | |
256 | 282 | | |
257 | 283 | | |
258 | 284 | | |
| |||
266 | 292 | | |
267 | 293 | | |
268 | 294 | | |
269 | | - | |
270 | | - | |
271 | | - | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
272 | 298 | | |
273 | 299 | | |
274 | 300 | | |
| |||
300 | 326 | | |
301 | 327 | | |
302 | 328 | | |
303 | | - | |
| 329 | + | |
304 | 330 | | |
305 | 331 | | |
306 | 332 | | |
307 | 333 | | |
308 | 334 | | |
309 | 335 | | |
310 | 336 | | |
| 337 | + | |
311 | 338 | | |
0 commit comments