Paper Reading AI Learner

Score as Action: Fine-Tuning Diffusion Generative Models by Continuous-time Reinforcement Learning

2025-02-03 20:50:05
Hanyang Zhao, Haoxian Chen, Ji Zhang, David D. Yao, Wenpin Tang

Abstract

Reinforcement learning from human feedback (RLHF), which aligns a diffusion model with input prompt, has become a crucial step in building reliable generative AI models. Most works in this area use a discrete-time formulation, which is prone to induced errors, and often not applicable to models with higher-order/black-box solvers. The objective of this study is to develop a disciplined approach to fine-tune diffusion models using continuous-time RL, formulated as a stochastic control problem with a reward function that aligns the end result (terminal state) with input prompt. The key idea is to treat score matching as controls or actions, and thereby making connections to policy optimization and regularization in continuous-time RL. To carry out this idea, we lay out a new policy optimization framework for continuous-time RL, and illustrate its potential in enhancing the value networks design space via leveraging the structural property of diffusion models. We validate the advantages of our method by experiments in downstream tasks of fine-tuning large-scale Text2Image models of Stable Diffusion v1.5.

Abstract (translated)

从人类反馈中进行强化学习(RLHF)以将扩散模型与输入提示对齐,已成为构建可靠生成式AI模型的关键步骤。在这一领域的大多数研究工作中,通常采用离散时间形式化方法,这种方法容易产生诱导误差,并且往往不适用于具有高阶/黑盒求解器的模型。本研究的目标是开发一种基于连续时间强化学习的方法来微调扩散模型,将其视为一个随机控制问题,奖励函数旨在使最终结果(终端状态)与输入提示对齐。该方法的核心思想是将评分匹配视为控制或行动,并因此建立与策略优化和连续时间RL中正则化之间的联系。 为了实现这一想法,我们提出了一种新的连续时间RL的策略优化框架,并展示了它在通过利用扩散模型的结构性质来扩展价值网络设计空间方面的潜力。我们通过实验验证了该方法在微调Stable Diffusion v1.5的大规模文本到图像(Text2Image)模型的任务中的优势。 这一研究不仅为解决离散时间强化学习中常见的问题提供了新的途径,而且还展示了连续时间框架在处理复杂生成任务时的优势,特别是当这些任务涉及到需要精细控制和对齐的高阶或黑盒解算器时。通过这种方式,我们能够进一步提升AI模型生成内容的质量和可靠性,特别是在图像生成等任务上展现出显著效果。

URL

https://arxiv.org/abs/2502.01819

PDF

https://arxiv.org/pdf/2502.01819.pdf


Tags
3D Action Action_Localization Action_Recognition Activity Adversarial Agent Attention Autonomous Bert Boundary_Detection Caption Chat Classification CNN Compressive_Sensing Contour Contrastive_Learning Deep_Learning Denoising Detection Dialog Diffusion Drone Dynamic_Memory_Network Edge_Detection Embedding Embodied Emotion Enhancement Face Face_Detection Face_Recognition Facial_Landmark Few-Shot Gait_Recognition GAN Gaze_Estimation Gesture Gradient_Descent Handwriting Human_Parsing Image_Caption Image_Classification Image_Compression Image_Enhancement Image_Generation Image_Matting Image_Retrieval Inference Inpainting Intelligent_Chip Knowledge Knowledge_Graph Language_Model LLM Matching Medical Memory_Networks Multi_Modal Multi_Task NAS NMT Object_Detection Object_Tracking OCR Ontology Optical_Character Optical_Flow Optimization Person_Re-identification Point_Cloud Portrait_Generation Pose Pose_Estimation Prediction QA Quantitative Quantitative_Finance Quantization Re-identification Recognition Recommendation Reconstruction Regularization Reinforcement_Learning Relation Relation_Extraction Represenation Represenation_Learning Restoration Review RNN Robot Salient Scene_Classification Scene_Generation Scene_Parsing Scene_Text Segmentation Self-Supervised Semantic_Instance_Segmentation Semantic_Segmentation Semi_Global Semi_Supervised Sence_graph Sentiment Sentiment_Classification Sketch SLAM Sparse Speech Speech_Recognition Style_Transfer Summarization Super_Resolution Surveillance Survey Text_Classification Text_Generation Time_Series Tracking Transfer_Learning Transformer Unsupervised Video_Caption Video_Classification Video_Indexing Video_Prediction Video_Retrieval Visual_Relation VQA Weakly_Supervised Zero-Shot