Skip to content

Commit 9be384a

Browse files
committed
feat: support custom field ordering and subsetting in Metric.write()
Fixes #177
1 parent a1ffed0 commit 9be384a

File tree

2 files changed

+342
-5
lines changed

2 files changed

+342
-5
lines changed

fgpyo/util/metric.py

Lines changed: 110 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,41 @@
111111
>>> Person(name=Name(first='john', last='doe'), age=42, address=None).formatted_values()
112112
["first last", "42"]
113113
```
114+
115+
## Customizing Field Order and Selection
116+
117+
There are two ways to control which fields are written and in what order:
118+
119+
### 1. Class-Level: Override `_fields_to_write()`
120+
121+
Use this when a class should **always** write fields in a specific order. This is useful for
122+
subclasses where child fields should appear before parent fields:
123+
124+
```python
125+
>>> from typing import List
126+
>>> from fgpyo.util.inspect import FieldType
127+
>>> @dataclasses.dataclass(frozen=True)
128+
... class ChildMetric(ParentMetric):
129+
... priority_field: str
130+
...
131+
... @classmethod
132+
... def _fields_to_write(cls, field_types: List[FieldType]) -> List[str]:
133+
... return ["priority_field", "inherited_field"] # Child first
134+
```
135+
136+
### 2. Call-Level: Use `include_fields` or `exclude_fields`
137+
138+
Use this for **one-off** or **varying** field selection:
139+
140+
```python
141+
>>> # Write only specific fields in a specific order
142+
>>> MyMetric.write(path, metric, include_fields=["field_a", "field_b"])
143+
>>> # Write all fields except some
144+
>>> MyMetric.write(path, metric, exclude_fields=["internal_field"])
145+
```
146+
147+
**A good guideline**: If you find yourself passing the same `include_fields`
148+
to every write() call, consider overriding `_fields_to_write()` instead.
114149
"""
115150

116151
import dataclasses
@@ -144,6 +179,7 @@
144179

145180
from fgpyo import io
146181
from fgpyo.util import inspect
182+
from fgpyo.util.inspect import FieldType
147183

148184
MetricType = TypeVar("MetricType", bound="Metric")
149185

@@ -300,24 +336,91 @@ def parse(cls, fields: List[str]) -> Any:
300336
return inspect.attr_from(cls=cls, kwargs=dict(zip(header, fields)), parsers=parsers)
301337

302338
@classmethod
303-
def write(cls, path: Path, *values: MetricType, threads: Optional[int] = None) -> None:
339+
def write(
340+
cls,
341+
path: Path,
342+
*values: MetricType,
343+
include_fields: Optional[List[str]] = None,
344+
exclude_fields: Optional[List[str]] = None,
345+
threads: Optional[int] = None,
346+
) -> None:
304347
"""Writes zero or more metrics to the given path.
305348
306349
The header will always be written.
307350
308351
Args:
309352
path: Path to the output file.
310353
values: Zero or more metrics.
354+
include_fields: If specified, only write these fields, in this order.
355+
Overrides any class-level _fields_to_write() customization.
356+
exclude_fields: If specified, exclude these fields from output.
357+
Cannot be used together with include_fields.
311358
threads: the number of threads to use when compressing gzip files
312359
313360
"""
314-
with MetricWriter[MetricType](path, metric_class=cls, threads=threads) as writer:
361+
with MetricWriter[MetricType](
362+
path,
363+
metric_class=cls,
364+
include_fields=include_fields,
365+
exclude_fields=exclude_fields,
366+
threads=threads,
367+
) as writer:
315368
writer.writeall(values)
316369

370+
@classmethod
371+
def _fields_to_write(cls, field_types: List[FieldType]) -> List[str]:
372+
"""Returns field names for writing, allowing reordering or subsetting.
373+
374+
Override this method when your class should ALWAYS write fields in a
375+
specific order or exclude certain fields. This is useful for:
376+
377+
- Subclasses where child fields should appear before parent fields
378+
- Classes with internal fields that should never be serialized
379+
- Enforcing a consistent output format across all write() calls
380+
381+
For one-off or varying field orders, use the `include_fields` parameter
382+
on write() instead.
383+
384+
Args:
385+
field_types: The list of field types for the class, in definition order.
386+
387+
Returns:
388+
A list of field names to write, in the desired order.
389+
390+
Example:
391+
>>> @dataclass
392+
... class ChildMetric(ParentMetric):
393+
... child_field: str
394+
...
395+
... @classmethod
396+
... def _fields_to_write(cls, field_types):
397+
... # Put child_field before parent fields
398+
... return ["child_field", "parent_field"]
399+
"""
400+
return [f.name for f in field_types]
401+
317402
@classmethod
318403
def header(cls) -> List[str]:
319404
"""The list of header values for the metric."""
320-
return [a.name for a in inspect.get_fields(cls)] # type: ignore[arg-type]
405+
field_types = list(inspect.get_fields(cls)) # type: ignore[arg-type]
406+
field_names = {field.name for field in field_types}
407+
header = cls._fields_to_write(field_types=field_types)
408+
409+
# Validate no extra fields
410+
extra_fields = [h for h in header if h not in field_names]
411+
if extra_fields:
412+
raise ValueError(
413+
f"_fields_to_write() returned fields not in class: {', '.join(extra_fields)}"
414+
)
415+
416+
# Validate no duplicates
417+
if len(header) != len(set(header)):
418+
duplicates = [h for h in header if header.count(h) > 1]
419+
raise ValueError(
420+
f"_fields_to_write() returned duplicate fields: {', '.join(set(duplicates))}"
421+
)
422+
423+
return header
321424

322425
@classmethod
323426
def format_value(cls, value: Any) -> str: # noqa: C901
@@ -613,12 +716,14 @@ def _validate_and_generate_final_output_fieldnames(
613716
)
614717
elif exclude_fields is not None:
615718
_assert_fieldnames_are_metric_attributes(exclude_fields, metric_class)
616-
output_fieldnames = [f for f in metric_class.keys() if f not in exclude_fields]
719+
# Use header() to respect _fields_to_write() ordering
720+
output_fieldnames = [f for f in metric_class.header() if f not in exclude_fields]
617721
elif include_fields is not None:
618722
_assert_fieldnames_are_metric_attributes(include_fields, metric_class)
619723
output_fieldnames = include_fields
620724
else:
621-
output_fieldnames = list(metric_class.keys())
725+
# Use header() to respect _fields_to_write() ordering
726+
output_fieldnames = metric_class.header()
622727

623728
return output_fieldnames
624729

0 commit comments

Comments
 (0)