آموزش مدل‌های بزرگ وقتی بسته‌ها تو راه گم میشن! (یا: چطوری با اینترنت نصفه‌نیمه هم میشه هوش مصنوعی یاد داد؟)

Fall Back

خب بچه‌ها، امروز می‌خوام راجع به یه موضوع جالب حرف بزنم: آموزش مدل‌های هوش مصنوعی گنده، مخصوصاً اونایی که زبان یا تصویر رو پردازش می‌کنن، درست وقتی که اینترنت یا شبکه وسط کار «نصفه‌نیمه» عمل می‌کنه و بعضی بسته‌ها تو راه گم میشن!

اول بذارید یه مقدمه رفیقونه بگم. الان تو دنیا برای اینکه مدل‌های خیلی پیشرفته مثل زبان (مثلاً ChatGPT) یا بینایی (کارهایی مثل تشخیص تصویر)، حسابی آموزش ببینن و قوی بشن، باید رو هزار تا GPU اجرا بشن — اونم معمولاً تو دیتاسنترهای مختلف که هر کدوم جای دنیاست. حالا تو این داستان یه مشکل بزرگ هست: سیستم‌های امروزی فرض کردن که ارتباط بین این ماشین‌ها همیشه مطمئنه! مثلاً با فناوری‌هایی مثل InfiniBand یا RoCE کار می‌کنن، که اینا شبکه‌هایی خیلی سریع و قابل اعتمادن.

حالا مشکل چیه؟

دقت کن! وقتی تو شبکه ارتباط مطمئن باشه، هی باید مطمئن بشی هر پیامی رسیده یا نه، اگه نرسیده بازفرستی. این باعث میشه کلی پیغام اضافی، acknowledgment یعنی اعلان رسید، و دیتای تکراری رد و بدل شه. خلاصه، این داستان باعث میشه سرعت کار بیاد پایین و دیگه آموزش مدل‌ها اون‌قدری که می‌خوای مقیاس‌پذیر یا سریع نباشه.

اما اگه بگیم «بیخیال مطمئن بودن! ارتباط نامطمئن باشه»، خب، بسته‌ها ممکنه تو راه گم بشن؛ سرعت میره بالا، ولی یه جورایی ممکنه هم مدل خوب آموزش نبینه، هم دقتش خراب شه.

حالا اون چیزی که تو دیتاسنترهای معمولی مشکل‌زاست همین بسته گم شدنه — یعنی packet loss. اصلاً packet loss یعنی وقتی یه تیکه از دیتایی که تو شبکه می‌فرستی، به هر دلیلی (شلختگی اینترنت!) نمی‌رسه.

مسئله اینجاست که کسی تا الان یه راه‌حل دقیق، اصولی، و تضمینی نداشته که با شبکه نامطمئن هم بشه آموزش مدل رو دقیق و درست انجام داد. یعنی یا باید قید سرعت رو بزنی و همه چیز مطمئن باشه، یا قید دقت رو بزنی و سریع باشه ولی مدل هر چی می‌خواد درمیاد!

خبر خوب: یه تیم باحال یه چارچوب آموزش توزیع‌شده جدید ساخته که دقیقاً برای همین موقع‌هاست: یعنی وقتی بسته‌ها گم میشن و قرار هم نیست شبکه خیلی پیشرفته باشه. توی این روش دیگه لازم نیست مدل یا optimizer خودت رو تغییر بدی. (optimizer یعنی اون بخشی که تصمیم می‌گیره مدل چجوری «یاد بگیره»، مثلاً با SGD یادگار قدیمی‌ها!)

حالا چطوری اینا بسته‌های گم‌شده رو هندل می‌کنن؟ راهکار چند مرحله‌ای دارن:

۱. جمع‌آوری گرادیان بی‌طرفانه

منظور از
«گرادیان» تو یادگیری ماشین اینه که مدل چقدر باید پارامترهاش رو تغییر بده تا بهتر شه. حالا چون بسته‌ها ممکنه نرسن، هر worker (همون ماشین یا GPU که داره مدل رو یاد می‌گیره) یاد گرفته که با هرچی گرادیان به دستش می‌رسه، یه تخمین درست بزنه. یعنی حتی اگه نصف اطلاعات بیاد، نتیجه کار از لحاظ متوسط درست درمیاد.

۲. انتشار پارامتر با اختلاف کنترل‌شده

تو آموزش موازی ممکنه هر worker مدلش یه کم فرق کنه چون بعضی دیتاها وسط راه گم شدن. اینجا، کاری کردن که این اختلاف هیچ‌وقت زیادی بزرگ نشه (به قول ریاضی‌دان‌ها: O(1) بمونه)، یعنی هر چقدر هم تکرار کنیم، مدل‌ها هیچ‌وقت از کنترل خارج نمی‌شن. این جلوی اون «واگرایی» بدجور رو می‌گیره که تو روش‌های آسینکرونیک داریم. (واگرایی یعنی مدل‌ها مسیر خودشون رو برن و دیگه بهم ربطی نداشته باشن!)

هم از نظر تئوری هم آزمایش مثال زدن! مثلاً با مدل LLAMA2 (که ۷ میلیارد پارامتر داره — یعنی واقعاً گنده‌ست!) گرفتن و رو ۶۴ تا GPU اجرا کردن. حتی وقتی ۱۰٪ بسته‌ها تصادفی گم می‌شدن، فقط ۰/۸٪ تغییر تو perplexity مدل دیدن (perplexity یعنی چقدر مدل تو پیش‌بینی بعدی گیج میشه — که هر چی کمتر باشه بهتره).

نتیجه جالبی که این کار داشته اینه که مثل یه پل زده بین دیتاسنترهای معمولی و سازمانی (که پروتکل‌های ارتباطی اقتصادی و سریع دارن) و تیم‌هایی که میخوان دقت مدلشون بالا بمونه، مخصوصاً موقعی که می‌خوان مدل خیلی بزرگی رو آموزش بدن. دیگه لازم نیست فقط رو شبکه‌های گران و خاص حساب کنی — روی حتی شبکه‌های کم‌ کیفیت و عادی هم میشه آموزش قوی و درست داشت و پکت گم شده باشه، باز نتیجه عالی دربیاد!

خلاصه، راه‌حلشون باعث میشه بتونی با منابع معمولی هم مدل حرفه‌ای بسازی، بدون اینکه حساسیت زیادی به کیفیت شبکه داشته باشی. یعنی آینده آموزش مدل‌های هوش مصنوعی، آزادتر و دم‌دستی‌تر واسه همه میشه!

اگه کنجکاوی اسم این مقاله چی بود، “Distributed Training under Packet Loss” بود؛ یعنی «آموزش توزیع‌شده وقتی بسته‌ها گم میشن!» خلاصه، یه قدم بزرگ برای همه ما که دوست داریم مدل‌های گنده رو با هرچی سخت‌افزار دم دست هست، روشون کار کنیم.

منبع: +