Skip to content

Commit e360b36

Browse files
committed
TSFM
1 parent 78842a8 commit e360b36

File tree

3 files changed

+185
-56
lines changed

3 files changed

+185
-56
lines changed

iotfunctions/bif.py

+8-55
Original file line numberDiff line numberDiff line change
@@ -3231,6 +3231,11 @@ def login(self):
32313231
if self.logged_on:
32323232
return
32333233

3234+
self.logged_on = self.initialize_local_model()
3235+
if self.logged_on:
3236+
logger.info('using local model')
3237+
return
3238+
32343239
# retrieve WML credentials as constant
32353240
# {"apikey": api_key, "url": 'https://' + location + '.ml.cloud.ibm.com'}
32363241
c = None
@@ -3247,7 +3252,7 @@ def login(self):
32473252
elif self.wml_auth is not None:
32483253
# check if exists, but empty
32493254
if not self.wml_auth:
3250-
self.init_local_model = init_local_model(self)
3255+
self.logged_on = self.initialize_local_model()
32513256
try:
32523257
c = self._entity_type.get_attributes_dict()
32533258
except Exception:
@@ -3365,7 +3370,7 @@ def execute(self, df):
33653370

33663371
def _calc(self, df):
33673372

3368-
#entity = df.index[0][0]
3373+
entity = df.index[0][0]
33693374

33703375
# get rid of entity id as part of the index
33713376
#df = df.droplevel(0)
@@ -3374,7 +3379,7 @@ def _calc(self, df):
33743379
# do inference with the local model
33753380
if self.init_local_model:
33763381
logging.info("Calling local model")
3377-
return call_local_model(df)
3382+
return self.call_local_model(df)
33783383

33793384

33803385
if len(self.input_items) >= 1:
@@ -3439,58 +3444,6 @@ def build_ui(cls):
34393444
return (inputs, outputs)
34403445

34413446

