@@ -72,15 +72,43 @@ def run_FLiES_ANN_inference(
72
72
73
73
# Convert DataFrame to numpy array and reshape for the model
74
74
inputs_array = inputs .values
75
- # The model expects 3D input: (batch_size, sequence_length, features)
76
- # Reshape from (batch_size, features) to (batch_size, 1, features)
77
- inputs_array = inputs_array .reshape (inputs_array .shape [0 ], 1 , inputs_array .shape [1 ])
75
+
76
+ # Check what input shape the model expects and adapt accordingly
77
+ # Different TensorFlow/Keras versions may have different input requirements
78
+ try :
79
+ model_input_shape = ANN_model .input_shape
80
+ if len (model_input_shape ) == 3 :
81
+ # Model expects 3D input: (batch_size, sequence_length, features)
82
+ # Reshape from (batch_size, features) to (batch_size, 1, features)
83
+ inputs_array = inputs_array .reshape (inputs_array .shape [0 ], 1 , inputs_array .shape [1 ])
84
+ expects_3d = True
85
+ elif len (model_input_shape ) == 2 :
86
+ # Model expects 2D input: (batch_size, features)
87
+ # Keep the original 2D shape
88
+ expects_3d = False
89
+ else :
90
+ # Fallback: try 2D first
91
+ expects_3d = False
92
+ except (AttributeError , TypeError ):
93
+ # If input_shape is not available, try 2D first
94
+ expects_3d = False
78
95
79
96
# Run inference using the ANN model
80
- outputs = ANN_model .predict (inputs_array )
97
+ try :
98
+ outputs = ANN_model .predict (inputs_array )
99
+ except ValueError as e :
100
+ error_msg = str (e )
101
+ if not expects_3d and ("expected shape" in error_msg or "incompatible" in error_msg ):
102
+ # Try reshaping to 3D if 2D failed
103
+ inputs_array = inputs .values # Reset to original 2D shape
104
+ inputs_array = inputs_array .reshape (inputs_array .shape [0 ], 1 , inputs_array .shape [1 ])
105
+ outputs = ANN_model .predict (inputs_array )
106
+ expects_3d = True
107
+ else :
108
+ raise e
81
109
82
- # The model returns 3D output due to the reshaped input, squeeze out the middle dimension
83
- if len (outputs .shape ) == 3 :
110
+ # Handle output dimensions based on input dimensions used
111
+ if expects_3d and len (outputs .shape ) == 3 :
84
112
outputs = outputs .squeeze (axis = 1 )
85
113
86
114
shape = COT .shape
0 commit comments