An image is comprised of matrices of pixel values that can be represented as tf.Tensor. It’s not recommended to present the raw 8-bit or 16-bit integer tensors to a CNN as it may fail to converge. This post demonstrates how a 16-bit RGB image tensor that is not using its full value range can be normalized to a float64 tensor that is interpolated channel-wise to use the full [0,1] range.

First, let’s look at the RGB image tensor:

def stats(tensor):
    tensor = tf.cast(tensor, tf.float64) #for tf.math.reduce_mean with int
    return f'min: {tf.math.reduce_min(tensor)}, max:{tf.math.reduce_max(tensor)}, mean:{tf.math.reduce_mean(tensor)}'

print('shape', tf.shape(image_tensor))
print('dtype', image_tensor.dtype)
print('R', stats(image_tensor[:,:,0]))
print('G', stats(image_tensor[:,:,1]))
print('B', stats(image_tensor[:,:,2]))

This gives the following output:

shape tf.Tensor([500 500   3], shape=(3,), dtype=int32)
dtype <dtype: 'uint16'>
R min: 91.0, max:5892.0, mean:912.79528
G min: 124.0, max:6206.0, mean:831.6976
B min: 53.0, max:16102.0, mean:572.277328

Here we have a 500x500 pixel image. The first two dimensions are height and width and the last dimension contains the pixel values for the red, green and blue color channels. Pixel values are in this case unsigned 16-bit integers with a [0,65536] range. As you can see, this range is not fully used by the channels. The goal is to interpolate that tensor to have a [0,1] range that can be fed to a CNN. If the tensor was a numpy array we could use np.interp(a, (a.min(), a.max()), (0, 1)) for linear interpolation, but a similar function is not available in TensorFlow. Instead we need to write our own functions that normalize and interpolate the image across all channels (rescale_0_1) and for better pixel distribution channel-wise (rescale_0_1_channel_wise):

def rescale_0_1(tensor):
    tensor = tf.cast(tensor, tf.float64)
    tensor = (tensor - tf.math.reduce_min(tensor)) * (1 / (tf.math.reduce_max(tensor) - tf.math.reduce_min(tensor)))
    return tensor

def rescale_0_1_channel_wise(tensor):
    num_channels = tf.shape(tensor)[-1]
    channels = tf.TensorArray(tf.float64, size=num_channels)
    for channel_idx in tf.range(num_channels):
        channel = rescale_0_1(tensor[:,:,channel_idx])
        channels = channels.write(channel_idx, channel)
    tensor = tf.transpose(channels.stack(), [1,2,0])
    return tensor

Note that we must use channels = channels.write(channel_idx, channel) instead of just channels.write(channel_idx, channel) as the call to write() works in-place as a convenience when in eager mode, but must be chained in graph mode (see this issue).

Let’s take a look at the color histograms of each channel. We compare the interpolation across all channels with the channel-wise interpolation:

def plot(tensor, tensor_cw):
    rows,cols = 2,3
    fig, axs = plt.subplots(rows, cols, figsize=(10,5))

    channel_red, channel_green, channel_blue = tensor[:,:,0], tensor[:,:,1], tensor[:,:,2]
    channel_red_cw, channel_green_cw, channel_blue_cw = tensor_cw[:,:,0], tensor_cw[:,:,1], tensor_cw[:,:,2]
    
    axs[0,0].hist(channel_red, facecolor='red', range=(0, 1))
    axs[0,0].set_title(stats(channel_red))
    axs[0,1].hist(channel_green, facecolor='green', range=(0, 1))
    axs[0,1].set_title(stats(channel_green))
    axs[0,2].hist(channel_blue, facecolor='blue', range=(0, 1))
    axs[0,2].set_title(stats(channel_blue))

    axs[1,0].hist(channel_red_cw, facecolor='red', range=(0, 1))
    axs[1,0].set_title(stats(channel_red_cw))
    axs[1,1].hist(channel_green_cw, facecolor='green', range=(0, 1))
    axs[1,1].set_title(stats(channel_green_cw))
    axs[1,2].hist(channel_blue_cw, facecolor='blue', range=(0, 1))
    axs[1,2].set_title(stats(channel_blue_cw))

    row_names = ['normalized', 'channel-wise normalized']
    for ax, row_name in zip(axs[:,0], row_names):
        ax.set_ylabel(row_name, size='large')

    plt.show()

tensor = rescale_0_1(image_tensor)
tensor_cw = rescale_0_1_channel_wise(image_tensor)
plot(tensor, tensor_cw)

Channel-wise RGB image normalization

As we can see in the histogram, the channel-wise normalized and interpolated image offers a better pixel value distribution compared to the trivial interpolation across all channels.