Tech </> tech.log
· - ·
논문리뷰

[LLM Serving] FlashAttention

기다릴 바엔 다시 계산해버린다.

LLM에서 처리해야 하는 Context들은 시간이 지날 수록 길어지고 있다.

따라서 트랜스포머의 핵심인 Attention은 Long context를 처리하면서 시퀀스 길이가 길어질수록 메모리 점유율이 제곱으로 폭주하는 '성능의 병목'을 맞이할 수밖에 없었다.

단순히 수식을 최적화한 것을 넘어, 메모리 전송 속도의 한계를 GPU 연산기의 강력함으로 정복한 논문, FlashAttention에 대해 학습하고자 한다.

Motivation: 하드웨어의 불균형이 만든 '메모리 벽'

현대 반도체의 발전 양상을 확인한다면 흥미로운 점을 발견할 수 있다.

실질적인 데이터 처리를 담당하는 연산기(ALU)의 계산 능력은 ‘무어의 법칙’에 따라서 집적되는 트랜지스터의 수가 매 세대 비약적으로 상승했지만, 데이터를 공급하는 메모리 대역폭의 발전은 그것을 따라가지 못했다.

이로 인해 발생하는 현상이 바로 **'Memory-bound'**이다.

나는 일을 엄청 잘 해서 매우 빨리 일을 다 끝낼 수 있는데, 일을 시키는 명령이 카톡이 아니라 편지로 하나씩 배달 오고 있어서 모든 일을 끝내려면 편지를 기다리느라 한참 걸리는 것이죠. 이처럼 대부분의 어텐션 연산(Softmax, MatrixMultiplication 등)은 실제 계산 시간보다 데이터를 읽고 쓰는, 즉 연산기가 데이터를 기다리고 있는 시간이 훨씬 오래 걸린다.

그렇다면,

"메모리가 데이터를 보내주는 속도가 너무 답답하다면, 차라리 연산기가 조금 더 고생해서(재계산) 그 시간을 아끼면 어떨까?"

이 역발상이 FlashAttention의 시작이다.

즉, 연산량(FLOPs)을 추가 지불하여 느린 메모리 접근 횟수를 줄이는 전략이다.

Background: 하드웨어 계층 구조와 IO-Awareness

1. HBM vs SRAM (면적과 용량의 트레이드오프)

GPU 메모리는 성능과 용량에 따라 크게 두 계층으로 나뉜다.

• HBM (High Bandwidth Memory): 용량은 크지만(4080GB) 상대적으로 느림 (1.52.0 TB/s).

• SRAM (On-chip Memory): 용량은 매우 작지만(192KB/SM), 속도는 HBM보다 10배 이상 빠름 (약 19 TB/s).

이것은 마치 CPU에서 main memory DRAM과 속도는 매우 빠르지만 칩 안에서 차지하는 area가 너무 커서 비효율적인 Cache의 GPU ver라고 생각하면 된다.

따라서 제한된 크기의 SRAM을 극한으로 활용해야 더 빠른 연산이 가능한 것이다.

2. IO-Awareness: 하드웨어를 아는 소프트웨어

일반적으로 소프트웨어의 알고리즘에서는 하드웨어의 성능을 명확하게 파악하지 않는다. 하드웨어의 세부 명세에 의존하지 않으므로 어떤 하드웨어에 붙어도 실행될 수 있도록 호환성을 유지하고, 하드웨어에서 올바른 데이터를 넘겨준다는 것을 상정하고 있어야 소프트웨어을 구성할 수 있기 때문이다.

하지만 특수한 목적을 위해서 소프트웨어가 하드웨어의 성능과 용량을 파악하고 있다면 더욱 뛰어난 성능을 보일 수 있다. 그럴 때 IO-awareness구조를 통해 소프트웨어가 하드웨어의 input과 output을 인지한 채로 실행되도록 한다.

