-
Notifications
You must be signed in to change notification settings - Fork 110
New Interfaces for LibTuner
#771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
src/flag_gems/utils/libentry.py
Outdated
lambda idx0, idx1: self.strategy[idx0](args[idx1]), | ||
enumerate(self.keys), | ||
) | ||
) + tuple(str(arg.dtype) for arg in args.values() if hasattr(arg, "dtype")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
append the data types for both status
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
key = tuple(args[k] for k in self.keys if k in args) | ||
else: | ||
key = tuple( | ||
starmap( | ||
lambda idx0, idx1: self.strategy[idx0](args[idx1]), | ||
enumerate(self.keys), | ||
) | ||
) | ||
key += tuple(str(arg.dtype) for arg in args.values() if hasattr(arg, "dtype")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it OK that self.keys
is not filtered against args
in the else branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. The previous implementation doesn't check args
and works well. Do you think I should add it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a big issue assuming self.key
not including kwargs.
src/flag_gems/utils/libentry.py
Outdated
""" | ||
|
||
def decorator( | ||
policy: Callable[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about renaming it as policy_impl
to distinguish from AnonymousLibTunerImpl
's policy
method. It's bit confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix.
src/flag_gems/utils/libentry.py
Outdated
args: Tuple[Any], | ||
kwargs: Dict[str, Any], | ||
) -> Tuple[triton.Config, Dict[str, float]]: | ||
return policy(fn, configs, args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling policy_impl
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix.
return strategy | ||
|
||
return decorator | ||
|
||
def run(self, *args, **kwargs): | ||
self.nargs = dict(zip(self.arg_names, args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment for clarification, noting that arg_names correspond to a JITFunction's parameter names in its signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take a look plz. I'm unsure whether it meets your requirements. Thanks!
), f"the length of strategy {len(strategy)} must match the length of keys {len(self.keys)}" | ||
strategy: List[Callable[[Any], Any]] = [ | ||
LibTuner.get_strategy(s) if isinstance(s, str) else s for s in strategy | ||
] | ||
self.strategy = strategy | ||
# Use table name with hash instead of hash in key | ||
self.kernel_hash = None | ||
self.table_name = f"{self.__name__}_{self.get_kernel_hash()}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To deal with re-entrance is similar to what jit's _do_compile does in the async_mode, where the cache_key is voluntarily calculated with get_cache_key
. I think it's more appropriate than just reusing cache key?
cache_key = get_cache_key(src, backend, options, env_vars)
PR Category
libtuner
Type of Change
New Feature
Description
This is a proposal to provide a new interface to extend the current
LibTuner
. It redefinedLibTuner
as a base class and allows extending it by overriding thepolicy
method. For easily usage, we also provide a subclass calledOfflineLibTuner
so our current implementation, including the default one and new strategies, could be added to the newLibTuner
system bymake
static method. For the future extension including online autotuner, it could be integrated by providing new subclasses. Totally our design principle is to allow different granularity to control the autotuner details.Except that, we also provide a dispatch mechanism so new policies and
LibTuner
s could be registered easily into the global dispatcher easily and use them with their name.If this work is approved, future work includes migrate #755, #763, and implement online autotuner #762 on it.
Issue
This is a competitive proposal against #763. Please have a detailed discussion on which design we should work on, or anything we can make better.
Progress
Performance