Skip to content

Commit 55bc6b7

Browse files
authored
Merge pull request #14 from davepagurek/feat/normal-output
Add adjustNormal() method
2 parents 42f10c9 + fe18e43 commit 55bc6b7

File tree

8 files changed

+467
-49
lines changed

8 files changed

+467
-49
lines changed

package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@davepagurek/glsl-autodiff",
3-
"version": "0.0.15",
3+
"version": "0.0.16",
44
"main": "build/autodiff.js",
55
"author": "Dave Pagurek <[email protected]>",
66
"license": "MIT",

src/base.ts

+138-7
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ export abstract class Op {
5656
public dependsOn: Op[]
5757
public usedIn: Op[] = []
5858
public srcLine: string = ''
59+
public internalDerivatives: { op: Op, param: Param }[] = []
5960

6061
constructor(ad: ADBase, ...params: Op[]) {
6162
this.ad = ad
@@ -106,7 +107,7 @@ export abstract class Op {
106107

107108
public derivRef(param: Param): string {
108109
if (this.useTempVar()) {
109-
return `_glslad_dv${this.id}_d${param.name}`
110+
return `_glslad_dv${this.id}_d${param.safeName()}`
110111
} else {
111112
return `(${this.derivative(param)})`
112113
}
@@ -137,15 +138,58 @@ export abstract class Op {
137138
return this.dependsOn.every((op) => op.isConst())
138139
}
139140

140-
public deepDependencies(): Set<Op> {
141-
const deps = new Set<Op>()
141+
public outputDependencies({ deps, derivDeps }: { deps: Set<Op>; derivDeps: Map<Param, Set<Op>> }): string {
142+
let code = ''
142143
for (const op of this.dependsOn) {
143-
for (const dep of op.deepDependencies().values()) {
144-
deps.add(dep)
144+
if (!deps.has(op)) {
145+
deps.add(op)
146+
code += op.outputDependencies({ deps, derivDeps })
147+
code += op.initializer()
145148
}
146-
deps.add(op)
147149
}
148-
return deps
150+
151+
for (const { param, op } of this.internalDerivatives) {
152+
if (!derivDeps.get(param)?.has(op)) {
153+
const paramDerivDeps = derivDeps.get(param) ?? new Set<Op>()
154+
paramDerivDeps.add(op)
155+
derivDeps.set(param, paramDerivDeps)
156+
code += op.outputDerivDependencies(param, { deps, derivDeps })
157+
code += op.derivInitializer(param)
158+
}
159+
}
160+
161+
return code
162+
}
163+
164+
public outputDerivDependencies(param: Param, { deps, derivDeps }: { deps: Set<Op>; derivDeps: Map<Param, Set<Op>> }): string {
165+
let code = ''
166+
for (const op of this.dependsOn) {
167+
if (!deps.has(op)) {
168+
deps.add(op)
169+
code += op.outputDependencies({ deps, derivDeps })
170+
code += op.initializer()
171+
}
172+
173+
if (!derivDeps.get(param)?.has(op)) {
174+
const paramDerivDeps = derivDeps.get(param) ?? new Set<Op>()
175+
paramDerivDeps.add(op)
176+
derivDeps.set(param, paramDerivDeps)
177+
code += op.outputDerivDependencies(param, { deps, derivDeps })
178+
code += op.derivInitializer(param)
179+
}
180+
}
181+
182+
for (const { param, op } of this.internalDerivatives) {
183+
if (!derivDeps.get(param)?.has(op)) {
184+
const paramDerivDeps = derivDeps.get(param) ?? new Set<Op>()
185+
paramDerivDeps.add(op)
186+
derivDeps.set(param, paramDerivDeps)
187+
code += op.outputDependencies({ deps, derivDeps })
188+
code += op.derivInitializer(param)
189+
}
190+
}
191+
192+
return code
149193
}
150194

151195
public output(name: string) { return this.ad.output(name, this) }
@@ -155,6 +199,81 @@ export abstract class Op {
155199
public abstract derivative(param: Param): string
156200
}
157201

202+
export abstract class BooleanOp extends Op {
203+
abstract operator(): string
204+
definition() {
205+
return this.dependsOn.map((op) => op.ref()).join(this.operator())
206+
}
207+
derivative(): string {
208+
throw new Error('unimplemented')
209+
}
210+
isConst() {
211+
// They might not actually be constant, but we don't have derivatives
212+
// for these so we just treat them like they are
213+
return true;
214+
}
215+
glslType() {
216+
return 'bool'
217+
}
218+
}
219+
220+
export class EqOp extends BooleanOp {
221+
operator() {
222+
return '=='
223+
}
224+
}
225+
226+
export class NeOp extends BooleanOp {
227+
operator() {
228+
return '!='
229+
}
230+
}
231+
232+
export class LtOp extends BooleanOp {
233+
operator() {
234+
return '<'
235+
}
236+
}
237+
238+
export class LeOp extends BooleanOp {
239+
operator() {
240+
return '<='
241+
}
242+
}
243+
244+
export class GtOp extends BooleanOp {
245+
operator() {
246+
return '>'
247+
}
248+
}
249+
250+
export class GeOp extends BooleanOp {
251+
operator() {
252+
return '>='
253+
}
254+
}
255+
256+
export class NotOp extends BooleanOp {
257+
operator() {
258+
return '!'
259+
}
260+
definition() {
261+
return this.operator() + this.dependsOn[0].ref()
262+
}
263+
}
264+
265+
export class AndOp extends BooleanOp {
266+
operator() {
267+
return '&&'
268+
}
269+
}
270+
271+
export class OrOp extends BooleanOp {
272+
operator() {
273+
return '||'
274+
}
275+
}
276+
158277
export abstract class OpLiteral extends Op {
159278
public override initializer() { return '' }
160279
public override derivInitializer() { return '' }
@@ -187,6 +306,18 @@ export class Param extends OpLiteral {
187306
this.ad.registerParam(this, name)
188307
}
189308

309+
safeName() {
310+
// A version of the name that can be used in temp variable
311+
// names
312+
return this.name.split('').map((c) => {
313+
if (c.match(/[\w\d]/)) {
314+
return c
315+
} else {
316+
return '_'
317+
}
318+
}).join('') + this.id // Add id to ensure uniqueness
319+
}
320+
190321
isConst() { return false }
191322
definition() { return this.name }
192323
derivative(param: Param) {

src/functions.ts

+9
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ export class IfElse extends Op {
7171
}
7272
}
7373

74+
export class Abs extends Op {
75+
definition() {
76+
return `abs(${this.dependsOn[0].ref()})`
77+
}
78+
derivative(param: Param) {
79+
return '0.0'
80+
}
81+
}
82+
7483
declare module './base' {
7584
interface Op {
7685
sin(): Op

src/index.ts

+10-34
Original file line numberDiff line numberDiff line change
@@ -58,47 +58,23 @@ class AutoDiffImpl implements ADBase {
5858

5959
// Add initializers for outputs
6060
const deps = new Set<Op>()
61-
for (const name in this.outputs) {
62-
for (const dep of this.outputs[name].deepDependencies().values()) {
63-
deps.add(dep)
64-
}
65-
deps.add(this.outputs[name])
66-
}
67-
for (const param in this.derivOutputs) {
68-
for (const name in this.derivOutputs[param]) {
69-
for (const dep of this.derivOutputs[param][name].deepDependencies().values()) {
70-
deps.add(dep)
71-
}
72-
deps.add(this.derivOutputs[param][name])
73-
}
74-
}
75-
for (const dep of deps.values()) {
76-
code += dep.initializer()
77-
}
61+
const derivDeps = new Map<Param, Set<Op>>()
7862

79-
// Add outputs
8063
for (const name in this.outputs) {
64+
code += this.outputs[name].outputDependencies({ deps, derivDeps })
65+
code += this.outputs[name].initializer()
66+
deps.add(this.outputs[name])
8167
code += `${this.outputs[name].glslType()} ${name}=${this.outputs[name].ref()};\n`
8268
}
83-
8469
for (const param in this.derivOutputs) {
8570
const paramOp = this.params[param]
86-
87-
// Add initializers for derivative outputs
88-
const derivDeps = new Set<Op>()
89-
for (const name in this.derivOutputs[param]) {
90-
for (const dep of this.derivOutputs[param][name].deepDependencies().values()) {
91-
derivDeps.add(dep)
92-
}
93-
derivDeps.add(this.derivOutputs[param][name])
94-
}
95-
for (const dep of derivDeps.values()) {
96-
code += dep.derivInitializer(paramOp)
97-
}
98-
99-
// Add derivative outputs
10071
for (const name in this.derivOutputs[param]) {
72+
code += this.derivOutputs[param][name].outputDerivDependencies(paramOp, { deps, derivDeps })
73+
code += this.derivOutputs[param][name].derivInitializer(paramOp)
10174
code += `${this.derivOutputs[param][name].glslType()} ${name}=${this.derivOutputs[param][name].derivRef(paramOp)};\n`
75+
const paramDerivDeps = derivDeps.get(paramOp) ?? new Set<Op>()
76+
paramDerivDeps.add(this.derivOutputs[param][name])
77+
derivDeps.set(paramOp, paramDerivDeps)
10278
}
10379
}
10480

@@ -109,7 +85,7 @@ class AutoDiffImpl implements ADBase {
10985
// TODO figure out a better way of writing this that Typescript can still infer the type of
11086
const ExtendedAD = WithVecFunctions(WithVecArithmetic(WithVecBase(WithFunctions(WithArithmetic(AutoDiffImpl)))))
11187
type GetType<T> = T extends new (...args: any[]) => infer V ? V : never
112-
type AD = GetType<typeof ExtendedAD>
88+
export type AD = GetType<typeof ExtendedAD>
11389

11490
export const gen = (cb: (ad: AD) => void, settings: Partial<ADSettings> = {}): string => {
11591
const ad = new ExtendedAD()

src/vecBase.ts

+72-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Op, OpLiteral, ADBase, Param, Input, ADConstructor, UserInput } from './base'
1+
import { Op, OpLiteral, ADBase, Param, Input, ADConstructor, UserInput, EqOp } from './base'
22

33
export interface VecOp extends Op {
44
x(): Op
@@ -58,6 +58,8 @@ export function Cache(target: Object, propertyKey: string, descriptor: PropertyD
5858
return descriptor
5959
}
6060

61+
62+
6163
export abstract class VectorOp extends Op {
6264
scalar() { return false }
6365

@@ -243,6 +245,71 @@ export abstract class ScalarWithVecDependencies extends Op {
243245
}
244246
}
245247

248+
export class OffsetJacobian extends WithVecDependencies {
249+
constructor(ad, ...args) {
250+
super(ad, ...args)
251+
this.internalDerivatives.push(
252+
{ op: this.offset(), param: this.position().x() },
253+
{ op: this.offset(), param: this.position().y() },
254+
{ op: this.offset(), param: this.position().z() },
255+
)
256+
for (const { op } of this.internalDerivatives) {
257+
op.usedIn.push(this)
258+
}
259+
}
260+
261+
public size(): number {
262+
return 3
263+
}
264+
265+
private position() {
266+
return this.dependsOn[0] as VecParam
267+
}
268+
269+
private offset() {
270+
return this.dependsOn[1] as VectorOp
271+
}
272+
273+
glslType() {
274+
return 'mat3'
275+
}
276+
277+
definition() {
278+
const dodx = this.offset().derivRef(this.position().x())
279+
const dody = this.offset().derivRef(this.position().y())
280+
const dodz = this.offset().derivRef(this.position().z())
281+
return `mat3(${dodx},${dody},${dodz})`
282+
}
283+
derivative(_param: Param): string {
284+
throw new Error('Unimplemented')
285+
}
286+
287+
public dot(vec3: VectorOp) {
288+
return new Mat3Dot(this.ad, this, vec3)
289+
}
290+
}
291+
292+
export class Mat3Dot extends VectorOp {
293+
public size() {
294+
return 3
295+
}
296+
297+
private mat3() {
298+
return this.dependsOn[0] as OffsetJacobian
299+
}
300+
301+
private vec3() {
302+
return this.dependsOn[1] as VectorOp
303+
}
304+
305+
definition() {
306+
return `${this.mat3().ref()}*${this.vec3().ref()}`
307+
}
308+
derivative(_param: Param): string {
309+
throw new Error('Unimplemented')
310+
}
311+
}
312+
246313
export class Vec extends VectorOp {
247314
public size(): number {
248315
return this.dependsOn.length
@@ -273,16 +340,16 @@ export class VecParam extends VectorOp {
273340
}
274341

275342
@Cache
276-
public x(): Op { return new VecParamElementRef(this.ad, 'x', this) }
343+
public x() { return new VecParamElementRef(this.ad, 'x', this) }
277344

278345
@Cache
279-
public y(): Op { return new VecParamElementRef(this.ad, 'y', this) }
346+
public y() { return new VecParamElementRef(this.ad, 'y', this) }
280347

281348
@Cache
282-
public z(): Op { return new VecParamElementRef(this.ad, 'z', this) }
349+
public z() { return new VecParamElementRef(this.ad, 'z', this) }
283350

284351
@Cache
285-
public w(): Op { return new VecParamElementRef(this.ad, 'w', this) }
352+
public w() { return new VecParamElementRef(this.ad, 'w', this) }
286353

287354
private getElems() {
288355
return 'xyzw'

0 commit comments

Comments
 (0)