Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tests/flow/test_evaluation_imperative.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,66 @@ def test_evaluation_logger_with_custom_attributes(client):
assert calls[0].attributes["custom_attribute"] == "value"


def test_evaluation_logger_prediction_metadata(client):
ev = weave.EvaluationLogger()
pred = ev.log_prediction(
inputs={"a": 1},
output=2,
metadata={"prediction_metadata": "value"},
)
pred.finish()
ev.finish()
client.flush()

calls = client.get_calls()
predict_and_score_call = next(
c for c in calls if op_name_from_call(c) == "Evaluation.predict_and_score"
)
predict_call = next(c for c in calls if op_name_from_call(c) == "Model.predict")

assert predict_and_score_call.attributes["prediction_metadata"] == "value"
assert predict_call.attributes["prediction_metadata"] == "value"


def test_evaluation_logger_example_metadata(client):
ev = weave.EvaluationLogger()
ev.log_example(
inputs={"a": 1},
output=2,
scores={"correctness": 1.0},
metadata={"prediction_metadata": "value"},
)
ev.finish()
client.flush()

calls = client.get_calls()
predict_and_score_call = next(
c for c in calls if op_name_from_call(c) == "Evaluation.predict_and_score"
)
predict_call = next(c for c in calls if op_name_from_call(c) == "Model.predict")

assert predict_and_score_call.attributes["prediction_metadata"] == "value"
assert predict_call.attributes["prediction_metadata"] == "value"


def test_evaluation_logger_score_metadata(client):
ev = weave.EvaluationLogger()
with ev.log_prediction(inputs={"a": 1}, output=2) as pred:
pred.log_score("correctness", 1.0, metadata={"score_metadata": "direct"})
with pred.log_score("quality", metadata={"score_metadata": "context"}) as score:
score.value = 0.9

ev.finish()
client.flush()

calls = client.get_calls()
correctness_call = next(c for c in calls if op_name_from_call(c) == "correctness")
quality_call = next(c for c in calls if op_name_from_call(c) == "quality")

assert correctness_call.attributes["score_metadata"] == "direct"
assert quality_call.attributes["score_metadata"] == "context"


def test_evaluation_logger_uses_passed_output_not_model_predict(client):
"""Test that EvaluationLogger uses the passed output instead of calling model.predict.

Expand Down
95 changes: 71 additions & 24 deletions weave/evaluation/eval_imperative.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ class ScoreLogger:
"""Interface for logging scores and managing prediction outputs.

This class is returned by `EvaluationLogger.log_prediction()` and can be used
either directly or as a context manager.
either directly or as a context manager. Prediction-level metadata can be
attached via `EvaluationLogger.log_prediction(..., metadata=...)`.

Direct usage - when output is known upfront:

Expand Down Expand Up @@ -413,7 +414,11 @@ def _prepare_scorer(self, scorer: Scorer | dict | str) -> Scorer:

return scorer

def _create_score_call(self, scorer: Scorer | dict | str) -> tuple[Call, Scorer]:
def _create_score_call(
self,
scorer: Scorer | dict | str,
metadata: dict[str, Any] | None = None,
) -> tuple[Call, Scorer]:
"""Create a score call but don't finish it yet."""
scorer = self._prepare_scorer(scorer)

Expand All @@ -425,18 +430,23 @@ def score_method(self: Scorer, *, output: Any, inputs: Any) -> ScoreType:
scorer.__dict__["score"] = MethodType(score_method, scorer)

# Create the score call with predict_and_score as parent
with attributes(IMPERATIVE_SCORE_MARKER):
wc = require_weave_client()
score_call = wc.create_call(
as_op(scorer.score),
inputs={
"self": scorer,
"output": self._predict_output,
"inputs": self.predict_call.inputs,
},
parent=self.predict_and_score_call,
use_stack=False,
)
score_attributes = (
IMPERATIVE_SCORE_MARKER
if metadata is None
else metadata | IMPERATIVE_SCORE_MARKER
)
wc = require_weave_client()
score_call = wc.create_call(
as_op(scorer.score),
inputs={
"self": scorer,
"output": self._predict_output,
"inputs": self.predict_call.inputs,
},
parent=self.predict_and_score_call,
attributes=score_attributes,
use_stack=False,
)

return score_call, scorer

Expand All @@ -458,22 +468,30 @@ def log_score(
self,
scorer: Scorer | dict | str,
score: ScoreType,
*,
metadata: dict[str, Any] | None = None,
) -> None: ...

@overload
def log_score(
self,
scorer: Scorer | dict | str,
score: _NotSetType = NOT_SET,
*,
metadata: dict[str, Any] | None = None,
) -> _LogScoreContext: ...

