Skip to content

Commit 774b282

Browse files
upgrading to python 3.11 and tensorflow 2.19
1 parent 67be9eb commit 774b282

File tree

3 files changed

+41
-42
lines changed

3 files changed

+41
-42
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
runs-on: ubuntu-latest
1414
strategy:
1515
matrix:
16-
python-version: ["3.10"]
16+
python-version: ["3.10", "3.11"]
1717

1818
steps:
1919
- name: Checkout repository

FLiESANN/run_FLiES_ANN_inference.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,43 @@ def run_FLiES_ANN_inference(
7272

7373
# Convert DataFrame to numpy array and reshape for the model
7474
inputs_array = inputs.values
75-
# The model expects 3D input: (batch_size, sequence_length, features)
76-
# Reshape from (batch_size, features) to (batch_size, 1, features)
77-
inputs_array = inputs_array.reshape(inputs_array.shape[0], 1, inputs_array.shape[1])
75+
76+
# Check what input shape the model expects and adapt accordingly
77+
# Different TensorFlow/Keras versions may have different input requirements
78+
try:
79+
model_input_shape = ANN_model.input_shape
80+
if len(model_input_shape) == 3:
81+
# Model expects 3D input: (batch_size, sequence_length, features)
82+
# Reshape from (batch_size, features) to (batch_size, 1, features)
83+
inputs_array = inputs_array.reshape(inputs_array.shape[0], 1, inputs_array.shape[1])
84+
expects_3d = True
85+
elif len(model_input_shape) == 2:
86+
# Model expects 2D input: (batch_size, features)
87+
# Keep the original 2D shape
88+
expects_3d = False
89+
else:
90+
# Fallback: try 2D first
91+
expects_3d = False
92+
except (AttributeError, TypeError):
93+
# If input_shape is not available, try 2D first
94+
expects_3d = False
7895

7996
# Run inference using the ANN model
80-
outputs = ANN_model.predict(inputs_array)
97+
try:
98+
outputs = ANN_model.predict(inputs_array)
99+
except ValueError as e:
100+
error_msg = str(e)
101+
if not expects_3d and ("expected shape" in error_msg or "incompatible" in error_msg):
102+
# Try reshaping to 3D if 2D failed
103+
inputs_array = inputs.values # Reset to original 2D shape
104+
inputs_array = inputs_array.reshape(inputs_array.shape[0], 1, inputs_array.shape[1])
105+
outputs = ANN_model.predict(inputs_array)
106+
expects_3d = True
107+
else:
108+
raise e
81109

82-
# The model returns 3D output due to the reshaped input, squeeze out the middle dimension
83-
if len(outputs.shape) == 3:
110+
# Handle output dimensions based on input dimensions used
111+
if expects_3d and len(outputs.shape) == 3:
84112
outputs = outputs.squeeze(axis=1)
85113

86114
shape = COT.shape

Processing FLiES with a raster and default parameters.ipynb

Lines changed: 6 additions & 35 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)