TL;DR
Stop worrying about plasticity loss, just apply FIRE before training on new data.
Overview. FIRE balances stability and plasticity in a principled manner via constarined optimization. (a) FIRE well balances stability-plasticity in continual visual learning: Stability and plasticity across diverse architectures and datasets, including CIFAR-10 with ResNet-18, CIFAR-100 with ViT-Tiny, and Tiny ImageNet with VGG-16. (b) FIRE well balances stability-plasticity in continual pretraining of LLMs: GPT-0.1B was trained on WikiText-103 and OpenWebText. From left to right, results correspond to models initialized from the best checkpoint during pretraining, from 30k pretraining iterations, and from 60k pretraining iterations. (c) FIRE well balances stability-plasticity in reinforcement learning: continuous control with SAC on three HumanoidBench tasks. The black dashed line indicates the point at which reinitialization is applied.
Abstract
We propose FIRE, a principled reinitialization method that explicitly balances the stability-plasticity tradeoff. FIRE quantifies stability through Squared Frobenius Error (SFE), measuring proximity to past weights, and plasticity through Deviation from Isometry (DfI), reflecting weight isotropy. The reinitialization point is obtained by solving a constrained optimization problem, minimizing SFE subject to DfI being zero, which is efficiently approximated by Newton-Schulz iteration. FIRE is evaluated on continual visual learning (CIFAR-10 with ResNet-18), language modeling (OpenWebText with GPT-0.1B), and reinforcement learning (HumanoidBench with SAC and Atari games with DQN). Across all domains, FIRE consistently outperforms both naive training without intervention and standard reinitialization methods, demonstrating effective balancing of the stability-plasticity tradeoff.
FIRE Algorithm
(a) Illustration of FIRE: Solving a constrained optimization problem, FIRE places weights at the intersection of high-stability and high-plasticity manifolds. The solution to the constrained optimization problem is the nearest semi-orthogonal weight matrix. (b) Algorithm of FIRE: FIRE approximates this solution via the Newton-Schulz iteration, orthogonalizing the neural network weights after training on the current dataset and before learning on a new one.
FIRE with Continual Visual Learning
We evaluate FIRE on 3 continual visual learning tasks across 3 pairs of various architectures and datasets: CIFAR-10 with ResNet-18, CIFAR-100 with ViT-Tiny, and Tiny ImageNet with VGG-16.
FIRE with Continual Pre-training of LLMs
We pretrained GPT-0.1B on WikiText-103 and then continually trained on a new dataset consisting of a mixture of OpenWebText and WikiText-103. From left to right, results correspond to models initialized from the best checkpoint during pretraining, from 30k pretraining iterations, and from 60k pretraining iterations.
FIRE with Reinforcement Learning
(up) Discrete control with DQN on three Atari environments that suffer from severe plasticity loss. (down) Continuous control with SAC on three HumanoidBench tasks. The black dashed line indicates the point at which reinitialization is applied.
Paper
FIRE: Frobenius-Isometry Reinitialization for Balancing the Stability-Plasticity TradeoffIsaac Han, Sangyeon Park, Seungwon Oh, Donghu Kim,
Hojoon Lee, Kyung-Joong Kim
Openreview
Citation
If you find our work useful, please consider citing the paper as follows: