Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add aggregation_units to metric and datasource #553

Merged
merged 7 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/metric-config-parser/metric_config_parser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class AnalysisUnit(Enum):
CLIENT = "client_id"
PROFILE_GROUP = "profile_group_id"
22 changes: 20 additions & 2 deletions lib/metric-config-parser/metric_config_parser/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .definition import DefinitionSpecSub
from .project import ProjectConfiguration

from . import AnalysisUnit
from .util import converter, is_valid_slug


Expand Down Expand Up @@ -77,6 +78,16 @@ class DataSource:
`{dataset}` in from_expr if a value is not provided
at runtime. Mandatory if from_expr contains a
`{dataset}` parameter.
build_id_column (str, optional):
Default 'SAFE.SUBSTR(application.build_id, 0, 8)'.
friendly_name (str, optional)
description (str, optional)
joins (list[DataSourceJoin], optional)
columns_as_dimensions (bool, optional): Default false.
analysis_units (list[AnalysisUnit], optional): denotes which
aggregations are supported by this data_source. At time
of writing, this means 'client_id', 'profile_group_id',
or both. Defaults to 'client_id'.
"""

name = attr.ib(validator=attr.validators.instance_of(str))
Expand All @@ -90,6 +101,7 @@ class DataSource:
description = attr.ib(default=None, type=str)
joins = attr.ib(default=None, type=List[DataSourceJoin])
columns_as_dimensions = attr.ib(default=False, type=bool)
analysis_units = attr.ib(default=[AnalysisUnit.CLIENT], type=List[AnalysisUnit])

EXPERIMENT_COLUMN_TYPES = (None, "simple", "native", "glean")

Expand Down Expand Up @@ -162,6 +174,7 @@ class DataSourceDefinition:
description: Optional[str] = None
joins: Optional[Dict[str, Dict[str, Any]]] = None
columns_as_dimensions: Optional[bool] = None
analysis_units: Optional[list[AnalysisUnit]] = None

def resolve(
self,
Expand All @@ -179,7 +192,10 @@ def resolve(
+ "Wildcard characters are only allowed if matching slug is defined."
)

params: Dict[str, Any] = {"name": self.name, "from_expression": self.from_expression}
params: Dict[str, Any] = {
"name": self.name,
"from_expression": self.from_expression,
}
# Allow mozanalysis to infer defaults for these values:
for k in (
"experiments_column_type",
Expand All @@ -190,6 +206,7 @@ def resolve(
"friendly_name",
"description",
"columns_as_dimensions",
"analysis_units",
):
v = getattr(self, k)
if v:
Expand Down Expand Up @@ -243,7 +260,8 @@ class DataSourcesSpec:
def from_dict(cls, d: dict) -> "DataSourcesSpec":
definitions = {
k: converter.structure(
{"name": k, **dict((kk.lower(), vv) for kk, vv in v.items())}, DataSourceDefinition
{"name": k, **dict((kk.lower(), vv) for kk, vv in v.items())},
DataSourceDefinition,
)
for k, v in d.items()
}
Expand Down
37 changes: 31 additions & 6 deletions lib/metric-config-parser/metric_config_parser/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .definition import DefinitionSpecSub
from .project import ProjectConfiguration

from . import AnalysisUnit
from .data_source import DataSource, DataSourceReference
from .parameter import ParameterDefinition
from .pre_treatment import PreTreatmentReference
Expand Down Expand Up @@ -96,6 +97,7 @@ class Metric:
owner: Optional[List[str]] = None
deprecated: bool = False
level: Optional[MetricLevel] = None
analysis_units: List[AnalysisUnit] = [AnalysisUnit.CLIENT]


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -147,6 +149,7 @@ class MetricDefinition:
owner: Optional[Union[str, List[str]]] = None
deprecated: bool = False
level: Optional[MetricLevel] = None
analysis_units: Optional[List[AnalysisUnit]] = None

@staticmethod
def generate_select_expression(
Expand Down Expand Up @@ -239,13 +242,21 @@ def resolve(
owner=[self.owner] if isinstance(self.owner, str) else self.owner,
deprecated=self.deprecated,
level=self.level,
analysis_units=self.analysis_units or [AnalysisUnit.CLIENT],
)
elif metric_definition:
metric_definition.analysis_bases = self.analysis_bases or [
AnalysisBasis.ENROLLMENTS,
AnalysisBasis.EXPOSURES,
]
metric_definition.analysis_bases = (
self.analysis_bases
or metric_definition.analysis_bases
or [
AnalysisBasis.ENROLLMENTS,
AnalysisBasis.EXPOSURES,
]
)
metric_definition.statistics = self.statistics
metric_definition.analysis_units = (
self.analysis_units or metric_definition.analysis_units or [AnalysisUnit.CLIENT]
)
metric_summary = metric_definition.resolve(spec, conf, configs)
else:
select_expression = self.generate_select_expression(
Expand All @@ -254,9 +265,21 @@ def resolve(
configs=configs,
)

# ensure all of metric's analysis_units are supported by data_source
resolved_ds = self.data_source.resolve(spec, conf, configs)
analysis_units = self.analysis_units or [AnalysisUnit.CLIENT]
for agg_unit in analysis_units:
if agg_unit not in resolved_ds.analysis_units:
raise ValueError(
f"data_source {resolved_ds.name} does not support "
f"all analysis_units specified by metric {self.name}: "
f"analysis_units for metric: {analysis_units}, "
f"analysis_units for data_source: {resolved_ds.analysis_units}"
)

metric = Metric(
name=self.name,
data_source=self.data_source.resolve(spec, conf, configs),
data_source=resolved_ds,
select_expression=select_expression,
friendly_name=(
dedent(self.friendly_name) if self.friendly_name else self.friendly_name
Expand All @@ -271,6 +294,7 @@ def resolve(
owner=[self.owner] if isinstance(self.owner, str) else self.owner,
deprecated=self.deprecated,
level=self.level,
analysis_units=analysis_units,
)

metrics_with_treatments = []
Expand Down Expand Up @@ -361,7 +385,8 @@ def from_dict(cls, d: dict) -> "MetricsSpec":

params["definitions"] = {
k: converter.structure(
{"name": k, **dict((kk.lower(), vv) for kk, vv in v.items())}, MetricDefinition
{"name": k, **dict((kk.lower(), vv) for kk, vv in v.items())},
MetricDefinition,
)
for k, v in d.items()
if k not in known_keys and k != "28_day"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def local_tmp_repo(tmp_path):
shutil.copytree(TEST_DIR / "data", tmp_path / "metrics")
r.config_writer().set_value("user", "name", "test").release()
r.config_writer().set_value("user", "email", "[email protected]").release()
r.config_writer().set_value("commit", "gpgsign", "false").release()
r.git.add(".")
r.git.commit("-m", "commit")
return tmp_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[metrics.unenroll]
data_source = "normandy_events"
select_expression='''{{agg_any(
select_expression = '''{{agg_any(
"""
event_category = 'normandy'
AND event_method = 'unenroll'
Expand Down Expand Up @@ -67,11 +67,28 @@ data_source = "joined_baseline"
select_expression = "SELECT 1"


