case class TransformerEncoderBlock(attention: MultiheadAttention, layerNorm1: LayerNorm, layerNorm2: LayerNorm, w1: Constant, b1: Constant, w2: Constant, b2: Constant, scale1: Constant, scale2: Constant, dropout: Double, train: Boolean, gptOrder: Boolean) extends GenericModule[(Variable, Option[STen]), Variable] with Product with Serializable
A single block of the transformer self attention encoder using GELU
Input is (data, maxLength)
where data
is (batch, sequence, input
dimension), double tensor maxLength
is a 1D or 2D long tensor used for
attention masking.
The order of operations depends on gptOrder param. If gptOrder
is true
then:
- y = attention(norm(input))+input
- result = mlp(norm(y))+y
- Note that in this case there is no normalization at the end of the transformer. One may wants to add one separately. This is how GPT2 is defined in hugging face or nanoGPT.
- Note that the residual connection has a path which does not flow through the normalization.
- + dimension wise learnable scale parameter in each residual path
If gptOrder
is false then:
- y = norm(attention(input)+input )
- result = norm(mlp(y)+y)
- This follows chapter 11.7 in d2l.ai v1.0.0-beta0. (Same as in https://arxiv.org/pdf/1706.03762.pdf)
- Note that the residual connection has a path which flows through the normalization.
Output is (bach, sequence, output dimension)
- Alphabetic
- By Inheritance
- TransformerEncoderBlock
- Serializable
- Product
- Equals
- GenericModule
- AnyRef
- Any
- Hide All
- Show All
- Public
- Protected
Instance Constructors
Value Members
- final def !=(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- final def ##: Int
- Definition Classes
- AnyRef → Any
- final def ==(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- def apply[S](a: (Variable, Option[STen]))(implicit arg0: Sc[S]): Variable
Alias of forward
Alias of forward
- Definition Classes
- GenericModule
- final def asInstanceOf[T0]: T0
- Definition Classes
- Any
- val attention: MultiheadAttention
- val b1: Constant
- val b2: Constant
- def clone(): AnyRef
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.CloneNotSupportedException]) @IntrinsicCandidate() @native()
- val dropout: Double
- final def eq(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- def forward[S](x: (Variable, Option[STen]))(implicit arg0: Sc[S]): Variable
The implementation of the function.
The implementation of the function.
In addition of
x
it can also use all thestate to compute its value.
- Definition Classes
- TransformerEncoderBlock → GenericModule
- final def getClass(): Class[_ <: AnyRef]
- Definition Classes
- AnyRef → Any
- Annotations
- @IntrinsicCandidate() @native()
- val gptOrder: Boolean
- final def gradients(loss: Variable, zeroGrad: Boolean = true): Seq[Option[STen]]
Computes the gradient of loss with respect to the parameters.
Computes the gradient of loss with respect to the parameters.
- Definition Classes
- GenericModule
- final def isInstanceOf[T0]: Boolean
- Definition Classes
- Any
- val layerNorm1: LayerNorm
- val layerNorm2: LayerNorm
- final def learnableParameters: Long
Returns the total number of optimizable parameters.
Returns the total number of optimizable parameters.
- Definition Classes
- GenericModule
- final def ne(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- final def notify(): Unit
- Definition Classes
- AnyRef
- Annotations
- @IntrinsicCandidate() @native()
- final def notifyAll(): Unit
- Definition Classes
- AnyRef
- Annotations
- @IntrinsicCandidate() @native()
- final def parameters: Seq[(Constant, PTag)]
Returns the state variables which need gradient computation.
Returns the state variables which need gradient computation.
- Definition Classes
- GenericModule
- def productElementNames: Iterator[String]
- Definition Classes
- Product
- val scale1: Constant
- val scale2: Constant
- def state: List[(Constant, LeafTag)]
List of optimizable, or non-optimizable, but stateful parameters
List of optimizable, or non-optimizable, but stateful parameters
Stateful means that the state is carried over the repeated forward calls.
- Definition Classes
- TransformerEncoderBlock → GenericModule
- final def synchronized[T0](arg0: => T0): T0
- Definition Classes
- AnyRef
- val train: Boolean
- val w1: Constant
- val w2: Constant
- final def wait(arg0: Long, arg1: Int): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def wait(arg0: Long): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException]) @native()
- final def wait(): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def zeroGrad(): Unit
- Definition Classes
- GenericModule
Deprecated Value Members
- def finalize(): Unit
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.Throwable]) @Deprecated
- Deprecated
(Since version 9)