kd.metrics.State#
- class kauldron.metrics.State(
- *,
- parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
Bases:
abc.ABC
,Generic
[kauldron.metrics.base_state._MetricT
]Base metric state class.
In Kauldron,
kd.metrics.Metric
are stateless pure-python objects. Instead, each metrics emit akd.metrics.State
when calling state = metric(**kwargs) (often inside the jax.jit train or eval step).Those states can then be accumulated across multiple steps (with state.merge) before computing the final value (with state.compute())
metric = kd.metric.Accuracy() state = metric.get_state(logits=logits, labels=labels) # Optionally accumulate the state across multiple batches state = state.merge(other_state) values = state.compute() # Get the final value
- Attribute:
- parent: A reference to the metric that emitted this state. Automatically
added by metric.get_state().
- parent: kauldron.metrics.base_state._MetricT = 1
- abstractmethod classmethod empty() kauldron.metrics.base_state._SelfT [source]
Returns an empty instance (i.e. .merge(State.empty()) is a no-op).
- abstractmethod merge(
- other: kauldron.metrics.base_state._SelfT,
Returns a new state that is the accumulation of self and other.
- abstractmethod compute() Any [source]
Computes final metrics from intermediate values.
- static isinstance(other) bool [source]
Returns whether other is a
State
(used in tree_map(is_leaf=)).
- replace(**updates)
Returns a new object replacing the specified fields with new values.