kd.metrics.sum_field

kd.metrics.sum_field#

kauldron.metrics.sum_field(
*,
default: typing.Any = <dataclasses._MISSING_TYPE object>,
**kwargs,
)[source]

Define an AutoState data-field that is merged by summation (a + b).

Preserves shape and assumes that the other (merged) field has the same shape.

Usage:

@flax.struct.dataclass
class ShapePreservingAverage(AutoState):
  summed_values: Float['*any'] = sum_field()
  total_values: Float['*any'] = sum_field()

  def compute(self):
    return self.summed_values / self.total_values
Parameters:
  • default – The default value of the field.

  • **kwargs – Additional arguments to pass to the dataclasses.field.

Returns:

A dataclasses.Field instance with additional metadata that marks this field as a pytree_node for jax and sets the field merger to _ReduceSum().

Morty Proxy This is a proxified and sanitized view of the page, visit original site.