Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): pyright: enable reportCallIssue #2881

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ repos:
# Note: You have to update pyproject.toml[tool.mypy] too!
args: ["--config-file=pyproject.toml"]
exclude: (.*pyi$)|(weave_query)|(tests)|(examples)
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.387
hooks:
- id: pyright
additional_dependencies: [".[tests]"]

# This is legacy Weave when we were building a notebook product - should be removed
- repo: local
hooks:
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ exclude = ["weave_query", "tests", "examples", "docs", "noxfile.py"]
# In cases where we support multiple versions of an integration, some imports can be missing
reportMissingImports = false

# TODO: Gradually remove as we improve our code!
reportAttributeAccessIssue = false
reportPossiblyUnboundVariable = false
reportOptionalMemberAccess = false
reportArgumentType = false

[tool.mypy]
warn_unused_configs = true
Expand Down
2 changes: 1 addition & 1 deletion weave/flow/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def bind_rows(self, dataset: Union[list[dict], Any]) -> list["Prompt"]:
return bound

@overload
def __getitem__(self, index: SupportsIndex) -> Any: ...
def __getitem__(self, key: SupportsIndex) -> Any: ...

@overload
def __getitem__(self, key: slice) -> "EasyPrompt": ...
Expand Down
10 changes: 5 additions & 5 deletions weave/integrations/cohere/cohere_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _accumulate_content(

def cohere_wrapper(name: str) -> Callable:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op = weave.op(fn)
op.name = name # type: ignore
return op

Expand Down Expand Up @@ -122,7 +122,7 @@ def _wrapper(*args: Any, **kwargs: Any) -> Any:

return _wrapper

op = weave.op()(_post_process_response(fn))
op = weave.op(_post_process_response(fn))
op.name = name # type: ignore
return op

Expand Down Expand Up @@ -156,7 +156,7 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:

return _wrapper

op = weave.op()(_post_process_response(fn))
op = weave.op(_post_process_response(fn))
op.name = name # type: ignore
return op

Expand All @@ -165,7 +165,7 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:

def cohere_stream_wrapper(name: str) -> Callable:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op = weave.op(fn)
op.name = name # type: ignore
return add_accumulator(op, lambda inputs: cohere_accumulator) # type: ignore

Expand All @@ -174,7 +174,7 @@ def wrapper(fn: Callable) -> Callable:

def cohere_stream_wrapper_v2(name: str) -> Callable:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op = weave.op(fn)
op.name = name # type: ignore
return add_accumulator(
op, make_accumulator=lambda inputs: cohere_accumulator_v2
Expand Down
8 changes: 4 additions & 4 deletions weave/integrations/instructor/instructor_iterable_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@


def instructor_iterable_accumulator(
acc: Optional[BaseModel], value: BaseModel
acc: Optional[list[BaseModel]], value: BaseModel
) -> list[BaseModel]:
if acc is None:
acc = [value]
return [value]
if acc[-1] != value:
acc.append(value)
return acc
Expand All @@ -29,7 +29,7 @@ def should_accumulate_iterable(inputs: dict) -> bool:

def instructor_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op = weave.op(fn)
op.name = name # type: ignore
return add_accumulator(
op, # type: ignore
Expand All @@ -50,7 +50,7 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op()(_fn_wrapper(fn))
op = weave.op(_fn_wrapper(fn))
op.name = name # type: ignore
return add_accumulator(
op, # type: ignore
Expand Down
8 changes: 5 additions & 3 deletions weave/integrations/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
except ImportError:
import_failed = True

from typing import Any, Dict, Generator, List, Optional
from typing import Any, Dict, Generator, List, Optional, cast

RUNNABLE_SEQUENCE_NAME = "RunnableSequence"

Expand Down Expand Up @@ -89,7 +89,7 @@ def _run_to_dict(run: Run, as_input: bool = False) -> dict:
run_dict = {k: v for k, v in run_dict.items() if v}
return run_dict

class WeaveTracer(BaseTracer):
class WeaveTracer(BaseTracer): # pyright: ignore[reportRedeclaration]
run_inline: bool = True

def __init__(self, **kwargs: Any) -> None:
Expand Down Expand Up @@ -182,7 +182,9 @@ def _persist_run_single(self, run: Run) -> None:
# Note: this is implemented as a network call - it would be much nice
# to refactor `create_call` such that it could accept a parent_id instead
# of an entire Parent object.
parent_run = self.gc.get_call(wv_current_run.parent_id)
parent_run = cast(
Call, self.gc.get_call(wv_current_run.parent_id)
)

fn_name = make_pythonic_function_name(run.name)
complete_op_name = f"langchain.{run.run_type.capitalize()}.{fn_name}"
Expand Down
2 changes: 1 addition & 1 deletion weave/integrations/llamaindex/llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

if not import_failed:

class WeaveCallbackHandler(BaseCallbackHandler):
class WeaveCallbackHandler(BaseCallbackHandler): # pyright: ignore[reportRedeclaration]
"""Base callback handler that can be used to track event starts and ends."""

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions weave/integrations/notdiamond/custom_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _get_model_results(provider_name: str) -> pd.DataFrame:
class _DummyEvalModel(weave.Model):
model_results: pd.DataFrame

@weave.op
def predict(self, prompt: str) -> Dict[str, Any]:
response, score = self.model_results[
self.model_results[prompt_column] == prompt
Expand All @@ -92,12 +93,12 @@ def predict(self, prompt: str) -> Dict[str, Any]:
class BestRoutedModel(_DummyEvalModel):
model_name: str

@weave.op()
@weave.op
def predict(self, prompt: str) -> Dict[str, Any]:
return super().predict(prompt)

class NotDiamondRoutedModel(_DummyEvalModel):
@weave.op()
@weave.op
def predict(self, prompt: str) -> Dict[str, Any]:
return super().predict(prompt)

Expand Down
2 changes: 2 additions & 0 deletions weave/scorers/base_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Scorer(Object):
description="A mapping from column names in the dataset to the names expected by the scorer",
)

@weave.op
def score(self, *, output: Any, **kwargs: Any) -> Any:
raise NotImplementedError

Expand Down Expand Up @@ -87,6 +88,7 @@ def auto_summarize(data: list) -> Optional[dict[str, Any]]:
def get_scorer_attributes(
scorer: Union[Callable, Op, Scorer],
) -> Tuple[str, Callable, Callable]:
score_fn: Union[Op, Callable[..., Any]]
if weave_isinstance(scorer, Scorer):
scorer_name = scorer.name
if scorer_name is None:
Expand Down
4 changes: 2 additions & 2 deletions weave/trace/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def cli() -> None:
pass


@cli.command(help="Serve weave models.")
@cli.command(help="Serve weave models.") # type: ignore
@click.argument("model_ref")
@click.option("--method", help="Method name to serve.")
@click.option("--project", help="W&B project name.")
Expand Down Expand Up @@ -54,7 +54,7 @@ def serve(
)


@cli.group(help="Deploy weave models.")
@cli.group(help="Deploy weave models.") # type: ignore
def deploy() -> None:
pass

Expand Down
2 changes: 1 addition & 1 deletion weave/trace/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def create_wrapper(func: Callable) -> Op:
if is_async:

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRedeclaration]
res, _ = await _do_call_async(
cast(Op, wrapper), *args, __should_raise=True, **kwargs
)
Expand Down
3 changes: 2 additions & 1 deletion weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def children(self) -> "CallsIter":
def delete(self) -> bool:
"""Delete the call."""
client = weave_client_context.require_weave_client()
return client.delete_call(call=self)
client.delete_call(call=self)
return True

def set_display_name(self, name: Optional[str]) -> None:
"""
Expand Down
Loading