Neural Network Quantization: From Theory to Browser
Experience neural network quantization interactively in your browser
What is Quantization?
Quantization is a technique that converts neural network weights and activations from high-precision representations (e.g., 32-bit floating-point) to low-precision formats (e.g., 8-bit integers). This significantly reduces model size and computational cost, with a trade-off in accuracy.
Core challenge: Finding the optimal balance between precision and efficiency.
Mathematical Foundation
Quantization is essentially a mapping function from continuous floating-point space to discrete integer space:
\[q = \text{round}\left(\frac{r}{s}\right) + z\]Where:
- $r$ is the real-valued floating-point number
- $q$ is the quantized integer value
- $s$ is the scale factor
- $z$ is the zero-point offset
The dequantization process:
\[r = s \cdot (q - z)\]🎮 Interactive Quantization Simulator
Let’s implement a real-time quantization simulator using jax-js in the browser. You can adjust bit-width and observe precision loss directly.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
<script type="module">
import { numpy as np, random } from "https://esm.sh/@jax-js/jax";
function roundToEven(x) {
const floorX = np.floor(x.ref);
const frac = x.sub(floorX.ref);
const gtHalf = frac.ref.greater(0.5);
const eqHalf = frac.equal(0.5);
const ceilX = floorX.ref.add(1);
const rounded = np.where(gtHalf, ceilX.ref, floorX.ref);
const floorIsEven = np.floor(floorX.ref.div(2)).mul(2).equal(floorX.ref);
const tieRounded = np.where(floorIsEven, floorX, ceilX);
return np.where(eqHalf, tieRounded, rounded);
}
// Quantization function
function quantize(x, bits) {
const qmin = 0;
const qmax = (1 << bits) - 1;
const xForMin = x.ref;
const xForMax = x.ref;
const xMin = xForMin.min();
const xMax = xForMax.max();
// Compute scale and zero_point
const scale = xMax.sub(xMin.ref).div(qmax - qmin);
const zeroPoint = roundToEven(xMin.div(scale.ref).mul(-1).add(qmin));
// Quantize
const xScaled = x.div(scale.ref).add(zeroPoint.ref);
const xQuantized = np.clip(roundToEven(xScaled), qmin, qmax);
return { quantized: xQuantized, scale, zeroPoint };
}
// Dequantization function
function dequantize(xq, scale, zeroPoint) {
return xq.sub(zeroPoint).mul(scale);
}
// Create test data
const key = random.key(95);
const weights = random.normal(key, [4, 4]).mul(0.5);
// 8-bit quantization
const { quantized, scale, zeroPoint } = quantize(weights.ref, 8);
const recovered = dequantize(quantized, scale, zeroPoint);
// Compute quantization error
const error = np.abs(weights.sub(recovered)).mean();
console.log("Quantization error:", error.item());
</script>
Try it yourself (demo runs when visible):
Scroll here to load demo...
Key insight: Lower bit-width increases discretization, leading to higher error. However, 8-bit quantization typically maintains >99% accuracy.
🔥 Symmetric vs Affine Quantization
Symmetric quantization constrains the zero-point to 0, simplifying computation:
1
2
3
4
5
6
7
8
9
10
11
12
function symmetricQuantize(x, bits) {
const qmax = (1 << (bits - 1)) - 1;
const scale = np.abs(x.ref).max().div(qmax);
const xQuantized = np.clip(
np.floor(x.div(scale.ref).add(0.5)),
-qmax,
qmax
);
return { quantized: xQuantized, scale };
}
Advantages:
- Zero-point elimination simplifies multiplication operations
- Suitable for symmetric weight distributions
- Hardware-friendly for accelerated inference
⚡ Quantized Matrix Multiplication
The core value of quantization lies in accelerating matrix operations. Here’s how to replace floating-point arithmetic with integer operations:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Quantized matrix multiplication
function quantizedMatmul(A, B, scaleA, zpA, scaleB, zpB) {
// Integer matrix multiplication
const C = np.matmul(A, B);
// Output scale
const scaleC = scaleA.mul(scaleB);
// Zero-point correction term (can be precomputed)
const zpCorrection = zpA.ref.mul(B.sum(0)).add(
zpB.ref.mul(A.sum(1, true))
).sub(zpA.mul(zpB).mul(A.shape[1]));
return { result: C.sub(zpCorrection), scale: scaleC };
}
Performance gains: Integer operations are 2-4× faster than floating-point, with 75% memory reduction.
📊 Visualizing Weight Distribution Changes
Interactive visualization (loads when visible):
Observations:
- 2-bit: Significant distortion, only for extreme compression scenarios
- 4-bit: Usable with Quantization-Aware Training (QAT)
- 8-bit: Nearly lossless, industry standard
- 16-bit: Virtually identical to float32
🤖 Model Inference Comparison
Interactive inference demo (loads when visible):
Key finding: 8-bit quantization typically preserves predictions with output difference < 0.01.
🎨 Error Heatmap: Visualizing Quantization Loss
Interactive heatmap (loads when visible, click cells to see values):
Insights:
- Deep blue regions: Quantization error < 0.001 (nearly perfect)
- Purple regions: Moderate error 0.001-0.01
- Red regions: High error > 0.01 (requires attention)
Most weights exhibit minimal quantization error; only a few outliers need special handling.
🚀 Performance Arena: Float vs Integer
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import { jit, numpy as np, random } from "https://esm.sh/@jax-js/jax";
// JIT-compiled float32 matmul
const float32Matmul = jit((A, B) => np.matmul(A, B));
// JIT-compiled int8 matmul
const int8Matmul = jit((A, B) => np.matmul(A, B));
// Benchmark function
function benchmark(fn, A, B, iterations = 100) {
const start = performance.now();
for (let i = 0; i < iterations; i++) {
fn(A.ref, B.ref);
}
return (performance.now() - start) / iterations;
}
// Create test matrices
const key = random.key(95);
const [keyA, keyB] = random.split(key, 2);
const A_float = random.normal(keyA, [128, 128]);
const B_float = random.normal(keyB, [128, 128]);
const { quantized: A_int } = quantize(A_float.ref, 8);
const { quantized: B_int } = quantize(B_float.ref, 8);
const timeFloat = benchmark(float32Matmul, A_float, B_float);
const timeInt = benchmark(int8Matmul, A_int, B_int);
console.log(`Float32: ${timeFloat.toFixed(2)}ms`);
console.log(`Int8: ${timeInt.toFixed(2)}ms`);
console.log(`Speedup: ${(timeFloat / timeInt).toFixed(2)}x`);
💡 Practical Guidelines
When to use quantization?
- Edge deployment: Mobile devices and IoT with limited memory and compute
- Large-scale inference: Cloud services requiring cost reduction
- Real-time applications: Latency-sensitive scenarios (speech recognition, video processing)
Quantization strategy selection:
| Scenario | Recommended Approach | Accuracy Loss |
|---|---|---|
| Image Classification | 8-bit Post-Training Quantization | < 1% |
| Object Detection | 8-bit + QAT | < 2% |
| Large Language Models | 4-bit Group Quantization | 2-5% |
| Generative Models | Mixed Precision (critical layers 16-bit) | < 3% |
Best practices:
- ❌ Don’t quantize Batch Normalization layers (numerically unstable)
- ✅ Fuse Conv + BN + ReLU before quantization
- ❌ Don’t use uniform bit-width for all layers
- ✅ Maintain higher precision for sensitive layers (first/last layers)
🎯 Summary
Quantization is a fundamental technique for model compression. Through interactive demonstrations using jax-js in the browser, we can:
- Intuitively understand the mathematical principles and precision trade-offs
- Compare in real-time different quantization strategies
- Experiment at zero cost without requiring GPU infrastructure
Next steps to explore:
- Implement Quantization-Aware Training (QAT)
- Explore mixed-precision quantization
- Compare quantization frameworks (TensorRT, ONNX Runtime)
Interactive Demo: Visit Live Demo to experience the complete quantization simulator
References: