模型库 / Falconsai/nsfw_image_detection

nsfw_image_detection

Falconsai image-classification transformers
Falconsai/nsfw_image_detection
15,384,704
下载量
1266
收藏数
22
浏览量
apache-2.0
许可

简介

Model Card: Fine-Tuned Vision Transformer (ViT) for NSFW Image Classification

模型卡片

许可协议 apache-2.0
任务 image-classification

模型配置

模型类型 vit
架构 ViTForImageClassification

模型详情

已翻译

Model Card: 用于 NSFW 图像分类的微调 Vision Transformer (ViT)

模型描述

Fine-Tuned Vision Transformer (ViT) 是一种类似于 BERT 的 transformer 编码器架构变体,已适配用于图像分类任务。该特定模型名为 "google/vit-base-patch16-224-in21k",以监督方式在大量图像集合上进行了预训练,利用了 ImageNet-21k 数据集。预训练数据集中的图像被调整为 224x224 像素的分辨率,使其适用于广泛的图像识别任务。

在训练阶段,对超参数设置给予了细致关注,以确保模型性能达到最优。模型以经过审慎选择的 batch size 16 进行微调。这一选择不仅平衡了计算效率,还使模型能够有效处理和学习多样化的图像。

为促进这一微调过程,采用了 5e-5 的学习率。学习率是一个关键的调优参数,决定了训练过程中模型参数调整的幅度。在此案例中,选择 5e-5 的学习率是为了在快速收敛与稳定优化之间取得和谐平衡,从而使模型不仅学习迅速,还能在整个训练过程中稳步提升其能力。

该训练阶段使用了一个包含 80,000 张图像的专有数据集,每张图像都具有显著的可变性。数据集经过精心策划,包含两个不同的类别,即 "normal" 和 "nsfw"。这种多样性使模型能够掌握细微的视觉模式,使其具备准确区分安全内容与敏感内容的能力。

这一细致训练过程的总体目标是赋予模型对视觉线索的深刻理解,确保其在处理 NSFW 图像分类这一特定任务时的稳健性和能力。最终得到的模型已准备好为内容安全与审核做出重要贡献,同时保持最高标准的准确性和可靠性。

预期用途与局限性

预期用途

  • NSFW 图像分类:该模型的主要预期用途是对 NSFW(不适合工作场所)图像进行分类。它已为此目的进行了微调,适用于在各种应用中过滤敏感或不适当的内容。

使用方法

以下是使用该模型基于 2 个类别(normal, nsfw)对图像进行分类的方法:

# Use a pipeline as a high-level helper
from PIL import Image
from transformers import pipeline

img = Image.open("")
classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
classifier(img)
# Load model directly
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, ViTImageProcessor

img = Image.open("")
model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
with torch.no_grad():
    inputs = processor(images=img, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits

predicted_label = logits.argmax(-1).item()
model.config.id2label[predicted_label]

运行 YOLO 版本

import os
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import onnxruntime as ort
import json # Added import for json

# Predict using YOLOv9 model
def predict_with_yolov9(image_path, model_path, labels_path, input_size):
    """
    Run inference using the converted YOLOv9 model on a single image.

    Args:
        image_path (str): Path to the input image file.
        model_path (str): Path to the ONNX model file.
        labels_path (str): Path to the JSON file containing class labels.
        input_size (tuple): The expected input size (height, width) for the model.

    Returns:
        str: The predicted class label.
        PIL.Image.Image: The original loaded image.
    """
    def load_json(file_path):
        with open(file_path, "r") as f:
            return json.load(f)

    # Load labels
    labels = load_json(labels_path)

    # Preprocess image
    original_image = Image.open(image_path).convert("RGB")
    image_resized = original_image.resize(input_size, Image.Resampling.BILINEAR)
    image_np = np.array(image_resized, dtype=np.float32) / 255.0
    image_np = np.transpose(image_np, (2, 0, 1))  # [C, H, W]
    input_tensor = np.expand_dims(image_np, axis=0).astype(np.float32)

    # Load YOLOv9 model
    session = ort.InferenceSession(model_path)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name # Assuming classification output

    # Run inference
    outputs = session.run([output_name], {input_name: input_tensor})
    predictions = outputs[0]

    # Postprocess predictions (assuming classification output)
    # Adapt this section if your model output is different (e.g., detection boxes)
    predicted_index = np.argmax(predictions)
    predicted_label = labels[str(predicted_index)] # Assumes labels are indexed by string numbers

    return predicted_label, original_image

# Display prediction for a single image
def display_single_prediction(image_path, model_path, labels_path, input_size):
    """
    Predicts the class for a single image and displays the image with its prediction.

    Args:
        image_path (str): Path to the input image file.
        model_path (str): Path to the ONNX model file.
        labels_path (str): Path to the JSON file containing class labels.
        input_size (tuple): The expected input size (height, width) for the model.
    """
    try:
        # Run prediction
        prediction, img = predict_with_yolov9(image_path, model_path, labels_path, input_size)

        # Display image and prediction
        fig, ax = plt.subplots(1, 1, figsize=(8, 8)) # Create a single plot
        ax.imshow(img)
        ax.set_title(f"Prediction: {prediction}", fontsize=14)
        ax.axis("off") # Hide axes ticks and labels

        plt.tight_layout()
        plt.show()

    except FileNotFoundError:
        print(f"Error: Image file not found at {image_path}")
    except Exception as e:
        print(f"An error occurred: {e}")

# --- Main Execution ---

# Paths and parameters - **MODIFY THESE**
single_image_path = "path/to/your/single_image.jpg"  # 

### 局限性
- **特定任务微调**:虽然该模型擅长 NSFW 图像分类,但在应用于其他任务时,其性能可能会有所变化。
- 有兴趣将该模型用于不同任务的用户,应探索模型中心中可用的微调版本,以获得最佳效果。

## 训练数据

该模型的训练数据包含一个约 80,000 张图像的专有数据集。该数据集具有显著的可变性,包含两个不同的类别:"normal" 和 "nsfw"。在此数据上的训练过程旨在使模型具备有效区分安全内容与敏感内容的能力。

### 训练统计信息
``` markdown

- 'eval_loss': 0.07463177293539047,
- 'eval_accuracy': 0.980375, 
- 'eval_runtime': 304.9846, 
- 'eval_samples_per_second': 52.462, 
- 'eval_steps_per_second': 3.279

注意: 在将模型应用于实际场景(尤其是涉及潜在敏感内容的场景)时,必须负责任且合乎道德地使用,遵守内容指南和适用法规。

有关模型微调和使用的更多详细信息,请参阅模型文档和模型中心。

参考文献

免责声明: 模型的性能可能受到其微调所用数据的质量和代表性的影响。建议用户评估模型对其特定应用和数据集的适用性。

标签

vit arxiv:2010.11929 license:apache-2.0 endpoints_compatible deploy:azure region:us

操作


详细信息

厂商
Falconsai
任务
image-classification
框架
transformers
模型类型
vit
许可(HF)
apache-2.0