Skip to content

Commit 1745524

Browse files
committed
prefix_sum
1 parent d83ef1e commit 1745524

File tree

5 files changed

+124
-168
lines changed

5 files changed

+124
-168
lines changed

examples/jsm/gpgpu/BitonicSort.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ export class BitonicSort {
119119
*
120120
* @type {StorageBufferNode}
121121
*/
122-
this.workgroupSize = options.workgroupSize ? Math.min( this.dispatchSize, options.workgroupSize ) : Math.min( this.dispatchSize, 64 );
122+
this.workgroupSize = options.workgroupSize ? Math.min( this.dispatchSize, options.workgroupSize ) : Math.min( this.dispatchSize, this.renderer.backend.device.limits.maxComputeWorkgroupSizeX );
123123

124124
/**
125125
* A node representing a workgroup scoped buffer that holds locally sorted elements.

examples/jsm/gpgpu/PrefixSum.js

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
import { Fn, If, instancedArray, invocationLocalIndex, countTrailingZeros, Loop, workgroupArray, subgroupSize, workgroupBarrier, workgroupId, uint, select, invocationSubgroupIndex, dot, uvec4, vec4, float, subgroupAdd, array, subgroupShuffle, subgroupInclusiveAdd, subgroupBroadcast, invocationSubgroupMetaIndex, arrayBuffer } from 'three/tsl';
1+
import {
2+
StorageInstancedBufferAttribute
3+
} from 'three';
4+
import { Fn, If, instancedArray, invocationLocalIndex, countTrailingZeros, Loop, workgroupArray, subgroupSize, workgroupBarrier, workgroupId, uint, select, invocationSubgroupIndex, dot, uvec4, vec4, float, subgroupAdd, array, subgroupShuffle, subgroupInclusiveAdd, subgroupBroadcast, invocationSubgroupMetaIndex, arrayBuffer, storage } from 'three/tsl';
25

36
const divRoundUp = ( size, part_size ) => {
47

@@ -68,6 +71,12 @@ export class PrefixSum {
6871
*/
6972
this.renderer = renderer;
7073

74+
if ( this.renderer.backend.device === null ) {
75+
76+
renderer.backend.init();
77+
78+
}
79+
7180
/**
7281
* @type {PrefixSumStorageObjects}
7382
*/
@@ -132,7 +141,14 @@ export class PrefixSum {
132141
*
133142
* @type {number}
134143
*/
135-
this.workgroupSize = options.workgroupSize ? options.workgroupSize : Math.min( this.vecCount, 64 );
144+
this.workgroupSize = options.workgroupSize ? options.workgroupSize : Math.min( this.vecCount, this.renderer.backend.device.limits.maxComputeWorkgroupSizeX );
145+
146+
/**
147+
* The minimumn subgroup size specified by the renderer's graphics device.
148+
*
149+
* @type {number}
150+
*/
151+
this.minSubgroupSize = ( this.renderer.backend.device.adapterInfo && this.renderer.backend.device.adapterInfo.subgroupMinSize ) ? this.renderer.backend.device.adapterInfo.subgroupMinSize : 4;
136152

137153
/**
138154
* The maximum number of elements that will be read by an individual workgroup in the reduction step.
@@ -179,10 +195,17 @@ export class PrefixSum {
179195
_createStorageBuffers( inputArray ) {
180196

181197
this.arrayBuffer = this.type === 'uint' ? Uint32Array.from( inputArray ) : Float32Array.from( inputArray );
198+
this.outputArrayBuffer = this.type === 'uint' ? Uint32Array.from( inputArray ) : Float32Array.from( inputArray );
199+
200+
const inputAttribute = new StorageInstancedBufferAttribute( this.arrayBuffer, 1 );
201+
const outputAttribute = new StorageInstancedBufferAttribute( this.outputArrayBuffer, 1 );
202+
203+
this.storageBuffers.dataBuffer = storage( inputAttribute, this.vecType, inputAttribute.count / 4 ).setName( `Prefix_Sum_Input_Vec_${id}` );
204+
this.storageBuffers.unvectorizedDataBuffer = storage( inputAttribute, this.type, inputAttribute.count ).setName( `Prefix_Sum_Input_Unvec_${id}` );
205+
206+
this.storageBuffers.outputBuffer = storage( outputAttribute, this.vecType, outputAttribute.count / 4 ).setName( `Prefix_Sum_Output_Vec_${id}` );
207+
this.storageBuffers.unvectorizedOutputBuffer = storage( outputAttribute, this.type, outputAttribute.count ).setName( `Prefix_Sum_Output_Unvec_${id}` );
182208

183-
this.storageBuffers.unvectorizedDataBuffer = instancedArray( this.arrayBuffer, this.type ).setPBO( true ).setName( `Prefix_Sum_Input_Unvec_${id}` );
184-
this.storageBuffers.dataBuffer = instancedArray( this.arrayBuffer, this.vecType ).setPBO( true ).setName( `Prefix_Sum_Input_Vec_${id}` );
185-
this.storageBuffers.outputBuffer = instancedArray( this.arrayBuffer, this.vecType ).setName( `Prefix_Sum_Output_${id}` );
186209
this.storageBuffers.reductionBuffer = instancedArray( this.numWorkgroups, this.type ).setPBO( true ).setName( `Prefix_Sum_Reduction_${id}` );
187210

188211
}
@@ -472,6 +495,19 @@ export class PrefixSum {
472495
_getSpineScanFn() {
473496

474497
const { reductionBuffer } = this.storageBuffers;
498+
499+
if ( this.numWorkgroups <= this.minSubgroupSize ) {
500+
501+
const fnDef = Fn( () => {
502+
503+
reductionBuffer.element( invocationSubgroupIndex ).assign( subgroupInclusiveAdd( reductionBuffer.element( invocationSubgroupIndex ) ) );
504+
505+
} )().compute( this.numWorkgroups, [ this.workgroupSize ] );
506+
507+
return fnDef;
508+
509+
}
510+
475511
const { subgroupReductionArray, unvectorizedSubgroupOffset, spineSize, subgroupSizeLog } = this.utilityNodes;
476512
const { unvectorizedWorkPerInvocation } = this;
477513

@@ -630,16 +666,13 @@ export class PrefixSum {
630666

631667
} )().compute( this.numWorkgroups, [ this.workgroupSize ] );
632668

633-
console.log( fnDef );
634-
635669
return fnDef;
636670

637671
}
638672

639673
_getDownsweepFn() {
640674

641675
const { dataBuffer, reductionBuffer, outputBuffer } = this.storageBuffers;
642-
const { vecType } = this;
643676
const { subgroupOffset, workgroupOffset, subgroupReductionArray, subgroupSizeLog, spineSize } = this.utilityNodes;
644677

645678
const { workPerInvocation, vecCount } = this;
@@ -958,9 +991,9 @@ export class PrefixSum {
958991
*/
959992
async compute() {
960993

961-
await this.computeStep( this.currentStep );
962-
await this.computeStep( this.currentStep );
963-
await this.computeStep( this.currentStep );
994+
await this.computeReduce();
995+
await this.computeSpineScan();
996+
await this.computeDownsweep();
964997

965998
}
966999

0 commit comments

Comments
 (0)