Skip to content

estimate token use before sending openai completions #1112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
91 changes: 88 additions & 3 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import logging
import re
import tiktoken
from typing import List, Union

import openai
Expand Down Expand Up @@ -113,12 +114,24 @@
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4o-2024-08-06": 128000,
"gpt-4o-mini": 16384,
"gpt-4o-mini": 128000,
"gpt-4o-mini-2024-07-18": 16384,
"o1-mini": 65536,
"o1": 200000,
"o1-mini": 128000,
"o1-mini-2024-09-12": 65536,
"o1-preview": 32768,
"o1-preview-2024-09-12": 32768,
"o3-mini": 200000,
}

output_max = {
"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
"gpt-4o": 16384,
"o3-mini": 100000,
"o1": 100000,
"o1-mini": 65536,
"gpt-4o-mini": 16384,
}


Expand Down Expand Up @@ -171,6 +184,75 @@ def _clear_client(self):
def _validate_config(self):
pass

def _validate_token_args(self, create_args: dict, prompt: str) -> dict:
"""Ensure maximum token limit compatibility with OpenAI create request"""
token_limit_key = "max_tokens"
fixed_cost = 0
if (
self.generator == self.client.chat.completions
and self.max_tokens is not None
):
token_limit_key = "max_completion_tokens"
if not hasattr(self, "max_completion_tokens"):
create_args["max_completion_tokens"] = self.max_tokens

create_args.pop(
"max_tokens", None
) # remove deprecated value, utilize `max_completion_tokens`
# every reply is primed with <|start|>assistant<|message|> (3 toks) plus 1 for name change
# see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
# section 6 "Counting tokens for chat completions API calls"
fixed_cost = 7

# basic token boundary validation to ensure requests are not rejected for exceeding target context length
token_limit = create_args.pop(token_limit_key, None)
if token_limit is not None:
# Suppress max_tokens if greater than context_len
if (
hasattr(self, "context_len")
and self.context_len is not None
and token_limit > self.context_len
):
logging.warning(
f"Requested garak maximum tokens {token_limit} exceeds context length {self.context_len}, no limit will be applied to the request"
)
token_limit = None

if self.name in output_max and token_limit > output_max[self.name]:
logging.warning(
f"Requested maximum tokens {token_limit} exceeds max output {output_max[self.name]}, no limit will be applied to the request"
)
token_limit = None

if self.context_len is not None and token_limit is not None:
# count tokens in prompt and ensure token_limit requested is <= context_len or output_max allowed
prompt_tokens = 0 # this should apply to messages object
try:
encoding = tiktoken.encoding_for_model(self.name)
prompt_tokens = len(encoding.encode(prompt))
except KeyError as e:
prompt_tokens = int(
len(prompt.split()) * 4 / 3
) # extra naive fallback 1 token ~= 3/4 of a word

if (prompt_tokens + fixed_cost + token_limit > self.context_len) and (
prompt_tokens + fixed_cost < self.context_len
):
token_limit = self.context_len - prompt_tokens - fixed_cost
elif token_limit > prompt_tokens + fixed_cost:
token_limit = token_limit - prompt_tokens - fixed_cost
else:
raise garak.exception.GarakException(
"A response of %s toks plus prompt %s toks cannot be generated; API capped at context length %s toks"
% (
self.max_tokens,
prompt_tokens + fixed_cost,
self.context_len,
)
)
create_args[token_limit_key] = token_limit
return create_args

