Replies: 2 comments 2 replies
-
|
Thanks a lot for the detailed feedback, even including a nice proposal how to update the API! Classes vs. InstancesThe biggest difference in your design, as compared to the existing The advantage of using classes is that we declare collections of metrics declaratively, which feels very Flaxy (from @flax.struct.dataclass
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output("loss")
loss_std: metrics.Std.from_output("loss")Parametrized MetricsThere are different ways to parametriz an existing EfficiencyThanks for highlighting the issue with the jitted A minimal way of extending the existing API would be to add the following class method: class Metric:
# [...]
@classmethod
def empty(cls) -> "Metric":
"""Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op)."""
raise NotImplementedError("Must override empty()")Every sub-classed metric that wants to make use of the added API would then need to implement this new class method. For example, the @flax.struct.dataclass
class Average(Metric):
# [...]
@classmethod
def empty(cls) -> Metric:
return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32))Finally, the class @flax.struct.dataclass
class Collection:
# [...]
@classmethod
def empty(cls) -> "Collection":
return cls(
_reduction_counter=_ReductionCounter(jnp.array(1)),
**{
metric_name: metric.empty()
for metric_name, metric in cls.__annotations__.items()
})So finally we can move the from clu import metrics
import flax
import jax
@flax.struct.dataclass # required for jax.tree_*
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output("loss")
loss_std: metrics.Std.from_output("loss")
def eval_step(ms, model, variables, inputs, labels):
loss, logits = get_loss_and_logits(model, variables, inputs, labels)
return ms.merge(Metrics.gather_from_model_output(
loss=loss, logits=logits, labels=labels))
p_eval_step = jax.pmap(
eval_step, axis_name="batch", static_broadcasted_argnums=0)
def evaluate(model, p_variables, test_ds):
ms = Metrics.empty()
for inputs, labels in test_ds:
ms = flax.jax_utils.unreplicate(
p_eval_step(ms, model, p_variables, inputs, labels))
return ms.compute()SummaryI think this small API extension would address your concerns 2 (concern 1 is already covered in the existin API). I would prefer to keep as much as possible from the existing API because we already have a lot of users using that API and updating them to a new API would be very costly. Even worse, the functionality provided by The proposed API change is purely additional, so users who would do the metric summation outside their jitted |
Beta Was this translation helpful? Give feedback.
-
|
Hey @andsteing, thanks for the detailed response! I understand that drastically changing the API might be challenging or even impossible given it could break Google internal code. I do have a couple of additional points I will mention but in the end I think your TypingWhen I was playing with @flax.struct.dataclass
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output("loss")
loss_std: metrics.Std.from_output("loss")Is actually disliked/not encouraged by some Python linters, I don't know if there is a PEP for this but Pylance warn against this pattern: Have not checked with EmptyI like this solution. The nice thing about having access to an instance thanks to def eval_step(ms: Collection, model, variables, inputs, labels):
loss, logits = get_loss_and_logits(model, variables, inputs, labels)
updates = ms.gather_from_model_output(loss=loss, logits=logits, labels=labels)
return ms.merge(updates)I like it. Parametrization via function local classesEdit: No longer relevant as I saw the strategy in a colab you shared: show original responseTried to create a simple (possibly flawed) `BinaryAccuracy` with a `threshold` parameter in `clu`, this is what I got:@flax.struct.dataclass
class BinaryAccuracy(metrics.Average):
@classmethod
def from_model_outputs(cls, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs):
values= ((logits > 0.5) == labels).astype(jnp.float32)
return super().from_model_output(values, **kwargs)
@staticmethod
def with_params(threshold: float = 0.5):
@flax.struct.dataclass
class BinaryAccuracyWithParams(metrics.Average):
@classmethod
def from_model_outputs(cls, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs):
values= ((logits > threshold) == labels).astype(jnp.float32)
return super().from_model_output(values, **kwargs)
return BinaryAccuracyWithParamsI am not to happy about the approach, maybe it can be cleaned up to avoid code duplication but it feels a bit more complex than having instances which by nature are easy to parametrize. Some thoughts (opinion)Feel free to ignore this section, it just some random thoughts I've had during the process. Asymmetry between Metric and Collection APIsI am very curious why either Collection-like API via instancesThis is probably not important but I'll just mention this in case you are interested, I did try to mimic show codeimport numpy as np
from flax_tools.metrics import Metrics, Accuracy, Mean
loss = np.random.uniform(size=(10,))
logits = np.random.uniform(size=(10, 10))
labels = np.random.randint(0, 10, size=(10,))
metrics = Metrics.new(
[
Accuracy.new(),
Mean.new(name="loss").on_args("loss"),
]
).reset()
metrics = metrics.update(preds=logits, target=labels, loss=loss)
logs = metrics.compute() # e.g: {'accuracy': 0.3, 'loss': 0.47997332} |
Beta Was this translation helpful? Give feedback.


Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Current State
Metricfromclucurrently exposes the following API:Documentation currently suggests they are used like this:
However, this if you try to implement it in terms of a realistic jitted
eval_stepthis pattern become more complex:This has the following downsides:
eval_stepalways has to recompile twice asmetricwill change fromNonein the first step to aMetricinstance from then on.Metricscannot be (easily) parametrized sinceMetricClass.from_model_outputtakes no parameters other than the actual values. If we take a look at a more complex metric such as tf.keras.metrics.BinaryIoU we see it has a couple of parameters such asthreshold. I might be missing something but I don't see an easy way of implementing this inclu.Proposal
My suggestion to solve both is the following API:
This has the following differences:
resetmethod which should leave the metric in a neutral/zero state.updateknows how update current state based on incoming values.from_model_outputis replaced withbatch_updateswhich doesreset+update.Example
A simple implementation of
Accuracycould be:Now the previous can be slightly simplified example can be slightly simplified to:
For a non-distributed setup you can even just use
.updatedirectly since you don't need to synchronize metric state (batch_updates) between devices:Parametrized Metrics
Now the obvious benefit of being able to instantiate a
Metricfrom outside is that you can define parametrized metrics e.g. you could implement anAccuracymetric that with atopkparameter:And use it like this:
Reference Implementation
I've been playing around with this definition of
Metricin this non-published repo calledflax-tools, you can check the definition ofMetricand implementation of a couple of non-trivial metrics ported from Treex in flax_tools/metrics.cc: @jheek @marcvanzee
Beta Was this translation helpful? Give feedback.
All reactions