模型库 / facebook/contriever

contriever

facebook transformers
facebook/contriever
7,815,693
下载量
291
收藏数
18
浏览量
-
许可

简介

该模型采用《Towards Unsupervised Dense Information Retrieval with Contrastive Learning》中描述的方法进行无监督训练。相关GitHub代码仓库可通过以下链接访问:https://github.com/facebookresearch/contriever。

模型配置

模型类型 bert
架构 Contriever

模型详情

已翻译

该模型按照论文《Towards Unsupervised Dense Information Retrieval with Contrastive Learning》(https://arxiv.org/abs/2112.09118)中描述的方法进行无监督训练。相关GitHub代码库可在此处获取:https://github.com/facebookresearch/contriever。

使用方式(HuggingFace Transformers)

直接使用HuggingFace transformers中提供的模型时,需要添加均值池化操作以获取句子embedding。

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
model = AutoModel.from_pretrained('facebook/contriever')

sentences = [
    "Where was Marie Curie born?",
    "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
    "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace."
]

# Apply tokenizer
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

# Compute token embeddings
outputs = model(**inputs)

# Mean pooling
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings
embeddings = mean_pooling(outputs[0], inputs['attention_mask'])

标签

bert arxiv:2112.09118 endpoints_compatible region:us

操作


详细信息

厂商
facebook
框架
transformers
模型类型
bert