This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

DiscQuant: A Quantization Method for
Neural Networks Inspired by Discrepancy Theory

Jerry Chee Department of Computer Science, Cornell University Arturs Backurs Microsoft Research, Microsoft Rainie Heck Department of Mathematics, University of Washington Li Zhang Microsoft Research, Microsoft Janardhan Kulkarni Microsoft Research, Microsoft Thomas Rothvoss Department of Mathematics, University of Washington Sivakanth Gopi Microsoft Research, Microsoft
Abstract

Quantizing the weights of a neural network has two steps: (1) Finding a good low bit-complexity representation for weights (which we call the quantization grid) and (2) Rounding the original weights to values in the quantization grid. In this paper, we study the problem of rounding optimally given any quantization grid. The simplest and most commonly used way to round is Round-to-Nearest (RTN). By rounding in a data-dependent way instead, one can improve the quality of the quantized model significantly.

We study the rounding problem from the lens of discrepancy theory, which studies how well we can round a continuous solution to a discrete solution without affecting solution quality too much. We prove that given m=poly(1/ε)m=\mathrm{poly}(1/\varepsilon) samples from the data distribution, we can round all but O(m)O(m) model weights such that the expected approximation error of the quantized model on the true data distribution is ε\leq\varepsilon as long as the space of gradients of the original model is approximately low rank (which we empirically validate).

Our proof, which is algorithmic, inspired a simple and practical rounding algorithm called DiscQuant. In our experiments, we demonstrate that DiscQuant significantly improves over the prior state-of-the-art rounding method called GPTQ and the baseline RTN over a range of benchmarks on Phi3mini-3.8B and Llama3.1-8B. For example, rounding Phi3mini-3.8B to a fixed quantization grid with 3.25 bits per parameter using DiscQuant gets 64% accuracy on the GSM8k dataset, whereas GPTQ achieves 54% and RTN achieves 31% (the original model achieves 84%). We make our code available at https://github.com/jerry-chee/DiscQuant.

1 Introduction

Modern deep learning models continue to grow in size, incurring greater challenges to train and serve these models. Post training compression methods have emerged which aim to make model inference faster and cheaper. Compressing after pretraining is desirable among practitioners who either cannot afford to train models themselves, or do not want to change the expensive training process too much. In this paper, we study post training quantization (PTQ) of the model weights. Quantization reduces the memory requirements of the model, and speeds up inference for LLMs under memory-bound settings such as the generation phase (as opposed to prefilling phase which is compute-bound) (Kwon et al., 2023).

The quantization problem can be divided into two overall steps: (1) Construct a good low bit-complexity representation for the weights (we colloquially call this the quantization grid), and (2) Round the original weights to values in the quantization grid. Within step (1), we also consider those methods which apply a transformation on the weights to better match the encoding format. There has been much recent work on weights-only PTQ for LLMs. To date, the vast majority of such research has been focused on step (1): constructing good low bit representations (Shao et al., 2024; Tseng et al., 2024a; Egiazarian et al., 2024). However, work on rounding methods is under-explored. To the best of our knowledge, Round-to-Nearest (RTN) and GPTQ (Hassibi et al., 1993; Frantar et al., 2022, 2023) are the primary rounding methods for LLM weight quantization. RTN is a simple baseline, and GPTQ is a data dependent method which aims to match the activations of the quantized model with that of the original model layer-by-layer.

Let f(w;s)f(w;s) be the loss function of a neural network where ww are original pretrained weights and ss is an input sample; for example ff can be the usual cross-entropy loss on input ss. To find a good rounding solution, we are looking for perturbations of the original weights wnw\in\mathbb{R}^{n} that correspond to values in the quantization grid, and do not increase the loss ff too much. We further impose the constraint that we only round each parameter up or down, this ensures that we are not changing the original model weights too much. Then the set of allowed quantization points can be pictured as vertices of a hypercube HH around ww. Let w^n\hat{w}\in\mathbb{R}^{n} be these perturbed weights, and Δf=f(w^;s)f(w;s)\Delta f=f(\hat{w};s)-f(w;s) be the resulting change in loss function for a sample ss. We approximate Δf\Delta f via a first order Taylor expansion: Δfwf(w;s),w^w\Delta f\approx\langle\nabla_{w}f(w;s),\hat{w}-w\rangle. Some prior works such as Nagel et al. (2020); Hassibi et al. (1993) assume the gradients of a pretrained model to be nearly zero, and focus on the second order terms. We show that this assumption is not always true, the average gradients are close to zero but per-sample gradients can be big; in fact the first order term is a good approximation to Δf\Delta f (see Figure 4).

KKwwVVHH
Figure 1: An illustrative figure showing the convex polytope KK formed by the intersection of an nn-dimensional hypercube HH and an nmn-m dimensional affine subspace VV. Any vertex of KK should have nmn-m coordinates which are fully rounded.

Therefore, to incur a small Δf\Delta f, we want wf(w;s),w^w0\langle\nabla_{w}f(w;s),\hat{w}-w\rangle\approx 0 for ss sampled from the data distribution 𝒟data\mathcal{D}_{\textrm{data}}. Suppose we are given mm independent samples s1,s2,,sm𝒟datas_{1},s_{2},\dots,s_{m}\sim\mathcal{D}_{\textrm{data}}, we can impose the constraints wf(w;s),w^w=0\langle\nabla_{w}f(w;s),\hat{w}-w\rangle=0 which correspond to an affine subspace VV of dimension nmn-m. The intersection of the subspace VV and the hypercube HH is a convex polytope KK. It can be shown that any vertex of KK should have at least nmn-m fully rounded parameters, see Figure 1 for an illustration. Since the number of parameters nmn\gg m, any vertex of KK gives an almost fully rounded solution. Obviously this solution satisfies the linear constraints for the samples s1,s2,,sms_{1},s_{2},\dots,s_{m}. But will it generalize to unseen samples from the data distribution 𝒟data\mathcal{D}_{\textrm{data}}? We prove that it can generalize if the distribution of gradients g=wf(w;s)g=\nabla_{w}f(w;s) for s𝒟datas\sim\mathcal{D}_{\textrm{data}} is approximately low rank. Let Σ=𝔼s𝒟data[ggT]\Sigma=\mathbb{E}_{s\sim\mathcal{D}_{\textrm{data}}}[gg^{T}] where g=wf(W;s)g=\nabla_{w}f(W;s) be the covariance matrix of gradients. We prove the following theorem; the algorithm and the proof draws on techniques from discrepancy theory, in particular the famous Lovett-Meka algorithm (Lovett and Meka, 2012).

Theorem 1.1 (Informal).

If the eigenvalues of the covariance matrix of gradients decay polynomially fast, then given m=poly(lognε)m=\mathrm{poly}\left(\frac{\log n}{\varepsilon}\right) samples s1,s2,,sm𝒟datas_{1},s_{2},\dots,s_{m}\sim\mathcal{D}_{\textrm{data}} there is a randomized algorithm to find w^\hat{w} with nmn-m weights rounded such that Es𝒟data[|Δf|]ε.E_{s\sim\mathcal{D}_{\textrm{data}}}[|\Delta f|]\leq\varepsilon.

From these insights we develop a practical rounding algorithm called DiscQuant. The Lovett-Meka algorithm does a random walk starting from the original weights until it converges to a vertex of KK. Instead, we can find a vertex of KK by minimizing a linear function over the convex polytope KK. DiscQuant uses stochastic gradient descent to minimize two objectives, one corresponding to low Δf\Delta f, and the other corresponding to minimizing a linear function. We take a knowledge distillation approach for the first term, minimizing the KL divergence between the original and quantized model. These two losses are balanced with a regularization parameter λ>0\lambda>0:

minw^λc,w^+𝔼z𝒟data𝔼i[DKL(pw(|z<i)pw^(|z<i))]\displaystyle\min_{\hat{w}}\lambda\left\langle c,\hat{w}\right\rangle+\mathbb{E}_{z\sim\mathcal{D}_{\textrm{data}}}\mathbb{E}_{i}[D_{KL}\left(p_{w}(\cdot|z_{<i})\left\|\right.p_{\hat{w}}(\cdot|z_{<i})\right)] (1)
s.t.w^H.\displaystyle s.t.\ \hat{w}\in H.

Here pw(|z<i)p_{w}(\cdot|z_{<i}) is the next token distribution given prefix z<iz_{<i}. An astute reader may notice that the first order approximation of the KL divergence in (1) is exactly zero, and how our discussion above applies. In Section 4 where we describe in detail our exact optimization objective, we also show that the second order term of KL divergence can be written as

𝔼z𝒟data𝔼i𝔼tpw(|z<i)[wlogpw(t|z<i),w^w2].\mathbb{E}_{z\sim\mathcal{D}_{\textrm{data}}}\mathbb{E}_{i}\mathbb{E}_{t\sim p_{w}(\cdot|z_{<i})}\left[\left\langle\nabla_{w}\log p_{w}(t|z_{<i}),\hat{w}-w\right\rangle^{2}\right].

So minimizing the KL divergence is a succinct way to impose constraints of the form wlogpw(t|z<i),w^w0\langle\nabla_{w}\log p_{w}(t|z_{<i}),\hat{w}-w\rangle\approx 0 or equivalently logpw(t|z<i)logpw^(t|z<i)\log p_{w}(t|z_{<i})\approx\log p_{\hat{w}}(t|z_{<i}) where tpw(|z<i)t\sim p_{w}(\cdot|z_{<i}) and z𝒟dataz\sim\mathcal{D}_{\textrm{data}}. Therefore our framework still applies.

After every step of gradient descent, we project the weights back to the hypercube HH. This ensures that the trajectory of DiscQuant remains within the convex polytope KK and eventually converges to a vertex of KK with almost all the coordinates rounded. Instead of picking a random direction cc to find a random vertex of KK, we use a special cc^{*} which let’s us find the vertex closest to the original weights ww (see Section 4). We use RTN to round the few unrounded parameters left at the end of the optimization.

We perform extensive experiments which show the strength of our method: on models Phi-3-mini-4k-instruct and Meta-Llama-3.1-8B-Instruct, across a variety of evaluation tasks, and across the block scaling and incoherence processing quantization formats. DiscQuant is agnostic towards the quantization grid, and can therefore be composed with other quantization methods. Block scaling sets a bits parameter which determines the number of grid points, and a unique scaling parameter per groupsize weights (Frantar et al., 2023). Incoherence processing applies a random orthogonal transformation, which reduces the weight ranges and can make quantization easier (Chee et al., 2023; Tseng et al., 2024a). A subset of results can be found in Figure 2. Across tasks, models, and quantization levels, our method DiscQuant achieves superior compression over baselines GPTQ and RTN.

Refer to caption
Refer to caption
Refer to caption
Figure 2: Select results quantizing Phi-3-mini-4k-instruct and Meta-Llama-3.1-8B-Instruct using block scaling quantization. GSM8k is a math-based generative task, and WinoGrande and PIQA are multiple choice commonsense reasoning tasks. Error bars are standard errors from lm-evaluation-harness. See Section 5 for full results.

We summarize our main contributions:

  • Theoretical developments: We prove that it is possible to achieve generalization error ε\leq\varepsilon on the true data distribution by rounding all but poly(logn/ε)\mathrm{poly}(\log n/\varepsilon) weights, so long as the gradients of the original model are approximately low rank.

  • Practical algorithm: We develop a simple and practical algorithm DiscQuant guided by our theoretical analysis. We perform extensive experiments on Phi-3-mini-4k-instruct and Meta-Llama-3.1-8B-Instruct, over block scaling and incoherence processing quantization formats, and a variety of evaluation tasks. Our method DiscQuant achieves superior or comparable quantization to the baselines GPTQ and RTN as can be seen from Figure 2.

2 Related Work

In this paper we focus on weights-only PTQ. Quantization can also be applied to the activations or KV-cache (Ashkboos et al., 2024; Liu et al., 2024a, b). Other compression method such as pruning
(Frantar and Alistarh, 2023; Sun et al., 2023) are also outside the scope of this work. As discussed in the introduction, post training quantization can be divided into two overall steps: (1) Construct a good low bit-complexity representations for the weights (the quantization grid), and (2) Round the original weights to the values in the quantization grid. To this date, the vast majority of PTQ research for LLMs has focused on step (1). Note that determining a good compressed representation can involve both encoding formats, as well as transformations to ensure the weights better match the encoding format.

2.1 Quantization Grids

One of the more common quantization formats is called block scaling, or group-wise quantization (Frantar et al., 2023). In addition to the bits parameter determining the number of representable points, each groupsize parameters share a unique scaling parameter. Another successful encoding is to identify a small set of important weights and keep them in high precision (Dettmers et al., 2022, 2024; Kim et al., 2024). Shao et al. (2024) learns quantization parameters. Other works apply transformations to make quantization easier, either relatively simple invariant scalings (Xiao et al., 2023; Lin et al., 2024), or more complicated random orthogonal transformations (Chee et al., 2023; Liu et al., 2024a). Beyond block scaling, there has been work quantizing multiple parameters together using vector quantization (Tseng et al., 2024a; Egiazarian et al., 2024; van Baalen et al., 2024) or trellis quantization (Tseng et al., 2024b).

2.2 Rounding

