Skip to content

Commit 51f031f

Browse files
add interactions
1 parent ff69e01 commit 51f031f

File tree

4 files changed

+101
-104
lines changed

4 files changed

+101
-104
lines changed
Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,46 @@
1-
import json
2-
import os
3-
from typing import Dict, List, Union
1+
from typing import Dict, List, Tuple, Union
42

5-
from pandas import Series
6-
from scipy.sparse._csr import csr_matrix
3+
from sklearn.feature_extraction import DictVectorizer
74
from xgboost import Booster
85

9-
from mlops.utils..data_preparation.feature_engineering import combine_features
10-
from mlops.utils..models.xgboost import build_data, load_model
6+
from mlops.utils.data_preparation.feature_engineering import combine_features
7+
from mlops.utils.models.xgboost import build_data
118

129
if 'custom' not in globals():
1310
from mage_ai.data_preparation.decorators import custom
1411

12+
DEFAULT_INPUTS = [
13+
{
14+
# target = "duration": 11.5
15+
'DOLocationID': 239,
16+
'PULocationID': 236,
17+
'trip_distance': 1.98,
18+
},
19+
{
20+
# target = "duration" 20.8666666667
21+
'DOLocationID': '170',
22+
'PULocationID': '65',
23+
'trip_distance': 6.54,
24+
},
25+
]
26+
1527

1628
@custom
1729
def predict(
18-
training_set: Dict[str, List[Union[Series, csr_matrix]]],
19-
model_settings: Dict[str, List[Booster]],
30+
model_settings: Dict[str, Tuple[Booster, DictVectorizer]],
2031
**kwargs,
2132
) -> List[float]:
22-
inputs: List[Dict[str, Union[float, int]]] = kwargs.get(
23-
'inputs',
24-
[
25-
{
26-
# target = "duration": 11.5
27-
'DOLocationID': 239,
28-
'PULocationID': 236,
29-
'trip_distance': 1.98,
30-
},
31-
{
32-
# target = "duration" 20.8666666667
33-
'DOLocationID': '170',
34-
'PULocationID': '65',
35-
'trip_distance': 6.54,
36-
},
37-
],
38-
)
39-
40-
dict_vectorizer = training_set['build'][6]
41-
print(dict_vectorizer)
42-
vectors = dict_vectorizer.transform(combine_features(inputs))
43-
44-
print(model_settings)
45-
model = model_settings.get(
46-
'xgboost',
47-
load_model(kwargs.get('model_dir'), 'model.ubj', 'config.json'),
48-
)
33+
inputs: List[Dict[str, Union[float, int]]] = kwargs.get('inputs', DEFAULT_INPUTS)
34+
inputs = combine_features(inputs)
35+
36+
model, vectorizer = model_settings['xgboost']
37+
vectors = vectorizer.transform(inputs)
4938

5039
predictions = model.predict(build_data(vectors))
5140

52-
for idx, input_feature in enumerate(input_dicts):
41+
for idx, input_feature in enumerate(inputs):
5342
print(f'Prediction of duration using these features: {predictions[idx]}')
54-
for key, value in input_features[idx].items():
43+
for key, value in inputs[idx].items():
5544
print(f'\t{key}: {value}')
5645

5746
return predictions.tolist()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
inputs:
2+
number field:
3+
style:
4+
input_type: number
5+
type: text_field
6+
layout:
7+
- - variable: PULocationID
8+
width: 1
9+
- variable: DOLocationID
10+
width: 1
11+
- - variable: trip_distance
12+
width: 1
13+
variables:
14+
DOLocationID:
15+
description: e.g. 239, 170
16+
input: number field
17+
name: DOLocationID
18+
types:
19+
- integer
20+
PULocationID:
21+
description: e.g. 236, 65
22+
input: number field
23+
name: PULocationID
24+
types:
25+
- integer
26+
trip_distance:
27+
description: e.g. 1.98, 6.54
28+
input: number field
29+
name: Trip distance
30+
types:
31+
- float
Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
blocks:
22
inference:
3-
- description: null
4-
layout: []
5-
name: null
6-
permissions: []
7-
uuid: playground
8-
variables: {}
3+
- uuid: playground.yaml
94
layout: []
5+
permissions: []
Lines changed: 41 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,47 @@
11
blocks:
2-
- all_upstream_blocks_executed: true
3-
color: null
4-
configuration:
5-
global_data_product:
6-
uuid: training_set
7-
downstream_blocks:
8-
- inference
9-
- model
10-
executor_config: null
11-
executor_type: local_python
12-
has_callback: false
13-
language: python
14-
name: Training data
15-
retry_config: null
16-
status: executed
17-
timeout: null
18-
type: global_data_product
19-
upstream_blocks: []
20-
uuid: training_data
21-
- all_upstream_blocks_executed: true
22-
color: null
23-
configuration:
24-
file_source:
25-
path: unit_4_triggering/global_data_products/model.py
26-
global_data_product:
27-
uuid: xgboost
28-
downstream_blocks:
29-
- inference
30-
executor_config: null
31-
executor_type: local_python
32-
has_callback: false
33-
language: python
34-
name: Model
35-
retry_config: null
36-
status: executed
37-
timeout: null
38-
type: global_data_product
39-
upstream_blocks: []
40-
uuid: model
41-
- all_upstream_blocks_executed: true
42-
color: teal
43-
configuration:
44-
file_source:
45-
path: null
46-
downstream_blocks: []
47-
executor_config: null
48-
executor_type: local_python
49-
has_callback: false
50-
language: python
51-
name: inference
52-
retry_config: null
53-
status: failed
54-
timeout: null
55-
type: custom
56-
upstream_blocks:
57-
- training_data
58-
- model
59-
uuid: inference
2+
- all_upstream_blocks_executed: true
3+
color: null
4+
configuration:
5+
file_source:
6+
path: unit_4_triggering/global_data_products/model.py
7+
global_data_product:
8+
uuid: xgboost
9+
downstream_blocks:
10+
- inference
11+
executor_config: null
12+
executor_type: local_python
13+
has_callback: false
14+
language: python
15+
name: Model
16+
retry_config: null
17+
status: executed
18+
timeout: null
19+
type: global_data_product
20+
upstream_blocks: []
21+
uuid: model
22+
- all_upstream_blocks_executed: true
23+
color: teal
24+
configuration:
25+
file_source:
26+
path: null
27+
downstream_blocks: []
28+
executor_config: null
29+
executor_type: local_python
30+
has_callback: false
31+
language: python
32+
name: inference
33+
retry_config: null
34+
status: executed
35+
timeout: null
36+
type: custom
37+
upstream_blocks:
38+
- model
39+
uuid: inference
6040
cache_block_output_in_memory: false
6141
callbacks: []
6242
concurrency_config: {}
6343
conditionals: []
64-
created_at: "2024-05-09 02:45:15.656239+00:00"
44+
created_at: '2024-05-09 02:45:15.656239+00:00'
6545
data_integration: null
6646
description: Online inference pipeline.
6747
executor_config: {}
@@ -80,5 +60,6 @@ tags: []
8060
type: python
8161
uuid: predict
8262
variables:
83-
model_dir: ""
63+
model_dir: ''
64+
variables_dir: /root/.mage_data/unit_4_triggering
8465
widgets: []

0 commit comments

Comments
 (0)