This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Hyperspherical Loss-Aware Ternary Quantization

Dan Liu, Xue Liu
Abstract

Most of the existing works use projection functions for ternary quantization in discrete space. Scaling factors and thresholds are used in some cases to improve the model accuracy. However, the gradients used for optimization are inaccurate and result in a notable accuracy gap between the full precision and ternary models. To get more accurate gradients, some works gradually increase the discrete portion of the full precision weights in the forward propagation pass, e.g., using temperature-based Sigmoid function. Instead of directly performing ternary quantization in discrete space, we push full precision weights close to ternary ones through regularization term prior to ternary quantization. In addition, inspired by the temperature-based method, we introduce a re-scaling factor to obtain more accurate gradients by simulating the derivatives of Sigmoid function. The experimental results show that our method can significantly improve the accuracy of ternary quantization in both image classification and object detection tasks.

Introduction

Most deep neural network (DNN) models have a huge amount of parameters, making it impractical to deploy them on edge devices. When deploying DNN models with latency, memory, and power restrictions, the inference efficiency and model size are the main obstacles. There are many studies on how to use quantization and pruning to reduce model size and computation footprint.

DNN model compression has always been an important area of research. For instance, quantization and pruning are frequently used because they can minimize model size and resource requirements. Pruning can retain higher accuracy at the expense of a longer inference time, while quantization can accelerate and compress the model at the expense of accuracy. The purpose of model quantization is to get a DNN model with maximum accuracy and minimum bit width. Low-bit quantization has the advantage of quick inference, but accuracy is often reduced by inaccurate gradients (yin2019understanding) that are estimated using discrete weights (gholami2021survey). Compared to the low-bit method, weight sharing quantization (stock2019killthebits; martinez_2020_pqf; cho2021dkm) has less model accuracy drop, while inference time is longer because the actual weight values used are remain full precision.

The discrepancy between the quantized weights in the forward pass and the full precision weights in the backward pass can be reduced by gradually discretizing the weights (louizos2018relaxed; jang2016categorical; chung2016hierarchical). However, it only works well with 4-bit (or higher) precision, as the ultra-low bit (binary or ternary) quantizer may drastically affect the weight magnitude and lead to unstable weights (gholami2021survey). Recent research indicates that the essential semantics in feature maps are preserved by direction information (liu2016large; liu2017deephyperspherical; liu2017sphereface; TimRDavidson2018HypersphericalVA; deng2019arcface; SungWooPark2019SphereGA; BeidiChen2020AngularVH).

In this work, inspired by the hyperspherical learning (liu2017deephyperspherical), we rely on the direction information of weight values to perform ternary quantization. Before performaing quantization, we use a regularization term to reduce the cosine similarity between the full precision weight values and their ternary masks under hyperspherical settings (liu2017deephyperspherical). Then we use the proposed gradient scaling method with the straight-through estimator (STE) (bengio2013estimating) to fulfill the ternary quantizaiton. The following is a summary of our contributions:

  • We propose a loss-aware ternary quantization method which uses loss regularization terms to reduce the cosine distance between full precision weight values and their ternary counterparts. Once the training is done, the weight values will be separated into three clusters, aloowing for a more accurate quantization.

  • We propose a learned quantization threshold to improve the weight sparsity and the quantization performance. We use a re-scaling method during backpropagation to simulate the temperature-based method to obtain more accurate gradients.

Related Work

To reduce model size and accelerate inference, a notable amount of research has been devoted to model quantization methods, such as binary quantization (courbariaux2015binaryconnect; rastegari2016xnor) and ternary quantization (hwang2014fixed; li2016twn; zhu2016ttq). In ternary models training, optimization is difficult because the discrete weight values hinder efficient local-search. Gradient projection method such as straight-through estimator (STE) is usually used to overcome this difficulty:

