Test the Unbalanced Data Pipeline
In this step, you will test the unbalanced_data_pipeline
function to ensure it is working as expected.
Add the following code in the unbalanced_data_pipeline.py
file.
if __name__ == "__main__":
data = [
[[1, 2, 5], [1, 0]],
[[1, 6, 0], [1, 0]],
[[4, 1, 8], [1, 0]],
[[7, 0, 4], [0, 1]],
[[5, 9, 4], [0, 1]],
[[2, 0, 1], [0, 1]],
[[1, 9, 3], [0, 1]],
[[5, 5, 5], [0, 1]],
[[8, 4, 0], [0, 1]],
[[9, 6, 3], [0, 1]],
[[7, 7, 0], [0, 1]],
[[0, 3, 4], [0, 1]],
]
for epoch in range(10):
batch_data = unbalanced_data_pipeline(data, 6)
batch_data = list(batch_data)
print(f"{epoch=}, {batch_data=}")
In the if __name__ == "__main__":
block, we call the unbalanced_data_pipeline
function with the sample data and a batch size of 6.
Run the unbalanced_data_pipeline.py
file to see the output.
python unbalanced_data_pipeline.py
The output should look similar to the example provided in the original challenge:
epoch=0, batch_data=[[[1, 2, 5], [1, 0]], [[4, 1, 8], [1, 0]], [[1, 6, 0], [1, 0]], [[2, 0, 1], [0, 1]], [[7, 0, 4], [0, 1]], [[5, 9, 4], [0, 1]]]
epoch=1, batch_data=[[[4, 1, 8], [1, 0]], [[1, 2, 5], [1, 0]], [[1, 6, 0], [1, 0]], [[2, 0, 1], [0, 1]], [[9, 6, 3], [0, 1]], [[1, 9, 3], [0, 1]]]
epoch=2, batch_data=[[[4, 1, 8], [1, 0]], [[1, 2, 5], [1, 0]], [[1, 6, 0], [1, 0]], [[5, 5, 5], [0, 1]], [[7, 0, 4], [0, 1]], [[8, 4, 0], [0, 1]]]
epoch=3, batch_data=[[[1, 2, 5], [1, 0]], [[1, 6, 0], [1, 0]], [[4, 1, 8], [1, 0]], [[7, 7, 0], [0, 1]], [[8, 4, 0], [0, 1]], [[0, 3, 4], [0, 1]]]
epoch=4, batch_data=[[[4, 1, 8], [1, 0]], [[1, 6, 0], [1, 0]], [[1, 2, 5], [1, 0]], [[5, 5, 5], [0, 1]], [[0, 3, 4], [0, 1]], [[8, 4, 0], [0, 1]]]
epoch=5, batch_data=[[[1, 6, 0], [1, 0]], [[4, 1, 8], [1, 0]], [[1, 2, 5], [1, 0]], [[2, 0, 1], [0, 1]], [[7, 0, 4], [0, 1]], [[7, 7, 0], [0, 1]]]
epoch=6, batch_data=[[[1, 2, 5], [1, 0]], [[1, 6, 0], [1, 0]], [[4, 1, 8], [1, 0]], [[8, 4, 0], [0, 1]], [[5, 9, 4], [0, 1]], [[0, 3, 4], [0, 1]]]
epoch=7, batch_data=[[[1, 2, 5], [1, 0]], [[1, 6, 0], [1, 0]], [[4, 1, 8], [1, 0]], [[2, 0, 1], [0, 1]], [[0, 3, 4], [0, 1]], [[1, 9, 3], [0, 1]]]
epoch=8, batch_data=[[[1, 6, 0], [1, 0]], [[4, 1, 8], [1, 0]], [[1, 2, 5], [1, 0]], [[7, 7, 0], [0, 1]], [[2, 0, 1], [0, 1]], [[0, 3, 4], [0, 1]]]
epoch=9, batch_data=[[[1, 2, 5], [1, 0]], [[4, 1, 8], [1, 0]], [[1, 6, 0], [1, 0]], [[7, 0, 4], [0, 1]], [[0, 3, 4], [0, 1]], [[5, 5, 5], [0, 1]]]
In this step, you have tested the unbalanced_data_pipeline
function to ensure it is working as expected. The function should now be able to process the unbalanced data and return batches of data with approximately balanced class distributions.