kd.losses.L2

kd.losses.L2#

class kauldron.losses.L2(
*,
step: kontext.Key = 'step',
mask: kontext.Key | None = None,
weight: int | float | Schedule = 1.0,
normalize_by: Literal['mask', 'values'] = 'mask',
preds: kontext.Key = '__KEY_REQUIRED__',
targets: kontext.Key = '__KEY_REQUIRED__',
)[source]

Bases: kauldron.losses.base.Loss

L2 loss.

preds: kontext.Key = '__KEY_REQUIRED__'
targets: kontext.Key = '__KEY_REQUIRED__'
get_values(
preds: jaxtyping.Float[Array, '*a'] | jaxtyping.Float[ndarray, '*a'],
targets: jaxtyping.Float[Array, '*a'] | jaxtyping.Float[ndarray, '*a'],
) jaxtyping.Float[Array, '*a'] | jaxtyping.Float[ndarray, '*a'][source]

Compute the loss values (before masking, averaging and weighting).

Subclasses need to implement this method. :param *args: Any required arguments (names should match kontext.Key annotations) :param **kwargs: Any arguments (names should match kontext.Key annotations)

Returns:

A jnp.Array of loss values compatible in shape with any desired masking.

empty() kauldron.metrics.base.Metric.State[source]
get_state(
*args,
mask: jaxtyping.Shaped[Array, '...'] | jaxtyping.Shaped[ndarray, '...'] | None = None,
step: int | None = None,
**kwargs,
) kauldron.losses.base.AllReduceMean[source]

Compute the loss state, and takes care of masking and loss-weight.

The Loss.State is AllReduceMean by default which keeps track of a single scalar loss value, but ensures correctly averaging even while using masks.

Parameters:
  • *args – Positional arguments to be passed on to get_values.

  • mask – An optional mask to exclude some of the loss values from the total. The shape of this mask needs to be broadcastable to the shape of values returned from get_values. A value of 1 means that a value should be included (and 0 to exclude).

  • step – The current step to be used to compute the loss-weight if self.weight is set to a schedule. Otherwise step is ignored.

  • **kwargs – Keyword arguments to be passed on to get_values.

Returns:

An instance of Loss.State (AllReduceMean by default) which keeps track of a single scalar loss value, but ensures correctly averaging even while using masks. This final loss value can be computed from this state by calling state.compute(). Optionally the state first can be reduced (to remove the device dimension after pmap) or merged with other (previous) loss states.

Morty Proxy This is a proxified and sanitized view of the page, visit original site.