package nn
Provides building blocks for neural networks
Notable types:
- nn.GenericModule is an abstraction on parametric functions
- nn.Optimizer is an abstraction of gradient based optimizers
- nn.LossFunction is an abstraction of loss functions, see the companion object for the implemented losses
- nn.SupervisedModel combines a module with a loss
Optimizers:
Modules facilitating composing other modules:
- nn.Sequential composes a homogenous list of modules (analogous to List)
- nn.sequence composes a heterogeneous list of modules (analogous to tuples)
- nn.EitherModule composes two modules in a scala.Either
Examples of neural network building blocks, layers etc:
- nn.Linear implements
W X + b
with parametersW
andb
and inputX
- nn.BatchNorm, nn.LayerNorm implement batch and layer normalization
- nn.MLP is a factory of a multilayer perceptron architecture
- Alphabetic
- By Inheritance
- nn
- AnyRef
- Any
- Hide All
- Show All
- Public
- Protected
Type Members
- case class AdamW(parameters: Seq[(STen, PTag)], weightDecay: OptimizerHyperparameter, learningRate: OptimizerHyperparameter = simple(0.001), beta1: OptimizerHyperparameter = simple(0.9), beta2: OptimizerHyperparameter = simple(0.999), eps: Double = 1e-8, clip0: Option[Double] = None, debias: Boolean = true, mixedPrecision: Boolean = false) extends Optimizer with Product with Serializable
- See also
https://arxiv.org/pdf/1711.05101.pdf Algorithm 2
- case class AdversarialTraining(eps: Double) extends LossCalculation[Variable] with Product with Serializable
- case class BatchNorm(weight: Constant, bias: Constant, runningMean: Constant, runningVar: Constant, training: Boolean, momentum: Double, eps: Double, forceTrain: Boolean, forceEval: Boolean, evalIfBatchSizeIsOne: Boolean) extends Module with Product with Serializable
- case class BatchNorm2D(weight: Constant, bias: Constant, runningMean: Constant, runningVar: Constant, training: Boolean, momentum: Double, eps: Double) extends Module with Product with Serializable
- case class Conv1D(weights: Constant, bias: Constant, stride: Long, padding: Long, dilation: Long, groups: Long) extends Module with Product with Serializable
- case class Conv2D(weights: Constant, bias: Constant, stride: Long, padding: Long, dilation: Long, groups: Long) extends Module with Product with Serializable
- case class Conv2DTransposed(weights: Constant, bias: Constant, stride: Long, padding: Long, dilation: Long) extends Module with Product with Serializable
- case class Debug(fun: (STen, Boolean, Boolean) => Unit) extends Module with Product with Serializable
- case class DependentHyperparameter(default: Double)(pf: PartialFunction[PTag, Double]) extends OptimizerHyperparameter with Product with Serializable
- case class Dropout(prob: Double, training: Boolean) extends Module with Product with Serializable
- case class EitherModule[A, B, M1 <: GenericModule[A, B], M2 <: GenericModule[A, B]](members: Either[M1 with GenericModule[A, B], M2 with GenericModule[A, B]]) extends GenericModule[A, B] with Product with Serializable
- case class Embedding(weights: Constant) extends Module with Product with Serializable
Learnable mapping from classes to dense vectors.
Learnable mapping from classes to dense vectors. Equivalent to L * W where L is the n x C one-hot encoded matrix of the classes * is matrix multiplication W is the C x dim dense matrix. W is learnable. L is never computed directly. C is the number of classes. n is the size of the batch.
Input is a long tensor with values in [0,C-1]. Input shape is arbitrary, (*). Output shape is (* x D) where D is the embedding dimension.
- case class FreeRunningRNN[T, M <: StatefulModule[Variable, Variable, T]](module: M with StatefulModule[Variable, Variable, T], timeSteps: Int) extends StatefulModule[Variable, Variable, T] with Product with Serializable
Wraps a (sequence x batch) long -> (sequence x batch x dim) double stateful module and runs in it greedy (argmax) generation mode over
timeSteps
steps. - case class Fun(fun: (Scope) => (Variable) => Variable) extends Module with Product with Serializable
- case class GRU(weightXh: Constant, weightHh: Constant, weightXr: Constant, weightXz: Constant, weightHr: Constant, weightHz: Constant, biasR: Constant, biasZ: Constant, biasH: Constant) extends StatefulModule[Variable, Variable, Option[Variable]] with Product with Serializable
Inputs of size (sequence length * batch * in dim) Outputs of size (sequence length * batch * hidden dim)
- case class GenericFun[A, B](fun: (Scope) => (A) => B) extends GenericModule[A, B] with Product with Serializable
- trait GenericModule[A, B] extends AnyRef
Base type of modules
Base type of modules
Modules are functions of type
(Seq[lamp.autograd.Constant],A) => B
, where theSeq[lamp.autograd.Constant]
arguments are optimizable parameters andA
is a non-optimizable input.Modules provide a way to build composite functions while also keep track of the parameter list of the composite function.
Example
case object Weights extends LeafTag case object Bias extends LeafTag case class Linear(weights: Constant, bias: Option[Constant]) extends Module { override val state = List( weights -> Weights ) ++ bias.toList.map(b => (b, Bias)) def forward[S: Sc](x: Variable): Variable = { val v = x.mm(weights) bias.map(_ + v).getOrElse(v) } }
Some other attributes of modules are attached by type classes e.g. with the nn.TrainingMode, nn.Load type classes.
- A
the argument type of the module
- B
the value type of the module
- See also
nn.Module is an alias for simple
Variable => Variable
modules
- trait InitState[M, C] extends AnyRef
Type class about how to initialize recurrent neural networks
- implicit class InitStateSyntax[M, C] extends AnyRef
- case class LSTM(weightXi: Constant, weightXf: Constant, weightXo: Constant, weightHi: Constant, weightHf: Constant, weightHo: Constant, weightXc: Constant, weightHc: Constant, biasI: Constant, biasF: Constant, biasO: Constant, biasC: Constant) extends StatefulModule[Variable, Variable, Option[(Variable, Variable)]] with Product with Serializable
Inputs of size (sequence length * batch * vocab) Outputs of size (sequence length * batch * output dim)
- case class LayerNorm(scale: Option[Constant], bias: Option[Constant], eps: Double, normalizedShape: List[Long]) extends Module with Product with Serializable
- trait LeafTag extends PTag
- trait LearningRateSchedule[State] extends AnyRef
- case class LiftedModule[M <: Module](mod: M with Module) extends StatefulModule[Variable, Variable, Unit] with Product with Serializable
- case class Linear(weights: Constant, bias: Option[Constant]) extends Module with Product with Serializable
- trait Load[M] extends AnyRef
Type class about how to load the contents of the state of modules from external tensors
- implicit class LoadSyntax[M] extends AnyRef
- trait LossCalculation[I] extends AnyRef
Loss and Gradient calculation
Loss and Gradient calculation
Takes samples, target, module, loss function and computes the loss and the gradients
- trait LossFunction extends AnyRef
- case class MappedState[A, B, C, D, M <: StatefulModule[A, B, C]](statefulModule: M with StatefulModule[A, B, C], map: (C) => D) extends StatefulModule2[A, B, C, D] with Product with Serializable
- case class ModelWithOptimizer[I, M <: GenericModule[I, Variable]](model: SupervisedModel[I, M], optimizer: Optimizer) extends Product with Serializable
- type Module = GenericModule[Variable, Variable]
- case class MultiheadAttention(wQ: Constant, wK: Constant, wV: Constant, wO: Constant, dropout: Double, train: Boolean, numHeads: Int, linearized: Boolean, causalMask: Boolean) extends GenericModule[(Variable, Variable, Variable, Option[STen]), Variable] with Product with Serializable
Multi-head scaled dot product attention module
Multi-head scaled dot product attention module
Input: (query,key,value,maxLength) where
- query: batch x num queries x query dim
- key: batch x num k-v x key dim
- value: batch x num k-v x key value
- maxLength: 1D or 2D long tensor for attention masking
- trait Optimizer extends AnyRef
- trait OptimizerHyperparameter extends AnyRef
- trait PTag extends AnyRef
A small trait to mark paramters for unique identification
- class PerturbedLossCalculation[I] extends LossCalculation[I]
Evaluates the gradient at current point + eps where eps is I * N(0,noiseLevel)
- case class RAdam(parameters: Seq[(STen, PTag)], weightDecay: OptimizerHyperparameter, learningRate: OptimizerHyperparameter = simple(0.001), beta1: OptimizerHyperparameter = simple(0.9), beta2: OptimizerHyperparameter = simple(0.999), eps: Double = 1e-8, clip0: Option[Double] = None) extends Optimizer with Product with Serializable
Rectified Adam optimizer algorithm
- case class RNN(weightXh: Constant, weightHh: Constant, biasH: Constant) extends StatefulModule[Variable, Variable, Option[Variable]] with Product with Serializable
Inputs of size (sequence length * batch * in dim) Outputs of size (sequence length * batch * hidden dim)
- case class Recursive[A, M <: GenericModule[A, A]](member: M with GenericModule[A, A], n: Int) extends GenericModule[A, A] with Product with Serializable
- case class ResidualModule[M <: Module](transform: M with Module) extends Module with Product with Serializable
- case class SGDW(parameters: Seq[(STen, PTag)], learningRate: OptimizerHyperparameter, weightDecay: OptimizerHyperparameter, momentum: Option[OptimizerHyperparameter] = None, clip0: Option[Double] = None) extends Optimizer with Product with Serializable
- case class Seq2[T1, T2, T3, M1 <: GenericModule[T1, T2], M2 <: GenericModule[T2, T3]](m1: M1 with GenericModule[T1, T2], m2: M2 with GenericModule[T2, T3]) extends GenericModule[T1, T3] with Product with Serializable
- case class Seq2Seq[S0, S1, M1 <: StatefulModule2[Variable, Variable, S0, S1], M2 <: StatefulModule[Variable, Variable, S1]](encoder: M1 with StatefulModule2[Variable, Variable, S0, S1], decoder: M2 with StatefulModule[Variable, Variable, S1]) extends StatefulModule2[(Variable, Variable), Variable, S0, S1] with Product with Serializable
- case class Seq3[T1, T2, T3, T4, M1 <: GenericModule[T1, T2], M2 <: GenericModule[T2, T3], M3 <: GenericModule[T3, T4]](m1: M1 with GenericModule[T1, T2], m2: M2 with GenericModule[T2, T3], m3: M3 with GenericModule[T3, T4]) extends GenericModule[T1, T4] with Product with Serializable
- case class Seq4[T1, T2, T3, T4, T5, M1 <: GenericModule[T1, T2], M2 <: GenericModule[T2, T3], M3 <: GenericModule[T3, T4], M4 <: GenericModule[T4, T5]](m1: M1 with GenericModule[T1, T2], m2: M2 with GenericModule[T2, T3], m3: M3 with GenericModule[T3, T4], m4: M4 with GenericModule[T4, T5]) extends GenericModule[T1, T5] with Product with Serializable
- case class Seq5[T1, T2, T3, T4, T5, T6, M1 <: GenericModule[T1, T2], M2 <: GenericModule[T2, T3], M3 <: GenericModule[T3, T4], M4 <: GenericModule[T4, T5], M5 <: GenericModule[T5, T6]](m1: M1 with GenericModule[T1, T2], m2: M2 with GenericModule[T2, T3], m3: M3 with GenericModule[T3, T4], m4: M4 with GenericModule[T4, T5], m5: M5 with GenericModule[T5, T6]) extends GenericModule[T1, T6] with Product with Serializable
- case class Seq6[T1, T2, T3, T4, T5, T6, T7, M1 <: GenericModule[T1, T2], M2 <: GenericModule[T2, T3], M3 <: GenericModule[T3, T4], M4 <: GenericModule[T4, T5], M5 <: GenericModule[T5, T6], M6 <: GenericModule[T6, T7]](m1: M1 with GenericModule[T1, T2], m2: M2 with GenericModule[T2, T3], m3: M3 with GenericModule[T3, T4], m4: M4 with GenericModule[T4, T5], m5: M5 with GenericModule[T5, T6], m6: M6 with GenericModule[T6, T7]) extends GenericModule[T1, T7] with Product with Serializable
- case class SeqLinear(weight: Constant, bias: Constant) extends Module with Product with Serializable
Inputs of size (sequence length * batch * in dim) Outputs of size (sequence length * batch * output dim) Applies a linear function to each time step
- case class Sequential[A, M <: GenericModule[A, A]](members: M with GenericModule[A, A]*) extends GenericModule[A, A] with Product with Serializable
- case class Shampoo(parameters: Seq[(STen, PTag)], learningRate: OptimizerHyperparameter = simple(0.001), clip0: Option[Double] = None, eps: Double = 1e-4, diagonalThreshold: Int = 256, updatePreconditionerEveryNIterations: Int = 100, momentum: OptimizerHyperparameter = simple(0d)) extends Optimizer with Product with Serializable
- See also
https://arxiv.org/pdf/1802.09568.pdf Algorithm 1
- class SimpleLossCalculation[I] extends LossCalculation[I]
- type StatefulModule[A, B, C] = GenericModule[(A, C), (B, C)]
- type StatefulModule2[A, B, C, D] = GenericModule[(A, C), (B, D)]
- case class StatefulSeq2[T1, T2, T3, S1, S2, M1 <: StatefulModule[T1, T2, S1], M2 <: StatefulModule[T2, T3, S2]](m1: M1 with StatefulModule[T1, T2, S1], m2: M2 with StatefulModule[T2, T3, S2]) extends StatefulModule[T1, T3, (S1, S2)] with Product with Serializable
- case class StatefulSeq3[T1, T2, T3, T4, S1, S2, S3, M1 <: StatefulModule[T1, T2, S1], M2 <: StatefulModule[T2, T3, S2], M3 <: StatefulModule[T3, T4, S3]](m1: M1 with StatefulModule[T1, T2, S1], m2: M2 with StatefulModule[T2, T3, S2], m3: M3 with StatefulModule[T3, T4, S3]) extends StatefulModule[T1, T4, (S1, S2, S3)] with Product with Serializable
- case class StatefulSeq4[T1, T2, T3, T4, T5, S1, S2, S3, S4, M1 <: StatefulModule[T1, T2, S1], M2 <: StatefulModule[T2, T3, S2], M3 <: StatefulModule[T3, T4, S3], M4 <: StatefulModule[T4, T5, S4]](m1: M1 with StatefulModule[T1, T2, S1], m2: M2 with StatefulModule[T2, T3, S2], m3: M3 with StatefulModule[T3, T4, S3], m4: M4 with StatefulModule[T4, T5, S4]) extends StatefulModule[T1, T5, (S1, S2, S3, S4)] with Product with Serializable
- case class StatefulSeq5[T1, T2, T3, T4, T5, T6, S1, S2, S3, S4, S5, M1 <: StatefulModule[T1, T2, S1], M2 <: StatefulModule[T2, T3, S2], M3 <: StatefulModule[T3, T4, S3], M4 <: StatefulModule[T4, T5, S4], M5 <: StatefulModule[T5, T6, S5]](m1: M1 with StatefulModule[T1, T2, S1], m2: M2 with StatefulModule[T2, T3, S2], m3: M3 with StatefulModule[T3, T4, S3], m4: M4 with StatefulModule[T4, T5, S4], m5: M5 with StatefulModule[T5, T6, S5]) extends StatefulModule[T1, T6, (S1, S2, S3, S4, S5)] with Product with Serializable
- case class SupervisedModel[I, M <: GenericModule[I, Variable]](module: M with GenericModule[I, Variable], lossFunction: LossFunction, lossCalculation: LossCalculation[I] = new SimpleLossCalculation[I], printMemoryAllocations: Boolean = false)(implicit tm: TrainingMode[M]) extends Product with Serializable
- implicit class ToLift[M <: Module] extends AnyRef
- implicit class ToMappedState[A, B, C, M <: StatefulModule[A, B, C]] extends AnyRef
- implicit class ToUnlift[A, B, C, D, M <: StatefulModule2[A, B, C, D]] extends AnyRef
- implicit class ToWithInit[A, B, C, M <: StatefulModule[A, B, C]] extends AnyRef
- trait TrainingMode[M] extends AnyRef
Type class about how to switch a module into training or evaluation mode
- implicit class TrainingModeSyntax[M] extends AnyRef
- case class Transformer(encoder: TransformerEncoder, decoder: TransformerDecoder) extends GenericModule[(Variable, Variable, Option[STen], Option[STen]), Variable] with Product with Serializable
- case class TransformerDecoder(blocks: Seq[TransformerDecoderBlock]) extends GenericModule[(Variable, Variable, Option[STen]), Variable] with Product with Serializable
- case class TransformerDecoderBlock(attentionDecoderDecoder: MultiheadAttention, attentionEncoderDecoder: MultiheadAttention, layerNorm1: LayerNorm, layerNorm2: LayerNorm, layerNorm3: LayerNorm, layerNorm4: LayerNorm, w1: Constant, b1: Constant, w2: Constant, b2: Constant, dropout: Double, train: Boolean) extends GenericModule[(Variable, Variable, Option[STen]), Variable] with Product with Serializable
- case class TransformerEmbedding(embedding: Embedding, addPositionalEmbedding: Boolean, positionalEmbedding: Constant) extends GenericModule[Variable, Variable] with Product with Serializable
A module with positional and token embeddings
A module with positional and token embeddings
Token embeddings are lookup embeddings. Positional embeddings are supplied as a constant. They are supposed to come from a fixed unlearned derivation of the positions.
Token and positional embeddings are summed.
Gradients are not computed for
positionalEmbedding
- case class TransformerEncoder(blocks: Seq[TransformerEncoderBlock]) extends GenericModule[(Variable, Option[STen]), Variable] with Product with Serializable
TransformerEncoder module
TransformerEncoder module
Does *not* include initial embedding or position encoding.
Input is
(data, maxLength)
wheredata
is (batch, sequence, input dimension), double tensormaxLength
is a 1D or 2D long tensor used for attention masking.Attention masking is implemented similarly to chapter 11.3.2.1 in d2l.ai v1.0.0-beta0. It supports unmasked attention, attention on variable length input, and left-to-right attention.
Output is (bach, sequence, output dimension)
- case class TransformerEncoderBlock(attention: MultiheadAttention, layerNorm1: LayerNorm, layerNorm2: LayerNorm, w1: Constant, b1: Constant, w2: Constant, b2: Constant, scale1: Constant, scale2: Constant, dropout: Double, train: Boolean, gptOrder: Boolean) extends GenericModule[(Variable, Option[STen]), Variable] with Product with Serializable
A single block of the transformer self attention encoder using GELU
A single block of the transformer self attention encoder using GELU
Input is
(data, maxLength)
wheredata
is (batch, sequence, input dimension), double tensormaxLength
is a 1D or 2D long tensor used for attention masking.The order of operations depends on gptOrder param. If
gptOrder
is true then:- y = attention(norm(input))+input
- result = mlp(norm(y))+y
- Note that in this case there is no normalization at the end of the transformer. One may wants to add one separately. This is how GPT2 is defined in hugging face or nanoGPT.
- Note that the residual connection has a path which does not flow through the normalization.
- + dimension wise learnable scale parameter in each residual path
If
gptOrder
is false then:- y = norm(attention(input)+input )
- result = norm(mlp(y)+y)
- This follows chapter 11.7 in d2l.ai v1.0.0-beta0. (Same as in https://arxiv.org/pdf/1706.03762.pdf)
- Note that the residual connection has a path which flows through the normalization.
Output is (bach, sequence, output dimension)
- case class UnliftedModule[A, B, C, D, M <: StatefulModule2[A, B, C, D]](statefulModule: M with StatefulModule2[A, B, C, D])(implicit init: InitState[M, C]) extends GenericModule[A, B] with Product with Serializable
- case class WeightNormLinear(weightsV: Constant, weightsG: Constant, bias: Option[Constant]) extends Module with Product with Serializable
- case class WithInit[A, B, C, M <: StatefulModule[A, B, C]](module: M with StatefulModule[A, B, C], init: C) extends StatefulModule[A, B, C] with Product with Serializable
- case class WrapFun[A, B, M <: GenericModule[A, B], O](module: M, fun: (A, B) => O) extends GenericModule[A, (B, O)] with Product with Serializable
- case class Yogi(parameters: Seq[(STen, PTag)], weightDecay: OptimizerHyperparameter, learningRate: OptimizerHyperparameter = simple(0.01), beta1: OptimizerHyperparameter = simple(0.9), beta2: OptimizerHyperparameter = simple(0.999), eps: Double = 1e-3, clip0: Option[Double] = None, debias: Boolean = true) extends Optimizer with Product with Serializable
The Yogi optimizer algorithm I added the decoupled weight decay term following https://arxiv.org/pdf/1711.05101.pdf
The Yogi optimizer algorithm I added the decoupled weight decay term following https://arxiv.org/pdf/1711.05101.pdf
- See also
https://papers.nips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf Algorithm 2
- case class simple(v: Double) extends OptimizerHyperparameter with Product with Serializable
Value Members
- def gradientClippingInPlace(gradients: Seq[Option[STen]], theta: STen): Unit
- def initLinear[S](in: Int, out: Int, tOpt: STenOptions)(implicit arg0: Sc[S]): Constant
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _], T3 <: GenericModule[_, _], T4 <: GenericModule[_, _], T5 <: GenericModule[_, _], T6 <: GenericModule[_, _], T7 <: GenericModule[_, _], T8 <: GenericModule[_, _], T9 <: GenericModule[_, _], T10 <: GenericModule[_, _], T11 <: GenericModule[_, _], T12 <: GenericModule[_, _]](t1: T1, t2: T2, t3: T3, t4: T4, t5: T5, t6: T6, t7: T7, t8: T8, t9: T9, t10: T10, t11: T11, t12: T12, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2], arg2: Load[T3], arg3: Load[T4], arg4: Load[T5], arg5: Load[T6], arg6: Load[T7], arg7: Load[T8], arg8: Load[T9], arg9: Load[T10], arg10: Load[T11], arg11: Load[T12]): Unit
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _], T3 <: GenericModule[_, _], T4 <: GenericModule[_, _], T5 <: GenericModule[_, _], T6 <: GenericModule[_, _], T7 <: GenericModule[_, _], T8 <: GenericModule[_, _], T9 <: GenericModule[_, _], T10 <: GenericModule[_, _], T11 <: GenericModule[_, _]](t1: T1, t2: T2, t3: T3, t4: T4, t5: T5, t6: T6, t7: T7, t8: T8, t9: T9, t10: T10, t11: T11, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2], arg2: Load[T3], arg3: Load[T4], arg4: Load[T5], arg5: Load[T6], arg6: Load[T7], arg7: Load[T8], arg8: Load[T9], arg9: Load[T10], arg10: Load[T11]): Unit
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _], T3 <: GenericModule[_, _], T4 <: GenericModule[_, _], T5 <: GenericModule[_, _], T6 <: GenericModule[_, _], T7 <: GenericModule[_, _], T8 <: GenericModule[_, _], T9 <: GenericModule[_, _], T10 <: GenericModule[_, _]](t1: T1, t2: T2, t3: T3, t4: T4, t5: T5, t6: T6, t7: T7, t8: T8, t9: T9, t10: T10, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2], arg2: Load[T3], arg3: Load[T4], arg4: Load[T5], arg5: Load[T6], arg6: Load[T7], arg7: Load[T8], arg8: Load[T9], arg9: Load[T10]): Unit
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _], T3 <: GenericModule[_, _], T4 <: GenericModule[_, _], T5 <: GenericModule[_, _], T6 <: GenericModule[_, _], T7 <: GenericModule[_, _], T8 <: GenericModule[_, _], T9 <: GenericModule[_, _]](t1: T1, t2: T2, t3: T3, t4: T4, t5: T5, t6: T6, t7: T7, t8: T8, t9: T9, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2], arg2: Load[T3], arg3: Load[T4], arg4: Load[T5], arg5: Load[T6], arg6: Load[T7], arg7: Load[T8], arg8: Load[T9]): Unit
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _], T3 <: GenericModule[_, _], T4 <: GenericModule[_, _], T5 <: GenericModule[_, _], T6 <: GenericModule[_, _], T7 <: GenericModule[_, _], T8 <: GenericModule[_, _]](t1: T1, t2: T2, t3: T3, t4: T4, t5: T5, t6: T6, t7: T7, t8: T8, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2], arg2: Load[T3], arg3: Load[T4], arg4: Load[T5], arg5: Load[T6], arg6: Load[T7], arg7: Load[T8]): Unit
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _], T3 <: GenericModule[_, _], T4 <: GenericModule[_, _], T5 <: GenericModule[_, _], T6 <: GenericModule[_, _], T7 <: GenericModule[_, _]](t1: T1, t2: T2, t3: T3, t4: T4, t5: T5, t6: T6, t7: T7, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2], arg2: Load[T3], arg3: Load[T4], arg4: Load[T5], arg5: Load[T6], arg6: Load[T7]): Unit
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _], T3 <: GenericModule[_, _], T4 <: GenericModule[_, _], T5 <: GenericModule[_, _], T6 <: GenericModule[_, _]](t1: T1, t2: T2, t3: T3, t4: T4, t5: T5, t6: T6, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2], arg2: Load[T3], arg3: Load[T4], arg4: Load[T5], arg5: Load[T6]): Unit
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _], T3 <: GenericModule[_, _], T4 <: GenericModule[_, _], T5 <: GenericModule[_, _]](t1: T1, t2: T2, t3: T3, t4: T4, t5: T5, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2], arg2: Load[T3], arg3: Load[T4], arg4: Load[T5]): Unit
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _], T3 <: GenericModule[_, _], T4 <: GenericModule[_, _]](t1: T1, t2: T2, t3: T3, t4: T4, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2], arg2: Load[T3], arg3: Load[T4]): Unit
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _], T3 <: GenericModule[_, _]](t1: T1, t2: T2, t3: T3, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2], arg2: Load[T3]): Unit
- def loadMultiple[T1 <: GenericModule[_, _], T2 <: GenericModule[_, _]](t1: T1, t2: T2, tensors: Seq[STen])(implicit arg0: Load[T1], arg1: Load[T2]): Unit
- object AdamW extends Serializable
- object BatchNorm extends Serializable
- object BatchNorm2D extends Serializable
- object Conv1D extends Serializable
- object Conv2D extends Serializable
- object Conv2DTransposed extends Serializable
- object Debug extends Serializable
- object Dropout extends Serializable
- object EitherModule extends Serializable
- object Embedding extends Serializable
- object FreeRunningRNN extends Serializable
- object Fun extends Serializable
- object GRU extends Serializable
- object GenericFun extends Serializable
- object GenericModule
- object InitState
- object LSTM extends Serializable
- object LayerNorm extends Serializable
- object LearningRateSchedule
- object LiftedModule extends Serializable
- object Linear extends Serializable
- object Load
- object LossFunctions
- object MLP
Factory for multilayer fully connected feed forward networks
Factory for multilayer fully connected feed forward networks
Returned network has the following repeated structure: [linear -> batchnorm -> nonlinearity -> dropout]*
The last block does not include the nonlinearity and the dropout.
- object MappedState extends Serializable
- object MultiheadAttention extends Serializable
- case object NoTag extends LeafTag with Product with Serializable
- object PTag
- object PositionalEmbedding
- object RAdam extends Serializable
- object RNN extends Serializable
- object Recursive extends Serializable
- object ResidualModule extends Serializable
- object SGDW extends Serializable
- object Seq2 extends Serializable
- object Seq2Seq extends Serializable
- object Seq3 extends Serializable
- object Seq4 extends Serializable
- object Seq5 extends Serializable
- object Seq6 extends Serializable
- object SeqLinear extends Serializable
- object Sequential extends Serializable
- object Shampoo extends Serializable
- object StatefulSeq2 extends Serializable
- object StatefulSeq3 extends Serializable
- object StatefulSeq4 extends Serializable
- object StatefulSeq5 extends Serializable
- object TrainingMode
- object Transformer extends Serializable
- object TransformerDecoder extends Serializable
- object TransformerDecoderBlock extends Serializable
- object TransformerEmbedding extends Serializable
- object TransformerEncoder extends Serializable
- object TransformerEncoderBlock extends Serializable
- object UnliftedModule extends Serializable
- object WeightNormLinear extends Serializable
- object WithInit extends Serializable
- object WrapFun extends Serializable
- object Yogi extends Serializable
- object sequence
- object statefulSequence