Engineering
NaNs Out of Nowhere? The Trick That Saved My Training Run
A technical dive into a sneaky numerical instability issue with a custom loss function and how math tricks helped to solve it.

Training deep learning models can be a real headache. We often wrestle with finding good data, crafting the right architecture, and battling the demons of CUDA drivers, multi-GPU setups, and conflicting library versions. And let’s not even talk about the cryptic and sometimes outdated documentation of our favourite frameworks.
However, today I’ll tell you a story about yet another challenge that is less common for an average ML engineer: numerical instability. And I am not talking about the classic exploding/vanishing gradients, but rather the design of the loss function itself.
The Problem: A Custom Loss Goes Rogue
For most Large Language Model (LLM) training, any sane person would stick to the tried-and-true cross-entropy loss. My use case, however, was a bit… special. I needed to teach my model when to say “stop” - specifically, to penalise it for generating an End-of-Sequence (EOS) token when it shouldn’t, and for not generating one when it should. If someone is wondering why would one ever want to do this, I will briefly say that in my case the position of the EOS token was key to classifying positive or negative examples.
So my genius (not really) idea was to construct the following loss function:
Here, is the regular cross-entropy loss, punishes the model for spitting out an EOS for a “positive” example, and punishes it for holding back the EOS for a “negative” example. and are just the tuning knobs.
The penalty terms looked like this:
Where is the model’s logit (raw score) for the EOS token, and is the softmax function. The and superscripts simply indicate whether we’re looking at positive or negative examples from the batch.
Seems straightforward enough, doesn’t it? Here is my first simple PyTorch implementation:
import torch.nn.functional as F
# x_logits: [batch_size, vocab_size]
# eos_id: int, ID of the EOS token
# positive_examples: [batch_size], boolean mask for positive examples
# negative_examples: [batch_size], boolean mask for negative examples
softmax_eos = F.softmax(x_logits, dim=1)[:, eos_id]
# We take the mean over the batch
neg_penalty = torch.mean(
-torch.log(softmax_eos[negative_examples]).clamp(max=10.0)
)
pos_penalty = torch.mean(
-torch.log(1 - softmax_eos[positive_examples]).clamp(max=10.0)
)
I even threw in a .clamp(max=10.0)
for good measure, trying to prevent any extreme values. Can you spot the bug?
Well, it is quite tricky 🙂 I didn’t see it either, at first. My model would train happily for an epoch or so. All the metrics - loss, gradient norm - looked perfectly fine. Then, suddenly, midway through the second epoch: NaN. The loss went NaN. The gradient norm went NaN. Peeking inside, I found all my model weights had turned into NaN. Yet, there were no warning signs, no dramatic explosion in gradients beforehand. Very suspicious indeed.
The Solution
So, what went wrong? Let’s dissect the negative penalty calculation. When we apply the softmax definition, we get:
Look closely at the second term. We are calculating the for every single logit in the vocabulary and then summing them up. If any of those logits is either very large or very small, its exponential might over or underflow. This propagates the error through the loss calculations, turning relevant values into NaN and poisoning the entire training process.
This isn’t a new problem. It’s a classic numerical trap, and thankfully, there’s a classic solution: the Log-Sum-Exp (LSE) trick. It’s a way to compute in a numerically stable manner, often by shifting the values before exponentiating. I highly recommend checking out this blog post by Gregory Gundersen for a deeper dive. This trick is a cornerstone of many stable computations in machine learning - even the famous Flash Attention uses a similar idea (see section 3.1 of the paper).
Luckily, PyTorch has our back with a built-in, stable implementation: torch.logsumexp
. Using it, we can rewrite our negative penalty code:
import torch
# Calculate LogSumExp across the vocabulary dimension (numerically stable)
log_sum_exp_x_logits = torch.logsumexp(x_logits, dim=1)
# Apply the stable formula for negative penalty
neg_penalty = torch.mean((
-x_logits[negative_examples, eos_id] +
log_sum_exp_x_logits[negative_examples]
).clamp(max=10.0))
We can perform a similar transformation for the positive penalty:
This translates to:
import torch
# LogSumExp for all logits (already calculated)
log_sum_exp_x_logits = torch.logsumexp(x_logits, dim=1)
# Calculate LogSumExp *excluding* the EOS token
# (This assumes eos_id is not 0 or V-1, otherwise needs more careful handling)
x_not_eos_logits = torch.concatenate([
x_logits[positive_examples, :eos_id],
x_logits[positive_examples, eos_id+1:],
], dim=1)
log_sum_exp_x_not_eos_logits = torch.logsumexp(x_not_eos_logits, dim=1)
# Apply the stable formula for positive penalty
pos_penalty = torch.mean((
-log_sum_exp_x_not_eos_logits +
log_sum_exp_x_logits[positive_examples]
).clamp(max=10.0))
By switching to the LSE-based formulation, the dreaded NaNs vanished, and my model could finally train stably. Success! 🎉
Conclusion
Numerical stability is one of those “invisible” challenges in deep learning. Things can look fine on the surface, with seemingly sound mathematical formulas, but the limitations of floating-point arithmetic can bite you when you least expect it, especially when dealing with exponentials and logarithms. So, the next time your training run inexplicably blows up, remember the silent threat of numerical instability, and perhaps, the saving grace of the Log-Sum-Exp trick. Happy (and stable) training!
References
- Gundersen, G. (2020). The Log-Sum-Exp Trick. https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
- Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arXiv preprint arXiv:2205.14135. https://arxiv.org/abs/2205.14135
- PyTorch Documentation. The
torch.logsumexp
function. https://docs.pytorch.org/docs/stable/generated/torch.logsumexp.html#torch-logsumexp