Skip to content

Commit a6037ca

Browse files
committed
Don't output derivative lines for constant nodes
1 parent a284c15 commit a6037ca

File tree

6 files changed

+206
-13
lines changed

6 files changed

+206
-13
lines changed

package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@davepagurek/glsl-autodiff",
3-
"version": "0.0.19",
3+
"version": "0.0.20",
44
"main": "build/autodiff.js",
55
"author": "Dave Pagurek <[email protected]>",
66
"license": "MIT",

src/arithmetic.ts

+9-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@ export class Mult extends Op {
2424
}
2525
derivative(param: Param) {
2626
const [f, g] = this.dependsOn
27-
return `${f.ref()}*${g.derivRef(param)}+${g.ref()}*${f.derivRef(param)}`
27+
const fIsConst = f.isConst(param)
28+
const gIsConst = g.isConst(param)
29+
if (fIsConst && !gIsConst) {
30+
return `${f.ref()}*${g.derivRef(param)}`
31+
} else if (!fIsConst && gIsConst) {
32+
return `${g.ref()}*${f.derivRef(param)}`
33+
} else {
34+
return `${f.ref()}*${g.derivRef(param)}+${g.ref()}*${f.derivRef(param)}`
35+
}
2836
}
2937
}
3038

src/base.ts

+21-8
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,14 @@ export abstract class Op {
105105
}
106106
}
107107

