-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Description
I am trying to quantify open_clip's pre-trained model and then do a zero sample classification test on clip_benchmark. But get an error:
AttributeError: module 'triton.language' has no attribute 'libdevice'
Here's my code:
import open_clip
import bitsandbytes as bnb
def quantize_clip(model_name: str = "ViT-B-32-quickgelu", pretrained: str = "laion400m_e32", cache_dir: str = None, device="cpu", **kwargs):
print("====quantize clip====")
model, _, transform = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, cache_dir=cache_dir)
int8_linear_layer = bnb.nn.triton_based_modules.SwitchBackLinear
int8_model = open_clip.utils.replace_linear(model, int8_linear_layer, include_modules=['c_fc', 'c_proj']).cuda()
int8_model = int8_model.to(device)
tokenizer = open_clip.get_tokenizer(model_name)
print("====return quantize clip====")
return int8_model, transform, tokenizer
This code is placed in /clip_benchmark/models/quantize_model.py
My run command is as follows:
clip_benchmark eval --dataset=cifar10 --task=zeroshot_classification --pretrained=/models/clip_model/DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin --model=ViT-H-14-378-quickgelu --output=quantize.json --batch_size=64 --/CLIP_benchmark/dataset --model_type=quantize
Detailed error records are as follows:
Traceback (most recent call last):
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/bin/clip_benchmark", line 8, in <module>
sys.exit(main())
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/clip_benchmark/cli.py", line 87, in main
main_eval(base)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/clip_benchmark/cli.py", line 193, in main_eval
run(args)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/clip_benchmark/cli.py", line 313, in run
metrics = zeroshot_classification.evaluate(
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/clip_benchmark/metrics/zeroshot_classification.py", line 204, in evaluate
classifier = zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=amp)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/clip_benchmark/metrics/zeroshot_classification.py", line 52, in zero_shot_classifier
class_embeddings = model.encode_text(texts)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/open_clip/model.py", line 279, in encode_text
x = self.transformer(x, attn_mask=self.attn_mask)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/open_clip/transformer.py", line 364, in forward
x = r(x, attn_mask=attn_mask)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/open_clip/transformer.py", line 264, in forward
x = x + self.ls_2(self.mlp(self.ln_2(x)))
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/container.py", line 219, in forward
input = module(input)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/bitsandbytes/nn/triton_based_modules.py", line 209, in forward
return self._fn.apply(x, self.weight, self.bias)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/bitsandbytes/nn/triton_based_modules.py", line 31, in forward
X_int8, state_X = quantize_rowwise(X)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/bitsandbytes/triton/quantize_rowwise.py", line 66, in quantize_rowwise
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 156, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 156, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 133, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/testing.py", line 103, in do_bench
fn()
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 114, in kernel_call
self.fn.run(
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 662, in run
kernel = self.compile(
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/compiler/compiler.py", line 240, in compile
key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}"
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/compiler/compiler.py", line 109, in hash
key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}"
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 758, in cache_key
dependencies_finder.visit(self.parse())
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
return visitor(node)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 379, in generic_visit
self.visit(item)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
return visitor(node)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 167, in visit_FunctionDef
self.generic_visit(node)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 379, in generic_visit
self.visit(item)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
return visitor(node)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 221, in visit_Assign
self.generic_visit(node)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 381, in generic_visit
self.visit(value)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
return visitor(node)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 127, in visit_Call
func = self.visit(node.func)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
return visitor(node)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 112, in visit_Attribute
lhs = self.visit(node.value)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
return visitor(node)
File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 117, in visit_Attribute
return getattr(lhs, node.attr)
AttributeError: module 'triton.language' has no attribute 'libdevice'Metadata
Metadata
Assignees
Labels
No labels