biotransformers.lightning_utils.data
Contents
biotransformers.lightning_utils.data¶
Module Contents¶
Classes¶
Class that carries tokenizer information |
|
Sampler that returns batches of sequences indices in the dataset so that to ensure |
|
Sampler that returns batches of sequences indices in the dataset so that to ensure |
|
Dataset class to work in pair with the BatchWithConstantNumberTokensSampler. |
|
Functions¶
|
This function convert a state_dict coming form pytorch lightning checkpoint to |
|
Set numpy random seed for each worker. |
|
Mask one sequence randomly. |
|
Collate function to mask tokens. |
|
If the length of the sequence is superior to crop_length, crop randomly |
|
This sampler aims to create batches that do not contain fixed number of sequences |
- class biotransformers.lightning_utils.data.AlphabetDataLoader(prepend_bos: bool, append_eos: bool, mask_idx: int, pad_idx: int, standard_toks: List[str], model_dir: str, lambda_toks_to_ids: Callable, lambda_tokenizer: Callable)¶
Class that carries tokenizer information
- tok_to_idx(self, x)¶
- tokenizer(self)¶
Return seq-token based on sequence
- biotransformers.lightning_utils.data.convert_ckpt_to_statedict(checkpoint_state_dict: collections.OrderedDict) → collections.OrderedDict¶
This function convert a state_dict coming form pytorch lightning checkpoint to a state_dict model that can be load directly in the bio-transformers model.
The keys are updated so that it m.jionatches those in the bio-transformers
- Parameters
checkpoint_state_dict – a state_dict loaded from a checkpoint
- biotransformers.lightning_utils.data.worker_init_fn(worker_id: int)¶
Set numpy random seed for each worker.
https://github.com/pytorch/pytorch/issues/5059#issuecomment-404232359
- Parameters
worker_id – unique id for each worker
- biotransformers.lightning_utils.data.mask_seq(seq: str, tokens: torch.Tensor, prepend_bos: bool, mask_idx: int, pad_idx: int, masking_ratio: float, masking_prob: float, random_token_prob: float, random_token_indices: List[int]) → Tuple[torch.Tensor, torch.Tensor]¶
Mask one sequence randomly.
- Parameters
seq – string of the sequence.
tokens – tokens corresponding to the sequence, length can be longer than the seq.
prepend_bos – if tokenizer adds <bos> token
mask_idx – index of the mask token
pad_idx – index of the padding token
masking_ratio – ratio of tokens to be masked.
masking_prob – probability that the chose token is replaced with a mask token.
random_token_prob – probability that the chose token is replaced with a random token.
random_token_indices – list of token indices that random replacement selects from.
- Returns
masked tokens targets: same length as tokens
- Return type
tokens
- biotransformers.lightning_utils.data.collate_fn(samples: Sequence[Tuple[str, str]], tokenizer: esm.data.BatchConverter, alphabet: AlphabetDataLoader, masking_ratio: float, masking_prob: float, random_token_prob: float) → Tuple[torch.Tensor, torch.Tensor]¶
Collate function to mask tokens.
- Parameters
samples – a sequences of (label, seq).
tokenizer – facebook tokenizer, that accepts sequences of (label, seq_str) and outputs (labels, seq_strs, tokens).
alphabet – facebook alphabet.
masking_ratio – ratio of tokens to be masked.
masking_prob – probability that the chose token is replaced with a mask token.
random_token_prob – probability that the chose token is replaced with a random token.
- Returns
model input targets: model target mask_indices: indices of masked tokens
- Return type
tokens
- biotransformers.lightning_utils.data.crop_sequence(sequence: str, crop_length: int) → str¶
If the length of the sequence is superior to crop_length, crop randomly the sequence to get the proper length.
- biotransformers.lightning_utils.data.get_batch_indices(sequence_strs: List[str], toks_per_batch: int, crop_sizes: Tuple[int, int] = (600, 1200), seed: int = 0) → List[List[List[Tuple[int, int]]]]¶
This sampler aims to create batches that do not contain fixed number of sequences but rather constant number of tokens. Some the batch can contain a few long sequences or multiple small ones.
This sampler returns batches of indices to achieve this property. It also decides if sequences must be cropped and return the desired length. The cropping length is sampled randomly for each sequence at each epoch in the range of crop_sizes values.
This sampler computes a list of list of tuple which contains indices and lengths of sequences inside the batch.
Example
returning [[(1, 100), (3, 600)],[(4, 100), (7, 1200), (10, 600)], [(12, 1000)]] means that the first batch will be composed of sequence at index 1 and 3 with lengths 100 and 600. The third batch contains only sequence 12 with a length of 1000.
- Parameters
sequence_strs – list of string
toks_per_batch (int) – Maximum number of token per batch
extra_toks_per_seq (int, optional) – . Defaults to 0.
crop_sizes (Tuple[int, int]) – min and max sequence lengths when cropping
seed (int) – seed to be used for random generator
- Returns
List of batches indexes and lengths
- Return type
List
- class biotransformers.lightning_utils.data.BatchWithConstantNumberTokensSampler(sequence_strs: List[str], toks_per_batch: int, crop_sizes: Tuple[int, int] = (512, 1024))¶
Bases:
torch.utils.data.SamplerSampler that returns batches of sequences indices in the dataset so that to ensure not a fixed number of sequences per batch but rather a fixed number of tokens per batch. This sampler also takes into account that we may want to crop dynamically sequences when sampling and thus returns in addition to indices, desired cropping lengths to inform the dataloader.
- __len__(self)¶
- __iter__(self)¶
- class biotransformers.lightning_utils.data.DistributedBatchWithConstantNumberTokensSampler(sequence_strs: List[str], toks_per_batch: int, crop_sizes: Tuple[int, int] = (512, 1024), num_replicas: Optional[int] = None, rank: Optional[int] = None, seed: int = 0)¶
Bases:
torch.utils.data.SamplerSampler that returns batches of sequences indices in the dataset so that to ensure not a fixed number of sequences per batch but rather a fixed number of tokens per batch. This sampler also takes into account that we may want to crop dynamically sequences when sampling and thus returns in addition to indices, desired cropping lengths to inform the dataloader. This version of the sampler is distributed to be used with DDP accelerator.
- __len__(self) → int¶
- set_epoch(self, epoch: int) → None¶
- __iter__(self)¶
- class biotransformers.lightning_utils.data.BatchWithConstantNumberTokensDataset(sequences: List[str])¶
Bases:
torch.utils.data.DatasetDataset class to work in pair with the BatchWithConstantNumberTokensSampler.
- __len__(self)¶
- __getitem__(self, sampler_out) → List[str]¶
- class biotransformers.lightning_utils.data.BatchWithConstantNumberTokensDataModule(train_sequences: List[str], validation_sequences: List[str], alphabet: AlphabetDataLoader, masking_ratio: float, masking_prob: float, random_token_prob: float, num_workers: int, toks_per_batch: int, crop_sizes: Tuple[int, int] = (512, 1024))¶
Bases:
pytorch_lightning.LightningDataModule- _get_dataloader(self, sequences: List[str]) → torch.utils.data.DataLoader¶
- train_dataloader(self)¶
- val_dataloader(self)¶