contriever
facebook
transformers
facebook/contriever
7,815,693
下载量
291
收藏数
18
浏览量
-
许可
简介
该模型采用《Towards Unsupervised Dense Information Retrieval with Contrastive Learning》中描述的方法进行无监督训练。相关GitHub代码仓库可通过以下链接访问:https://github.com/facebookresearch/contriever。
模型配置
模型类型
bert
架构
Contriever
模型详情
已翻译该模型按照论文《Towards Unsupervised Dense Information Retrieval with Contrastive Learning》(https://arxiv.org/abs/2112.09118)中描述的方法进行无监督训练。相关GitHub代码库可在此处获取:https://github.com/facebookresearch/contriever。
使用方式(HuggingFace Transformers)
直接使用HuggingFace transformers中提供的模型时,需要添加均值池化操作以获取句子embedding。
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
model = AutoModel.from_pretrained('facebook/contriever')
sentences = [
"Where was Marie Curie born?",
"Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
"Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace."
]
# Apply tokenizer
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
outputs = model(**inputs)
# Mean pooling
def mean_pooling(token_embeddings, mask):
token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
return sentence_embeddings
embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
正在翻译中,请稍候...
标签
bert
arxiv:2112.09118
endpoints_compatible
region:us