kd.knn.Dense

kd.knn.Dense#

class kauldron.klinen.Dense(features: int, use_bias: bool = True, dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any, NoneType] = None, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any] = <class 'jax.numpy.float32'>, precision: Union[NoneType, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]] = None, kernel_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[..., Any]] = <function variance_scaling.<locals>.init at 0x7363ffc3efc0>, bias_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[..., Any]] = <function zeros at 0x736402dcb9c0>, promote_dtype: flax.linen.linear.PromoteDtypeFn = <function promote_dtype at 0x736418159080>, dot_general: collections.abc.Callable[..., typing.Union[jax.Array, typing.Any]] | None = None, dot_general_cls: Any = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x736418142750>, name: Optional[str] = None, *, _kd_state: 'Optional[_ModuleState]' = None)[source]

Bases: flax.linen.linear.Dense, kauldron.klinen.module.Module

name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: Scope | None = None
Morty Proxy This is a proxified and sanitized view of the page, visit original site.