-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX. #296
Comments
have you transformed your dataframe or dataset to array, e.g. with np.array(media_data_train)? |
it is an array, and is float32, it should be in good shape already. WHich dataset threw an error? target? media? or others? |
Hmmmm.... I know you have checked it already, but would there be any chance that your target_train has nan, null, zero, or different data dtype? Do you mind to provide your notebook, and data for further investigation? I am not sure if there is any pm function in github. |
When trying to do this:
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
media_data_train = media_scaler.fit_transform(media_data_train)
target_train = target_scaler.fit_transform(target_train)
costs2 = cost_scaler.fit_transform(costs)
I got the error. up until then everything went as in the tutorial.
The text was updated successfully, but these errors were encountered: