kd.metrics.State

kd.metrics.State#

class kauldron.metrics.State(
*,
parent: kauldron.metrics.base_state._MetricT = _EMPTY_TYPE.EMPTY,
)[source]

Bases: abc.ABC, Generic[kauldron.metrics.base_state._MetricT]

Base metric state class.

In Kauldron, kd.metrics.Metric are stateless pure-python objects. Instead, each metrics emit a kd.metrics.State when calling state = metric(**kwargs) (often inside the jax.jit train or eval step).

Those states can then be accumulated across multiple steps (with state.merge) before computing the final value (with state.compute())

metric = kd.metric.Accuracy()

state = metric.get_state(logits=logits, labels=labels)

# Optionally accumulate the state across multiple batches
state = state.merge(other_state)

values = state.compute()  # Get the final value
Attribute:
parent: A reference to the metric that emitted this state. Automatically

added by metric.get_state().

parent: kauldron.metrics.base_state._MetricT = 1
abstractmethod classmethod empty() kauldron.metrics.base_state._SelfT[source]

Returns an empty instance (i.e. .merge(State.empty()) is a no-op).

abstractmethod merge(
other: kauldron.metrics.base_state._SelfT,
) kauldron.metrics.base_state._SelfT[source]

Returns a new state that is the accumulation of self and other.

Parameters:

other – A State whose intermediate values should be accumulated onto the values of self.

Returns:

A new State that accumulates the value from both self and other.

abstractmethod compute() Any[source]

Computes final metrics from intermediate values.

static isinstance(other) bool[source]

Returns whether other is a State (used in tree_map(is_leaf=)).

replace(**updates)

Returns a new object replacing the specified fields with new values.

Morty Proxy This is a proxified and sanitized view of the page, visit original site.