Skip to content

AttributeError: module 'jax.random' has no attribute 'KeyArray' #3

@VVh5912

Description

@VVh5912

Hi, when i run the interferences (both text-to-image and style transfer) I encounter this error:

2024-05-03 06:36:47.216259: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-03 06:36:47.216309: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-03 06:36:47.217764: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-03 06:36:48.362862: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /content/gdrive/MyDrive/dreamstyler/dreamstyler/inference_style_transfer.py:12 in │
│ │
│ 9 import imageio │
│ 10 import numpy as np │
│ 11 from PIL import Image │
│ ❱ 12 from diffusers import ControlNetModel, UniPCMultistepScheduler │
│ 13 from transformers import CLIPTextModel, CLIPTokenizer │
│ 14 from controlnet_aux.processor import Processor │
│ 15 import custom_pipelines │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/init.py:38 in │
│ │
│ 35 except OptionalDependencyNotAvailable: │
│ 36 │ from .utils.dummy_pt_objects import * # noqa F403 │
│ 37 else: │
│ ❱ 38 │ from .models import ( │
│ 39 │ │ AsymmetricAutoencoderKL, │
│ 40 │ │ AutoencoderKL, │
│ 41 │ │ AutoencoderTiny, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/init.py:36 in │
│ │
│ 33 │ from .vq_model import VQModel │
│ 34 │
│ 35 if is_flax_available(): │
│ ❱ 36 │ from .controlnet_flax import FlaxControlNetModel │
│ 37 │ from .unet_2d_condition_flax import FlaxUNet2DConditionModel │
│ 38 │ from .vae_flax import FlaxAutoencoderKL │
│ 39 │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py:25 in │
│ │
│ 22 from ..configuration_utils import ConfigMixin, flax_register_to_config │
│ 23 from ..utils import BaseOutput │
│ 24 from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps │
│ ❱ 25 from .modeling_flax_utils import FlaxModelMixin │
│ 26 from .unet_2d_blocks_flax import ( │
│ 27 │ FlaxCrossAttnDownBlock2D, │
│ 28 │ FlaxDownBlock2D, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py:46 in │
│ │
│ 43 logger = logging.get_logger(name) │
│ 44 │
│ 45 │
│ ❱ 46 class FlaxModelMixin(PushToHubMixin): │
│ 47 │ r""" │
│ 48 │ Base class for all Flax models. │
│ 49 │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py:195 in │
│ FlaxModelMixin │
│ │
│ 192 │ │ ```""" │
│ 193 │ │ return self._cast_floating_to(params, jnp.float16, mask) │
│ 194 │ │
│ ❱ 195 │ def init_weights(self, rng: jax.random.KeyArray) -> Dict: │
│ 196 │ │ raise NotImplementedError(f"init_weights method has to be implemented for {self} │
│ 197 │ │
│ 198 │ @classmethod
│ │
│ /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py:54 in getattr │
│ │
│ 51 │ │ raise AttributeError(message) │
│ 52 │ warnings.warn(message, DeprecationWarning, stacklevel=2) │
│ 53 │ return fn │
│ ❱ 54 │ raise AttributeError(f"module {module!r} has no attribute {name!r}") │
│ 55 │
│ 56 return getattr │
│ 57 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: module 'jax.random' has no attribute 'KeyArray'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions