biotransformers.lightning_utils.data

Module Contents

Classes

AlphabetDataLoader

Class that carries tokenizer information

BatchWithConstantNumberTokensSampler

Sampler that returns batches of sequences indices in the dataset so that to ensure

DistributedBatchWithConstantNumberTokensSampler

Sampler that returns batches of sequences indices in the dataset so that to ensure

BatchWithConstantNumberTokensDataset

Dataset class to work in pair with the BatchWithConstantNumberTokensSampler.

BatchWithConstantNumberTokensDataModule

Functions

convert_ckpt_to_statedict(checkpoint_state_dict: collections.OrderedDict) → collections.OrderedDict

This function convert a state_dict coming form pytorch lightning checkpoint to

worker_init_fn(worker_id: int)

Set numpy random seed for each worker.

mask_seq(seq: str, tokens: torch.Tensor, prepend_bos: bool, mask_idx: int, pad_idx: int, masking_ratio: float, masking_prob: float, random_token_prob: float, random_token_indices: List[int]) → Tuple[torch.Tensor, torch.Tensor]

Mask one sequence randomly.

collate_fn(samples: Sequence[Tuple[str, str]], tokenizer: esm.data.BatchConverter, alphabet: AlphabetDataLoader, masking_ratio: float, masking_prob: float, random_token_prob: float) → Tuple[torch.Tensor, torch.Tensor]

Collate function to mask tokens.

crop_sequence(sequence: str, crop_length: int) → str

If the length of the sequence is superior to crop_length, crop randomly

get_batch_indices(sequence_strs: List[str], toks_per_batch: int, crop_sizes: Tuple[int, int] = (600, 1200), seed: int = 0) → List[List[List[Tuple[int, int]]]]

This sampler aims to create batches that do not contain fixed number of sequences

class biotransformers.lightning_utils.data.AlphabetDataLoader(prepend_bos: bool, append_eos: bool, mask_idx: int, pad_idx: int, standard_toks: List[str], model_dir: str, lambda_toks_to_ids: Callable, lambda_tokenizer: Callable)

Class that carries tokenizer information

tok_to_idx(self, x)
tokenizer(self)

Return seq-token based on sequence

biotransformers.lightning_utils.data.convert_ckpt_to_statedict(checkpoint_state_dict: collections.OrderedDict)collections.OrderedDict

This function convert a state_dict coming form pytorch lightning checkpoint to a state_dict model that can be load directly in the bio-transformers model.

The keys are updated so that it m.jionatches those in the bio-transformers

Parameters

checkpoint_state_dict – a state_dict loaded from a checkpoint

biotransformers.lightning_utils.data.worker_init_fn(worker_id: int)

Set numpy random seed for each worker.

https://github.com/pytorch/pytorch/issues/5059#issuecomment-404232359

Parameters

worker_id – unique id for each worker

biotransformers.lightning_utils.data.mask_seq(seq: str, tokens: torch.Tensor, prepend_bos: bool, mask_idx: int, pad_idx: int, masking_ratio: float, masking_prob: float, random_token_prob: float, random_token_indices: List[int])Tuple[torch.Tensor, torch.Tensor]

Mask one sequence randomly.

Parameters
  • seq – string of the sequence.

  • tokens – tokens corresponding to the sequence, length can be longer than the seq.

  • prepend_bos – if tokenizer adds <bos> token

  • mask_idx – index of the mask token

  • pad_idx – index of the padding token

  • masking_ratio – ratio of tokens to be masked.

  • masking_prob – probability that the chose token is replaced with a mask token.

  • random_token_prob – probability that the chose token is replaced with a random token.

  • random_token_indices – list of token indices that random replacement selects from.

Returns

masked tokens targets: same length as tokens

Return type

tokens

biotransformers.lightning_utils.data.collate_fn(samples: Sequence[Tuple[str, str]], tokenizer: esm.data.BatchConverter, alphabet: AlphabetDataLoader, masking_ratio: float, masking_prob: float, random_token_prob: float)Tuple[torch.Tensor, torch.Tensor]

Collate function to mask tokens.

Parameters
  • samples – a sequences of (label, seq).

  • tokenizer – facebook tokenizer, that accepts sequences of (label, seq_str) and outputs (labels, seq_strs, tokens).

  • alphabet – facebook alphabet.

  • masking_ratio – ratio of tokens to be masked.

  • masking_prob – probability that the chose token is replaced with a mask token.

  • random_token_prob – probability that the chose token is replaced with a random token.

Returns

model input targets: model target mask_indices: indices of masked tokens

Return type

tokens

biotransformers.lightning_utils.data.crop_sequence(sequence: str, crop_length: int)str

If the length of the sequence is superior to crop_length, crop randomly the sequence to get the proper length.

biotransformers.lightning_utils.data.get_batch_indices(sequence_strs: List[str], toks_per_batch: int, crop_sizes: Tuple[int, int] = (600, 1200), seed: int = 0)List[List[List[Tuple[int, int]]]]

This sampler aims to create batches that do not contain fixed number of sequences but rather constant number of tokens. Some the batch can contain a few long sequences or multiple small ones.

This sampler returns batches of indices to achieve this property. It also decides if sequences must be cropped and return the desired length. The cropping length is sampled randomly for each sequence at each epoch in the range of crop_sizes values.

This sampler computes a list of list of tuple which contains indices and lengths of sequences inside the batch.

Example

returning [[(1, 100), (3, 600)],[(4, 100), (7, 1200), (10, 600)], [(12, 1000)]] means that the first batch will be composed of sequence at index 1 and 3 with lengths 100 and 600. The third batch contains only sequence 12 with a length of 1000.

Parameters
  • sequence_strs – list of string

  • toks_per_batch (int) – Maximum number of token per batch

  • extra_toks_per_seq (int, optional) – . Defaults to 0.

  • crop_sizes (Tuple[int, int]) – min and max sequence lengths when cropping

  • seed (int) – seed to be used for random generator

Returns

List of batches indexes and lengths

Return type

List

class biotransformers.lightning_utils.data.BatchWithConstantNumberTokensSampler(sequence_strs: List[str], toks_per_batch: int, crop_sizes: Tuple[int, int] = (512, 1024))

Bases: torch.utils.data.Sampler

Sampler that returns batches of sequences indices in the dataset so that to ensure not a fixed number of sequences per batch but rather a fixed number of tokens per batch. This sampler also takes into account that we may want to crop dynamically sequences when sampling and thus returns in addition to indices, desired cropping lengths to inform the dataloader.

__len__(self)
__iter__(self)
class biotransformers.lightning_utils.data.DistributedBatchWithConstantNumberTokensSampler(sequence_strs: List[str], toks_per_batch: int, crop_sizes: Tuple[int, int] = (512, 1024), num_replicas: Optional[int] = None, rank: Optional[int] = None, seed: int = 0)

Bases: torch.utils.data.Sampler

Sampler that returns batches of sequences indices in the dataset so that to ensure not a fixed number of sequences per batch but rather a fixed number of tokens per batch. This sampler also takes into account that we may want to crop dynamically sequences when sampling and thus returns in addition to indices, desired cropping lengths to inform the dataloader. This version of the sampler is distributed to be used with DDP accelerator.

__len__(self)int
set_epoch(self, epoch: int)None
__iter__(self)
class biotransformers.lightning_utils.data.BatchWithConstantNumberTokensDataset(sequences: List[str])

Bases: torch.utils.data.Dataset

Dataset class to work in pair with the BatchWithConstantNumberTokensSampler.

__len__(self)
__getitem__(self, sampler_out)List[str]
class biotransformers.lightning_utils.data.BatchWithConstantNumberTokensDataModule(train_sequences: List[str], validation_sequences: List[str], alphabet: AlphabetDataLoader, masking_ratio: float, masking_prob: float, random_token_prob: float, num_workers: int, toks_per_batch: int, crop_sizes: Tuple[int, int] = (512, 1024))

Bases: pytorch_lightning.LightningDataModule

_get_dataloader(self, sequences: List[str])torch.utils.data.DataLoader
train_dataloader(self)
val_dataloader(self)