[metrics.test_active_hours]
friendly_name = "Test Active hours"
description = """
Contrived example for specific test case.
"""
select_expression = '{{agg_sum("test_active_hours_sum")}}'
data_source = "test_main"
analysis_bases = ["exposures"]
analysis_units = ["profile_group_id"]

[metrics.test_active_hours.statistics.bootstrap_mean]


[data_sources]

[data_sources.main]
from_expression = "(SELECT 1)"

[data_sources.test_main]
from_expression = "(SELECT 1)"
analysis_units = ["client_id", "profile_group_id"]

[data_sources.clients_daily]
from_expression = "mozdata.telemetry.clients_daily"
friendly_name = "Clients Daily"
Expand All @@ -91,7 +108,7 @@ from_expression = """(
FROM mozdata.telemetry.events
WHERE event_category = 'normandy'
)"""
experiments_column_type="native"
experiments_column_type = "native"
friendly_name = "Normandy Events"
description = "Normandy Events"

Expand Down
147 changes: 147 additions & 0 deletions lib/metric-config-parser/metric_config_parser/tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from cattrs.errors import ClassValidationError
from mozilla_nimbus_schemas.jetstream import AnalysisBasis

from metric_config_parser import AnalysisUnit
from metric_config_parser.analysis import AnalysisConfiguration, AnalysisSpec
from metric_config_parser.data_source import DataSource
from metric_config_parser.errors import InvalidConfigurationException
Expand Down Expand Up @@ -347,6 +348,7 @@ def test_merge_configs_override_metric(self, experiments, config_collection):
assert spam.metric.data_source.name == "main"
assert spam.metric.select_expression == "2"
assert spam.metric.analysis_bases == [AnalysisBasis.ENROLLMENTS, AnalysisBasis.EXPOSURES]
assert spam.metric.analysis_units == [AnalysisUnit.CLIENT]
assert spam.statistic.name == "bootstrap_mean"
assert spam.statistic.params["num_samples"] == 100

