99from .constants import TOKEN_COSTS
1010from decimal import Decimal
1111import logging
12+ import re
13+ from typing import Optional , Tuple , Pattern
1214
1315logger = logging .getLogger (__name__ )
1416
1921TokenType = Literal ["input" , "output" , "cached" ]
2022
2123
24+ MODEL_PRICE_PATTERNS : List [Tuple [Pattern [str ], Dict [str , Union [int , float , str , bool ]]]] = []
25+
26+
27+ def _to_per_token (cost_per_1k_tokens : Union [int , float , Decimal ]) -> float :
28+ """Convert a price expressed per 1K tokens to a per-token float."""
29+ return float (Decimal (str (cost_per_1k_tokens )) / Decimal (1000 ))
30+
31+
32+ def configure_model (
33+ model_name : str ,
34+ input_cost_per_1k_tokens : Union [int , float , Decimal ],
35+ output_cost_per_1k_tokens : Union [int , float , Decimal ],
36+ * ,
37+ max_input_tokens : Optional [int ] = None ,
38+ max_output_tokens : Optional [int ] = None ,
39+ litellm_provider : Optional [str ] = None ,
40+ mode : str = "chat" ,
41+ ) -> None :
42+ """
43+ Register or override pricing for a specific model name.
44+
45+ Args:
46+ model_name: The exact model identifier to store (case-insensitive).
47+ input_cost_per_1k_tokens: USD per 1K input tokens (e.g., 0.003 for $3.00/M).
48+ output_cost_per_1k_tokens: USD per 1K output tokens (e.g., 0.015 for $15.00/M).
49+ max_input_tokens: Optional maximum input tokens.
50+ max_output_tokens: Optional maximum output tokens.
51+ litellm_provider: Optional provider hint (e.g., "bedrock").
52+ mode: Model mode, defaults to "chat".
53+ """
54+ key = model_name .lower ()
55+ TOKEN_COSTS [key ] = {
56+ "input_cost_per_token" : _to_per_token (input_cost_per_1k_tokens ),
57+ "output_cost_per_token" : _to_per_token (output_cost_per_1k_tokens ),
58+ "mode" : mode ,
59+ }
60+ if max_input_tokens is not None :
61+ TOKEN_COSTS [key ]["max_input_tokens" ] = int (max_input_tokens )
62+ if max_output_tokens is not None :
63+ TOKEN_COSTS [key ]["max_output_tokens" ] = int (max_output_tokens )
64+ if litellm_provider is not None :
65+ TOKEN_COSTS [key ]["litellm_provider" ] = litellm_provider
66+
67+
68+ def register_model_pattern (
69+ pattern : str ,
70+ input_cost_per_1k_tokens : Union [int , float , Decimal ],
71+ output_cost_per_1k_tokens : Union [int , float , Decimal ],
72+ * ,
73+ max_input_tokens : Optional [int ] = None ,
74+ max_output_tokens : Optional [int ] = None ,
75+ litellm_provider : Optional [str ] = None ,
76+ mode : str = "chat" ,
77+ ) -> None :
78+ """
79+ Register a wildcard or regex-like pattern that assigns pricing to any matching model.
80+
81+ The pattern supports '*' as a wildcard. It is converted to a full regex match.
82+ Example: "bedrock/anthropic.claude-3-5-sonnet-*".
83+ """
84+ # Convert simple wildcard pattern to regex
85+ regex_str = "^" + re .escape (pattern ).replace (r"\*" , ".*" ) + "$"
86+ compiled = re .compile (regex_str )
87+ entry : Dict [str , Union [int , float , str , bool ]] = {
88+ "input_cost_per_token" : _to_per_token (input_cost_per_1k_tokens ),
89+ "output_cost_per_token" : _to_per_token (output_cost_per_1k_tokens ),
90+ "mode" : mode ,
91+ }
92+ if max_input_tokens is not None :
93+ entry ["max_input_tokens" ] = int (max_input_tokens )
94+ if max_output_tokens is not None :
95+ entry ["max_output_tokens" ] = int (max_output_tokens )
96+ if litellm_provider is not None :
97+ entry ["litellm_provider" ] = litellm_provider
98+ MODEL_PRICE_PATTERNS .append ((compiled , entry ))
99+
100+
101+ def _normalize_model_for_pricing (model : str ) -> str :
102+ """
103+ Normalize a model identifier for price lookup.
104+
105+ Rules:
106+ - Lowercase everything
107+ - Keep exact matches if present
108+ - Special-case Bedrock Anthropics: strip the leading "bedrock/" prefix when the next
109+ segment starts with "anthropic.", since pricing keys are stored without the prefix.
110+ - Otherwise, try the last segment after '/'. This helps for provider prefixes like
111+ "azure/", "openai/", etc., when prices are stored under the bare model key.
112+ """
113+ m = model .lower ()
114+ if m in TOKEN_COSTS :
115+ return m
116+
117+ # bedrock/anthropic.* => anthropic.* (pricing keys stored this way)
118+ if m .startswith ("bedrock/" ) and "/" in m :
119+ first , rest = m .split ("/" , 1 )
120+ if rest .startswith ("anthropic." ):
121+ if rest in TOKEN_COSTS :
122+ return rest
123+
124+ # Try last path segment as a fallback (handles e.g., azure/gpt-4o)
125+ if "/" in m :
126+ last = m .split ("/" )[- 1 ]
127+ if last in TOKEN_COSTS :
128+ return last
129+
130+ # Try matching any user-registered wildcard patterns. If matched, bind pricing to this key.
131+ for regex , entry in MODEL_PRICE_PATTERNS :
132+ if regex .match (m ):
133+ # Cache the computed pricing under the exact model string
134+ TOKEN_COSTS [m ] = dict (entry )
135+ return m
136+
137+ return m
138+
139+
22140def _get_field_from_token_type (token_type : TokenType ) -> str :
23141 """
24142 Get the field name from the token type.
@@ -97,7 +215,7 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int:
97215 model = strip_ft_model_name (model )
98216
99217 # Anthropic token counting requires a valid API key
100- if "claude-" in model :
218+ if "claude-" in model and not model . startswith ( "anthropic." ) :
101219 logger .warning (
102220 "Warning: Anthropic token counting API is currently in beta. Please expect differences in costs!"
103221 )
@@ -199,7 +317,7 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: TokenType)
199317 Returns:
200318 Decimal: The calculated cost in USD.
201319 """
202- model = model . lower ( )
320+ model = _normalize_model_for_pricing ( model )
203321 if model not in TOKEN_COSTS :
204322 raise KeyError (
205323 f"""Model { model } is not implemented.
@@ -238,7 +356,8 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
238356 """
239357 model = model .lower ()
240358 model = strip_ft_model_name (model )
241- if model not in TOKEN_COSTS :
359+ pricing_model = _normalize_model_for_pricing (model )
360+ if pricing_model not in TOKEN_COSTS :
242361 raise KeyError (
243362 f"""Model { model } is not implemented.
244363 Double-check your spelling, or submit an issue/PR"""
@@ -253,7 +372,7 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
253372 else count_message_tokens (prompt , model )
254373 )
255374
256- return calculate_cost_by_tokens (prompt_tokens , model , "input" )
375+ return calculate_cost_by_tokens (prompt_tokens , pricing_model , "input" )
257376
258377
259378def calculate_completion_cost (completion : str , model : str ) -> Decimal :
@@ -273,7 +392,8 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
273392 Decimal('0.000014')
274393 """
275394 model = strip_ft_model_name (model )
276- if model not in TOKEN_COSTS :
395+ pricing_model = _normalize_model_for_pricing (model )
396+ if pricing_model not in TOKEN_COSTS :
277397 raise KeyError (
278398 f"""Model { model } is not implemented.
279399 Double-check your spelling, or submit an issue/PR"""
@@ -291,7 +411,7 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
291411 else :
292412 completion_tokens = count_string_tokens (completion , model )
293413
294- return calculate_cost_by_tokens (completion_tokens , model , "output" )
414+ return calculate_cost_by_tokens (completion_tokens , pricing_model , "output" )
295415
296416
297417def calculate_all_costs_and_tokens (
@@ -322,7 +442,7 @@ def calculate_all_costs_and_tokens(
322442 else count_message_tokens (prompt , model )
323443 )
324444
325- if "claude-" in model :
445+ if "claude-" in model and not model . startswith ( "anthropic." ) :
326446 logger .warning ("Warning: Token counting is estimated for " )
327447 completion_list = [{"role" : "assistant" , "content" : completion }]
328448 # Anthropic appends some 13 additional tokens to the actual completion tokens
0 commit comments