用自然语言查询数据库——不用写 SQL,只说你要什么数据
"帮我查一下上个月销售额最高的 5 个产品"——只需要说这句话,AI 自动生成 SQL 并执行查询,返回结果。
核心代码
#!/usr/bin/env python3
# askdb.py
import sqlite3, os, re
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
client = OpenAI(api_key=os.getenv("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com/v1")
def get_schema(db_path):
"""获取数据库表结构。"""
conn = sqlite3.connect(db_path)
tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
schema = {}
for t in tables:
cols = conn.execute(f"PRAGMA table_info({t[0]})").fetchall()
schema[t[0]] = [f"{c[1]} {c[2]}" for c in cols]
conn.close()
return schema
def generate_sql(question, schema):
"""AI 根据自然语言生成 SQL。"""
schema_text = "\n".join(f"{t}: {', '.join(c)}" for t, c in schema.items())
resp = client.chat.completions.create(
model="deepseek-chat",
messages=[{"role": "system", "content": f"""你是一个 SQL 专家。根据表结构生成 SQL 查询。
表结构:{schema_text}
规则:
1. 只生成 SELECT 查询,不执行 INSERT/UPDATE/DELETE
2. 只返回 SQL,不要解释
3. 用 LIMIT 限制结果数量(默认 20)"""},
{"role": "user", "content": question}],
temperature=0.1,
max_tokens=500,
)
return resp.choices[0].message.content.strip()
def execute_query(db_path, sql):
"""安全执行 SQL 查询。"""
# 安全检查:只允许 SELECT
if not re.match(r"^\s*SELECT", sql, re.IGNORECASE):
return None, "只允许 SELECT 查询"
# 限制返回行数
if "LIMIT" not in sql.upper():
sql += " LIMIT 20"
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(sql).fetchall()
return [dict(r) for r in rows], len(rows)
except Exception as e:
return None, str(e)
finally:
conn.close()
def ask(db_path, question):
"""自然语言查询数据库。"""
schema = get_schema(db_path)
print(f"📊 数据库:{db_path}({len(schema)} 个表)\n")
sql = generate_sql(question, schema)
print(f"🔍 SQL:{sql}\n")
result, info = execute_query(db_path, sql)
if result is None:
print(f"❌ 查询失败:{info}")
return
print(f"✅ 返回 {info} 条记录:")
if result:
cols = result[0].keys()
print(" | ".join(cols))
print("-" * 50)
for row in result[:10]:
print(" | ".join(str(v)[:50] for v in row.values()))
if __name__ == "__main__":
ask("data.db", "上个月哪些用户发帖最多?显示前 5 名")
使用方式
python askdb.py "今年哪些分类的文章阅读量最高"
# 📊 数据库:zyentor.db(15 个表)
# 🔍 SQL:SELECT category, SUM(view_count) FROM news WHERE is_published=1 GROUP BY category ORDER BY SUM(view_count) DESC LIMIT 5
# ✅ 返回 5 条记录:
# category | SUM(view_count)
# ai_dev | 15420
# news | 8234
安全机制
- 只允许 SELECT 查询,拒绝 INSERT/UPDATE/DELETE
- 自动加 LIMIT 限制返回行数
- 用异常处理防止 SQL 错误
总结
一个简单实用的自然语言数据库查询工具,让不懂 SQL 的人也能查数据。
本文由 Zyentor(智元界)原创发布