用自然语言查询数据库——不用写 SQL,只说你要什么数据

"帮我查一下上个月销售额最高的 5 个产品"——只需要说这句话,AI 自动生成 SQL 并执行查询,返回结果。

核心代码

#!/usr/bin/env python3
# askdb.py
import sqlite3, os, re
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()

client = OpenAI(api_key=os.getenv("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com/v1")

def get_schema(db_path):
    """获取数据库表结构。"""
    conn = sqlite3.connect(db_path)
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
    schema = {}
    for t in tables:
        cols = conn.execute(f"PRAGMA table_info({t[0]})").fetchall()
        schema[t[0]] = [f"{c[1]} {c[2]}" for c in cols]
    conn.close()
    return schema

def generate_sql(question, schema):
    """AI 根据自然语言生成 SQL。"""
    schema_text = "\n".join(f"{t}: {', '.join(c)}" for t, c in schema.items())

    resp = client.chat.completions.create(
        model="deepseek-chat",
        messages=[{"role": "system", "content": f"""你是一个 SQL 专家。根据表结构生成 SQL 查询。

表结构:{schema_text}

规则:
1. 只生成 SELECT 查询,不执行 INSERT/UPDATE/DELETE
2. 只返回 SQL,不要解释
3. 用 LIMIT 限制结果数量(默认 20)"""},
                  {"role": "user", "content": question}],
        temperature=0.1,
        max_tokens=500,
    )
    return resp.choices[0].message.content.strip()

def execute_query(db_path, sql):
    """安全执行 SQL 查询。"""
    # 安全检查:只允许 SELECT
    if not re.match(r"^\s*SELECT", sql, re.IGNORECASE):
        return None, "只允许 SELECT 查询"

    # 限制返回行数
    if "LIMIT" not in sql.upper():
        sql += " LIMIT 20"

    conn = sqlite3.connect(db_path)
    conn.row_factory = sqlite3.Row
    try:
        rows = conn.execute(sql).fetchall()
        return [dict(r) for r in rows], len(rows)
    except Exception as e:
        return None, str(e)
    finally:
        conn.close()

def ask(db_path, question):
    """自然语言查询数据库。"""
    schema = get_schema(db_path)
    print(f"📊 数据库:{db_path}{len(schema)} 个表)\n")

    sql = generate_sql(question, schema)
    print(f"🔍 SQL:{sql}\n")

    result, info = execute_query(db_path, sql)
    if result is None:
        print(f"❌ 查询失败:{info}")
        return

    print(f"✅ 返回 {info} 条记录:")
    if result:
        cols = result[0].keys()
        print(" | ".join(cols))
        print("-" * 50)
        for row in result[:10]:
            print(" | ".join(str(v)[:50] for v in row.values()))

if __name__ == "__main__":
    ask("data.db", "上个月哪些用户发帖最多?显示前 5 名")

使用方式

python askdb.py "今年哪些分类的文章阅读量最高"
# 📊 数据库:zyentor.db(15 个表)
# 🔍 SQL:SELECT category, SUM(view_count) FROM news WHERE is_published=1 GROUP BY category ORDER BY SUM(view_count) DESC LIMIT 5
# ✅ 返回 5 条记录:
# category | SUM(view_count)
# ai_dev | 15420
# news | 8234

安全机制

  • 只允许 SELECT 查询,拒绝 INSERT/UPDATE/DELETE
  • 自动加 LIMIT 限制返回行数
  • 用异常处理防止 SQL 错误

总结

一个简单实用的自然语言数据库查询工具,让不懂 SQL 的人也能查数据。

本文由 Zyentor(智元界)原创发布