Skip to content

Commit ce546d8

Browse files
committed
Add cross-imp
1 parent 261e744 commit ce546d8

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

Diff for: src/process.js

+30-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ module.exports = class Process {
2121
dimensions: p['Dimensions'],
2222
column: p['Target variable'],
2323
transform: p['Transform'],
24-
method: p['Method'],
24+
method: p['Projection method'],
2525
steps: p['Steps'],
2626
importance: p['Feature importance']
2727
}
@@ -50,6 +50,9 @@ module.exports = class Process {
5050
const Xraw = new Matrix(this.records.map(row => features.map(f => row[f])))
5151
const featuresFiltered = []
5252

53+
let imp
54+
let impMatrix
55+
5356
// Remove columns with many NaNs
5457
console.log('[Vis] Transforming data')
5558
const cols = []
@@ -142,10 +145,11 @@ module.exports = class Process {
142145
})
143146
Y = ae.encode(X)
144147

145-
var impMatrix = []
148+
impMatrix = []
149+
146150
console.log('[Vis] Generate importance matrix with Autoencoder')
147151
featuresFiltered.forEach((f, fi) => {
148-
const imp = []
152+
const impTemp = []
149153
const Xr = []
150154
X.forEach(x => Xr.push(x.slice(0)))
151155
for (let i = Xr.length - 1; i > 0; i--) {
@@ -157,9 +161,9 @@ module.exports = class Process {
157161
const Xp = ae.predict(Xr)
158162
featuresFiltered.forEach((ff, ffi) => {
159163
const mse = Xp.reduce((a, x, xi) => Math.pow(x[ffi] - X[xi][ffi], 2) + a, 0) / Xp.length
160-
imp.push(mse)
164+
impTemp.push(mse)
161165
})
162-
impMatrix.push(imp)
166+
impMatrix.push(impTemp)
163167
})
164168
console.log('[Vis] Autoencoder importance matrix:', impMatrix)
165169
impMatrix = new Matrix(impMatrix).scaleColumns().to2DArray()
@@ -177,7 +181,6 @@ module.exports = class Process {
177181
let target
178182
let colorscale
179183
let g
180-
let imp
181184

182185
if (params.column && params.column.length && (params.column !== 'None')) {
183186
console.log('[Vis] Target variable is present')
@@ -212,9 +215,29 @@ module.exports = class Process {
212215
[0, '#8A8DA1'],
213216
[1, '#8A8DA1']
214217
]
218+
219+
if (params.importance === 'Random Forest') {
220+
console.log('[Vis] Fitting random forest on all variables')
221+
console.log(X, featuresFiltered, featuresFiltered.length)
222+
impMatrix = featuresFiltered.map((f, i) => {
223+
console.log(`Calculating ${i} of ${featuresFiltered.length}: ${f}`)
224+
const Xtemp = X.map(row => row.filter((_, j) => j !== i))
225+
const gtemp = X.map(row => row.filter((_, j) => j === i))
226+
const rf = new RandomForest({
227+
nEstimators: 50,
228+
maxDepth: 5,
229+
maxFeatures: 'auto'
230+
})
231+
rf.train(Xtemp, gtemp)
232+
const impTemp = rf.getFeatureImportances(Xtemp, gtemp, { n: 3, means: true, verbose: false })
233+
// Add importance of the target feature on itself
234+
impTemp.splice(i, 0, 0)
235+
return impTemp
236+
})
237+
}
215238
}
216239

217-
return {Y, X, params, nDims, featuresFiltered, recordsFiltered, target, g, colorscale, impMatrix, corr, imp}
240+
return { Y, X, params, nDims, featuresFiltered, recordsFiltered, target, g, colorscale, impMatrix, corr, imp }
218241
}
219242
}
220243
}

0 commit comments

Comments
 (0)