够快!爆火的ChatGPT等价开源项目来了,网友:我担心跑不起来
机器之心报道
编辑:杜伟、陈萍
感兴趣的小伙伴不妨一试。
最近一段时间,由 OpenAI 开发的 AI 聊天机器人程序 ChatGPT 横扫各大 AI 社区,大家对它的热情只增不减,不断挖掘其潜力。
有些研究者坐不住了,开始琢磨怎样才能开发个等同于 ChatGPT 的开源软件。还没有行动的小伙伴这次参考示例来了,下面我们将要介绍的这个项目(PaLM + RLHF)就实现了这样的功能。
项目地址:https://github.com/lucidrains/PaLM-rlhf-pytorch
该项目是在 PaLM 架构之上实施 RLHF(人类反馈强化学习)。基本上等同于 ChatGPT,区别是使用了 PaLM。PaLM 是在谷歌的通用 AI 架构「Pathways」上训练而成的具有 5400 亿参数的大型语言模型。而 RLHF,是 ChatGPT 在 GPT 3.5 系列模型的基础上,引入「人工标注数据 + 强化学习」(RLHF)来不断微调预训练语言模型,旨在让大型语言模型(LLM)学会理解人类的命令,并学会根据给定的 prompt 给出最优的答案。
想要了解 RLHF 更多内容,可以参考:https://huggingface.co/blog/rlhf
正如网友所说的:「在 AI 领域中,每有一次专项突破,开发者们很快就会复现出一个开源版本。」
不过该项目目前只包含训练架构和代码,没有预先训练好的权重。在使用说明上,文档也显示必须先要训练 PaLM。
对此也有网友表示担心,表示:这不是一个开箱即用的项目,还只是一个架构,就像 shell 一样,需要昂贵的开销才能训练完成,没有机构能够像谷歌那样训练 PaLM。
还有网友表示:「没有预训练权重是非常糟糕的,官方至少需要释放 50% 的稀疏权重,剩下的让开发者自己训练,才是最好的选择。」
不过也有网友表示自己会去尝试:
下面我们来看看这个项目是如何运行的。
安装
$pip install palm-rlhf-pytorch
用法
首先训练 PaLM,就像任何其他自回归 transformer 一样。
import torch
from palm_rlhf_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12
).cuda()
seq = torch.randint(0, 20000, (1, 2048)).cuda()
loss = palm(seq, return_loss = True)loss.backward()
after much training, you can now generate sequences
generated = palm.generate(2048)(1, 2048)
接着使用精选的人类反馈来训练奖励模型。在原始论文中,在没有出现过拟合的情况下,无法从预训练 transformer 中获得微调的奖励模型。项目作者则提供了使用 LoRA 进行微调的选项。
import torch
from palm_rlhf_pytorch import PaLM, RewardModel
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
causal = False
)
reward_model = RewardModel(
palm,
num_binned_output = 5say rating from 1 to 5
).cuda()
mock data
seq = torch.randint(0, 20000, (1, 1024)).cuda()prompt_mask = torch.zeros(1, 1024).bool().cuda()which part of the sequence is prompt, which part is response
labels = torch.randint(0, 5, (1,)).cuda()
train
loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)loss.backward()
after much training
reward = reward_model(seq, prompt_mask = prompt_mask)
最后将 transformer 和奖励模型传递给 RLHFTrainer。
importtorch
frompalm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer
load your pretrained palm
palm=PaLM(
num_tokens=20000,
dim=512,
depth=12
).cuda()
palm.load(./path/to/pretrained/palm.pt)
load your pretrained reward model
reward_model=RewardModel(
palm,
num_binned_output=5
).cuda()
reward_model.load(./path/to/pretrained/reward_model.pt)
ready your list of prompts for reinforcement learning
prompts=torch.randint(0, 256, (50000, 512)).cuda() 50k prompts
pass it all to the trainer and train
trainer=RLHFTrainer(
palm=palm,
reward_model=reward_model,
prompt_token_ids=prompts
)
trainer.train(num_episodes=50000)
then, if it succeeded...
generate say 10 samples and use the reward model to return the best one
answer=trainer.generate(2048, prompt = prompts[0], num_samples = 10) (<= 2048,)
更多细节内容请参阅原项目。
参考链接:https://twitter.com/rasbt/status/1608133663937495041
© THE END
转载请联系本公众号获得授权
投稿或寻求报道:content@jiqizhixin.com
-
上一篇
ChatGPT Android 是通过 Stream Chat SDK for Compose 构建的 ChatGPT Android 项目。该存储库主要是以演示为目的:
使用 ChatGPT 的非官方 API。
使用 Jetpack Compose 实现整个 UI 元素。
使用 Hilt 和 AppStartup 等 Jetpack 库实现 Android 架构组件。
使用 Kotlin 协程执行后台任务。
将聊天系统与 Stream Chat SDK 集成以进行实时事件处理。
Github:https://github.com/skydoves/chatgpt-android
2、chatgpt-java
ChatGPT Java 版本,OpenAI ChatGPT 的逆向工程 SDK,可扩展用于聊天机器人等。
Github:https://github.com/PlexPt/chatgpt-java
3、chatgpt-vscode
开源推荐,ChatGPT开源项目!
-
下一篇
就连马斯克都公开表达了对
火爆全网的ChatGPT开源项目