torch/ao/nn/quantized/functional.py
def conv2d(input, weight, bias,
stride=1, padding=0, dilation=1, groups=1,
padding_mode='zeros',
scale=1.0, zero_point=0,
dtype=torch.quint8):
r"""
Applies a 2D convolution over a quantized 2D input composed of several input
planes.
See :class:`~torch.ao.nn.quantized.Conv2d` for details and output shape.
Args:
input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sH, sW)`. Default: 1
padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padH, padW)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dH, dW)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
number of groups. Default: 1
padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
scale: quantization scale for the output. Default: 1.0
zero_point: quantization zero_point for the output. Default: 0
dtype: quantization data type to use. Default: ``torch.quint8``
"""
if padding_mode != 'zeros':
raise NotImplementedError("Only zero-padding is supported!")
if input.dtype != torch.quint8:
raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
if weight.dtype != torch.qint8:
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
if input.ndim != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
packed_params = torch.ops.quantized.conv2d_prepack(
weight, bias, stride, padding, dilation, groups)
return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point)
torch.ops.quantized.conv2d_prepack
aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp
template <int kSpatialDim>
c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightCudnn<
kSpatialDim>::
prepack(
at::Tensor weight,
c10::optional<at::Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose) {
// TODO: need to check out to implement groups for conv operator in Conv.cpp
TORCH_CHECK(groups == 1, "Quantized cudnn conv2d is currently limited to groups = 1; received groups =", groups);
TORCH_CHECK(weight.qscheme() == c10::kPerTensorAffine, "Unsupported qscheme: ", toString(weight.qscheme()));
TORCH_CHECK(
kSpatialDim == 2, // 1D is packed as 2d, hence we don't need other checks
"cuDNN packing only supports 2D convolution.");
TORCH_CHECK(
weight.ndimension() == kSpatialDim + 2,
"Weights are expected to have ",
kSpatialDim + 2,
" dimensions");
TORCH_CHECK(
stride.size() == kSpatialDim,
"stride should contain ",
kSpatialDim,
" elements for ",
kSpatialDim,
"D convolution.");
TORCH_CHECK(
padding.size() == kSpatialDim,
"quantized::conv_prepack (cudnn): Specify front/top/left padding only. "
"end/bottom/right padding assumed to be equal to front/top/left");
TORCH_CHECK(
!transpose || output_padding.size() == kSpatialDim,
"quantized::conv_prepack: Specify top/left output padding "
"only. bottom/right padding assumed to be equal to top/left");
TORCH_CHECK(
dilation.size() == kSpatialDim,
"quantized::conv_prepack (cudnn): dilation should contain ",
kSpatialDim,
" elements for ",
kSpatialDim,
"D convolution.");
TORCH_CHECK(!transpose, "cudNN quantized conv prepack expects transpose = false")
const int num_unpadded_output_channels = weight.size(0);
const auto qtype = weight.qscheme();
if (bias.has_value()) {
TORCH_CHECK(bias.value().dim() == 1, "bias should be a vector (1D Tensor)");
TORCH_CHECK(
bias.value().size(0) == num_unpadded_output_channels,
"bias should have K elements: " + std::to_string(num_unpadded_output_channels));
// TODO: we create a broadcasted_bias tensor later so I think we don't need to make this contiguous here.
// we will revisit this when nvidia adds proper support for broadcasting
// bias_contig = bias->contiguous();
}
// cudnn v8.4.0 expects conv2d's int8 weight tensor's input and output channels to be a multiple of 4. if it is not
// we need to explicitly pad it to a multiple of 4 ourselves as cudnn does not currently support padding.
// TODO: when and if cudnn enables padding in their operators, we can remove padding on our end;
// currently, limit padding support to groups=1 (ungrouped conv)
// TODO: implement this for groups > 1
auto num_input_channels = weight.size(1);
int8_t num_output_slices2pad = (4 - num_unpadded_output_channels % 4) % 4;
int8_t num_input_slices2pad = (4 - num_input_channels % 4) % 4;
if (num_output_slices2pad != 0 || num_input_slices2pad != 0) {
// the second argument is an initializer list of padded values. there are 2 values for each dimension.
// refer to https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html for more details
weight = at::pad(weight, {0, 0, 0, 0, 0, num_input_slices2pad, 0, num_output_slices2pad}, "constant", 0);
if (bias.has_value()) {
bias.value() = at::pad(bias.value(), {0, num_output_slices2pad}, "constant", 0);
}
}
auto ret_ptr = c10::make_intrusive<PackedConvWeightCudnn<kSpatialDim>>(
weight.to(c10::MemoryFormat::ChannelsLast), // TODO: this assumes 2D I think. make it more general?
bias,
stride,
padding,
output_padding,
dilation,
groups,
transpose,
qtype,
num_unpadded_output_channels);
return ret_ptr;
}
template <int kSpatialDim = 2>
class QConvPackWeightInt8Cudnn final {
public:
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run_conv(
Tensor weight,
c10::optional<Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> dilation,
int64_t groups) {
torch::List<int64_t> output_padding;
output_padding.reserve(kSpatialDim);
for (C10_UNUSED const auto idx : c10::irange(kSpatialDim)) {
output_padding.push_back((int64_t)0);
}
return _run(weight, bias, stride, padding, output_padding, dilation, groups,
/*transpose=*/false);
}
private:
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> _run(
Tensor weight,
c10::optional<Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose) {
return PackedConvWeightCudnn<kSpatialDim>::prepack(
weight, bias, stride, padding, output_padding, dilation, groups,
transpose);
}
};
TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8Cudnn<2>::run_conv));
}
torch.ops.quantized.conv2d
aten/src/ATen/native/quantized/cudnn/Conv.cpp
struct CacheKey {
at::native::ConvolutionParams params;
uint8_t input_alignment;
uint8_t weight_alignment;
uint8_t output_alignment;
// default to -1 when no bias
int8_t bias_alignment;
bool kReluFused;
};
std::unordered_map<CacheKey, cudnn_frontend::ExecutionPlan, at::native::ParamsHash<CacheKey>, at::native::ParamsEqual<CacheKey>> execution_plan_cache;
} // anonymous namespace
// TODO: we can use cudnn_frontend::ExecutionPlanCache when it supports caching
// multiple operators
// reference: https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/conv_sample.cpp#L293
//static cudnn_frontend::ExecutionPlanCache plan_cache("sample_cache");
// the parameter quantized_output is a quantized tensor
template <int kSpatialDim>
template <bool kReluFused>
void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& quantized_output, const at::Tensor& input, double output_scale) {
auto act_scale = input.q_scale();
auto weight_scale = maybe_padded_weight_.q_scale();
auto requantize_multiplier = act_scale * weight_scale / output_scale;
at::Tensor requantize_multiplier_tensor = cudnn_utils::getRequantMultiplierTensor(requantize_multiplier, kSpatialDim + 2);
c10::optional<at::Tensor> bias_multiplier_tensor;
c10::optional<at::Tensor> broadcasted_bias;
if (bias_.has_value()) {
// the input bias is a 1-D tensor whose size is the same as the size of the second dimension of quantized_output.
// we need to add trailing dimensions in order to properly broadcast bias, otherwise broadcast_to will fail.
// the number of trailling dimensions is quantized_output.dim() - 2, so the new size of the broadcast_bias
// becomes quantized_output.dim() - 2 + 1. nothing needs to be done for the leading dimensions
std::vector<int64_t> new_size(quantized_output.dim() - 1, 1);
new_size[0] = bias_.value().size(0);
broadcasted_bias = bias_.value().reshape(new_size);
broadcasted_bias.value() = broadcasted_bias.value().broadcast_to(quantized_output.sizes());
broadcasted_bias.value() = broadcasted_bias.value().to(c10::MemoryFormat::ChannelsLast);
bias_multiplier_tensor = at::empty(quantized_output.sizes(), at::device(at::kCUDA).dtype(at::kFloat), at::MemoryFormat::ChannelsLast);
auto bias_multiplier = 1.0 / (act_scale * weight_scale);
bias_multiplier_tensor.value().fill_(bias_multiplier);
}
cudnnHandle_t handle = at::native::getCudnnHandle();
CacheKey key;
// memset is needed here because there is implicit packing added for CacheKey, and this can result in uninitialized padded values that are
// used for hashing (see how at::native::ParamsHash is defined). without memset, we can potentially come across a situation where two
// CacheKey objects have the same user defined parameters, but
// different padded values, resulting in different hash outputs.
memset(&key, 0, sizeof(key));
bool deterministic{true};
bool allow_tf32{false};
auto padding_vec = padding_.vec();
auto stride_vec = stride_.vec();
auto dilation_vec = dilation_.vec();
setConvolutionParams(&key.params, input, maybe_padded_weight_, padding_vec, stride_vec, dilation_vec, groups_, deterministic, allow_tf32, input.suggest_memory_format());
// operator datatype needs to be int32 for int8 convolution, but we can
// set the datatype for output tensor to int32 or fp32
key.params.dataType = CUDNN_DATA_INT32;
key.input_alignment = cudnn_utils::getAlignment(input);
key.output_alignment = cudnn_utils::getAlignment(quantized_output);
key.weight_alignment = cudnn_utils::getAlignment(maybe_padded_weight_);
if (bias_.has_value()) {
key.bias_alignment = cudnn_utils::getAlignment(broadcasted_bias.value());
} else {
key.bias_alignment = -1;
}
key.kReluFused = kReluFused;
auto run = [&](const cudnn_frontend::ExecutionPlan& plan_desc) {
auto workspace_size = plan_desc.getWorkspaceSize();
auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
at::SmallVector<void *, 7> data_ptrs;
at::SmallVector<int64_t, 7> uids;
data_ptrs = {input.data_ptr<int8_t>(), maybe_padded_weight_.data_ptr<int8_t>(),
requantize_multiplier_tensor.data_ptr(), quantized_output.data_ptr<int8_t>()};
uids = {'x', 'w', 's', 'r'};
if (bias_.has_value()) {
data_ptrs.insert(data_ptrs.end(), {broadcasted_bias.value().data_ptr(), bias_multiplier_tensor.value().data_ptr(),
broadcasted_bias.value().data_ptr()});
uids.insert(uids.end(), {'b', 'c', 'd'});
}
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_size ? workspace_ptr.get() : nullptr)
.setDataPointers(uids.size(), data_ptrs.data())
.setUids(uids.size(), uids.data())
.build();
auto variant_pack_desc = variantPack.get_raw_desc();
AT_CUDNN_CHECK(cudnnBackendExecute(handle, plan_desc.get_raw_desc(), variant_pack_desc));
};
auto search = execution_plan_cache.find(key);
if (search != execution_plan_cache.end()) {
cudnn_frontend::ExecutionPlan plan_desc = search->second;
run(plan_desc);
return;
}
// conv_op computes act_fp32 * w_fp32 (matrix multiplication)
// where act_fp32 and w_fp32 are the input and weight variables, resp.
// output is a fp32 tensor
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(cudnn_utils::getTensorDescriptor(input.sizes(), input.strides(), CUDNN_DATA_INT8, 'x', key.input_alignment))
// for virtual tensors, the alignment is not used, so we can just put an arbitrary value here, e.g., key.output_alignment
.setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_FLOAT, 'y', key.output_alignment, true))
.setwDesc(cudnn_utils::getTensorDescriptor(maybe_padded_weight_.sizes(), maybe_padded_weight_.strides(), CUDNN_DATA_INT8, 'w', key.weight_alignment))
.setcDesc(getConvDescriptor(key.params.dataType, padding_vec, stride_vec, dilation_vec))
.build();
// std::cout << "operator:" << conv_op.describe() << std::endl;
c10::optional<cudnn_frontend::Operation> bias_mult_op;
c10::optional<cudnn_frontend::Operation> sum_conv_bias_op;
if (bias_.has_value()) {
// we can't directly assign bias_mult_op becauase operator= is deleted for cudnn_frontend::Operation;
// alternatively, I think we can use std::unique_ptr and dynamically allocate these builder ops
// but here, we chose to do it statically. c10::optional<T>::emplace() enables this approach
// bias_mult_op computes bias_fp32 / (act_scale * w_scale) or bias_fp32 * (1 / (act_scale * w_scale))
// where bias_multiplier = (1 / (act_scale * w_scale))
// output is a fp32 tensor
// we use inplace operation here where the output is assigned to the input
bias_mult_op.emplace(cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(cudnn_utils::getTensorDescriptor(broadcasted_bias.value(), 'b', cudnn_utils::getAlignment(broadcasted_bias.value())))
.setbDesc(cudnn_utils::getTensorDescriptor(bias_multiplier_tensor.value(), 'c', cudnn_utils::getAlignment(bias_multiplier_tensor.value())))
.setyDesc(cudnn_utils::getTensorDescriptor(broadcasted_bias.value(), 'd', cudnn_utils::getAlignment(broadcasted_bias.value())))
.setpwDesc(cudnn_utils::getPointWiseMulDescriptor(at::native::getCudnnDataType(bias_multiplier_tensor.value())))
.build());
// computes (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)])
// where the 1st and 2nd summands is output of conv_op and broadcasted_bias, resp.
// output is a fp32 tensor
// we use inplace operation here where the output is assigned to the input
sum_conv_bias_op.emplace(cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(cudnn_utils::getTensorDescriptor(broadcasted_bias.value(), 'd', cudnn_utils::getAlignment(broadcasted_bias.value())))
// for virtual tensors, the alignment is not used, so we can just put an arbitrary value here, e.g., key.output_alignment
.setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_FLOAT, 'e', key.output_alignment, true))
.setpwDesc(cudnn_utils::getPointWiseAddDescriptor(at::native::getCudnnDataType(broadcasted_bias.value())))
.build());
}
// relu_op computes relu(act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]
// or relu(act_int8 * w_int8) if bias is not present.
// output is a fp32 tensor
c10::optional<cudnn_frontend::Operation> relu_op;
std::shared_ptr<cudnn_frontend::OpaqueBackendPointer> tensor2requant_ptr = bias_.has_value() ? sum_conv_bias_op.value().getOutputTensor() : conv_op.getOutputTensor();
if (kReluFused) {
// we use inplace operation here where the output is assigned to the input
relu_op.emplace(cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(tensor2requant_ptr)
// for virtual tensors, the alignment is not used, so we can just put an arbitrary value here, e.g., key.output_alignment
.setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_FLOAT, 'f', key.output_alignment, true))
.setpwDesc(cudnn_utils::getPointWiseReluDescriptor(CUDNN_DATA_FLOAT))
.build());
}
// relu_op computes relu(act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]) / (out_scale / (act_scale * w_scale))
// or relu(act_int8 * w_int8) / (out_scale / (act_scale * w_scale))) if bias is not present.
// output is a fp32 tensor
auto requant_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(kReluFused ? relu_op.value().getOutputTensor() : tensor2requant_ptr)
.setbDesc(cudnn_utils::getTensorDescriptor(requantize_multiplier_tensor, 's', cudnn_utils::getAlignment(requantize_multiplier_tensor)))
.setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_INT8, 'r', key.output_alignment))
.setpwDesc(cudnn_utils::getPointWiseMulDescriptor(at::native::getCudnnDataType(requantize_multiplier_tensor)))
.build();
// std::cout << "operator:" << requant_op.describe() << std::endl;
std::vector<cudnn_frontend::Operation const *> ops{&conv_op};
if (bias_.has_value()) {
ops.emplace_back(&(bias_mult_op.value()));
ops.emplace_back(&(sum_conv_bias_op.value()));
}
if (kReluFused) {
ops.emplace_back(&(relu_op.value()));
}
ops.emplace_back(&requant_op);
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(ops.size(), ops.data())
.build();
// std::cout << "opGraph: " << opGraph.describe() << std::endl;
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(opGraph)
.setHeurMode(CUDNN_HEUR_MODE_INSTANT)
.build();
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(opGraph)
.setOperation(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.build();
auto& engine_configs = heuristics.getEngineConfig(heuristics.getEngineConfigCount());
auto& fallback_list = fallback.getFallbackList();
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_utils::filterEngineConfigs(engine_configs, filtered_configs, deterministic, allow_tf32, at::kChar);
cudnn_utils::filterEngineConfigs(fallback_list, filtered_configs, deterministic, allow_tf32, at::kChar);
for (auto &cfg : engine_configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(cfg)
.build();
run(plan);
execution_plan_cache.emplace(key, plan);
return;
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
}
TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Conv2D Cudnn");
}
//
// output Tensor will be a clampped int8 Tensor
// both act and weight will be int8 Tensor
/*
Numerics:
out_fp32 = conv_fp32(act_fp32, w_fp32, …)
= act_fp32 * w_fp32 + bias_fp32
act_int8 = act_fp32 / act_scale + act_zero_point
w_int8 = w_fp32 / w_scale + w_zero_point
out_int8 = out_fp32 / out_scale + out_zero_point
out_int8 = (act_fp32 * w_fp32 + [bias_fp32]) / out_scale + out_zero_point
= (act_int8 - act_zero_point) * act_scale * (w_int8 - w_zero_point) * w_scale / out_scale + out_zero_point + [bias_fp32 / out_scale]
= (act_int8 * w_int8 - act_int8 * w_zero_point - act_zero_point * w_int8 + act_zero_point * w_zero_point) * act_scale * w_scale / out_scale + out_zero_point + [bias_fp32 / out_scale]
= (if both act and weight are symmetrically quantized, int8, then act_zero_point = w_zero_point = 0)
= (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]) * act_scale * w_scale / out_scale
= (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]) / (out_scale / (act_scale * w_scale))
= requantize((act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]), out_scale / (act_scale * w_scale))
*/
template <int kSpatialDim>
template <bool kReluFused>
at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply_impl(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point) {
const auto batch_size = kSpatialDim == 2 ? act.size(0) : 1;
const auto num_input_channels = act.size(kSpatialDim - 1);
const auto H = act.size(kSpatialDim);
const auto W = act.size(kSpatialDim + 1);
const auto num_output_channels = maybe_padded_weight_.size(0); // output channels
std::vector<int64_t> kernel_size = {maybe_padded_weight_.size(2), maybe_padded_weight_.size(3)};
auto output_shape = at::native::quantized::MakeConvOutputShape<kSpatialDim>(batch_size, num_output_channels, {H, W},
kernel_size, stride_, padding_, dilation_);
at::Tensor quantized_output = at::_empty_affine_quantized(
output_shape,
at::device(at::kCUDA).dtype(at::ScalarType::QInt8),
output_scale,
output_zero_point,
at::MemoryFormat::ChannelsLast);
// cudnn v8.4.0 expects conv2d's int8 activation tensor's input channels to be a multiple of 4. if it is not
// we need to explicitly pad it to a multiple of 4 ourselves as cudnn does not currently support padding.
// TODO: when and if cudnn enables padding in their operators, we can remove padding on our end;
// currently, limit padding support to groups=1 (ungrouped conv)
// TODO: implement this for groups > 1; should be straightforward since we're only padding a single dimension
auto act_maybe_padded = act;
if (num_input_channels % 4 != 0) {
int8_t num_slices = 4 - num_input_channels % 4; // number of slices we need to pad
act_maybe_padded = at::pad(act, {0, 0, 0, 0, 0, num_slices, 0, 0}, "constant", 0);
}
apply_impl_helper<kReluFused>(
quantized_output, act_maybe_padded.to(c10::MemoryFormat::ChannelsLast), output_scale);
// need to return sliced tensor if output_channels was padded
if (num_unpadded_output_channels_ != maybe_padded_weight_.size(0)) {
return quantized_output.slice(1, 0, num_unpadded_output_channels_);
}
return quantized_output;
}
template <int kSpatialDim>
at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) {
return apply_impl<false>(input, output_scale, output_zero_point);
}
template <int kSpatialDim>
at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply_relu(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) {
return apply_impl<true>(input, output_scale, output_zero_point);
}
template <int kSpatialDim, bool kReluFused>
class QConvInt8 final {
public:
static at::Tensor run(
at::Tensor act,
const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& packed_weight,
double output_scale,
int64_t output_zero_point) {
TORCH_CHECK(kSpatialDim == 1 || kSpatialDim == 2, "Error in quantized cudnn conv2d operator: "
"Expected kSpatialDim == 1 || kSpatialDim == 2; received kSpatialDim=", kSpatialDim);
// TODO: check all zero_points are zero/all tensors are symmetrically quantized
if (kReluFused) {
return packed_weight->apply_relu(act, output_scale, output_zero_point);
} else {
return packed_weight->apply(act, output_scale, output_zero_point);
}
}
};
TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d.new"), QConvInt8<2, false>::run);
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_relu.new"), QConvInt8<2, true>::run);
}