Expand Down Expand Up @@ -763,3 +765,148 @@ def test_multiple_nesting(self, experiments, config_collection):
assert len(metric.depends_on[0].metric.depends_on) == 1
assert metric.depends_on[0].metric.depends_on[0].metric.name == "spam"
assert metric.depends_on[0].metric.depends_on[0].metric.select_expression == "1"

def test_default_analysis_units(self, experiments, config_collection):
config_str = dedent(
"""
[metrics]
weekly = ["active_hours"]

[metrics.active_hours.statistics.bootstrap_mean]
"""
)

spec = AnalysisSpec.from_dict(toml.loads(config_str))
cfg = spec.resolve(experiments[0], config_collection)

metric = [m for m in cfg.metrics[AnalysisPeriod.WEEK] if m.metric.name == "active_hours"][
0
].metric
assert metric.analysis_units == [AnalysisUnit.CLIENT]

data_source = metric.data_source
assert data_source.analysis_units == [AnalysisUnit.CLIENT]

def test_no_override_defined_analysis_units_with_defaults(self, experiments, config_collection):
# ensure the resolve function does not override
# an upstream definition with the default value
custom_conf = dedent(
"""
[metrics]
weekly = ["test_active_hours"]

[metrics.test_active_hours]
friendly_name = "Overridden Active Hours"

[metrics.test_active_hours.statistics.bootstrap_mean]
"""
)

spec = AnalysisSpec.from_dict(toml.loads(custom_conf))
cfg = spec.resolve(experiments[0], config_collection)

spam = [
m for m in cfg.metrics[AnalysisPeriod.WEEK] if m.metric.name == "test_active_hours"
][0]

assert len(cfg.metrics[AnalysisPeriod.WEEK]) == 1
assert spam.metric.data_source.name == "test_main"
assert spam.metric.analysis_bases == [AnalysisBasis.EXPOSURES]
assert spam.metric.analysis_units == [AnalysisUnit.PROFILE_GROUP]

@pytest.mark.parametrize(
"metric_units, ds_units",
(
(
"analysis_units = ['profile_group_id']",
"analysis_units = ['profile_group_id']",
),
(
"analysis_units = ['client_id']",
"analysis_units = ['client_id']",
),
(
"analysis_units = ['profile_group_id']",
"analysis_units = ['client_id', 'profile_group_id']",
),
(
"analysis_units = ['client_id']",
"analysis_units = ['profile_group_id', 'client_id']",
),
(
"analysis_units = ['client_id', 'profile_group_id']",
"analysis_units = ['profile_group_id', 'client_id']",
),
),
)
def test_valid_analysis_units_combinations(
self, metric_units, ds_units, experiments, config_collection
):
config_str = dedent(
f"""
[metrics]
weekly = ["spam"]

[metrics.spam]
data_source = "eggs"
select_expression = "1"
{metric_units}

[metrics.spam.statistics.bootstrap_mean]

[data_sources.eggs]
from_expression = "england.camelot"
client_id_column = "client_info.client_id"
{ds_units}
"""
)

spec = AnalysisSpec.from_dict(toml.loads(config_str))
cfg = spec.resolve(experiments[0], config_collection)
metric = [m for m in cfg.metrics[AnalysisPeriod.WEEK] if m.metric.name == "spam"][0].metric
assert metric.analysis_units is not None
assert metric.data_source.analysis_units is not None
for unit in metric.analysis_units:
assert unit in metric.data_source.analysis_units

@pytest.mark.parametrize(
danielkberry marked this conversation as resolved.
Show resolved Hide resolved
"metric_units,ds_units",
(
("analysis_units = ['client_id']", "analysis_units = ['profile_group_id']"),
(
"analysis_units = ['client_id', 'profile_group_id']",
"analysis_units = ['profile_group_id']",
),
("analysis_units = ['profile_group_id']", "analysis_units = ['client_id']"),
(
"analysis_units = ['client_id', 'profile_group_id']",
"analysis_units = ['client_id']",
),
),
)
def test_invalid_analysis_units_combinations(
self, metric_units, ds_units, experiments, config_collection
):
config_str = dedent(
f"""
[metrics]
weekly = ["spam"]

[metrics.spam]
data_source = "eggs"
select_expression = "1"
{metric_units}

[metrics.spam.statistics.bootstrap_mean]

[data_sources.eggs]
from_expression = "england.camelot"
client_id_column = "client_info.client_id"
{ds_units}
"""
)

spec = AnalysisSpec.from_dict(toml.loads(config_str))

with pytest.raises(ValueError, match="data_source eggs does not support"):
spec.resolve(experiments[0], config_collection)
Loading
Loading