Metrics#
Loss, metrics, summaries all share the same API.
See the available metrics:
Using a metric#
Kauldron usage#
In Kauldron, the metrics are automatically applied and accumulated by the training loop. Users specify what the metrics inputs are through the keys.
cfg.metrics = {
'reconstruction': kd.losses.L2(preds="preds.image", targets="batch.image"),
'roc_auc': kd.metrics.RocAuc(preds="preds.logits", targets="batch.label"),
}
Standalone usage#
Metrics can be used outside Kauldron, as standalone modules (using
//third_party/py/kauldron/metrics
,…):
from kauldron import metrics
from kauldron import losses
from kauldron import summaries
Metrics are stateless objects.
Creation:
metric1 = metrics.Accuracy()
Usage (1-time):
accuracy = metric(logits=logits, labels=labels)
Equivalent to:
accuracy = metric.get_state(logits=logits, labels=labels).compute()
Usage (accumulated):
Some metrics require accumulating values over multiple steps. In this case, every metric can emit states which are merged together:
state0 = metric.get_state(logits=logits, labels=labels)
state1 = metric.get_state(logits=logits, labels=labels)
# Accumulate the states
acc_state = state0.merge(state1)
# Compute the final metric value
accuracy = acc_state.compute()
Creating a metric#
Metric#
Metrics inherit the kd.metrics.Metric
class and overwrite the State
class
and get_state
attribute.
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
class Accuracy(kd.metrics.Metric):
"""Classification Accuracy."""
logits: kontext.Key = kd.kontext.REQUIRED # e.g. "preds.logits"
labels: kontext.Key = kd.kontext.REQUIRED # e.g. "batch.label"
# Could be `State = kd.metrics.AverageState` but inheritance gives a better
# name `Accuracy.State`
class State(kd.metrics.AverageState):
pass
@typechecked
def get_state(self, logits: Float["*b n"], labels: Float["*b"]) -> Float["*b"]:
correct = logits.argmax(axis=-1) == labels
return self.State.from_values(values=correct)
The state performs the aggregation of the metric values. Some states are provided by default:
kd.metrics.AverageState
: for simple averaging of a value (e.g.kd.metrics.Norm
).kd.metrics.CollectingState
(deprecated): for metrics that need to collect and concatenate model outputs over many batches (e.g.kd.metrics.RocAuc
).kd.metrics.AutoState
: A more flexible version ofCollectingState
which supports arbitrary aggregation. SeeSummary
section below.
You can also implement your custom State
. To choose whether the logic should
go in State
or Metric
:
Metric.get_state
: Is executed insidejax.jit
State.compute
: Is executed outsidejax.jit
, so can contain arbitrary Python code (e.g. somesklearn.metrics
contain logic which would be hard to implement in pure Jax)
Loss#
All losses inherit from kd.metrics.Metric
, so use the same API as above.
However for convenience, a kd.losses.Loss
base class is provided which also
supports handling masks, averaging, and loss-weight.
The difference is that Loss
implements the get_values
method instead of
get_state
(and the get_state
method is implemented by the base class).
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
class L2(kd.losses.Loss):
"""L2 loss."""
preds: kontext.Key = kd.kontext.REQUIRED
targets: kontext.Key = kd.kontext.REQUIRED
@typechecked
def get_values(self, preds: Float["*b"], targets: Float["*b"]) -> Float["*b"]:
return jnp.square(preds - targets)
Note: The get_values
method returns the per-example loss (i.e. the returned
value has the batch dimension), and the averaging is done in the base class.
This ensures accumulating losses over multiple batches (e.g. in eval) works
correctly.
Losses also adds:
weight
: If multiple losses are used, they can be weighted differently.mask
: Allow to filter out examples from the batch.
Summary#
Summaries are like metrics. The only difference is the State.compute()
method
does not return a scalar but instead returns one of:
Array['b h w c']
: for image summaries.str
for text summaries.kd.summaries.Histogram
: for histogram summaries.kd.summaries.PointCloud
: for point cloud summaries.
For convenience, the State
can inherit from kd.metrics.AutoState
to perform
automated aggregations (e.g. only keep the first x
images):
@dataclasses.dataclass(kw_only=True, frozen=True)
class ShowImages(kd.metrics.Metric):
num_images: int = 5
@flax.struct.dataclass
class State(kd.metrics.AutoState):
"""Collects the first num_images images and boxes."""
# When the states are aggregated (e.g. across multiple batches, only keep
# the first `num_images` images).
images: Float["n h w c"] = kd.metrics.truncate_field(num_field="parent.num_images")
boxes: Float["n k 4"] = kd.metrics.truncate_field(num_field="parent.num_images")
...
...