torch/nn/modules/conv.py
from .. import functional as F
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight, self.bias)
torch/nn/functional.py
conv2d = _add_docstr(
torch.conv2d,
r""" comments etc... """
) # noqa: E501
aten/src/ATen/native/native_functions.yaml
- func: conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
dispatch:
CompositeImplicitAutograd: conv2d_symint
aten/src/ATen/native/Convolution.cpp
at::Tensor conv2d_symint(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
TORCH_CHECK(
!bias.defined() || bias.dtype() == input_.dtype(),
"Input type (",
input_.dtype().name(),
") and bias type (",
bias.dtype().name(),
") should be the same");
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
} else {
output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
}
return is_batched ? std::move(output) : output.squeeze(0);
}
aten/src/ATen/native/native_functions.yaml
- func: convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
dispatch:
CompositeExplicitAutograd: convolution
autogen: convolution.out
tags: core
- func: _convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
dispatch:
CompositeExplicitAutograd: _convolution
autogen: _convolution.out
aten/src/ATen/native/Convolution.cpp
at::Tensor convolution(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
bool transposed, IntArrayRef output_padding, int64_t groups) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
auto& ctx = at::globalContext();
// See Note [Enabling Deterministic Operations]
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
return at::_convolution(input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups,
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
}
at::Tensor _convolution(
const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_r_opt,
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
auto input = input_r;
auto weight = weight_r;
auto bias = bias_r;
auto k = weight.ndimension();
c10::IntArrayRef weight_sizes = weight.sizes();
int64_t dim = k - 2;
TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
TORCH_CHECK(groups_ > 0, "non-positive groups is not supported");
ConvParams<int64_t> params;
params.stride = expand_param_if_needed(stride_, "stride", dim);
params.padding = expand_param_if_needed(padding_, "padding", dim);
params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
params.transposed = transposed_;
params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
params.groups = groups_;
params.benchmark = benchmark;
params.deterministic = deterministic;
params.cudnn_enabled = cudnn_enabled;
params.allow_tf32 = allow_tf32;
check_shape_forward(input, weight_sizes, bias, params);
// Expand 1d -> 2d.
// This is only done for backends that don't natively support 1d spatial input.
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
// avoid accidentally going through NHWC for permuted 3d input.
input = input.contiguous();
params.view1d_as_2d();
input = view4d(input);
weight = view4d(weight);
}
// Select appropriate backend to use.
auto bias_sizes_opt = bias.defined() ? c10::optional<IntArrayRef>(bias.sizes()) : c10::nullopt;
bool need_backward = GradMode::is_enabled() &&
(input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params);
at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
// Call the backend.
Tensor output;
auto kernel_size = weight.sizes().slice(2);
switch (backend) {
case ConvBackend::CudaDepthwise2d:
output = at::_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias,
params.stride, params.padding, params.dilation);
break;
case ConvBackend::CudaDepthwise3d:
output = at::conv_depthwise3d(input.contiguous(), weight, kernel_size, bias,
params.stride, params.padding, params.dilation);
break;
case ConvBackend::Cudnn:
check_input_same_type_as_parameters(input, weight, bias);
output = at::cudnn_convolution(
input.contiguous(backend_memory_format), weight, params.padding, params.stride,
params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
if (bias.defined()) {
output.add_(reshape_bias(input.dim(), bias));
}
break;
case ConvBackend::CudnnTranspose:
check_input_same_type_as_parameters(input, weight, bias);
output = at::cudnn_convolution_transpose(
input.contiguous(backend_memory_format), weight, params.padding, params.output_padding,
params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
if (bias.defined()) {
output.add_(reshape_bias(input.dim(), bias));
}
break;
// case ConvBackend::Empty:{ ...
}
return output;
}
aten/src/ATen/native/cudnn/ConvShared.cpp
Tensor cudnn_convolution(
const Tensor& input_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)
{
TensorArg input { input_t, "input", 1 },
weight { weight_t, "weight", 2 };
CheckedFrom c = "cudnn_convolution";
auto output_t = cudnn_convolution_forward(
c, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
return output_t;
}
Tensor cudnn_convolution_forward(
CheckedFrom c,
const TensorArg& input, const TensorArg& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
{
checkAllSameType(c, {input, weight});
checkAllSameGPU(c, {input, weight});
auto memory_format = cudnn_conv_suggest_memory_format(*input, *weight);
Tensor output_t = at::detail::empty_cuda(
conv_output_size(input->sizes(), weight->sizes(),
padding, stride, dilation),
input->options().memory_format(memory_format));
if (output_t.numel() == 0) {
return output_t;
}
// Avoid ambiguity of "output" when this is being used as backwards
TensorArg output{ output_t, "result", 0 };
convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
Tensor weight_contig = weight->contiguous(memory_format);
Tensor input_contig = input->contiguous(memory_format);
raw_cudnn_convolution_forward_out(
*output, input_contig, weight_contig,
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
return *output;
}
aten/src/ATen/native/cudnn/Conv_v8.cpp
void raw_cudnn_convolution_forward_out(
const Tensor& output, const Tensor& input, const Tensor& weight,
const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups,
const bool benchmark, const bool deterministic, const bool allow_tf32)
{
if (output.numel() == 0) { return; }
if (at::native::cudnnv8_enabled_check_debug()) {
run_single_conv(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
input, output, weight, padding, stride, dilation, groups,
benchmark, deterministic, allow_tf32);
} else {
raw_cudnn_convolution_forward_out_v7(
output, input, weight,
padding, stride, dilation, groups,
benchmark, deterministic, allow_tf32);
}
}
aten/src/ATen/native/cudnn/Conv_v7.cpp
void raw_cudnn_convolution_forward_out_v7(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32) {
split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit);
}
template <typename func_t>
static inline void split_batch_dim_to_32bit_out(
const at::Tensor& output,
const at::Tensor& input,
const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32,
int64_t max_worksize, func_t func_32bit) {
constexpr int64_t int_max = std::numeric_limits<int>::max();
const int64_t ni = input.numel();
const int64_t no = output.numel();
// Assume the shape of the tensor is (N, C, D1, D2, ...)
// if N * C * D1 * D2 * ... <= int_max, then no need to split at all
if (ni <= int_max && no <= int_max) {
func_32bit(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
return;
}
// else, if C * D1 * D2 * ... <= int_max, then we just need to split across the N dimension
//
// Here we use a simple heuristics to determine the size of each split
// We don't max out the 2^31 address space because this number is super
// large and very likely to get an OOM.
int64_t n = output.size(0);
int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
int64_t split_size = std::max<int64_t>(max_worksize / max_inner_size, 1L);
int64_t num_splits = (n + split_size - 1) / split_size;
if (split_size * max_inner_size < int_max) {
for (const auto i : c10::irange(num_splits)) {
int64_t start = split_size * i;
int64_t split_size_ = std::min<int64_t>(split_size, n - start);
Tensor input_ = input.narrow(0, start, split_size_);
Tensor output_ = output.narrow(0, start, split_size_);
func_32bit(output_, input_, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
return;
}
// If control flow reaches here, this means even splitting N is not enough, then things starts to become complicated:
// For example, for conv2d, there following questions needs to be considered.
// - Is the memory layout NCHW or NHWC ?
// - If the conv is NCHW -> NC'H'W', then should we
// - split only NC?
// - split only N'C'?
// - split both?
// - If the conv is NHWC, then we need to split across H, we need to be very careful about the boundary condition
// to make sure that the boundary is handled correctly.
// - If we decide to make these splits, is the memory contiguous? Do we need to copy the memory?
// Considering the complexity of this issue, it is better not to use cuDNN for this case
TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
}
void raw_cudnn_convolution_forward_out_32bit(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32) {
auto dataType = getCudnnDataType(input);
ConvolutionArgs args{ input, output, weight };
args.handle = getCudnnHandle();
at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(input, weight);
setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32, memory_format);
args.idesc.set(input, memory_format);
args.wdesc.set(weight, memory_format, 0);
args.odesc.set(output, memory_format);
args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32);
// TODO: when we do legacy group convolution support, we'll repeatedly
// reinitialize the workspace for each convolution we do. This is
// wasteful; we'd rather reuse the workspace. OTOH, legacy group
// convolution support is already pretty slow, so this might not
// matter. (This applies to raw_cudnn_convolution_backward_input as well.)
AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark).try_all(
[&](const cudnnConvolutionFwdAlgoPerf_t &fwdAlgPerf){
Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input);
// update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out
// whether to use Tensor core kernels or not
// See Note [behavior of cudnnFind and cudnnGet]
ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
AT_CUDNN_CHECK_WITH_SHAPES(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType), args);
Constant one(dataType, 1);
Constant zero(dataType, 0);
AT_CUDNN_CHECK_WITH_SHAPES(cudnnConvolutionForward(
args.handle,
&one, args.idesc.desc(), input.data_ptr(),
args.wdesc.desc(), weight.data_ptr(),
args.cdesc.desc(), fwdAlgPerf.algo, workspace.data_ptr(), fwdAlgPerf.memory,
&zero, args.odesc.desc(), output.data_ptr()),
args, "Forward algorithm: ", static_cast<int>(fwdAlgPerf.algo), "\n");
}
);
}