Post

DDPM 논문 이해하기 (3편): Loss 수식의 이해

[PDF] [Code]

지난 글(2편, diffusion의 forward와 reverse process에 대한 이해)에 이어서 음의 로그 가능도에 대한 변분 하한을 최적화하는 과정을 설명하려 합니다. 이 과정은 Denoising Diffusion Probabilistic Models (DDPM)과 같은 모델의 학습 과정의 핵심이기 때문에 정확하게 파악하는 것이 중요합니다. 목표는 입력 데이터를 가장 잘 설명하는 확률 분포인 $p_\theta (x_0)$을 추정하는 것입니다!



1. 변분 하한과 주변 분포

우리는 원본 이미지 데이터를 가장 잘 나타내는 확률분포 $p_\theta(\mathbf{x}_0)$을 추정하려고 합니다. 이는 diffusion 과정을 통해 분해되며, 알고자 하는 $\mathbf{x}_0$에 대해서만 남기는 주변 분포(marginal distribution)로 표현됩니다. 이 과정은 다음과 같이 나타낼 수 있습니다:

\[p_\theta(\mathbf{x}_0) = \int p_\theta(\mathbf{x}_{0:T})\,d\mathbf{x}_{1:T}\]

즉, $p_\theta(\mathbf{x}_0)$는 중간 상태인 $\mathbf{x}_1, \mathbf{x}_2, \dots, \mathbf{x}_T$에 대해 적분한 값을 의미합니다. 이를 계산 가능하게 하기 위해, 우리는 근사 분포 $q(\mathbf{x}_{1:T} | \mathbf{x}_0)$를 도입하고, 분모와 분자에 이 분포를 곱하여 다음과 같이 변환합니다:

\[p_\theta(\mathbf{x}_0) = \int \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)} q(\mathbf{x}_{1:T}|\mathbf{x}_0) \, d\mathbf{x}_{1:T}\]

이제 이 식은 기댓값 형태로 다시 쓸 수 있습니다. (참고):

\[p_\theta(\mathbf{x}_0) = \mathbb{E}_{q}\left[ \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)} \right]\]

이를 통해 우리는 복잡한 다변량 분포를 효율적으로 처리할 수 있으며, 확률 분포의 기댓값을 계산하는 방식으로 접근하게 됩니다.


2. 젠슨의 부등식과 음의 로그 가능도

이제 위에서 얻은 식에 로그를 취해 로그 가능도를 구해보겠습니다. 이는 모델이 얼마나 잘 데이터를 설명하는지 나타내는 지표입니다. 로그를 취하면 다음과 같은 식을 얻게 됩니다:

\[\log p_\theta(\mathbf{x}_0) = \log \mathbb{E}_{q}\left[ \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)} \right]\]

여기서 젠슨의 부등식을 적용할 수 있습니다. 젠슨의 부등식은 기댓값의 로그가 기댓값 자체의 로그보다 항상 작거나 같다는 원리를 의미합니다. 이를 통해 다음과 같은 하한을 구할 수 있습니다:

\[\log p_\theta(\mathbf{x}_0) \ge \mathbb{E}_{q}\left[ \log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)} \right]\]

즉, 우리는 로그 가능도의 하한을 구할 수 있으며, 이 하한을 최대화함으로써 모델을 학습하게 됩니다. 다음으로, 이 식을 조금 더 풀어 설명해 보겠습니다. 분자와 분모에 등장하는 확률 분포들을 마코프 체인이라는 성질을 이용해 단계적으로 나눌 수 있습니다. 마코프 체인이란, 현재 상태가 이전 상태에만 의존하는 성질을 의미합니다. 이를 적용하면 아래와 같이 표현할 수 있습니다:

