Pytorch Convolution

torch/nn/modules/conv.py

from .. import functional as F

def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
    if self.padding_mode != 'zeros':
        return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                        weight, bias, self.stride,
                        _pair(0), self.dilation, self.groups)
    return F.conv2d(input, weight, bias, self.stride,
                    self.padding, self.dilation, self.groups)

def forward(self, input: Tensor) -> Tensor:
    return self._conv_forward(input, self.weight, self.bias)

torch/nn/functional.py

conv2d = _add_docstr(
    torch.conv2d,
    r""" comments etc... """
)  # noqa: E501

aten/src/ATen/native/native_functions.yaml

- func: conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
  dispatch:
    CompositeImplicitAutograd: conv2d_symint

aten/src/ATen/native/Convolution.cpp

at::Tensor conv2d_symint(
    const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
    SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  TORCH_CHECK(
    !bias.defined() || bias.dtype() == input_.dtype(),
    "Input type (",
    input_.dtype().name(),
    ") and bias type (",
    bias.dtype().name(),
    ") should be the same");

  Tensor input;
  bool is_batched;
  std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
  Tensor output;
  if (at::isComplexType(input_.scalar_type())) {
    output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
  } else {
    output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
  }
  return is_batched ? std::move(output) : output.squeeze(0);
}

aten/src/ATen/native/native_functions.yaml

- func: convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
  dispatch:
    CompositeExplicitAutograd: convolution
  autogen: convolution.out
  tags: core

- func: _convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
  dispatch:
    CompositeExplicitAutograd: _convolution
  autogen: _convolution.out

aten/src/ATen/native/Convolution.cpp

at::Tensor convolution(
    const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
    bool transposed, IntArrayRef output_padding, int64_t groups) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  auto& ctx = at::globalContext();
  // See Note [Enabling Deterministic Operations]
  bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
  return at::_convolution(input, weight, bias, stride, padding, dilation,
                          transposed, output_padding, groups,
                          ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
}


at::Tensor _convolution(
    const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_r_opt,
    IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
    bool transposed_, IntArrayRef output_padding_, int64_t groups_,
    bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
  const Tensor& bias_r = *bias_r_maybe_owned;

  auto input = input_r;
  auto weight = weight_r;
  auto bias = bias_r;
  auto k = weight.ndimension();
  c10::IntArrayRef weight_sizes = weight.sizes();
  int64_t dim = k - 2;

  TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
  TORCH_CHECK(groups_ > 0, "non-positive groups is not supported");

  ConvParams<int64_t> params;
  params.stride = expand_param_if_needed(stride_, "stride", dim);
  params.padding = expand_param_if_needed(padding_, "padding", dim);
  params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
  params.transposed = transposed_;
  params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
  params.groups = groups_;
  params.benchmark = benchmark;
  params.deterministic = deterministic;
  params.cudnn_enabled = cudnn_enabled;
  params.allow_tf32 = allow_tf32;

  check_shape_forward(input, weight_sizes, bias, params);

  // Expand 1d -> 2d.
  // This is only done for backends that don't natively support 1d spatial input.
  if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
    // avoid accidentally going through NHWC for permuted 3d input.
    input = input.contiguous();
    params.view1d_as_2d();
    input = view4d(input);
    weight = view4d(weight);
  }

  // Select appropriate backend to use.
  auto bias_sizes_opt = bias.defined() ? c10::optional<IntArrayRef>(bias.sizes()) : c10::nullopt;
  bool need_backward = GradMode::is_enabled() &&
      (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
  ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params);
  at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);

  // Call the backend.
  Tensor output;
  auto kernel_size = weight.sizes().slice(2);
  switch (backend) {
    case ConvBackend::CudaDepthwise2d:
      output = at::_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias,
          params.stride, params.padding, params.dilation);
      break;
    case ConvBackend::CudaDepthwise3d:
      output = at::conv_depthwise3d(input.contiguous(), weight, kernel_size, bias,
          params.stride, params.padding, params.dilation);
      break;
    case ConvBackend::Cudnn:
      check_input_same_type_as_parameters(input, weight, bias);
      output = at::cudnn_convolution(
          input.contiguous(backend_memory_format), weight, params.padding, params.stride,
          params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
      if (bias.defined()) {
        output.add_(reshape_bias(input.dim(), bias));
      }
      break;
    case ConvBackend::CudnnTranspose:
      check_input_same_type_as_parameters(input, weight, bias);
      output = at::cudnn_convolution_transpose(
          input.contiguous(backend_memory_format), weight, params.padding, params.output_padding,
          params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
      if (bias.defined()) {
        output.add_(reshape_bias(input.dim(), bias));
      }
      break;
    // case ConvBackend::Empty:{ ...
  }
  return output;
}

aten/src/ATen/native/cudnn/ConvShared.cpp

Tensor cudnn_convolution(
    const Tensor& input_t, const Tensor& weight_t,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
    int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)
{
  TensorArg input  { input_t,  "input",  1 },
            weight { weight_t, "weight", 2 };
  CheckedFrom c = "cudnn_convolution";
  auto output_t = cudnn_convolution_forward(
    c, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
  return output_t;
}


Tensor cudnn_convolution_forward(
    CheckedFrom c,
    const TensorArg& input, const TensorArg& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32)
{
  checkAllSameType(c, {input, weight});
  checkAllSameGPU(c, {input, weight});

  auto memory_format = cudnn_conv_suggest_memory_format(*input, *weight);
  Tensor output_t = at::detail::empty_cuda(
      conv_output_size(input->sizes(), weight->sizes(),
                       padding, stride, dilation),
      input->options().memory_format(memory_format));

  if (output_t.numel() == 0) {
    return output_t;
  }

  // Avoid ambiguity of "output" when this is being used as backwards
  TensorArg output{ output_t, "result", 0 };
  convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);

  Tensor weight_contig = weight->contiguous(memory_format);
  Tensor input_contig = input->contiguous(memory_format);

  raw_cudnn_convolution_forward_out(
      *output, input_contig, weight_contig,
      padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);

  return *output;
}

