Paper Reading AI Learner

Exploring Token Pruning in Vision State Space Models

2024-09-27 17:59:50
Zheng Zhan, Zhenglun Kong, Yifan Gong, Yushu Wu, Zichong Meng, Hangyu Zheng, Xuan Shen, Stratis Ioannidis, Wei Niu, Pu Zhao, Yanzhi Wang

Abstract

State Space Models (SSMs) have the advantage of keeping linear computational complexity compared to attention modules in transformers, and have been applied to vision tasks as a new type of powerful vision foundation model. Inspired by the observations that the final prediction in vision transformers (ViTs) is only based on a subset of most informative tokens, we take the novel step of enhancing the efficiency of SSM-based vision models through token-based pruning. However, direct applications of existing token pruning techniques designed for ViTs fail to deliver good performance, even with extensive fine-tuning. To address this issue, we revisit the unique computational characteristics of SSMs and discover that naive application disrupts the sequential token positions. This insight motivates us to design a novel and general token pruning method specifically for SSM-based vision models. We first introduce a pruning-aware hidden state alignment method to stabilize the neighborhood of remaining tokens for performance enhancement. Besides, based on our detailed analysis, we propose a token importance evaluation method adapted for SSM models, to guide the token pruning. With efficient implementation and practical acceleration methods, our method brings actual speedup. Extensive experiments demonstrate that our approach can achieve significant computation reduction with minimal impact on performance across different tasks. Notably, we achieve 81.7\% accuracy on ImageNet with a 41.6\% reduction in the FLOPs for pruned PlainMamba-L3. Furthermore, our work provides deeper insights into understanding the behavior of SSM-based vision models for future research.

Abstract (translated)

State Space Models (SSMs)具有保持线性计算复杂性与Transformer中的注意力模块相比的优势,并且已经应用于视觉任务作为新型强大的视觉基础模型。受到视觉Transformer(ViTs)中最后预测仅基于最具信息性的标记的观察,我们迈出了通过基于标记的剪枝增强SSM基于视觉模型的效率的新一步。然而,为实现这一目标,为ViTs设计的现有标记剪枝技术的直接应用却无法实现良好的性能,即使是进行了广泛的微调。为了应对这个问题,我们重新审视了SSM的独特的计算特性,并发现了 naive application会破坏标记的序列位置的观察。这一见解促使我们设计一种新的、适用于SSM基于视觉模型的自适应标记剪枝方法。我们首先引入了一种pruning-aware hidden state alignment method来稳定性能提升前的残留标记。此外,根据我们的详细分析,我们提出了一个适用于SSM模型的标记重要性评估方法,以指导标记剪枝。通过高效的实现和实际加速方法,我们的方法实现了实际速度提升。大量的实验证明,我们的方法可以在不同任务上实现显著的计算减少,同时对性能的影响非常小。值得注意的是,在ImageNet上,我们实现了81.7%的准确率,同时将FLOPs减少了41.6%。此外,我们的工作为未来研究提供了对SSM基于视觉模型的行为的深入理解。

URL

https://arxiv.org/abs/2409.18962

PDF

https://arxiv.org/pdf/2409.18962.pdf


Tags
3D Action Action_Localization Action_Recognition Activity Adversarial Agent Attention Autonomous Bert Boundary_Detection Caption Chat Classification CNN Compressive_Sensing Contour Contrastive_Learning Deep_Learning Denoising Detection Dialog Diffusion Drone Dynamic_Memory_Network Edge_Detection Embedding Embodied Emotion Enhancement Face Face_Detection Face_Recognition Facial_Landmark Few-Shot Gait_Recognition GAN Gaze_Estimation Gesture Gradient_Descent Handwriting Human_Parsing Image_Caption Image_Classification Image_Compression Image_Enhancement Image_Generation Image_Matting Image_Retrieval Inference Inpainting Intelligent_Chip Knowledge Knowledge_Graph Language_Model LLM Matching Medical Memory_Networks Multi_Modal Multi_Task NAS NMT Object_Detection Object_Tracking OCR Ontology Optical_Character Optical_Flow Optimization Person_Re-identification Point_Cloud Portrait_Generation Pose Pose_Estimation Prediction QA Quantitative Quantitative_Finance Quantization Re-identification Recognition Recommendation Reconstruction Regularization Reinforcement_Learning Relation Relation_Extraction Represenation Represenation_Learning Restoration Review RNN Robot Salient Scene_Classification Scene_Generation Scene_Parsing Scene_Text Segmentation Self-Supervised Semantic_Instance_Segmentation Semantic_Segmentation Semi_Global Semi_Supervised Sence_graph Sentiment Sentiment_Classification Sketch SLAM Sparse Speech Speech_Recognition Style_Transfer Summarization Super_Resolution Surveillance Survey Text_Classification Text_Generation Time_Series Tracking Transfer_Learning Transformer Unsupervised Video_Caption Video_Classification Video_Indexing Video_Prediction Video_Retrieval Visual_Relation VQA Weakly_Supervised Zero-Shot