package bert
- Alphabetic
- Public
- Protected
Type Members
- case class BertEncoder(tokenEmbedding: Embedding, segmentEmbedding: Embedding, positionalEmbedding: Constant, blocks: Seq[TransformerEncoderBlock]) extends GenericModule[(Variable, Variable, Option[STen]), Variable] with Product with Serializable
BertEncoder module
BertEncoder module
Input is
(tokens, segments, maxLength)
wheretokens
andsegments
are both (batch,num tokens) long tensor. maxLength is a 1D long tensor indicating the length of input sequencesOutput is (batch, num tokens, out dimension)
- case class BertLoss(pretrain: BertPretrainModule, mlmLoss: LossFunction, wholeSentenceLoss: LossFunction) extends GenericModule[BertLossInput, Variable] with Product with Serializable
- case class BertLossInput(input: BertPretrainInput, maskedLanguageModelTarget: STen, wholeSentenceTarget: STen) extends Product with Serializable
Input to BertLoss module
Input to BertLoss module
- input: feature data, see documentation of BertPretrainInput
- maskedLanguageModelTarget: long tensor of (batch size, masked positions (variable)). Values are the true tokens masked out at the positions in input.positions
- wholeSentenceTarget: float tensor of size (batch size). Values are truth targets for the whole sentence loss which is a BCEWithLogitLoss. Values are floats in [0,1].
- case class BertPretrainInput(tokens: Constant, segments: Constant, positions: STen, maxLength: Option[STen]) extends Product with Serializable
Input for BERT pretrain module
Input for BERT pretrain module
- Tokens: Long tensor of size (batch, sequence length). Sequence length includes cls and sep tokens. Values are tokens of the input vocabulary and 4 additional control tokens: cls, sep, pad, mask. First token must be cls.
- Segments: Long tensor of size (batch, sequence length). Values are segment tokens.
- Positions: Long tensor of size (batch, mask size (variable)). Values are indices in [0,sequence length) selecting masked sequence positions. They never select positions of cls, sep, pad.
- maxLength: 1D long tensor of size (sequence length). Values are in [0,sequence_length]. Tokens at positions higher or equal than the sequence length are ignored.
- case class BertPretrainModule(encoder: BertEncoder, mlm: MaskedLanguageModelModule, wholeSentenceBinaryClassifier: BertPretrainModule.MLP) extends GenericModule[BertPretrainInput, BertPretrainOutput] with Product with Serializable
- case class BertPretrainOutput(encoded: Variable, languageModelScores: Variable, wholeSentenceBinaryClassifierScore: Variable) extends Product with Serializable
Output of BERT
Output of BERT
- encoded: float tensor of size (batch, sequence length, embedding dimension ) holds per token embeddings
- languageModelScores: float tensor of size (batch, sequence length, vocabulary size) holds per token log probability distributions (from logSoftMax)
- wholeSentenceBinaryClassifierScore: float tensor of size (batch) holds the output score of the whole sentence prediction task suitable for BCELogitLoss
- case class MaskedLanguageModelModule(mlp: MaskedLanguageModelModule.MLP) extends GenericModule[(Variable, STen), Variable] with Product with Serializable
Masked Language Model Input of (embedding, positions) Embedding of size (batch, num tokens, embedding dim) Positions of size (batch, max num tokens) long tensor indicating which positions to make predictions on Output (batch, len(Positions), vocabulary size)
Value Members
- object BertEncoder extends Serializable
- object BertLoss extends Serializable
- object BertLossInput extends Serializable
- object BertPretrainInput extends Serializable
- object BertPretrainModule extends Serializable
- object MaskedLanguageModelModule extends Serializable