I have a 160GB TensorFlow image dataset that was generated using tfds. It can be directly loaded as tf.data.Dataset:

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train[-80%:]','train[:20%]'],
    shuffle_files=True,
    with_info=True
)

The dataset is is imbalanced and to accommodate this fact I want to calculate the class weight matrix and pass them to model.fit() like this:

from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(class_labels),
    y=class_labels
)

The class_labels must be passed as numpy array. The only way to iterate over the dataset and get the labels is np.array([x[1].numpy() for x in list(ds_train)]). With this approach I run into a OOM-error as the whole dataset is cached in memory.

The only working solution I have right now is to precompute the class distribution from the raw image dataset which is far from ideal. Hopefully, this information is available in tf.data.Dataset in a future TensorFlow release.