作者 | Team PyTorch
译者 | Monanfei
责编 | 夕颜
出品 | AI科技大本营(ID: rgznai100)
导读:6月11日,Facebook PyTorch 团队推出了全新 API PyTorch
Hub,提供模型的基本构建模块,用于提高机器学习研究的模型复现性。PyTorch Hub
包含一个经过预训练的模型库,内置对Colab的支持,而且能够与Papers With Code 集成。另外重要的一点是,它的整个工作流程大大简化。
简化到什么程度呢?Facebook 首席 AI 科学家Yann LeCun 兼图灵奖图灵奖得主Yann LeCun发表 Twitter强烈推荐,使用
PyTorch Hub,无论是ResNet、BERT、GPT、VGG、PGAN 还是 MobileNet 等经典模型,只需输入一行代码,就能实现一键调用。
Twitter 一发,立刻引来众多网友评论点赞,并有网友表示希望看到PyTorch Hub 与TensorFlow Hub的区别。
这个模型聚合中心到底如何呢?我们来一探究竟。
模型复现是许多领域的基本要求,尤其是在与机器学习相关的邻域中。然而,许多机器学习相关的出版物,要么不可复现,要么难以复现。随着出版物数量的不断增长(包括在
arXiv
上发表的成数万篇论文,以及会议提交的大量论文),模型复现比以往任何时候都更加重要。虽然这些出版物大多数都包含代码和训练好的模型,但如果用户想复现这些模型,还需要做大量的额外的工作。
今天,我们很荣幸地宣布推出 PyTorch
Hub,它是一个非常简单的API,并且具有极其简单的工作流程。它提供模型的基本构建模块,用于提高机器学习研究的模型复现性。PyTorch Hub
包含一个经过预训练的模型库,专门用于促进研究的可重复性和快速开展新的研究。PyTorch Hub 内置了对 Colab的 支持,并且能够与 Papers
With Code 集成。目前 PyTorch Hub 已包含一系列广泛的模型,包括分类器和分割器、生成器、变换器等。
【开发者】发布模型
通过添加简单 hubconf.py 文件,开发者能够将预训练的模型(模型定义和预训练的权重)发布到 GitHub
仓库中。该文件提供了所支持模型的枚举,以及运行这些模型的依赖环境列表。相关的例子可以参见 torchvision、 huggingface-bert 和
gan-model-zoo 仓库。
让我们看看最简单的例子:torchvision’s hubconf.py:
在 torchvision 中,各模型具有如下性质:
*
每个模型文件都能作为函数调用,或者独立执行
*
除了 PyTorch 之外(在 hubconf.py 中编码为 dependencies['torch']),它们不需要任何其他包的支持
*
不需要单独的接入点,因为模型在创建时可以无缝地接入
PyTorch Hub 将包的依赖性降到了最小,当使用者加载模型并立即进行实验时,该特性能够提高用户体验。
接下来我们看一个较为复杂的例子:HuggingFace’s BERT 模型,下面是该模型的 hubconf.py:
1dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']
2
3from hubconfs.bert_hubconf import (
4 bertTokenizer,
5 bertModel,
6 bertForNextSentencePrediction,
7 bertForPreTraining,
8 bertForMaskedLM,
9 bertForSequenceClassification,
10 bertForMultipleChoice,
11 bertForQuestionAnswering,
12 bertForTokenClassification
13)
每个模型都需要创建一个接入点,一下代码用于指定 bertForMaskedLM 模型的接入点,并返回预训练的模型权重。
1def bertForMaskedLM(*args, **kwargs):
2 """
3 BertForMaskedLM includes the BertModel Transformer followed by the
4 pre-trained masked language modeling head.
5 Example:
6 ...
7 """
8 model = BertForMaskedLM.from_pretrained(*args, **kwargs)
9 return model
这些接入点可以作为复杂模型的包装器,它们能够提供干净且一致的帮助文档字符串,支持使用者选择是否下载预训练权重(例如
pretrained=True),并且具有其它的特定功能,例如可视化。
创建好 hubconf.py 后,可以根据此模板发送 github 推送请求 。PyTorch Hub
的目标是为研究复现提供高质量、易于重复、高效的模型。因此,我们可能会与开发者合作完善推送请求,并在某些情况下拒绝发布一些低质量的模型。一旦我们接受了开发者的推送请求,开发者的模型将很快出现在
Pytorch 中心网页上,从而供所有的用户浏览。
【用户】工作流程
作为用户,PyTorch Hub
提供非常简单的工作流程,用户只需要按照以下三个步骤执行即可:(1)探索有价值的模型;(2)加载模型;(3)了解任何给定模型的可用方法。接下来,让我们分别看看每个步骤。
探索可用的接入点
用户可以使用 torch.hub.list() 列出仓库中所有可用的接入点。
1>>> torch.hub.list('pytorch/vision')
2>>>
3['alexnet',
4'deeplabv3_resnet101',
5'densenet121',
6...
7'vgg16',
8'vgg16_bn',
9'vgg19',
10 'vgg19_bn']
值得注意的是,PyTorch Hub 还允许辅助接入点(除了预训练模型)。例如,bertTokenizer 可以用于 BERT
模型中的预处理,这使得用户的工作流程更加顺畅。
加载模型
现在,我们已经知道了 Hub中可用的模型,那么用户便能够使用 torch.hub.load() 来加载模型接入点。该命令无需安装其他依赖包,此外,
torch.hub.help() 提供了如何实例化模型的信息。
1print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))
2model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=
True)
由于开发者会不断修复 bug,改进模型,因此 PyTorch Hub 也提供了便捷的方法,使得用户可以非常容易地获取最新的更新:
1model = torch.hub.load(..., force_reload=True)
我们相信,这些功能可以让开发者更加专注于他们的研究,而不用为这些繁琐的事情浪费时间。同时,这能够确保用户享受最新的模型。
从另一个方面来看,对用户而言,稳定性是非常重要的。因此,一些开发者会在其他分支上推送稳定的模型,而不是在 mater
分支上推送,这样能够保证代码的稳定性。例如,pytorch_GAN_zoo 在 hub 分支上提供稳定的版本。
1model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN'
, pretrained=True, useGPU=False)
请注意,hub.load() 中的 *args 和 **kwargs 用于实例化模型。在上面的例子中,pretrained=True 以及
useGPU=False 会被传递给模型的接入点。
探索加载的模型
从PyTorch Hub加载模型后,用户可以使用下面的工作流程找出模型的可用方法,并更好地了解运行该模型所需的参数。
dir(model) 用于查看模型的所有可用方法。接下来,让我们看看 bertForMaskedLM 可用的方法。
1>>> dir(model)
2>>>
3['forward'
4...
5'to'
6'state_dict',
7]
help(model.forward) 用于展示模型运行所需的参数
1>>> help(model.forward)
2>>>
3Help on method forward in module pytorch_pretrained_bert.modeling:
4
forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
5...
在 BERT 和 DeepLabV3 页面中,用户可以详细了解这些模型的使用方法。
其他探索的方式
PyTorch Hub中提供的模型支持 Colab,并且直接链接在 Papers With Code上,只需单击即可使用。下面是一个很好的入门示例。
其他资源
*
PyTorch Hub API文档(https://pytorch.org/docs/stable/hub.html)
*
提交模型(https://github.com/pytorch/hub)
*
可用模型的更多信息(https://pytorch.org/hub)
*
探索更多模型(https://paperswithcode.com/)
原文链接:
https://pytorch.org/blog/towards-reproducible-research-with-pytorch-hub/
(*本文为 AI科技大本营编译文章,转载请微信联系 1092722531)
◆
精彩推荐
◆
6月29-30日,2019以太坊技术及应用大会 特邀以太坊创始人V神与以太坊基金会核心成员
,以及海内外知名专家齐聚北京,聚焦前沿技术,把握时代机遇,深耕行业应用,共话以太坊2.0新生态。
扫码或点击阅读原文,既享优惠购票!
推荐阅读
*
Bert时代的创新:Bert在NLP各领域的应用进展 | 技术头条
<https://blog.csdn.net/dQCFKyQDXYm3F8rB0/article/details/91388515>
*
免费GPU哪家强?谷歌Kaggle vs. Colab
<https://blog.csdn.net/dQCFKyQDXYm3F8rB0/article/details/91388651>
*
高能!8段代码演示Numpy数据运算的神操作
<https://mp.weixin.qq.com/s?__biz=MzU5MjEwMTE2OQ==&mid=2247486619&idx=1&sn=74e644a444efd786c18591178fffeb2b&scene=21#wechat_redirect>
*
Python编写循环的两个建议 | 鹅厂实战
<https://blog.csdn.net/dQCFKyQDXYm3F8rB0/article/details/91488255>
*
Lambda 表达式有何用处?
<https://blog.csdn.net/FL63Zv9Zou86950w/article/details/91467997>
*
9年前他用1万个比特币买了两个披萨, 9年后他把当年的代码卖给了苹果,成为了 GPU 挖矿之父
<https://mp.weixin.qq.com/s?__biz=MzA5MzY4NTQwMA==&mid=2651011547&idx=1&sn=c1edeb322b2f5ae4dee029f81f4f3e67&scene=21#wechat_redirect>
*
TIOBE 6月编程语言排行榜:Python 势不可挡,或在四年之内超越Java、C
<https://mp.weixin.qq.com/s?__biz=MjM5MjAwODM4MA==&mid=2650722168&idx=1&sn=3a286d91889491c7084755a92d49681d&scene=21#wechat_redirect>
*
漫威金刚狼男主弃影炒币了? <https://blog.csdn.net/Blockchain_lemon/article/details/91466947>
你点的每个“在看”,我都认真当成了喜欢
热门工具 换一换