To the best of our knowledge, GPTQ (Frantar et al., 2023) is the main rounding method for LLMs. It is based on the Optimal Brain Surgeon (Hassibi et al., 1993), which was adapted for pruning and quantization in Frantar et al. (2022) and then refined for quantization in GPTQ. GPTQ works by minimizing a layer-wise objective WXW^X22\|WX-\hat{W}X\|_{2}^{2}, where WW is the weight matrix of a linear layer and XX is the matrix of input activations to that layer (stacked as columns). Two other LLM rounding methods both use coordinate descent: Nair and Suggala (2024) only has results on the closed source PaLM-2 models with no released code, and Behdin et al. (2023) has results on the OPT, BLOOM, and Falcon model families.

There was more work on rounding methods several years ago, before the LLM boom. These papers were typically on smaller vision models. The line of work was started by AdaRound (Nagel et al., 2020) and continuing to AdaQuant (Hubara et al., 2021) and BRECQ (Li et al., 2021) employ a similar approach to ours, optimizing essentially interpolation variables between the closest up(wupw^{\textrm{up}}) and down(wdownw^{\textrm{down}}) quantization grid points, while adding a concave regularization term to encourage rounding and using a rectified sigmoid to interpolate between wupw^{\textrm{up}} and wdownw^{\textrm{down}}. They also do rounding layer by layer. However our method uses a linear term as a regularizer inspired from our theoretical insights using discrepancy theory and uses simple linear interpolation between wupw^{\textrm{up}} and wdownw^{\textrm{down}} and we round the entire model at once.

Refer to caption
Figure 3: First order approximation of the error function Δf\Delta f when quantizing the model to 4.25 bits using RTN and DiscQuant. Here ff is the per-token loss function and ss is sampled from the WikiText-2 dataset.
Refer to caption
Refer to caption
Figure 4: Eigenvalues of the covariance matrix of the gradients of pre-trained models. The covariance matrix is estimated by averaging over 8k8k sample gradients from RedPajama-1T-Sample and projecting them to 20482048 dimensions using Johnson-Lindenstrauss projections.

2.3 Discrepancy Theory

Discrepancy theory is a deep branch of mathematics and theoretical computer science, and we refer the readers to standard textbooks for more details (Matousek, 2009; Chazelle et al., 2004; Bansal, 2022) To our knowledge, only Lybrand and Saab (2021) makes the connection between discrepancy theory and quantization. However, besides the high level motivational similarities, their work is not directly relevant to ours. Lybrand and Saab (2021) reduce the problem of understanding the error introduced by quantization on the output of a single neuron to a problem in discrepancy, and construct an algorithm for quantizing a single neuron. Their theoretical analysis on the generalization error only applies to quantizing the first layer of a neural network. On the other hand, we use discrepancy theory to understand when the whole network f(w;s)f(w;s) can be approximated by f(w^;s)f(\hat{w};s) with w^\hat{w} in the quantization grid, and our theory holds for any network as a whole as long as our assumptions are true.

3 Connections to Discrepancy Theory

Model 𝔼(g)2\|\mathbb{E}(g)\|^{2} 𝔼g2\mathbb{E}\|g\|^{2}
Phi3-mini-128k 0.1021 4.7812
Llama3.1-8B 1.6328 107
Table 1: 𝔼(g)2\|\mathbb{E}(g)\|^{2} vs 𝔼g2\mathbb{E}\|g\|^{2} over 81928192 samples from RedPajama-1T-Sample dataset with window size 20482048.

Let f(w;s)f(w;s) be the loss function of a pre-trained neural network with weights wnw\in\mathbb{R}^{n} on an input sample ss and let 𝒟data\mathcal{D}_{\textrm{data}} be the sample data distribution. Suppose we are also given a (scalar) quantization grid 𝒬=Q1×Q2××Qn\mathcal{Q}=Q_{1}\times Q_{2}\times\dots\times Q_{n} where QjQ_{j}\subset\mathbb{R} is a finite set of quantization points available to quantize the jthj^{th} parameter.111The quantization grid 𝒬\mathcal{Q} can depend on ww, like in Block Scaling (Frantar et al., 2023). So ideally, we should write 𝒬w\mathcal{Q}_{w}, but we ignore the dependence to simplify notation. In this work, we focus on scalar quantization which allows us to write the quantization grid as a product set, i.e., each parameter can be independently rounded to a finite set of available values. Alternatively, in vector quantization a group of dd variables are rounded together to one of a finite set of quantization points in d\mathbb{R}^{d}, which has been used in some prior works (Tseng et al., 2024a; Egiazarian et al., 2024; van Baalen et al., 2024). Generalizing our method to vector quantizers is an interesting future research direction.

Our goal is to find a rounding w^𝒬\hat{w}\in\mathcal{Q} of the original weights ww such f(w^;s)f(w;s)f(\hat{w};s)\approx f(w;s) where s𝒟datas\sim\mathcal{D}_{\textrm{data}}. We further impose the constraint that for each parameter wjw_{j}, we only round up or round down to the available values in QjQ_{j}, i.e., we only have two choices for w^j\hat{w}_{j} denoted by wjup,wjdownQjw^{\textrm{up}}_{j},w^{\textrm{down}}_{j}\in Q_{j} where wjupwjwjdownw^{\textrm{up}}_{j}\leq w_{j}\leq w^{\textrm{down}}_{j}.222If wj<minQjw_{j}<\min Q_{j} or wj>maxQjw_{j}>\max Q_{j}, we just set wjup=wjdown=minQjw^{\textrm{up}}_{j}=w^{\textrm{down}}_{j}=\min Q_{j} or maxQj\max Q_{j} respectively. We make this assumption because we don’t want to change any parameter of the original model too much during quantization, consider it an important property of algorithms we design. Using Taylor expansion:

Δf=f(w^;s)f(w;s)=wf(w;s),w^w+(w^w)Tw2f(w;s)(w^w)+\Delta f=f(\hat{w};s)-f(w;s)=\left\langle\nabla_{w}f(w;s),\hat{w}-w\right\rangle+(\hat{w}-w)^{T}\nabla_{w}^{2}f(w;s)(\hat{w}-w)+\cdots (2)

Assuming that the quantization grid 𝒬\mathcal{Q} is fine enough and since we only round each parameter up or down, w^w\left\lVert\hat{w}-w\right\rVert is small and so we can ignore the higher order terms. We claim that the first order term is the dominant term. Prior works such as Nagel et al. (2020); Hassibi et al. (1993); LeCun et al. (1989) have assumed that the first order term can be assumed to be zero because the model is trained to convergence and focused on reducing the second order term. But the model being trained to convergence just means that average gradient over many samples from the distribution is nearly zero. But the gradients still have some variance and gradients w.r.t. individual samples from the data distribution are not approximately zero (see Table 1). Figure 4 demonstrates this by showing that the error term Δf\Delta f is well-correlated with the first order approximation wf(w;s),w^w\left\langle\nabla_{w}f(w;s),\hat{w}-w\right\rangle.333In the special case when ff is the KL distillation loss between the original model and quantized model, the first order term vanishes exactly. See Section 4 for why this analysis still applies.

So the goal now is to find a rounding w^\hat{w} such that wf(w;s),w^w0\left\langle\nabla_{w}f(w;s),\hat{w}-w\right\rangle\approx 0 for samples s𝒟data.s\sim\mathcal{D}_{\textrm{data}}. Suppose we sample mm samples s1,s2,,sm𝒟datas_{1},s_{2},\dots,s_{m}\sim\mathcal{D}_{\textrm{data}} independently from the data distribution, where mnm\ll n. We now break our task into two parts of bounding the empirical error and generalization error as follows:

Question 3.1.

Can we find w^𝒬\hat{w}\in\mathcal{Q} (with w^j{wjdown,wjup}\hat{w}_{j}\in\{w^{\textrm{down}}_{j},w^{\textrm{up}}_{j}\}) such that wf(w;si),w^w0\left\langle\nabla_{w}f(w;s_{i}),\hat{w}-w\right\rangle\approx 0 for all the samples s1,,sms_{1},\dots,s_{m}?

Question 3.2.

Once we find such a w^\hat{w}, will it generalize to the true data distribution, i.e., will wf(w;s),w^w0\left\langle\nabla_{w}f(w;s),\hat{w}-w\right\rangle\approx 0 for s𝒟datas\sim\mathcal{D}_{\textrm{data}}? How many samples mm do we need for this?

3.1 Bounding empirical error (Question 3.1)

For simplicity, let us assume that the quantization grid is uniform and wiupwidown=δw^{\textrm{up}}_{i}-w^{\textrm{down}}_{i}=\delta for all i[n]i\in[n] where δ>0\delta>0 is the distance between grid points. See Appendix C for how to genealize this to non-uniform grids. We will introduce new parameters x[0,1]nx\in[0,1]^{n} and define wx=wdown+δxw^{x}=w^{\textrm{down}}+\delta x. Note that wixw^{x}_{i} interpolates between widownw^{\textrm{down}}_{i} and wiupw^{\textrm{up}}_{i} where wi=widownw_{i}=w^{\textrm{down}}_{i} if xi=0x_{i}=0 and wi=wiupw_{i}=w^{\textrm{up}}_{i} if xi=1x_{i}=1. Let y[0,1]ny\in[0,1]^{n} be the interpolation point corresponding to the original weights, i.e., wy=ww^{y}=w. We can rewrite the linear constraints in terms of xx as follows:

wf(w;si),wxw=wf(w;si),wxwy=δwf(w;si),xy.\displaystyle\left\langle\nabla_{w}f(w;s_{i}),w^{x}-w\right\rangle=\left\langle\nabla_{w}f(w;s_{i}),w^{x}-w^{y}\right\rangle=\delta\left\langle\nabla_{w}f(w;s_{i}),x-y\right\rangle.

Let MM be an m×nm\times n matrix whose ithi^{th} row is given by wf(w;si)\nabla_{w}f(w;s_{i}). Then the linear constraints can be simply written as M(xy)=0M(x-y)=0. Our goal is to find a fully integral x^{0,1}n\hat{x}\in\{0,1\}^{n} such that M(x^y)=0.M(\hat{x}-y)=0. Let V={xn:Mx=My}V=\{x\in\mathbb{R}^{n}:Mx=My\} which is an affine subspace of dimension nm\geq n-m. Define K=[0,1]nVK=[0,1]^{n}\cap V as the intersection of the hypercube with this subspace. KK is a convex polytope and it is non-empty because yK.y\in K. Therefore any vertex of KK should have nmn-m integral coordinates (i.e., coordinates jj such that xj{0,1}x_{j}\in\{0,1\}).444This is because at a vertex, we need to have nn tight constraints, and VV imposes only mm tight constraints. So the remaining nmn-m tight constraints should come from the hypercube. These are also called basic feasible solutions in linear programming.

See Figure 1 for geometric intuition about why this is true. Since the number of parameters nn is much larger than the number of samples mm, any vertex of KK is almost fully integral and exactly satisfies all the mm linear constraints.

Suppose we further ask for a fully integral x^\hat{x} which approximately satisfies all the mm linear constraints, this precise question is answered by discrepancy theory which studies how to do this and relates the approximation error to properties of MM such as hereditary discrepancy (Lovász et al., 1986; Bansal, 2022). We don’t explore this direction further because the almost integral x^\hat{x}—a vertex of KK—is good enough if we apply RTN to the few remaining fractional parameters; we observe that the linear constraints are all approximately satisfied.

3.2 Bounding Generalization Error (Question 3.2)

How do we bound the generalization error if we know that the empirical approximation error is small? If w^w\hat{w}-w is approximately orthogonal to mm sample gradients wf(w;si)\nabla_{w}f(w;s_{i}) for i=1i=1 to mm, why should we expect that w^w\hat{w}-w is orthogonal to unseen gradients wf(w;s)\nabla_{w}f(w;s) for samples s𝒟datas\sim\mathcal{D}_{\textrm{data}}? This should happen only if the gradients are approximately low rank. More precisely, let

Σ=𝔼s𝒟data[ggT] where g=wf(w;s)\Sigma=\mathbb{E}_{s\sim\mathcal{D}_{\textrm{data}}}[gg^{T}]\text{ where }g=\nabla_{w}f(w;s)

be the covariance matrix of the distribution of sample gradients and let λ1λ2λn\lambda_{1}\geq\lambda_{2}\geq\dots\geq\lambda_{n} be its eigenvalues. We observe that the eigenvalues decay very fast, see Figure 4 for empirical validation of this on some real world models. We model this by assuming that λkλ1/kα\lambda_{k}\leq\lambda_{1}/k^{\alpha} for α>1\alpha>1. The assumption that α>1\alpha>1 is valid since

𝔼s[g2]=𝔼s[Tr(ggT)]=Tr(𝔼s[ggT])=Tr(Σ)=i=1nλi.\mathbb{E}_{s}[\left\lVert g\right\rVert^{2}]=\mathbb{E}_{s}[{\rm Tr}(gg^{T})]={\rm Tr}(\mathbb{E}_{s}[gg^{T}])={\rm Tr}(\Sigma)=\sum_{i=1}^{n}\lambda_{i}.

It is well-known that the gradients of a pretrained model have constant norm on most samples (see Table 1 for empirical validation). Therefore i=1nλi=O(1)\sum_{i=1}^{n}\lambda_{i}=O(1) and so the the decay coefficient α\alpha has to be at least 1.

