:mod:`biotransformers.lightning_utils.data`
===========================================

.. py:module:: biotransformers.lightning_utils.data


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   biotransformers.lightning_utils.data.AlphabetDataLoader
   biotransformers.lightning_utils.data.BatchWithConstantNumberTokensSampler
   biotransformers.lightning_utils.data.DistributedBatchWithConstantNumberTokensSampler
   biotransformers.lightning_utils.data.BatchWithConstantNumberTokensDataset
   biotransformers.lightning_utils.data.BatchWithConstantNumberTokensDataModule



Functions
~~~~~~~~~

.. autoapisummary::

   biotransformers.lightning_utils.data.convert_ckpt_to_statedict
   biotransformers.lightning_utils.data.worker_init_fn
   biotransformers.lightning_utils.data.mask_seq
   biotransformers.lightning_utils.data.collate_fn
   biotransformers.lightning_utils.data.crop_sequence
   biotransformers.lightning_utils.data.get_batch_indices



.. class:: AlphabetDataLoader(prepend_bos: bool, append_eos: bool, mask_idx: int, pad_idx: int, standard_toks: List[str], model_dir: str, lambda_toks_to_ids: Callable, lambda_tokenizer: Callable)


   Class that carries tokenizer information

   .. method:: tok_to_idx(self, x)


   .. method:: tokenizer(self)

      Return seq-token based on sequence



.. function:: convert_ckpt_to_statedict(checkpoint_state_dict: collections.OrderedDict) -> collections.OrderedDict

   This function convert a state_dict coming form pytorch lightning checkpoint to
   a state_dict model that can be load directly in the bio-transformers model.

   The keys are updated so that it  m.jionatches those in the bio-transformers

   :param checkpoint_state_dict: a state_dict loaded from a checkpoint


.. function:: worker_init_fn(worker_id: int)

   Set numpy random seed for each worker.

   https://github.com/pytorch/pytorch/issues/5059#issuecomment-404232359

   :param worker_id: unique id for each worker


.. function:: mask_seq(seq: str, tokens: torch.Tensor, prepend_bos: bool, mask_idx: int, pad_idx: int, masking_ratio: float, masking_prob: float, random_token_prob: float, random_token_indices: List[int]) -> Tuple[torch.Tensor, torch.Tensor]

   Mask one sequence randomly.

   :param seq: string of the sequence.
   :param tokens: tokens corresponding to the sequence, length can be longer than the seq.
   :param prepend_bos: if tokenizer adds <bos> token
   :param mask_idx: index of the mask token
   :param pad_idx: index of the padding token
   :param masking_ratio: ratio of tokens to be masked.
   :param masking_prob: probability that the chose token is replaced with a mask token.
   :param random_token_prob: probability that the chose token is replaced with a random token.
   :param random_token_indices: list of token indices that random replacement selects from.

   :returns: masked tokens
             targets: same length as tokens
   :rtype: tokens


.. function:: collate_fn(samples: Sequence[Tuple[str, str]], tokenizer: esm.data.BatchConverter, alphabet: AlphabetDataLoader, masking_ratio: float, masking_prob: float, random_token_prob: float) -> Tuple[torch.Tensor, torch.Tensor]

   Collate function to mask tokens.

   :param samples: a sequences of (label, seq).
   :param tokenizer: facebook tokenizer, that accepts sequences of (label, seq_str)
                     and outputs (labels, seq_strs, tokens).
   :param alphabet: facebook alphabet.
   :param masking_ratio: ratio of tokens to be masked.
   :param masking_prob: probability that the chose token is replaced with a mask token.
   :param random_token_prob: probability that the chose token is replaced with a random token.

   :returns: model input
             targets: model target
             mask_indices: indices of masked tokens
   :rtype: tokens


.. function:: crop_sequence(sequence: str, crop_length: int) -> str

   If the length of the sequence is superior to crop_length, crop randomly
   the sequence to get the proper length.


.. function:: get_batch_indices(sequence_strs: List[str], toks_per_batch: int, crop_sizes: Tuple[int, int] = (600, 1200), seed: int = 0) -> List[List[List[Tuple[int, int]]]]

   This sampler aims to create batches that do not contain fixed number of sequences
   but rather constant number of tokens. Some the batch can contain a few long
   sequences or multiple small ones.

   This sampler returns batches of indices to achieve this property. It also decides
   if sequences must be cropped and return the desired length. The cropping length is
   sampled randomly for each sequence at each epoch in the range of crop_sizes values.

   This sampler computes a list of list of tuple which contains indices and
   lengths of sequences inside the batch.

   .. rubric:: Example

   returning [[(1, 100), (3, 600)],[(4, 100), (7, 1200), (10, 600)], [(12, 1000)]]
   means that the first batch will be composed of sequence at index 1 and 3 with
   lengths 100 and 600. The third batch contains only sequence 12 with a length
   of 1000.

   :param sequence_strs: list of string
   :param toks_per_batch: Maximum number of token per batch
   :type toks_per_batch: int
   :param extra_toks_per_seq: . Defaults to 0.
   :type extra_toks_per_seq: int, optional
   :param crop_sizes: min and max sequence lengths when cropping
   :type crop_sizes: Tuple[int, int]
   :param seed: seed to be used for random generator
   :type seed: int

   :returns: List of batches indexes and lengths
   :rtype: List


.. class:: BatchWithConstantNumberTokensSampler(sequence_strs: List[str], toks_per_batch: int, crop_sizes: Tuple[int, int] = (512, 1024))


   Bases: :py:obj:`torch.utils.data.Sampler`

   Sampler that returns batches of sequences indices in the dataset so that to ensure
   not a fixed number of sequences per batch but rather a fixed number of tokens per
   batch. This sampler also takes into account that we may want to crop dynamically
   sequences when sampling and thus returns in addition to indices, desired cropping
   lengths to inform the dataloader.

   .. method:: __len__(self)


   .. method:: __iter__(self)



.. class:: DistributedBatchWithConstantNumberTokensSampler(sequence_strs: List[str], toks_per_batch: int, crop_sizes: Tuple[int, int] = (512, 1024), num_replicas: Optional[int] = None, rank: Optional[int] = None, seed: int = 0)


   Bases: :py:obj:`torch.utils.data.Sampler`

   Sampler that returns batches of sequences indices in the dataset so that to ensure
   not a fixed number of sequences per batch but rather a fixed number of tokens per
   batch. This sampler also takes into account that we may want to crop dynamically
   sequences when sampling and thus returns in addition to indices, desired cropping
   lengths to inform the dataloader. This version of the sampler is distributed to
   be used with DDP accelerator.

   .. method:: __len__(self) -> int


   .. method:: set_epoch(self, epoch: int) -> None


   .. method:: __iter__(self)



.. class:: BatchWithConstantNumberTokensDataset(sequences: List[str])


   Bases: :py:obj:`torch.utils.data.Dataset`

   Dataset class to work in pair with the BatchWithConstantNumberTokensSampler.

   .. method:: __len__(self)


   .. method:: __getitem__(self, sampler_out) -> List[str]



.. class:: BatchWithConstantNumberTokensDataModule(train_sequences: List[str], validation_sequences: List[str], alphabet: AlphabetDataLoader, masking_ratio: float, masking_prob: float, random_token_prob: float, num_workers: int, toks_per_batch: int, crop_sizes: Tuple[int, int] = (512, 1024))


   Bases: :py:obj:`pytorch_lightning.LightningDataModule`

   .. method:: _get_dataloader(self, sequences: List[str]) -> torch.utils.data.DataLoader


   .. method:: train_dataloader(self)


   .. method:: val_dataloader(self)