표준 어텐션은 $N \times N$ 크기의 거대한 어텐션 행렬을 매 단계마다 느린 HBM에 썼다가 다시 읽어오는 과정을 반복합니다. FlashAttention은 이 행렬을 HBM에 아예 생성하지 않고, 빠른 SRAM 내부에서 모든 계산을 완결하는 것을 목표로 한다.

연산량을 추가 지불하되, 최소한으로 지불해야 최고의 이득을 얻을 수 있기 때문에 FlashAttention은 알고리즘이 하드웨어의 특성을 인지하여 HBM이 아닌 SRAM에서 연산해야 한다는 IO-Aware 설계를 제안한다.

핵심 기술 1: Tiling & Online Softmax

거대한 $N \times N$ 행렬을 한꺼번에 HBM에 가져와서 처리하는 것이 비효율적이라면, 작게 쪼개서 일부만 SRAM으로 가져와서 처리하면 된다.

이것이 바로 **Tiling(타일링)**기법 이다.

이때 SRAM을 극한으로 활용하기 위한 방법으로 SRAM의 사이즈에 딱 맞도록 Tiling해서 데이터를 가져오기 위해 IO-Aware 설계를 활용하는 것이다.

하지만 여기서 Softmax의 수학적 제약이라는 난관에 봉착한다.

❓ What is Softmax?

$\text{softmax}(x)_i = \frac{e^{x_i}}{\sum e^{x_j}} = \frac{e^{x_i - m}}{\sum e^{x_j - m}}$

Softmax는 Attention Score에 대해서 여러 숫자 중 어떤 것이 얼마나 중요한지 ‘비율(확률)’을 알려주는 공식이다. 이 때 분모를 집중하면 지수합 ($l = \sum e^{x_j}$**)**은 모든 값의 지수를 다 더한 값이다. 이게 전체 비중을 결정한다.

다만 이렇게 연산하게 되면 지수 함수의 값이 너무 크기 때문에 그것을 보정하기 위해 그 행의 최댓값 $e^m$을 분모, 분자에 나누어 전체 softmax의 값은 동일하지만 연산하는 값의 크기는 매우 작도록하는 Safe Softmax를 활용한다.

🚫 Softmax의 딜레마

그 분모는 행 전체의 데이터($x_j$)를 다 알아야 분모인 지수합($l(x)$)을 구할 수 있다. 블록 단위로 쪼개서 데이터를 읽으면, 현재 블록 바깥에 있는 더 큰 값이나 지수합을 알 수 없어 정확한 정규화가 불가능해 보인다.

💡 해결책: Online Softmax (Incremental Update)

FlashAttention은 **'메모지(통계량)'**를 활용해 이 문제를 해결한다. 전체를 다 보기 전이라도, 지금까지 본 것들만 가지고 임시 계산을 해두고 새로운 데이터가 올 때마다 수정하는 방식이다.

다음의 두 가지 핵심 정보를 메모지에 적어둡니다.

  1. 최댓값 ($m$): 지금까지 본 점수 중 가장 높은 점수 (수학적으로 값이 너무 커져서 컴퓨터가 계산 불능에 빠지는 '오버플로'를 막기 위해 필요).
  2. 지수합 ($l$): 지금까지 본 점수들의 지수 합계 (분모 역할).

✅ 연산 방법

이것을 이용하여 연산하는 방법은 새로운 데이터가 들어올 때마다 과거의 계산을 **'보정(Rescaling)'**해주는 것이다.

  • Step 1 (순차 처리): 새로운 블록이 들어오면 그 블록만의 최고점($\tilde{m}$)과 합계($\tilde{l}$)를 구한다.

  • Step 2 (메모지 업데이트): 기존 최고점($m_{old}$)과 새 최고점($\tilde{m}$)을 비교해 진짜 최고점($m_{new}$)을 갱신한다.

  • Step 3 (과거 보정): 이게 핵심입니다! 예전에 계산해둔 결과값($O_{old}$)은 옛날 최고점과 옛날 합계 기준이었죠?
    이걸 **새로운 최고점과 새로운 합계 기준으로 강제로 '업데이트'**해서 합친다.

    $l_{new} = e^{m_{old} - m_{new}} \cdot l_{old} + e^{\tilde{m} - m_{new}} \cdot \tilde{l}$

