Post

PyTorch에서 `LSTMCell` 이해하기 - 내부 동작과 구현

딥러닝에서 순환 신경망(RNN)은 시계열 데이터나 자연어 처리(NLP)에서 매우 중요한 역할을 합니다. 하지만 RNN은 긴 시퀀스 데이터에서의 장기 의존성을 학습하는 데 한계가 있습니다. 이를 해결하기 위해 LSTM(Long Short-Term Memory)이라는 구조가 도입되었고, PyTorch에서는 LSTMCell()을 통해 이를 쉽게 구현할 수 있습니다. 이번 글에서는 LSTMCell이 무엇인지, 어떻게 동작하는지, 그리고 이를 구현하는 방법에 대해 쉽게 설명하겠습니다.


1. LSTMCell이란?

LSTMCell은 LSTM(Long Short-Term Memory)의 기본 단위입니다. 일반적인 LSTM 레이어는 여러 타임스텝을 한꺼번에 처리하지만, LSTMCell은 개별 타임스텝에서의 계산을 수행합니다. 이를 통해 RNN의 동작을 더 세밀하게 제어할 수 있습니다.


2. LSTM의 기본 개념

LSTM은 일반적인 RNN이 시퀀스 데이터를 처리할 때 발생하는 기울기 소실 문제(gradient vanishing problem)를 해결하기 위해 설계된 구조입니다. 이를 위해 LSTM은 셀 상태(cell state)은닉 상태(hidden state)라는 두 가지 주요 상태를 유지하면서, 네 가지 주요 게이트를 통해 정보를 선택적으로 기억하거나 잊습니다.

alt text

출처: https://colah.github.io/posts/2015-08-Understanding-LSTMs/


또한, LSTM에서 은닉 상태셀 상태의 차원은 동일해야 합니다.

  • 은닉 상태: 네트워크가 출력하는 상태로, 다음 타임스텝의 계산에 사용됩니다.
  • 셀 상태: 더 장기적인 메모리를 유지하는 상태로, 현재 타임스텝의 셀 상태는 이전 타임스텝의 셀 상태와 입력, 그리고 은닉 상태에 기반해 업데이트됩니다.

이 두 상태가 동일한 차원을 가져야, LSTM이 내부적으로 계산을 수행할 때 일관성을 유지할 수 있습니다. PyTorch와 같은 프레임워크에서도 이 점을 강제하며, 두 상태의 차원이 다를 경우 오류가 발생합니다.


3. LSTM의 구성 요소

LSTM은 다음과 같은 네 가지 게이트와 상태로 구성됩니다:

  1. 입력 게이트 ($i_t$): 새 정보가 셀 상태에 얼마나 들어올지를 결정합니다.
  2. 망각 게이트 ($f_t$): 이전 셀 상태를 얼마나 유지할지를 결정합니다.
  3. 셀 상태 후보 ($\tilde{C_t}$): 셀 상태를 업데이트하기 위한 새로운 정보입니다.
  4. 출력 게이트 ($o_t$): 은닉 상태가 얼마나 업데이트될지를 결정합니다.


4. LSTM의 수학적 연산

LSTM은 다음과 같은 연산을 통해 동작합니다:

\[\begin{aligned} f_t &= \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \\ i_t &= \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \\ \tilde{C_t} &= \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) \\ C_t &= f_t \cdot C_{t-1} + i_t \cdot \tilde{C_t} \\ o_t &= \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \\ h_t &= o_t \cdot \tanh(C_t) \end{aligned}\]

여기서:

  • $h_t$는 은닉 상태입니다.
  • $C_t$는 셀 상태입니다.
  • $x_t$는 현재 입력입니다.
  • $W_f, W_i, W_C, W_o$는 가중치 행렬이고, $b_f, b_i, b_C, b_o$는 편향입니다.
  • $\sigma$는 시그모이드 함수, $\tanh$는 하이퍼볼릭 탄젠트 함수입니다.


5. LSTMCell 구현하기

