-
Notifications
You must be signed in to change notification settings - Fork 6k
[WIP] Modular Diffusers support custom code/pipeline blocks #11539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: modular-refactor
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Dhruv!
I couldn't run the code as there seems to be some problems with undefined variables. I can give the saving logic a look after that is fixed.
@@ -154,15 +157,132 @@ def check_imports(filename): | |||
return get_relative_imports(filename) | |||
|
|||
|
|||
def get_class_in_module(class_name, module_path): | |||
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): Would prefer this to be a private method.
if trust_remote_code is None: | ||
if has_local_code: | ||
trust_remote_code = False | ||
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0: | ||
prev_sig_handler = None | ||
try: | ||
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) | ||
signal.alarm(TIME_OUT_REMOTE_CODE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variables like signal
seem to be undefined. After-effects of merge-conflict resolves?
f"The repository for {model_name} contains custom code which must be executed to correctly " | ||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" | ||
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" | ||
f"Do you wish to run the custom code? [y/N] " | ||
) | ||
if answer.lower() in ["yes", "y", "1"]: | ||
trust_remote_code = True | ||
elif answer.lower() in ["no", "n", "0", ""]: | ||
trust_remote_code = False | ||
signal.alarm(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's just about passing trust_remote_code=Yes
, would it be too much to just enforce that and avoid signal
altogether?
@@ -37,6 +39,7 @@ | |||
|
|||
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror | |||
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" | |||
_HF_REMOTE_CODE_LOCK = threading.Lock() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have to acquire a lock?
# Hash the module file and all its relative imports to check if we need to reload it | ||
module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file))) | ||
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, if the imports are too many could it cause any side-effects? But I guess since it's custom blocks, the users have some form of awareness already. But just wanted to flag.
# reload in both cases, unless the module is already imported and the hash hits | ||
if getattr(module, "__transformers_module_hash__", "") != module_hash: | ||
module_spec.loader.exec_module(module) | ||
module.__transformers_module_hash__ = module_hash | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it only transformers
?
""" | ||
Import a module on the cache directory for modules and extract a class from it. | ||
""" | ||
module_path = module_path.replace(os.path.sep, ".") | ||
module = importlib.import_module(module_path) | ||
name = os.path.normpath(module_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any common bits shared by get_class_in_module
and get_class_in_modular_module
that we could wrap in a method?
What does this PR do?
Add support for loading custom pipeline blocks with Modular Diffusers. PR is still in very rough shape, but is functional.
Snippet to test
Note I think the formatting changes might have been because of a difference in ruff versions.
TODOs:
get_class_from_dynamic_module
. Probably don't need theis_modular
argument.get_class_in_modular_module
trust_remote_code=None
is currently broken. Need to fix.Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.