forked from PAIR-code/lit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathregression_module.ts
154 lines (133 loc) · 4.92 KB
/
regression_module.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
/**
* @license
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// tslint:disable:no-new-decorators
import {customElement, html} from 'lit-element';
import {observable} from 'mobx';
import {app} from '../core/lit_app';
import {LitModule} from '../core/lit_module';
import {IndexedInput, ModelInfoMap, Spec} from '../lib/types';
import {doesOutputSpecContain, findSpecKeys} from '../lib/utils';
import {RegressionService} from '../services/services';
import {styles} from './regression_module.css';
import {styles as sharedStyles} from './shared_styles.css';
interface RegressionResult {
[key: string]: number;
}
/**
* A LIT module that renders regression results.
*/
@customElement('regression-module')
export class RegressionModule extends LitModule {
static title = 'Regression Results';
static duplicateForExampleComparison = true;
static numCols = 3;
static template = (model = '', selectionServiceIndex = 0) => {
return html`<regression-module model=${model} selectionServiceIndex=${
selectionServiceIndex}></regression-module>`;
};
static get styles() {
return [sharedStyles, styles];
}
private readonly regressionService = app.getService(RegressionService);
@observable private result: RegressionResult|null = null;
firstUpdated() {
const getPrimarySelectedInputData = () =>
this.selectionService.primarySelectedInputData;
this.reactImmediately(
getPrimarySelectedInputData, primarySelectedInputData => {
this.updateSelection(primarySelectedInputData);
});
}
private async updateSelection(inputData: IndexedInput|null) {
if (inputData == null) {
this.result = null;
return;
}
const dataset = this.appState.currentDataset;
const promise = this.regressionService.getRegressionPreds(
[inputData], this.model, dataset);
const results = await this.loadLatest('regressionPreds', promise);
if (results === null || results.length === 0) {
this.result = null;
return;
}
// Extract the single result, as this only is for a single input.
const keys = Object.keys(results[0]);
for (const key of keys) {
const regressionInfo = (await this.regressionService.getResults(
[inputData.id], this.model, key))[0];
results[0][this.regressionService.getErrorKey(key)] =
regressionInfo.error;
}
this.result = results[0];
}
render() {
if (this.result == null) {
return null;
}
const result = this.result;
const input = this.selectionService.primarySelectedInputData!;
// Use the spec to find which fields we should display.
const spec = this.appState.getModelSpec(this.model);
const scoreFields: string[] = findSpecKeys(spec.output, 'RegressionScore');
const rows: string[][] = [];
let hasParent = false;
// Per output, display score, and parent field and error if available.
for (const scoreField of scoreFields) {
// Add new row for each output from the model.
const score = result[scoreField] == null ?
'' :
result[scoreField].toFixed(4);
// Target score to compare against.
const parentField = spec.output[scoreField].parent! || '';
const parentScore = input.data[parentField] == null ?
'' :
input.data[parentField].toFixed(4);
let errorScore = '';
if (parentField && parentScore) {
const error =
result[this.regressionService.getErrorKey(scoreField)];
if (error != null) {
hasParent = true;
errorScore = error.toFixed(4);
}
}
rows.push([scoreField, parentScore, score, errorScore]);
}
// If no fields have ground truth scores to compare then don't display the
// ground truth-related columns.
const columnNames = ["Field", "Ground truth", "Score", "Error"];
const columnVisibility = new Map<string, boolean>();
columnNames.forEach((name) => {
columnVisibility.set(
name, hasParent || (name !== 'Ground truth' && name !== 'Error'));
});
return html`
<lit-data-table
.columnVisibility=${columnVisibility}
.data=${rows} selectionDisabled
></lit-data-table>`;
}
static shouldDisplayModule(modelSpecs: ModelInfoMap, datasetSpec: Spec) {
return doesOutputSpecContain(modelSpecs, 'RegressionScore');
}
}
declare global {
interface HTMLElementTagNameMap {
'regression-module': RegressionModule;
}
}