Abstract
State Space Models (SSMs) have the advantage of keeping linear computational complexity compared to attention modules in transformers, and have been applied to vision tasks as a new type of powerful vision foundation model. Inspired by the observations that the final prediction in vision transformers (ViTs) is only based on a subset of most informative tokens, we take the novel step of enhancing the efficiency of SSM-based vision models through token-based pruning. However, direct applications of existing token pruning techniques designed for ViTs fail to deliver good performance, even with extensive fine-tuning. To address this issue, we revisit the unique computational characteristics of SSMs and discover that naive application disrupts the sequential token positions. This insight motivates us to design a novel and general token pruning method specifically for SSM-based vision models. We first introduce a pruning-aware hidden state alignment method to stabilize the neighborhood of remaining tokens for performance enhancement. Besides, based on our detailed analysis, we propose a token importance evaluation method adapted for SSM models, to guide the token pruning. With efficient implementation and practical acceleration methods, our method brings actual speedup. Extensive experiments demonstrate that our approach can achieve significant computation reduction with minimal impact on performance across different tasks. Notably, we achieve 81.7\% accuracy on ImageNet with a 41.6\% reduction in the FLOPs for pruned PlainMamba-L3. Furthermore, our work provides deeper insights into understanding the behavior of SSM-based vision models for future research.
Abstract (translated)
State Space Models (SSMs)具有保持线性计算复杂性与Transformer中的注意力模块相比的优势,并且已经应用于视觉任务作为新型强大的视觉基础模型。受到视觉Transformer(ViTs)中最后预测仅基于最具信息性的标记的观察,我们迈出了通过基于标记的剪枝增强SSM基于视觉模型的效率的新一步。然而,为实现这一目标,为ViTs设计的现有标记剪枝技术的直接应用却无法实现良好的性能,即使是进行了广泛的微调。为了应对这个问题,我们重新审视了SSM的独特的计算特性,并发现了 naive application会破坏标记的序列位置的观察。这一见解促使我们设计一种新的、适用于SSM基于视觉模型的自适应标记剪枝方法。我们首先引入了一种pruning-aware hidden state alignment method来稳定性能提升前的残留标记。此外,根据我们的详细分析,我们提出了一个适用于SSM模型的标记重要性评估方法,以指导标记剪枝。通过高效的实现和实际加速方法,我们的方法实现了实际速度提升。大量的实验证明,我们的方法可以在不同任务上实现显著的计算减少,同时对性能的影响非常小。值得注意的是,在ImageNet上,我们实现了81.7%的准确率,同时将FLOPs减少了41.6%。此外,我们的工作为未来研究提供了对SSM基于视觉模型的行为的深入理解。
URL
https://arxiv.org/abs/2409.18962