Class distribution in tf.data.Dataset
I have a 160GB TensorFlow image dataset that was generated using tfds. It can be directly loaded as tf.data.Dataset
:
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train[-80%:]','train[:20%]'],
shuffle_files=True,
with_info=True
)
The dataset is is imbalanced and to accommodate this fact I want to calculate the class weight matrix and pass them to model.fit()
like this:
from sklearn.utils.class_weight import compute_class_weight
class_weights = compute_class_weight(
class_weight='balanced',
classes=np.unique(class_labels),
y=class_labels
)
The class_labels
must be passed as numpy array. The only way to iterate over the dataset and get the labels is np.array([x[1].numpy() for x in list(ds_train)])
. With this approach I run into a OOM-error as the whole dataset is cached in memory.
The only working solution I have right now is to precompute the class distribution from the raw image dataset which is far from ideal. Hopefully, this information is available in tf.data.Dataset
in a future TensorFlow release.