
Nghiên cứu của Nous đề xuất Lighthouse Chú ý: Chú ý theo cấp bậc dựa trên lựa chọn chỉ dành cho đào tạo giúp tăng tốc độ đào tạo trước 1,4–1,7× ở bối cảnh dài
Việc đào tạo các mô hình ngôn ngữ lớn theo chuỗi dài có một vấn đề nổi tiếng: sự chú ý rất tốn kém. Sự chú ý của sản phẩm chấm (SDPA) được chia tỷ lệ ở lõi của mỗi máy biến áp có tỷ lệ bậc hai Θ(N²) trong cả điện toán và bộ nhớ với độ dài chuỗi N. FlashAttention đã giải quyết vấn đề này thông qua việc xếp lớp nhận biết IO để tránh hiện thực hóa toàn bộ ma trận chú ý N×N trong bộ nhớ băng thông cao, giảm đáng kể dung lượng bộ nhớ nhưng vẫn duy trì tỷ lệ điện toán Θ(N²) cơ bản. Các nhà nghiên cứu tại Nous Research đã giới thiệu một phương pháp mới có tên là Lighthouse Chú ý nhằm giải quyết vấn đề tắc nghẽn này.
Việc đào tạo các mô hình ngôn ngữ lớn theo chuỗi dài có một vấn đề nổi tiếng: sự chú ý rất tốn kém. Sự chú ý của sản phẩm chấm (SDPA) được chia tỷ lệ ở lõi của mỗi máy biến áp có tỷ lệ bậc hai Θ(N²) trong cả điện toán và bộ nhớ với độ dài chuỗi N. FlashAttention đã giải quyết vấn đề này thông qua việc xếp lớp nhận biết IO để tránh hiện thực hóa toàn bộ ma trận chú ý N×N trong bộ nhớ băng thông cao, giảm đáng kể dung lượng bộ nhớ nhưng vẫn duy trì tỷ lệ điện toán Θ(N²) cơ bản. Các nhà nghiên cứu tại Nous Research đã giới thiệu một phương pháp mới có tên là Lighthouse Chú ý nhằm giải quyết nút thắt cổ chai này một cách cụ thể tại thời điểm đào tạo trước, đạt được tốc độ tăng tốc đồng hồ treo tường từ 1,40× đến 1,69× từ đầu đến cuối so với đường cơ sở SDPA được cuDNN hỗ trợ, với tổn thất đào tạo cuối cùng tương đương hoặc thấp hơn.
Vấn đề cốt lõi của các phương pháp chú ý thưa thớt hiện có
Để hiểu lý do tại sao Lighthouse hoạt động như vậy, cần biết các phương pháp chú ý thưa thớt hiện có để làm gì. Hầu hết các công việc trước đây như NSA, HISA, DSA, MoBA đều đưa ra hai quyết định thiết kế giống nhau. Đầu tiên, chúng chỉ gộp phần khóa và giá trị trong khi để lại các truy vấn ở độ phân giải đầy đủ (nén không đối xứng). Thứ hai, logic lựa chọn của họ nằm bên trong hạt nhân chú ý tùy chỉnh, có nghĩa là các nhóm không thể sử dụng lại hạt nhân chú ý dày đặc đã được tối ưu hóa mà các lõi tensor GPU hiện đại được xây dựng xung quanh.
Ngoài ra còn có một mối quan tâm cụ thể đối với việc đào tạo mà các phương pháp thưa thớt chỉ suy luận không gặp phải. Một phương thức thưa thớt theo thời gian suy luận chỉ được đánh giá dựa trên xương sống dày đặc của nó và nó tốt nhất là tốt như xương sống đó. Phương pháp thưa thớt về thời gian huấn luyện phải đối mặt với một bài kiểm tra khó hơn: sau khi huấn luyện xong, liệu các trọng số thu được có còn tạo ra một mô hình tập trung hiệu quả khi suy luận không? Lighthouse coi câu hỏi đó là tiêu chí chính xác trọng tâm của nó.
Lighthouse có cách tiếp cận khác đối với cả hai quyết định thiết kế. Nó gộp các truy vấn, khóa và giá trị một cách đối xứng trên một kim tự tháp nhiều cấp và nó đặt lựa chọn hoàn toàn bên ngoài hạt nhân chú ý. Sau khi chọn, hệ thống tập hợp các mục đã chọn thành một chuỗi con liền kề, dày đặc và chạy FlashAttention gốc trên đó — cùng một hạt nhân được sử dụng bởi đường cơ sở dày đặc.
https://arxiv.org/pdf/2605.06554
Cách thức hoạt động của đường ống bốn giai đoạn
Lớp chú ý của Ngọn hải đăng bao quanh nhưng không sửa đổi, thu nhỏ sự chú ý của sản phẩm chấm. Đường ống có bốn giai đoạn.
Trong giai đoạn đầu tiên, việc gộp trung bình xây dựng một kim tự tháp cấp L từ Q, K và V. Với hệ số gộp p, cấp ℓ của kim tự tháp có N/p^ℓ mã thông báo, mỗi mã thông báo tóm tắt p^ℓ vị trí cơ sở. Điều quan trọng là việc gộp chung giống nhau áp dụng cho cả ba phép chiếu, tạo ra bộ ba (Q^(ℓ), K^(ℓ), V^(ℓ)) nhất quán ở mọi cấp độ. Tổng chi phí xây dựng kim tự tháp là Θ(N) thời gian và bộ nhớ.
Ở giai đoạn thứ hai, trình chấm điểm không có tham số sẽ chỉ định cho mỗi mục nhập kim tự tháp hai điểm vô hướng bằng cách sử dụng định mức ℓ₂ cho mỗi đầu người: một điểm làm điểm truy vấn (∥Q^(ℓ)_i∥₂) và một điểm làm điểm chính (∥K^(ℓ)_i∥₂). Các cấp độ thô hơn kế thừa điểm số từ các cấp độ tốt hơn thông qua tổng hợp tối đa, do đó, phạm vi thô sẽ cho thấy tầm quan trọng của mã thông báo mạnh nhất của nó. Sau đó, một hạt nhân top-K chunked-bitonic hợp nhất sẽ chọn k mục cùng nhau trên tất cả các cấp độ kim tự tháp. Một chi tiết thiết kế đáng chú ý: cấp độ kim tự tháp thô nhất luôn được giữ lại đầy đủ - nó rẻ và đảm bảo có ít nhất một người đóng góp ở mọi vị trí cơ sở; ngân sách lựa chọn còn lại được chi cho các cấp độ tốt hơn. Ngoài ra, thiết kế chunked-bitonic tạo ra top-K được phân tầng thay vì top-K toàn cầu nghiêm ngặt: luồng điểm được phân chia thành các phần có kích thước cố định, mỗi phần duy trì một bộ đệm top-m trong sổ đăng ký, vì vậy nếu k mục có điểm cao nhất trên toàn cầu được nhóm lại thành một phần, một số mục sẽ được thay thế bằng các mục có điểm thấp hơn từ các phần khác. Kết quả là phạm vi chú ý được cân bằng hơn trong toàn bộ chuỗi và tránh việc lựa chọn bị thu hẹp trong một phạm vi hẹp.
Bước top-K là rời rạc và không khả vi - không có công cụ ước tính xuyên suốt, không có softmax Gumbel. Các chỉ số lựa chọn không có độ dốc. Độ dốc chỉ chảy qua các mục nhập Q, K, V được thu thập vào WQ, WK, WV, do đó, các phép chiếu học cách tạo ra các giá trị hữu ích khi được chọn thay vì điểm tốt khi chọn.
Ở giai đoạn thứ ba, các mục đã chọn sẽ được tập hợp thành một chuỗi con liền kề có độ dài S = N/p^(L−1) + (L−1)·p·k và được chuyển đến FlashAttention tiêu chuẩn. Tại N = 1.000.000 với L = 4, p = 4, k = 4.096, S ≈ 65.000 - nhỏ hơn nhiều so với N. Một đặc tính quan trọng của quy trình thu thập là nó đảm bảo không có “lỗ hổng” hoặc khoảng trống nào trong chuỗi con được lắp ráp. Điều này đặc biệt quan trọng vì Lighthouse cũng nén các truy vấn: một khoảng trống trong chuỗi có nghĩa là những mã thông báo bị thiếu đó không có đường dẫn chuyển màu trong quá trình truyền ngược và có thể gây ra sự mất ổn định trong quá trình đào tạo. Các phương pháp bất đối xứng để lại các truy vấn ở độ phân giải đầy đủ không gặp phải vấn đề này, nhưng thiết kế đối xứng của Lighthouse yêu cầu chuỗi con được thu thập vẫn hoàn toàn dày đặc.
Trong giai đoạn thứ tư, mỗi mục đầu ra được phân tán trở lại vị trí cơ sở p^ℓ mà nó đại diện thông qua hạt nhân phân tán nguyên tử số nguyên xác định, với độ dịch chuyển p^ℓ - 1 để bảo toàn quan hệ nhân quả. Fan-in trên mỗi vị trí được giới hạn bởi L bất kể k.
https://arxiv.org/pdf/2605.06554
Tại sao việc gộp đối xứng lại thay đổi cách tính toán
Việc gộp các truy vấn cùng với các khóa và giá trị sẽ thay đổi đặc tính tính toán của lệnh gọi chú ý từ O(N Sd) thành O(S² d) tại thời điểm đào tạo. Bởi vì S ≪ N ở bối cảnh dài, đây là điều tạo ra lợi thế về độ trễ. Được đo điểm chuẩn trên một NVIDIA B200 duy nhất ở bối cảnh 512K (bfloat16, B=1, H=8, kích thước đầu 128, L=3, p=4, độ thưa thớt ≈ 1:64), Lighthouse nhanh hơn 21 lần khi chuyển tiếp và nhanh hơn 17,3 lần khi chuyển tiếp tiến + lùi kết hợp so với SDPA được cuDNN hỗ trợ.
Từ quan điểm tiệm cận, việc đặt L = logp(N/k) sẽ đưa ra kích thước chuỗi con được tập hợp là S = Θ(k log N), điều này làm cho chi phí cuộc gọi FlashAttention dày đặc Θ(k² log² N d) — đa logarit tính bằng N tại k cố định. Kết hợp với các giai đoạn chi phí tuyến tính (xây dựng kim tự tháp, tính điểm, sc


Nguồn tin: MarkTechPost — Tác giả: Asif Razzaq. Bản dịch tiếng Việt do AI thực hiện, có thể có sai sót.