\[\log p_\theta(\mathbf{x}_0) \ge \mathbb{E}_q \left[ \log \frac{p_\theta(\mathbf{x}_T)\prod_{t=1}^T p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{\prod_{t=1}^T q(\mathbf{x}_t|\mathbf{x}_{t-1})} \right] = \mathbb{E}_q \left[ \log p_\theta(\mathbf{x}_T) + \sum_{t=1}^T \log \frac{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{q(\mathbf{x}_t | \mathbf{x}_{t-1})}\right]\]

이 수식은 모델이 단계별로 상태 $\mathbf{x}_0$에서 시작해 $\mathbf{x}_T$에 이르는 과정을 어떻게 예측하는지를 설명합니다. 각각의 상태 변화 과정에서, 모델의 예측 확률과 우리가 설정한 근사 확률 $q$ 간의 차이를 최대한 줄이는 방향으로 학습을 진행하게 됩니다.


3. 손실 항으로의 분해

우리가 학습하는 모델이 데이터를 얼마나 잘 설명하는지를 평가하기 위해 손실 함수라는 개념을 사용합니다. 이 손실 함수는 음의 로그 가능도에 기반하며, 값이 작을수록 모델이 데이터를 더 잘 설명하는 것을 의미합니다. 먼저, 음의 로그 가능도 식에서 손실 함수 $L$을 다음과 같이 정의할 수 있습니다:

\[L = \mathbb{E}_q \left[ -\log p_\theta(\mathbf{x}_T) - \sum_{t=1}^T \log \frac{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{q(\mathbf{x}_t | \mathbf{x}_{t-1})}\right]\]

이 수식은 모델이 각 단계에서 상태 $\mathbf{x}_t$에서 이전 상태 $\mathbf{x}_{t-1}$로 가는 과정을 설명합니다. 이제 이 손실 함수를 좀 더 자세히 분해하여 전개해 보겠습니다.


Negative Log-Likelihood (NLL)의 수식 전개

먼저 이 손실 함수는 다음과 같이 다시 쓸 수 있습니다:

\[L = \mathbb{E}_{q}\left[-\log p_\theta(\mathbf{x}_T) - \sum_{t\geq 1} \log\frac{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{q(\mathbf{x}_t|\mathbf{x}_{t-1})} \right]\]

여기서, $t = 1$인 경우를 따로 분리해서 다음과 같이 쓸 수 있습니다:

\[L = \mathbb{E}_{q}\left[-\log p_\theta(\mathbf{x}_T) - \sum_{t > 1} \log\frac{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} -\log \frac{p_\theta(\mathbf{x}_0|\mathbf{x}_1)}{q(\mathbf{x}_1|\mathbf{x}_0)}\right]\]


마코프 체인과 베이즈 정리 사용!

이제, 마코프 체인이라는 성질을 이용해 추가적인 조건부 확률을 넣어 전개할 수 있습니다. 여기서 마코프 체인은 각 상태가 직전 상태에만 의존하는 성질을 말합니다. 이를 이용하여 다음과 같이 전개합니다:

\[\frac{1}{q(\mathbf{x}_t | \mathbf{x}_{t-1})}=\frac{1}{q(\mathbf{x}_t | \mathbf{x}_{t-1}, \mathbf{x}_0)} = \frac{1}{\frac{q(\mathbf{x}_t,\mathbf{x}_{t-1}, \mathbf{x}_0)}{q(\mathbf{x}_{t-1}, \mathbf{x}_0)}} = \frac{q(\mathbf{x}_{t-1}, \mathbf{x}_0)}{q(\mathbf{x}_t,\mathbf{x}_{t-1}, \mathbf{x}_0)}\]

이를 통해 각 항을 마코프 체인의 성질로 나눌 수 있으며, 분자를 정리하면:

\[\frac{q(\mathbf{x}_{t-1}, \mathbf{x}_0)}{q(\mathbf{x}_t,\mathbf{x}_{t-1}, \mathbf{x}_0)}=\frac{q(\mathbf{x}_{t-1}, \mathbf{x}_0)}{q(\mathbf{x}_t,\mathbf{x}_{t-1}, \mathbf{x}_0)}\cdot \frac{q(\mathbf{x}_t, \mathbf{x}_0)}{q(\mathbf{x}_t, \mathbf{x}_0)} = \frac{1}{q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0)} \cdot \frac{q(\mathbf{x}_{t-1}, \mathbf{x}_0)}{q(\mathbf{x}_t, \mathbf{x}_0)}\]

따라서 손실 함수는 다음과 같이 정리됩니다:

\[\mathbb{E}_{q}\left[-\log p_\theta(\mathbf{x}_{T}) - \sum_{t > 1} \log\frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0)} \cdot \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t | \mathbf{x}_0)} -\log \frac{p_\theta(\mathbf{x}_0|\mathbf{x}_1)}{q(\mathbf{x_1|}\mathbf{x}_0)}\right]\]


