:mod:`biotransformers.lightning_utils.models`
=============================================

.. py:module:: biotransformers.lightning_utils.models


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   biotransformers.lightning_utils.models.LightningModule




.. class:: LightningModule(model, alphabet, lr: float, warmup_end_lr: float, warmup_updates: int = 10, warmup_init_lr: float = 1e-07)


   Bases: :py:obj:`pytorch_lightning.LightningModule`

   Create lightning model to use ddp

   .. method:: forward(self, x)


   .. method:: configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[Dict]]

      Configure the optimizer and learning rate scheduler.

      :returns:

                - list of optimizers.
                - list of lr schedulers.


   .. method:: cross_entropy_loss(self, logits, targets)


   .. method:: training_step(self, train_batch, batch_idx)


   .. method:: validation_step(self, val_batch, batch_idx)

      Log the loss and metrics for a batch.

      :param batch: batch input.
      :param batch_idx: index of the batch.


   .. method:: get_tensor_accuracy(self, logits: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]

      Calculate accuracy for multi-masking, summed over batch.

      :param logits: prediction from the model, shape = (batch, len_tokens, len_vocab)
      :param targets: ground truth, shape = (batch, len_tokens)

      :returns: accuracy value.



