electra-base-discriminator
google
transformers
en
google/electra-base-discriminator
54,079,034
下载量
299
收藏数
93
浏览量
apache-2.0
许可
简介
ELECTRA:以判别器而非生成器方式预训练文本编码器
模型卡片
许可协议
apache-2.0
语言
en
模型配置
模型类型
electra
架构
ElectraForPreTraining
模型详情
已翻译ELECTRA: 将文本编码器预训练为判别器而非生成器
ELECTRA 是一种用于自监督语言表示学习的新方法。它可以用相对较少的计算资源来预训练 transformer 网络。ELECTRA 模型被训练用于区分"真实"输入 token 与由另一个神经网络生成的"虚假"输入 token,类似于 GAN 中的判别器。在小规模场景下,即使仅在单个 GPU 上训练,ELECTRA 也能取得强劲的结果。在大规模场景下,ELECTRA 在 SQuAD 2.0 数据集上达到了最先进的结果。
有关详细描述和实验结果,请参阅我们的论文 ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators。
本仓库包含预训练 ELECTRA 的代码,包括在单个 GPU 上训练小型 ELECTRA 模型。它还支持在下游任务上微调 ELECTRA,包括分类任务(例如 GLUE)、问答任务(例如 SQuAD)以及序列标注任务(例如 文本组块分析)。
如何在 transformers 中使用判别器
from transformers import ElectraForPreTraining, ElectraTokenizerFast
import torch
discriminator = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator")
tokenizer = ElectraTokenizerFast.from_pretrained("google/electra-base-discriminator")
sentence = "The quick brown fox jumps over the lazy dog"
fake_sentence = "The quick brown fox fake over the lazy dog"
fake_tokens = tokenizer.tokenize(fake_sentence)
fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt")
discriminator_outputs = discriminator(fake_inputs)
predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)
[print("%7s" % token, end="") for token in fake_tokens]
[print("%7s" % int(prediction), end="") for prediction in predictions.tolist()]
正在翻译中,请稍候...
标签
tf
jax
rust
electra
pretraining
en
arxiv:1406.2661
license:apache-2.0
endpoints_compatible