Under this assumption, it is reasonable to expect generalization. But this is not at all obvious to find a generalizing solution. In fact, any deterministic algorithm which chooses one of the vertices of KK will most likely not generalize. We give a randomized rounding algorithm (see Algorithm B.2) based on the famous Lovett-Meka algorithm from discrepancy theory (Lovett and Meka, 2012) which finds a vertex of KK which has low generalization error. The algorithm starts at yy and does a random walk (Brownian motion) inside the nmn-m dimensional subspace VV formed by the linear constraints imposed by the mm samples. Whenever it hits a face xi=0x_{i}=0 or xi=1x_{i}=1 of the hypercube, it fixes that variable and continues the random walk until almost all the variables are rounded.

In order to prove rigorous bounds we also need a mild assumption that the distribution of gradients is well-behaved. We use the notion by O’Donnell (2014) and say that for a parameter β1\beta\geq 1, a random vector XnX\in\mathbb{R}^{n} is β\beta-reasonable if

𝔼[X,θ4]β𝔼[X,θ2]2θn.\mathbb{E}[\left<X,\theta\right>^{4}]\leq\beta\cdot\mathbb{E}[\left<X,\theta\right>^{2}]^{2}\quad\forall\theta\in\mathbb{R}^{n}.

For example X{1,1}nX\sim\{-1,1\}^{n} and a Gaussian XN(𝟎,Σ)X\sim N(\bm{0},\Sigma) are both O(1)O(1)-reasonable. Our main theoretical result (proved in Appendix B) is then:

Theorem 3.3.

Let α>1\alpha>1 and β1\beta\geq 1 be constants and let 1mn161\leq m\leq\frac{n}{16}. Let 𝒟\mathcal{D} be a β\beta-reasonable distribution with unknown covariance matrix Σn×n\Sigma\in\mathbb{R}^{n\times n} whose Eigenvalues satisfy λkλ1kα\lambda_{k}\leq\frac{\lambda_{1}}{k^{\alpha}} for all k=1,,nk=1,\ldots,n. Then there is a randomized polynomial time algorithm that given a y[0,1]ny\in[0,1]^{n} and mm independent samples g1,,gm𝒟g_{1},\ldots,g_{m}\sim\mathcal{D}, produces an x[0,1]nx\in[0,1]^{n} with high probability such that all but O(m)O(m) parameters in xx are fully rounded and

𝔼g𝒟[g,xy2]=(xy)TΣ(xy)α,βλ1mmin{1/2,α1}(logn)2.\mathbb{E}_{g\sim\mathcal{D}}[\left\langle g,x-y\right\rangle^{2}]=(x-y)^{T}\Sigma(x-y)\lesssim_{\alpha,\beta}\lambda_{1}m^{-\min\{1/2,\alpha-1\}}(\log n)^{2}.

4 DiscQuant: Algorithm

In this section, we will present DiscQuant, a simple and practical algorithm for rounding inspired by the theoretical insights in Section 3. Instead of trying to approximate the loss function of the pre-trained model, i.e., f(w^;s)f(w;s)f(\hat{w};s)\approx f(w;s), we will instead take a distillation approach and try to minimize the KL divergence between the next token distribution of the original model and the quantized model. Let pw(|z<i)p_{w}(\cdot|z_{<i}) be the distribution of the next token predicted by the original model given prefix z<iz_{<i} where z𝒟dataz\sim\mathcal{D}_{\textrm{data}} is a sample from the data distribution. We want error(w^)=𝔼z𝒟data𝔼iDKL(pw(|z<i)pw^(|z<i))0\textrm{error}(\hat{w})=\mathbb{E}_{z\sim\mathcal{D}_{\textrm{data}}}\mathbb{E}_{i}D_{KL}\left(p_{w}(\cdot|z_{<i})\left\|\right.p_{\hat{w}}(\cdot|z_{<i})\right)\approx 0.

Expanding error(w^)\textrm{error}(\hat{w}) using Taylor series, we can see that first order term vanishes exactly and so the second order term is the dominant term (see Appendix D). By Lemma D.1, Hessian of error(w^)\textrm{error}(\hat{w}) can be written as a covariance of gradients as:

