Intending to use distilling for training my model. The Plan is:
- Train model A and model B with same code and same dataset
- Predict the dataset with model A and model B, and store the average of their result
- Use the average prediction as the target of a new training process
Step 1 and Step 2 are successful. But when I run the new training process, it will report the loss as “Nan” after some steps.
To find out the reason, I started to print all the “average prediction results” for every step. At first, they look just as normal, but after a while, I find out that some input has “Nan”.
Why there is “Nan” in the “average prediction results”? I guess the reason is: some samples are too rare (or special) so the model will give an unreliable output. It’s quite easy to just ignore them:
if np.isnan(label).any() or not np.isfinite(label).all(): # Drop the corresponding sample return None
Now the distilling training could go on.