Shortcuts

Source code for torch.nn.quantized.modules.conv

# coding=utf-8
r"""Quantized convolution modules."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from typing import Optional, List

import torch
import torch.nn as nn
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.qat as nniqat

from torch._ops import ops
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.quantized.modules.utils import _pair_from_first
from torch.nn.quantized.modules.utils import _quantize_weight
from torch.nn.utils import fuse_conv_bn_weights

class _ConvNd(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation,
                 transposed, output_padding,
                 groups, bias,
                 padding_mode='zeros'):
        super(_ConvNd, self).__init__()
        if padding_mode != 'zeros':
            raise NotImplementedError(
                "Currently only zero-padding is supported by quantized conv")
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        self.padding_mode = padding_mode
        # Initialize as NCHW. set_weight will internally transpose to NHWC.
        qweight = torch._empty_affine_quantized(
            [out_channels, in_channels // self.groups] + list(kernel_size),
            scale=1, zero_point=0, dtype=torch.qint8)
        bias_float = (
            torch.zeros(out_channels, dtype=torch.float) if bias else None)

        self.set_weight_bias(qweight, bias_float)
        self.scale = 1.0
        self.zero_point = 0

    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}, scale={scale}, zero_point={zero_point}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias() is None:
            s += ', bias=False'
        return s.format(**self.__dict__)

    # ===== Serialization methods =====
    # The special consideration here is that we have to unpack the weights into
    # their regular QTensor form for serialization. Packed weights should not
    # live outside the process in which they were created, rather they should be
    # derived from the QTensor weight.
    def _save_to_state_dict(self, destination, prefix, keep_vars):
        super(_ConvNd, self)._save_to_state_dict(destination, prefix, keep_vars)
        (w, b) = self._weight_bias()
        destination[prefix + 'weight'] = w
        destination[prefix + 'scale'] = torch.tensor(self.scale)
        destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
        destination[prefix + 'bias'] = b

    @torch.jit.export
    def __getstate__(self):
        if not torch.jit.is_scripting():
            raise RuntimeError(
                'torch.save() is not currently supported for quantized modules.'
                ' See https://github.com/pytorch/pytorch/issues/24045.'
                ' Please use state_dict or torch.jit serialization.')
        (w, b) = self._weight_bias()
        return (
            self.in_channels,
            self.out_channels,
            self.kernel_size,
            self.stride,
            self.padding,
            self.dilation,
            self.transposed,
            self.output_padding,
            self.groups,
            self.padding_mode,
            w,
            b,
            self.scale,
            self.zero_point,
            self.training
        )

    # ===== Deserialization methods =====
    # Counterpart to the serialization methods, we must pack the serialized
    # QTensor weight into its packed format for use by the FBGEMM ops.
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        self.set_weight_bias(
            state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
        state_dict.pop(prefix + 'weight')
        state_dict.pop(prefix + 'bias')
        self.scale = float(state_dict[prefix + 'scale'])
        state_dict.pop(prefix + 'scale')
        self.zero_point = int(state_dict[prefix + 'zero_point'])
        state_dict.pop(prefix + 'zero_point')
        super(_ConvNd, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, False, missing_keys,
            unexpected_keys, error_msgs)

    @torch.jit.export
    def __setstate__(self, state):
        self.in_channels = state[0]
        self.out_channels = state[1]
        self.kernel_size = state[2]
        self.stride = state[3]
        self.padding = state[4]
        self.dilation = state[5]
        self.transposed = state[6]
        self.output_padding = state[7]
        self.groups = state[8]
        self.padding_mode = state[9]
        self.set_weight_bias(state[10], state[11])
        self.scale = state[12]
        self.zero_point = state[13]
        self.training = state[14]


class Conv1d(_ConvNd):
    r"""Applies a 1D convolution over a quantized input signal composed of
    several quantized input planes.

    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.Conv1d`.

    .. note::
        Only `zeros` is supported for the :attr:`padding_mode` argument.

    .. note::
        Only `torch.quint8` is supported for the input data type.


    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point

    See :class:`~torch.nn.Conv1d` for other attributes.

    Examples::

        >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
        >>> input = torch.randn(20, 16, 100)
        >>> # quantize input to quint8
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0,
                                                dtype=torch.quint32)
        >>> output = m(q_input)

    """

    _FLOAT_MODULE = nn.Conv1d

    # We are using Conv2d to run the Conv1d. For that we need to know which
    # dimension is squeezed/unsqueezed.
    _SQUEEZE_DIM = -2  # -2 is faster than -1.

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True,
                 padding_mode='zeros'):
        kernel_size = list(_pair_from_first(kernel_size))
        kernel_size[self._SQUEEZE_DIM] = 1
        kernel_size = tuple(kernel_size)

        stride = list(_pair_from_first(stride))
        stride[self._SQUEEZE_DIM] = 1
        stride = tuple(stride)

        padding = list(_pair_from_first(padding))
        padding[self._SQUEEZE_DIM] = 0
        padding = tuple(padding)

        dilation = list(_pair_from_first(dilation))
        dilation[self._SQUEEZE_DIM] = 1
        dilation = tuple(dilation)

        super(Conv1d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _single(0), groups, bias, padding_mode)

    def _get_name(self):
        return 'QuantizedConv1d'

    def set_weight_bias(self, w, b):
        # type: (torch.Tensor, Optional[torch.Tensor]) -> None
        # Check if need to unsqueeze
        if len(w.shape) == 3:
            w = w.unsqueeze(dim=self._SQUEEZE_DIM)
        self._packed_params = torch.ops.quantized.conv2d_prepack(
            w, b, self.stride, self.padding, self.dilation, self.groups)

    def _weight_bias(self):
        w, b = torch.ops.quantized.conv2d_unpack(self._packed_params)
        w = w.squeeze(dim=self._SQUEEZE_DIM)
        return w, b

    def weight(self):
        return self._weight_bias()[0]

    def bias(self):
        return self._weight_bias()[1]

    def forward(self, input):
        # Temporarily using len(shape) instead of ndim due to JIT issue
        # https://github.com/pytorch/pytorch/issues/23890
        if len(input.shape) != 3:
            raise ValueError("Input shape must be `(N, C, L)`!")
        input = input.unsqueeze(dim=self._SQUEEZE_DIM)
        output = ops.quantized.conv2d(
            input, self._packed_params, self.stride, self.padding,
            self.dilation, self.groups, self.scale, self.zero_point)
        return output.squeeze(dim=self._SQUEEZE_DIM)

    @classmethod
    def from_float(cls, mod):
        r"""Creates a quantized module from a float module or qparams_dict.

        Args:
            mod (Module): a float module, either produced by torch.quantization
              utilities or provided by the user
        """
        assert type(mod) == cls._FLOAT_MODULE, \
            ' nnq.' + cls.__name__ + '.from_float only works for ' + \
            cls._FLOAT_MODULE.__name__
        assert hasattr(mod, 'qconfig'), \
            'Input float module must have qconfig defined.'
        weight_post_process = mod.qconfig.weight()
        weight_post_process(mod.weight)
        act_scale, act_zp = mod.activation_post_process.calculate_qparams()
        assert weight_post_process.dtype == torch.qint8, \
            'Weight observer must have a dtype of qint8'
        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
                    mod.stride, mod.padding, mod.dilation, mod.groups,
                    mod.bias is not None, mod.padding_mode)
        qconv.set_weight_bias(qweight, mod.bias)
        qconv.scale = float(act_scale)
        qconv.zero_point = int(act_zp)

        return qconv


[docs]class Conv2d(_ConvNd): r"""Applies a 2D convolution over a quantized input signal composed of several quantized input planes. For details on input arguments, parameters, and implementation see :class:`~torch.nn.Conv2d`. .. note:: Only `zeros` is supported for the :attr:`padding_mode` argument. .. note:: Only `torch.quint8` is supported for the input data type. Attributes: weight (Tensor): packed tensor derived from the learnable weight parameter. scale (Tensor): scalar for the output scale zero_point (Tensor): scalar for the output zero point See :class:`~torch.nn.Conv2d` for other attributes. Examples:: >>> # With square kernels and equal stride >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) >>> # non-square kernels and unequal stride and with padding and dilation >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) >>> input = torch.randn(20, 16, 50, 100) >>> # quantize input to quint8 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) >>> output = m(q_input) """ _FLOAT_MODULE = nn.Conv2d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super(Conv2d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode) def _get_name(self): return 'QuantizedConv2d' def set_weight_bias(self, w, b): # type: (torch.Tensor, Optional[torch.Tensor]) -> None self._packed_params = torch.ops.quantized.conv2d_prepack( w, b, self.stride, self.padding, self.dilation, self.groups) def _weight_bias(self): return torch.ops.quantized.conv2d_unpack(self._packed_params) def weight(self): (w, _) = torch.ops.quantized.conv2d_unpack(self._packed_params) return w def bias(self): (_, b) = torch.ops.quantized.conv2d_unpack(self._packed_params) return b def forward(self, input): # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: raise ValueError("Input shape must be `(N, C, H, W)`!") return ops.quantized.conv2d( input, self._packed_params, self.stride, self.padding, self.dilation, self.groups, self.scale, self.zero_point)
[docs] @classmethod def from_float(cls, mod): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ if hasattr(mod, 'weight_fake_quant'): # assert type(mod) == cls.__QAT_MODULE, ' nnq.' + cls.__name__ + \ # '.from_float only works for ' + cls.__QAT_MODULE.__name__ if type(mod) == nniqat.ConvBn2d: mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.running_mean, mod.running_var, mod.eps, mod.gamma, mod.beta) assert hasattr(mod, 'activation_post_process'), \ 'Input QAT module must have observer attached' weight_post_process = mod.weight_fake_quant activation_post_process = mod.activation_post_process else: assert type(mod) == cls._FLOAT_MODULE, \ ' nnq.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ assert hasattr(mod, 'qconfig'), \ 'Input float module must have qconfig defined.' # workaround for sequential, ConvReLU2d should probably # inherit from Conv2d instead if type(mod) == nni.ConvReLU2d: activation_post_process = mod[1].activation_post_process mod = mod[0] else: activation_post_process = mod.activation_post_process weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) act_scale, act_zp = activation_post_process.calculate_qparams() assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias is not None, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias) qconv.scale = float(act_scale) qconv.zero_point = int(act_zp) return qconv
[docs]class Conv3d(_ConvNd): r"""Applies a 3D convolution over a quantized input signal composed of several quantized input planes. For details on input arguments, parameters, and implementation see :class:`~torch.nn.Conv3d`. .. note:: Only `zeros` is supported for the :attr:`padding_mode` argument. .. note:: Only `torch.quint8` is supported for the input data type. Attributes: weight (Tensor): packed tensor derived from the learnable weight parameter. scale (Tensor): scalar for the output scale zero_point (Tensor): scalar for the output zero point See :class:`~torch.nn.Conv3d` for other attributes. Examples:: >>> # With square kernels and equal stride >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)) >>> # non-square kernels and unequal stride and with padding and dilation >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2)) >>> input = torch.randn(20, 16, 56, 56, 56) >>> # quantize input to quint8 >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) >>> output = m(q_input) """ _FLOAT_MODULE = nn.Conv3d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): kernel_size = _triple(kernel_size) stride = _triple(stride) padding = _triple(padding) dilation = _triple(dilation) super(Conv3d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _triple(0), groups, bias, padding_mode) def _get_name(self): return 'QuantizedConv3d' def set_weight_bias(self, w, b): # type: (torch.Tensor, Optional[torch.Tensor]) -> None self._packed_params = torch.ops.quantized.conv3d_prepack( w, b, self.stride, self.padding, self.dilation, self.groups) def _weight_bias(self): return torch.ops.quantized.conv3d_unpack(self._packed_params) def weight(self): (w, _) = torch.ops.quantized.conv3d_unpack(self._packed_params) return w def bias(self): (_, b) = torch.ops.quantized.conv3d_unpack(self._packed_params) return b def forward(self, input): # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 5: raise ValueError("Input shape must be `(N, C, D, H, W)`!") return ops.quantized.conv3d( input, self._packed_params, self.stride, self.padding, self.dilation, self.groups, self.scale, self.zero_point)
[docs] @classmethod def from_float(cls, mod): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ assert type(mod) == cls._FLOAT_MODULE, \ ' nnq.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ assert hasattr(mod, 'qconfig'), \ 'Input float module must have qconfig defined.' # Workaround for sequential, ConvReLU3d should probably inherit from # Conv3d instead if type(mod) == nni.ConvReLU3d: activation_post_process = mod[1].activation_post_process mod = mod[0] else: activation_post_process = mod.activation_post_process weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) act_scale, act_zp = activation_post_process.calculate_qparams() assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias is not None, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias) qconv.scale = float(act_scale) qconv.zero_point = int(act_zp) return qconv
# === Transposed Convolutions === # Note: The MRO should make sure that the `super` in the `_ConvNd` will be # called, while the one in the `nn._ConvTransposeNd` it won't. class _ConvTransposeNd(_ConvNd, nn.modules.conv._ConvTransposeNd): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode): if padding_mode != 'zeros': raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__)) super(_ConvTransposeNd, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode) def _input_padding(self, kernel_size, dilation, padding): # type: (List[int], List[int], List[int]) -> List[int] res = torch.jit.annotate(List[int], []) for kdx in range(len(kernel_size)): pad = (dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx]) res.append(pad) return res

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources