Skip to content

Commit

Permalink
started to generalize stitching modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip Daniel Keicher committed Nov 8, 2024
1 parent 86d0734 commit e774a5f
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 71 deletions.
67 changes: 50 additions & 17 deletions hbt/config/configs_hbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,28 +415,61 @@ def if_era(

# define inclusive datasets for the stitched process identification with corresponding leaf processes
if run == 3 and not sync_mode:
# drell-yan
cfg.x.dy_stitching = {
"m50toinf": {
"inclusive_dataset": cfg.datasets.n.dy_m50toinf_amcatnlo,
"leaf_processes": [
# the following processes cover the full njet and pt phasespace
procs.n.dy_m50toinf_0j,
*(
procs.get(f"dy_m50toinf_{nj}j_pt{pt}")
for nj in [1, 2]
for pt in ["0to40", "40to100", "100to200", "200to400", "400to600", "600toinf"]
),
procs.n.dy_m50toinf_ge3j,
],
# stitchting definitions
# format:
# {
# "dataset_tag": {
# "dataset_to_stitch": {
# "inclusive_dataset": dataset,
# "leaf_processes": [proc1, proc2, ...],
# },
# }
# }
cfg.x.stitching = {
# DY stitching
"is_dy": {
"m50toinf": {
"inclusive_dataset": cfg.datasets.n.dy_m50toinf_amcatnlo,
"leaf_processes": [
# the following processes cover the full njet and pt phasespace
procs.n.dy_m50toinf_0j,
*(
procs.get(f"dy_m50toinf_{nj}j_pt{pt}")
for nj in [1, 2]
for pt in ["0to40", "40to100", "100to200", "200to400", "400to600", "600toinf"]
),
procs.n.dy_m50toinf_ge3j,
],
},
},
# w+jets
# TODO: add
}
# w+jets
# TODO: add


# dataset groups for conveniently looping over certain datasets
# (used in wrapper_factory and during plotting)
cfg.x.dataset_groups = {}
cfg.x.dataset_groups = {
"dy": [
# dy
"dy_m4to10_amcatnlo",
"dy_m10to50_amcatnlo",
"dy_m50toinf_amcatnlo",
"dy_m50toinf_0j_amcatnlo",
"dy_m50toinf_1j_amcatnlo",
"dy_m50toinf_2j_amcatnlo",
"dy_m50toinf_1j_pt40to100_amcatnlo",
"dy_m50toinf_1j_pt100to200_amcatnlo",
"dy_m50toinf_1j_pt200to400_amcatnlo",
"dy_m50toinf_1j_pt400to600_amcatnlo",
"dy_m50toinf_1j_pt600toinf_amcatnlo",
"dy_m50toinf_2j_pt40to100_amcatnlo",
"dy_m50toinf_2j_pt100to200_amcatnlo",
"dy_m50toinf_2j_pt200to400_amcatnlo",
"dy_m50toinf_2j_pt400to600_amcatnlo",
"dy_m50toinf_2j_pt600toinf_amcatnlo",
],
}

# category groups for conveniently looping over certain categories
# (used during plotting)
Expand Down
182 changes: 129 additions & 53 deletions hbt/production/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
"""
Process ID producer relevant for the stitching of the DY samples.
"""

from __future__ import annotations
import functools

import law
import order

from columnflow.production import Producer, producer
from columnflow.util import maybe_import, InsertableDict
from columnflow.columnar_util import set_ak_column
from columnflow.columnar_util import set_ak_column, Route

from hbt.util import IF_DATASET_IS_DY
from columnflow.types import Callable

np = maybe_import("numpy")
ak = maybe_import("awkward")
Expand All @@ -27,61 +29,135 @@

set_ak_column_i64 = functools.partial(set_ak_column, value_type=np.int64)


@producer(
uses={IF_DATASET_IS_DY("LHE.NpNLO", "LHE.Vpt")},
produces={IF_DATASET_IS_DY("process_id")},
)
def process_ids_dy(self: Producer, events: ak.Array, **kwargs) -> ak.Array:
"""
Assigns each dy event a single process id, based on the number of jets and the di-lepton pt of
the LHE record. This is used for the stitching of the DY samples.
"""
# as always, we assume that each dataset has exactly one process associated to it
if len(self.dataset_inst.processes) != 1:
raise NotImplementedError(
f"dataset {self.dataset_inst.name} has {len(self.dataset_inst.processes)} processes "
"assigned, which is not yet implemented",
)
process_inst = self.dataset_inst.processes.get_first()

# get the number of nlo jets and the di-lepton pt
njets = events.LHE.NpNLO
pt = events.LHE.Vpt

# raise a warning if a datasets was already created for a specific "bin" (leaf process),
# but actually does not fit
njets_range = process_inst.x("njets", None)
if njets_range is not None:
outliers = (njets < njets_range[0]) | (njets >= njets_range[1])
if ak.any(outliers):
logger.warning(
f"dataset {self.dataset_inst.name} is meant to contain njet values in the range "
f"[{njets_range[0]}, {njets_range[0]}), but found {ak.sum(outliers)} events "
"outside this range",
class stitched_process_ids(Producer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_func: Callable or None = None
self.cross_check_func: Callable or None = self.cross_check_func
self.stichting_observabes: list[str] or None = None
self.cross_check_translation_dict: dict[str, str] or None = None

def init_func(self, *args, **kwargs):
self.uses |= {IF_DATASET_IS_DY(*self.stichting_observabes)}
self.produces |= {IF_DATASET_IS_DY("process_id")}

def call_func(self, events: ak.Array, **kwargs) -> ak.Array:
"""
Assigns each event a single process id, based on the stitching observables of
the LHE record. This is used for the stitching of the respective samples.
"""
# as always, we assume that each dataset has exactly one process associated to it
if len(self.dataset_inst.processes) != 1:
raise NotImplementedError(
f"dataset {self.dataset_inst.name} has {len(self.dataset_inst.processes)} processes "
"assigned, which is not yet implemented",
)
pt_range = process_inst.x("ptll", None)
if pt_range is not None:
outliers = (pt < pt_range[0]) | (pt >= pt_range[1])
if ak.any(outliers):
logger.warning(
f"dataset {self.dataset_inst.name} is meant to contain ptll values in the range "
f"[{pt_range[0]}, {pt_range[1]}), but found {ak.sum(outliers)} events outside this "
"range",
process_inst = self.dataset_inst.processes.get_first()

# # get the number of nlo jets and the di-lepton pt
# njets = events.LHE.NpNLO
# pt = events.LHE.Vpt
# get stitching observables
stitching_obs_values = [Route(obs).apply(events) for obs in self.stichting_observabes]

if self.cross_check_translation_dict and callable(self.cross_check_func):
self.cross_check_func(process_inst, stitching_obs_values)

# lookup the id and check for invalid values
process_ids = np.squeeze(np.asarray(self.id_table[self.key_func(*stitching_obs_values)].todense()))
invalid_mask = process_ids == 0
if ak.any(invalid_mask):
raise ValueError(
f"found {sum(invalid_mask)} events that could not be assigned to a process",
)

# lookup the id and check for invalid values
process_ids = np.squeeze(np.asarray(self.id_table[self.key_func(njets, pt)].todense()))
invalid_mask = process_ids == 0
if ak.any(invalid_mask):
raise ValueError(
f"found {sum(invalid_mask)} dy events that could not be assigned to a process",
)

# store them
events = set_ak_column_i64(events, "process_id", process_ids)
# store them
events = set_ak_column_i64(events, "process_id", process_ids)

return events


def stitching_range_cross_check(
self: Producer,
process_inst: order.Process,
stichting_values: list[ak.Array]
) -> None:
# define lookup for stichting observable -> process auxiliary values to compare with
# raise a warning if a datasets was already created for a specific "bin" (leaf process),
# but actually does not fit
for obs_name, obs_values in zip(self.stichting_observabes, stichting_values):
aux_name = self.cross_check_translation_dict.get(obs_name, obs_name)
aux_values = process_inst.x(aux_name, None)
if aux_values is not None:
outliers = (obs_values < aux_values[0]) | (obs_values >= aux_values[1])
if ak.any(outliers):
logger.warning(
f"dataset {self.dataset_inst.name} is meant to contain {aux_name} "
f"values in the range [{aux_values[0]}, {aux_values[0]}), but found {ak.sum(outliers)} "
"events outside this range",
)

process_ids_dy = stitched_process_ids.derive(
"process_ids_dy", cls_dict={
"stichting_observabes": ["LHE.NpNLO", "LHE.Vpt"],
"cross_check_translation_dict": {"LHE.NpNLO": "njets", "LHE.Vpt": "ptll"},
},
)

return events
# @producer(
# uses={IF_DATASET_IS_DY("LHE.NpNLO", "LHE.Vpt")},
# produces={IF_DATASET_IS_DY("process_id")},
# )
# def process_ids_dy(self: Producer, events: ak.Array, **kwargs) -> ak.Array:
# """
# Assigns each dy event a single process id, based on the number of jets and the di-lepton pt of
# the LHE record. This is used for the stitching of the DY samples.
# """
# # as always, we assume that each dataset has exactly one process associated to it
# if len(self.dataset_inst.processes) != 1:
# raise NotImplementedError(
# f"dataset {self.dataset_inst.name} has {len(self.dataset_inst.processes)} processes "
# "assigned, which is not yet implemented",
# )
# process_inst = self.dataset_inst.processes.get_first()

# # get the number of nlo jets and the di-lepton pt
# njets = events.LHE.NpNLO
# pt = events.LHE.Vpt

# # raise a warning if a datasets was already created for a specific "bin" (leaf process),
# # but actually does not fit
# njets_range = process_inst.x("njets", None)
# if njets_range is not None:
# outliers = (njets < njets_range[0]) | (njets >= njets_range[1])
# if ak.any(outliers):
# logger.warning(
# f"dataset {self.dataset_inst.name} is meant to contain njet values in the range "
# f"[{njets_range[0]}, {njets_range[0]}), but found {ak.sum(outliers)} events "
# "outside this range",
# )
# pt_range = process_inst.x("ptll", None)
# if pt_range is not None:
# outliers = (pt < pt_range[0]) | (pt >= pt_range[1])
# if ak.any(outliers):
# logger.warning(
# f"dataset {self.dataset_inst.name} is meant to contain ptll values in the range "
# f"[{pt_range[0]}, {pt_range[1]}), but found {ak.sum(outliers)} events outside this "
# "range",
# )

# # lookup the id and check for invalid values
# process_ids = np.squeeze(np.asarray(self.id_table[self.key_func(njets, pt)].todense()))
# invalid_mask = process_ids == 0
# if ak.any(invalid_mask):
# raise ValueError(
# f"found {sum(invalid_mask)} dy events that could not be assigned to a process",
# )

# # store them
# events = set_ak_column_i64(events, "process_id", process_ids)

# return events


@process_ids_dy.setup
Expand Down
3 changes: 2 additions & 1 deletion hbt/selection/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,9 @@ def default_init(self: Selector) -> None:

self.process_ids_dy: process_ids_dy | None = None
if self.dataset_inst.has_tag("is_dy"):
dy_stitching = self.config_inst.x.stitching["is_dy"]
# check if this dataset is covered by any dy id producer
for name, dy_cfg in self.config_inst.x.dy_stitching.items():
for name, dy_cfg in dy_stitching.items():
dataset_inst = dy_cfg["inclusive_dataset"]
# the dataset is "covered" if its process is a subprocess of that of the dy dataset
if dataset_inst.has_process(self.dataset_inst.processes.get_first()):
Expand Down

0 comments on commit e774a5f

Please sign in to comment.