The per_image_whitening function in TensorFlow (Abadi et al., 2015) is used to normalize individual images by adjusting their mean and standard deviation independently. This technique ensures that each image has zero mean and a unit standard deviation, which helps improve the training stability of deep learning models.
How per_image_whitening Works
For a given image , the function applies the following transformation:
where:
-
is the input image (tensor).
-
is the mean pixel value of the image.
-
is the adjusted standard deviation (explained below).
-
is the whitened image.
Adjusted Standard Deviation Calculation:
Unlike a standard standard deviation calculation, TensorFlow adjusts it slightly to prevent division by zero. The formula used is:
where:
-
is the total number of pixels in the image.
-
is a small constant (e.g., ) added to avoid division by zero.
Key Properties of per_image_whitening
-
Zero Mean: The mean pixel value of the image becomes 0.
-
Unit Variance (with adjustment): The standard deviation is adjusted to avoid division by very small values.
-
Applied Per Image: Unlike batch normalization, which normalizes across a batch of images, this method is applied independently to each image.
-
Useful for CNNs: Helps improve convergence by ensuring uniform feature scaling.
TensorFlow Implementation
In TensorFlow (older versions), the function was available as:
import tensorflow as tf
# Load or define an image tensor (H, W, C)
image = tf.random.normal(shape=[128, 128, 3]) # Example image
# Apply per-image whitening
whitened_image = tf.image.per_image_standardization(image)
# This function performs per-image whitening as:
# I' = (I - mean) / adjusted_std
Modern TensorFlow (TF 2.x) Equivalent
The function tf.image.per_image_standardization performs the same operation:
import tensorflow as tf
def per_image_whitening(image):
mean = tf.reduce_mean(image)
stddev = tf.math.reduce_std(image)
adjusted_stddev = tf.math.maximum(stddev, 1.0 / tf.sqrt(tf.cast(tf.size(image), tf.float32)))
return (image - mean) / adjusted_stddev
# Example Image Tensor
image = tf.random.normal([128, 128, 3])
# Apply Whitening
whitened_image = per_image_whitening(image)
- Let's break down per-image whitening step by step with a visualization approach. We'll generate a sample image, compute the mean and standard deviation, and apply whitening while displaying intermediate results.
Step 1: Generate a Sample Image
We'll create a random image (128 × 128 pixels with 3 color channels) to simulate an actual image.
Step 2: Compute the Mean and Standard Deviation
We calculate:
-
= mean pixel value
-
= adjusted standard deviation
Step 3: Apply Whitening Transformation
Each pixel is transformed using:
where is adjusted to avoid division by zero.
Step 4: Visualize the Changes
We'll show:
-
The original image.
-
The pixel intensity histogram before and after whitening.
-
The whitened image.
Python Code for Visualization
We'll use TensorFlow for preprocessing and Matplotlib for visualization.
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Generate a random image (128x128 with 3 channels)
image = tf.random.uniform(shape=[128, 128, 3], minval=0, maxval=255, dtype=tf.float32)
# Function to perform per-image whitening
def per_image_whitening(image):
mean = tf.reduce_mean(image)
stddev = tf.math.reduce_std(image)
adjusted_stddev = tf.maximum(stddev, 1.0 / tf.sqrt(tf.cast(tf.size(image), tf.float32)))
whitened_image = (image - mean) / adjusted_stddev
return whitened_image, mean, stddev, adjusted_stddev
# Apply Whitening
whitened_image, mean, stddev, adjusted_stddev = per_image_whitening(image)
# Convert to numpy for visualization
image_np = image.numpy() / 255.0 # Normalize for display
whitened_image_np = (whitened_image.numpy() - np.min(whitened_image.numpy())) / \
(np.max(whitened_image.numpy()) - np.min(whitened_image.numpy()))
# Plot original and whitened images
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(image_np)
ax[0].set_title(f"Original Image\nMean: {mean.numpy():.2f}, Std: {stddev.numpy():.2f}")
ax[0].axis("off")
ax[1].hist(image_np.ravel(), bins=50, color='blue', alpha=0.7)
ax[1].set_title("Pixel Intensity Histogram (Before Whitening)")
ax[2].hist(whitened_image_np.ravel(), bins=50, color='red', alpha=0.7)
ax[2].set_title("Pixel Intensity Histogram (After Whitening)")
plt.show()
Expected Results
-
Original Image: Displays the raw image with its original color distribution.
-
Histogram Before Whitening: Shows a spread of pixel values between 0 and 255.
-
Histogram After Whitening: Now centered around zero mean with unit variance.
Conclusion
The adjusted standard deviation ensures numerical stability.
per_image_whiteningis independent for each image, unlike batch normalization.It is commonly used in preprocessing pipelines for image models.
Comments
Post a Comment