Skip to content

error when use try to use int8 operations with OpenCLIP. #1012

@xiaohoua

Description

@xiaohoua

I am trying to quantify open_clip's pre-trained model and then do a zero sample classification test on clip_benchmark. But get an error:
AttributeError: module 'triton.language' has no attribute 'libdevice'
Here's my code:

import open_clip
import bitsandbytes as bnb

def quantize_clip(model_name: str = "ViT-B-32-quickgelu", pretrained: str = "laion400m_e32", cache_dir: str = None, device="cpu", **kwargs):
    print("====quantize clip====")
    model, _, transform = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, cache_dir=cache_dir)
    int8_linear_layer = bnb.nn.triton_based_modules.SwitchBackLinear
    int8_model = open_clip.utils.replace_linear(model, int8_linear_layer, include_modules=['c_fc', 'c_proj']).cuda()
    int8_model = int8_model.to(device)
    tokenizer = open_clip.get_tokenizer(model_name)
    print("====return quantize clip====")
    return int8_model, transform, tokenizer

This code is placed in /clip_benchmark/models/quantize_model.py
My run command is as follows:
clip_benchmark eval --dataset=cifar10 --task=zeroshot_classification --pretrained=/models/clip_model/DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin --model=ViT-H-14-378-quickgelu --output=quantize.json --batch_size=64 --/CLIP_benchmark/dataset --model_type=quantize

Detailed error records are as follows:

Traceback (most recent call last):
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/bin/clip_benchmark", line 8, in <module>
    sys.exit(main())
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/clip_benchmark/cli.py", line 87, in main
    main_eval(base)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/clip_benchmark/cli.py", line 193, in main_eval
    run(args)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/clip_benchmark/cli.py", line 313, in run
    metrics = zeroshot_classification.evaluate(
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/clip_benchmark/metrics/zeroshot_classification.py", line 204, in evaluate
    classifier = zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=amp)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/clip_benchmark/metrics/zeroshot_classification.py", line 52, in zero_shot_classifier
    class_embeddings = model.encode_text(texts)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/open_clip/model.py", line 279, in encode_text
    x = self.transformer(x, attn_mask=self.attn_mask)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/open_clip/transformer.py", line 364, in forward
    x = r(x, attn_mask=attn_mask)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/open_clip/transformer.py", line 264, in forward
    x = x + self.ls_2(self.mlp(self.ln_2(x)))
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/bitsandbytes/nn/triton_based_modules.py", line 209, in forward
    return self._fn.apply(x, self.weight, self.bias)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/bitsandbytes/nn/triton_based_modules.py", line 31, in forward
    X_int8, state_X = quantize_rowwise(X)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/bitsandbytes/triton/quantize_rowwise.py", line 66, in quantize_rowwise
    _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 156, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 156, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 133, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/testing.py", line 103, in do_bench
    fn()
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 114, in kernel_call
    self.fn.run(
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/compiler/compiler.py", line 240, in compile
    key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}"
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/compiler/compiler.py", line 109, in hash
    key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}"
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 758, in cache_key
    dependencies_finder.visit(self.parse())
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 379, in generic_visit
    self.visit(item)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 167, in visit_FunctionDef
    self.generic_visit(node)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 379, in generic_visit
    self.visit(item)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 221, in visit_Assign
    self.generic_visit(node)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 381, in generic_visit
    self.visit(value)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 127, in visit_Call
    func = self.visit(node.func)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 112, in visit_Attribute
    lhs = self.visit(node.value)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/data/home/houyazhou/miniconda3/envs/clip_benchmark/lib/python3.8/site-packages/triton/runtime/jit.py", line 117, in visit_Attribute
    return getattr(lhs, node.attr)
AttributeError: module 'triton.language' has no attribute 'libdevice'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions