Skip to content

Commit

Permalink
additional tests for hugginface inference
Browse files Browse the repository at this point in the history
Add validation test to ensure uri values and names populate as expected

Signed-off-by: Jeffrey Martin <[email protected]>
  • Loading branch information
jmartin-tech committed Nov 18, 2024
1 parent 3b98a00 commit ca2e050
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
7 changes: 7 additions & 0 deletions tests/generators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ def openai_compat_mocks():
"""Mock responses for OpenAI compatible endpoints"""
with open(pathlib.Path(__file__).parents[0] / "openai.json") as mock_openai:
return json.load(mock_openai)


@pytest.fixture
def hf_endpoint_mocks():
"""Mock responses for Huggingface InferenceAPI based endpoints"""
with open(pathlib.Path(__file__).parents[0] / "hf_inference.json") as mock_openai:
return json.load(mock_openai)
10 changes: 10 additions & 0 deletions tests/generators/hf_inference.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"hf_inference": {
"code": 200,
"json": [
{
"generated_text":"restricted by their policy,"
}
]
}
}
60 changes: 56 additions & 4 deletions tests/generators/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import requests
import transformers
import garak.generators.huggingface
from garak._config import GarakSubConfig
Expand All @@ -8,6 +9,7 @@
def hf_generator_config():
gen_config = {
"huggingface": {
"api_key": "fake",
"hf_args": {
"device": "cpu",
"torch_dtype": "float32",
Expand All @@ -19,6 +21,17 @@ def hf_generator_config():
return config_root


@pytest.fixture
def hf_mock_response(hf_endpoint_mocks):
import json

mock_resp_data = hf_endpoint_mocks["hf_inference"]
mock_resp = requests.Response()
mock_resp.status_code = mock_resp_data["code"]
mock_resp._content = json.dumps(mock_resp_data["json"]).encode("UTF-8")
return mock_resp


def test_pipeline(hf_generator_config):
generations = 10
g = garak.generators.huggingface.Pipeline("gpt2", config_root=hf_generator_config)
Expand All @@ -37,16 +50,55 @@ def test_pipeline(hf_generator_config):
assert isinstance(item, str)


def test_inference():
return # slow w/o key
g = garak.generators.huggingface.InferenceAPI("gpt2")
assert g.name == "gpt2"
def test_inference(mocker, hf_mock_response, hf_generator_config):
model_name = "gpt2"
mock_request = mocker.patch.object(
requests, "request", return_value=hf_mock_response
)

g = garak.generators.huggingface.InferenceAPI(
model_name, config_root=hf_generator_config
)
assert g.name == model_name
assert model_name in g.uri

hf_generator_config.generators["huggingface"]["name"] = model_name
g = garak.generators.huggingface.InferenceAPI(config_root=hf_generator_config)
assert g.name == model_name
assert model_name in g.uri
assert isinstance(g.max_tokens, int)
g.max_tokens = 99
assert g.max_tokens == 99
g.temperature = 0.1
assert g.temperature == 0.1
output = g.generate("")
mock_request.assert_called_once()
assert len(output) == 1 # 1 generation by default
for item in output:
assert isinstance(item, str)


def test_endpoint(mocker, hf_mock_response, hf_generator_config):
model_name = "https://localhost:8000/gpt2"
mock_request = mocker.patch.object(requests, "post", return_value=hf_mock_response)

g = garak.generators.huggingface.InferenceEndpoint(
model_name, config_root=hf_generator_config
)
assert g.name == model_name
assert g.uri == model_name

hf_generator_config.generators["huggingface"]["name"] = model_name
g = garak.generators.huggingface.InferenceEndpoint(config_root=hf_generator_config)
assert g.name == model_name
assert g.uri == model_name
assert isinstance(g.max_tokens, int)
g.max_tokens = 99
assert g.max_tokens == 99
g.temperature = 0.1
assert g.temperature == 0.1
output = g.generate("")
mock_request.assert_called_once()
assert len(output) == 1 # 1 generation by default
for item in output:
assert isinstance(item, str)
Expand Down

0 comments on commit ca2e050

Please sign in to comment.