3442-
class TSFMZeroShotScorer(InvokeWMLModel):
3443-
"""
3444-
Call time series foundation model
3445-
"""
3446-
def __init__(self, input_items, output_items=None, context=512, horizon=96, watsonx_auth=None):
3447-
logger.debug(str(input_items) + ', ' + str(output_items))
3448-
3449-
super().__init__(input_items, watsonx_auth, output_items)
3450-
3451-
self.context = context
3452-
self.horizon = horizon
3453-
self.whoami = 'TSFMZeroShot'
3454-
3455-
# allow for expansion of the dataframe
3456-
self.allowed_to_expand = True
3457-
3458-
# ask for more data if we do not have enough data for context and horizon
3459-
def check_size(self, size_df):
3460-
return min(size_df) < self.context + self.horizon
3461-
3462-
# TODO implement local model lookup and initialization later
3463-
# initialize local model is a NoOp for superclass
3464-
def initialize_local_model(self):
3465-
return False
3466-
3467-
# inference on local model is a NoOp for superclass
3468-
def call_local_model(self, df):
3469-
return df
3470-
3471-
3472-
@classmethod
3473-
def build_ui(cls):
3474-
3475-
# define arguments that behave as function inputs
3476-
inputs = []
3477-
3478-
inputs.append(UIMultiItem(name='input_items', datatype=float, required=True, output_item='output_items',
3479-
is_output_datatype_derived=True))
3480-
inputs.append(
3481-
UISingle(name='context', datatype=int, required=False, description='Context - past data'))
3482-
inputs.append(
3483-
UISingle(name='horizon', datatype=int, required=False, description='Forecasting horizon'))
3484-
inputs.append(UISingle(name='watsonx_auth', datatype=str,
3485-
description='Endpoint to the WatsonX service where model is hosted', tags=['TEXT'], required=True))
3486-
3487-
# define arguments that behave as function outputs
3488-
outputs=[]
3489-
#outputs.append(UISingle(name='output_items', datatype=float))
3490-
return inputs, outputs
3491-
3492-
3493-
34943447
class InvokeWMLClassifier(InvokeWMLModel):
34953448
'''
34963449
Pass multivariate data in input_items to a classification function deployed to

iotfunctions/db.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1273,8 +1273,9 @@ def install_package(self, url):
12731273
logger.warning(('Request to install package %s was ignored. This package'
12741274
' is pre-installed.'), url)
12751275
else:
1276+
# --process-dependency-links is gone with pip19
12761277
try:
1277-
completedProcess = subprocess.run(['pip', 'install', '--process-dependency-links', '--upgrade', url],
1278+
completedProcess = subprocess.run(['pip', 'install', '-U', '--break-system-packages', url],
12781279
stderr=subprocess.STDOUT, stdout=subprocess.PIPE,
12791280
universal_newlines=True)
12801281
except Exception as e:

iotfunctions/experimental.py

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# *****************************************************************************
2+
# © Copyright IBM Corp. 2024 All Rights Reserved.
3+
#
4+
# This program and the accompanying materials
5+
# are made available under the terms of the Apache V2.0 license
6+
# which accompanies this distribution, and is available at
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# *****************************************************************************
10+
11+
"""
12+
The experimental functions module contains (no surprise here) experimental functions
13+
"""
14+
15+
import datetime as dt
16+
import logging
17+
import re
18+
import time
19+
import warnings
20+
import ast
21+
import os
22+
import subprocess
23+
import importlib
24+
from collections import OrderedDict
25+
26+
import numpy as np
27+
import scipy as sp
28+
import pandas as pd
29+
from sqlalchemy import String
30+
31+
import torch
32+
33+
from .base import (BaseTransformer, BaseEvent, BaseSCDLookup, BaseSCDLookupWithDefault, BaseMetadataProvider,
34+
BasePreload, BaseDatabaseLookup, BaseDataSource, BaseDBActivityMerge, BaseSimpleAggregator,
35+
DataExpanderTransformer)
36+
from .bif import (InvokeWMLModel)
37+
from .loader import _generate_metadata
38+
from .ui import (UISingle, UIMultiItem, UIFunctionOutSingle, UISingleItem, UIFunctionOutMulti, UIMulti, UIExpression,
39+
UIText, UIParameters)
40+
from .util import adjust_probabilities, reset_df_index, asList, UNIQUE_EXTENSION_LABEL
41+
42+
logger = logging.getLogger(__name__)
43+
PACKAGE_URL = 'git+https://github.com/ibm-watson-iot/functions.git@'
44+
45+
# Do away with numba and onnxscript logs
46+
numba_logger = logging.getLogger('numba')
47+
numba_logger.setLevel(logging.INFO)
48+
onnx_logger = logging.getLogger('onnxscript')
49+
onnx_logger.setLevel(logging.ERROR)
50+
51+
52+
def install_and_activate_granite_tsfm():
53+
#db.install_package("git+https://github.com/sedgewickmm18/granite-tsfm")
54+
#db.import_target("tsfm_public.models","tinytimemixer", "TinyTimeMixerForPrediction")
55+
56+
url = "git+https://github.com/sedgewickmm18/granite-tsfm"
57+
try:
58+
completedProcess = subprocess.run(['pip', 'install', '-U', '--break-system-packages', url],
59+
stderr=subprocess.STDOUT, stdout=subprocess.PIPE,
60+
universal_newlines=True)
61+
except Exception as e:
62+
#raise ImportError('pip install for url %s failed: \n%s', url, str(e))
63+
logger.error('Could not download from ' + url)
64+
return False
65+
66+
if completedProcess.returncode == 0:
67+
importlib.invalidate_caches()
68+
logger.debug('pip install for url %s was successful: \n %s', url, completedProcess.stdout)
69+
70+
else:
71+
raise ImportError('pip install for url %s failed: \n %s.', url, completedProcess.stdout)
72+
logger.error('Could not install ' + url)
73+
return False
74+
75+
try:
76+
exec('from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction')
77+
except Exception as e:
78+
logger.error('Could not import TinyTimeMixers')
79+
return False
80+
81+
return True
82+
83+
class TSFMZeroShotScorer(InvokeWMLModel):
84+
"""
85+
Call time series foundation model
86+
"""
87+
def __init__(self, input_items, output_items=None, context=512, horizon=96, watsonx_auth=None):
88+
logger.debug(str(input_items) + ', ' + str(output_items))
89+
90+
super().__init__(input_items, watsonx_auth, output_items)
91+
92+
self.context = context
93+
self.horizon = horizon
94+
self.whoami = 'TSFMZeroShot'
95+
96+
# allow for expansion of the dataframe
97+
self.allowed_to_expand = True
98+
99+
self.init_local_model = install_and_activate_granite_tsfm()
100+
self.model = None # cache model for multiple calls
101+
102+
103+
# ask for more data if we do not have enough data for context and horizon
104+
def check_size(self, size_df):
105+
return min(size_df) < self.context + self.horizon
106+
107+
# TODO implement local model lookup and initialization later
108+
# initialize local model is a NoOp for superclass
109+
def initialize_local_model(self):
110+
logger.info('initialize local model')
111+
try:
112+
from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction
113+
TTM_MODEL_REVISION = "main"
114+
# Forecasting parameters
115+
#context_length = 512
116+
#forecast_length = 96
117+
#install_and_activate_granite_tsfm()
118+
self.model = TinyTimeMixerForPrediction.from_pretrained("ibm/TTM", revision=TTM_MODEL_REVISION)
119+
except Exception as e:
120+
logger.error("Failed to load local model with error " + str(e))
121+
return False
122+
logger.info('local model ready')
123+
return True
124+
125+
# inference on local model
126+
def call_local_model(self, df):
127+
logger.info('call local model')
128+
129+
logger.debug('df columns ' + str(df.columns))
130+
logger.debug('df index ' + str(df.index.names))
131+
132+
# size of the df should be fine
133+
len = self.context + self.horizon
134+
135+
if self.model is not None:
136+
logger.debug('Forecast ' + str(df.shape[0]/self.horizon) + ' times')
137+
for i in range(self.context, df.shape[0], self.horizon):
138+
inputtensor_ = torch.from_numpy(df[i-self.context:i][self.input_items].values)
139+
#logger.debug('shape input ' + str(inputtensor_.shape))
140+
# add dimension
141+
#inputtensor = inputtensor_[None,:self.context,:] # only the historic context
142+
inputtensor = inputtensor_[None,:,:] # only the historic context
143+
#logger.debug('shape input ' + str(inputtensor.shape))
144+
outputtensor = self.model(inputtensor)['prediction_outputs'] # get the forecasting horizon back
145+
#logger.debug('shapes input ' + str(inputtensor.shape) + ' , output ' + str(outputtensor.shape))
146+
# and update the dataframe with it
147+
#df.loc[df.tail(self.horizon).index, self.output_items] = outputtensor[0].detach().numpy()
148+
try:
149+
df.loc[df[i:i + self.horizon].index, self.output_items] = outputtensor[0].detach().numpy()
150+
except:
151+
logger.debug('Issue with ' + str(i) + ':' + str(i+self.horizon))
152+
pass
153+
154+
return df
155+
156+
@classmethod
157+
def build_ui(cls):
158+
159+
# define arguments that behave as function inputs
160+
inputs = []
161+
162+
inputs.append(UIMultiItem(name='input_items', datatype=float, required=True, output_item='output_items',
163+
is_output_datatype_derived=True))
164+
inputs.append(
165+
UISingle(name='context', datatype=int, required=False, description='Context - past data'))
166+
inputs.append(
167+
UISingle(name='horizon', datatype=int, required=False, description='Forecasting horizon'))
168+
inputs.append(UISingle(name='watsonx_auth', datatype=str,
169+
description='Endpoint to the WatsonX service where model is hosted', tags=['TEXT'], required=True))
170+
171+
# define arguments that behave as function outputs
172+
outputs=[]
173+
#outputs.append(UISingle(name='output_items', datatype=float))
174+
return inputs, outputs
175+

0 commit comments

Comments
 (0)