{𝐖^𝐭=Proj(𝐖𝐭)=sign(𝐖𝐭)𝐖𝐭+𝟏=𝐖𝐭ηt^L(𝐖^𝐭).\left\{\begin{array}[]{l}\mathbf{\hat{W}_{t}}=\operatorname{Proj}\left(\mathbf{{W}_{t}}\right)=\operatorname{sign}\left(\mathbf{{W}_{t}}\right)\\ \mathbf{{W}_{t+1}}=\mathbf{{W}_{t}}-\eta_{t}\hat{\nabla}L\left(\mathbf{\hat{W}_{t}}\right)\end{array}\right.. (1)

Proj\operatorname{Proj} is the projection operator and projects 𝐖\mathbf{{W}}\in\mathbb{R} to a discrete 𝐖^{0,±1}\mathbf{\hat{W}}\in\{0,\pm 1\}. The optimization of Proj\operatorname{Proj} is equivalent to:

argminα,𝐖^𝐖α𝐖^22,\underset{\alpha,\mathbf{\hat{W}}}{\arg\min}\|\mathbf{W}-\alpha\mathbf{\hat{W}}\|_{2}^{2}, (2)

where α\alpha is the scaling factor (bai2018proxquant; parikh2014proximal; li2016twn). A fixed threshold Δ\Delta is often introduced by previous works (li2016twn; zhu2016ttq) to determine the quantization intervals of Proj\operatorname{Proj} (Eq. (1)). Therefore, ternary quantization can be divided into estimation and optimization-based methods depending on how we obtain α\alpha and Δ\Delta.

Estimation-Based Ternary Quantization

The estimation-based methods, such as (li2016twn), use the approximated form Δ=0.7×𝑬(|w~l|)\Delta=0.7\times\boldsymbol{E}\left(\left|\tilde{w}_{l}\right|\right) and α=𝑬w>Δ(|w|)\alpha=\underset{w>\Delta}{\boldsymbol{E}}(|w|) as optimizing α\alpha and Δ\Delta are time consuming. (wang2018two_27) uses an alternating greedy approximation method to improve α\alpha and Δ\Delta. Given Δ\Delta the α\alpha has a closed form optimal solution α=𝑬w>Δ(|w|)\alpha=\underset{w>\Delta}{\boldsymbol{E}}(|w|) (li2016twn; zhu2016ttq). However, direct or alternating estimation is a very rough approximation for Δ\Delta. The best optimization result of Eq. (1) cannot be guaranteed.

Optimization-Based Ternary Quantization

The optimization-based method TTQ (zhu2016ttq) uses a fixed Δ=0.05×max(|𝐖|)\Delta=0.05\times\operatorname{max}(|\mathbf{W}|) and two SGD-optimized scaling factors to improve the quantization results. The intuition is straightforward: since Δ\Delta is an unstable approximation, we can use SGD to optimize the scaling factor α\alpha. The following works (yang2019quantization_22; dbouk2020dbq_20; li2020rtn_23) add more complicated scaling factors and provide corresponding optimization schemes.

There is another subclass in the optimization-based methods, which aims to manipulate the gradient and filter out the less salient weight values by adding regularization terms to the loss function. For example, (hou2018loss_31) uses second-order information to rescale the gradient to find out the weight values that are not sensitive to the ternarization, i.e. the less salient weight values. (zhou2018explicit_26.5) adds a regularization term to refine the gradient through the L1 distance. If ww is close to w^\hat{w}, the regularization should be small, otherwise the regularization should be large. The intuition is simple: if ww and w^\hat{w} are close to each other, they should not change frequently. However, the discrete ternary values are always involved in this kind of optimization-based methods, so the gradient is not accurate, which affects the optimization results. More advanced methods (dbouk2020dbq_20; yang2019quantization_22) introduce a temperature-based Sigmoid function after ternary projection in order to gradually discretize ww. Since part of the optimization process is in continuous space, more accurate aa and Δ\Delta can be obtained after the weights are fully discretized. One advantage of using the Sigmoid function is that the gradient is rescaled during backpropagation, i.e., as ww approaches w^\hat{w}, the gradient approaches 0, and as ww approaches Δ\Delta, the gradient gets larger. However, this Sigmoid-based method is very time-consuming as it performs layer-wise training and quantization. It is infeasible for very deep neural networks.

Unlike the above methods, our proposed method optimizes the distance between 𝐖\mathbf{W} and 𝐖^\mathbf{\hat{W}} in continuous space to overcome inaccurate gradients before the conventional ternary quantization. Our method has similar advantages as the Sigmoid-based methods: the use of full-precision weights to optimize the distance. Compared with their time-consuming layer-wise training strategy, our method works globally and requires much less training epoch. In addition, we propose a novel way to re-scale gradients during backpropagation.

Preliminary and Notations

Hyperspherical Networks

A hyperspherical neural network layer (liu2017deephyperspherical) is defined as:

𝐲=ϕ(𝐖𝐱),\mathbf{y}=\phi(\mathbf{W}^{\top}\mathbf{x}), (3)

where 𝐖m×n\mathbf{W}\in\mathbb{R}^{m\times{n}} is the weight matrix, 𝐱m\mathbf{x}\in\mathbb{R}^{m} is the input vector to the layer, ϕ\phi represents a nonlinear activation function, and 𝐲n\mathbf{y}\in\mathbb{R}^{n} is the output feature vector. The input vector 𝐱\mathbf{x} and each column vector 𝐰jm\mathbf{w}_{j}\in\mathbb{R}^{m} of 𝐖\mathbf{W} subject to 𝐰j2=1\|\mathbf{w}_{j}\|_{2}=1 for all j=1,,nj=1,...,n, and 𝐱2=1\|\mathbf{x}\|_{2}=1.

Ternary Quantization

To adapt to the hyperspherical settings, the ternary quantizer is defined as:

𝐖^=Ternary(𝐖,Δ)={α:wij>Δ,0:|wij|Δ,α:wij<Δ,\hat{\mathbf{W}}=\texttt{Ternary}(\mathbf{W},\Delta)=\left\{\begin{aligned} {\alpha}&:{w_{ij}}>~{}~{}~{}\Delta,\\ 0&:\left|{w_{ij}}\right|\leq~{}\Delta,\\ -{\alpha}&:{w_{ij}}<-\Delta,\end{aligned}\right. (4)
s.t.α=1𝐰j0,s.t.~{}~{}\alpha=\frac{1}{\sqrt{\|\mathbf{w}_{j}\|_{0}}}, (5)

where Δ\Delta is a quantization threshold, 𝐰j0\|\mathbf{w}_{j}\|_{0} denotes the number of non-zero values in 𝐰j\mathbf{w}_{j}, and 𝐰^j2=1\|\mathbf{\hat{w}}_{j}\|_{2}=1. The ternary hyperspherical layer has the following property: ϕ(𝐰^j𝐱)=ϕ(α𝐰¯j𝐱),\phi(\mathbf{\hat{w}}_{j}^{\top}\mathbf{x})=\phi(\alpha\mathbf{\bar{w}}_{j}^{\top}\mathbf{x}), where 𝐰¯{1,0,1}\mathbf{\bar{w}}\in{\{-1,0,1\}}.

Loss-Aware Ternary Quantization

The magnitude and direction information of weight vectors change dramatically during ternary quantization. Hyperspherical neural networks can work without taking the magnitude change into account. Combining ternary quantization with hyperspherical learning may offer stable features beneficial for quantization performance.

In this section, to relieve the magnitude and direction deviation in the subsequent quantization procedure, we first introduce a regularization term to push full precision weights toward their ternary counterparts. Then, a re-scaling factor is applied during the quantization stage to obtain more accurate gradients. We also propose a learned threshold to facilitate the quantization process.

Refer to caption
(a) tt=0
Refer to caption
(b) tt=0.5
Refer to caption
(c) tt=0.7
Refer to caption
(d) tt=0.9
Figure 1: The weight distribution of a layer after applying the proposed regularization term LdL_{d} (Eq. (6)). (a) is a ResNet-18 baseline model. Prior to ternary quantization, the regularisation term LdL_{d} divides the full precision weight values into three clusters. As the sparsity factor tt increases, more weight values are moving closer to zero.

Pushing 𝐖\mathbf{W} close to 𝐖^\mathbf{\hat{W}} before Quantization

Given a regular objective function LL, we formulate the optimization process as:

min𝐖J(𝐖)=L(𝐖)+Ld(𝐖,Δ).\min_{\mathbf{W}}J(\mathbf{W})=L(\mathbf{W})+L_{d}(\mathbf{W},\Delta). (6)

The regularization term LdL_{d} is defined as:

Ld(𝐖,Δ)=1n(diag(𝐖𝐖^𝐈))2L_{d}(\mathbf{W},\Delta)=\frac{1}{n}\left(\texttt{diag}(\mathbf{W}^{\top}\mathbf{\hat{W}}-\mathbf{I})\right)^{2} (7)
s.t.𝐖^=Ternary(𝐖,Δ),s.t.~{}~{}\mathbf{\hat{W}}=\texttt{Ternary}(\mathbf{W},\Delta),

where 𝐈\mathbf{I} is the identity matrix, and diag()\texttt{diag}(\cdot) returns the diagonal elements of a matrix. The threshold Δ\Delta is:

Δ=T(t)=ValueAtIndex(sort(|𝐖|,idx),\Delta=\texttt{T}(t)=\texttt{ValueAtIndex}(\texttt{sort}(|\mathbf{W}|,\texttt{idx}), (8)
s.t.idx=t×m×n,0<t<1,s.t.~{}~{}\texttt{idx}=\lfloor t\times~{}m\times~{}n\rfloor,0<t<1,

where tt controls the sparsity, i.e., percentage of zero values, of 𝐖^\mathbf{\hat{W}} (Figure 1), and ValueAtIndex([],idx)\texttt{ValueAtIndex}([\cdot],\texttt{idx}) returns a value of [][\cdot] at the index idx. During training, based on the work of (liu2018rethinking), we initially assign t=0.5t=0.5, namely, removing 50%50\% values by their magnitude in 𝐖\mathbf{W} as the potential ternary references 𝐖^\mathbf{\hat{W}}, then tt is gradually increased to t=0.7t=0.7, i.e., 70% of 𝐖^\mathbf{\hat{W}} is zero. We use the quadratic term to keep L(𝐖)L(\mathbf{W}) and Ld(𝐖,Δ)L_{d}(\mathbf{W},\Delta) at the same scale. Since J(𝐖)J(\mathbf{W}) consists of two parts and is optimized in continuous space, the distance between 𝐖\mathbf{W} and 𝐖^\mathbf{\hat{W}} will gradually decrease without much change in model accuracy. In practice, we observe that applying Ld(𝐖,Δ)L_{d}(\mathbf{W},\Delta) barely reduces the model accuracy.

Since we are training with hyperspherical settings, i.e., 𝐰j2=𝐰^j2=1\|\mathbf{w}_{j}\|_{2}=\|\mathbf{\hat{w}}_{j}\|_{2}=1, the diagonal elements of 𝐖𝐖^\mathbf{W}^{\top}\mathbf{\hat{W}} in Eq. (7) denotes the cosine similarities between 𝐰𝐣\mathbf{w_{j}} and 𝐰^𝐣\mathbf{\hat{w}_{j}}. Minimizing LdL_{d} is equivalent to pushing 𝐰j\mathbf{w}_{j} close to 𝐰^j\mathbf{\hat{w}}_{j}, namely, making part of the magnitude of weight values close to α\alpha (Eq. (4)) while the rest close to zero (Figure 1). In addition, the gradient of LdL_{d} can adjust J𝐖\frac{\partial J}{\partial{\mathbf{W}}}, which is similar to the works of (hou2018loss_31; yang2019quantization_22; zhou2018explicit_26.5; dbouk2020dbq_20).

Rescaling the Gradient During Quantization

Refer to caption
Figure 2: Function ff and gg share the similar shape. S=dfdwS=\frac{df}{dw} is an alteration of dgdw\frac{dg}{dw} to re-scale the gradient and is much easier to calculate.

With 𝐖\mathbf{W} getting close to 𝐖^\hat{\mathbf{W}} before the ternary quantization, we still need to convert the full precision weight to ternary values. We change the Δ\Delta to a learnable threshold Δ¯\bar{\Delta} and empirically initialize Δ¯=T(t+0.1)\bar{\Delta}=\texttt{T}(t+0.1) via Eq. (8). The Eq. (6) becomes:

min𝐖,Δ¯J(𝐖,Δ¯)=L(𝐖^)+Ld(𝐖,Δ¯).\min_{\mathbf{W},\bar{\Delta}}J(\mathbf{W},\bar{\Delta})=L(\mathbf{\hat{W}})+L_{d}(\mathbf{W},\bar{\Delta}). (9)

We introduce a gradient scaling factor SS and adopt STE (bengio2013estimating) to bypass the non-differentiable problem of Ternary()\texttt{Ternary}(\cdot):

Forward:

𝐖^=Ternary(𝐖,Δ¯);\displaystyle\begin{split}\hat{\mathbf{W}}={\texttt{Ternary}(\mathbf{W},\bar{\Delta})};\end{split} (10)

Backward:

J𝐖=J𝐖^𝐖^𝐖ST¯EJ𝐖^×S,\frac{\partial J}{\partial{\mathbf{W}}}=\frac{\partial J}{\partial{\hat{\mathbf{W}}}}\frac{\partial\hat{\mathbf{W}}}{\partial{\mathbf{W}}}\underset{S\bar{T}E}{\approx}\frac{\partial J}{\partial{\hat{\mathbf{W}}}}\times S, (11)

where SS is defined as:

S=𝟏𝐖𝐖,S=\mathbf{1}-{\mathbf{W}\odot\mathbf{W}}, (12)

and \odot denotes the element-wise multiplication. Inspired by the work of (yang2019quantization_22), 1w21-w^{2} is used to re-scale the gradient of wijw_{ij} as it shares similarities with the derivative of Sigmoid()\texttt{Sigmoid}(\cdot) (Figure 2). SS is intuitive and easy to calculate. The gradient of wijw_{ij} should be smaller when it gets farther away from Δ¯\bar{\Delta}. This intuition is the same as the works of (hou2018loss_31; zhou2018explicit_26.5; dbouk2020dbq_20).

We simply apply the averaged gradients of wijw_{ij} to update the quantization threshold Δ¯\bar{\Delta}:

LΔ¯=1m×ni=1mj=1nLwij,\frac{\partial L}{\partial\bar{\Delta}}=\frac{1}{m\times n}\sum_{i=1}^{m}\sum_{j=1}^{n}{\frac{\partial L}{\partial w_{ij}}}, (13)
s.t.wij𝐖 and wij0.s.t.~{}~{}w_{ij}\in\mathbf{W}\text{ and }w_{ij}\neq 0.

Figure 1 shows that we can initialize Δ¯\bar{\Delta} with a smaller value and then gradually increase it until the near-zero portion is converted to zero during quantization. We should increase the threshold Δ¯\bar{\Delta} rapidly when the error is very large. As the training error becomes smaller, such increase should slow down.

Implementation Details

Training Algorithm

The proposed method is in Algorithm 1. We gradually increase the tt from t1=0.3t_{1}=0.3 to t2=0.7t_{2}=0.7 based on (liu2018rethinking). t=0.7t=0.7 gives 𝐖^\mathbf{\hat{W}} 70% sparsity (Eq. (8)). The overall process can be summarized as: i) Fine-tuning from pre-trained model weights with hyperspherical learning architecture (liu2017deephyperspherical); ii) Initialize Δ¯=T(t2+0.1)\bar{\Delta}=\texttt{T}(t_{2}+0.1) to convert near-zero weight values to zero; iii) Ternary quantization, updating the weights and Δ¯\bar{\Delta} through STE. Note that Δ\Delta is a global magnitude threshold (blalock2020state) and Δ¯\bar{\Delta} is initialized globally and updated by SGD in layer-wise manner.

Algorithm 1 HLA training approach
1:Input: Input 𝐱\mathbf{x}, a hyperspherical neural layer ϕ()\phi(\cdot), t1=0.3t_{1}=0.3 and t2=0.7t_{2}=0.7.
2:Result: Quantized ternary networks for inference
3:1. Fine-tuning:
4:t=t1t=t_{1}
5:while  t<t2t<t_{2}  do
6:     Δ=T(t)\Delta=\texttt{T}(t) \triangleright Update Δ\Delta
7:     while  not converged  do
8:          y=ϕ(𝐖𝐱)y=\phi(\mathbf{W}^{\top}\mathbf{x}) \triangleright Hyperspherical training
9:          Minimize Eq. (6)
10:          Perform SGD, calculate J𝐖\frac{\partial J}{\partial{\mathbf{W}}}, and update 𝐖\mathbf{W}
11:     end while
12:     t+=0.04t+=0.04 \triangleright Increase the tt
13:end while
14:2. Ternary Quantization:
15:Δ¯=T(t2+0.1)\bar{\Delta}=\texttt{T}(t_{2}+0.1)
16:while not converged do
17:      𝐖^=Ternary(𝐖,Δ¯)\mathbf{\hat{W}}=\texttt{Ternary}(\mathbf{W},\bar{\Delta})
18:      y=ϕ(𝐖^𝐱)y=\phi({\mathbf{\hat{W}}}^{\top}\mathbf{x})
19:     Minimize Eq. (9)
20:      Get J𝐖\frac{\partial J}{\partial{\mathbf{{W}}}}, JΔ¯\frac{\partial J}{\partial{\bar{\Delta}}} via SGD; update 𝐖\mathbf{{W}}, Δ¯\bar{\Delta} \triangleright Eq. (11,13)
21:end while

Training Time

When training the ResNet-18 model with 8×\timesV100 and mixed-precision (16-bit), each epoch takes about 6 minutes. It takes about 100 epochs to obtain a hyperspherical ready-to-quantize model. The inner SGD loop (Line 7-11 in Algorithm 1) takes about 10 epochs. The ternary quantization loop (Line 16-21 in Algorithm 1) takes about 200 epochs.

Discussion

Our work shows that the model’s angular information, i.e., the cosine similarity, connects sparsity and quantization on the hypersphere. We demonstrate how gradually adjusting the sparsity constraints can facilitate ternary quantization.

Most ternary quantization works try to directly project full precision models into ternary ones. However, the discrepancy between the full precision and ternary values before quantization is more or less ignored. Only a few approaches (ding2017three_40; hu2019cluster_24) attempt to minimize the discrepancy prior to quantization. Intuitively, the gradients during ternary quantization are inaccurate; therefore, it is inapproriate to use such gradients for distance optimization. Accurate gradients, which are produced by full precision weights, are better than estimated gradients in terms of reducing discrepancy. Our work shows that by making the full precision weights close to the ternary in the initial steps, we can improve the performance of ternary quantization.

In mainstream ternary quantization works, such as TWN (li2016twn) and TTQ (zhu2016ttq), due to the non-differentiable thresholds and unstable weight magnitudes, fixed thresholds and learned scaling factors are used to optimize ternary boundaries to determine which values should be zero. In our work, the clustered weight values with hyperspherical learning allows us to use the average gradient to update the quantization threshold Δ¯\bar{\Delta} directly, as the thresholding only needs a tiny increment to push the less important weights to zero.

Compared to temperature-based works (yang2019quantization_22) using Sigmoid function for projection and layer-wise training, our proposed method is simple yet effective: by combining a bell-shaped gradient re-scaling factor SS with ternary quantization, we achieve a better efficiency-accuracy trade-off. In addition, our method does not need too many hyperparameters and sophisticated learned variables (li2020rtn_23; yang2019quantization_22). Compared to other loss-aware ternary quantization works (zhou2018explicit_26.5; hou2018loss_31) using estimated gradients to optimize the loss penalty term, our method minimizes the loss function by using full precision weights before quantization, which is more accurate and efficient.

Experiments

Our method is evaluated on image classification and object detection tasks with the ImageNet dataset (ILSVRC15). The ResNet-18/50 (he2016deep) and MobileNetV2 (sandler2018mobilenetv2) architectures are for image classification. The Mask R-CNN (he2017mask) with ResNet-50 FPN (wu2019detectron2) is for object detection. The object detection tasks are performed on the MS COCO (lin2014mscoco) dataset. The pre-trained weights are provided by the PyTorch zoo and Detectron2 (wu2019detectron2).

Experiment Setup

For image classification, the batch size is 256. The weight decay is 0.0001, and the momentum of stochastic gradient descent (SGD) is 0.9. We use the cosine annealing schedule with restarts (loshchilov2016sgdr) to adjust the learning rates. The initial learning rate is 0.01. All of the experiments use 16-bit Automatic Mixed Precision (AMP) from PyTorch to accelerate the training process. The first convolutional layer and last fully-connected layer are skipped when quantizing ResNet models. For the MobileNetV2, and Mask R-CNN, we only skip the first convolutional layer.

Image Classification