Skip to content

Commit 2bfda06

Browse files
reczkokiwoplaza
andauthored
fix: Function call member access type inference (#1062)
* Fix function member access type inference * Add WGSL void type and error handling * Move $repr symbol from type import to value import * Refactor internal symbol and dual implementation * Set Internal Implementation to jsImpl * Review fixes --------- Co-authored-by: Iwo Plaza <[email protected]>
1 parent a9548fc commit 2bfda06

File tree

13 files changed

+194
-52
lines changed

13 files changed

+194
-52
lines changed

packages/typegpu/src/core/function/tgpuFn.ts

+2-10
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,7 @@ function createFn<
221221
}),
222222
);
223223

224-
Object.defineProperty(call, $internal, {
225-
value: {
226-
implementation,
227-
},
228-
});
224+
call[$internal].implementation = implementation;
229225

230226
const fn = Object.assign(call, fnBase as This) as unknown as TgpuFn<
231227
Args,
@@ -310,11 +306,7 @@ function createBoundFunction<
310306
},
311307
});
312308

313-
Object.defineProperty(fn, $internal, {
314-
value: {
315-
implementation: innerFn[$internal].implementation,
316-
},
317-
});
309+
fn[$internal].implementation = innerFn[$internal].implementation;
318310

319311
return fn;
320312
}

packages/typegpu/src/core/resolve/resolveData.ts

+4
Original file line numberDiff line numberDiff line change
@@ -257,5 +257,9 @@ export function resolveData(ctx: ResolutionCtx, data: AnyData): string {
257257
throw new Error('Abstract types have no concrete representation in WGSL');
258258
}
259259

260+
if (data.type === 'void') {
261+
throw new Error('Void has no representation in WGSL');
262+
}
263+
260264
assertExhaustive(data, 'resolveData');
261265
}

packages/typegpu/src/data/dataTypes.ts

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,19 @@ import type {
66
InferPartialRecord,
77
InferRecord,
88
} from '../shared/repr.js';
9+
import { $internal } from '../shared/symbols.js';
910
import type { Prettify } from '../shared/utilityTypes.js';
1011
import { vertexFormats } from '../shared/vertexFormat.js';
1112
import type { PackedData } from './vertexFormatData.js';
1213
import * as wgsl from './wgslTypes.js';
1314

15+
export type TgpuDualFn<TImpl extends (...args: unknown[]) => unknown> =
16+
TImpl & {
17+
[$internal]: {
18+
implementation: TImpl | string;
19+
};
20+
};
21+
1422
/**
1523
* Array schema constructed via `d.disarrayOf` function.
1624
*
@@ -139,5 +147,5 @@ export function isData(value: unknown): value is AnyData {
139147
export type AnyData = wgsl.AnyWgslData | AnyLooseData;
140148
export type AnyConcreteData = Exclude<
141149
AnyData,
142-
wgsl.AbstractInt | wgsl.AbstractFloat
150+
wgsl.AbstractInt | wgsl.AbstractFloat | wgsl.Void
143151
>;

packages/typegpu/src/data/matrix.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { $repr } from '../shared/repr.js';
33
import type { SelfResolvable } from '../types';
44
import { vec2f, vec3f, vec4f } from './vector';
55
import type {
6+
AnyWgslData,
67
Mat2x2f,
78
Mat3x3f,
89
Mat4x4f,
@@ -53,7 +54,7 @@ function createMatSchema<
5354
[$repr]: undefined as unknown as ValueType,
5455
type: options.type,
5556
label: options.type,
56-
};
57+
} as unknown as AnyWgslData;
5758

5859
const construct = createDualImpl(
5960
// CPU implementation

packages/typegpu/src/data/numeric.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ const u32Cast = createDualImpl(
6060
*/
6161
export const u32: U32 = Object.assign(u32Cast, {
6262
type: 'u32',
63-
}) as U32;
63+
}) as unknown as U32;
6464

