Home / AI / Optimizing Deep Learning: A Guide to Weight Quantization

Optimizing Deep Learning: A Guide to Weight Quantization

In the current realm of deep learning, optimizing models for environments with limited resources is increasingly vital. Weight quantization offers a solution by reducing the precision of model parameters—typically transitioning from 32-bit floating-point values to lower bit-width representations. This results in smaller models that execute more efficiently on constrained hardware. This tutorial will guide you through weight quantization, utilizing PyTorch’s dynamic quantization approach on a pre-trained ResNet18 model. You’ll learn how to inspect weight distributions, apply dynamic quantization to key layers (such as fully connected ones), compare model sizes, and visualize changes. This tutorial provides both the theoretical understanding and practical skills needed to deploy deep learning models effectively.

import torch
import torch.nn as nn
import torch.quantization
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
import os

print("Torch version:", torch.__version__)

Here, we import necessary libraries like PyTorch, torchvision, and matplotlib, and display the PyTorch version, ensuring all modules are set for model manipulation and visualization.

model_fp32 = models.resnet18(pretrained=True)
model_fp32.eval()  

print("Pretrained ResNet18 (FP32) model loaded.")

We load a pre-trained ResNet18 model in FP32 (floating-point) precision and set it to evaluation mode, readying it for further processing and quantization.

fc_weights_fp32 = model_fp32.fc.weight.data.cpu().numpy().flatten()

plt.figure(figsize=(8, 4))
plt.hist(fc_weights_fp32, bins=50, color='skyblue', edgecolor='black')
plt.title("FP32 - FC Layer Weight Distribution")
plt.xlabel("Weight values")
plt.ylabel("Frequency")
plt.grid(True)
plt.show()

Here, we extract and flatten the weights from the final fully connected layer of the FP32 model, then plot a histogram to visualize their distribution before quantization.

quantized_model = torch.quantization.quantize_dynamic(model_fp32, {nn.Linear}, dtype=torch.qint8)
quantized_model.eval()  

print("Dynamic quantization applied to the model.")

Dynamic quantization is applied to the model, focusing on Linear layers, converting them to lower-precision formats—a critical step for reducing model size and inference latency.

def get_model_size(model, filename="temp.p"):
    torch.save(model.state_dict(), filename)
    size = os.path.getsize(filename) / 1e6
    os.remove(filename)
    return size

fp32_size = get_model_size(model_fp32, "fp32_model.p")
quant_size = get_model_size(quantized_model, "quant_model.p")

print(f"FP32 Model Size: {fp32_size:.2f} MB")
print(f"Quantized Model Size: {quant_size:.2f} MB")

We define a helper function to save the model and check its size on disk, using it to measure and compare the sizes of the original and quantized models, highlighting quantization’s compression benefits.

dummy_input = torch.randn(1, 3, 224, 224)

with torch.no_grad():
    output_fp32 = model_fp32(dummy_input)
    output_quant = quantized_model(dummy_input)

print("Output from FP32 model (first 5 elements):", output_fp32[0][:5])
print("Output from Quantized model (first 5 elements):", output_quant[0][:5])

A dummy input simulating an image is created, and both FP32 and quantized models are run on this input, allowing you to compare their outputs and confirm that quantization doesn’t significantly alter predictions.

if hasattr(quantized_model.fc, 'weight'):
    fc_weights_quant = quantized_model.fc.weight().dequantize().cpu().numpy().flatten()
else:
    fc_weights_quant = quantized_model.fc._packed_params._packed_weight.dequantize().cpu().numpy().flatten()

plt.figure(figsize=(14, 5))

plt.subplot(1, 2, 1)
plt.hist(fc_weights_fp32, bins=50, color='skyblue', edgecolor='black')
plt.title("FP32 - FC Layer Weight Distribution")
plt.xlabel("Weight values")
plt.ylabel("Frequency")
plt.grid(True)

plt.subplot(1, 2, 2)
plt.hist(fc_weights_quant, bins=50, color='salmon', edgecolor='black')
plt.title("Quantized - FC Layer Weight Distribution")
plt.xlabel("Weight values")
plt.ylabel("Frequency")
plt.grid(True)

plt.tight_layout()
plt.show()

We extract the quantized weights (after dequantization) from the fully connected layer and compare them with the original FP32 weights using histograms, illustrating the changes due to quantization.

In summary, this tutorial provided a comprehensive guide to weight quantization, highlighting its impact on model size and performance. Through quant

Deje un comentario

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *