Paper Reading AI Learner

Causal Diffusion Autoencoders: Toward Counterfactual Generation via Diffusion Probabilistic Models

2024-04-27 00:09:26
Aneesh Komanduri, Chen Zhao, Feng Chen, Xintao Wu

Abstract

Diffusion probabilistic models (DPMs) have become the state-of-the-art in high-quality image generation. However, DPMs have an arbitrary noisy latent space with no interpretable or controllable semantics. Although there has been significant research effort to improve image sample quality, there is little work on representation-controlled generation using diffusion models. Specifically, causal modeling and controllable counterfactual generation using DPMs is an underexplored area. In this work, we propose CausalDiffAE, a diffusion-based causal representation learning framework to enable counterfactual generation according to a specified causal model. Our key idea is to use an encoder to extract high-level semantically meaningful causal variables from high-dimensional data and model stochastic variation using reverse diffusion. We propose a causal encoding mechanism that maps high-dimensional data to causally related latent factors and parameterize the causal mechanisms among latent factors using neural networks. To enforce the disentanglement of causal variables, we formulate a variational objective and leverage auxiliary label information in a prior to regularize the latent space. We propose a DDIM-based counterfactual generation procedure subject to do-interventions. Finally, to address the limited label supervision scenario, we also study the application of CausalDiffAE when a part of the training data is unlabeled, which also enables granular control over the strength of interventions in generating counterfactuals during inference. We empirically show that CausalDiffAE learns a disentangled latent space and is capable of generating high-quality counterfactual images.

Abstract (translated)

扩散概率模型(DPMs)已经成为高质量图像生成的领先技术。然而,DPMs具有任意噪声的潜在空间,没有可解释或可控制的意义。尽管在提高图像样本质量方面已经进行了大量的研究努力,但在使用扩散模型进行表示控制生成方面,工作还很少。具体来说,使用DPM进行因果建模和可控制反事实生成是一个未被探索的领域。 在这项工作中,我们提出CausalDiffAE,一种基于扩散的因果表示学习框架,以实现根据指定因果模型的反事实生成。我们的关键想法是使用编码器从高维数据中提取高级语义的有意义的因果变量,并使用反向扩散建模随机变化。我们提出了一种因果编码机制,将高维数据映射到相关潜在因素,并通过神经网络参数化因果机制。为了确保因果变量的离散化,我们定义了一个变分目标,并利用先验标签信息对潜在空间进行 Regularization。我们还提出了一个基于DDIM的生成反事实程序。 最后,为了应对有限的标记监督情况,我们还研究了在训练数据部分未标记的情况下如何应用CausalDiffAE,这也能在推理过程中对干预强度进行细粒度控制。我们通过实验验证,CausalDiffAE能够学习到一个分离的潜在空间,并能够生成高质量的反事实图像。

URL

https://arxiv.org/abs/2404.17735

PDF

https://arxiv.org/pdf/2404.17735.pdf


Tags
3D Action Action_Localization Action_Recognition Activity Adversarial Agent Attention Autonomous Bert Boundary_Detection Caption Chat Classification CNN Compressive_Sensing Contour Contrastive_Learning Deep_Learning Denoising Detection Dialog Diffusion Drone Dynamic_Memory_Network Edge_Detection Embedding Embodied Emotion Enhancement Face Face_Detection Face_Recognition Facial_Landmark Few-Shot Gait_Recognition GAN Gaze_Estimation Gesture Gradient_Descent Handwriting Human_Parsing Image_Caption Image_Classification Image_Compression Image_Enhancement Image_Generation Image_Matting Image_Retrieval Inference Inpainting Intelligent_Chip Knowledge Knowledge_Graph Language_Model LLM Matching Medical Memory_Networks Multi_Modal Multi_Task NAS NMT Object_Detection Object_Tracking OCR Ontology Optical_Character Optical_Flow Optimization Person_Re-identification Point_Cloud Portrait_Generation Pose Pose_Estimation Prediction QA Quantitative Quantitative_Finance Quantization Re-identification Recognition Recommendation Reconstruction Regularization Reinforcement_Learning Relation Relation_Extraction Represenation Represenation_Learning Restoration Review RNN Robot Salient Scene_Classification Scene_Generation Scene_Parsing Scene_Text Segmentation Self-Supervised Semantic_Instance_Segmentation Semantic_Segmentation Semi_Global Semi_Supervised Sence_graph Sentiment Sentiment_Classification Sketch SLAM Sparse Speech Speech_Recognition Style_Transfer Summarization Super_Resolution Surveillance Survey Text_Classification Text_Generation Tracking Transfer_Learning Transformer Unsupervised Video_Caption Video_Classification Video_Indexing Video_Prediction Video_Retrieval Visual_Relation VQA Weakly_Supervised Zero-Shot