Paper Reading AI Learner

A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration

2022-03-25 18:02:13
Ramya Hebbalaguppe, Jatin Prakash, Neelabh Madan, Chetan Arora

Abstract

Deep Neural Networks ( DNN s) are known to make overconfident mistakes, which makes their use problematic in safety-critical applications. State-of-the-art ( SOTA ) calibration techniques improve on the confidence of predicted labels alone and leave the confidence of non-max classes (e.g. top-2, top-5) uncalibrated. Such calibration is not suitable for label refinement using post-processing. Further, most SOTA techniques learn a few hyper-parameters post-hoc, leaving out the scope for image, or pixel specific calibration. This makes them unsuitable for calibration under domain shift, or for dense prediction tasks like semantic segmentation. In this paper, we argue for intervening at the train time itself, so as to directly produce calibrated DNN models. We propose a novel auxiliary loss function: Multi-class Difference in Confidence and Accuracy ( MDCA ), to achieve the same MDCA can be used in conjunction with other application/task-specific loss functions. We show that training with MDCA leads to better-calibrated models in terms of Expected Calibration Error ( ECE ), and Static Calibration Error ( SCE ) on image classification, and segmentation tasks. We report ECE ( SCE ) score of 0.72 (1.60) on the CIFAR 100 dataset, in comparison to 1.90 (1.71) by the SOTA. Under domain shift, a ResNet-18 model trained on PACS dataset using MDCA gives an average ECE ( SCE ) score of 19.7 (9.7) across all domains, compared to 24.2 (11.8) by the SOTA. For the segmentation task, we report a 2X reduction in calibration error on PASCAL - VOC dataset in comparison to Focal Loss. Finally, MDCA training improves calibration even on imbalanced data, and for natural language classification tasks. We have released the code here: code is available at this https URL

Abstract (translated)

URL

https://arxiv.org/abs/2203.13834

PDF

https://arxiv.org/pdf/2203.13834.pdf


Tags
3D Action Action_Localization Action_Recognition Activity Adversarial Agent Attention Autonomous Bert Boundary_Detection Caption Chat Classification CNN Compressive_Sensing Contour Contrastive_Learning Deep_Learning Denoising Detection Dialog Diffusion Drone Dynamic_Memory_Network Edge_Detection Embedding Embodied Emotion Enhancement Face Face_Detection Face_Recognition Facial_Landmark Few-Shot Gait_Recognition GAN Gaze_Estimation Gesture Gradient_Descent Handwriting Human_Parsing Image_Caption Image_Classification Image_Compression Image_Enhancement Image_Generation Image_Matting Image_Retrieval Inference Inpainting Intelligent_Chip Knowledge Knowledge_Graph Language_Model Matching Medical Memory_Networks Multi_Modal Multi_Task NAS NMT Object_Detection Object_Tracking OCR Ontology Optical_Character Optical_Flow Optimization Person_Re-identification Point_Cloud Portrait_Generation Pose Pose_Estimation Prediction QA Quantitative Quantitative_Finance Quantization Re-identification Recognition Recommendation Reconstruction Regularization Reinforcement_Learning Relation Relation_Extraction Represenation Represenation_Learning Restoration Review RNN Salient Scene_Classification Scene_Generation Scene_Parsing Scene_Text Segmentation Self-Supervised Semantic_Instance_Segmentation Semantic_Segmentation Semi_Global Semi_Supervised Sence_graph Sentiment Sentiment_Classification Sketch SLAM Sparse Speech Speech_Recognition Style_Transfer Summarization Super_Resolution Surveillance Survey Text_Classification Text_Generation Tracking Transfer_Learning Transformer Unsupervised Video_Caption Video_Classification Video_Indexing Video_Prediction Video_Retrieval Visual_Relation VQA Weakly_Supervised Zero-Shot