



package languagemodel

  1. Alphabetic
  1. Public
  2. Protected

Type Members

  1. case class LanguageModelInput(tokens: Constant, maxLength: Option[STen], positions: Option[STen]) extends Product with Serializable

    Input to language model

    Input to language model


    batch x sequence, type long


    batch x sequence OR batch, see maskedSoftmax. Used to define masking of the attention matrix. Use cases:

    • Left-to-right (causal) attention with uniform sequence length. In this case use a batch x sequence 2D matrix with arange(0,sequence) in each row.
    • Variable length sequences with bidirectional attention. In this case use a 1D [batch] vector with the real length of each sequence (rest are padded).
    • If empty the attention matrix is not masked

    batch x sequence, type long in [0,sequence], selects positions. Final LM logits are computed on the selected positions. If empty then selects all positions.

  2. case class LanguageModelLoss(languageModel: LanguageModelModule, loss: LossFunction) extends GenericModule[LossInput, Variable] with Product with Serializable

    Module with the language model and a loss

    Module with the language model and a loss

    Main trainig entry point of the language model

  3. case class LanguageModelModule(tokenEmbedding: Embedding, positionEmbedding: Embedding, encoder: TransformerEncoder, finalNorm: LayerNorm) extends GenericModule[LanguageModelInput, LanguageModelOutput] with Product with Serializable

    Transformer based language model module

    Transformer based language model module

    Initial embedding is the sum of token and position embedding. Token embedding is a learned embedding. Position embedding is also a learned embedding (not sinusoidal etc).

    Initial embeddings are fed into layers of transformer blocks. Attention masking is governed by the input similarly as described in chapter in d2l v1.0.0-beta0.

    Selected sequence positions in the output of the transformer chain are linearly mapped back into the desired vocabulary size.

  4. case class LanguageModelOutput(encoded: Variable, languageModelLogits: Variable) extends Product with Serializable

    Output of LM

    Output of LM


    encoded: float tensor of size (batch, sequence length, embedding dimension) holds per token embeddings


    float tensor of size (batch, sequence length, vocabulary size) holds per token logits. Use logSoftMax(dim=2) to get log probabilities.

  5. case class LanguageModelOutputNonVariable(encoded: STen, languageModelLogits: STen) extends Product with Serializable
  6. case class LossInput(input: LanguageModelInput, languageModelTarget: STen) extends Product with Serializable

    Language model input and target for loss calculation

    Language model input and target for loss calculation


    batch x sequence