이제, LSTMCell을 수동으로 구현해보겠습니다. PyTorch 없이 순수 Python과 NumPy를 사용해 간단한 LSTM 셀을 만들어 보겠습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np

class LSTMCell:
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Weight matrices and biases for the gates
        self.W_f = np.random.randn(hidden_size, hidden_size + input_size)
        self.b_f = np.zeros((hidden_size, 1))

        self.W_i = np.random.randn(hidden_size, hidden_size + input_size)
        self.b_i = np.zeros((hidden_size, 1))

        self.W_C = np.random.randn(hidden_size, hidden_size + input_size)
        self.b_C = np.zeros((hidden_size, 1))

        self.W_o = np.random.randn(hidden_size, hidden_size + input_size)
        self.b_o = np.zeros((hidden_size, 1))

    def forward(self, x_t, h_prev, C_prev):
        # Concatenate h_prev and x_t
        combined = np.concatenate((h_prev, x_t), axis=1)

        # Forget gate
        f_t = self.sigmoid(np.dot(self.W_f, combined.T) + self.b_f)

        # Input gate
        i_t = self.sigmoid(np.dot(self.W_i, combined.T) + self.b_i)

        # Candidate memory cell
        C_tilde = np.tanh(np.dot(self.W_C, combined.T) + self.b_C)

        # Cell state
        C_t = f_t * C_prev + i_t * C_tilde

        # Output gate
        o_t = self.sigmoid(np.dot(self.W_o, combined.T) + self.b_o)

        # Hidden state
        h_t = o_t * np.tanh(C_t)

        return h_t, C_t

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

# Example usage
input_size = 5
hidden_size = 3
lstm_cell = LSTMCell(input_size, hidden_size)

x_t = np.random.randn(1, input_size)
h_prev = np.random.randn(1, hidden_size)
C_prev = np.random.randn(1, hidden_size)

h_t, C_t = lstm_cell.forward(x_t, h_prev, C_prev)

print("h_t:", h_t)
print("C_t:", C_t)


6. PyTorch에서 LSTMCell 사용 예시

LSTMCell을 직접 구현하는 대신, PyTorch의 내장 클래스를 사용할 수도 있습니다. PyTorch에서 nn.LSTMCell을 사용하는 예시는 다음과 같습니다:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.nn as nn

# LSTMCell 초기화
input_size = 10
hidden_size = 20
lstm_cell = nn.LSTMCell(input_size, hidden_size)

# 입력 데이터 (예: batch_size=3, input_size=10)
input = torch.randn(3, 10)

# 이전 타임스텝의 은닉 상태 및 셀 상태 초기화
hx = torch.randn(3, 20)  # 은닉 상태
cx = torch.randn(3, 20)  # 셀 상태

# LSTMCell 실행 (하나의 타임스텝)
hx, cx = lstm_cell(input, (hx, cx))

print(hx.shape, cx.shape)  # 출력: (3, 20) (3, 20)


여기서 nn.LSTMCellinput_sizehidden_size를 지정해 초기화됩니다. input_size는 입력 데이터의 크기이고, hidden_size는 은닉 상태 및 셀 상태의 크기입니다. 이 셀을 이용해 각 타임스텝에서의 연산을 직접 제어할 수 있습니다.


7. 결론

LSTMCell은 PyTorch에서 LSTM의 기본 단위를 구현하는 클래스입니다. 특히 은닉 상태와 셀 상태의 차원이 동일해야 한다는 점은 LSTM의 연산이 올바르게 이루어지기 위한 중요한 조건 중 하나입니다.

  • LSTMCell: 한 번에 하나의 타임스텝 만 처리하며, 은닉 상태와 셀 상태를 수동으로 업데이트할 수 있습니다. 더 세밀한 제어가 필요할 때 사용됩니다.
  • LSTM: 여러 타임스텝 을 자동으로 처리하며, 전체 시퀀스를 한꺼번에 처리할 때 사용됩니다.
This post is licensed under CC BY 4.0 by the author.