case class MaskedLanguageModelModule(mlp: MaskedLanguageModelModule.MLP) extends GenericModule[(Variable, STen), Variable] with Product with Serializable
Masked Language Model Input of (embedding, positions) Embedding of size (batch, num tokens, embedding dim) Positions of size (batch, max num tokens) long tensor indicating which positions to make predictions on Output (batch, len(Positions), vocabulary size)
- Alphabetic
- By Inheritance
- MaskedLanguageModelModule
- Serializable
- Product
- Equals
- GenericModule
- AnyRef
- Any
- Hide All
- Show All
- Public
- Protected
Instance Constructors
- new MaskedLanguageModelModule(mlp: MaskedLanguageModelModule.MLP)
Value Members
- final def !=(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- final def ##: Int
- Definition Classes
- AnyRef → Any
- final def ==(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- def apply[S](a: (Variable, STen))(implicit arg0: Sc[S]): Variable
Alias of forward
Alias of forward
- Definition Classes
- GenericModule
- final def asInstanceOf[T0]: T0
- Definition Classes
- Any
- def clone(): AnyRef
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.CloneNotSupportedException]) @IntrinsicCandidate() @native()
- final def eq(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- def forward[S](x: (Variable, STen))(implicit arg0: Sc[S]): Variable
The implementation of the function.
The implementation of the function.
In addition of
x
it can also use all thestate to compute its value.
- Definition Classes
- MaskedLanguageModelModule → GenericModule
- final def getClass(): Class[_ <: AnyRef]
- Definition Classes
- AnyRef → Any
- Annotations
- @IntrinsicCandidate() @native()
- final def gradients(loss: Variable, zeroGrad: Boolean = true): Seq[Option[STen]]
Computes the gradient of loss with respect to the parameters.
Computes the gradient of loss with respect to the parameters.
- Definition Classes
- GenericModule
- final def isInstanceOf[T0]: Boolean
- Definition Classes
- Any
- final def learnableParameters: Long
Returns the total number of optimizable parameters.
Returns the total number of optimizable parameters.
- Definition Classes
- GenericModule
- val mlp: MaskedLanguageModelModule.MLP
- final def ne(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- final def notify(): Unit
- Definition Classes
- AnyRef
- Annotations
- @IntrinsicCandidate() @native()
- final def notifyAll(): Unit
- Definition Classes
- AnyRef
- Annotations
- @IntrinsicCandidate() @native()
- final def parameters: Seq[(Constant, PTag)]
Returns the state variables which need gradient computation.
Returns the state variables which need gradient computation.
- Definition Classes
- GenericModule
- def productElementNames: Iterator[String]
- Definition Classes
- Product
- def state: Seq[(Constant, PTag)]
List of optimizable, or non-optimizable, but stateful parameters
List of optimizable, or non-optimizable, but stateful parameters
Stateful means that the state is carried over the repeated forward calls.
- Definition Classes
- MaskedLanguageModelModule → GenericModule
- final def synchronized[T0](arg0: => T0): T0
- Definition Classes
- AnyRef
- final def wait(arg0: Long, arg1: Int): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def wait(arg0: Long): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException]) @native()
- final def wait(): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def zeroGrad(): Unit
- Definition Classes
- GenericModule
Deprecated Value Members
- def finalize(): Unit
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.Throwable]) @Deprecated
- Deprecated
(Since version 9)