object MultiheadAttention extends Serializable
- Alphabetic
- By Inheritance
- MultiheadAttention
- Serializable
- AnyRef
- Any
- Hide All
- Show All
- Public
- Protected
Value Members
- final def !=(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- final def ##: Int
- Definition Classes
- AnyRef → Any
- final def ==(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- def apply[S](dQ: Int, dK: Int, dV: Int, hiddenPerHead: Int, out: Int, dropout: Double, numHeads: Int, tOpt: STenOptions, linearized: Boolean, causalMask: Boolean)(implicit arg0: Sc[S]): MultiheadAttention
- final def asInstanceOf[T0]: T0
- Definition Classes
- Any
- def clone(): AnyRef
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.CloneNotSupportedException]) @IntrinsicCandidate() @native()
- final def eq(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- def equals(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef → Any
- final def getClass(): Class[_ <: AnyRef]
- Definition Classes
- AnyRef → Any
- Annotations
- @IntrinsicCandidate() @native()
- def hashCode(): Int
- Definition Classes
- AnyRef → Any
- Annotations
- @IntrinsicCandidate() @native()
- final def isInstanceOf[T0]: Boolean
- Definition Classes
- Any
- def linearizedAttention[S](query: Variable, keys: Variable, values: Variable, maxLength: Option[STen], dropout: Double, trainDropout: Boolean)(implicit arg0: Sc[S]): Variable
Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf
Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf
replaces exp(a dot b) with f(a) dot f(b) where f is any elementwise function, in the paper f(x) = elu(x)+1 here f(x) = swish1(x)+1 due to this decomposition a more efficient configuration of the chained matrix multiplication may be used: (Q Kt) V = Q (Kt V)
applies masking according to maskedSoftmax
- query
batch x num queries x key dim
- maxLength
batch x num queries OR batch , type long
- returns
batch x num queries x value dim
- implicit val load: Load[MultiheadAttention]
- def maskedSoftmax[S](input: Variable, maxLength: STen)(implicit arg0: Sc[S]): Variable
- input
batch x seq x ???
- maxLength
batch x seq OR batch , long
- returns
batch x seq x ???
- def multiheadAttention[S](query: Variable, keys: Variable, values: Variable, maxLength: Option[STen], dropout: Double, trainDropout: Boolean, wQuery: Variable, wKeys: Variable, wValues: Variable, wOutput: Variable, numHeads: Int, linearized: Boolean, causalMask: Boolean)(implicit arg0: Sc[S]): Variable
Multi-head scaled dot product attention
Multi-head scaled dot product attention
See chapter 11.5 in d2l v1.0.0-beta0
Attention masking is implemented similarly to chapter 11.3.2.1 in d2l.ai v1.0.0-beta0. It supports unmasked attention, attention on variable length input, and left-to-right attention.
- query
batch x num queries x dq
- maxLength
batch x num queries OR batch , type long
- wQuery
dq x hidden
- wKeys
dk x hidden
- wValues
dv x hidden
- wOutput
hidden x po
- numHeads
number of output heads, must be divisible by hidden
- linearized
if true uses linearized attention. if false used scaled dot product attention
- returns
batch x num queries x po
- final def ne(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- final def notify(): Unit
- Definition Classes
- AnyRef
- Annotations
- @IntrinsicCandidate() @native()
- final def notifyAll(): Unit
- Definition Classes
- AnyRef
- Annotations
- @IntrinsicCandidate() @native()
- def scaledDotProductAttention[S](query: Variable, keys: Variable, values: Variable, maxLength: Option[STen], dropout: Double, trainDropout: Boolean)(implicit arg0: Sc[S]): Variable
Scaled dot product attention
Scaled dot product attention
if maxLength is 2D: (batch,query,key) locations where maxLength(batch,query) > key are ignored.
if maxLength is 1D: (batch,query,key) locations where maxLength(batch) > query are ignored
See chapter 11.3.3 in d2l v1.0.0-beta0
- query
batch x num queries x key dim
- maxLength
batch x num queries OR batch, type long
- returns
batch x num queries x value dim
- def sequenceMask[S](maxLength: STen, maskable: Variable, fill: Double)(implicit arg0: Sc[S]): Variable
Masks on the 3rd axis of maskable depending on the dimensions of maxLength
Masks on the 3rd axis of maskable depending on the dimensions of maxLength
if maxLength is 2D: (batch,query,key) locations where maxLength(batch,query) > key are ignored.
if maxLength is 1D: (batch,query,key) locations where maxLength(batch) > query are ignored
- def sequenceMaskValidLength1D[S](maxLength: STen, maskable: Variable, fill: Double)(implicit arg0: Sc[S]): Variable
Masks the maskable(i,j,k) cell iff k >= maxLength(i)
Masks the maskable(i,j,k) cell iff k >= maxLength(i)
- maxLength
batch, type Long
- maskable
batch x seq x ???
- fill
scalar
- def sequenceMaskValidLength2D[S](maxLength: STen, maskable: Variable, fill: Double)(implicit arg0: Sc[S]): Variable
Masks the maskable(i,j,k) cell iff k >= maxLength(i,j)
Masks the maskable(i,j,k) cell iff k >= maxLength(i,j)
Masks some elements on the last (3rd) axis of maskable
- maxLength
batch x seq, type Long
- maskable
batch x seq x ???
- fill
scalar
- final def synchronized[T0](arg0: => T0): T0
- Definition Classes
- AnyRef
- def toString(): String
- Definition Classes
- AnyRef → Any
- implicit val trainingMode: TrainingMode[MultiheadAttention]
- final def wait(arg0: Long, arg1: Int): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def wait(arg0: Long): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException]) @native()
- final def wait(): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- case object WeightsK extends LeafTag with Product with Serializable
- case object WeightsO extends LeafTag with Product with Serializable
- case object WeightsQ extends LeafTag with Product with Serializable
- case object WeightsV extends LeafTag with Product with Serializable
Deprecated Value Members
- def finalize(): Unit
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.Throwable]) @Deprecated
- Deprecated
(Since version 9)