object MultiheadAttention extends Serializable
- Alphabetic
 - By Inheritance
 
- MultiheadAttention
 - Serializable
 - AnyRef
 - Any
 
- Hide All
 - Show All
 
- Public
 - Protected
 
Value Members
-   final  def !=(arg0: Any): Boolean
- Definition Classes
 - AnyRef → Any
 
 -   final  def ##: Int
- Definition Classes
 - AnyRef → Any
 
 -   final  def ==(arg0: Any): Boolean
- Definition Classes
 - AnyRef → Any
 
 -  def apply[S](dQ: Int, dK: Int, dV: Int, hiddenPerHead: Int, out: Int, dropout: Double, numHeads: Int, tOpt: STenOptions, linearized: Boolean, causalMask: Boolean)(implicit arg0: Sc[S]): MultiheadAttention
 -   final  def asInstanceOf[T0]: T0
- Definition Classes
 - Any
 
 -    def clone(): AnyRef
- Attributes
 - protected[lang]
 - Definition Classes
 - AnyRef
 - Annotations
 - @throws(classOf[java.lang.CloneNotSupportedException]) @IntrinsicCandidate() @native()
 
 -   final  def eq(arg0: AnyRef): Boolean
- Definition Classes
 - AnyRef
 
 -    def equals(arg0: AnyRef): Boolean
- Definition Classes
 - AnyRef → Any
 
 -   final  def getClass(): Class[_ <: AnyRef]
- Definition Classes
 - AnyRef → Any
 - Annotations
 - @IntrinsicCandidate() @native()
 
 -    def hashCode(): Int
- Definition Classes
 - AnyRef → Any
 - Annotations
 - @IntrinsicCandidate() @native()
 
 -   final  def isInstanceOf[T0]: Boolean
- Definition Classes
 - Any
 
 -    def linearizedAttention[S](query: Variable, keys: Variable, values: Variable, maxLength: Option[STen], dropout: Double, trainDropout: Boolean)(implicit arg0: Sc[S]): Variable
Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf
Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf
replaces exp(a dot b) with f(a) dot f(b) where f is any elementwise function, in the paper f(x) = elu(x)+1 here f(x) = swish1(x)+1 due to this decomposition a more efficient configuration of the chained matrix multiplication may be used: (Q Kt) V = Q (Kt V)
applies masking according to maskedSoftmax
- query
 batch x num queries x key dim
- maxLength
 batch x num queries OR batch , type long
- returns
 batch x num queries x value dim
 -  implicit val load: Load[MultiheadAttention]
 -    def maskedSoftmax[S](input: Variable, maxLength: STen)(implicit arg0: Sc[S]): Variable
- input
 batch x seq x ???
- maxLength
 batch x seq OR batch , long
- returns
 batch x seq x ???
 -    def multiheadAttention[S](query: Variable, keys: Variable, values: Variable, maxLength: Option[STen], dropout: Double, trainDropout: Boolean, wQuery: Variable, wKeys: Variable, wValues: Variable, wOutput: Variable, numHeads: Int, linearized: Boolean, causalMask: Boolean)(implicit arg0: Sc[S]): Variable
Multi-head scaled dot product attention
Multi-head scaled dot product attention
See chapter 11.5 in d2l v1.0.0-beta0
Attention masking is implemented similarly to chapter 11.3.2.1 in d2l.ai v1.0.0-beta0. It supports unmasked attention, attention on variable length input, and left-to-right attention.
- query
 batch x num queries x dq
- maxLength
 batch x num queries OR batch , type long
- wQuery
 dq x hidden
- wKeys
 dk x hidden
- wValues
 dv x hidden
- wOutput
 hidden x po
- numHeads
 number of output heads, must be divisible by hidden
- linearized
 if true uses linearized attention. if false used scaled dot product attention
- returns
 batch x num queries x po
 -   final  def ne(arg0: AnyRef): Boolean
- Definition Classes
 - AnyRef
 
 -   final  def notify(): Unit
- Definition Classes
 - AnyRef
 - Annotations
 - @IntrinsicCandidate() @native()
 
 -   final  def notifyAll(): Unit
- Definition Classes
 - AnyRef
 - Annotations
 - @IntrinsicCandidate() @native()
 
 -    def scaledDotProductAttention[S](query: Variable, keys: Variable, values: Variable, maxLength: Option[STen], dropout: Double, trainDropout: Boolean)(implicit arg0: Sc[S]): Variable
Scaled dot product attention
Scaled dot product attention
if maxLength is 2D: (batch,query,key) locations where maxLength(batch,query) > key are ignored.
if maxLength is 1D: (batch,query,key) locations where maxLength(batch) > query are ignored
See chapter 11.3.3 in d2l v1.0.0-beta0
- query
 batch x num queries x key dim
- maxLength
 batch x num queries OR batch, type long
- returns
 batch x num queries x value dim
 -    def sequenceMask[S](maxLength: STen, maskable: Variable, fill: Double)(implicit arg0: Sc[S]): Variable
Masks on the 3rd axis of maskable depending on the dimensions of maxLength
Masks on the 3rd axis of maskable depending on the dimensions of maxLength
if maxLength is 2D: (batch,query,key) locations where maxLength(batch,query) > key are ignored.
if maxLength is 1D: (batch,query,key) locations where maxLength(batch) > query are ignored
 -    def sequenceMaskValidLength1D[S](maxLength: STen, maskable: Variable, fill: Double)(implicit arg0: Sc[S]): Variable
Masks the maskable(i,j,k) cell iff k >= maxLength(i)
Masks the maskable(i,j,k) cell iff k >= maxLength(i)
- maxLength
 batch, type Long
- maskable
 batch x seq x ???
- fill
 scalar
 -    def sequenceMaskValidLength2D[S](maxLength: STen, maskable: Variable, fill: Double)(implicit arg0: Sc[S]): Variable
Masks the maskable(i,j,k) cell iff k >= maxLength(i,j)
Masks the maskable(i,j,k) cell iff k >= maxLength(i,j)
Masks some elements on the last (3rd) axis of maskable
- maxLength
 batch x seq, type Long
- maskable
 batch x seq x ???
- fill
 scalar
 -   final  def synchronized[T0](arg0: => T0): T0
- Definition Classes
 - AnyRef
 
 -    def toString(): String
- Definition Classes
 - AnyRef → Any
 
 -  implicit val trainingMode: TrainingMode[MultiheadAttention]
 -   final  def wait(arg0: Long, arg1: Int): Unit
- Definition Classes
 - AnyRef
 - Annotations
 - @throws(classOf[java.lang.InterruptedException])
 
 -   final  def wait(arg0: Long): Unit
- Definition Classes
 - AnyRef
 - Annotations
 - @throws(classOf[java.lang.InterruptedException]) @native()
 
 -   final  def wait(): Unit
- Definition Classes
 - AnyRef
 - Annotations
 - @throws(classOf[java.lang.InterruptedException])
 
 -  case object WeightsK extends LeafTag with Product with Serializable
 -  case object WeightsO extends LeafTag with Product with Serializable
 -  case object WeightsQ extends LeafTag with Product with Serializable
 -  case object WeightsV extends LeafTag with Product with Serializable
 
Deprecated Value Members
-    def finalize(): Unit
- Attributes
 - protected[lang]
 - Definition Classes
 - AnyRef
 - Annotations
 - @throws(classOf[java.lang.Throwable]) @Deprecated
 - Deprecated
 (Since version 9)