Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/trace/test_call_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@ def greet(name: str, age: int) -> str:
"thread_id": call.thread_id,
"turn_id": call.turn_id,
"project_id": call.project_id,
"wb_run_id": call.wb_run_id,
"wb_run_step": call.wb_run_step,
"wb_run_step_end": call.wb_run_step_end,
"storage_size_bytes": call.storage_size_bytes,
"total_storage_size_bytes": call.total_storage_size_bytes,
}
31 changes: 28 additions & 3 deletions tests/trace/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
get_serializer_for_obj,
register_serializer,
)
from weave.trace.wandb_run_context import WandbRunContext
from weave.trace_server.clickhouse_trace_server_batched import NotFoundError
from weave.trace_server.common_interface import SortBy
from weave.trace_server.constants import MAX_DISPLAY_NAME_LENGTH
Expand Down Expand Up @@ -4054,9 +4055,6 @@ def test_table_create_from_digests(network_proxy_client):
def test_calls_query_with_wb_run_id_not_null(client, monkeypatch):
"""Test optimized stats query for wb_run_id not null."""
# Mock wandb to simulate a run
from weave.trace import weave_client
from weave.trace.wandb_run_context import WandbRunContext

mock_run_id = f"{client._project_id()}/test_run_123"
monkeypatch.setattr(
weave_client,
Expand All @@ -4079,6 +4077,33 @@ def test_op(x: int) -> int:
assert len(calls) == 1
assert calls[0].wb_run_id == mock_run_id


def test_get_calls_columns_wb_run_id(client, monkeypatch):
mock_run_id = f"{client._project_id()}/test_run_456"
monkeypatch.setattr(
weave_client,
"get_global_wb_run_context",
lambda: WandbRunContext(run_id="test_run_456", step=7),
)

@weave.op
def test_op(x: int) -> int:
return x * 3

_, call = test_op.call(2)
client.flush()

calls = list(
client.get_calls(
columns=["wb_run_id"],
filter=tsi.CallsFilter(call_ids=[call.id]),
)
)

assert len(calls) == 1
assert hasattr(calls[0], "wb_run_id")
assert calls[0].wb_run_id == mock_run_id

# Now query for calls with wb_run_id not null using limit=1 to trigger optimization
query = tsi.Query(
**{
Expand Down
10 changes: 10 additions & 0 deletions weave/trace/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ def to_dict(self) -> CallDict:
deleted_at=self.deleted_at,
thread_id=self.thread_id,
turn_id=self.turn_id,
wb_run_id=self.wb_run_id,
wb_run_step=self.wb_run_step,
wb_run_step_end=self.wb_run_step_end,
storage_size_bytes=self.storage_size_bytes,
total_storage_size_bytes=self.total_storage_size_bytes,
)


Expand Down Expand Up @@ -313,6 +318,11 @@ class CallDict(TypedDict):
deleted_at: datetime.datetime | None
thread_id: str | None
turn_id: str | None
wb_run_id: str | None
wb_run_step: int | None
wb_run_step_end: int | None
storage_size_bytes: int | None
total_storage_size_bytes: int | None


CallsIter = PaginatedIterator[CallSchema, WeaveObject]
Expand Down
Loading