Learning to Reason in LLMs by Expectation Maximization
Abstract
Large language models (LLMs) solve reasoning problems by first generating a rationale and then answering. We formalize reasoning as a latent variable model and derive an expectation-maximization (EM) objective for learning to reason. This view connects EM and modern reward-based optimization, and shows that the main challenge lies in designing a sampling distribution that generates rationales that justify correct answers. We instantiate and compare several sampling schemes: rejection sampling with a budget, self-taught reasoner (STaR), and prompt posterior sampling (PPS), which only keeps the rationalization stage of STaR. Our experiments on the ARC, MMLU, and OpenBookQA datasets with the Llama and Qwen models show that the sampling scheme can significantly affect the accuracy of learned reasoning models. Despite its simplicity, we observe that PPS outperforms the other sampling schemes.
Summary
This paper addresses the problem of improving reasoning in large language models (LLMs) by framing it as a latent variable model (LVM) and applying the Expectation-Maximization (EM) algorithm. The core idea is that reasoning can be seen as generating a rationale (latent variable) that leads to the correct answer. The authors derive an EM objective specialized for LLMs, approximating the E-step with Monte Carlo sampling and the M-step with a filtered gradient-based update. They highlight the importance of the sampling distribution used to generate rationales, arguing that its design is crucial for effective learning. The paper then compares three different sampling schemes: Rejection Sampling with Budget (RS-M), Self-Taught Reasoner (STaR), and Prompt Posterior Sampling (PPS), which is a simplified version of STaR, only keeping the rationalization stage. The key finding is that PPS, despite its simplicity, consistently outperforms the other sampling schemes on the ARC, MMLU, and OpenBookQA datasets using Llama3.2-3B-Instruct and Qwen2.5-3B-Instruct models. This suggests that directly prompting the model to generate rationales conditioned on the correct answer is more effective than methods that rely on rejection sampling or a combination of both. The authors also demonstrate that PPS leads to shorter, more focused rationales. This work contributes by providing a theoretical connection between EM and modern reward-based optimization for LLMs and empirically showing the significant impact of the rationale sampling scheme on the accuracy of learned reasoning models.
Key Insights
- •The paper formalizes learning to reason in LLMs as an Expectation-Maximization (EM) problem, connecting it to latent variable models.
- •The authors demonstrate that many self-improvement reasoning algorithms can be viewed as approximate E-steps followed by M-steps within the EM framework.
- •The paper highlights the importance of the rationale sampling distribution in the E-step of the EM algorithm, arguing that it significantly impacts the performance of the learned reasoning model.
- •Prompt Posterior Sampling (PPS), which involves prompting the model to generate rationales conditioned on the correct answer, outperforms Rejection Sampling with Budget (RS-M) and Self-Taught Reasoner (STaR).
- •Experimental results on ARC, MMLU, and OpenBookQA datasets show that PPS leads to higher test accuracy compared to RS-M and STaR when using Llama3.2-3B-Instruct and Qwen2.5-3B-Instruct models.
- •The authors observe that PPS generally results in shorter and more focused reasoning compared to the other sampling schemes.
- •The study reveals that simply generating more training data (as done by STaR) doesn't necessarily translate to better performance; the quality of the rationales is more crucial.
Practical Implications
- •Practitioners can use Prompt Posterior Sampling (PPS) as a simple and effective technique to improve the reasoning capabilities of LLMs on question-answering tasks.
- •The EM framework provides a theoretical foundation for understanding and designing new learning-to-reason algorithms for LLMs.
- •The findings suggest that focusing on generating high-quality rationales, even if it means a mismatch between train-time and test-time sampling conditions, can be more beneficial than relying solely on methods like rejection sampling.
- •This research opens up avenues for future work on exploring different rationale proposal distributions q(·| x,y ⋆ ;θ) and developing more sophisticated methods for approximating the E-step in the EM algorithm for LLMs.
- •The paper emphasizes the importance of designing prompts that elicit informative reasoning from LLMs.