Post

Neural Network Quantization: From Theory to Browser

Neural Network Quantization: From Theory to Browser

Experience neural network quantization interactively in your browser

What is Quantization?

Quantization is a technique that converts neural network weights and activations from high-precision representations (e.g., 32-bit floating-point) to low-precision formats (e.g., 8-bit integers). This significantly reduces model size and computational cost, with a trade-off in accuracy.

Core challenge: Finding the optimal balance between precision and efficiency.

Mathematical Foundation

Quantization is essentially a mapping function from continuous floating-point space to discrete integer space:

\[q = \text{round}\left(\frac{r}{s}\right) + z\]

Where:

  • $r$ is the real-valued floating-point number
  • $q$ is the quantized integer value
  • $s$ is the scale factor
  • $z$ is the zero-point offset

The dequantization process:

\[r = s \cdot (q - z)\]

🎮 Interactive Quantization Simulator

Let’s implement a real-time quantization simulator using jax-js in the browser. You can adjust bit-width and observe precision loss directly.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
<script type="module">
import { numpy as np, random } from "https://esm.sh/@jax-js/jax";

function roundToEven(x) {
  const floorX = np.floor(x.ref);
  const frac = x.sub(floorX.ref);
  const gtHalf = frac.ref.greater(0.5);
  const eqHalf = frac.equal(0.5);
  const ceilX = floorX.ref.add(1);
  const rounded = np.where(gtHalf, ceilX.ref, floorX.ref);
  const floorIsEven = np.floor(floorX.ref.div(2)).mul(2).equal(floorX.ref);
  const tieRounded = np.where(floorIsEven, floorX, ceilX);
  return np.where(eqHalf, tieRounded, rounded);
}

// Quantization function
function quantize(x, bits) {
  const qmin = 0;
  const qmax = (1 << bits) - 1;

  const xForMin = x.ref;
  const xForMax = x.ref;
  const xMin = xForMin.min();
  const xMax = xForMax.max();

  // Compute scale and zero_point
  const scale = xMax.sub(xMin.ref).div(qmax - qmin);
  const zeroPoint = roundToEven(xMin.div(scale.ref).mul(-1).add(qmin));

  // Quantize
  const xScaled = x.div(scale.ref).add(zeroPoint.ref);
  const xQuantized = np.clip(roundToEven(xScaled), qmin, qmax);

  return { quantized: xQuantized, scale, zeroPoint };
}

// Dequantization function
function dequantize(xq, scale, zeroPoint) {
  return xq.sub(zeroPoint).mul(scale);
}

// Create test data
const key = random.key(95);
const weights = random.normal(key, [4, 4]).mul(0.5);

// 8-bit quantization
const { quantized, scale, zeroPoint } = quantize(weights.ref, 8);
const recovered = dequantize(quantized, scale, zeroPoint);

// Compute quantization error
const error = np.abs(weights.sub(recovered)).mean();
console.log("Quantization error:", error.item());
</script>

Try it yourself (demo runs when visible):

Quantization Demo:
Scroll here to load demo...

Key insight: Lower bit-width increases discretization, leading to higher error. However, 8-bit quantization typically maintains >99% accuracy.

🔥 Symmetric vs Affine Quantization

Symmetric quantization constrains the zero-point to 0, simplifying computation:

1
2
3
4
5
6
7
8
9
10
11
12
function symmetricQuantize(x, bits) {
  const qmax = (1 << (bits - 1)) - 1;
  const scale = np.abs(x.ref).max().div(qmax);

  const xQuantized = np.clip(
    np.floor(x.div(scale.ref).add(0.5)),
    -qmax,
    qmax
  );

  return { quantized: xQuantized, scale };
}

Advantages:

  • Zero-point elimination simplifies multiplication operations
  • Suitable for symmetric weight distributions
  • Hardware-friendly for accelerated inference

⚡ Quantized Matrix Multiplication