$\Sigma$ 의 전개와 로그의 합을 곱으로 변환

위 식에서 세 번째 항을 풀어서 작성하면 다음과 같이 됩니다:

\[- \sum_{t > 1} \log\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t | \mathbf{x}_0)} = -\log\frac{q(\mathbf{x}_1|\mathbf{x}_0)}{q(\mathbf{x}_2|\mathbf{x}_0)} -\log\frac{q(\mathbf{x}_2|\mathbf{x}_0)}{q(\mathbf{x}_3|\mathbf{x}_0)} - \cdots -\log\frac{q(\mathbf{x}_{T-1}|\mathbf{x}_0)}{q(\mathbf{x}_T|\mathbf{x}_0)}\]

로그의 성질에 따라 로그의 합은 곱으로 변환할 수 있습니다:

\[- \sum_{t > 1} \log\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t | \mathbf{x}_0)} = -\log \frac{q(\mathbf{x}_1|\mathbf{x}_0)}{q(\mathbf{x}_T|\mathbf{x}_0)}\]

이를 최종적으로 정리하면 손실 함수는 다음과 같이 쓸 수 있습니다:

\[\mathbb{E}_{q}\left[-\log \frac{p_\theta(\mathbf{x}_{T})}{q(\mathbf{x}_T|\mathbf{x}_0)} - \sum_{t > 1} \log\frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0)} -\log {p_\theta(\mathbf{x}_0|\mathbf{x}_1)}\right]\]


KL divergence로 나타낸 손실함수

이 식을 KL 발산(Kullback-Leibler Divergence)이라는 개념으로 정리할 수 있습니다. KL 발산은 두 확률 분포가 얼마나 다른지를 나타내는 지표입니다. KL 발산을 사용하면 손실 함수는 다음과 같이 표현됩니다:

\[L = \mathbb{E}_{q}\left[ D_{\mathrm{KL}}\!\left(q(\mathbf{x}_T|\mathbf{x}_0)~{}\|~p_\theta(\mathbf{x}_T)\right) + \sum_{t>1} D_{\mathrm{KL}}\!\left(q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)~{}\|~p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)\right) -\log p_\theta(\mathbf{x}_0|\mathbf{x}_1)\right]\]

여기서 KL 발산 항은 모델이 예측한 분포 $p_\theta$와 우리가 설정한 근사 분포 $q$ 간의 차이를 측정합니다. 이 값이 작을수록 두 분포가 비슷해지며, 이는 모델이 더 정확하게 데이터를 설명한다는 뜻입니다.


4. 결론

이와 같은 세부적인 분해 과정을 통해, 변분 하한을 이용하여 음의 로그 가능도를 최적화함으로써 효율적인 학습이 가능합니다. 특히, 마코프 성질과 베이즈 정리를 사용하여 KL 발산 항으로 분해하고, 가우시안 분포 간의 비교로 표현할 수 있습니다. 이를 통해 학습이 더 안정적이고 빠르게 수렴할 수 있게 됩니다!

이번 글에서는 DDPM 논문에 정의된 음의 로그 가능도에 대한 변분 하한 최적화 과정과 이 개념들이 학습 과정에서 어떻게 상호작용하는지 살펴보았습니다. 다음 글에서는 본격적으로 diffusion 모델에 대해 다뤄보겠습니다.

This post is licensed under CC BY 4.0 by the author.