|
| 1 | +# ***************************************************************************** |
| 2 | +# © Copyright IBM Corp. 2024 All Rights Reserved. |
| 3 | +# |
| 4 | +# This program and the accompanying materials |
| 5 | +# are made available under the terms of the Apache V2.0 license |
| 6 | +# which accompanies this distribution, and is available at |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# ***************************************************************************** |
| 10 | + |
| 11 | +""" |
| 12 | +The experimental functions module contains (no surprise here) experimental functions |
| 13 | +""" |
| 14 | + |
| 15 | +import datetime as dt |
| 16 | +import logging |
| 17 | +import re |
| 18 | +import time |
| 19 | +import warnings |
| 20 | +import ast |
| 21 | +import os |
| 22 | +import subprocess |
| 23 | +import importlib |
| 24 | +from collections import OrderedDict |
| 25 | + |
| 26 | +import numpy as np |
| 27 | +import scipy as sp |
| 28 | +import pandas as pd |
| 29 | +from sqlalchemy import String |
| 30 | + |
| 31 | +import torch |
| 32 | + |
| 33 | +from .base import (BaseTransformer, BaseEvent, BaseSCDLookup, BaseSCDLookupWithDefault, BaseMetadataProvider, |
| 34 | + BasePreload, BaseDatabaseLookup, BaseDataSource, BaseDBActivityMerge, BaseSimpleAggregator, |
| 35 | + DataExpanderTransformer) |
| 36 | +from .bif import (InvokeWMLModel) |
| 37 | +from .loader import _generate_metadata |
| 38 | +from .ui import (UISingle, UIMultiItem, UIFunctionOutSingle, UISingleItem, UIFunctionOutMulti, UIMulti, UIExpression, |
| 39 | + UIText, UIParameters) |
| 40 | +from .util import adjust_probabilities, reset_df_index, asList, UNIQUE_EXTENSION_LABEL |
| 41 | + |
| 42 | +logger = logging.getLogger(__name__) |
| 43 | +PACKAGE_URL = 'git+https://github.com/ibm-watson-iot/functions.git@' |
| 44 | + |
| 45 | +# Do away with numba and onnxscript logs |
| 46 | +numba_logger = logging.getLogger('numba') |
| 47 | +numba_logger.setLevel(logging.INFO) |
| 48 | +onnx_logger = logging.getLogger('onnxscript') |
| 49 | +onnx_logger.setLevel(logging.ERROR) |
| 50 | + |
| 51 | + |
| 52 | +def install_and_activate_granite_tsfm(): |
| 53 | + #db.install_package("git+https://github.com/sedgewickmm18/granite-tsfm") |
| 54 | + #db.import_target("tsfm_public.models","tinytimemixer", "TinyTimeMixerForPrediction") |
| 55 | + |
| 56 | + url = "git+https://github.com/sedgewickmm18/granite-tsfm" |
| 57 | + try: |
| 58 | + completedProcess = subprocess.run(['pip', 'install', '-U', '--break-system-packages', url], |
| 59 | + stderr=subprocess.STDOUT, stdout=subprocess.PIPE, |
| 60 | + universal_newlines=True) |
| 61 | + except Exception as e: |
| 62 | + #raise ImportError('pip install for url %s failed: \n%s', url, str(e)) |
| 63 | + logger.error('Could not download from ' + url) |
| 64 | + return False |
| 65 | + |
| 66 | + if completedProcess.returncode == 0: |
| 67 | + importlib.invalidate_caches() |
| 68 | + logger.debug('pip install for url %s was successful: \n %s', url, completedProcess.stdout) |
| 69 | + |
| 70 | + else: |
| 71 | + raise ImportError('pip install for url %s failed: \n %s.', url, completedProcess.stdout) |
| 72 | + logger.error('Could not install ' + url) |
| 73 | + return False |
| 74 | + |
| 75 | + try: |
| 76 | + exec('from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction') |
| 77 | + except Exception as e: |
| 78 | + logger.error('Could not import TinyTimeMixers') |
| 79 | + return False |
| 80 | + |
| 81 | + return True |
| 82 | + |
| 83 | +class TSFMZeroShotScorer(InvokeWMLModel): |
| 84 | + """ |
| 85 | + Call time series foundation model |
| 86 | + """ |
| 87 | + def __init__(self, input_items, output_items=None, context=512, horizon=96, watsonx_auth=None): |
| 88 | + logger.debug(str(input_items) + ', ' + str(output_items)) |
| 89 | + |
| 90 | + super().__init__(input_items, watsonx_auth, output_items) |
| 91 | + |
| 92 | + self.context = context |
| 93 | + self.horizon = horizon |
| 94 | + self.whoami = 'TSFMZeroShot' |
| 95 | + |
| 96 | + # allow for expansion of the dataframe |
| 97 | + self.allowed_to_expand = True |
| 98 | + |
| 99 | + self.init_local_model = install_and_activate_granite_tsfm() |
| 100 | + self.model = None # cache model for multiple calls |
| 101 | + |
| 102 | + |
| 103 | + # ask for more data if we do not have enough data for context and horizon |
| 104 | + def check_size(self, size_df): |
| 105 | + return min(size_df) < self.context + self.horizon |
| 106 | + |
| 107 | + # TODO implement local model lookup and initialization later |
| 108 | + # initialize local model is a NoOp for superclass |
| 109 | + def initialize_local_model(self): |
| 110 | + logger.info('initialize local model') |
| 111 | + try: |
| 112 | + from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction |
| 113 | + TTM_MODEL_REVISION = "main" |
| 114 | + # Forecasting parameters |
| 115 | + #context_length = 512 |
| 116 | + #forecast_length = 96 |
| 117 | + #install_and_activate_granite_tsfm() |
| 118 | + self.model = TinyTimeMixerForPrediction.from_pretrained("ibm/TTM", revision=TTM_MODEL_REVISION) |
| 119 | + except Exception as e: |
| 120 | + logger.error("Failed to load local model with error " + str(e)) |
| 121 | + return False |
| 122 | + logger.info('local model ready') |
| 123 | + return True |
| 124 | + |
| 125 | + # inference on local model |
| 126 | + def call_local_model(self, df): |
| 127 | + logger.info('call local model') |
| 128 | + |
| 129 | + logger.debug('df columns ' + str(df.columns)) |
| 130 | + logger.debug('df index ' + str(df.index.names)) |
| 131 | + |
| 132 | + # size of the df should be fine |
| 133 | + len = self.context + self.horizon |
| 134 | + |
| 135 | + if self.model is not None: |
| 136 | + logger.debug('Forecast ' + str(df.shape[0]/self.horizon) + ' times') |
| 137 | + for i in range(self.context, df.shape[0], self.horizon): |
| 138 | + inputtensor_ = torch.from_numpy(df[i-self.context:i][self.input_items].values) |
| 139 | + #logger.debug('shape input ' + str(inputtensor_.shape)) |
| 140 | + # add dimension |
| 141 | + #inputtensor = inputtensor_[None,:self.context,:] # only the historic context |
| 142 | + inputtensor = inputtensor_[None,:,:] # only the historic context |
| 143 | + #logger.debug('shape input ' + str(inputtensor.shape)) |
| 144 | + outputtensor = self.model(inputtensor)['prediction_outputs'] # get the forecasting horizon back |
| 145 | + #logger.debug('shapes input ' + str(inputtensor.shape) + ' , output ' + str(outputtensor.shape)) |
| 146 | + # and update the dataframe with it |
| 147 | + #df.loc[df.tail(self.horizon).index, self.output_items] = outputtensor[0].detach().numpy() |
| 148 | + try: |
| 149 | + df.loc[df[i:i + self.horizon].index, self.output_items] = outputtensor[0].detach().numpy() |
| 150 | + except: |
| 151 | + logger.debug('Issue with ' + str(i) + ':' + str(i+self.horizon)) |
| 152 | + pass |
| 153 | + |
| 154 | + return df |
| 155 | + |
| 156 | + @classmethod |
| 157 | + def build_ui(cls): |
| 158 | + |
| 159 | + # define arguments that behave as function inputs |
| 160 | + inputs = [] |
| 161 | + |
| 162 | + inputs.append(UIMultiItem(name='input_items', datatype=float, required=True, output_item='output_items', |
| 163 | + is_output_datatype_derived=True)) |
| 164 | + inputs.append( |
| 165 | + UISingle(name='context', datatype=int, required=False, description='Context - past data')) |
| 166 | + inputs.append( |
| 167 | + UISingle(name='horizon', datatype=int, required=False, description='Forecasting horizon')) |
| 168 | + inputs.append(UISingle(name='watsonx_auth', datatype=str, |
| 169 | + description='Endpoint to the WatsonX service where model is hosted', tags=['TEXT'], required=True)) |
| 170 | + |
| 171 | + # define arguments that behave as function outputs |
| 172 | + outputs=[] |
| 173 | + #outputs.append(UISingle(name='output_items', datatype=float)) |
| 174 | + return inputs, outputs |
| 175 | + |
0 commit comments