6565
const i32Cast = createDualImpl(
6666
// CPU implementation
@@ -99,7 +99,7 @@ const i32Cast = createDualImpl(
9999
*/
100100
export const i32: I32 = Object.assign(i32Cast, {
101101
type: 'i32',
102-
}) as I32;
102+
}) as unknown as I32;
103103

104104
const f32Cast = createDualImpl(
105105
// CPU implementation
@@ -127,7 +127,7 @@ const f32Cast = createDualImpl(
127127
*/
128128
export const f32: F32 = Object.assign(f32Cast, {
129129
type: 'f32',
130-
}) as F32;
130+
}) as unknown as F32;
131131

132132
const f16Cast = createDualImpl(
133133
// CPU implementation
@@ -158,4 +158,4 @@ const f16Cast = createDualImpl(
158158
*/
159159
export const f16: F16 = Object.assign(f16Cast, {
160160
type: 'f16',
161-
}) as F16;
161+
}) as unknown as F16;

packages/typegpu/src/data/wgslTypes.ts

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import type { TgpuNamable } from '../namable.js';
22
import type {
3-
$repr,
43
Infer,
54
InferGPU,
65
InferGPURecord,
@@ -10,6 +9,7 @@ import type {
109
MemIdentity,
1110
MemIdentityRecord,
1211
} from '../shared/repr.js';
12+
import { $repr } from '../shared/repr.js';
1313
import type { Prettify } from '../shared/utilityTypes.js';
1414

1515
type DecoratedLocation<T extends BaseData> = Decorated<T, Location<number>[]>;
@@ -50,6 +50,12 @@ export interface AbstractFloat {
5050
readonly [$repr]: number;
5151
}
5252

53+
export interface Void {
54+
readonly type: 'void';
55+
readonly [$repr]: undefined;
56+
}
57+
export const Void: Void = { type: 'void', [$repr]: undefined };
58+
5359
interface Swizzle2<T2, T3, T4> {
5460
readonly xx: T2;
5561
readonly xy: T2;
@@ -1158,6 +1164,7 @@ export const wgslTypeLiterals = [
11581164
'decorated',
11591165
'abstractInt',
11601166
'abstractFloat',
1167+
'void',
11611168
] as const;
11621169

11631170
export type WgslTypeLiteral = (typeof wgslTypeLiterals)[number];
@@ -1222,7 +1229,8 @@ export type AnyWgslData =
12221229
| Atomic
12231230
| Decorated
12241231
| AbstractInt
1225-
| AbstractFloat;
1232+
| AbstractFloat
1233+
| Void;
12261234

12271235
// #endregion
12281236

packages/typegpu/src/shared/generators.ts

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import type { TgpuDualFn } from '../data/dataTypes';
12
import { inGPUMode } from '../gpuMode';
23
import type { Resource } from '../types';
4+
import { $internal } from './symbols';
35

46
/**
57
* Yields values in the sequence 0,1,2..∞ except for the ones in the `excluded` set.
@@ -24,8 +26,8 @@ type MapValueToResource<T> = { [K in keyof T]: Resource };
2426
export function createDualImpl<T extends (...args: any[]) => any>(
2527
jsImpl: T,
2628
gpuImpl: (...args: MapValueToResource<Parameters<T>>) => Resource,
27-
): T {
28-
return ((...args: Parameters<T>) => {
29+
): TgpuDualFn<T> {
30+
const impl = ((...args: Parameters<T>) => {
2931
if (inGPUMode()) {
3032
return gpuImpl(
3133
...(args as unknown as MapValueToResource<Parameters<T>>),
@@ -34,4 +36,12 @@ export function createDualImpl<T extends (...args: any[]) => any>(
3436
// biome-ignore lint/suspicious/noExplicitAny: <it's very convenient>
3537
return jsImpl(...(args as any));
3638
}) as T;
39+
40+
Object.defineProperty(impl, $internal, {
41+
value: {
42+
implementation: jsImpl,
43+
},
44+
});
45+
46+
return impl as TgpuDualFn<T>;
3747
}

packages/typegpu/src/smol/generationHelpers.ts

+13-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { isDerived, isSlot } from '../core/slot/slotTypes';
2+
import type { AnyData } from '../data/dataTypes';
23
import { mat2x2f, mat3x3f, mat4x4f } from '../data/matrix';
34
import {
45
abstractFloat,
@@ -27,8 +28,6 @@ import {
2728
vec4u,
2829
} from '../data/vector';
2930
import {
30-
type AnyWgslData,
31-
type BaseData,
3231
type WgslStruct,
3332
isDecorated,
3433
isWgslArray,
@@ -67,7 +66,7 @@ type SwizzleLength = 1 | 2 | 3 | 4;
6766

6867
const swizzleLenToType: Record<
6968
SwizzleableType,
70-
Record<SwizzleLength, AnyWgslData>
69+
Record<SwizzleLength, AnyData>
7170
> = {
7271
f: {
7372
1: f32,
@@ -146,7 +145,7 @@ const indexableTypeToResult = {
146145
export function getTypeForPropAccess(
147146
targetType: Wgsl,
148147
propName: string,
149-
): BaseData | UnknownData {
148+
): AnyData | UnknownData {
150149
if (
151150
typeof targetType === 'string' ||
152151
typeof targetType === 'number' ||
@@ -164,23 +163,25 @@ export function getTypeForPropAccess(
164163
}
165164
const unwrapped = ctx.unwrap(targetType);
166165

167-
return getTypeFromWgsl(unwrapped as Wgsl) as BaseData;
166+
return getTypeFromWgsl(unwrapped);
168167
}
169168

170-
let target = targetType as BaseData;
169+
let target = targetType as AnyData;
171170

172171
if (hasInternalDataType(target)) {
173-
target = target[$internal].dataType;
172+
target = target[$internal].dataType as AnyData;
174173
}
175174
while (isDecorated(target)) {
176-
target = target.inner;
175+
target = target.inner as AnyData;
177176
}
178177

179178
const targetTypeStr =
180179
'kind' in target ? (target.kind as string) : target.type;
181180

182181
if (targetTypeStr === 'struct') {
183-
return (target as WgslStruct).propTypes[propName] ?? UnknownData;
182+
return (
183+
((target as WgslStruct).propTypes[propName] as AnyData) ?? UnknownData
184+
);
184185
}
185186

186187
const propLength = propName.length;
@@ -204,11 +205,11 @@ export function getTypeForPropAccess(
204205
return isWgslData(target) ? target : UnknownData;
205206
}
206207

207-
export function getTypeForIndexAccess(resource: Wgsl): BaseData | UnknownData {
208+
export function getTypeForIndexAccess(resource: Wgsl): AnyData | UnknownData {
208209
if (isWgslData(resource)) {
209210
// array
210211
if (isWgslArray(resource)) {
211-
return resource.elementType;
212+
return resource.elementType as AnyData;
212213
}
213214

214215
// vector or matrix
@@ -222,7 +223,7 @@ export function getTypeForIndexAccess(resource: Wgsl): BaseData | UnknownData {
222223
return UnknownData;
223224
}
224225

225-
export function getTypeFromWgsl(resource: Wgsl): BaseData | UnknownData {
226+
export function getTypeFromWgsl(resource: Wgsl): AnyData | UnknownData {
226227
if (isDerived(resource) || isSlot(resource)) {
227228
return getTypeFromWgsl(resource.value as Wgsl);
228229
}

packages/typegpu/src/smol/wgslGenerator.ts

+42-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import type * as smol from 'tinyest';
22
import * as d from '../data';
3+
import type { AnyData } from '../data/dataTypes';
34
import { abstractInt } from '../data/numeric.js';
45
import * as wgsl from '../data/wgslTypes.js';
56
import {
67
type ResolutionCtx,
78
type Resource,
89
UnknownData,
9-
type Wgsl,
10+
isMarkedInternal,
1011
isWgsl,
1112
} from '../types.js';
1213
import {
@@ -46,8 +47,8 @@ type Operator =
4647
| smol.UnaryOperator;
4748

4849
function operatorToType<
49-
TL extends wgsl.AnyWgslData | UnknownData,
50-
TR extends wgsl.AnyWgslData | UnknownData,
50+
TL extends AnyData | UnknownData,
51+
TR extends AnyData | UnknownData,
5152
>(lhs: TL, op: Operator, rhs?: TR): TL | TR | wgsl.Bool {
5253
if (!rhs) {
5354
if (op === '!' || op === '~') {
@@ -190,7 +191,9 @@ export function generateExpression(
190191
if (typeof target.value === 'string') {
191192
return {
192193
value: `${target.value}.${property}`,
193-
dataType: getTypeForPropAccess(target.dataType as Wgsl, property),
194+
dataType: d.isData(target.dataType)
195+
? getTypeForPropAccess(target.dataType, property)
196+
: UnknownData,
194197
};
195198
}
196199

@@ -214,10 +217,24 @@ export function generateExpression(
214217
// biome-ignore lint/suspicious/noExplicitAny: <sorry TypeScript>
215218
const propValue = (target.value as any)[property];
216219

220+
if (target.dataType.type !== 'unknown') {
221+
if (wgsl.isMat(target.dataType) && property === 'columns') {
222+
return {
223+
value: target.value,
224+
dataType: target.dataType,
225+
};
226+
}
227+
228+
return {
229+
value: propValue,
230+
dataType: getTypeForPropAccess(target.dataType, property),
231+
};
232+
}
233+
217234
if (isWgsl(target.value)) {
218235
return {
219236
value: propValue,
220-
dataType: getTypeForPropAccess(target.value as d.AnyWgslData, property),
237+
dataType: getTypeForPropAccess(target.value, property),
221238
};
222239
}
223240

@@ -245,7 +262,9 @@ export function generateExpression(
245262

246263
return {
247264
value: `${targetStr}[${propertyStr}]`,
248-
dataType: getTypeForIndexAccess(targetExpr.dataType as d.AnyWgslData),
265+
dataType: d.isData(targetExpr.dataType)
266+
? getTypeForIndexAccess(targetExpr.dataType)
267+
: UnknownData,
249268
};
250269
}
251270

@@ -291,10 +310,21 @@ export function generateExpression(
291310
};
292311
}
293312

313+
if (!isMarkedInternal(idValue)) {
314+
throw new Error(
315+
`Function ${String(idValue)} has not been created using TypeGPU APIs. Did you mean to wrap the function with tgpu.fn(args, return)(...) ?`,
316+
);
317+
}
318+
294319
// Assuming that `id` is callable
295-
return (idValue as unknown as (...args: unknown[]) => unknown)(
320+
const fnRes = (idValue as unknown as (...args: unknown[]) => unknown)(
296321
...resolvedResources,
297322
) as Resource;
323+
324+
return {
325+
value: resolveRes(ctx, fnRes),
326+
dataType: fnRes.dataType,
327+
};
298328
}
299329

300330
if ('o' in expression) {
@@ -369,7 +399,7 @@ export function generateExpression(
369399

370400
return {
371401
value: `${arrayType}( ${arrayValues.join(', ')} )`,
372-
dataType: d.arrayOf(type as d.AnyWgslData, values.length),
402+
dataType: d.arrayOf(type, values.length) as d.AnyWgslData,
373403
};
374404
}
375405

@@ -449,6 +479,10 @@ ${alternate}`;
449479
throw new Error('Cannot create variable without an initial value.');
450480
}
451481

482+
if (d.isLooseData(eq.dataType)) {
483+
throw new Error('Cannot create variable with loose data type.');
484+
}
485+
452486
registerBlockVariable(ctx, rawId, eq.dataType);
453487
const id = resolveRes(ctx, generateIdentifier(ctx, rawId));
454488

0 commit comments

Comments
 (0)