@@ -28,6 +28,7 @@ def test_olmo_model(model_path: str):
28
28
@pytest .mark .skipif (torch .cuda .device_count () < 1 , reason = "Requires CUDA devices" )
29
29
def test_flash_attention_2 (model_path : str ):
30
30
from transformers import AutoModelForCausalLM , AutoTokenizer
31
+
31
32
import hf_olmo # noqa: F401
32
33
33
34
hf_model = AutoModelForCausalLM .from_pretrained (model_path )
@@ -45,6 +46,7 @@ def test_flash_attention_2(model_path: str):
45
46
46
47
def test_sdpa (model_path : str ):
47
48
from transformers import AutoModelForCausalLM , AutoTokenizer
49
+
48
50
import hf_olmo # noqa: F401
49
51
50
52
hf_model = AutoModelForCausalLM .from_pretrained (model_path )
@@ -62,6 +64,7 @@ def test_sdpa(model_path: str):
62
64
63
65
def test_gradient_checkpointing (model_path : str ):
64
66
from transformers import AutoModelForCausalLM , AutoTokenizer , PreTrainedModel
67
+
65
68
import hf_olmo # noqa: F401
66
69
67
70
hf_model : PreTrainedModel = AutoModelForCausalLM .from_pretrained (model_path )
@@ -81,6 +84,7 @@ def test_gradient_checkpointing(model_path: str):
81
84
82
85
def test_gradient_checkpointing_disable (model_path : str ):
83
86
from transformers import AutoModelForCausalLM , AutoTokenizer , PreTrainedModel
87
+
84
88
import hf_olmo # noqa: F401
85
89
86
90
hf_model : PreTrainedModel = AutoModelForCausalLM .from_pretrained (model_path )
0 commit comments