위 수식에서 $e^{m_{old} - m_{new}}$ 같은 부분들이 바로 **"옛날 데이터를 현재 기준에 맞춰라"**라고 명령하는 보정 계수이다.

정리하자면

표준 어텐션은 모든 블럭의 데이터를 다 모을 때까지 기다렸다가 한꺼번에 비중을 구한다.

하지만 FlashAttention은 한 개씩 계산하면서도 메모지에 "지금까지 최고점은 몇 점이고, 합계는 얼마야"라고 적어두고, 새로운 데이터가 올 때마다 그 메모지를 바탕으로 기존 점수 비중을 살짝살짝 수정해나간다.

핵심 기술 2: Recomputation (연산과 메모리의 Trade-off)

LLM의 학습(Training) 과정에서는 역전파(Backward pass)를 수행하기 위해 순전파 때 계산했던 중간 값($S, P$)들이 필요하다. 표준 방식은 이 거대한 $N^2$ 행렬들을 HBM에 고이 저장해두지만, 이는 치명적인 메모리 점유를 야기한다.

따라서 FlashAttention은 여기서 **'재계산(Recomputation)'**이라는 선택을 제안한다.

  • 버리기: 순전파 때 계산한 $N \times N$ 중간 행렬을 HBM에 저장하지 않고 과감히 버린다.
  • 다시 계산하기: 역전파 시점이 되면, HBM에서 데이터를 읽어오는 대신 저장된 통계량($m, l$)과 출력값($O$)을 바탕으로 SRAM 안에서 즉석에서 다시 계산한다.

이 선택으로 인해 전체 연산량(FLOPs)은 소폭 증가하지만, 느린 HBM 도로를 달리는 시간을 획기적으로 줄였기 때문에 전체 실행 속도는 오히려 훨씬 빨라지게 되는 것이다.

결국 이 기법은 **'가장 흔한 자원(연산 능력)을 아낌없이 투자해 가장 귀한 자원(메모리 대역폭)을 사오는 전략적 선택'**이다. 만약 연산기와 메모리가 똑같은 속도로 발전했다면 재계산은 무의미한 낭비였겠지만, '무어의 법칙'으로 폭발한 연산 능력이 **'메모리 벽'**에 가로막힌 대역폭을 앞지른 현대 반도체의 불균형이 낳은 최적의 공학적 승부수라고 할 수 있다.

결론: 알고리즘이 하드웨어를 이해할 때 일어나는 일

FlashAttention은 단순히 수학적 최적화에 그치지 않고,

"하드웨어의 물리적 병목(IO)을 인지하고, 풍부한 자원(연산)을 활용해 부족한 자원(메모리 대역폭)을 보완”

한 공학적 승리이다.

이제 거의 모든 transformer 기반 에서 FlashAttention은 필수적인 기반 기술이 되었다.

이 논문을 깊이 이해하기 전에는 "연산을 다시 수행하는 것은 무조건 비효율적"이라는 고정관념에 갇혀 있었다. 단순히 FLOPs(연산량) 수치가 늘어나는 것을 경계하거나, 혹은 반대로 산술 집약도(Arithmetic Intensity)가 높아졌다는 수치적 지표만 확인하며 성능이 개선되었다는 결과론적인 해석에 머물렀을지도 모른다.

하지만 진짜 최적화는 숫자 너머의 **'자원 간의 트레이드오프'**를 이해하는 데 있었다. 성능을 측정하고 분석하는 입장에서 단순히 벤치마크 수치가 좋아졌는가에만 집중할 것이 아니라, "이 커널이 하드웨어 구조상 정말 최선인가?", "다른 자원을 희생해서라도 병목을 해결할 대안은 없는가?"를 끊임없이 고민해야 한다는 소중한 선례를 얻었다.

앞으로도 관성적인 개발 방식에서 벗어나 시스템의 본질을 꿰뚫는 커널 설계와 최적화의 가능성을 항상 열어두어야겠다.