▷事後学習:DPOによるアライメント
▶︎はじめに:
本文書は、NEDOプロジェクトGENIACに採択された松尾研プロジェクトの一環として、JINIAC班が実施したアライメント手法、特にDPO(Direct Preference Optimization)に関する記録をまとめたものです。GENIACプロジェクトでは、7つの班がそれぞれ競い合いながら、LLM(大規模言語モデル)をスクラッチから構築することを目指しており、JINIAC班はそのうちの1つです。
我々の班の主な目標は、日本語データセットの不足に対処し、日本語の豊かな特性を生かした自然な生成を実現するために、知識転移に基づいた新たなアプローチを模索することにあります。また、この過程で次世代のLLM人材の育成にも寄与することを目指しています。本文書では、DPOを用いたアライメント手法の実施経過やその結果について詳述し、得られた知見を共有することを主眼としています。
▶︎携わったJINIACメンバー:
河本さん、森永さん、高木さん、山口さん、鎌田さん、西前さん、岡さん
▶︎背景
単一の大規模言語モデル(LLM)構築のステップ
スクラッチから単一の大規模言語モデル(LLM)を構築するには、大きく分けて次の4つのステップがあります。
-
事前学習
- 文法、語彙、一般的な知識を獲得する段階
-
ファインチューニング
- 特定のタスクやドメインに適用するために調整する段階
-
インストラクションチューニング
- 特定の指示やタスクに従って動作するように調整する段階
-
アライメント
- モデルの出力が人間の価値観や倫理観に沿ったものであることを確保するプロセス
DPO(Direct Preference Optimization)について
DPOはオフラインRLHF(Reinforcement Learning from Human Feedback)の一つで、主に上記の4. アライメントを行いますが、2. ファインチューニングや3. インストラクションチューニングにも関わります。
RLHFとオフライン・オンラインRLHFについて
RLHF(人間のフィードバックによる強化学習)とは、AIモデルの出力を人間の好みや評価に基づいて改善する手法です。AIが生成した複数の回答に対して人間が評価を行い、その評価を基にモデルを調整します。RLHFには主に2つのアプローチがあります:
- オンラインRLHF:
- 継続的に新しいデータを収集しながら学習を進めます。
- 例えばPPO(Proximal Policy Optimization)などの手法が用いられます。
- リアルタイムで人間のフィードバックを取り入れられる利点がありますが、計算コストが高くなりがちです。
- オフラインRLHF:
- 事前に収集された人間の好みやフィードバックのデータセットを使用して学習します。
- DPO(Direct Preference Optimization)はこのカテゴリーに属します。
- オンラインRLHFと比べて計算コストが低く、より手軽にAIのアライメント(人間の意図や価値観との調整)ができるとされています。
- ただし、新しい状況への適応には制限がある場合があります。
オフラインRLHFの方法は、大規模な事前学習済みモデルを効率的に調整でき、本プロジェクトの趣旨にも合致するため、アライメント手法として最もよく利用されるアルゴリズムの一つであるDPOを、JINIACでは採用しました。
Phase1 JINIACにおけるDPOの利用
Phase1 JINIACでは、事前学習とSFT(Supervised Fine-Tuning、上記2. ファインチューニング)を行ったモデルに対して、DPOによるアライメントを実施し、性能向上を図りました。
▶︎検討
オフラインRLHFの新規手法は日々提案され、進展しています。これらは随時ベンチマークされ、それに併せて実装コードも公開されています。
例えば、トランスフォーマーベースのモデルに対してオフライン/オンラインRLHFを適用するためのモジュールとして、**TRL(Transformer Reinforcement Learning)**が公開されており、日々更新されています。
一方で、アライメントを行う際に、以下の2点について特に日本語LLM構築に際して一般的な知見が無く、検討が必要でした。
- より効果的なアライメントを行うための最適なパラメータ
- 必要なデータ量
Phase1 JINIACでは、2024年4月時点で有力なオフラインRLHF手法の幾つかについてベンチマークしていた以下の論文に注目し、検討を実施しました。
本来であれば日本語LLMについて、この論文の方向性でのベンチマークをとって確認する必要がありましたが、今回は時間の制約があり、Phase1 JINIACでは独自ベンチマークを行えませんでした。
ただ、アライメントビギナーである私たちにとって示唆的な観点は幾つかありました。例えば、
- SFTベースのDPO/IPO/CPO/KTOチューニングでは、必要となるデータはそれほど多くない:5K or 10K
この点は、上記の2.の観点で注目していました。
- Phase1 JINIACではDeepSeekMoEで事前学習を行なっていたので、Mistral-7B-v0.1をベースに行っていた上記論文の結果や示唆は直接的には適用できないと考えていました。また、上記論文で使用されている一般的なデータセットUltraFeedback binarized (Tunstall et al., 2023)はGPT-4の出力を使用しているため、本コンペでは使用できません。
時間的な制約もあったため、データ量の観点2.からは暫定的に、
- 1K, 10K, 100K
でDPOアライメントを実施し、結果を比較することにしました。
さらに、上記論文の着目点として以下の手法にも注目していました。
- IPO(Identity-PO)
- KTO(Kahneman-Tversky Optimization)
これらの2手法は、上記論文でのベンチマークにおいて高いスコアを出す傾向があるとのことで注目していました。
▶︎データ作成
DPOを行うためのデータセットは、次の「prompt」「chosen」「rejected」から成る形式である必要があります:
dpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}
DPO用データの課題と対応
DPO用のデータに関して、以下の課題がありました。
- データ数が多くなく、日本語データの場合は自動翻訳によるものが多く、品質が低いことが問題でした。
背景で述べた通り、倫理的なガイドラインの学習のためにはllm-jp/hh-rlhf-12k-jaを使用するのが良さそうでしたが、正確な日本語の学習を行うためのデータセットが必要でした。
そのため、Phase1 JINIACで作成していた省庁記者会見データセット(厚生労働省、文部科学省、総務省、国土交通省、金融庁、農林水産省の6省庁)を使用し、以下の手順でデータセットを準備しました。
-
日→英→日と翻訳を行い、元のデータを
chosen、逆翻訳したデータをrejectedとして設定。 - これにより、正確な日本語の学習を行うためのデータセットを準備しました。
▶︎コード作成
足がかりとしては、このページの指針を参考にしました。
https://note.com/npaka/n/n23576a1211a0
https://colab.research.google.com/drive/1PkUo0NubEB1XWwcJ23-CRjUmZriO3Gy9?usp=drive_link
TRLのLoRA+DPOTrainerを使った学習コードの概要
TRLのLoRAとDPOTrainerを使ったコードを作成し、以下の環境で学習を行いました。
- NVIDIA GPU:本番環境ではH100、Google ColabではA100を使用。
- 事前学習:DeepSpeed ZeRO Stage1を使用。
学習環境の設定
-
Accelerateモジュールを使用した際の
config設定や、下記パラメータ設定は、事前学習の設定と合わせています。(config設定ファイルはdefault_config.yamlを参照。)
# SFT済みモデルの準備
model = AutoModelForCausalLM.from_pretrained(
"...",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
# load_in_8bit=True,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
# 参照モデルの準備
model_ref = AutoModelForCausalLM.from_pretrained(
"...",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
#load_in_8bit=True,
)
model_ref.config.pretraining_tp = 1
# トークナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(
"...",
use_fast=False,
pad_to_max_length=False,
truncation=True,
max_length=max_length
)
# LoRAパラメータ
peft_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear"
)
# 学習パラメータ
training_args = TrainingArguments(
output_dir="./output_dir",
fp16=False,
bf16=True,
max_steps=300,
num_train_epochs=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=1,
optim="paged_adamw_32bit",
lr_scheduler_type="cosine",
max_grad_norm=0.3,
weight_decay=0.001,
report_to="tensorboard",
save_strategy="epoch",
evaluation_strategy="steps",
eval_steps=10,
logging_steps=50,
learning_rate=5e-5,
warmup_ratio=0.1
)
# DPOトレーナーの準備
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.5,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=peft_config,
force_use_ref_model=True
)
default_config.yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
zero3_init_flag: false
zero_stage: 1
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
データの使用と前処理
使用したデータセットは以下の通りです:
- llm-jp/hh-rlhf-12k-ja(倫理的なガイドラインの学習)
- 省庁記者会見データセット(正確な日本語の学習)
これらのデータを500ずつ混合し、シャッフルさせたものを使用しました。また、データに応じた前処理を検討し、実施しました。
パラメータの検討と実装
背景で述べた(1)に関わるパラメータ検討において、以下の指針が議論を通じて打ち出されました。
-
LoraConfigでは
target_modules="all-linear"に設定する。 -
DPOTrainerでは
beta=0.5に設定するのが良さそう。 - 評価には
evaluation lossを見るのが良いので、TrainingArgumentsにevaluation_strategy、eval_stepsを入れておく(ただし、学習に時間がかかるようになる)。
これらの指針は、上記の実装に反映されました。
結果の評価
結果の評価は、固定した複数のプロンプトに対するモデルの推論結果を定性的に評価しました。主なプロンプトは以下の通りです。
「古代ギリシャを学ぶ上で知っておくべきポイントは?
古代ギリシャは、古代文明の中で重要な役割を担った文化であり、西洋文明の原点とされています。」
「仕事の熱意を取り戻すためのアイデアを5つ挙げてください。
1. 自分の仕事に対する興味を再発見するために、新しい技能や知識を学ぶこと。」
「User: 以下のメールに返信してください。
お疲れ様です。本日体調不良により、予定より到着が少し遅れてしまいそうです。遅くとも13時過ぎには着くと思います。ご迷惑をおかけして恐縮ではございますが、何卒ご容赦いただけますようお願い申し上げます。
Assistant: 」
「以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
### 指示:
数字の数え方を覚えましょう。
153663の次の数字はなんでしょう。
### 応答:」
▶︎コード
前述迄で述べたデータ処理、結果生成は、conda環境default_test.ymlに於いて、コードdpo_testcode.pyを、以下の実行コマンドで実行することで得られます。
accelerate launch --num_processes 1 dpo_testcode.py | tee "output_$(date +%Y%m%d_%H%M%S).txt"
▶︎ベンチマーク
省庁記者会見データセットにおいて、データ数やステップ数などの組み合わせを複数通り変えてベンチマークを行いました。その結果、以下のことが判明しました。
- 文部科学省のデータセットが、正確な日本語の学習に最も寄与していた。
- どの設定でも、メールの返信タスクは難しく、幾つかの質問では繰り返し回答が見られました。
▶︎最終提出モデル
最終的に、次のSFT済みモデル(チェックポイント500)に、上記設定でDPOを適用したモデルをPhase1 JINIAC最終提出モデルとしました:
- weblab-GENIAC/team_nakamura_submit
複数のチェックポイント(100, 500, 2600)に対してSFT済みモデルをDPOし、スコアの高い順に結果を示します。
Jasterスコア(高い順)
- チェックポイント 500: 0.0851
- チェックポイント 100: 0.0707
- チェックポイント 2600: 0.06
MT-benchスコア(高い順)
- チェックポイント 2600: 1.438
- チェックポイント 500: 1.294
- チェックポイント 100: 1.219
Jasterの評価ではチェックポイント500が、MT-benchの評価ではチェックポイント2600が最も高いスコアを記録しました。しかし、チェックポイント2600のMT-bench評価をフェーズ1の期限内に完了できなかったため、JasterとMT-benchの両方で良好な結果を示したチェックポイント500のモデルを最終的に提出しました。
▶︎Phase1報告会で報告した内容について
Phase1報告会向けの資料作成時に、倫理的なガイドラインの学習がうまく行えているか改めて評価したところ、良い評価が得られませんでした。
改めてデータ数を6,000まで拡張してベンチマークをとったところ、「爆弾の作り方を教えて」という問いには否定的な回答をすることが確認されました。
東京大学 松尾・岩澤研究室が運営する松尾研LLMコミュニティのLLM開発プロジェクト[GENIAC] の開発記録、情報発信になります。 各種リンクはこちら linktr.ee/matsuolab_community
