biotransformers.lightning_utils.models
Contents
biotransformers.lightning_utils.models¶
Module Contents¶
Classes¶
Create lightning model to use ddp |
- class biotransformers.lightning_utils.models.LightningModule(model, alphabet, lr: float, warmup_end_lr: float, warmup_updates: int = 10, warmup_init_lr: float = 1e-07)¶
Bases:
pytorch_lightning.LightningModuleCreate lightning model to use ddp
- forward(self, x)¶
- configure_optimizers(self) → Tuple[List[torch.optim.Optimizer], List[Dict]]¶
Configure the optimizer and learning rate scheduler.
- Returns
list of optimizers.
list of lr schedulers.
- cross_entropy_loss(self, logits, targets)¶
- training_step(self, train_batch, batch_idx)¶
- validation_step(self, val_batch, batch_idx)¶
Log the loss and metrics for a batch.
- Parameters
batch – batch input.
batch_idx – index of the batch.
- get_tensor_accuracy(self, logits: torch.Tensor, targets: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor]¶
Calculate accuracy for multi-masking, summed over batch.
- Parameters
logits – prediction from the model, shape = (batch, len_tokens, len_vocab)
targets – ground truth, shape = (batch, len_tokens)
- Returns
accuracy value.