108+
public zeroDerivative() {
109+
return '0.0'
110+
}
111+
108112
public derivRef(param: Param): string {
109-
if (this.useTempVar()) {
113+
if (this.isConst(param)) {
114+
return this.zeroDerivative()
115+
} else if (this.useTempVar()) {
110116
return `_glslad_dv${this.id}_d${param.safeName()}`
111117
} else {
112118
return `(${this.derivative(param)})`
@@ -127,15 +133,15 @@ export abstract class Op {
127133
}
128134

129135
public derivInitializer(param: Param): string {
130-
if (this.useTempVar()) {
131-
return `${this.glslType()} ${this.derivRef(param)}=${this.derivative(param)};\n`
132-
} else {
136+
if (this.isConst(param) || !this.useTempVar()) {
133137
return ''
138+
} else {
139+
return `${this.glslType()} ${this.derivRef(param)}=${this.derivative(param)};\n`
134140
}
135141
}
136142

137-
public isConst(): boolean {
138-
return this.dependsOn.every((op) => op.isConst())
143+
public isConst(param?: Param): boolean {
144+
return this.dependsOn.every((op) => op.isConst(param))
139145
}
140146

141147
public outputDependencies({ deps, derivDeps }: { deps: Set<Op>; derivDeps: Map<Param, Set<Op>> }): string {
@@ -162,6 +168,7 @@ export abstract class Op {
162168
}
163169

164170
public outputDerivDependencies(param: Param, { deps, derivDeps }: { deps: Set<Op>; derivDeps: Map<Param, Set<Op>> }): string {
171+
if (this.isConst()) return ''
165172
let code = ''
166173
for (const op of this.dependsOn) {
167174
if (!deps.has(op)) {
@@ -170,7 +177,7 @@ export abstract class Op {
170177
code += op.initializer()
171178
}
172179

173-
if (!derivDeps.get(param)?.has(op)) {
180+
if (!derivDeps.get(param)?.has(op) && !op.isConst(param)) {
174181
const paramDerivDeps = derivDeps.get(param) ?? new Set<Op>()
175182
paramDerivDeps.add(op)
176183
derivDeps.set(param, paramDerivDeps)
@@ -318,7 +325,13 @@ export class Param extends OpLiteral {
318325
}).join('') + this.id // Add id to ensure uniqueness
319326
}
320327

321-
isConst() { return false }
328+
isConst(param?: Param) {
329+
if (param) {
330+
return param !== this
331+
} else {
332+
return false
333+
}
334+
}
322335
definition() { return this.name }
323336
derivative(param: Param) {
324337
if (param === this) {

src/vecBase.ts

+27-3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ export class VecParamElementRef extends Param {
3333
this.ad.registerParam(this, this.name)
3434
}
3535

36+
public isConst(param?: Param) {
37+
if (param) {
38+
return param !== this
39+
} else {
40+
return false
41+
}
42+
}
43+
3644
definition() { return `${this.dependsOn[0].ref()}.${this.prop}` }
3745
derivative(param: Param) {
3846
if (param === this) {
@@ -91,6 +99,14 @@ export abstract class VectorOp extends Op {
9199
}
92100
}
93101

102+
public glslType() {
103+
return `vec${this.size()}`
104+
}
105+
106+
zeroDerivative() {
107+
return `${this.glslType()}(0.0)`
108+
}
109+
94110
public u() { return this.x() }
95111
public v() { return this.y() }
96112
public r() { return this.x() }
@@ -363,13 +379,21 @@ export class VecParam extends VectorOp {
363379
return `vec${this.size()}(${this.getElems().map((el) => el.derivRef(param)).join(',')})`
364380
}
365381

382+
public isConst(param?: Param) {
383+
if (param) {
384+
return param !== this.x() && param !== this.y() && param !== this.z()
385+
} else {
386+
return false
387+
}
388+
}
389+
366390
public override initializer() { return '' }
367391
public override ref() { return this.definition() }
368392
public override derivInitializer(param: Param) {
369-
if (this.useTempVar()) {
370-
return `vec${this.size()} ${this.derivRef(param)}=${this.derivative(param)};\n`
371-
} else {
393+
if (this.isConst(param) || !this.useTempVar()) {
372394
return ''
395+
} else {
396+
return `vec${this.size()} ${this.derivRef(param)}=${this.derivative(param)};\n`
373397
}
374398
}
375399
}

test/simple-wiggle/index.html

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<!DOCTYPE html>
2+
<html>
3+
<head>
4+
<title>glsl-autodiff test</title>
5+
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.3.1/p5.min.js"></script>
6+
<script type="text/javascript" src="../../build/autodiff.js"></script>
7+
<script type="text/javascript" src="test.js"></script>
8+
</head>
9+
<body>
10+
</body>
11+
</html>

test/simple-wiggle/test.js

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
const vert = `
2+
attribute vec3 aPosition;
3+
attribute vec3 aNormal;
4+
attribute vec2 aTexCoord;
5+
6+
uniform mat4 uModelViewMatrix;
7+
uniform mat4 uProjectionMatrix;
8+
uniform mat3 uNormalMatrix;
9+
10+
uniform float time;
11+
12+
varying vec2 vTexCoord;
13+
varying vec3 vNormal;
14+
varying vec3 vPosition;
15+
16+
void main(void) {
17+
vec4 objSpacePosition = vec4(aPosition, 1.0);
18+
float origZ = objSpacePosition.z;
19+
${AutoDiff.gen((ad) => {
20+
const pos = ad.vec3Param('objSpacePosition')
21+
const y = pos.y()
22+
const time = ad.param('time')
23+
24+
offset = ad.vec3(time.mult(0.005).add(y.mult(2)).sin().mult(0.5), 0, 0)
25+
offset.output('offset')
26+
offset.adjustNormal(ad.vec3Param('aNormal'), pos).output('normal')
27+
//offset.output('z')
28+
//offset.outputDeriv('dzdx', x)
29+
//offset.outputDeriv('dzdy', y)
30+
}, { debug: true, maxDepthPerVariable: 8 })}
31+
objSpacePosition.xyz += offset;
32+
//vec3 slopeX = vec3(1.0, 0.0, dzdx);
33+
//vec3 slopeY = vec3(0.0, 1.0, dzdy);
34+
vec4 worldSpacePosition = uModelViewMatrix * objSpacePosition;
35+
gl_Position = uProjectionMatrix * worldSpacePosition;
36+
vTexCoord = aTexCoord;
37+
vPosition = worldSpacePosition.xyz;
38+
//vNormal = uNormalMatrix * aNormal;
39+
//normal=cross(_glslad_v66,_glslad_v65);
40+
//normal=_glslad_v66;
41+
vNormal = uNormalMatrix * normal;
42+
}
43+
`
44+
console.log(vert)
45+
46+
const frag = `
47+
precision mediump float;
48+
const int MAX_LIGHTS = 3;
49+
50+
varying vec2 vTexCoord;
51+
varying vec3 vNormal;
52+
varying vec3 vPosition;
53+
54+
uniform sampler2D img;
55+
uniform int numLights;
56+
uniform vec3 lightPositions[MAX_LIGHTS];
57+
uniform vec3 lightColors[MAX_LIGHTS];
58+
uniform float lightStrengths[MAX_LIGHTS];
59+
uniform vec3 ambientLight;
60+
uniform float materialShininess;
61+
62+
void main(void) {
63+
vec3 materialColor = texture2D(img, vTexCoord).rgb;
64+
vec3 normal = normalize(vNormal);
65+
gl_FragColor = vec4(abs(normal), 1.); return;
66+
//gl_FragColor = length(vNormal) * vec4(1.); return;
67+
vec3 color = vec3(0.0, 0.0, 0.0);
68+
for (int i = 0; i < MAX_LIGHTS; i++) {
69+
if (i >= numLights) break;
70+
vec3 lightPosition = lightPositions[i];
71+
float distanceSquared = 0.0; /*0.00015*dot(
72+
lightPosition - vPosition,
73+
lightPosition - vPosition);*/
74+
vec3 lightDir = normalize(lightPosition - vPosition);
75+
float lambertian = max(dot(lightDir, normal), 0.0);
76+
color += lambertian * materialColor * lightColors[i] *
77+
lightStrengths[i] / (1.0 + distanceSquared);
78+
vec3 viewDir = normalize(-vPosition);
79+
float spec = pow(
80+
max(dot(viewDir, reflect(-lightDir, normal)), 0.0),
81+
materialShininess);
82+
color += spec * lightStrengths[i] * lightColors[i] /
83+
(1.0 + distanceSquared);
84+
}
85+
color += ambientLight * materialColor;
86+
gl_FragColor = vec4(color, 1.0);
87+
}
88+
`
89+
90+
let distortShader
91+
let texture
92+
function setup() {
93+
createCanvas(800, 600, WEBGL)
94+
distortShader = createShader(vert, frag)
95+
texture = createGraphics(500, 500)
96+
}
97+
98+
const lights = [{
99+
position: [200, 50, -100],
100+
color: [1, 1, 1],
101+
strength: 0.5,
102+
},
103+
{
104+
position: [-200, -50, -100],
105+
color: [1, 1, 1],
106+
strength: 0.5,
107+
},
108+
];
109+
110+
function draw() {
111+
texture.background(255, 0, 0)
112+
texture.fill(255)
113+
texture.noStroke()
114+
texture.textSize(70)
115+
texture.textAlign(CENTER, CENTER)
116+
texture.text('hello, world', texture.width / 2, texture.height / 2)
117+
118+
background(0)
119+
120+
const shininess = 1000
121+
const ambient = [0.2, 0.2, 0.2]
122+
123+
orbitControl()
124+
noStroke()
125+
shader(distortShader)
126+
distortShader.setUniform('img', texture)
127+
distortShader.setUniform('lightPositions', lights.map(l => l.position).flat())
128+
distortShader.setUniform('lightColors', lights.map(l => l.color).flat())
129+
distortShader.setUniform('lightStrengths', lights.map(l => l.strength).flat())
130+
distortShader.setUniform('numLights', lights.length)
131+
distortShader.setUniform('ambientLight', ambient)
132+
distortShader.setUniform('materialShininess', shininess)
133+
distortShader.setUniform('time', millis())
134+
push()
135+
sphere(200, 60, 30)
136+
pop()
137+
}

0 commit comments

Comments
 (0)