Paper Reading AI Learner

OTTER: Improving Zero-Shot Classification via Optimal Transport

2024-04-12 13:18:47
Changho Shin, Jitian Zhao, Sonia Cromp, Harit Vishwakarma, Frederic Sala

Abstract

Popular zero-shot models suffer due to artifacts inherited from pretraining. A particularly detrimental artifact, caused by unbalanced web-scale pretraining data, is mismatched label distribution. Existing approaches that seek to repair the label distribution are not suitable in zero-shot settings, as they have incompatible requirements such as access to labeled downstream task data or knowledge of the true label balance in the pretraining distribution. We sidestep these challenges and introduce a simple and lightweight approach to adjust pretrained model predictions via optimal transport. Our technique requires only an estimate of the label distribution of a downstream task. Theoretically, we characterize the improvement produced by our procedure under certain mild conditions and provide bounds on the error caused by misspecification. Empirically, we validate our method in a wide array of zero-shot image and text classification tasks, improving accuracy by 4.8% and 15.9% on average, and beating baselines like Prior Matching -- often by significant margins -- in 17 out of 21 datasets.

Abstract (translated)

由于预训练中存在的元数据导致的 artifacts,流行的一零 shot 模型效果不佳。特别有害的元数据是由不平衡的跨网站预训练数据引起的,即不平衡标签分布。现有的试图修复标签分布的方法在零 shot 设置中并不适用,因为它们具有不兼容的要求,如访问已标注的下游任务数据或对预训练分布的真实标签平衡的了解。我们避开了这些挑战,并引入了一种简单而轻量级的通过最优传输调整预训练模型预测的方法。我们的技术只需要预训练任务下游任务的标签分布的估计。从理论上看,我们研究了我们的过程在某些轻度条件下的改进,并提供了由不准确估计引起的误差的上界。在实证研究中,我们在广泛的零 shot图像和文本分类任务中验证了我们的方法,平均提高了 4.8% 的准确率,并且在 21 个数据集中的基线(如 Prior Matching)中击败了像这样具有显著优势的基线。

URL

https://arxiv.org/abs/2404.08461

PDF

https://arxiv.org/pdf/2404.08461.pdf


Tags
3D Action Action_Localization Action_Recognition Activity Adversarial Agent Attention Autonomous Bert Boundary_Detection Caption Chat Classification CNN Compressive_Sensing Contour Contrastive_Learning Deep_Learning Denoising Detection Dialog Diffusion Drone Dynamic_Memory_Network Edge_Detection Embedding Embodied Emotion Enhancement Face Face_Detection Face_Recognition Facial_Landmark Few-Shot Gait_Recognition GAN Gaze_Estimation Gesture Gradient_Descent Handwriting Human_Parsing Image_Caption Image_Classification Image_Compression Image_Enhancement Image_Generation Image_Matting Image_Retrieval Inference Inpainting Intelligent_Chip Knowledge Knowledge_Graph Language_Model LLM Matching Medical Memory_Networks Multi_Modal Multi_Task NAS NMT Object_Detection Object_Tracking OCR Ontology Optical_Character Optical_Flow Optimization Person_Re-identification Point_Cloud Portrait_Generation Pose Pose_Estimation Prediction QA Quantitative Quantitative_Finance Quantization Re-identification Recognition Recommendation Reconstruction Regularization Reinforcement_Learning Relation Relation_Extraction Represenation Represenation_Learning Restoration Review RNN Robot Salient Scene_Classification Scene_Generation Scene_Parsing Scene_Text Segmentation Self-Supervised Semantic_Instance_Segmentation Semantic_Segmentation Semi_Global Semi_Supervised Sence_graph Sentiment Sentiment_Classification Sketch SLAM Sparse Speech Speech_Recognition Style_Transfer Summarization Super_Resolution Surveillance Survey Text_Classification Text_Generation Tracking Transfer_Learning Transformer Unsupervised Video_Caption Video_Classification Video_Indexing Video_Prediction Video_Retrieval Visual_Relation VQA Weakly_Supervised Zero-Shot