|
111 | 111 | >>> Person(name=Name(first='john', last='doe'), age=42, address=None).formatted_values() |
112 | 112 | ["first last", "42"] |
113 | 113 | ``` |
| 114 | +
|
| 115 | +## Customizing Field Order and Selection |
| 116 | +
|
| 117 | +There are two ways to control which fields are written and in what order: |
| 118 | +
|
| 119 | +### 1. Class-Level: Override `_fields_to_write()` |
| 120 | +
|
| 121 | +Use this when a class should **always** write fields in a specific order. This is useful for |
| 122 | +subclasses where child fields should appear before parent fields: |
| 123 | +
|
| 124 | +```python |
| 125 | + >>> from typing import List |
| 126 | + >>> from fgpyo.util.inspect import FieldType |
| 127 | + >>> @dataclasses.dataclass(frozen=True) |
| 128 | + ... class ChildMetric(ParentMetric): |
| 129 | + ... priority_field: str |
| 130 | + ... |
| 131 | + ... @classmethod |
| 132 | + ... def _fields_to_write(cls, field_types: List[FieldType]) -> List[str]: |
| 133 | + ... return ["priority_field", "inherited_field"] # Child first |
| 134 | +``` |
| 135 | +
|
| 136 | +### 2. Call-Level: Use `include_fields` or `exclude_fields` |
| 137 | +
|
| 138 | +Use this for **one-off** or **varying** field selection: |
| 139 | +
|
| 140 | +```python |
| 141 | + >>> # Write only specific fields in a specific order |
| 142 | + >>> MyMetric.write(path, metric, include_fields=["field_a", "field_b"]) |
| 143 | + >>> # Write all fields except some |
| 144 | + >>> MyMetric.write(path, metric, exclude_fields=["internal_field"]) |
| 145 | +``` |
| 146 | +
|
| 147 | +**A good guideline**: If you find yourself passing the same `include_fields` |
| 148 | +to every write() call, consider overriding `_fields_to_write()` instead. |
114 | 149 | """ |
115 | 150 |
|
116 | 151 | import dataclasses |
|
144 | 179 |
|
145 | 180 | from fgpyo import io |
146 | 181 | from fgpyo.util import inspect |
| 182 | +from fgpyo.util.inspect import FieldType |
147 | 183 |
|
148 | 184 | MetricType = TypeVar("MetricType", bound="Metric") |
149 | 185 |
|
@@ -300,24 +336,91 @@ def parse(cls, fields: List[str]) -> Any: |
300 | 336 | return inspect.attr_from(cls=cls, kwargs=dict(zip(header, fields)), parsers=parsers) |
301 | 337 |
|
302 | 338 | @classmethod |
303 | | - def write(cls, path: Path, *values: MetricType, threads: Optional[int] = None) -> None: |
| 339 | + def write( |
| 340 | + cls, |
| 341 | + path: Path, |
| 342 | + *values: MetricType, |
| 343 | + include_fields: Optional[List[str]] = None, |
| 344 | + exclude_fields: Optional[List[str]] = None, |
| 345 | + threads: Optional[int] = None, |
| 346 | + ) -> None: |
304 | 347 | """Writes zero or more metrics to the given path. |
305 | 348 |
|
306 | 349 | The header will always be written. |
307 | 350 |
|
308 | 351 | Args: |
309 | 352 | path: Path to the output file. |
310 | 353 | values: Zero or more metrics. |
| 354 | + include_fields: If specified, only write these fields, in this order. |
| 355 | + Overrides any class-level _fields_to_write() customization. |
| 356 | + exclude_fields: If specified, exclude these fields from output. |
| 357 | + Cannot be used together with include_fields. |
311 | 358 | threads: the number of threads to use when compressing gzip files |
312 | 359 |
|
313 | 360 | """ |
314 | | - with MetricWriter[MetricType](path, metric_class=cls, threads=threads) as writer: |
| 361 | + with MetricWriter[MetricType]( |
| 362 | + path, |
| 363 | + metric_class=cls, |
| 364 | + include_fields=include_fields, |
| 365 | + exclude_fields=exclude_fields, |
| 366 | + threads=threads, |
| 367 | + ) as writer: |
315 | 368 | writer.writeall(values) |
316 | 369 |
|
| 370 | + @classmethod |
| 371 | + def _fields_to_write(cls, field_types: List[FieldType]) -> List[str]: |
| 372 | + """Returns field names for writing, allowing reordering or subsetting. |
| 373 | +
|
| 374 | + Override this method when your class should ALWAYS write fields in a |
| 375 | + specific order or exclude certain fields. This is useful for: |
| 376 | +
|
| 377 | + - Subclasses where child fields should appear before parent fields |
| 378 | + - Classes with internal fields that should never be serialized |
| 379 | + - Enforcing a consistent output format across all write() calls |
| 380 | +
|
| 381 | + For one-off or varying field orders, use the `include_fields` parameter |
| 382 | + on write() instead. |
| 383 | +
|
| 384 | + Args: |
| 385 | + field_types: The list of field types for the class, in definition order. |
| 386 | +
|
| 387 | + Returns: |
| 388 | + A list of field names to write, in the desired order. |
| 389 | +
|
| 390 | + Example: |
| 391 | + >>> @dataclass |
| 392 | + ... class ChildMetric(ParentMetric): |
| 393 | + ... child_field: str |
| 394 | + ... |
| 395 | + ... @classmethod |
| 396 | + ... def _fields_to_write(cls, field_types): |
| 397 | + ... # Put child_field before parent fields |
| 398 | + ... return ["child_field", "parent_field"] |
| 399 | + """ |
| 400 | + return [f.name for f in field_types] |
| 401 | + |
317 | 402 | @classmethod |
318 | 403 | def header(cls) -> List[str]: |
319 | 404 | """The list of header values for the metric.""" |
320 | | - return [a.name for a in inspect.get_fields(cls)] # type: ignore[arg-type] |
| 405 | + field_types = list(inspect.get_fields(cls)) # type: ignore[arg-type] |
| 406 | + field_names = {field.name for field in field_types} |
| 407 | + header = cls._fields_to_write(field_types=field_types) |
| 408 | + |
| 409 | + # Validate no extra fields |
| 410 | + extra_fields = [h for h in header if h not in field_names] |
| 411 | + if extra_fields: |
| 412 | + raise ValueError( |
| 413 | + f"_fields_to_write() returned fields not in class: {', '.join(extra_fields)}" |
| 414 | + ) |
| 415 | + |
| 416 | + # Validate no duplicates |
| 417 | + if len(header) != len(set(header)): |
| 418 | + duplicates = [h for h in header if header.count(h) > 1] |
| 419 | + raise ValueError( |
| 420 | + f"_fields_to_write() returned duplicate fields: {', '.join(set(duplicates))}" |
| 421 | + ) |
| 422 | + |
| 423 | + return header |
321 | 424 |
|
322 | 425 | @classmethod |
323 | 426 | def format_value(cls, value: Any) -> str: # noqa: C901 |
@@ -613,12 +716,14 @@ def _validate_and_generate_final_output_fieldnames( |
613 | 716 | ) |
614 | 717 | elif exclude_fields is not None: |
615 | 718 | _assert_fieldnames_are_metric_attributes(exclude_fields, metric_class) |
616 | | - output_fieldnames = [f for f in metric_class.keys() if f not in exclude_fields] |
| 719 | + # Use header() to respect _fields_to_write() ordering |
| 720 | + output_fieldnames = [f for f in metric_class.header() if f not in exclude_fields] |
617 | 721 | elif include_fields is not None: |
618 | 722 | _assert_fieldnames_are_metric_attributes(include_fields, metric_class) |
619 | 723 | output_fieldnames = include_fields |
620 | 724 | else: |
621 | | - output_fieldnames = list(metric_class.keys()) |
| 725 | + # Use header() to respect _fields_to_write() ordering |
| 726 | + output_fieldnames = metric_class.header() |
622 | 727 |
|
623 | 728 | return output_fieldnames |
624 | 729 |
|
|
0 commit comments