from datasets import Dataset
data = { "prompt": ["解释相对论", "如何学习Python"], "chosen": ["爱因斯坦提出的时空理论...", "从基础语法开始,多写代码..."], "rejected": ["相对论是关于速度的", "看视频就够了"] }
dataset = Dataset.from_dict(data)
from torch.utils.data import DataLoader
def collate_fn(batch): return { "prompt": [x["prompt"] for x in batch], "chosen": [x["chosen"] for x in batch], "rejected": [x["rejected"] for x in batch] }
data_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token
policy_model = AutoModelForCausalLM.from_pretrained(model_name)
ref_model = AutoModelForCausalLM.from_pretrained(model_name) for param in ref_model.parameters(): param.requires_grad = False
import torch def get_log_probs(model, tokenizer, prompts, responses): texts = [p + r for p, r in zip(prompts, responses)] inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True) outputs = model(**inputs, labels=inputs["input_ids"]) log_probs = -outputs.loss * inputs["input_ids"].shape[1] return log_probs
import torch.nn.functional as F def dpo_loss(policy_chosen_logps, policy_reject_logps, ref_chosen_logps, ref_reject_logps, beta=0.1): policy_logratios = policy_chosen_logps - policy_reject_logps ref_logratios = ref_chosen_logps - ref_reject_logps losses = -F.logsigmoid(beta * (policy_logratios - ref_logratios)) return losses.mean()
from torch.optim import AdamW
optimizer = AdamW(policy_model.parameters(), lr=5e-5) beta = 0.1
from tqdm import tqdm for epoch in range(3): for batch in tqdm(data_loader): policy_chosen_logps = get_log_probs(policy_model, tokenizer, batch["prompt"], batch["chosen"]) policy_reject_logps = get_log_probs(policy_model, tokenizer, batch["prompt"], batch["rejected"]) ref_chosen_logps = get_log_probs(ref_model, tokenizer, batch["prompt"], batch["chosen"]) ref_reject_logps = get_log_probs(ref_model, tokenizer, batch["prompt"], batch["rejected"])
losses = dpo_loss(policy_chosen_logps, policy_reject_logps, ref_chosen_logps, ref_reject_logps, beta)
optimizer.zero_grad() losses.backward() optimizer.step() print(f"Epoch: {epoch}, Loss: {losses.item(): .4f}")
def generate_response(model, tokenizer, prompt): inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate(**inputs, max_length=100) return tokenizer.decode(outputs[0], skip_special_tokens=True)
test_prompt = "怎么学习Python" print("优化后:", generate_response(policy_model, tokenizer, test_prompt)) print("参考模型输出:", generate_response(ref_model, tokenizer, test_prompt))
policy_model.save_pretrained("dpo_finetuned_model") tokenizer.save_pretrained("dpo_finetuned_model")
|