aten/src/ATen/native/cudnn/Conv_v8.cpp

void raw_cudnn_convolution_forward_out(
    const Tensor& output, const Tensor& input, const Tensor& weight,
    const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups,
    const bool benchmark, const bool deterministic, const bool allow_tf32)
{
  if (output.numel() == 0) { return; }
  if (at::native::cudnnv8_enabled_check_debug()) {
    run_single_conv(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
      input, output, weight, padding, stride, dilation, groups,
      benchmark, deterministic, allow_tf32);
  } else {
    raw_cudnn_convolution_forward_out_v7(
      output, input, weight,
      padding, stride, dilation, groups,
      benchmark, deterministic, allow_tf32);
  }
}

aten/src/ATen/native/cudnn/Conv_v7.cpp

void raw_cudnn_convolution_forward_out_v7(
    const Tensor& output, const Tensor& input, const Tensor& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32) {
  split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit);
}


template <typename func_t>
static inline void split_batch_dim_to_32bit_out(
    const at::Tensor& output,
    const at::Tensor& input,
    const at::Tensor& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32,
    int64_t max_worksize, func_t func_32bit) {
  constexpr int64_t int_max = std::numeric_limits<int>::max();
  const int64_t ni = input.numel();
  const int64_t no = output.numel();
  // Assume the shape of the tensor is (N, C, D1, D2, ...)
  // if N * C * D1 * D2 * ... <= int_max, then no need to split at all
  if (ni <= int_max && no <= int_max) {
    func_32bit(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
    return;
  }
  // else, if C * D1 * D2 * ... <= int_max, then we just need to split across the N dimension
  //
  // Here we use a simple heuristics to determine the size of each split
  // We don't max out the 2^31 address space because this number is super
  // large and very likely to get an OOM.
  int64_t n = output.size(0);
  int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
  int64_t split_size = std::max<int64_t>(max_worksize / max_inner_size, 1L);
  int64_t num_splits = (n + split_size - 1) / split_size;
  if (split_size * max_inner_size < int_max) {
    for (const auto i : c10::irange(num_splits)) {
      int64_t start = split_size * i;
      int64_t split_size_ = std::min<int64_t>(split_size, n - start);
      Tensor input_ = input.narrow(0, start, split_size_);
      Tensor output_ = output.narrow(0, start, split_size_);
      func_32bit(output_, input_, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
    }
    return;
  }
  // If control flow reaches here, this means even splitting N is not enough, then things starts to become complicated:
  // For example, for conv2d, there following questions needs to be considered.
  // - Is the memory layout NCHW or NHWC ?
  // - If the conv is NCHW -> NC'H'W', then should we
  //   - split only NC?
  //   - split only N'C'?
  //   - split both?
  // - If the conv is NHWC, then we need to split across H, we need to be very careful about the boundary condition
  //   to make sure that the boundary is handled correctly.
  // - If we decide to make these splits, is the memory contiguous? Do we need to copy the memory?
  // Considering the complexity of this issue, it is better not to use cuDNN for this case
  TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
}


void raw_cudnn_convolution_forward_out_32bit(
    const Tensor& output, const Tensor& input, const Tensor& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32) {

  auto dataType = getCudnnDataType(input);

  ConvolutionArgs args{ input, output, weight };
  args.handle = getCudnnHandle();
  at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(input, weight);
  setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32, memory_format);
  args.idesc.set(input, memory_format);
  args.wdesc.set(weight, memory_format, 0);
  args.odesc.set(output, memory_format);
  args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32);

  // TODO: when we do legacy group convolution support, we'll repeatedly
  // reinitialize the workspace for each convolution we do.  This is
  // wasteful; we'd rather reuse the workspace.  OTOH, legacy group
  // convolution support is already pretty slow, so this might not
  // matter.  (This applies to raw_cudnn_convolution_backward_input as well.)
  AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark).try_all(
    [&](const cudnnConvolutionFwdAlgoPerf_t &fwdAlgPerf){
      Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input);

      // update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out
      // whether to use Tensor core kernels or not
      // See Note [behavior of cudnnFind and cudnnGet]
      ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
      AT_CUDNN_CHECK_WITH_SHAPES(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType), args);

      Constant one(dataType, 1);
      Constant zero(dataType, 0);

      AT_CUDNN_CHECK_WITH_SHAPES(cudnnConvolutionForward(
          args.handle,
          &one, args.idesc.desc(), input.data_ptr(),
          args.wdesc.desc(), weight.data_ptr(),
          args.cdesc.desc(), fwdAlgPerf.algo, workspace.data_ptr(), fwdAlgPerf.memory,
          &zero, args.odesc.desc(), output.data_ptr()),
        args, "Forward algorithm: ", static_cast<int>(fwdAlgPerf.algo), "\n");
      }
  );
}
Insu Choi
Insu Choi
Ph.D. Student in AI / Computer Architecture

My research interests include AI/ML, AI accelerator and memory reliability.