The core value of quantization lies in accelerating matrix operations. Here’s how to replace floating-point arithmetic with integer operations:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Quantized matrix multiplication
function quantizedMatmul(A, B, scaleA, zpA, scaleB, zpB) {
  // Integer matrix multiplication
  const C = np.matmul(A, B);

  // Output scale
  const scaleC = scaleA.mul(scaleB);

  // Zero-point correction term (can be precomputed)
  const zpCorrection = zpA.ref.mul(B.sum(0)).add(
    zpB.ref.mul(A.sum(1, true))
  ).sub(zpA.mul(zpB).mul(A.shape[1]));

  return { result: C.sub(zpCorrection), scale: scaleC };
}

Performance gains: Integer operations are 2-4× faster than floating-point, with 75% memory reduction.

📊 Visualizing Weight Distribution Changes

Interactive visualization (loads when visible):

Observations:

  • 2-bit: Significant distortion, only for extreme compression scenarios
  • 4-bit: Usable with Quantization-Aware Training (QAT)
  • 8-bit: Nearly lossless, industry standard
  • 16-bit: Virtually identical to float32

🤖 Model Inference Comparison

Interactive inference demo (loads when visible):

Scroll here to load demo...

Key finding: 8-bit quantization typically preserves predictions with output difference < 0.01.

🎨 Error Heatmap: Visualizing Quantization Loss

Interactive heatmap (loads when visible, click cells to see values):

Insights:

  • Deep blue regions: Quantization error < 0.001 (nearly perfect)
  • Purple regions: Moderate error 0.001-0.01
  • Red regions: High error > 0.01 (requires attention)

Most weights exhibit minimal quantization error; only a few outliers need special handling.

🚀 Performance Arena: Float vs Integer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import { jit, numpy as np, random } from "https://esm.sh/@jax-js/jax";

// JIT-compiled float32 matmul
const float32Matmul = jit((A, B) => np.matmul(A, B));

// JIT-compiled int8 matmul
const int8Matmul = jit((A, B) => np.matmul(A, B));

// Benchmark function
function benchmark(fn, A, B, iterations = 100) {
  const start = performance.now();
  for (let i = 0; i < iterations; i++) {
    fn(A.ref, B.ref);
  }
  return (performance.now() - start) / iterations;
}

// Create test matrices
const key = random.key(95);
const [keyA, keyB] = random.split(key, 2);
const A_float = random.normal(keyA, [128, 128]);
const B_float = random.normal(keyB, [128, 128]);

const { quantized: A_int } = quantize(A_float.ref, 8);
const { quantized: B_int } = quantize(B_float.ref, 8);

const timeFloat = benchmark(float32Matmul, A_float, B_float);
const timeInt = benchmark(int8Matmul, A_int, B_int);

console.log(`Float32: ${timeFloat.toFixed(2)}ms`);
console.log(`Int8: ${timeInt.toFixed(2)}ms`);
console.log(`Speedup: ${(timeFloat / timeInt).toFixed(2)}x`);

💡 Practical Guidelines

When to use quantization?

  1. Edge deployment: Mobile devices and IoT with limited memory and compute
  2. Large-scale inference: Cloud services requiring cost reduction
  3. Real-time applications: Latency-sensitive scenarios (speech recognition, video processing)

Quantization strategy selection:

ScenarioRecommended ApproachAccuracy Loss
Image Classification8-bit Post-Training Quantization< 1%
Object Detection8-bit + QAT< 2%
Large Language Models4-bit Group Quantization2-5%
Generative ModelsMixed Precision (critical layers 16-bit)< 3%

Best practices:

  • ❌ Don’t quantize Batch Normalization layers (numerically unstable)
  • ✅ Fuse Conv + BN + ReLU before quantization
  • ❌ Don’t use uniform bit-width for all layers
  • ✅ Maintain higher precision for sensitive layers (first/last layers)

🎯 Summary

Quantization is a fundamental technique for model compression. Through interactive demonstrations using jax-js in the browser, we can:

  1. Intuitively understand the mathematical principles and precision trade-offs
  2. Compare in real-time different quantization strategies
  3. Experiment at zero cost without requiring GPU infrastructure

Next steps to explore:

  • Implement Quantization-Aware Training (QAT)
  • Explore mixed-precision quantization
  • Compare quantization frameworks (TensorRT, ONNX Runtime)

Interactive Demo: Visit Live Demo to experience the complete quantization simulator

References:

This post is licensed under CC BY 4.0 by the author.