Skip to content

Commit 0ded5d3

Browse files
authored
Merge pull request #208 from AgentOps-AI/cursor/add-aws-bedrock-model-cost-tracking-223f
Add aws bedrock model cost tracking
2 parents d3ee634 + e502444 commit 0ded5d3

File tree

2 files changed

+129
-7
lines changed

2 files changed

+129
-7
lines changed

tokencost/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@
55
calculate_prompt_cost,
66
calculate_all_costs_and_tokens,
77
calculate_cost_by_tokens,
8+
configure_model,
9+
register_model_pattern,
810
)
911
from .constants import TOKEN_COSTS_STATIC, TOKEN_COSTS, update_token_costs, refresh_prices

tokencost/costs.py

Lines changed: 127 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from .constants import TOKEN_COSTS
1010
from decimal import Decimal
1111
import logging
12+
import re
13+
from typing import Optional, Tuple, Pattern
1214

1315
logger = logging.getLogger(__name__)
1416

@@ -19,6 +21,122 @@
1921
TokenType = Literal["input", "output", "cached"]
2022

2123

24+
MODEL_PRICE_PATTERNS: List[Tuple[Pattern[str], Dict[str, Union[int, float, str, bool]]]] = []
25+
26+
27+
def _to_per_token(cost_per_1k_tokens: Union[int, float, Decimal]) -> float:
28+
"""Convert a price expressed per 1K tokens to a per-token float."""
29+
return float(Decimal(str(cost_per_1k_tokens)) / Decimal(1000))
30+
31+
32+
def configure_model(
33+
model_name: str,
34+
input_cost_per_1k_tokens: Union[int, float, Decimal],
35+
output_cost_per_1k_tokens: Union[int, float, Decimal],
36+
*,
37+
max_input_tokens: Optional[int] = None,
38+
max_output_tokens: Optional[int] = None,
39+
litellm_provider: Optional[str] = None,
40+
mode: str = "chat",
41+
) -> None:
42+
"""
43+
Register or override pricing for a specific model name.
44+
45+
Args:
46+
model_name: The exact model identifier to store (case-insensitive).
47+
input_cost_per_1k_tokens: USD per 1K input tokens (e.g., 0.003 for $3.00/M).
48+
output_cost_per_1k_tokens: USD per 1K output tokens (e.g., 0.015 for $15.00/M).
49+
max_input_tokens: Optional maximum input tokens.
50+
max_output_tokens: Optional maximum output tokens.
51+
litellm_provider: Optional provider hint (e.g., "bedrock").
52+
mode: Model mode, defaults to "chat".
53+
"""
54+
key = model_name.lower()
55+
TOKEN_COSTS[key] = {
56+
"input_cost_per_token": _to_per_token(input_cost_per_1k_tokens),
57+
"output_cost_per_token": _to_per_token(output_cost_per_1k_tokens),
58+
"mode": mode,
59+
}
60+
if max_input_tokens is not None:
61+
TOKEN_COSTS[key]["max_input_tokens"] = int(max_input_tokens)
62+
if max_output_tokens is not None:
63+
TOKEN_COSTS[key]["max_output_tokens"] = int(max_output_tokens)
64+
if litellm_provider is not None:
65+
TOKEN_COSTS[key]["litellm_provider"] = litellm_provider
66+
67+
68+
def register_model_pattern(
69+
pattern: str,
70+
input_cost_per_1k_tokens: Union[int, float, Decimal],
71+
output_cost_per_1k_tokens: Union[int, float, Decimal],
72+
*,
73+
max_input_tokens: Optional[int] = None,
74+
max_output_tokens: Optional[int] = None,
75+
litellm_provider: Optional[str] = None,
76+
mode: str = "chat",
77+
) -> None:
78+
"""
79+
Register a wildcard or regex-like pattern that assigns pricing to any matching model.
80+
81+
The pattern supports '*' as a wildcard. It is converted to a full regex match.
82+
Example: "bedrock/anthropic.claude-3-5-sonnet-*".
83+
"""
84+
# Convert simple wildcard pattern to regex
85+
regex_str = "^" + re.escape(pattern).replace(r"\*", ".*") + "$"
86+
compiled = re.compile(regex_str)
87+
entry: Dict[str, Union[int, float, str, bool]] = {
88+
"input_cost_per_token": _to_per_token(input_cost_per_1k_tokens),
89+
"output_cost_per_token": _to_per_token(output_cost_per_1k_tokens),
90+
"mode": mode,
91+
}
92+
if max_input_tokens is not None:
93+
entry["max_input_tokens"] = int(max_input_tokens)
94+
if max_output_tokens is not None:
95+
entry["max_output_tokens"] = int(max_output_tokens)
96+
if litellm_provider is not None:
97+
entry["litellm_provider"] = litellm_provider
98+
MODEL_PRICE_PATTERNS.append((compiled, entry))
99+
100+
101+
def _normalize_model_for_pricing(model: str) -> str:
102+
"""
103+
Normalize a model identifier for price lookup.
104+
105+
Rules:
106+
- Lowercase everything
107+
- Keep exact matches if present
108+
- Special-case Bedrock Anthropics: strip the leading "bedrock/" prefix when the next
109+
segment starts with "anthropic.", since pricing keys are stored without the prefix.
110+
- Otherwise, try the last segment after '/'. This helps for provider prefixes like
111+
"azure/", "openai/", etc., when prices are stored under the bare model key.
112+
"""
113+
m = model.lower()
114+
if m in TOKEN_COSTS:
115+
return m
116+
117+
# bedrock/anthropic.* => anthropic.* (pricing keys stored this way)
118+
if m.startswith("bedrock/") and "/" in m:
119+
first, rest = m.split("/", 1)
120+
if rest.startswith("anthropic."):
121+
if rest in TOKEN_COSTS:
122+
return rest
123+
124+
# Try last path segment as a fallback (handles e.g., azure/gpt-4o)
125+
if "/" in m:
126+
last = m.split("/")[-1]
127+
if last in TOKEN_COSTS:
128+
return last
129+
130+
# Try matching any user-registered wildcard patterns. If matched, bind pricing to this key.
131+
for regex, entry in MODEL_PRICE_PATTERNS:
132+
if regex.match(m):
133+
# Cache the computed pricing under the exact model string
134+
TOKEN_COSTS[m] = dict(entry)
135+
return m
136+
137+
return m
138+
139+
22140
def _get_field_from_token_type(token_type: TokenType) -> str:
23141
"""
24142
Get the field name from the token type.
@@ -97,7 +215,7 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int:
97215
model = strip_ft_model_name(model)
98216

99217
# Anthropic token counting requires a valid API key
100-
if "claude-" in model:
218+
if "claude-" in model and not model.startswith("anthropic."):
101219
logger.warning(
102220
"Warning: Anthropic token counting API is currently in beta. Please expect differences in costs!"
103221
)
@@ -199,7 +317,7 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: TokenType)
199317
Returns:
200318
Decimal: The calculated cost in USD.
201319
"""
202-
model = model.lower()
320+
model = _normalize_model_for_pricing(model)
203321
if model not in TOKEN_COSTS:
204322
raise KeyError(
205323
f"""Model {model} is not implemented.
@@ -238,7 +356,8 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
238356
"""
239357
model = model.lower()
240358
model = strip_ft_model_name(model)
241-
if model not in TOKEN_COSTS:
359+
pricing_model = _normalize_model_for_pricing(model)
360+
if pricing_model not in TOKEN_COSTS:
242361
raise KeyError(
243362
f"""Model {model} is not implemented.
244363
Double-check your spelling, or submit an issue/PR"""
@@ -253,7 +372,7 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
253372
else count_message_tokens(prompt, model)
254373
)
255374

256-
return calculate_cost_by_tokens(prompt_tokens, model, "input")
375+
return calculate_cost_by_tokens(prompt_tokens, pricing_model, "input")
257376

258377

259378
def calculate_completion_cost(completion: str, model: str) -> Decimal:
@@ -273,7 +392,8 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
273392
Decimal('0.000014')
274393
"""
275394
model = strip_ft_model_name(model)
276-
if model not in TOKEN_COSTS:
395+
pricing_model = _normalize_model_for_pricing(model)
396+
if pricing_model not in TOKEN_COSTS:
277397
raise KeyError(
278398
f"""Model {model} is not implemented.
279399
Double-check your spelling, or submit an issue/PR"""
@@ -291,7 +411,7 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
291411
else:
292412
completion_tokens = count_string_tokens(completion, model)
293413

294-
return calculate_cost_by_tokens(completion_tokens, model, "output")
414+
return calculate_cost_by_tokens(completion_tokens, pricing_model, "output")
295415

296416

297417
def calculate_all_costs_and_tokens(
@@ -322,7 +442,7 @@ def calculate_all_costs_and_tokens(
322442
else count_message_tokens(prompt, model)
323443
)
324444

325-
if "claude-" in model:
445+
if "claude-" in model and not model.startswith("anthropic."):
326446
logger.warning("Warning: Token counting is estimated for ")
327447
completion_list = [{"role": "assistant", "content": completion}]
328448
# Anthropic appends some 13 additional tokens to the actual completion tokens

0 commit comments

Comments
 (0)