Skip to content

Commit bae093f

Browse files
committed
feat: support custom field ordering and subsetting in Metric.write()
Fixes #177
1 parent a1ffed0 commit bae093f

File tree

2 files changed

+339
-5
lines changed

2 files changed

+339
-5
lines changed

fgpyo/util/metric.py

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,47 @@
111111
>>> Person(name=Name(first='john', last='doe'), age=42, address=None).formatted_values()
112112
["first last", "42"]
113113
```
114+
115+
## Customizing Field Order and Selection
116+
117+
There are two ways to control which fields are written and in what order:
118+
119+
### 1. Class-Level: Override `_fields_to_write()`
120+
121+
Use this when a class should **always** write fields in a specific order. This is useful for
122+
subclasses where child fields should appear before parent fields:
123+
124+
```python
125+
>>> from typing import List
126+
>>> from fgpyo.util.inspect import FieldType
127+
>>> @dataclasses.dataclass(frozen=True)
128+
... class ChildMetric(ParentMetric):
129+
... priority_field: str
130+
...
131+
... @classmethod
132+
... def _fields_to_write(cls, field_types: List[FieldType]) -> List[str]:
133+
... return ["priority_field", "inherited_field"] # Child first
134+
```
135+
136+
### 2. Call-Level: Use `include_fields` or `exclude_fields`
137+
138+
Use this for **one-off** or **varying** field selection:
139+
140+
```python
141+
>>> # Write only specific fields in a specific order
142+
>>> MyMetric.write(path, metric, include_fields=["field_a", "field_b"])
143+
>>> # Write all fields except some
144+
>>> MyMetric.write(path, metric, exclude_fields=["internal_field"])
145+
```
146+
147+
**A good guideline**: If you find yourself passing the same `include_fields`
148+
to every write() call, consider overriding `_fields_to_write()` instead.
114149
"""
115150

116151
import dataclasses
117152
import sys
118153
from abc import ABC
154+
from collections import Counter
119155
from contextlib import AbstractContextManager
120156
from csv import DictWriter
121157
from dataclasses import dataclass
@@ -144,6 +180,7 @@
144180

145181
from fgpyo import io
146182
from fgpyo.util import inspect
183+
from fgpyo.util.inspect import FieldType
147184

148185
MetricType = TypeVar("MetricType", bound="Metric")
149186

@@ -300,24 +337,92 @@ def parse(cls, fields: List[str]) -> Any:
300337
return inspect.attr_from(cls=cls, kwargs=dict(zip(header, fields)), parsers=parsers)
301338

302339
@classmethod
303-
def write(cls, path: Path, *values: MetricType, threads: Optional[int] = None) -> None:
340+
def write(
341+
cls,
342+
path: Path,
343+
*values: MetricType,
344+
include_fields: Optional[List[str]] = None,
345+
exclude_fields: Optional[List[str]] = None,
346+
threads: Optional[int] = None,
347+
) -> None:
304348
"""Writes zero or more metrics to the given path.
305349
306350
The header will always be written.
307351
308352
Args:
309353
path: Path to the output file.
310354
values: Zero or more metrics.
355+
include_fields: If specified, only write these fields, in this order.
356+
Overrides any class-level _fields_to_write() customization.
357+
exclude_fields: If specified, exclude these fields from output.
358+
Cannot be used together with include_fields.
311359
threads: the number of threads to use when compressing gzip files
312360
313361
"""
314-
with MetricWriter[MetricType](path, metric_class=cls, threads=threads) as writer:
362+
with MetricWriter[MetricType](
363+
path,
364+
metric_class=cls,
365+
include_fields=include_fields,
366+
exclude_fields=exclude_fields,
367+
threads=threads,
368+
) as writer:
315369
writer.writeall(values)
316370

371+
@classmethod
372+
def _fields_to_write(cls, field_types: List[FieldType]) -> List[str]:
373+
"""Returns field names for writing, allowing reordering or subsetting.
374+
375+
Override this method when your class should ALWAYS write fields in a
376+
specific order or exclude certain fields. This is useful for:
377+
378+
- Subclasses where child fields should appear before parent fields
379+
- Classes with internal fields that should never be serialized
380+
- Enforcing a consistent output format across all write() calls
381+
382+
For one-off or varying field orders, use the `include_fields` parameter
383+
on write() instead.
384+
385+
Args:
386+
field_types: The list of field types for the class, in definition order.
387+
388+
Returns:
389+
A list of field names to write, in the desired order.
390+
391+
Example:
392+
>>> @dataclass
393+
... class ChildMetric(ParentMetric):
394+
... child_field: str
395+
...
396+
... @classmethod
397+
... def _fields_to_write(cls, field_types):
398+
... # Put child_field before parent fields
399+
... return ["child_field", "parent_field"]
400+
"""
401+
return [f.name for f in field_types]
402+
317403
@classmethod
318404
def header(cls) -> List[str]:
319405
"""The list of header values for the metric."""
320-
return [a.name for a in inspect.get_fields(cls)] # type: ignore[arg-type]
406+
field_types = list(inspect.get_fields(cls)) # type: ignore[arg-type]
407+
field_names = {field.name for field in field_types}
408+
header = cls._fields_to_write(field_types=field_types)
409+
410+
# Validate no extra fields
411+
extra_fields = [h for h in header if h not in field_names]
412+
if extra_fields:
413+
raise ValueError(
414+
f"_fields_to_write() returned fields not in class: {', '.join(extra_fields)}"
415+
)
416+
417+
# Validate no duplicates
418+
counts = Counter(header)
419+
if len(header) != len(counts):
420+
duplicates = [h for h, c in counts.items() if c > 1]
421+
raise ValueError(
422+
f"_fields_to_write() returned duplicate fields: {', '.join(duplicates)}"
423+
)
424+
425+
return header
321426

322427
@classmethod
323428
def format_value(cls, value: Any) -> str: # noqa: C901
@@ -613,12 +718,14 @@ def _validate_and_generate_final_output_fieldnames(
613718
)
614719
elif exclude_fields is not None:
615720
_assert_fieldnames_are_metric_attributes(exclude_fields, metric_class)
616-
output_fieldnames = [f for f in metric_class.keys() if f not in exclude_fields]
721+
# Use header() to respect _fields_to_write() ordering
722+
output_fieldnames = [f for f in metric_class.header() if f not in exclude_fields]
617723
elif include_fields is not None:
618724
_assert_fieldnames_are_metric_attributes(include_fields, metric_class)
619725
output_fieldnames = include_fields
620726
else:
621-
output_fieldnames = list(metric_class.keys())
727+
# Use header() to respect _fields_to_write() ordering
728+
output_fieldnames = metric_class.header()
622729

623730
return output_fieldnames
624731

0 commit comments

Comments
 (0)