diff --git a/src/epstats/server/req.py b/src/epstats/server/req.py index 698fc69..87628cc 100644 --- a/src/epstats/server/req.py +++ b/src/epstats/server/req.py @@ -153,18 +153,30 @@ class Filter(BaseModel): Filter specification for data to evaluate. """ - dimension: str = Field(..., title="Name of the dimension") - - value: List[Any] = Field(..., title="List of possible values") - scope: FilterScope = Field( ..., title="Scope of the filter", description="Scope of the filter is either `exposure` or `goal`.", ) + dimension: Optional[str] = Field(None, title="Name of the dimension") + + value: Optional[List[Any]] = Field(None, title="List of possible values") + + goal: Optional[str] = Field(None, title="Specify goals if filter scope is `trigger`") + + @model_validator(mode="after") + def check_trigger(self): + if self.scope == FilterScope.trigger: + if not self.goal: + raise ValueError("Goal is required for scope 'trigger'") + else: + if not (self.dimension and self.value): + raise ValueError("Dimension and value are required for this scope") + return self + def to_filter(self): - return EvFilter(self.dimension, self.value, self.scope) + return EvFilter(self.scope, self.dimension, self.value, self.goal) class Experiment(BaseModel): diff --git a/src/epstats/toolkit/experiment.py b/src/epstats/toolkit/experiment.py index 98fc5fe..79c50ff 100644 --- a/src/epstats/toolkit/experiment.py +++ b/src/epstats/toolkit/experiment.py @@ -112,6 +112,7 @@ class FilterScope(str, Enum): exposure = "exposure" goal = "goal" + trigger = "trigger" @dataclass @@ -120,9 +121,18 @@ class Filter: Filter specification for data to evaluate. """ - dimension: str - value: List[Any] scope: FilterScope + dimension: Optional[str] = None + value: Optional[List[Any]] = None + goal: Optional[str] = None + + def __post_init__(self): + if self.scope == FilterScope.trigger: + if not self.goal: + raise ValueError("Trigger scope requires goal") + else: + if not (self.dimension and self.value): + raise ValueError("Dimension and value are required for this scope") class Experiment: diff --git a/tests/epstats/server/test_req.py b/tests/epstats/server/test_req.py index 8926ee6..96ac9d6 100644 --- a/tests/epstats/server/test_req.py +++ b/tests/epstats/server/test_req.py @@ -257,3 +257,60 @@ def test_validate_date_for_between_date_to_and_date_for(): json = resp.json() assert json["detail"][0]["loc"][0] == "body" assert json["detail"][0]["type"] == "value_error" + + +def test_filter_scope_trigger_empty_goal(): + json_blob = { + "id": "test-trigger", + "control_variant": "a", + "variants": ["a", "b"], + "unit_type": "test_unit_type", + "filters": [ + {"dimension": "element", "value": ["button-1"], "scope": "trigger"}, + ], + "metrics": [], + "checks": [], + } + + resp = client.post("/evaluate", json=json_blob) + assert resp.status_code == 422 + json = resp.json() + assert json["detail"][0]["loc"][0] == "body" + assert json["detail"][0]["type"] == "value_error" + + +def test_filter_scope_trigger_empty_dimension(): + json_blob = { + "id": "test-trigger", + "control_variant": "a", + "variants": ["a", "b"], + "unit_type": "test_unit_type", + "filters": [ + {"dimension": None, "value": [], "scope": "trigger", "goal": "view"}, + ], + "metrics": [], + "checks": [], + } + + resp = client.post("/evaluate", json=json_blob) + assert resp.status_code == 200 + + +def test_filter_scope_exposure_empty_dimension(): + json_blob = { + "id": "test-trigger", + "control_variant": "a", + "variants": ["a", "b"], + "unit_type": "test_unit_type", + "filters": [ + {"dimension": None, "value": [], "scope": "exposure"}, + ], + "metrics": [], + "checks": [], + } + + resp = client.post("/evaluate", json=json_blob) + assert resp.status_code == 422 + json = resp.json() + assert json["detail"][0]["loc"][0] == "body" + assert json["detail"][0]["type"] == "value_error" diff --git a/tests/epstats/toolkit/test_experiment.py b/tests/epstats/toolkit/test_experiment.py index dd04b55..26f27e3 100644 --- a/tests/epstats/toolkit/test_experiment.py +++ b/tests/epstats/toolkit/test_experiment.py @@ -410,7 +410,7 @@ def test_filter_scope_goal(dao, metrics, checks, unit_type): [SrmCheck(1, "SRM", "count(test_unit_type.global.exposure)")], unit_type=unit_type, variants=["a", "b"], - filters=[Filter("element", ["button-1"], FilterScope.goal)], + filters=[Filter(FilterScope.goal, "element", ["button-1"])], ) evaluate_experiment_agg(experiment, dao)