__init__.py 493 Bytes
import torch
from .registery import *
from .logger import get_logger_and_log_dir
from .fix_pred import fix_text_obj

__all__ = [
    'Registry',
    'get_logger_and_log_dir',
    'sequence_mask',
]

def sequence_mask(X, valid_len, value=0):
    """Mask irrelevant entries in sequences.
    Defined in :numref:`sec_seq2seq_decoder`"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X