I have a 160GB TensorFlow image dataset that was generated using tfds. It can be directly loaded as tf.data.Dataset:

(ds_train, ds_test), ds_info = tfds.load(

The dataset is is imbalanced and to accommodate this fact I want to calculate the class weight matrix and pass them to model.fit() like this:

from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight(

The class_labels must be passed as numpy array. The only way to iterate over the dataset and get the labels is np.array([x[1].numpy() for x in list(ds_train)]). With this approach I run into a OOM-error as the whole dataset is cached in memory.

The only working solution I have right now is to precompute the class distribution from the raw image dataset which is far from ideal. Hopefully, this information is available in tf.data.Dataset in a future TensorFlow release.