__init__.py
458 Bytes
import torch
from .registery import *
from .logger import get_logger_and_log_dir
__all__ = [
'Registry',
'get_logger_and_log_dir',
'sequence_mask',
]
def sequence_mask(X, valid_len, value=0):
"""Mask irrelevant entries in sequences.
Defined in :numref:`sec_seq2seq_decoder`"""
maxlen = X.size(1)
mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X