Paper Reading AI Learner

Implicit Regularization of Gradient Flow on One-Layer Softmax Attention

2024-03-13 17:02:27
Heejune Sheen, Siyu Chen, Tianhao Wang, Harrison H. Zhou

Abstract

We study gradient flow on the exponential loss for a classification problem with a one-layer softmax attention model, where the key and query weight matrices are trained separately. Under a separability assumption on the data, we show that when gradient flow achieves the minimal loss value, it further implicitly minimizes the nuclear norm of the product of the key and query weight matrices. Such implicit regularization can be described by a Support Vector Machine (SVM) problem with respect to the attention weights. This finding contrasts with prior results showing that the gradient descent induces an implicit regularization on the Frobenius norm on the product weight matrix when the key and query matrices are combined into a single weight matrix for training. For diagonal key and query matrices, our analysis builds upon the reparameterization technique and exploits approximate KKT conditions of the SVM associated with the classification data. Moreover, the results are extended to general weights configurations given proper alignment of the weight matrices' singular spaces with the data features at initialization.

Abstract (translated)

我们研究的是在具有单层软max注意力的分类问题中,梯度在指数损失上的传播。在这种假设数据上,我们证明了当梯度达到最小损失值时,它进一步隐含地最小化了键和查询权重矩阵的乘积核范数。这种隐式正则化可以描述为与注意力权重相关的支持向量机(SVM)问题。这一发现与之前的结果相反,后者表明在将键和查询矩阵组合成一个权重矩阵进行训练时,梯度下降会在乘积权重矩阵上诱导隐式正则化。对于对称的键和查询矩阵,我们的分析基于同余变换技术和与分类数据相关的SVM的近KKT条件。此外,结果还扩展到给定合适的数据特征与初始化时权重矩阵的向量空间对齐的情况。

URL

https://arxiv.org/abs/2403.08699

PDF

https://arxiv.org/pdf/2403.08699.pdf


Tags
3D Action Action_Localization Action_Recognition Activity Adversarial Agent Attention Autonomous Bert Boundary_Detection Caption Chat Classification CNN Compressive_Sensing Contour Contrastive_Learning Deep_Learning Denoising Detection Dialog Diffusion Drone Dynamic_Memory_Network Edge_Detection Embedding Embodied Emotion Enhancement Face Face_Detection Face_Recognition Facial_Landmark Few-Shot Gait_Recognition GAN Gaze_Estimation Gesture Gradient_Descent Handwriting Human_Parsing Image_Caption Image_Classification Image_Compression Image_Enhancement Image_Generation Image_Matting Image_Retrieval Inference Inpainting Intelligent_Chip Knowledge Knowledge_Graph Language_Model LLM Matching Medical Memory_Networks Multi_Modal Multi_Task NAS NMT Object_Detection Object_Tracking OCR Ontology Optical_Character Optical_Flow Optimization Person_Re-identification Point_Cloud Portrait_Generation Pose Pose_Estimation Prediction QA Quantitative Quantitative_Finance Quantization Re-identification Recognition Recommendation Reconstruction Regularization Reinforcement_Learning Relation Relation_Extraction Represenation Represenation_Learning Restoration Review RNN Robot Salient Scene_Classification Scene_Generation Scene_Parsing Scene_Text Segmentation Self-Supervised Semantic_Instance_Segmentation Semantic_Segmentation Semi_Global Semi_Supervised Sence_graph Sentiment Sentiment_Classification Sketch SLAM Sparse Speech Speech_Recognition Style_Transfer Summarization Super_Resolution Surveillance Survey Text_Classification Text_Generation Tracking Transfer_Learning Transformer Unsupervised Video_Caption Video_Classification Video_Indexing Video_Prediction Video_Retrieval Visual_Relation VQA Weakly_Supervised Zero-Shot