Cơ chế attention (chú ý) của Transformer hầu như không thay đổi kể từ năm 2017. Hầu hết các nỗ lực cải thiện hiệu quả đều tập trung vào việc thay thế hoàn toàn cơ chế softmax attention. Một nghiên cứu mới đã đi theo một hướng khác. Nghiên cứu này giữ lại softmax attention và bổ sung một nhánh điều chỉnh.
Một nhóm các nhà nghiên cứu từ Đại học Northwestern, Tilde Research và Đại học Washington đã giới thiệu một cơ chế Local Linear Attention (Chú ý tuyến tính cục bộ) được tham số hóa có tên gọi 'Parallax', có khả năng mở rộng cho quá trình huấn luyện trước các mô hình ngôn ngữ lớn (LLM) và được đồng thiết kế với Muon.
Parallax không tìm kiếm hiệu quả bằng cách cắt giảm tính toán. Thay vào đó, nó chủ động tăng cường tính toán, sau đó làm cho quá trình tính toán đó trở nên rẻ hơn khi thực hiện.
Cơ chế chú ý (attention mechanism) của Transformer hầu như không thay đổi kể từ năm 2017. Hầu hết các nỗ lực cải thiện hiệu quả đều nhằm thay thế hoàn toàn cơ chế chú ý softmax. Một nghiên cứu mới đã đi theo một hướng khác. Nghiên cứu này giữ lại cơ chế chú ý softmax và bổ sung một nhánh hiệu chỉnh.
Một nhóm nghiên cứu từ Đại học Northwestern, Tilde Research và Đại học Washington đã giới thiệu một cơ chế Chú ý Tuyến tính Cục bộ (Local Linear Attention) được tham số hóa có tên gọi 'Parallax', có khả năng mở rộng cho quá trình tiền huấn luyện LLM và được đồng thiết kế với Muon.
Parallax không tìm kiếm hiệu quả bằng cách cắt giảm tính toán. Thay vào đó, nó cố tình tăng cường tính toán, sau đó làm cho việc tính toán đó trở nên rẻ hơn khi chạy trên các GPU hiện đại.
**Parallax là gì?**
Parallax được xây dựng dựa trên Chú ý Tuyến tính Cục bộ (LLA). LLA xuất phát từ khung hồi quy thời gian kiểm tra. Khung này coi chú ý như một bộ giải hồi quy trên các cặp khóa-giá trị (key-value pairs).
Theo quan điểm này, các khóa là các điểm dữ liệu huấn luyện. Các giá trị là các nhãn. Truy vấn là điểm kiểm tra. Chú ý softmax là một ước lượng phi tham số được gọi là Nadaraya-Watson. Nó phù hợp với một hàm hằng số cục bộ cho mỗi truy vấn.
LLA nâng cấp ước lượng hằng số cục bộ đó thành một ước lượng tuyến tính cục bộ. Nhóm nghiên cứu chứng minh rằng điều này mang lại sai số bình phương trung bình tích hợp nhỏ hơn đáng kể. Lợi ích là sự đánh đổi giữa độ chệch và phương sai tốt hơn cho bộ nhớ liên kết.
Tuy nhiên, LLA có một vấn đề về khả năng mở rộng. Việc tính toán chuyển tiếp chính xác của nó đòi hỏi phải giải một hệ phương trình tuyến tính cho mỗi truy vấn. Điều đó sử dụng một bộ giải gradient liên hợp song song (CG). Bộ giải CG tạo ra ba vấn đề: I/O chuyên sâu, sự đánh đổi khó khăn giữa điều hòa và khả năng biểu diễn, và không tương thích với độ chính xác thấp.
Parallax loại bỏ bộ giải. Thay vào đó, nó học một ma trận chiếu bổ sung. Nhóm nghiên cứu viết điều này là ρi = WRxi. Ở đây WR là một ma trận có thể học được, thăm dò hiệp phương sai KV trực tiếp từ đầu vào của lớp.
Vì vậy, Parallax giữ nguyên nguyên tắc tuyến tính cục bộ. Nó chỉ thay thế việc giải cho mỗi truy vấn bằng một bộ chiếu giống truy vấn đã học. Điều đó làm cho nó đơn giản hơn, hiệu quả hơn và dễ thực hiện hơn.
**Cơ chế hoạt động như thế nào?**
Parallax tái cấu trúc LLA thành chú ý softmax cộng với một hiệu chỉnh cộng thêm. Đầu ra bằng đầu ra chú ý softmax trừ đi một số hạng hiệp phương sai được chiếu. Trong ký hiệu của bài báo nghiên cứu, số hạng đó là hiệp phương sai KV nhân với đầu dò đã học ρi.
Nhóm nghiên cứu cũng loại bỏ một phần của LLA được gọi là hệ số khuếch đại biên, được đặt bằng 0. Điều này là cần thiết cho sự ổn định. Một khi đầu dò là tham số, cách giải thích hình học ban đầu sẽ bị phá vỡ. Việc giữ lại hệ số có thể khiến việc mở rộng phân kỳ hoặc đổi dấu.
Parallax nằm trong một họ các cơ chế chú ý. Nhóm nghiên cứu sắp xếp chúng trong bài báo theo ba trục: băng thông, cấu trúc đầu dò và cấu trúc affine. Ở một thái cực, Parallax suy biến chính xác thành chú ý softmax khi chuẩn đầu dò tiến về 0.
Đặt WR = 0 làm cho một lớp Parallax hoạt động giống hệt chú ý softmax. Vì vậy, một điểm kiểm tra Transformer đã được tiền huấn luyện có thể được chuyển đổi bằng cách thêm WR và tinh chỉnh.
**Lập luận về phần cứng**
Parallax kế thừa cấu trúc truyền tải của FlashAttention. Nó bổ sung một nhánh hiệp phương sai tái sử dụng cùng một luồng khóa-giá trị.
Nhóm nghiên cứu mở rộng quá trình chuyển tiếp thành hai nhánh tính điểm song song. Cả hai nhánh đều chia sẻ giá trị cực đại trực tuyến, hệ số điều chỉnh và các ô K và V. Vì vậy, Parallax không cần thêm I/O cho mỗi lần lặp.
Thuộc tính chính là cường độ số học (AI) cao hơn. AI là tỷ lệ giữa các phép toán dấu phẩy động và lưu lượng truy cập bộ nhớ băng thông cao. Trong chế độ mà công việc KV chiếm ưu thế, Parallax tăng gấp đôi cường độ số học. Nó bổ sung tính toán trong khi tái sử dụng cùng một luồng bộ nhớ.
Điều này chuyển sự chú ý sang một chế độ bị giới hạn bởi tính toán nhiều hơn. Đó chính xác là chế độ mà việc tối ưu hóa kernel có ích trên phần cứng hiện đại.
Nhóm nghiên cứu đã tạo mẫu một kernel giải mã trong CuTeDSL trên GPU NVIDIA Hopper. Các lệnh matmul lõi tensor của Hopper hoạt động trên các ô có ít nhất 64 hàng. Một bước giải mã chỉ cung cấp một hàng truy vấn. Do đó, các tích QK và RK có thể được tính toán cùng nhau, trong các lệnh mà cơ chế attention tiêu chuẩn đã phát ra.
Họ đã so sánh với FlashAttention 2 và 3 trên GPU H200 ở độ chính xác BF16. Họ đã thử nghiệm với kích thước lô từ 1 đến 2.048 và độ dài ngữ cảnh từ 128 đến 32.768. Kernel nguyên mẫu đạt hoặc vượt trội hơn FlashAttention trên tất cả các cấu hình. Hình dưới đây chú thích tốc độ tăng 1,54 lần trong cài đặt khớp tính toán và 1,14 lần trong cài đặt khớp I/O.
https://arxiv.org/pdf/2605.29157
Kết quả thử nghiệm
Nhóm nghiên cứu đã xác thực Parallax trên các tác vụ tổng hợp và trên quá trình huấn luyện trước LLM ở quy mô 0,6 tỷ và 1,7 tỷ. Các mô hình đã sử dụng kiến trúc Qwen-3 trong kho lưu trữ torchtitan. Họ đã huấn luyện trên tập dữ liệu Ultra-FineWeb với độ dài ngữ cảnh 4096. Các đường cơ sở bao gồm softmax attention (Transformer), Mamba, Gated DeltaNet, MesaNet và Kimi DeltaAttention.
Trên MAD-Benchmark, Parallax đạt độ chính xác tổng thể cao nhất ở mức trung bình 0,716. Nó liên tục cải thiện các tác vụ định hướng recall như In-Context-Recall và Selective-Copying. Nó vẫn cạnh tranh trên các tác vụ nén và ghi nhớ.
Trong mô hình ngôn ngữ, Parallax với Muon đạt độ phức tạp tốt nhất ở cả hai quy mô. Nó cũng đạt độ chính xác downstream trung bình cao nhất. Ở quy mô 1,7 tỷ, Parallax đạt trung bình 62,45 so với 61,43 của Transformer.
Hai kiểm soát thử nghiệm nguồn gốc của sự cải thiện. Một Transformer khớp tham số chỉ thu hẹp một phần nhỏ khoảng cách. Một Parallax khớp tính toán vẫn đánh bại cả hai đường cơ sở. Bài báo lập luận rằng điều này chỉ ra chính cơ chế, chứ không phải các tham số hoặc tính toán bổ sung.
Sự thay đổi của bộ tối ưu hóa
Một phát hiện cốt lõi là sự tương tác giữa bộ tối ưu hóa và kiến trúc. Parallax cho thấy một lợi thế lớn dưới Muon. Dưới AdamW, lợi thế giảm đáng kể hoặc thậm chí biến mất.
Muon là một bộ tối ưu hóa gần đây cho các tham số ma trận trong các lớp ẩn. Nó sử dụng yếu tố cực của bộ đệm động lượng, do đó các bản cập nhật có số điều kiện chính xác là một. Công trình trước đây cho thấy điều này tạo ra các ma trận trọng số được điều kiện tốt hơn.
Nhóm nghiên cứu trong bài báo đã theo dõi
Nguồn tin: MarkTechPost — Tác giả: Asif Razzaq. Bản dịch tiếng Việt do AI thực hiện, có thể có sai sót.