package languagemodel
- Alphabetic
- Public
- Protected
Type Members
- case class LanguageModelInput(tokens: Constant, maxLength: Option[STen], positions: Option[STen]) extends Product with Serializable
Input to language model
Input to language model
- tokens
batch x sequence, type long
- maxLength
batch x sequence OR batch, see maskedSoftmax. Used to define masking of the attention matrix. Use cases:
- Left-to-right (causal) attention with uniform sequence length. In this case use a batch x sequence 2D matrix with arange(0,sequence) in each row.
- Variable length sequences with bidirectional attention. In this case use a 1D [batch] vector with the real length of each sequence (rest are padded).
- If empty the attention matrix is not masked
- positions
batch x sequence, type long in [0,sequence], selects positions. Final LM logits are computed on the selected positions. If empty then selects all positions.
- case class LanguageModelLoss(languageModel: LanguageModelModule, loss: LossFunction) extends GenericModule[LossInput, Variable] with Product with Serializable
Module with the language model and a loss
Module with the language model and a loss
Main trainig entry point of the language model
- case class LanguageModelModule(tokenEmbedding: Embedding, positionEmbedding: Embedding, encoder: TransformerEncoder, finalNorm: LayerNorm) extends GenericModule[LanguageModelInput, LanguageModelOutput] with Product with Serializable
Transformer based language model module
Transformer based language model module
Initial embedding is the sum of token and position embedding. Token embedding is a learned embedding. Position embedding is also a learned embedding (not sinusoidal etc).
Initial embeddings are fed into layers of transformer blocks. Attention masking is governed by the input similarly as described in chapter 11.3.2.1 in d2l v1.0.0-beta0.
Selected sequence positions in the output of the transformer chain are linearly mapped back into the desired vocabulary size.
- case class LanguageModelOutput(encoded: Variable, languageModelLogits: Variable) extends Product with Serializable
Output of LM
Output of LM
- encoded
encoded: float tensor of size (batch, sequence length, embedding dimension) holds per token embeddings
- languageModelLogits
float tensor of size (batch, sequence length, vocabulary size) holds per token logits. Use logSoftMax(dim=2) to get log probabilities.
- case class LanguageModelOutputNonVariable(encoded: STen, languageModelLogits: STen) extends Product with Serializable
- case class LossInput(input: LanguageModelInput, languageModelTarget: STen) extends Product with Serializable
Language model input and target for loss calculation
Language model input and target for loss calculation
- languageModelTarget
batch x sequence
Value Members
- object LanguageModelInput extends Serializable
- object LanguageModelLoss extends Serializable
- object LanguageModelModule extends Serializable
- object LanguageModelOutputNonVariable extends Serializable
- object LossInput extends Serializable