برش هوشمند مدل‌های بزرگ با StruPrune – وقتی حافظه‌ی کارت گرافیک کم میاد!

Fall Back

احتمالاً اگه با مدل‌های زبانی خیلی بزرگ (LLMs) کار کرده باشین، می‌دونین که چقدر حافظه‌ی کارت گرافیک سریع پر میشه! این مدل‌ها گاهی میلیاردها پارامتر دارن و وقتی می‌خواییم کاری مثل pruning یعنی هرس کردن پارامترهای اضافی انجام بدیم تا سرعت و کارایی‌شون بهتر شه، معمولاً یه مشکل بزرگ داریم: مصرف حافظه! حالا بیاین ببینیم دنیاش چجوریه و چی کار کردن که راحت‌ترش کنن.

اصلاً pruning چیه؟ وقتی مدل، پارامترهای زیادی داره، همه‌شون به یه اندازه مهم نیستن. بتونیم اونایی که اهمیت کمتری دارن رو حذف کنیم (یا به قول خارجی‌ها prune کنیم)، هم مدل سبک‌تر میشه، هم مصرف انرژی و حافظه میاد پایین، هم سرعت بهبود پیدا می‌کنه.

تا الان چند مدل مختلف pruning داشتیم، دوتاش رو اینجا توضیح میدم:

  1. Global pruning: یعنی همه پارامترها رو با هم بررسی می‌کنیم و هر پارامتری که مهم نبود رو قطع می‌کنیم. کیفیتش خوبه ولی نیاز به حافظه خیلی زیاد داره (یعنی باید کل مدل رو یکجا تو حافظه بار کنیم). برای مدل‌هایی که میلیاردها پارامتر دارن، انصافاً شدنی نیست!

  2. Local pruning: این دفعه میایم فقط یه لایه از مدل رو همزمان تو حافظه میاریم و همون رو prune می‌کنیم. این باعث میشه مصرف حافظه کم بشه ولی یه مشکل داره: دیگه به روابط بین لایه‌ها توجهی نمیشه و کارایی مدل مخصوصاً وقتی خیلی پارامتر حذف می‌کنیم، ممکنه کمتر شه.

حالا می‌رسیم به structured pruning. این یعنی چی؟ به زبان ساده، به جای اینکه هر پارامتر رو جدا حذف کنیم، میایم ساختارمندی حذف می‌کنیم (مثلاً کل ردیف یا ستون وزن‌ها رو باهم حذف کنیم). این کار خیلی با سخت‌افزار سازگارتره (یعنی GPU-Kernel ها راحت‌تر اجراش می‌کنن و کتابخونه‌ها بهش بهینه شدن)، اما معمولاً نیاز به همون Global pruning داره و اگه بخوایم locally انجام بدیم، مدل خیلی افت می‌کنه.

اینجا جرقه‌ی ایده‌ی StruPrune زده میشه. تیم مقاله اومدن یه روش هوشمندانه پیشنهاد دادن تا مزیت هر دو دنیا رو داشته باشیم: هم حافظه‌ی کمی مصرف کنیم، هم مدل‌زدنی ساختارمند باشه و با سخت‌افزار هماهنگ بمونه. ایده‌شون اینه که کار رو به چند زیرمسئله کوچیک تقسیم می‌کنن (divide and conquer)، هرکدوم به اندازه هر ماژول مدل جداگانه، توی حافظه جا میشه. بعد این زیرمسئله‌ها با هم هماهنگ میشن تا هدف کلی هرس کردن وهمچنان ساختارمندی مدل رو حفظ کنن.

برای این کار، یه فریم‌ورک به اسم StruPrune طراحی کردن. (این اسم ترکیب Structured Pruning هست) و بر اساس ADMM ساخته شده — اینم یه روش ریاضی و بهینه‌سازی هست برای حل مسائلی که محدودیت دارن و دنبال جواب بهینه می‌گردن.

جالب اینجاست که واسه StruPrune تونستن یه فرمول تحلیلی دقیق برای ماسک‌های هرس ساختارمند بسازن. این یعنی دقیقاً می‌دونیم چطور پراکندگی هرس به صورت لایه به لایه باید باشه! حتی یه چارچوب مبتنی بر انرژی هم دادن و از اون حالت softmax استفاده کردن (softmax هم یعنی هر لایه به نسبت اهمیت، سهمش رو از مقدار هرس شدن می‌گیره) تا تخصیص بهینه‌تر بشه و دیگه لازم نباشه واسه هر لایه به طور جدا چونه بزنیم.

حالا نتیجه چی شد؟ توی آزمایش‌هاشون نشون دادن که StruPrune تقریباً همون کیفیت Global Structured pruning رو میده (مثلاً معیار Perplexity که یه جور سنجش مدل‌های زبانیه رو حفظ می‌کنه)، اما مصرف حافظه‌ش از مقدار O(N) اومده پایین به O(√N) — یعنی موقع استفاده، لازم نیست حافظه به اندازه کل مدل داشته باشین؛ حتی با مدل‌هایی که میلیاردی پارامتر دارن میشه این روش رو پیاده کرد!

خلاصه اگه دوست دارین مدل‌تون رو سبک و چابک کنین، ولی نه با هزینه‌ی کل حافظه‌ی دنیا، StruPrune می‌تونه دوست شما باشه! مخصوصاً برای کسایی که میخوان مدل‌های بزرگ رو سریع و کارآمد تربیت کنن یا به کار بگیرن، این تکنیک کلی راه رو باز می‌کنه.

اگه اسم‌ها و تکنیک‌های بالا براتون تازه بودن، بدونین دنیای یادگیری ماشین هر روز یه چیزی نو درمیاره، و StruPrune واقعاً از اون ابداع‌های باحال و کاربردیه که شاید خیلیا سر و صداش رو بزودی بیشتر بشنون!

منبع: +