def log_score(
self,
scorer: Scorer | dict | str,
score: ScoreType | _NotSetType = NOT_SET,
*,
metadata: dict[str, Any] | None = None,
) -> _LogScoreContext | None:
"""Log a score synchronously or return a context manager for deferred scoring.

Metadata can be attached to the score call via the ``metadata`` argument.

Can be used in two ways:

1. Direct scoring (immediate):
Expand All @@ -484,14 +502,16 @@ def log_score(

2. Context manager (deferred with automatic call stack):
```python
with pred.log_score("correctness") as score_ctx:
with pred.log_score("correctness", metadata={"source": "review"}) as score_ctx:
result = calculate_score(...)
score_ctx.value = result
```
"""
# If no score provided, return a context manager for deferred scoring
if score is NOT_SET:
score_call, prepared_scorer = self._create_score_call(scorer)
score_call, prepared_scorer = self._create_score_call(
scorer, metadata=metadata
)
return _LogScoreContext(self, prepared_scorer, score_call)

# Type narrowing: score is now guaranteed to be ScoreType
Expand All @@ -505,7 +525,9 @@ def log_score(
loop = asyncio.get_running_loop()
if asyncio.current_task() is None:
# We're not in an async context, but a loop exists
return loop.run_until_complete(self.alog_score(scorer, score_value))
return loop.run_until_complete(
self.alog_score(scorer, score_value, metadata=metadata)
)

# We're in an async context, we need to handle this differently
result = None
Expand All @@ -519,7 +541,7 @@ def run_in_new_loop() -> None:
asyncio.set_event_loop(new_loop)
try:
result = new_loop.run_until_complete(
self.alog_score(scorer, score_value)
self.alog_score(scorer, score_value, metadata=metadata)
)
finally:
new_loop.close()
Expand All @@ -536,12 +558,14 @@ def run_in_new_loop() -> None:
return result
except RuntimeError:
# No event loop exists, create one with asyncio.run
return asyncio.run(self.alog_score(scorer, score_value))
return asyncio.run(self.alog_score(scorer, score_value, metadata=metadata))

async def alog_score(
self,
scorer: Scorer | dict | str,
score: ScoreType,
*,
metadata: dict[str, Any] | None = None,
) -> None:
if self._has_finished:
raise ValueError("Cannot log score after finish has been called")
Expand All @@ -561,7 +585,12 @@ def score_method(self: Scorer, *, output: Any, inputs: Any) -> ScoreType:
[self.evaluate_call, self.predict_and_score_call]
):
with _set_current_score(score):
with attributes(IMPERATIVE_SCORE_MARKER):
score_attributes = (
IMPERATIVE_SCORE_MARKER
if metadata is None
else metadata | IMPERATIVE_SCORE_MARKER
)
with attributes(score_attributes):
await self.predict_call.apply_scorer(scorer)

# this is always true because of how the scorer is created in the validator
Expand Down Expand Up @@ -788,14 +817,20 @@ def _finalize_evaluation(

self._is_finalized = True

def log_prediction(self, inputs: dict[str, Any], output: Any = None) -> ScoreLogger:
def log_prediction(
self,
inputs: dict[str, Any],
output: Any = None,
metadata: dict[str, Any] | None = None,
) -> ScoreLogger:
"""Log a prediction to the Evaluation.

Returns a ScoreLogger that can be used directly or as a context manager.

Args:
inputs: The input data for the prediction
output: The output value. Defaults to None. Can be set later using pred.output.
metadata: Optional metadata to attach to the prediction calls as attributes.

Returns:
ScoreLogger for logging scores and optionally finishing the prediction.
Expand All @@ -820,11 +855,17 @@ def log_prediction(self, inputs: dict[str, Any], output: Any = None) -> ScoreLog
original_method = self.model.__dict__.get("predict")
self.model.__dict__["predict"] = self._context_predict_method

prediction_attributes = (
IMPERATIVE_EVAL_MARKER
if metadata is None
else metadata | IMPERATIVE_EVAL_MARKER
)

try:
with call_context.set_call_stack([self._evaluate_call]):
# Make the prediction call
with _set_current_output(output):
with attributes(IMPERATIVE_EVAL_MARKER):
with attributes(prediction_attributes):
_, predict_and_score_call = (
self._pseudo_evaluation.predict_and_score.call(
self._pseudo_evaluation,
Expand Down Expand Up @@ -861,7 +902,12 @@ def log_prediction(self, inputs: dict[str, Any], output: Any = None) -> ScoreLog
return pred

def log_example(
self, inputs: dict[str, Any], output: Any, scores: dict[str, ScoreType]
self,
inputs: dict[str, Any],
output: Any,
scores: dict[str, ScoreType],
*,
metadata: dict[str, Any] | None = None,
) -> None:
"""Log a complete example with inputs, output, and scores.

Expand All @@ -872,6 +918,7 @@ def log_example(
inputs: The input data for the prediction
output: The output value
scores: Dictionary mapping scorer names to score values
metadata: Optional metadata to attach to the prediction calls as attributes.

Example:
```python
Expand All @@ -890,7 +937,7 @@ def log_example(
)

# Log the prediction with the output
pred = self.log_prediction(inputs=inputs, output=output)
pred = self.log_prediction(inputs=inputs, output=output, metadata=metadata)

# Log all the scores
for scorer_name, score_value in scores.items():
Expand Down
Loading