Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nan handling #198

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions src/timesfm/timesfm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,65 @@ def freq_map(freq: str):
"""Returns the frequency map for the given frequency string."""
freq = str.upper(freq)
if (freq.endswith("H") or freq.endswith("T") or freq.endswith("MIN") or
freq.endswith("D") or freq.endswith("B") or freq.endswith("U")):
freq.endswith("D") or freq.endswith("B") or freq.endswith("U") or
freq.endswith("S")):
return 0
elif freq.endswith(("W", "M", "MS")):
return 1
elif freq.endswith("Y") or freq.endswith("Q"):
elif freq.endswith("Y") or freq.endswith("Q") or freq.endswith("A"):
return 2
else:
raise ValueError(f"Invalid frequency: {freq}")

def strip_leading_nans(arr):
"""
Removes contiguous NaN values from the beginning of a NumPy array.

Args:
arr: The input NumPy array.

Returns:
A new NumPy array with leading NaN values removed.
If the array is all NaNs or empty, returns an empty array.
"""

isnan = np.isnan(arr)
first_valid_index = np.argmax(~isnan)
return arr[first_valid_index:]

def linear_interpolation(arr):
"""
Performs linear interpolation to fill NaN values in a 1D numpy array.

Args:
arr: The 1D numpy array containing NaN values.

Returns:
A new numpy array with NaN values filled using linear interpolation,
or the original array if no NaNs are present.
Returns None if the input is not a 1D array.
Returns the original array if there are no NaN values.
"""

nans = np.isnan(arr)
if not np.any(nans): # Check if there are any NaNs
return arr

x = lambda z: z.nonzero()[0]
nans_indices = x(nans)
non_nans_indices = x(~nans)
non_nans_values = arr[~nans]

try:
arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)
except ValueError:
if len(non_nans_values) > 0:
mu = np.nanmean(arr)
else:
mu = 0.0
arr = np.where(np.isfinite(arr), arr, mu)
return arr


# Per time series normalization: forward.
def _normalize(batch):
Expand Down Expand Up @@ -313,6 +363,17 @@ def forecast(
ValueError: If the checkpoint is not properly loaded.
"""
stats = None

tmp_inputs = []
for each_input in inputs:
arr = np.array(each_input)
if not np.isfinite(arr).all():
arr = np.where(np.isfinite(arr), arr, np.nan)
arr = strip_leading_nans(arr)
arr = linear_interpolation(arr)
tmp_inputs.append(arr)

inputs = tmp_inputs
if normalize:
inputs, stats = _normalize(inputs)
mean_forecast, quantile_forecast = self._forecast(
Expand Down
Loading