biotransformers.lightning_utils.models

Module Contents

Classes

LightningModule

Create lightning model to use ddp

class biotransformers.lightning_utils.models.LightningModule(model, alphabet, lr: float, warmup_end_lr: float, warmup_updates: int = 10, warmup_init_lr: float = 1e-07)

Bases: pytorch_lightning.LightningModule

Create lightning model to use ddp

forward(self, x)
configure_optimizers(self)Tuple[List[torch.optim.Optimizer], List[Dict]]

Configure the optimizer and learning rate scheduler.

Returns

  • list of optimizers.

  • list of lr schedulers.

cross_entropy_loss(self, logits, targets)
training_step(self, train_batch, batch_idx)
validation_step(self, val_batch, batch_idx)

Log the loss and metrics for a batch.

Parameters
  • batch – batch input.

  • batch_idx – index of the batch.

get_tensor_accuracy(self, logits: torch.Tensor, targets: torch.Tensor)Tuple[torch.Tensor, torch.Tensor]

Calculate accuracy for multi-masking, summed over batch.

Parameters
  • logits – prediction from the model, shape = (batch, len_tokens, len_vocab)

  • targets – ground truth, shape = (batch, len_tokens)

Returns

accuracy value.