def __init__(self, name="", config_root=_config):
self.name = name
self._load_config(config_root)
Expand Down Expand Up @@ -216,13 +298,16 @@ def _call_model(
create_args = {}
if "n" not in self.suppressed_params:
create_args["n"] = generations_this_call
for arg in inspect.signature(self.generator.create).parameters:
create_params = inspect.signature(self.generator.create).parameters
for arg in create_params:
if arg == "model":
create_args[arg] = self.name
continue
if hasattr(self, arg) and arg not in self.suppressed_params:
create_args[arg] = getattr(self, arg)

create_args = self._validate_token_args(create_args, prompt)

if self.generator == self.client.completions:
if not isinstance(prompt, str):
msg = (
Expand Down
118 changes: 116 additions & 2 deletions tests/generators/test_openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@
import inspect

from collections.abc import Iterable
from garak.generators.openai import OpenAICompatible
from garak.generators.openai import OpenAICompatible, output_max, context_lengths


# TODO: expand this when we have faster loading, currently to process all generator costs 30s for 3 tests
# GENERATORS = [
# classname for (classname, active) in _plugins.enumerate_plugins("generators")
# ]
GENERATORS = ["generators.openai.OpenAIGenerator", "generators.nim.NVOpenAIChat", "generators.groq.GroqChat"]
GENERATORS = [
"generators.openai.OpenAIGenerator",
"generators.nim.NVOpenAIChat",
"generators.groq.GroqChat",
]

MODEL_NAME = "gpt-3.5-turbo-instruct"
ENV_VAR = os.path.abspath(
Expand Down Expand Up @@ -98,3 +102,113 @@ def test_openai_multiprocessing(openai_compat_mocks, classname):
with Pool(parallel_attempts) as attempt_pool:
for result in attempt_pool.imap_unordered(generate_in_subprocess, prompts):
assert result is not None


def test_validate_call_model_chat_token_restrictions(openai_compat_mocks):
import lorem
import json
import tiktoken
from garak.exception import GarakException

generator = build_test_instance(OpenAICompatible)
mock_url = getattr(generator, "uri", "https://api.openai.com/v1")
with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock:
mock_response = openai_compat_mocks["chat"]
respx_mock.post("chat/completions").mock(
return_value=httpx.Response(
mock_response["code"], json=mock_response["json"]
)
)
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[0].request.content)
assert (
req_body["max_completion_tokens"] <= generator.max_tokens
), "request max_completion_tokens must account for prompt tokens"

test_large_context = ""
encoding = tiktoken.encoding_for_model(MODEL_NAME)
while len(encoding.encode(test_large_context)) < generator.max_tokens:
test_large_context += "\n".join(lorem.paragraph())
large_context_len = len(encoding.encode(test_large_context))

generator.context_len = large_context_len * 2
generator.max_tokens = generator.context_len * 2
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[1].request.content)
assert (
req_body.get("max_completion_tokens", None) is None
and req_body.get("max_tokens", None) is None
), "request max_completion_tokens is suppressed when larger than context length"

generator.max_tokens = large_context_len - int(large_context_len / 2)
generator.context_len = large_context_len
with pytest.raises(GarakException) as exc_info:
generator._call_model(test_large_context)
assert "API capped" in str(
exc_info.value
), "a prompt larger than max_tokens must raise exception"

max_output_model = "gpt-3.5-turbo"
generator.name = max_output_model
generator.max_tokens = output_max[max_output_model] * 2
generator.context_len = generator.max_tokens * 2
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[2].request.content)
assert (
req_body.get("max_completion_tokens", None) is None
and req_body.get("max_tokens", None) is None
), "request max_completion_tokens is suppressed when larger than output_max limited known model"

generator.max_completion_tokens = int(output_max[max_output_model] / 2)
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[3].request.content)
assert (
req_body["max_completion_tokens"] < generator.max_completion_tokens
and req_body.get("max_tokens", None) is None
), "request max_completion_tokens is suppressed when larger than output_max limited known model"


def test_validate_call_model_completion_token_restrictions(openai_compat_mocks):
import lorem
import json
import tiktoken
from garak.exception import GarakException

generator = build_test_instance(OpenAICompatible)
generator._load_client()
generator.generator = generator.client.completions
mock_url = getattr(generator, "uri", "https://api.openai.com/v1")
with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock:
mock_response = openai_compat_mocks["completion"]
respx_mock.post("/completions").mock(
return_value=httpx.Response(
mock_response["code"], json=mock_response["json"]
)
)
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[0].request.content)
assert (
req_body["max_tokens"] <= generator.max_tokens
), "request max_tokens must account for prompt tokens"

test_large_context = ""
encoding = tiktoken.encoding_for_model(MODEL_NAME)
while len(encoding.encode(test_large_context)) < generator.max_tokens:
test_large_context += "\n".join(lorem.paragraph())
large_context_len = len(encoding.encode(test_large_context))

generator.context_len = large_context_len * 2
generator.max_tokens = generator.context_len * 2
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[1].request.content)
assert (
req_body.get("max_tokens", None) is None
), "request max_tokens is suppressed when larger than context length"

generator.max_tokens = large_context_len - int(large_context_len / 2)
generator.context_len = large_context_len
with pytest.raises(GarakException) as exc_info:
generator._call_model(test_large_context)
assert "API capped" in str(
exc_info.value
), "a prompt larger than max_tokens must raise exception"