Hw=𝔼z𝒟data𝔼i𝔼tpw(t|z<i)[(wlogpw(t|z<i)(wlogpw(|z<i))T].H_{w}=\mathbb{E}_{z\sim\mathcal{D}_{\textrm{data}}}\mathbb{E}_{i}\mathbb{E}_{t\sim p_{w}(t|z_{<i})}\left[(\nabla_{w}\log p_{w}(t|z_{<i})(\nabla_{w}\log p_{w}(\cdot|z_{<i}))^{T}\right].

Therefore

error(w^)(w^w)THw(w^w)=𝔼z𝒟data𝔼i𝔼tpw(|z<i)[wlogpw(t|z<i),w^w2].\textrm{error}(\hat{w})\approx(\hat{w}-w)^{T}H_{w}(\hat{w}-w)=\mathbb{E}_{z\sim\mathcal{D}_{\textrm{data}}}\mathbb{E}_{i}\mathbb{E}_{t\sim p_{w}(\cdot|z_{<i})}\left[\left\langle\nabla_{w}\log p_{w}(t|z_{<i}),\hat{w}-w\right\rangle^{2}\right].

So minimizing error(w^)\textrm{error}(\hat{w}) is a succinct way to impose constraints of the form wlogpw(t|z<i),w^w0\langle\nabla_{w}\log p_{w}(t|z_{<i}),\hat{w}-w\rangle\approx 0 or equivalently logpw(t|z<i)logpw^(t|z<i)\log p_{w}(t|z_{<i})\approx\log p_{\hat{w}}(t|z_{<i}) where tpw(|z<i)t\sim p_{w}(\cdot|z_{<i}) and z𝒟dataz\sim\mathcal{D}_{\textrm{data}}. Therefore, we can use the same techniques developed in Section 3 to solve this as well. Assuming that the gradients are low rank, the set of xx satisfying these constraints (where w^=wx\hat{w}=w^{x}) form an affine subspace VV of dimension nm\geq n-m where mm is the number of samples. We are again interested in finding a vertex of the polytope K=[0,1]nVK=[0,1]^{n}\cap V which will have nm\geq n-m integral coordinates. At this point, we could use the Lovett-Meka algorithm (Algorithm B.2) which has provable generalization guarantees. But explicitly calculating all the gradients and storing them is infeasible. Instead a simple heuristic way to find a random vertex of polytope KK is to minimize a random linear function. Let cnc\in\mathbb{R}^{n} be some arbitrary vector; we will try to minimize the linear function c,x\left\langle c,x\right\rangle along with the KL divergence by taking a linear combination of them. The final optimization objective is shown in (3) where λ>0\lambda>0 is a regularization coefficient.

minxλc,x+𝔼z𝒟data𝔼i[DKL(pw(|z<i)pwx(|z<i))]\displaystyle\min_{x}\lambda\left\langle c,x\right\rangle+\mathbb{E}_{z\sim\mathcal{D}_{\textrm{data}}}\mathbb{E}_{i}[D_{KL}\left(p_{w}(\cdot|z_{<i})\left\|\right.p_{w^{x}}(\cdot|z_{<i})\right)] (3)
s.t.x[0,1]n.\displaystyle s.t.\ x\in[0,1]^{n}.

We solve the optimization problem (3) using projected stochastic gradient descent where we project xx to the hypercube after every gradient update. Optimizing (3) will keep us close the polytope KK and will approximately converge to a vertex of KK which is almost integral. We round whatever fractional coordinates are left using RTN to get a fully integral solution.

We use one additional heuristic to improve the performance of the algorithm in practice. Instead of choosing a random vertex of the polytope KK by choosing the vector cc at random, we will choose it carefully so as to find the vertex of the polytope KK which is closest to yy which is the interpolation point corresponding to the original model weights (i.e., yy such that wy=ww^{y}=w). We have:

xy2=i(xi22xiyi+yi2)i(xi2xiyi+yi2)=c,x+y2\displaystyle\left\lVert x-y\right\rVert^{2}=\sum_{i}(x_{i}^{2}-2x_{i}y_{i}+y_{i}^{2})\approx\sum_{i}(x_{i}-2x_{i}y_{i}+y_{i}^{2})=\left\langle c^{*},x\right\rangle+\left\lVert y\right\rVert^{2}

where c=(12y)c^{*}=(1-2y). Here we have used the fact that xi2=xix_{i}^{2}=x_{i} whenever xi{0,1}x_{i}\in\{0,1\} and since xx is almost integral, we can use the approximation in the summation above. With this approximation, minimizing xy2\left\lVert x-y\right\rVert^{2} over almost integral xx is equivalent to minimizing c,x\left\langle c^{*},x\right\rangle. So in the DiscQuant algorithm, we use c=cc=c^{*} specifically instead of a random c.c.

5 Experiments

We evaluate our method on the Phi-3-mini-4k-instruct (Abdin et al., 2024) and Meta-Llama-3.1-8B-Instruct (Dubey et al., 2024) models, and compare against GPTQ and greedy rounding (i.e. round-to-nearest, or RTN). We use the lm-evaluation-harness Gao et al. (2023) to evaluate on the Wikitext, GSM8k_cot 8-shot, MMLU 5-shot, ARC_Challenge 0-shot, PIQA 0-shot, HellaSwag 0-shot, and Winogrande 0-shot tasks. We report standard errors from lm-evaluation-harness. Wikitext measures perplexity, GSM8k is a generative task, and the remaining are multiple choice tasks. Note that generative tasks are typically more difficult than multiple choice tasks, and better reflect how the models are used in practice. See Appendix A for details on the hardware used, and hyper-parameter settings. Our method has similar memory requires as knowledge distillation, which also requires two copies of the model. We do not perform inference timing experiments; DiscQuant can optimize over a given quantization grid, so that we can utilize any pre-existing inference optimizations. For example, there are inference kernels for block scaling (Frantar et al., 2024) and incoherence processing (Tseng et al., 2024a). Ablations on the loss formulation are in Appendix A.

Method Wbits Wiki\downarrow GSM8k\uparrow MMLU\uparrow ArcC\uparrow PIQA\uparrow Hella\uparrow Wino\uparrow
16.0 9.5 84.4±\pm1.0 70.4±\pm0.4 56.7±\pm1.4 80.8±\pm0.9 77.4±\pm0.4 73.5±\pm1.2
RTN 3.0 6.36.3E55 1.0±\pm0.3 23.3±\pm0.4 26.9±\pm1.3 53.4±\pm1.2 28.2±\pm0.4 48.6±\pm1.4
GPTQ 3.0 28.2 2.3±\pm0.4 37.7±\pm0.4 34.8±\pm1.4 64.3±\pm1.1 56.5±\pm0.5 52.6±\pm1.4
DiscQ 3.0 17.7 26.8±\pm1.2 45.6±\pm0.4 44.1±\pm1.5 73.9±\pm1.0 63.3±\pm0.5 66.6±\pm1.3
RTN 3.25 22.5 31.0±\pm1.3 53.2±\pm0.4 48.4±\pm1.5 72.5±\pm1.0 68.3±\pm0.5 62.6±\pm1.4
GPTQ 3.25 13.8 54.3±\pm1.4 59.0±\pm0.4 49.6±\pm1.5 77.3±\pm1.0 71.1±\pm0.5 66.5±\pm1.3
DiscQ 3.25 12.6 64.2±\pm1.3 60.7±\pm0.4 53.5±\pm1.5 78.7±\pm1.0 72.3±\pm0.4 72.5±\pm1.3
RTN 3.5 18.8 46.3±\pm1.4 57.0±\pm0.4 46.2±\pm1.5 73.8±\pm1.0 70.0±\pm0.5 63.9±\pm1.4
GPTQ 3.5 12.8 54.6±\pm1.4 61.7±\pm0.4 51.6±\pm1.5 78.9±\pm1.0 72.3±\pm0.4 68.3±\pm1.3
DiscQ 3.5 12.0 69.5±\pm1.3 63.0±\pm0.4 51.1±\pm1.5 78.9±\pm1.0 73.0±\pm0.4 73.9±\pm1.2
RTN 4.0 14.6 62.2±\pm1.3 61.2±\pm0.4 53.6±\pm1.5 76.3±\pm1.0 72.9±\pm0.4 65.3±\pm1.3
GPTQ 4.0 11.5 71.5±\pm1.2 65.1±\pm0.4 54.6±\pm1.5 78.8±\pm1.0 74.7±\pm0.4 70.9±\pm1.3
DiscQ 4.0 11.2 77.3±\pm1.2 65.7±\pm0.4 56.8±\pm1.4 79.5±\pm0.9 74.5±\pm0.4 72.0±\pm1.3
RTN 4.25 11.2 64.4±\pm1.3 67.5±\pm0.4 55.5±\pm1.5 79.3±\pm0.9 76.1±\pm0.4 69.1±\pm1.3
GPTQ 4.25 10.3 81.0±\pm1.1 68.5±\pm0.4 56.9±\pm1.4 79.7±\pm0.9 76.1±\pm0.4 72.1±\pm1.3
DiscQ 4.25 10.2 80.7±\pm1.1 68.4±\pm0.4 57.3±\pm1.4 80.7±\pm0.9 76.3±\pm0.4 74.2±\pm1.2
RTN 4.5 10.8 71.6±\pm1.2 67.7±\pm0.4 57.5±\pm1.4 79.3±\pm0.9 76.6±\pm0.4 72.2±\pm1.3
GPTQ 4.5 10.1 82.0±\pm1.1 68.8±\pm0.4 55.8±\pm1.5 80.8±\pm0.9 76.5±\pm0.4 71.8±\pm1.3
DiscQ 4.5 10.0 82.1±\pm1.1 68.5±\pm0.4 56.6±\pm1.4 80.2±\pm0.9 76.7±\pm0.4 74.2±\pm1.2
Table 2: Phi-3-mini-4k-instruct. Across all tasks and bits, our method DiscQuant always achieves superior results over the baseline RTN and GPTQ methods. On the ArcC, PIQA, and Wino tasks, DiscQuant achieves full recovery with at least 0.25 fewer bits per parameter than GPTQ and RTN.
Method Wbits Wiki\downarrow GSM8k\uparrow MMLU\uparrow ArcC\uparrow PIQA\uparrow Hella\uparrow Wino\uparrow
16.0 8.7 77.0±\pm1.2 68.0±\pm0.4 55.2±\pm1.5 81.3±\pm0.9 79.3±\pm0.4 73.7±\pm1.2
RTN 3.0 4.44.4E33 0.5±\pm0.2 23.2±\pm0.4 22.3±\pm1.2 52.4±\pm1.2 29.1±\pm0.5 50.0±\pm1.4
GPTQ 3.0 23.2 3.6±\pm0.5 24.6±\pm0.4 31.8±\pm1.4 66.6±\pm1.1 45.8±\pm0.5 54.1±\pm1.4
DiscQ 3.0 15.2 14.3±\pm1.0 44.6±\pm0.4 39.4±\pm1.4 73.2±\pm1.0 64.4±\pm0.5 62.8±\pm1.4
RTN 3.25 15.2 10.8±\pm0.9 50.5±\pm0.4 44.3±\pm1.5 75.2±\pm1.0 71.4±\pm0.5 67.2±\pm1.3
GPTQ 3.25 10.7 56.3±\pm1.4 60.5±\pm0.4 46.3±\pm1.5 76.7±\pm1.0 74.4±\pm0.4 68.7±\pm1.3
DiscQ 3.25 10.5 58.3±\pm1.4 60.2±\pm0.4 49.1±\pm1.5 79.1±\pm0.9 75.1±\pm0.4 72.1±\pm1.3
RTN 3.5 12.7 35.9±\pm1.3 51.4±\pm0.4 48.4±\pm1.5 76.7±\pm1.0 73.0±\pm0.4 69.1±\pm1.3
GPTQ 3.5 10.4 57.0±\pm1.4 62.1±\pm0.4 49.9±\pm1.5 77.3±\pm1.0 75.1±\pm0.4 71.1±\pm1.3
DiscQ 3.5 10.3 60.7±\pm1.3 60.9±\pm0.4 51.7±\pm1.5 79.2±\pm0.9 76.3±\pm0.4 72.5±\pm1.3
RTN 4.0 12.5 50.8±\pm1.4 59.3±\pm0.4 50.5±\pm1.5 77.6±\pm1.0 74.7±\pm0.4 69.9±\pm1.3
GPTQ 4.0 9.9 63.2±\pm1.3 64.4±\pm0.4 52.4±\pm1.5 78.4±\pm1.0 75.9±\pm0.4 71.7±\pm1.3
DiscQ 4.0 9.8 66.5±\pm1.3 63.4±\pm0.4 51.6±\pm1.5 79.2±\pm0.9 76.9±\pm0.4 72.8±\pm1.3
RTN 4.25 9.4 70.6±\pm1.3 65.7±\pm0.4 54.2±\pm1.5 80.1±\pm0.9 78.0±\pm0.4 73.9±\pm1.2
GPTQ 4.25 9.1 74.6±\pm1.2 66.8±\pm0.4 53.4±\pm1.5 79.6±\pm0.9 77.9±\pm0.4 73.5±\pm1.2
DiscQ 4.25 9.1 74.9±\pm1.2 66.9±\pm0.4 53.6±\pm1.5 79.9±\pm0.9 78.4±\pm0.4 72.6±\pm1.3
RTN 4.5 9.3 71.9±\pm1.2 65.8±\pm0.4 54.8±\pm1.5 80.3±\pm0.9 78.4±\pm0.4 72.4±\pm1.3
GPTQ 4.5 9.0 73.8±\pm1.2 66.9±\pm0.4 53.6±\pm1.5 79.6±\pm0.9 78.1±\pm0.4 73.7±\pm1.2
DiscQ 4.5 9.1 74.8±\pm1.2 66.8±\pm0.4 54.1±\pm1.5 80.6±\pm0.9 78.7±\pm0.4 72.9±\pm1.2
Table 3: Meta-Llama-3.1-8B-Instruct. Our method DiscQuant achieves superior compression on the vast majority of quantization levels and tasks over the baselines GPTQ and RTN.

5.1 Block Scaling

Our first experiments use standard block scaling quantization, determined by a bits and groupsize parameter. There are 2𝚋𝚒𝚝𝚜2^{\tt bits} unique points, and every groupsize parameters share a unique 16-bit scale parameter. For example, 3.25 bits is achieved with bits=3, groupsize=64. We use the block scaling implementation from Frantar et al. (2024) which is symmetric linear quantization. Table 2 shows the results quantizing Phi-3-mini-4k-instruct. Across all tasks and all bit settings, our method DiscQuant achieves superior or comparable compression over the baseline GPTQ and RTN methods. The gap between DiscQuant and the baselines is greater at lower bits. On the ARC_Challenge, PIQA, and WinoGrade tasks, DiscQuant achieves full recovery with at least 0.25 fewer bits per parameter than GPTQ and RTN. For example on ARC_Challenge, DiscQuant achieves full recovery at 4 bits per weight, whereas GPTQ requires 4.25 bits, and RTN 4.5 bits. DiscQuant achieves better compression on the more difficult generative GSM8k task: at 4 bits DiscQuant gets 77.3% accuracy, while GPTQ gets 71.5%, and RTN gets 62.2%. Table 3 shows the results quantizing Meta-Llama-3.1-8B-Instruct. Overall the story is the same. Our method DiscQuant achieves improved compression on the majority of quantization levels and tasks. For example at 4 bits, DiscQuant gets 66.5% GSM8k accuracy, while GPTQ gets 63.2%, and RTN gets 50.8%.

Refer to caption
Refer to caption
Figure 5: Quantizing Phi-3-mini-4k-instruct and Meta-LLama-3.1-8B-Instruct with block scaling, and additional incoherence processing. DiscQuant can compose with other quantization improvements, and with incoherence processing remains competitive with GPTQ.

5.2 Incoherence Processing

We explore another quantization format to show that our method can compose with other quantization improvements. Incoherence processing has been shown to improve quantization, especially at less than 4 bits per weight (Chee et al., 2023). The weights are multiplied by certain random orthogonal matrices prior to quantization, which can reduce the range of the weights and make quantization easier. We employ the Randomized Hadamard Transform from Tseng et al. (2024a). We use the same block scaling quantization grid as in the previous subsection. A subset of our results are shown in Figure 5, where we superimpose bar plots for block scaling and block scaling + incoherence processing. In the majority of cases, adding incoherence processing increases the task accuracy, especially at lower bits. We do not use fractional bits, (i.e. no groupsize), due to the fact that both these methods effect outliers and can interfere with one another. Incoherence especially helps GPTQ at 3 bits, and for Phi-3 DiscQuant without incoherence is competitive to GPTQ with incoherence. For full results see Appendix A.

Refer to caption
Refer to caption
Refer to caption
Figure 6: Effect of increasing the fraction of math data when quantizing Phi-3-mini-4k-instruct at 3.25 bits. For 8192 total samples, we use a fraction of math subject data (GSM8k & MetaMathQA), and the remaining our standard RedPajama. As expected, performance on GSM8k increases with more math data. Expected behavior on the other tasks is unclear.

5.3 Effect of Data

We perform a simple investigation into the effect of the dataset on quantization. We mix math subject data–GSM8k and MetaMathQA–with our standard RedPajama dataset. Figure 6 shows the results of quantizing Phi-3-mini-4k-instruct at 3.25 bits with such a mix. As expected, both methods increase accuracy on GSM8k when there is a greater fraction of math data. On HellaSwag, DiscQuant improves with more math data, where GPTQ gets worse. On PIQA, both methods get worse. See Appendix A for all tasks. There is a meaningful change in accuracy as a result of changing the data mix. Choosing an appropriate data mix for quantization remains an important open question.

References

  • Abdin et al. (2024) Marah Abdin, Jyoti Aneja, Hany Awadalla, Ahmed Awadallah, Ammar Ahmad Awan, Nguyen Bach, Amit Bahree, Arash Bakhtiari, Jianmin Bao, and Harkirat Behl. Phi-3 technical report: A highly capable language model locally on your phone, 2024. URL https://arxiv.org/abs/2404.14219.
  • Ashkboos et al. (2024) Saleh Ashkboos, Amirkeivan Mohtashami, Maximilian L Croci, Bo Li, Martin Jaggi, Dan Alistarh, Torsten Hoefler, and James Hensman. Quarot: Outlier-free 4-bit inference in rotated llms. In Thirty-either Conference on Neural Information Processing Systems, 2024.
  • Bansal (2022) Nikhil Bansal. Discrepancy theory and related algorithms. In Proc. Int. Cong. Math, volume 7, pages 5178–5210, 2022.
  • Behdin et al. (2023) Kayhan Behdin, Ayan Acharya, Aman Gupta, Sathiya Keerthi, Rahul Mazumder, Zhu Siyu, and Song Qingquan. Quantease: Optimization-based quantization for language models–an efficient and intuitive algorithm. arXiv preprint arXiv:2309.01885, 2023.
  • Chazelle et al. (2004) Bernard Chazelle, William WL Chen, and Anand Srivastav. Discrepancy theory and its applications. Oberwolfach Reports, 1(1):673–722, 2004.
  • Chee et al. (2023) Jerry Chee, Yaohui Cai, Volodymyr Kuleshov, and Christopher De Sa. QuIP: 2-bit quantization of large language models with guarantees. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=xrk9g5vcXR.
  • Computer (2023) Together Computer. Redpajama: An open source recipe to reproduce llama training dataset, 2023. URL https://github.com/togethercomputer/RedPajama-Data.
  • Dettmers et al. (2022) Tim Dettmers, Mike Lewis, Younes Belkada, and Luke Zettlemoyer. Llm.int(): 8-bit matrix multiplication for transformers at scale. In Advances in Neural Information Processing Systems, 2022.
  • Dettmers et al. (2024) Tim Dettmers, Ruslan Svirschevski, Vage Egiazarian, Denis Kuznedelev, Elias Frantar, Saleh Ashkboos, Alexander Borzunov, Torsten Hoefler, and Dan Alistarh. Spqr: A sparse-quantized representation for near-lossless llm weight compression, 2024.
  • Dubey et al. (2024) Abhimanyu Dubey, Abhinav Jauhri, Abhinav Pandey, Abhishek Kadian, Ahmad Al-Dahle, Aiesha Letman, Akhil Mathur, Alan Schelten, Amy Yang, Angela Fan, et al. The llama 3 herd of models, 2024. URL https://arxiv.org/abs/2407.21783.
  • Egiazarian et al. (2024) Vage Egiazarian, Andrei Panferov, Denis Kuznedelev, Elias Frantar, Artem Babenko, and Dan Alistarh. Extreme compression of large language models via additive quantization. In Forty-First International Conference on Machine Learning, 2024.
  • Frantar and Alistarh (2023) Elias Frantar and Dan Alistarh. Sparsegpt: Massive language models can be accurately pruned in one-shot. In Proceedings of the International Conference on Machine Learning, 2023.
  • Frantar et al. (2022) Elias Frantar, Sidak Pal Singh, and Dan Alistarh. Optimal brain compression: A framework for accurate post-training quantization and pruning. In Advances in Neural Information Processing Systems, 2022.
  • Frantar et al. (2023) Elias Frantar, Saleh Ashkboos, Torsten Hoefler, and Dan Alistarh. OPTQ: Accurate quantization for generative pre-trained transformers. In The Eleventh International Conference on Learning Representations, 2023.
  • Frantar et al. (2024) Elias Frantar, Roberto L Castro, Jiale Chen, Torsten Hoefler, and Dan Alistarh. Marlin: Mixed-precision auto-regressive parallel inference on large language models. arXiv preprint arXiv:2408.11743, 2024.
  • Gao et al. (2023) Leo Gao, Jonathan Tow, Baber Abbasi, Stella Biderman, Sid Black, Anthony DiPofi, Charles Foster, Laurence Golding, Jeffrey Hsu, Alain Le Noac’h, Haonan Li, Kyle McDonell, Niklas Muennighoff, Chris Ociepa, Jason Phang, Laria Reynolds, Hailey Schoelkopf, Aviya Skowron, Lintang Sutawika, Eric Tang, Anish Thite, Ben Wang, Kevin Wang, and Andy Zou. A framework for few-shot language model evaluation, 12 2023. URL https://zenodo.org/records/10256836.
  • Hassibi et al. (1993) Babak Hassibi, Daivd G Stork, and Gregory J Wolff. optimal brain surgeon and general network pruning. In IEEE International Conference on Neural Networks, 1993.
  • Hubara et al. (2021) Itay Hubara, Yury Nahshan, Yair Hanami, Ron Banner, and Daniel SOudry. Accurate post training quantization with small calibration sets. In Thirty-Eighth International Conference on Machine Learning, 2021.
  • Kim et al. (2024) Sehoon Kim, Coleman Hooper, Amir Gholami, Zhen Dong, Xiuyu Li, Sheng Shen, Michael Mahoney, and Kurt Keutzer. Squeezellm: Dense-and-sparse quantization. In Forty-First International Conference on Machine Learning, 2024.
  • Kurtic et al. (2023) Eldar Kurtic, Denis Kuznedelev, Elias Frantar, Michael Goin, and Dan Alistarh. Sparse fine-tuning for inference acceleration of large language models, 2023. URL https://arxiv.org/abs/2310.06927.
  • Kwon et al. (2023) Woosuk Kwon, Zhuohan Li, Siyuan Zhuang, Ying Sheng, Lianmin Zheng, Cody Hao Yu, Joseph Gonzalez, Hao Zhang, and Ion Stoica. Efficient memory management for large language model serving with pagedattention. In Proceedings of the 29th Symposium on Operating Systems Principles, pages 611–626, 2023.
  • LeCun et al. (1989) Yann LeCun, John Denker, and Sara Solla. Optimal brain damage. Advances in neural information processing systems, 2, 1989.
  • Li et al. (2021) Yuang Li, Ruihao Gong, Xu Tan, Yang Yang, Peng Hu, Qi Zhang, Fengwei Yu, Wei Wang, and Shi Gu. Brecq: Pushing the limit of post-training quantization by block reconstruction. In The Nineth International Conference on Learning Representations, 2021.
  • Lin et al. (2024) Jin Lin, Jiaming Tang, Haotian Tang, Shang Yang, Wei-Ming Chen, Wei-Chen Wang, Guangxuan Xiao, Xingyu Dang, Chuang Gan, and Song Han. Awq: Acttivation-aware weight quantization for on-device llm compression and acceleration. In Seventh Conference on Machine Learning and Systems, 2024.
  • Liu et al. (2024a) Zechun Liu, Changsheng Zhao, Igor Fedorov, Bilge Soran, Dhruv Choudhary, Raghuraman Krishnamoorthi, Vikas Chandra, Yuandong Tian, and Tijmen Blankevoort. Spinquant–llm quantization with learned rotations. arXiv preprint arXiv:2405.16406, 2024a.
  • Liu et al. (2024b) Zirui Liu, Jiayi Yuan, Hongye Jin, Shaochen Zhong, Zhaozhuo Xu, Vladimir Braverman, Beidi Chen, and Xia Hu. Kivi: A tuning-free asymmetric 2bit quantization for kv cache. In Forty-First International Conference on Machine Learning, 2024b.
  • Loshchilov and Hutter (2019) Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. The International Conference on Learning Representations, 2019.
  • Lovász et al. (1986) László Lovász, Joel Spencer, and Katalin Vesztergombi. Discrepancy of set-systems and matrices. European Journal of Combinatorics, 7(2):151–160, 1986.
  • Lovett and Meka (2012) Shachar Lovett and Raghu Meka. Constructive discrepancy minimization by walking on the edges. In FOCS, pages 61–67. IEEE Computer Society, 2012.
  • Lybrand and Saab (2021) Eric Lybrand and Rayan Saab. A greedy algorithm for quantizing neural networks. Journal of Machine Learning Research, 22(156):1–38, 2021.
  • Matousek (2009) Jiri Matousek. Geometric discrepancy: An illustrated guide, volume 18. Springer Science & Business Media, 2009.
  • Nagel et al. (2020) Markus Nagel, Rana Ali Amjad, Mart Van Baalen, Christos Louizos, and Tijmen Blankevoort. Up or down? Adaptive rounding for post-training quantization. In Hal Daumé III and Aarti Singh, editors, Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pages 7197–7206. PMLR, 13–18 Jul 2020. URL https://proceedings.mlr.press/v119/nagel20a.html.
  • Nair and Suggala (2024) Pranav Ajit Nair and Arun Sai Suggala. Cdquant: Accurate post-training weight quantization of large pre-trained models using greedy coordinate descent, 2024. URL https://arxiv.org/abs/2406.17542.
  • O’Donnell (2014) Ryan O’Donnell. Analysis of Boolean Functions. Cambridge University Press, 2014.
  • Shao et al. (2024) Wenqi Shao, Mengzhao Chen, Zhaoyang Zhang, Peng Xu, Lirui Zhao, Zhiqian Li, Kaipeng Zhang, Peng Gao, Yu Qiao, and Ping Luo. Omniquant: Omnidirectionally calibrated quantization for large language models. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=8Wuvhh0LYW.
  • Sun et al. (2023) Mingjie Sun, Zhuang Liu, Anna Bair, and J Zico Kolter. A simple and effective pruning approach for large language models. In Workshop on Efficient Systems for Foundation Models @ ICML2023, 2023. URL https://openreview.net/forum?id=tz9JV2PRSv.
  • Tseng et al. (2024a) Albert Tseng, Jerry Chee, Qingyao Sun, Volodymyr Kuleshov, and Christopher De Sa. QuIP#: Even better llm quantization with hadamard incoherence and lattice codebooks. In Forty-First International Conference on Machine Learning, 2024a.
  • Tseng et al. (2024b) Albert Tseng, Qingyao Sun, David Hou, and Christopher De Sa. QTIP: Quantization with trellises and incoherence processing. In Advances in Neural Information Processing Systems, 2024b.
  • van Baalen et al. (2024) Mart van Baalen, Andrey Kuzmin, Markus Nagel, Peter Couperus, Cedric Bastoul, Eric Mahurin, Tijmen Blankevoort, and Paul Whatmough. Gptvq: The blessing of dimensionality in llm quantization. arXiv preprint arXiv:2402.15319, 2024.
  • Xiao et al. (2023) Guangxuan Xiao, Ji Lin, Mickael Seznec, Hao Wu, Julien Demouth, and Song Han. Smoothquant: Accurate and efficient post-training quantization for large language models. In Fortieth International Conference on Machine Learning, 2023.

Appendix A Additional Experiments

Method Wbits Wiki\downarrow GSM8k\uparrow MMLU\uparrow ArcC\uparrow PIQA\uparrow Hella\uparrow Wino\uparrow
16.0 9.5 84.4±\pm1.0 70.4±\pm0.4 56.7±\pm1.4 80.8±\pm0.9 77.4±\pm0.4 73.5±\pm1.2
RTN 3.0 2.62.6E55 0.0±\pm0.0 23.4±\pm0.4 28.5±\pm1.3 49.8±\pm1.2 26.0±\pm0.4 50.2±\pm1.4
GPTQ 3.0 20.8 10.0±\pm0.8 43.8±\pm0.4 39.2±\pm1.4 70.5±\pm1.1 60.7±\pm0.5 58.0±\pm1.4
DiscQ 3.0 16.7 29.9±\pm1.3 48.0±\pm0.4 46.2±\pm1.5 75.1±\pm1.0 64.5±\pm0.5 66.6±\pm1.3
RTN 4.0 15.9 56.3±\pm1.4 55.0±\pm0.4 53.5±\pm1.5 77.4±\pm1.0 68.8±\pm0.5 70.6±\pm1.3
GPTQ 4.0 11.0 77.6±\pm1.1 65.8±\pm0.4 53.7±\pm1.5 80.2±\pm0.9 74.9±\pm0.4 72.5±\pm1.3
DiscQ 4.0 11.0 76.7±\pm1.2 65.6±\pm0.4 56.0±\pm1.5 79.5±\pm0.9 74.9±\pm0.4 74.2±\pm1.2
Table 4: Phi-3-mini-4k-instruct with incoherence processing. At 3 bits per weight, DiscQuant achieves superior compression across all tasks. At 4 bits per weight, DiscQuant achieves comparable compression.
Method Wbits Wiki\downarrow GSM8k\uparrow MMLU\uparrow ArcC\uparrow PIQA\uparrow Hella\uparrow Wino\uparrow
16.0 8.7 77.0±\pm1.2 68.0±\pm0.4 55.2±\pm1.5 81.3±\pm0.9 79.3±\pm0.4 73.7±\pm1.2
RTN 3.0 2.42.4E33 2.1±\pm0.4 25.2±\pm0.4 24.3±\pm1.3 54.7±\pm1.2 29.5±\pm0.5 49.6±\pm1.4
GPTQ 3.0 13.9 24.4±\pm1.2 49.7±\pm0.4 41.7±\pm1.4 73.1±\pm1.0 70.4±\pm0.5 66.2±\pm1.3
DiscQ 3.0 13.4 25.4±\pm1.2 51.5±\pm0.4 40.4±\pm1.4 73.2±\pm1.0 69.6±\pm0.5 64.2±\pm1.3
RTN 4.0 11.2 51.6±\pm1.4 59.5±\pm0.4 50.1±\pm1.5 78.9±\pm1.0 74.5±\pm0.4 71.0±\pm1.3
GPTQ 4.0 9.5 70.7±\pm1.3 64.9±\pm0.4 52.8±\pm1.5 80.0±\pm0.9 77.4±\pm0.4 72.7±\pm1.3
DiscQ 4.0 9.6 69.4±\pm1.3 63.7±\pm0.4 54.1±\pm1.5 80.7±\pm0.9 77.0±\pm0.4 73.2±\pm1.2
Table 5: Meta-Llama-3.1-8B-Instruct with incoherence processing. Across a majority of bits and tasks, DiscQuant achieves comparable compression with GPTQ, and does better than RNT.

A.1 Experimental Setup Details

The experiments for the Phi-3-mini model were conducted on either a single 80GB Nvidia A100 GPU, or 2x40GB A100 GPUs, while the Llama-3.1-8B model used either 2x80GB A100s, or 4x40GB A100s. We use the PyTorch framework. We initialize x[0,1]nx\in[0,1]^{n} uniformly at random, and used AdamW (Loshchilov and Hutter, 2019) with a cosine learning rate schedule. We multiply the regularization coefficient λ\lambda with the KL loss term, and perform entry-wise gradient clipping on the KL loss term. For DiscQuant, we tuned the hyper-parameters for each model and bit setting. The hyper-parameters clamp, λ\lambda, lr, batch_size, num_iter and warmup were tuned. In the block scaling setting we found that clamp={1.0, 0.5}, λ\lambda=200, lr={0.1, 0.05}, batch_size={4,8}, num_iter=1024, warmup=128 worked well for both models. In the incoherence processing setting we found that clamp={0.05,0.01}, lr={0.05,0.01} worked well for both models, all other parameters being the same as before. For GPTQ, we used the actorder, true_sequential heuristics, and tuned the number of samples over {1024, 4096, 8192} for each model and bit setting. Our quantization dataset is constructed from the RedPajama-1T-Sample training set (Computer, 2023). We concatenate random samples until up to 2048 sequence length, truncating the last sample if necessary. Greedy or round-to-nearest requires no data, and no hyper-parameter tuning.

Refer to caption
Figure 7: Quantizing Phi-3-mini-4k-instruct with block scaling, and additional incoherence processing. Adding incoherence processing largely improves model quality at 3 bits. At 4 bits, these improvements are smaller. At 3 bits, DiscQuant is better than GPTQ with incoherence processing.
Refer to caption
Figure 8: Quantizing Meta-Llama-3.1-8B-Instruct with block scaling, and additional incoherence processing. Adding incoherence processing largely improves model quality at 3 bits. At 4 bits, these improvements are smaller. After incoherence, DiscQuant is largely comparable to GPTQ.

A.2 Incoherence Processing

Table 4 shows our results quantizing Phi-3-mini-4k-instruct with incoherence processing. At 3 bits per weight, DiscQuant achieves superior compression across all tasks. At 4 bits per weight, DiscQuant achieves comparable compression. For example, on ARC_CHallenge at 3 bits, DiscQuant achieves 46.2% accuracy, while GPTQ achieves 39.2%, and RTN 28.5%. Table 5 shows our results quantizing Meta-Llama-3.1-8B-Instruct with incoherence processing. DiscQuant performs comparably to GPTQ, and better than RTN. For example, on WinoGrande at 4 bits, DiscQuant achieves 73.2% accuracy, while GPTQ achieves 72.7%, and RTN 71.0%.

Figures 7 and 8 show the results adding incoherence processing superimposed over just using block scaling. Incoherence processing largely improves quantization at 3 bits across both models, whereas at 4 bits the improvements are smaller. In the Phi-3 model at 3 bits, DiscQuant without incoherence is better than GPTQ with incoherence. Across the other models and bit settings, DiscQuant and GPTQ are comparable after incoherence processing.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 9: Effect of increasing the fraction of math data when quantizing Phi-3-mini-4k-instruct at 3.25 bits. For 8192 total samples, we use a fraction of math subject data (GSM8k and MetaMathQA), and the remaining our standard RedPajama. Across all evaluations, there is a meaningful change as a result of changing the data mix.
KL Coeff Intermed Coeff Intermed Type Wiki\downarrow GSM8k\uparrow
1.0 0.0 None 12.8 64.9±\pm1.3
0.0 1.0 Layer 14.7 54.1±\pm1.4
0.0 1.0 Linear 14.3 60.1±\pm1.4
0.1 0.9 Linear 13.1 61.4±\pm1.3
0.5 0.5 Linear 12.9 63.9±\pm1.3
0.9 0.1 Linear 12.8 63.8±\pm1.3
Table 6: Distillation Ablations. Quantizing Phi-3-mini-4k to 3.25 bits using a reduced 1024 samples of RedPajama. We test affine combinations between the KL divergence loss and intermediate L2 loss, which is either between the linear or decoder layers. Standard KL divergence does best.

A.3 Effect of Data

Here we give the full set of evaluation tasks when changing the mix of math subject data when quantizing Phi-3-mini-4k-instruct to 3.25 bits. It is interesting that across all evaluation tasks, there is a meaningful change in evaluation metrics as a result of changing the data mix. We leave the question of appropriate data curation as an important open question.

A.4 Ablations

We tried several distillation formulations, but ultimately chose a standard KL divergence between the outputs of the original and quantized model as the best approach. See Table 6. We quantize Phi-3-mini-4k-instruct to 3.25 bits, using 1024 samples. We tune the hyper-parameters as described at the beginning of this section. Note that for these ablations we used fewer samples than in our main experiments. In addition to the standard KL divergence, we tried several intermediate loss formulations for knowledge distillation. We used a normalized L2 loss between the outputs of the teacher and student, either per decoder layer (Intermed Type = Layer), or between each linear layer (Intermed Type = Linear). This distillation formulation was presented in Kurtic et al. (2023) for recovering LLMs after pruning. We also investigated taking an affine combination between the KL and intermediate losses, trying several different coefficients. Table 6 shows our results; using just the KL divergence gives the best results. We also tried minimizing the ground truth loss instead of a distillation loss. We use the same setup as Table 6, and find that minimizing the ground truth loss achieves 52.7% GSM8k accuracy, and 13.6 Wikitext perplexity. Therefore we use the KL divergence.

Appendix B Rounding weights via Discrepancy Theory

B.1 The Lovett Meka algorithm

A seminal result by Lovett and Meka Lovett and Meka (2012) works as follows: we are given a point y[0,1]ny\in[0,1]^{n} in the hypercube, vectors v1,,vmnv_{1},\ldots,v_{m}\in\mathbb{R}^{n} with vj2=1\|v_{j}\|_{2}=1 and parameters cj0c_{j}\geq 0 so that j=1mecj2/16n16\sum_{j=1}^{m}e^{-c_{j}^{2}/16}\leq\frac{n}{16}. Then in randomized polynomial time one can find a point x[0,1]nx\in[0,1]^{n} so that |vj,xy|cj|\left<v_{j},x-y\right>|\leq c_{j} for all jj and at least half the coordinates of xx are integral. Their algorithm is simple and elegant: we construct xx as the outcome of a random walk starting at yy. Then iteratively, for some small step size δ>0\delta>0 we add the outcome of a random Gaussian times δ\delta to the current point. After hitting some constraint xi=0x_{i}=0, xi=1x_{i}=1 or |vj,xy|=cj|\left<v_{j},x-y\right>|=c_{j}, the Gaussian updates will be taken orthogonal to those normal vectors. In other words, the random walk will continue in the face of the described polytope. Still Lovett and Meka (2012) prove that performing the updates for O(1δ2)O(\frac{1}{\delta^{2}}) iterations the walk will cover enough distance so that on average Θ(n)\Theta(n) box constraints must become tight.

In our setting we only need to use parameters cj=0c_{j}=0. However we use some properties of the Lovett-Meka algorithm that are not explicitly stated elsewhere. Here we denote M𝒮(1)\|M\|_{\mathcal{S}(1)} as the sum of the singular values of a matrix MM (also called Schatten-1 norm, nuclear norm or trace norm of MM).

Theorem B.1 (Derived from Lovett and Meka (2012)).

Let g1,,gmng_{1},\ldots,g_{m}\in\mathbb{R}^{n} be any vectors with mn16m\leq\frac{n}{16} and let y[0,1]ny\in[0,1]^{n}. Then in polynomial time one can compute a sample x𝒟:=𝒟LM(g1,,gm,y)x\sim\mathcal{D}:=\mathcal{D}_{LM}(g_{1},\ldots,g_{m},y) so that

  1. (i)

    One has x[0,1]nx\in[0,1]^{n} and with probability at least 110\frac{1}{10} one has |{j[n]:xj{0,1}}|n2|\{j\in[n]:x_{j}\in\{0,1\}\}|\geq\frac{n}{2}.

  2. (ii)

    For any vector θn\theta\in\mathbb{R}^{n} one has 𝔼x𝒟[θ,xy2]O(θ22)\mathbb{E}_{x\sim\mathcal{D}}[\left<\theta,x-y\right>^{2}]\leq O(\|\theta\|_{2}^{2}).

  3. (iii)

    For any symmetric matrix Mn×nM\in\mathbb{R}^{n\times n} one has 𝔼[M,(xy)(xy)T]O(M𝒮(1))\mathbb{E}[\left<M,(x-y)(x-y)^{T}\right>]\leq O(\|M\|_{\mathcal{S}(1)}).

Proof.

(i) is explicitly in Lovett and Meka (2012). For (ii) we use that the outcome of the random walk is of the form

x=y+δt=1O(1/δ2)utwhereutN(𝟎,Σt)x=y+\delta\sum_{t=1}^{O(1/\delta^{2})}u_{t}\quad\textrm{where}\quad u_{t}\sim N(\bm{0},\Sigma_{t})

Here 0ΣtIn0\preceq\Sigma_{t}\preceq I_{n}. But crucially each covariance matrix Σt\Sigma_{t} may depend on the outcome of u1,,ut1u_{1},\ldots,u_{t-1}. In particular it is not true that xyx-y is Gaussian. But it is a Martingale and as for each step tt one has 𝔼[ut,θ]=0\mathbb{E}[\left<u_{t},\theta\right>]=0 and 𝔼[ut,θ2]O(θ22)\mathbb{E}[\left<u_{t},\theta\right>^{2}]\leq O(\|\theta\|_{2}^{2}), the variance still satisfies 𝔼[<δt=1O(1/δ2)ut,θ>2]O(θ22)\mathbb{E}[\big{<}\delta\sum_{t=1}^{O(1/\delta^{2})}u_{t},\theta\big{>}^{2}]\leq O(\|\theta\|_{2}^{2}) which settles (ii). Finally we argue why (iii) holds. We note that (ii) can be restated as 𝔼x𝒟[(xy)(xy)T]O(1)In\mathbb{E}_{x\sim\mathcal{D}}[(x-y)(x-y)^{T}]\preceq O(1)\cdot I_{n}. Then

𝔼[M,(xy)(xy)T]\displaystyle\mathbb{E}[\left<M,(x-y)(x-y)^{T}\right>] =M,𝔼[(xy)(xy)T]\displaystyle=\left<M,\mathbb{E}[(x-y)(x-y)^{T}]\right>
M𝒮(1)𝔼[(xy)(xy)T]op\displaystyle\leq\|M\|_{\mathcal{S}(1)}\cdot\|\mathbb{E}[(x-y)(x-y)^{T}]\|_{\textrm{op}}
O(M𝒮(1)).\displaystyle\leq O(\|M\|_{\mathcal{S}(1)}).

B.2 The main theoretical result

As explained earlier we assume that we are given a weight vector y[0,1]ny\in[0,1]^{n} and have access to samples g1,,gm𝒟g_{1},\ldots,g_{m}\sim\mathcal{D} where 𝒟\mathcal{D} is a distribution on n\mathbb{R}^{n} whose covariance matrix Σ:=𝔼g𝒟data[ggT]\Sigma:=\mathbb{E}_{g\sim\mathcal{D}_{\textrm{data}}}[gg^{T}] has rapidly decaying Eigenvalues, say λkCkα\lambda_{k}\leq\frac{C}{k^{\alpha}} for some constants C>0C>0 and α>1\alpha>1. In order to prove rigorous bounds we also need a mild assumption that provides that the distribution is well-behaved. We use the notion by O’Donnell O’Donnell (2014) and say that for a parameter β1\beta\geq 1, a random vector XnX\in\mathbb{R}^{n} is β\beta-reasonable if

𝔼[X,θ4]β𝔼[X,θ2]2θn\mathbb{E}[\left<X,\theta\right>^{4}]\leq\beta\cdot\mathbb{E}[\left<X,\theta\right>^{2}]^{2}\quad\forall\theta\in\mathbb{R}^{n}

For example X{1,1}nX\sim\{-1,1\}^{n} and a Gaussian XN(𝟎,Σ)X\sim N(\bm{0},\Sigma) are both O(1)O(1)-reasonable. Our main theoretical result is then:

Theorem B.2.

Let α>1\alpha>1 and β1\beta\geq 1 be constants and let 1mn161\leq m\leq\frac{n}{16}. Let 𝒟\mathcal{D} be a β\beta-reasonable distribution with unknown covariance matrix Σn×n\Sigma\in\mathbb{R}^{n\times n} whose Eigenvalues satisfy λk1kα\lambda_{k}\leq\frac{1}{k^{\alpha}} for all k=1,,nk=1,\ldots,n. Then there is a randomized polynomial time algorithm that given a y[0,1]ny\in[0,1]^{n} and mm independent samples g1,,gm𝒟g_{1},\ldots,g_{m}\sim\mathcal{D}, produces an x[0,1]nx\in[0,1]^{n} so that with probability at least 0.99 one has

  1. (i)

    |frac(x)|16m|\textrm{frac}(x)|\leq 16m

  2. (ii)

    Σ,(xy)(xy)Tαlog(nm)Fα(m,n)\left<\Sigma,(x-y)(x-y)^{T}\right>\lesssim_{\alpha}\log(\frac{n}{m})\cdot F_{\alpha}(m,n) where

    Fα(m,n):={m1αif 1<α<32log(n)mif α=321mif α>3/2.F_{\alpha}(m,n):=\begin{cases}m^{1-\alpha}&\textrm{if }1<\alpha<\frac{3}{2}\\ \frac{\log(n)}{\sqrt{m}}&\textrm{if }\alpha=\frac{3}{2}\\ \frac{1}{\sqrt{m}}&\textrm{if }\alpha>3/2.\end{cases}

Ignoring polylogarithmic factors, this means that we can find an xx with O(m)O(m) fractional coordinates left and Σ,(xy)(xy)Tmax{m1α,1m}\left<\Sigma,(x-y)(x-y)^{T}\right>\leq\max\{m^{1-\alpha},\frac{1}{\sqrt{m}}\}. The algorithm to compute xx as in Theorem B.2 is simple:

Lovett-Meka Rounding Algorithm   Input: Weight vector y[0,1]ny\in[0,1]^{n} and parameter mm Output: Rounded vector xx (1) Sample g1,,gm𝒟g_{1},\ldots,g_{m}\sim\mathcal{D}. Initialize x(0):=yx^{(0)}:=y (2) FOR t=1t=1 TO \infty DO (3) IF |frac(x(t1))|16m|\textrm{frac}(x^{(t-1)})|\leq 16m then return x(t1)x^{(t-1)} (4) Set x(t):=𝒟LM(g1,,gm,x(t1))x^{(t)}:=\mathcal{D}_{LM}(g_{1},\ldots,g_{m},x^{(t-1)})

A crucial aspect of analyzing this algorithm is understanding how far the covariance estimator 1mj=1mgjgjT\frac{1}{m}\sum_{j=1}^{m}g_{j}g_{j}^{T} is from the actual covariance matrix Σ\Sigma in terms of the Schatten 1-norm 𝒮(1)\|\cdot\|_{\mathcal{S}(1)}. We use the following result.

Proposition B.3.

Let α>1\alpha>1, β1\beta\geq 1 and let 𝒟\mathcal{D} be a β\beta-reasonable distribution with covariance matrix Σn×n\Sigma\in\mathbb{R}^{n\times n} whose Eigenvalues satisfy λk1kα\lambda_{k}\leq\frac{1}{k^{\alpha}} for all k=1,,nk=1,\ldots,n. Let g1,,gm𝒟g_{1},\ldots,g_{m}\sim\mathcal{D} be independent samples and let X():=1mggTX^{(\ell)}:=\frac{1}{m}g_{\ell}g_{\ell}^{T} and X:==1mX()X:=\sum_{\ell=1}^{m}X^{(\ell)}. Then

𝔼[XΣ𝒮(1)]α,βFα(m,n)\mathbb{E}[\|X-\Sigma\|_{\mathcal{S}(1)}]\lesssim_{\alpha,\beta}F_{\alpha}(m,n)

where Fα(m,n)F_{\alpha}(m,n) is as defined in Theorem B.2.

We postpone the proof of Prop B.3 to Section B.3 and first conclude the proof of Theorem B.2.

Proof of Theorem B.2.

Suppose x(t)x^{(t^{*})} is the vector that the algorithm returned in (3). It will be notationally convenient to define x(t):=x(t)x^{(t)}:=x^{(t^{*})} for all t>tt>t^{*}. We say that iteration tt is good if either |frac(x(t1))|16m|\textrm{frac}(x^{(t-1)})|\leq 16m or if |frac(x(t))|12|frac(x(t1))||\textrm{frac}(x^{(t)})|\leq\frac{1}{2}|\textrm{frac}(x^{(t-1)})|. If an iteration tt is not good, we repeat the iteration until it is good. From Theorem B.1.(i) we know that every iteration is good with probability at least 110\frac{1}{10} (independently of previous outcomes), thus by standard Chernov bounds, with probability at least 0.99, within the first T:=Clog(nm)T:=C^{\prime}\log(\frac{n}{m}) iterations there must be at least log(nm)\log(\frac{n}{m}) many good iterations, for C>0C^{\prime}>0 a sufficiently large constant. After log(nm)\log(\frac{n}{m}) good iterations, one has |frac(x(T))|16m|\textrm{frac}(x^{(T)})|\leq 16m, and moreover the suffered discrepancy is

𝔼[<Σ,(x(T)y)(x(T)y)T>]t=1T𝔼[<Σ,(x(t)x(t1))(x(t)x(t1))T>]α,βTFα(m,n).\mathbb{E}\big{[}\big{<}\Sigma,(x^{(T)}-y)(x^{(T)}-y)^{T}\big{>}\big{]}\leq\sum_{t=1}^{T}\mathbb{E}\big{[}\big{<}\Sigma,(x^{(t)}-x^{(t-1)})(x^{(t)}-x^{(t-1)})^{T}\big{>}\big{]}\lesssim_{\alpha,\beta}T\cdot F_{\alpha}(m,n).

Thus the claim then follows. ∎

B.3 Analyzing the covariance estimator

It remains to prove Prop B.3.

Proof of Prop B.3.

We first present the proof for the case of 1<α<321<\alpha<\frac{3}{2} and then discuss the modifications for the other two cases. The claim is invariant under a change of basis, hence we may assume that Σ\Sigma is a diagonal matrix with Eigenvalues λ1λn0\lambda_{1}\geq\ldots\geq\lambda_{n}\geq 0, i.e. Σii=λi\Sigma_{ii}=\lambda_{i} for all i[n]i\in[n]. We can bound the variance terms for all entries (whether diagonal or not):
Claim I. For all i,j[n]i,j\in[n] one has 𝔼[|XijΣij|2]βλiλjm\mathbb{E}[|X_{ij}-\Sigma_{ij}|^{2}]\lesssim_{\beta}\frac{\lambda_{i}\lambda_{j}}{m}.
Proof of Claim I. We recall that 𝔼[X]=Σ\mathbb{E}[X]=\Sigma and 𝔼[X()]=1mΣ\mathbb{E}[X^{(\ell)}]=\frac{1}{m}\Sigma. For all i,j[n]i,j\in[n] one has

𝔼[|XijΣij|2]\displaystyle\mathbb{E}[|X_{ij}-\Sigma_{ij}|^{2}] =\displaystyle= Var[Xij]\displaystyle\textrm{Var}[X_{ij}]
=\displaystyle= =1mVar[Xij()]\displaystyle\sum_{\ell=1}^{m}\textrm{Var}[X_{ij}^{(\ell)}]
=\displaystyle= 1m𝔼h𝒟[|hihjΣij|2]\displaystyle\frac{1}{m}\mathbb{E}_{h\sim\mathcal{D}}[|h_{i}h_{j}-\Sigma_{ij}|^{2}]
\displaystyle\leq 2m(𝔼h𝒟[hi2hj2]+Σij2λiλj)\displaystyle\frac{2}{m}\Big{(}\mathbb{E}_{h\sim\mathcal{D}}[h_{i}^{2}h_{j}^{2}]+\underbrace{\Sigma_{ij}^{2}}_{\leq\lambda_{i}\lambda_{j}}\Big{)}
()\displaystyle\stackrel{{\scriptstyle(*)}}{{\leq}} 2m(𝔼h𝒟[hi4]1/2𝔼h𝒟[hj4]1/2+λiλj)\displaystyle\frac{2}{m}(\mathbb{E}_{h\sim\mathcal{D}}[h_{i}^{4}]^{1/2}\mathbb{E}_{h\sim\mathcal{D}}[h_{j}^{4}]^{1/2}+\lambda_{i}\lambda_{j})
()\displaystyle\stackrel{{\scriptstyle(**)}}{{\leq}} 2m(β1/2𝔼h𝒟[hi2]=λiβ1/2𝔼h𝒟[hj2]=λj+λiλj)=2β+2mλiλj\displaystyle\frac{2}{m}\big{(}\beta^{1/2}\underbrace{\mathbb{E}_{h\sim\mathcal{D}}[h_{i}^{2}]}_{=\lambda_{i}}\cdot\beta^{1/2}\underbrace{\mathbb{E}_{h\sim\mathcal{D}}[h_{j}^{2}]}_{=\lambda_{j}}+\lambda_{i}\lambda_{j}\big{)}=\frac{2\beta+2}{m}\cdot\lambda_{i}\lambda_{j}

Here we use the inequality (ab)22a2+2b2(a-b)^{2}\leq 2a^{2}+2b^{2}. Moveover Σijλiλj\Sigma_{ij}\leq\lambda_{i}\lambda_{j} holds because Σ\Sigma is a diagonal matrix. Note that we have used Cauchy-Schwarz in ()(*) and the assumption that 𝒟\mathcal{D} is β\beta-reasonable in ()(**). ∎
Now let J:={i[n]21i<2}J_{\ell}:=\{i\in[n]\mid 2^{\ell-1}\leq i<2^{\ell}\}. It will be useful to note that |J|2|J_{\ell}|\leq 2^{\ell} and the sum of the Eigenvalues in each block satisfies iJλi2(2)α=(2)1α\sum_{i\in J_{\ell}}\lambda_{i}\lesssim 2^{\ell}\cdot(2^{\ell})^{-\alpha}=(2^{\ell})^{1-\alpha}. Our strategy is to use the triangle inequality to bound:

𝔼[XΣ𝒮(1)]21k𝔼[XJ,JkΣJ,Jk𝒮(1)]\mathbb{E}[\|X-\Sigma\|_{\mathcal{S}(1)}]\leq 2\sum_{\ell\geq 1}\sum_{k\geq\ell}\mathbb{E}[\|X_{J_{\ell},J_{k}}-\Sigma_{J_{\ell},J_{k}}\|_{\mathcal{S}(1)}] (4)

Here XJ,JkX_{J_{\ell},J_{k}} is the |J|×|Jk||J_{\ell}|\times|J_{k}| submatrix of XX that is indexed by rows JJ_{\ell} and columns JkJ_{k}. In the following we will estimate the contribution of the different blocks depending on their parameter regime and whether they are diagonal or off-diagonal.

Claim II. Let k\ell\leq k and abbreviate Y:=XJ,JkΣJ,JkY:=X_{J_{\ell},J_{k}}-\Sigma_{J_{\ell},J_{k}}. Then

𝔼[Y𝒮(1)]rm2+k2(1α)\mathbb{E}[\|Y\|_{\mathcal{S}(1)}]\lesssim\sqrt{\frac{r}{m}}\cdot 2^{\frac{\ell+k}{2}(1-\alpha)}

assuming that rank(Y)r\textrm{rank}(Y)\leq r for any outcome of YY.
Proof of Claim II. We recall that for any matrix AA one has A𝒮(1)rank(A)AF\|A\|_{\mathcal{S}(1)}\leq\sqrt{\textrm{rank}(A)}\cdot\|A\|_{F}. Then for all k\ell\leq k we can bound

𝔼[Y𝒮(1)]\displaystyle\mathbb{E}[\|Y\|_{\mathcal{S}(1)}] \displaystyle\leq r𝔼[YF]\displaystyle\sqrt{r}\cdot\mathbb{E}[\|Y\|_{F}]
Jensen\displaystyle\stackrel{{\scriptstyle\textrm{Jensen}}}{{\leq}} r𝔼[YF2]1/2\displaystyle\sqrt{r}\cdot\mathbb{E}[\|Y\|_{F}^{2}]^{1/2}
βClaim Iβ\displaystyle\stackrel{{\scriptstyle\textrm{Claim I}}}{{\lesssim_{\beta}}} r(1m(iJλi)(jJkλj))1/2\displaystyle\sqrt{r}\cdot\Big{(}\frac{1}{m}\Big{(}\sum_{i\in J_{\ell}}\lambda_{i}\Big{)}\Big{(}\sum_{j\in J_{k}}\lambda_{j}\Big{)}\Big{)}^{1/2}
\displaystyle\lesssim r1m(2)1α(2k)1α\displaystyle\sqrt{r}\cdot\sqrt{\frac{1}{m}\cdot(2^{\ell})^{1-\alpha}\cdot(2^{k})^{1-\alpha}}
=\displaystyle= rm2+k2(1α)\displaystyle\sqrt{\frac{r}{m}}\cdot 2^{\frac{\ell+k}{2}(1-\alpha)}\qed

Now we can bound the contribution that off-diagonal blocks have to Eq (4). Here we use that ΣJ,Jk=𝟎\Sigma_{J_{\ell},J_{k}}=\bm{0} and rank(XJ,Jk)min{m,2}\textrm{rank}(X_{J_{\ell},J_{k}})\leq\min\{m,2^{\ell}\}. Then

1k>𝔼[XJ,JkΣJ,Jk=𝟎𝒮(1)]\displaystyle\sum_{\ell\geq 1}\sum_{k>\ell}\mathbb{E}\big{[}\|X_{J_{\ell},J_{k}}-\underbrace{\Sigma_{J_{\ell},J_{k}}}_{=\bm{0}}\|_{\mathcal{S}(1)}\big{]} Claim II\displaystyle\stackrel{{\scriptstyle\textrm{Claim II}}}{{\leq}} 1k>min{m,2}m2+k2(1α)\displaystyle\sum_{\ell\geq 1}\sum_{k>\ell}\frac{\sqrt{\min\{m,2^{\ell}\}}}{\sqrt{m}}\cdot 2^{\frac{\ell+k}{2}(1-\alpha)}
=\displaystyle= 1min{1,2/m}22(1α)k>2k2(1α)α2(1α)/2\displaystyle\sum_{\ell\geq 1}\min\big{\{}1,\sqrt{2^{\ell}/m}\big{\}}\cdot 2^{\frac{\ell}{2}(1-\alpha)}\underbrace{\sum_{k>\ell}\cdot 2^{\frac{k}{2}(1-\alpha)}}_{\lesssim_{\alpha}2^{\ell(1-\alpha)/2}}
α\displaystyle\lesssim_{\alpha} 1min{1,2/m}(2)1α\displaystyle\sum_{\ell\geq 1}\min\big{\{}1,\sqrt{2^{\ell}/m}\big{\}}\cdot(2^{\ell})^{1-\alpha}
α\displaystyle\lesssim_{\alpha} m1α\displaystyle m^{1-\alpha}

In the last step we use that the function zzz1αz\mapsto\sqrt{z}\cdot z^{1-\alpha} is monotonically increasing while zz1αz\mapsto z^{1-\alpha} is monotonically decreasing as we assume that 1<α<321<\alpha<\frac{3}{2}. Hence the term with m=2m=2^{\ell} dominates the sum.

It remains to bound the diagonal blocks. First we consider the regime of small indices. Here we use the bound rank(XJ,JΣJ,J)|J|2\textrm{rank}(X_{J_{\ell},J_{\ell}}-\Sigma_{J_{\ell},J_{\ell}})\leq|J_{\ell}|\leq 2^{\ell} which gives

:2m𝔼[XJ,JΣJ,J𝒮(1)]Claim II:2m2m2(1α)m1α\sum_{\ell:2^{\ell}\leq m}\mathbb{E}[\|X_{J_{\ell},J_{\ell}}-\Sigma_{J_{\ell},J_{\ell}}\|_{\mathcal{S}(1)}]\stackrel{{\scriptstyle\textrm{Claim II}}}{{\leq}}\sum_{\ell:2^{\ell}\leq m}\sqrt{\frac{2^{\ell}}{m}}\cdot 2^{\ell(1-\alpha)}\lesssim m^{1-\alpha} (6)

Here the last summand (with 2=m2^{\ell}=m) dominates the sum in (6), again as zzz1αz\mapsto\sqrt{z}\cdot z^{1-\alpha} is monotonically increasing.

The final regime to consider is the one of large indices, i.e. diagonal blocks with 2>m2^{\ell}>m. In that case we can ignore any concentration that the randomness may provide and simply bound

:2>m𝔼[XJ,JΣJ,J𝒮(1)]\displaystyle\sum_{\ell:2^{\ell}>m}\mathbb{E}[\|X_{J_{\ell},J_{\ell}}-\Sigma_{J_{\ell},J_{\ell}}\|_{\mathcal{S}(1)}] \displaystyle\leq :2>m(𝔼[XJ,J𝒮(1)]+ΣJ,J𝒮(1))\displaystyle\sum_{\ell:2^{\ell}>m}\big{(}\mathbb{E}[\|X_{J_{\ell},J_{\ell}}\|_{\mathcal{S}(1)}]+\|\Sigma_{J_{\ell},J_{\ell}}\|_{\mathcal{S}(1)}\big{)} (7)
=\displaystyle= :2>m(𝔼[Tr[XJ,J]]+Tr[ΣJ,J])\displaystyle\sum_{\ell:2^{\ell}>m}\big{(}\mathbb{E}[\textrm{Tr}[X_{J_{\ell},J_{\ell}}]]+\textrm{Tr}[\Sigma_{J_{\ell},J_{\ell}}]\big{)}
=\displaystyle= j=mn(𝔼[Xjj]=Σjj+Σjjjα)\displaystyle\sum_{j=m}^{n}(\underbrace{\mathbb{E}[X_{jj}]}_{=\Sigma_{jj}}+\underbrace{\Sigma_{jj}}_{\leq j^{-\alpha}})
\displaystyle\lesssim jm1jαm1α\displaystyle\sum_{j\geq m}\frac{1}{j^{\alpha}}\lesssim m^{1-\alpha}

Here we use again the triangle inequality of the trace norm and the fact that the matrices XJ,JX_{J_{\ell},J_{\ell}} and ΣJ,J\Sigma_{J_{\ell},J_{\ell}} are always positive semidefinite. This concludes the argument for 1<α<321<\alpha<\frac{3}{2}. If α=32\alpha=\frac{3}{2} then 2/m(2)1α1m\sqrt{2^{\ell}/m}\cdot(2^{\ell})^{1-\alpha}\leq\frac{1}{\sqrt{m}} for each 1\ell\geq 1 and so (B.3) is bounded by log(n)m\frac{\log(n)}{\sqrt{m}}. Moreover, the last two cases can be merged as

1𝔼[XJ,JΣJ,J𝒮(1)]Claim II12m2(1α)log(n)m\sum_{\ell\geq 1}\mathbb{E}[\|X_{J_{\ell},J_{\ell}}-\Sigma_{J_{\ell},J_{\ell}}\|_{\mathcal{S}(1)}]\stackrel{{\scriptstyle\textrm{Claim II}}}{{\leq}}\sum_{\ell\geq 1}\sqrt{\frac{2^{\ell}}{m}}\cdot 2^{\ell(1-\alpha)}\lesssim\frac{\log(n)}{\sqrt{m}} (8)

Finally, if α>32\alpha>\frac{3}{2} then the first term (for =1\ell=1) dominates the sums in (B.3) and (8) and the extra log(n)\log(n) term can be omitted. ∎

Appendix C Non-uniform Quantization Grid

We will introduce new parameters x[0,1]nx\in[0,1]^{n} and define

wx=wdown(1x)+wupxw^{x}=w^{\textrm{down}}\odot(1-x)+w^{\textrm{up}}\odot x

where \odot is component-wise product. Note that wixw^{x}_{i} interpolates between widownw^{\textrm{down}}_{i} and wiupw^{\textrm{up}}_{i} where wi=widownw_{i}=w^{\textrm{down}}_{i} if xi=0x_{i}=0 and wi=wiupw_{i}=w^{\textrm{up}}_{i} if xi=1x_{i}=1. Let y[0,1]ny\in[0,1]^{n} be the interpolation point corresponding to the original weights, i.e., wy=ww^{y}=w. We can rewrite the linear constraints in terms of xx as follows:

wf(w;si),wxw\displaystyle\left\langle\nabla_{w}f(w;s_{i}),w^{x}-w\right\rangle =wf(w;si),wxwy\displaystyle=\left\langle\nabla_{w}f(w;s_{i}),w^{x}-w^{y}\right\rangle
=wf(w;si),(wupwdown)(xy)\displaystyle=\left\langle\nabla_{w}f(w;s_{i}),(w^{\textrm{up}}-w^{\textrm{down}})\odot(x-y)\right\rangle
=wf(w;si)(wupwdown),xy.\displaystyle=\left\langle\nabla_{w}f(w;s_{i})\odot(w^{\textrm{up}}-w^{\textrm{down}}),x-y\right\rangle.

Let MM be an m×nm\times n matrix whose ithi^{th} row is given by wf(w;si)(wupwdown)\nabla_{w}f(w;s_{i})\odot(w^{\textrm{up}}-w^{\textrm{down}}). Then the linear constraints can be simply written as M(xy)=0M(x-y)=0.

Appendix D Taylor Series for KL Divergence

Let pw(|z<i)p_{w}(\cdot|z_{<i}) be the distribution of the next token predicted by the original model given prefix z<iz_{<i} where z𝒟dataz\sim\mathcal{D}_{\textrm{data}} is a sample from the data distribution. Let

error(w^)=𝔼z𝒟data𝔼iDKL(pw(|z<i)pw^(|z<i))\textrm{error}(\hat{w})=\mathbb{E}_{z\sim\mathcal{D}_{\textrm{data}}}\mathbb{E}_{i}D_{KL}\left(p_{w}(\cdot|z_{<i})\left\|\right.p_{\hat{w}}(\cdot|z_{<i})\right)

be the KL divergence between the original model and quantized model.

Lemma D.1.

Let

error(w^)=gw,w^w+(w^w)THw(w^w)+\displaystyle\textrm{error}(\hat{w})=\left\langle g_{w},\hat{w}-w\right\rangle+(\hat{w}-w)^{T}H_{w}(\hat{w}-w)+\cdots

be the Taylor series expansion of the KL divergence where gwg_{w} is the gradient and HwH_{w} is the Hessian. Then

  1. 1.

    gw=0g_{w}=0,

  2. 2.

    Hw=𝔼z𝒟data𝔼i𝔼tpw(|z<i)[(wlogpw(t|z<i))(wlogpw(t|z<i))T]H_{w}=\mathbb{E}_{z\sim\mathcal{D}_{\textrm{data}}}\mathbb{E}_{i}\mathbb{E}_{t\sim p_{w}(\cdot|z_{<i})}[(\nabla_{w}\log p_{w}(t|z_{<i}))(\nabla_{w}\log p_{w}(t|z_{<i}))^{T}]

Therefore error(w^)𝔼z𝒟data𝔼i𝔼tpw(|z<i)[wpw(t|z<i),w^w2].error(\hat{w})\approx\mathbb{E}_{z\sim\mathcal{D}_{\textrm{data}}}\mathbb{E}_{i}\mathbb{E}_{t\sim p_{w}(\cdot|z_{<i})}[\left\langle\nabla_{w}p_{w}(t|z_{<i}),\hat{w}-w\right\rangle^{2}].

Proof.

To simplify notation, we will ignore the z,iz,i variables coming from 𝔼z𝒟data\mathbb{E}_{z\sim\mathcal{D}_{\textrm{data}}} and 𝔼i\mathbb{E}_{i} and also drop them from pw(|z<i)p_{w}(\cdot|z_{<i}) and just write pw().p_{w}(\cdot). Adding these back and taking expectations over these variables, we get the desired result. We can expand the KL divergence using Taylor series and evaluate the first and second order terms.

error(w^)\displaystyle\textrm{error}(\hat{w}) =DKL(pw()pw^())\displaystyle=D_{KL}\left(p_{w}(\cdot)\left\|\right.p_{\hat{w}}(\cdot)\right)
=Et[logpw^(t)logpw(t)]\displaystyle=-E_{t\sim}\left[\log p_{\hat{w}}(t)-\log p_{w}(t)\right]
=Etpw[wlogpw(t),w^w+(w^w)Tw2logpw(t)(w^w)+]\displaystyle=-E_{t\sim p_{w}}\left[\left\langle\nabla_{w}\log p_{w}(t),\hat{w}-w\right\rangle+(\hat{w}-w)^{T}\nabla^{2}_{w}\log p_{w}(t)(\hat{w}-w)+\cdots\right]
=gw,w^w+(w^w)THw(w^w)+\displaystyle=\left\langle g_{w},\hat{w}-w\right\rangle+(\hat{w}-w)^{T}H_{w}(\hat{w}-w)+\cdots

where gw=Etpw[wlogpw(t)]g_{w}=-E_{t\sim p_{w}}[\nabla_{w}\log p_{w}(t)] and Hw=Etpw[w2logpw(t)].H_{w}=-E_{t\sim p_{w}}[\nabla^{2}_{w}\log p_{w}(t)].
(1) We first evaluate gw.g_{w}.

gw=Etpw[wlogpw(t)]\displaystyle g_{w}=-E_{t\sim p_{w}}[\nabla_{w}\log p_{w}(t)] =𝔼tpw[wpw(t)pw(t)]\displaystyle=\mathbb{E}_{t\sim p_{w}}\left[\frac{\nabla_{w}p_{w}(t)}{p_{w}(t)}\right]
=twpw(t)\displaystyle=\sum_{t}\nabla_{w}p_{w}(t)
=w(tpw(t))\displaystyle=\nabla_{w}(\sum_{t}p_{w}(t))
=w(1)=0.\displaystyle=\nabla_{w}(1)=0.

(2) We now evaluate HwH_{w}.

Hw\displaystyle H_{w} =Etpw[w2logpw(t)]\displaystyle=-E_{t\sim p_{w}}[\nabla^{2}_{w}\log p_{w}(t)]
=𝔼tpw[w(wpw(t)pw(t))]\displaystyle=-\mathbb{E}_{t\sim p_{w}}\left[\nabla_{w}\left(\frac{\nabla_{w}p_{w}(t)}{p_{w}(t)}\right)\right]
=𝔼tpw[w2pw(t)pw(t)(wpw(t))(wpw(t))Tpw(t)2]\displaystyle=-\mathbb{E}_{t\sim p_{w}}\left[\frac{\nabla^{2}_{w}p_{w}(t)}{p_{w}(t)}-\frac{(\nabla_{w}p_{w}(t))(\nabla_{w}p_{w}(t))^{T}}{p_{w}(t)^{2}}\right]
=𝔼tpw[w2pw(t)pw(t)(wlogpw(t))(wlogpw(t))T]\displaystyle=-\mathbb{E}_{t\sim p_{w}}\left[\frac{\nabla^{2}_{w}p_{w}(t)}{p_{w}(t)}-(\nabla_{w}\log p_{w}(t))(\nabla_{w}\log p_{w}(t))^{T}\right]
=tw2pw(t)+𝔼tpw[(wlogpw(t))(wlogpw(t))T]\displaystyle=-\sum_{t}\nabla^{2}_{w}p_{w}(t)+\mathbb{E}_{t\sim p_{w}}\left[(\nabla_{w}\log p_{w}(t))(\nabla_{w}\log p_{w}(t))^{T}\right]
=w2(tpw(t))+𝔼tpw[(wlogpw(t))(wlogpw(t))T]\displaystyle=-\nabla^{2}_{w}\left(\sum_{t}p_{w}(t)\right)+\mathbb{E}_{t\sim p_{w}}\left[(\nabla_{w}\log p_{w}(t))(\nabla_{w}\log p_{w}(t))^{T}\right]
=w2(1)+𝔼tpw[(wlogpw(t))(wlogpw(t))T]\displaystyle=-\nabla^{2}_{w}\left(1\right)+\mathbb{E}_{t\sim p_{w}}\left[(\nabla_{w}\log p_{w}(t))(\nabla_{w}\log p_{w}(t))^{T}\right]
=𝔼tpw[(wlogpw(t))(wlogpw(t))T]\displaystyle=\mathbb{E}_{t\sim p_{w}}\left[(\nabla_{w}\log p_{w}(t))(\nabla_{w}\log p_{w}(t))^{T}\right]

Appendix E LoRA experiments

See Table 7 for our LoRA experiments. We initialize the model with the optimal choice of DisQ for 3.25 bits and add LoRA adapters. We train LoRA adapters while freezing the rest of the parameters.

LoRA lr LoRA rank GSM8k\uparrow Wiki\downarrow MMLU\uparrow Wino\uparrow
0.0 0 62.9±\pm 1.3 12.6 60.5±\pm 0.4 73.0±\pm 1.2
3E-06 8 63.2±\pm 1.3 12.6 60.8±\pm 0.4 72.8±\pm 1.3
3E-06 16 63.3±\pm 1.3 12.6 60.8±\pm 0.4 72.8±\pm 1.2
3E-06 32 63.3±\pm 1.3 12.6 60.8±\pm 0.4 73.0±\pm 1.2
1E-05 8 63.7±\pm 1.3 12.6 61.0±\pm 0.4 73.1±\pm 1.2
1E-05 16 63.5±\pm 1.3 12.6 60.9±\pm 0.4 73.2±\pm 1.2
1E-05 32 63.9±\pm 1.3 12.6 60.9±\pm 0.4 73.0±\pm 1.2
3E-05 8 64.4±\pm 1.3 12.5 61.1±\pm 0.4 73.0±\pm 1.2
3E-05 16 64.1±\pm 1.3 12.5 61.2±\pm 0.4 72.8±\pm 1.3
3E-05 32 64.1±\pm 1.3 12.5 61.0±\pm 0.4 72.9±\pm 1.2
1E-04 8 66.2±\pm 1.3 12.4 61.1±\pm 0.4 72.9±\pm 1.2
1E-04 16 66.7±\pm 1.3 12.4 61.4±\pm 0.4 72.6±\pm 1.3
1E-04 32 67.0±\pm 1.3 12.4 61.4±\pm 0.4 73.0±\pm 1.2
3E-04 8 65.3±\pm 1.3 12.3 61.2±\pm 0.4 73.5±\pm 1.2
3E-04 16 66.5±\pm 1.3 12.3 61.3±\pm 0.4 73.1±\pm 1.2
3E-04 32 66.8±\pm 1.3 12.3 61.4±\pm 0.4 73.1±\pm 1.2
1E-03 8 0.0±\pm 0.0 21056.1 22.9±\pm 0.4 51.3±\pm 1.4
1E-03 16 59.2±\pm 1.4 13.0 58.1±\pm 0.4 73.2±\pm 1.2
1E-03 32 59.4±\pm 1.4 12.9 58.5±\pm 0.4 72.7±\pm 1.3
Base model
84.4±\pm 1.0 9.5 70.4±\pm 0.4 73.5±\pm 1.2
Table 7: LoRA experiments with 3.253.25 bits. The first row corresponds to the optimal DiscQ model with 3.253.25 bits that is not trained with LoRA. The last row corresponds to the full precision model.