Skip to content

Commit

Permalink
feat: add aggregation_units to metric and datasource (#553)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikewilli authored Aug 20, 2024
1 parent 6194422 commit a931259
Show file tree
Hide file tree
Showing 10 changed files with 475 additions and 169 deletions.
6 changes: 6 additions & 0 deletions lib/metric-config-parser/metric_config_parser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class AnalysisUnit(Enum):
CLIENT = "client_id"
PROFILE_GROUP = "profile_group_id"
22 changes: 20 additions & 2 deletions lib/metric-config-parser/metric_config_parser/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .definition import DefinitionSpecSub
from .project import ProjectConfiguration

from . import AnalysisUnit
from .util import converter, is_valid_slug


Expand Down Expand Up @@ -77,6 +78,16 @@ class DataSource:
`{dataset}` in from_expr if a value is not provided
at runtime. Mandatory if from_expr contains a
`{dataset}` parameter.
build_id_column (str, optional):
Default 'SAFE.SUBSTR(application.build_id, 0, 8)'.
friendly_name (str, optional)
description (str, optional)
joins (list[DataSourceJoin], optional)
columns_as_dimensions (bool, optional): Default false.
analysis_units (list[AnalysisUnit], optional): denotes which
aggregations are supported by this data_source. At time
of writing, this means 'client_id', 'profile_group_id',
or both. Defaults to 'client_id'.
"""

name = attr.ib(validator=attr.validators.instance_of(str))
Expand All @@ -90,6 +101,7 @@ class DataSource:
description = attr.ib(default=None, type=str)
joins = attr.ib(default=None, type=List[DataSourceJoin])
columns_as_dimensions = attr.ib(default=False, type=bool)
analysis_units = attr.ib(default=[AnalysisUnit.CLIENT], type=List[AnalysisUnit])

EXPERIMENT_COLUMN_TYPES = (None, "simple", "native", "glean")

Expand Down Expand Up @@ -162,6 +174,7 @@ class DataSourceDefinition:
description: Optional[str] = None
joins: Optional[Dict[str, Dict[str, Any]]] = None
columns_as_dimensions: Optional[bool] = None
analysis_units: Optional[list[AnalysisUnit]] = None

def resolve(
self,
Expand All @@ -179,7 +192,10 @@ def resolve(
+ "Wildcard characters are only allowed if matching slug is defined."
)

params: Dict[str, Any] = {"name": self.name, "from_expression": self.from_expression}
params: Dict[str, Any] = {
"name": self.name,
"from_expression": self.from_expression,
}
# Allow mozanalysis to infer defaults for these values:
for k in (
"experiments_column_type",
Expand All @@ -190,6 +206,7 @@ def resolve(
"friendly_name",
"description",
"columns_as_dimensions",
"analysis_units",
):
v = getattr(self, k)
if v:
Expand Down Expand Up @@ -243,7 +260,8 @@ class DataSourcesSpec:
def from_dict(cls, d: dict) -> "DataSourcesSpec":
definitions = {
k: converter.structure(
{"name": k, **dict((kk.lower(), vv) for kk, vv in v.items())}, DataSourceDefinition
{"name": k, **dict((kk.lower(), vv) for kk, vv in v.items())},
DataSourceDefinition,
)
for k, v in d.items()
}
Expand Down
37 changes: 31 additions & 6 deletions lib/metric-config-parser/metric_config_parser/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .definition import DefinitionSpecSub
from .project import ProjectConfiguration

from . import AnalysisUnit
from .data_source import DataSource, DataSourceReference
from .parameter import ParameterDefinition
from .pre_treatment import PreTreatmentReference
Expand Down Expand Up @@ -96,6 +97,7 @@ class Metric:
owner: Optional[List[str]] = None
deprecated: bool = False
level: Optional[MetricLevel] = None
analysis_units: List[AnalysisUnit] = [AnalysisUnit.CLIENT]


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -147,6 +149,7 @@ class MetricDefinition:
owner: Optional[Union[str, List[str]]] = None
deprecated: bool = False
level: Optional[MetricLevel] = None
analysis_units: Optional[List[AnalysisUnit]] = None

@staticmethod
def generate_select_expression(
Expand Down Expand Up @@ -239,13 +242,21 @@ def resolve(
owner=[self.owner] if isinstance(self.owner, str) else self.owner,
deprecated=self.deprecated,
level=self.level,
analysis_units=self.analysis_units or [AnalysisUnit.CLIENT],
)
elif metric_definition:
metric_definition.analysis_bases = self.analysis_bases or [
AnalysisBasis.ENROLLMENTS,
AnalysisBasis.EXPOSURES,
]
metric_definition.analysis_bases = (
self.analysis_bases
or metric_definition.analysis_bases
or [
AnalysisBasis.ENROLLMENTS,
AnalysisBasis.EXPOSURES,
]
)
metric_definition.statistics = self.statistics
metric_definition.analysis_units = (
self.analysis_units or metric_definition.analysis_units or [AnalysisUnit.CLIENT]
)
metric_summary = metric_definition.resolve(spec, conf, configs)
else:
select_expression = self.generate_select_expression(
Expand All @@ -254,9 +265,21 @@ def resolve(
configs=configs,
)

# ensure all of metric's analysis_units are supported by data_source
resolved_ds = self.data_source.resolve(spec, conf, configs)
analysis_units = self.analysis_units or [AnalysisUnit.CLIENT]
for agg_unit in analysis_units:
if agg_unit not in resolved_ds.analysis_units:
raise ValueError(
f"data_source {resolved_ds.name} does not support "
f"all analysis_units specified by metric {self.name}: "
f"analysis_units for metric: {analysis_units}, "
f"analysis_units for data_source: {resolved_ds.analysis_units}"
)

metric = Metric(
name=self.name,
data_source=self.data_source.resolve(spec, conf, configs),
data_source=resolved_ds,
select_expression=select_expression,
friendly_name=(
dedent(self.friendly_name) if self.friendly_name else self.friendly_name
Expand All @@ -271,6 +294,7 @@ def resolve(
owner=[self.owner] if isinstance(self.owner, str) else self.owner,
deprecated=self.deprecated,
level=self.level,
analysis_units=analysis_units,
)

metrics_with_treatments = []
Expand Down Expand Up @@ -361,7 +385,8 @@ def from_dict(cls, d: dict) -> "MetricsSpec":

params["definitions"] = {
k: converter.structure(
{"name": k, **dict((kk.lower(), vv) for kk, vv in v.items())}, MetricDefinition
{"name": k, **dict((kk.lower(), vv) for kk, vv in v.items())},
MetricDefinition,
)
for k, v in d.items()
if k not in known_keys and k != "28_day"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def local_tmp_repo(tmp_path):
shutil.copytree(TEST_DIR / "data", tmp_path / "metrics")
r.config_writer().set_value("user", "name", "test").release()
r.config_writer().set_value("user", "email", "[email protected]").release()
r.config_writer().set_value("commit", "gpgsign", "false").release()
r.git.add(".")
r.git.commit("-m", "commit")
return tmp_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[metrics.unenroll]
data_source = "normandy_events"
select_expression='''{{agg_any(
select_expression = '''{{agg_any(
"""
event_category = 'normandy'
AND event_method = 'unenroll'
Expand Down Expand Up @@ -67,11 +67,28 @@ data_source = "joined_baseline"
select_expression = "SELECT 1"


[metrics.test_active_hours]
friendly_name = "Test Active hours"
description = """
Contrived example for specific test case.
"""
select_expression = '{{agg_sum("test_active_hours_sum")}}'
data_source = "test_main"
analysis_bases = ["exposures"]
analysis_units = ["profile_group_id"]

[metrics.test_active_hours.statistics.bootstrap_mean]


[data_sources]

[data_sources.main]
from_expression = "(SELECT 1)"

[data_sources.test_main]
from_expression = "(SELECT 1)"
analysis_units = ["client_id", "profile_group_id"]

[data_sources.clients_daily]
from_expression = "mozdata.telemetry.clients_daily"
friendly_name = "Clients Daily"
Expand All @@ -91,7 +108,7 @@ from_expression = """(
FROM mozdata.telemetry.events
WHERE event_category = 'normandy'
)"""
experiments_column_type="native"
experiments_column_type = "native"
friendly_name = "Normandy Events"
description = "Normandy Events"

Expand Down
147 changes: 147 additions & 0 deletions lib/metric-config-parser/metric_config_parser/tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from cattrs.errors import ClassValidationError
from mozilla_nimbus_schemas.jetstream import AnalysisBasis

from metric_config_parser import AnalysisUnit
from metric_config_parser.analysis import AnalysisConfiguration, AnalysisSpec
from metric_config_parser.data_source import DataSource
from metric_config_parser.errors import InvalidConfigurationException
Expand Down Expand Up @@ -347,6 +348,7 @@ def test_merge_configs_override_metric(self, experiments, config_collection):
assert spam.metric.data_source.name == "main"
assert spam.metric.select_expression == "2"
assert spam.metric.analysis_bases == [AnalysisBasis.ENROLLMENTS, AnalysisBasis.EXPOSURES]
assert spam.metric.analysis_units == [AnalysisUnit.CLIENT]
assert spam.statistic.name == "bootstrap_mean"
assert spam.statistic.params["num_samples"] == 100

Expand Down Expand Up @@ -763,3 +765,148 @@ def test_multiple_nesting(self, experiments, config_collection):
assert len(metric.depends_on[0].metric.depends_on) == 1
assert metric.depends_on[0].metric.depends_on[0].metric.name == "spam"
assert metric.depends_on[0].metric.depends_on[0].metric.select_expression == "1"

def test_default_analysis_units(self, experiments, config_collection):
config_str = dedent(
"""
[metrics]
weekly = ["active_hours"]
[metrics.active_hours.statistics.bootstrap_mean]
"""
)

spec = AnalysisSpec.from_dict(toml.loads(config_str))
cfg = spec.resolve(experiments[0], config_collection)

metric = [m for m in cfg.metrics[AnalysisPeriod.WEEK] if m.metric.name == "active_hours"][
0
].metric
assert metric.analysis_units == [AnalysisUnit.CLIENT]

data_source = metric.data_source
assert data_source.analysis_units == [AnalysisUnit.CLIENT]

def test_no_override_defined_analysis_units_with_defaults(self, experiments, config_collection):
# ensure the resolve function does not override
# an upstream definition with the default value
custom_conf = dedent(
"""
[metrics]
weekly = ["test_active_hours"]
[metrics.test_active_hours]
friendly_name = "Overridden Active Hours"
[metrics.test_active_hours.statistics.bootstrap_mean]
"""
)

spec = AnalysisSpec.from_dict(toml.loads(custom_conf))
cfg = spec.resolve(experiments[0], config_collection)

spam = [
m for m in cfg.metrics[AnalysisPeriod.WEEK] if m.metric.name == "test_active_hours"
][0]

assert len(cfg.metrics[AnalysisPeriod.WEEK]) == 1
assert spam.metric.data_source.name == "test_main"
assert spam.metric.analysis_bases == [AnalysisBasis.EXPOSURES]
assert spam.metric.analysis_units == [AnalysisUnit.PROFILE_GROUP]

@pytest.mark.parametrize(
"metric_units, ds_units",
(
(
"analysis_units = ['profile_group_id']",
"analysis_units = ['profile_group_id']",
),
(
"analysis_units = ['client_id']",
"analysis_units = ['client_id']",
),
(
"analysis_units = ['profile_group_id']",
"analysis_units = ['client_id', 'profile_group_id']",
),
(
"analysis_units = ['client_id']",
"analysis_units = ['profile_group_id', 'client_id']",
),
(
"analysis_units = ['client_id', 'profile_group_id']",
"analysis_units = ['profile_group_id', 'client_id']",
),
),
)
def test_valid_analysis_units_combinations(
self, metric_units, ds_units, experiments, config_collection
):
config_str = dedent(
f"""
[metrics]
weekly = ["spam"]
[metrics.spam]
data_source = "eggs"
select_expression = "1"
{metric_units}
[metrics.spam.statistics.bootstrap_mean]
[data_sources.eggs]
from_expression = "england.camelot"
client_id_column = "client_info.client_id"
{ds_units}
"""
)

spec = AnalysisSpec.from_dict(toml.loads(config_str))
cfg = spec.resolve(experiments[0], config_collection)
metric = [m for m in cfg.metrics[AnalysisPeriod.WEEK] if m.metric.name == "spam"][0].metric
assert metric.analysis_units is not None
assert metric.data_source.analysis_units is not None
for unit in metric.analysis_units:
assert unit in metric.data_source.analysis_units

@pytest.mark.parametrize(
"metric_units,ds_units",
(
("analysis_units = ['client_id']", "analysis_units = ['profile_group_id']"),
(
"analysis_units = ['client_id', 'profile_group_id']",
"analysis_units = ['profile_group_id']",
),
("analysis_units = ['profile_group_id']", "analysis_units = ['client_id']"),
(
"analysis_units = ['client_id', 'profile_group_id']",
"analysis_units = ['client_id']",
),
),
)
def test_invalid_analysis_units_combinations(
self, metric_units, ds_units, experiments, config_collection
):
config_str = dedent(
f"""
[metrics]
weekly = ["spam"]
[metrics.spam]
data_source = "eggs"
select_expression = "1"
{metric_units}
[metrics.spam.statistics.bootstrap_mean]
[data_sources.eggs]
from_expression = "england.camelot"
client_id_column = "client_info.client_id"
{ds_units}
"""
)

spec = AnalysisSpec.from_dict(toml.loads(config_str))

with pytest.raises(ValueError, match="data_source eggs does not support"):
spec.resolve(experiments[0], config_collection)
Loading

0 comments on commit a931259

Please sign in to comment.