InstructGPT模型结构

谈到ChatGPT肯定是绕不开instructGPT的,或者说ChatGPT的内核就是instructGPT。那么,想要了解ChatGPT,就是直接了解instructGPT,instructGPT分为如下三大步:

SFT:生成模型GPT的有监督精调 (supervised fine-tuning)
RM:奖励模型的训练(reward model training)
PPO:近端策略优化模型( reinforcement learning via proximal policy optimization)
下面根据这三大步分为三个Step进行讲解以及实操。

SFT(supervised fine-tuning)原理
其实这一步没啥好说的,主要的东西还是大量的Prompt数据,GPT模型通过有监督的Prompt数据进行精调,其实就是做next token prediction任务。然后用精调后的模型对每个输入的[文本+prompt]进行generate,生成4~9个输出,并且进行解码
操作。具体的模型流程如下图所示:

SFT代码实操

数据准备

我们去成都旅游,必须要去的地方是大熊猫繁殖基地。 大熊猫是 一种有黑白斑纹的动物。
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。 大熊猫是 中国特有种,主要栖息地是中国四川、陕西和甘肃的山区。
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。 大熊猫是 已在地球上生存了至少800万年,被誉为“活化石”和“中国国宝”即国兽,世界自然基金会的形象大使,是世界生物多样性保护的旗舰物种。

首先将以上数据形成一个输入列表,如下所示:

Raw Data Prompt Label
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。 大熊猫是 一种有黑白斑纹的动物
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。 大熊猫是 中国特有种,主要栖息地是中国四川、陕西和甘肃的山区。
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。 大熊猫是 已在地球上生存了至少800万年,被誉为“活化石”和“中国国宝”即国兽,世界自然基金会的形象大使,是世界生物多样性保护的旗舰物种。
1
2
3
4
5
6
raw_data = "我们去成都旅游,必须要去的地方是大熊猫繁殖基地。"
prompt = "大熊猫是"
labels = ["一种有黑白斑纹的动物。","中国特有种,主要栖息地是中国四川、陕西和甘肃的山区。",
"已在地球上生存了至少800万年,被誉为“活化石”和“中国国宝”即国兽,世界自然基金会的形象大使,是世界生物多样性保护的旗舰物种。",
"属于熊科、大熊猫属的哺乳动物。仅有二个亚种。雄性个体稍大于雌性。体型肥硕似熊、丰腴富态,头圆尾短,头躯长1.2-1.8米,尾长10-12厘米。"]
combine_data = [raw_data+prompt+label for label in labels]2.

初始化模型

对输入数据进行编码,这里采用的是GPT2模型,如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForCausalLM
# 模型加载
tokenizer = BloomTokenizerFast.from_pretrained('pre_train_model/gpt2')
model = BloomForCaualLM.from_pretrained('pre_train_model/gpt2')
# 自定义DataSet类
class Datasets(Dataset):
def __init__(self, sample):
super(Datasets, self).__init__()
self.sample = sample

def __getitem__(self, item):
res = {k: v[item] for k, v in self.sample.items()}
return res

def __len__(self):
return len(self.sample['labels'])
# 数据转换
combine_data_token = tokenizer.batch_encode_plus(
initial_data_,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 将标签标签加入
combine_data_token['labels'] = combine_data_token['input_ids']
combine_data_token['labels'] = torch.where(
combine_data_token['labels']==0,
-100,
combine_data_token['labels']
)
# 模型训练保存
trainer_args = TrainingArguments("./model/", learning_rate=2e-5, weight_decay=0.01, num_train_epochs=10, auto_find_batch_size=True)
trainer = Trainer(model=initial_model, args=trainer_args, train_dataset=Datasets(initial_token_info))
trainer.train()
trainer.save_model()

模型加载/生成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
model = AutoModelForCausalLM.from_pretrained('./model')
# 处理输入数据
input_data = raw_input + prompt
input_datas = tokenizer.encode_plus(
input_data,
return_tensors='pt'
)
input_ids = input_datas['input_ids']
# 模型生成
result = model.generate(
input_ids=input_ids,
max_length=256,
do_sample=True, # 增加随机性
num_beams=5,
num_return_sequences=5, # 每个样本生成5个结果
no_repeat_ngram_size=3, # 防止重复的token
early_stopping=True # 提前停止
)

decode_tokens = tokenizer.batch_decode(
result,
skip_special_tokens=True
)

results = [i.replace(' ', '') for i in decode_tokens]

print("results",results)

结果如下所示:

我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是今世界上保存最完好的哺乳动物之一,也是世界自然保护联盟濒危物种红色名录的保护对象之一。在这里,你可以看到全世界最大的熊猫栖息地成都。成都是中国国家林业局直属的国家重点风景名胜区,是国家森林公园、国家湿地公园和国家地质公园的重要组成部分,是全国重点文物保护单位、全国生态文明建设示范区、中国红色旅游名城、国际生态旅游目的地和国际旅游岛建设先进区。地址:四川省成都市绵阳市成华区成都高新技术产业开发区成华大道1号乘车路线:成都绵阳都江堰雅
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家二级保护动物,是世界上保存最完整的动物种群之一,也是我国第一个国家级自然保护区。我们是四川省的首批国家重点保护野生动物和珍稀动物基金会的成员,被誉为中国动物保护的摇篮和世界生物多样性保护基地,被中国科学院、中华人民共和国国家林业局授予全国生态文明建设示范区称号,被国务院批准为国家森林城市、国际生态旅游目的地。熊猫基地位于成都市双流区东南部,是国家aaaa级旅游景区,国家地理标志保护单位。熊猫栖息地为亚热带或热带的高山
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家级自然保护区,也是世界上保存最完好的熊猫种群之一。它们栖息在亚热带或热带的高海拔草原上,生活
环境十分优越,是中国四大自然奇观之一,被誉为世界自然遗产和中国国家森林公园。熊猫栖息地主要分布在中国大陆的西藏、青海、甘肃、宁夏、新疆、内蒙古、山西、辽宁、吉林、黑龙江、江苏、河南、安徽、湖北、湖南、江西、广东、海南、四川、云南、贵州、陕西等地。中国熊猫研究中心主任、中国科学院院士、国家自然科学基金委员会委员、中华全国工商业联合会副主席
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家级自然保护区,也是世界上保存最完整、规模最大的野生动物种类繁多的地区之一,是中国国家重点保护的珍稀濒危动物及其栖息地和世界自然遗产的重要组成部分,被誉为中国最美丽的城市和世界生物多样性保护基地,被国际旅游组织评为全球生态旅游目的地。成都熊猫国家公园位于四川省甘孜藏族自治州,是国家aaaa级旅游景区,被《世界遗产名录》列为全国重点文物保护单位。目前,我国已建成国家森林公园、国家湿地公园和国家地质公园,国家林业局、国务院扶贫
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是现存最大、保存最完整的动物,属于国家二级保护动物。熊猫种类繁多,分布广泛,主要分布在四川、云南、陕西、甘肃、宁夏、内蒙古、新疆、青海、吉林、辽宁、黑龙江、山西、江苏、江西、河南、湖北、湖南、广东、广西、海南、重庆、贵州、西藏、四川等省区市。它们的栖息地主要为亚热带或热带的(低地)湿润低地林、亚高山草原、高山湖泊、高原湿润山区和高原沼泽地等,常栖息在高海拔地区。在中国大陆,熊猫分布于四川省甘孜藏族自治州和青海省西宁市等地。雄性熊猫体长约1.5米