kd.metrics.sum_field#
- kauldron.metrics.sum_field(
- *,
- default: typing.Any = <dataclasses._MISSING_TYPE object>,
- **kwargs,
Define an AutoState data-field that is merged by summation (a + b).
Preserves shape and assumes that the other (merged) field has the same shape.
Usage:
@flax.struct.dataclass class ShapePreservingAverage(AutoState): summed_values: Float['*any'] = sum_field() total_values: Float['*any'] = sum_field() def compute(self): return self.summed_values / self.total_values
- Parameters:
default – The default value of the field.
**kwargs – Additional arguments to pass to the dataclasses.field.
- Returns:
A dataclasses.Field instance with additional metadata that marks this field as a pytree_node for jax and sets the field merger to _ReduceSum().