openRiemann
2022-02-24 20:19
采纳率: 0%
浏览 136

PyTorch torch.nn.RNN中_VF.rnn_tanh的核心实现代码在哪里呢?它是如何被调用的?

我最近在看PyTorch RNN的源代码,torch/nn/modules/rnn.py文件中对RNN的核心实现是对_VF.rnn_tanh的调用,源代码片段如下(见标注处):

    def forward(self, input, hx=None):  # noqa: F811
        orig_input = input
        if isinstance(orig_input, PackedSequence):
            input, batch_sizes, sorted_indices, unsorted_indices = input
            max_batch_size = int(batch_sizes[0])
        else:
            batch_sizes = None
            is_batched = input.dim() == 3
            batch_dim = 0 if self.batch_first else 1
            if not is_batched:
                input = input.unsqueeze(batch_dim)
                if hx is not None:
                    if hx.dim() != 2:
                        raise RuntimeError(
                            f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
                    hx = hx.unsqueeze(1)
            else:
                if hx is not None and hx.dim() != 3:
                    raise RuntimeError(
                        f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
            max_batch_size = input.size(0) if self.batch_first else input.size(1)
            sorted_indices = None
            unsorted_indices = None

        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            hx = torch.zeros(self.num_layers * num_directions,
                             max_batch_size, self.hidden_size,
                             dtype=input.dtype, device=input.device)
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        assert hx is not None
        self.check_forward_args(input, hx, batch_sizes)
        assert self.mode == 'RNN_TANH' or self.mode == 'RNN_RELU'
        if batch_sizes is None:
            if self.mode == 'RNN_TANH':
                result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers,
                                      self.dropout, self.training, self.bidirectional,
                                      self.batch_first) #annotated by@openriemann FIXME :obj:`_VF.rnn_tanh`的源代码在哪里
            else:
                result = _VF.rnn_relu(input, hx, self._flat_weights, self.bias, self.num_layers,
                                      self.dropout, self.training, self.bidirectional,
                                      self.batch_first)
        else:
            if self.mode == 'RNN_TANH':
                result = _VF.rnn_tanh(input, batch_sizes, hx, self._flat_weights, self.bias,
                                      self.num_layers, self.dropout, self.training,
                                      self.bidirectional)
            else:
                result = _VF.rnn_relu(input, batch_sizes, hx, self._flat_weights, self.bias,
                                      self.num_layers, self.dropout, self.training,
                                      self.bidirectional)

        output = result[0]
        hidden = result[1]

        if isinstance(orig_input, PackedSequence):
            output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
            return output_packed, self.permute_hidden(hidden, unsorted_indices)

        if not is_batched:
            output = output.squeeze(batch_dim)
            hidden = hidden.squeeze(1)

        return output, self.permute_hidden(hidden, unsorted_indices)

但是我在torch/_VF.py文件下并没有看到对rnn_tanh等相关的定义:

"""
This makes the functions in torch._C._VariableFunctions available as
    torch._VF.<funcname>
without mypy being able to find them.

A subset of those functions are mapped to ATen functions in
torch/jit/_builtins.py

See https://github.com/pytorch/pytorch/issues/21478 for the reason for
introducing torch._VF

"""
import torch
import sys
import types


class VFModule(types.ModuleType):
    vf: types.ModuleType

    def __init__(self, name):
        super(VFModule, self).__init__(name)
        self.vf = torch._C._VariableFunctions

    def __getattr__(self, attr):
        return getattr(self.vf, attr)


sys.modules[__name__] = VFModule(__name__)

我的猜想是rnn_tanh是由底层cuDNN实现的,但是我并不知道rnn_tanh所对应的C++函数,自己目前对C++还不是很熟练,对CMake相关的知识也还有些陌生。
我的问题是_VF.rnn_tanh的核心实现源代码在哪里呢?如果能进一步解释_VF.rnn_tanh从C++到Python被调用的过程那就再好不过了!对此我先说一声感谢!

1条回答 默认 最新

相关推荐 更多相似问题