May 23, 2025

Grigory Khromov

6 min read

Engineering

NaNs Out of Nowhere? The Trick That Saved My Training Run

A technical dive into a sneaky numerical instability issue with a custom loss function and how math tricks helped to solve it.

NaNs Out of Nowhere? The Trick That Saved My Training Run

Training deep learning models can be a real headache. We often wrestle with finding good data, crafting the right architecture, and battling the demons of CUDA drivers, multi-GPU setups, and conflicting library versions. And let’s not even talk about the cryptic and sometimes outdated documentation of our favourite frameworks.

However, today I’ll tell you a story about yet another challenge that is less common for an average ML engineer: numerical instability. And I am not talking about the classic exploding/vanishing gradients, but rather the design of the loss function itself.


The Problem: A Custom Loss Goes Rogue

For most Large Language Model (LLM) training, any sane person would stick to the tried-and-true cross-entropy loss. My use case, however, was a bit… special. I needed to teach my model when to say “stop” - specifically, to penalise it for generating an End-of-Sequence (EOS) token when it shouldn’t, and for not generating one when it should. If someone is wondering why would one ever want to do this, I will briefly say that in my case the position of the EOS token was key to classifying positive or negative examples.

So my genius (not really) idea was to construct the following loss function:

L=LCE+αLpos_penalty+βLneg_penaltyL = L_{CE} + \alpha L_{pos\_penalty} + \beta L_{neg\_penalty}

Here, LCEL_{CE} is the regular cross-entropy loss, Lpos_penaltyL_{pos\_penalty} punishes the model for spitting out an EOS for a “positive” example, and Lneg_penaltyL_{neg\_penalty} punishes it for holding back the EOS for a “negative” example. α\alpha and β\beta are just the tuning knobs.

The penalty terms looked like this:

Lpos_penalty=log(1sm(xEOS+))L_{pos\_penalty} = -\log (1 - sm(x_{EOS}^{+})) Lneg_penalty=log(sm(xEOS))L_{neg\_penalty} = -\log (sm(x_{EOS}^{-}))

Where xEOSx_{EOS} is the model’s logit (raw score) for the EOS token, and smsm is the softmax function. The ++ and - superscripts simply indicate whether we’re looking at positive or negative examples from the batch.

Seems straightforward enough, doesn’t it? Here is my first simple PyTorch implementation:

import torch.nn.functional as F

# x_logits: [batch_size, vocab_size]
# eos_id: int, ID of the EOS token
# positive_examples: [batch_size], boolean mask for positive examples
# negative_examples: [batch_size], boolean mask for negative examples

softmax_eos = F.softmax(x_logits, dim=1)[:, eos_id]

# We take the mean over the batch
neg_penalty = torch.mean(
    -torch.log(softmax_eos[negative_examples]).clamp(max=10.0)
)

pos_penalty = torch.mean(
    -torch.log(1 - softmax_eos[positive_examples]).clamp(max=10.0)
)

I even threw in a .clamp(max=10.0) for good measure, trying to prevent any extreme values. Can you spot the bug?

Well, it is quite tricky 🙂 I didn’t see it either, at first. My model would train happily for an epoch or so. All the metrics - loss, gradient norm - looked perfectly fine. Then, suddenly, midway through the second epoch: NaN. The loss went NaN. The gradient norm went NaN. Peeking inside, I found all my model weights had turned into NaN. Yet, there were no warning signs, no dramatic explosion in gradients beforehand. Very suspicious indeed.


The Solution

So, what went wrong? Let’s dissect the negative penalty calculation. When we apply the softmax definition, we get:

Lneg_penalty=log(sm(xEOS))=log(exp(xEOS)i=1Vexp(xi))L_{neg\_penalty} = -\log (sm(x_{EOS}^{-})) = -\log \left( \frac{\exp(x_{EOS}^{-})}{\sum_{i=1}^{V} \exp(x_{i}^{-})} \right) =log(exp(xEOS))+log(i=1Vexp(xi))= -\log(\exp(x_{EOS}^{-})) + \log\left(\sum_{i=1}^{V} \exp(x_{i}^{-})\right) =xEOS+log(i=1Vexp(xi))= -x_{EOS}^{-} + \log\left(\sum_{i=1}^{V} \exp(x_{i}^{-})\right)

Look closely at the second term. We are calculating the exp\exp for every single logit in the vocabulary VV and then summing them up. If any of those logits xix_i^{-} is either very large or very small, its exponential might over or underflow. This propagates the error through the loss calculations, turning relevant values into NaN and poisoning the entire training process.

This isn’t a new problem. It’s a classic numerical trap, and thankfully, there’s a classic solution: the Log-Sum-Exp (LSE) trick. It’s a way to compute log(exp(xi))\log(\sum\exp(x_i)) in a numerically stable manner, often by shifting the values before exponentiating. I highly recommend checking out this blog post by Gregory Gundersen for a deeper dive. This trick is a cornerstone of many stable computations in machine learning - even the famous Flash Attention uses a similar idea (see section 3.1 of the paper).

Luckily, PyTorch has our back with a built-in, stable implementation: torch.logsumexp. Using it, we can rewrite our negative penalty code:

import torch

# Calculate LogSumExp across the vocabulary dimension (numerically stable)
log_sum_exp_x_logits = torch.logsumexp(x_logits, dim=1)

# Apply the stable formula for negative penalty
neg_penalty = torch.mean((
    -x_logits[negative_examples, eos_id] +
    log_sum_exp_x_logits[negative_examples]
).clamp(max=10.0))

We can perform a similar transformation for the positive penalty:

Lpos_penalty=log(iEOSexp(xi+))+log(i=1Vexp(xi+))L_{pos\_penalty} = -\log\left(\sum_{i\ne EOS} \exp(x_{i}^{+})\right) + \log\left(\sum_{i=1}^{V} \exp(x_{i}^{+})\right)

This translates to:

import torch

# LogSumExp for all logits (already calculated)
log_sum_exp_x_logits = torch.logsumexp(x_logits, dim=1)

# Calculate LogSumExp *excluding* the EOS token
# (This assumes eos_id is not 0 or V-1, otherwise needs more careful handling)
x_not_eos_logits = torch.concatenate([
    x_logits[positive_examples, :eos_id],
    x_logits[positive_examples, eos_id+1:],
], dim=1)
log_sum_exp_x_not_eos_logits = torch.logsumexp(x_not_eos_logits, dim=1)

# Apply the stable formula for positive penalty
pos_penalty = torch.mean((
    -log_sum_exp_x_not_eos_logits +
    log_sum_exp_x_logits[positive_examples]
).clamp(max=10.0))

By switching to the LSE-based formulation, the dreaded NaNs vanished, and my model could finally train stably. Success! 🎉


Conclusion

Numerical stability is one of those “invisible” challenges in deep learning. Things can look fine on the surface, with seemingly sound mathematical formulas, but the limitations of floating-point arithmetic can bite you when you least expect it, especially when dealing with exponentials and logarithms. So, the next time your training run inexplicably blows up, remember the silent threat of numerical instability, and perhaps, the saving grace of the Log-Sum-Exp trick. Happy (and stable) training!


References

Other Articles