在当今数据驱动的世界中,能够通过自然语言与数据库进行交互已成为一项重要能力。本文将深入解析一个基于Python和Streamlit构建的自然语言到SQL(NL2SQL)转换Web应用程序,该应用允许用户通过简单的自然语言查询来访问数据库信息。
一、 完整代码(此处略)
见末尾的附录部分
二、 环境准备
2.1 启动大模型服务
基于自研大模型训推平台,部署'Qwen3-0.6B',虽然模型比较小,但是根据自然语言翻译为SQL语句大概率可以正确写出。

2.2 启动应用
streamlit run web_nl2sql.py

自动打开浏览器,并输入地址:http://localhost:8501/

2.3 问数示例
点击推荐问题,或者手动输入问题,然后点击“发送”按钮。

如果正常,会返回SQL语句,以表格形式展现的数据,以图表形式展现的数据。

三、 应用概述
这个Web应用实现了完整的NL2SQL功能,具有以下核心特性:
- 可视化的聊天界面,模拟真实的对话体验
- 自然语言转SQL查询功能
- 结构化数据展示和可视化图表生成功能
- 示例数据库和快速查询引导
四、 技术架构分析
4.1 核心组件
应用主要依赖以下几个关键技术组件:
- Streamlit - 用于构建Web界面的Python框架
- SQLite - 轻量级数据库系统
- LangChain - 用于集成大语言模型的框架
- Pandas - 数据处理和分析库
- Matplotlib/Seaborn - 数据可视化库
4.2 应用初始化与状态管理
应用采用了Streamlit的session_state机制来管理应用状态:
if 'initialized' not in st.session_state:
st.session_state.initialized = False
st.session_state.db_initialized = False
st.session_state.query_results = {
'interactive': None,
'llm': None
}
st.session_state.chat_history = []
这种设计确保了在用户与应用交互过程中保持状态一致性,特别是在页面刷新或重新运行时。
4.3 数据库设计与初始化
应用内置了一个示例公司数据库,包含两个核心表:
- employees表 - 存储员工信息(ID、姓名、部门、薪资、入职日期)
- sales表 - 存储销售记录(ID、员工ID、销售金额、销售日期)
def init_database():
"""Initialize a sample SQLite database for demonstration"""
conn = sqlite3.connect('company.db')
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS employees (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
department TEXT NOT NULL,
salary REAL NOT NULL,
hire_date DATE NOT NULL
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS sales (
id INTEGER PRIMARY KEY,
employee_id INTEGER,
amount REAL NOT NULL,
sale_date DATE NOT NULL,
FOREIGN KEY (employee_id) REFERENCES employees (id)
)
''')
# 插入示例数据...
这种设计使用户无需配置外部数据库即可体验完整功能。
五、 NL2SQL核心实现
5.1 LLM集成与提示工程(这是让大模型通过自然语言生成SQL的核心中的核心)
应用通过LangChain集成了大语言模型,关键在于精心设计的提示模板:
template = """
你是一个专业的 SQLite 数据库助手。请严格遵守以下规则:
1. **绝对禁止**输出任何推理过程、解释、说明文字或额外内容。
2. 你的回答**必须且只能**是以下四行格式,每行以指定前缀开头:
Question: "{input}"
SQLQuery: "你的 SQLite 查询语句"
SQLResult: ""
Answer: ""
3. SQLQuery 必 须是合法的、可直接执行的 SQLite SELECT 语句。
4. 字段说明(极其重要!):
- 销售额、销售金额 → 必须使用 `sales.amount`
- 员工工资、薪资 → 必须使用 `employees.salary`
- 员工姓名 → `employees.name`
- 部门 → `employees.department`
5. 当问题涉及"每个"、"按...分组"时,**必须在 SELECT 子句中包含分组字段和聚合字段**,并在 GROUP BY 子句中指定分组字段。
6. 未指定数量时,不加 LIMIT。
这个提示模板通过严格的格式要求和明确的业务规则指导LLM生成准确的SQL查询。
5.2 SQL执行与安全控制
为了防止恶意操作,应用对SQL查询进行了安全检查:
def execute_sql(query: str):
clean_query = re.sub(r"^[\"'`]+|[\"'`]+$", "", query.strip())
# 安全检查
if not clean_query.lower().lstrip().startswith("select"):
st.error(f"⚠️ 非 SELECT 语句被拒绝。SQL: `{clean_query}`")
return pd.DataFrame()
conn = get_db_connection()
try:
df = pd.read_sql_query(clean_query, conn)
return df
except Exception as e:
st.error(f"❌ SQL 执行失败:\n{e}\n\n你提交的 SQL 是:\n```sql\n{clean_query}\n```")
return pd.DataFrame()
finally:
conn.close()
只允许执行SELECT查询,有效防止了DELETE、UPDATE等危险操作。
六、 用户界面设计
6.1 聊天式交互界面
应用采用现代化的聊天界面设计,用户可以在底部输入框中输入自然语言问题,系统在主区域展示完整的对话历史:
def render_chat_message(sender: str, content, message_type: str = "text"):
"""渲染单条聊天消息 - 清新配色"""
if sender == "user":
# 用户消息 - 右侧,浅蓝色
st.markdown(
f"""
<div style="display: flex; justify-content: flex-end; margin: 12px 0;">
<div style="
background-color: #e3f2fd;
color: #1976d2;
padding: 14px 18px;
border-radius: 18px 18px 4px 18px;
max-width: 70%;
text-align: left;
font-size: 15px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
border: 1px solid #bbdefb;
">
{content}
</div>
</div>
""",
unsafe_allow_html=True
)
这种设计提供了直观友好的用户体验,让用户感觉像是在与智能助手对话。
6.2 数据可视化
应用不仅能以表格形式展示查询结果,还能自动生成相应的图表:
def render_chart_for_dataframe(df):
"""根据 DataFrame 的内容生成并渲染一个美观的图表"""
if df.empty or len(df) == 0:
return None
# 情况1: 单值
if df.shape == (1, 1):
value = df.iloc[0, 0]
if pd.api.types.is_numeric_dtype(df.iloc[:, 0]):
html = f"""
<div style="text-align: center; margin: 16px 0;">
<div style="font-size: 32px; font-weight: bold; color: #1976d2;">{value:,.2f}</div>
<div style="font-size: 14px; color: #666;">结果</div>
</div>
"""
return html
return None
# 更多图表类型处理...
系统能自动识别数据特征并选择合适的图表类型,包括柱状图、折线图、热力图等。
七、 快速入门与用户体验优化
应用提供了预设的示例问题,帮助用户快速上手:
st.markdown("### 💡 快速开始")
example_questions = [
"每个部门有多少员工?",
"每个员工的总销售额是多少?",
"列出所有员工及其薪资,按薪资降序排列",
"按部门的平均工资是多少?",
"销售额前三的员工是谁?"
]
quick_cols = st.columns(len(example_questions))
for i, question in enumerate(example_questions):
if quick_cols[i].button(f"「{question}」", key=f"quick_{i}", use_container_width=True):
st.session_state.user_query_input = question
st.session_state.trigger_submit = True
st.rerun()
这种设计大大降低了用户的使用门槛,即使是初次接触的用户也能快速理解系统的功能。
八、 安全性和错误处理
应用实现了多层次的安全防护和错误处理机制:
- SQL注入防护 - 通过限制只能执行SELECT语句防止恶意操作
- 异常捕获 - 对SQL执行错误进行捕获并友好提示
- 输入验证 - 对用户输入进行清理和验证
try:
df = pd.read_sql_query(clean_query, conn)
return df
except Exception as e:
st.error(f"❌ SQL 执行失败:\n{e}\n\n你提交的 SQL 是:\n```sql\n{clean_query}\n```")
return pd.DataFrame()
九、 部署和扩展性
该应用基于Streamlit构建,具有良好的部署便利性。可以通过以下方式运行:
streamlit run web_nl2sql.py
同时,应用的设计具有良好的扩展性:
- 可以轻松替换底层数据库
- 可以集成不同的大语言模型
- 可以增加更多的可视化类型
- 可以扩展支持更复杂的查询场景
十、 总结
这个NL2SQL Web应用展示了如何结合现代Web技术和人工智能技术构建智能数据查询系统。通过精心设计的用户界面、严谨的安全控制和强大的数据可视化功能,它为用户提供了一种全新的数据库交互方式。
该应用不仅是一个实用的工具,也是一个很好的学习范例,展示了如何将大语言模型集成到实际应用中,解决真实世界的业务问题。随着自然语言处理技术的不断发展,这类NL2SQL应用将在数据分析和商业智能领域发挥越来越重要的作用。
十一、 附录:完整代码
web_nl2sql.py
"""
Web-based NL2SQL Interface using Streamlit (TRUE CHAT LAYOUT)
- 底部固定输入区
- 聊天区域显示完整对话(含数据表格)
- 清新配色,无错误
"""
import streamlit as st
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import warnings
import re
warnings.filterwarnings("ignore")
from dotenv import load_dotenv
# Try to import required packages
try:
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
LANGCHAIN_AVAILABLE = True
except ImportError:
LANGCHAIN_AVAILABLE = False
st.warning("LangChain 不可用。LLM 功能将被禁用。")
# Load environment variables
load_dotenv()
# Initialize session state
if 'initialized' not in st.session_state:
st.session_state.initialized = False
st.session_state.db_initialized = False
st.session_state.query_results = {
'interactive': None,
'llm': None
}
st.session_state.chat_history = [] # 存储完整对话历史
def init_database():
"""Initialize a sample SQLite database for demonstration"""
conn = sqlite3.connect('company.db')
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS employees (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
department TEXT NOT NULL,
salary REAL NOT NULL,
hire_date DATE NOT NULL
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS sales (
id INTEGER PRIMARY KEY,
employee_id INTEGER,
amount REAL NOT NULL,
sale_date DATE NOT NULL,
FOREIGN KEY (employee_id) REFERENCES employees (id)
)
''')
employees_data = [
(1, 'Alice Johnson', 'Engineering', 95000, '2022-01-15'),
(2, 'Bob Smith', 'Marketing', 75000, '2021-03-22'),
(3, 'Carol Davis', 'Engineering', 105000, '2020-07-10'),
(4, 'David Wilson', 'Sales', 65000, '2023-02-05'),
(5, 'Eve Brown', 'Sales', 70000, '2022-11-30'),
(6, 'Frank Miller', 'Engineering', 90000, '2021-06-15'),
(7, 'Grace Lee', 'Marketing', 72000, '2022-09-01'),
(8, 'Henry Taylor', 'Sales', 68000, '2020-11-20'),
]
cursor.executemany('INSERT OR REPLACE INTO employees VALUES (?, ?, ?, ?, ?)', employees_data)
sales_data = [
(1, 1, 15000, '2023-01-15'),
(2, 2, 12000, '2023-02-20'),
(3, 3, 18000, '2023-03-10'),
(4, 4, 10000, '2023-01-30'),
(5, 5, 13000, '2023-02-28'),
(6, 1, 20000, '2023-03-15'),
(7, 4, 11000, '2023-03-22'),
(8, 5, 16000, '2023-03-25'),
(9, 6, 17000, '2023-04-05'),
(10, 7, 14000, '2023-04-18'),
(11, 8, 12500, '2023-04-22'),
(12, 2, 15500, '2023-05-01'),
]
cursor.executemany('INSERT OR REPLACE INTO sales VALUES (?, ?, ?, ?)', sales_data)
conn.commit()
conn.close()
st.session_state.db_initialized = True
def get_db_connection():
"""Get database connection"""
conn = sqlite3.connect('company.db')
return conn
def execute_sql(query: str):
clean_query = re.sub(r"^[\"'`]+|[\"'`]+$", "", query.strip())
# 安全检查
if not clean_query.lower().lstrip().startswith("select"):
st.error(f"⚠️ 非 SELECT 语句被拒绝。SQL: `{clean_query}`")
return pd.DataFrame()
conn = get_db_connection()
try:
df = pd.read_sql_query(clean_query, conn)
return df
except Exception as e:
st.error(f"❌ SQL 执行失败:\n{e}\n\n你提交的 SQL 是:\n```sql\n{clean_query}\n```")
# 打印到控制台便于调试
print(f"DEBUG: Failed SQL: {clean_query}")
return pd.DataFrame()
finally:
conn.close()
def get_table_info():
"""Get information about tables in the database"""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
table_info = {}
for table in tables:
table_name = table[0]
cursor.execute(f"PRAGMA table_info({table_name})")
columns = cursor.fetchall()
table_info[table_name] = [(col[1], col[2]) for col in columns]
conn.close()
return table_info
def create_sql_chain(llm, db):
table_info = db.get_table_info()
# --- 强化后的 Prompt ---
template = """
你是一个专业的 SQLite 数据库助手。请严格遵守以下规则:
1. **绝对禁止**输出任何推理过程、解释、娓 标签、说明文字或额外内容。
2. 你的回答**必须且只能**是以下四行格式,每行以指定前缀开头:
Question: "{input}"
SQLQuery: "你的 SQLite 查询语句"
SQLResult: ""
Answer: ""
3. SQLQuery 必须是合法的、可直接执行的 SQLite SELECT 语句。
4. 字段说明(极其重要!):
- 销售额、销售金额 → 必须使用 `sales.amount`
- 员工工资、薪资 → 必须使用 `employees.salary`
- 员工姓名 → `employees.name`
- 部门 → `employees.department`
5. 当问题涉及“每个”、“按...分组”时,**必须在 SELECT 子句中包含分组字段和聚合字段**,并在 GROUP BY 子句中指定分组字段。
例如:`SELECT department, COUNT(*) FROM employees GROUP BY department;`
6. 未指定数量时,不加 LIMIT。
数据库 Schema:
{table_info}
现在,请严格按以上格式回答:
""".strip()
prompt = PromptTemplate.from_template(template)
def sql_chain(question):
formatted_prompt = prompt.format(table_info=table_info, input=question)
response = llm.invoke(formatted_prompt)
content = response.content.strip()
# 尝试方法1:按标准格式提取
if "SQLQuery:" in content:
start = content.find("SQLQuery:") + len("SQLQuery:")
# 找到下一个字段(SQLResult 或 Answer)或结尾
end = content.find("SQLResult:")
if end == -1:
end = content.find("Answer:")
if end == -1:
end = len(content)
sql_query = content[start:end].strip()
# 清理引号和多余字符
sql_query = sql_query.strip().strip('"').strip("'").strip("`").strip(";")
return sql_query
# 尝试方法2:如果全是自由文本,尝试提取第一个 SELECT 语句
sql_match = re.search(r"(SELECT\s+[^;]*);?", content, re.IGNORECASE | re.DOTALL)
if sql_match:
return sql_match.group(1).strip()
# 如果都失败,抛出错误
raise ValueError(f"LLM 未生成有效 SQL。原始输出:\n{content[:500]}...")
return sql_chain
def render_chat_message(sender: str, content, message_type: str = "text"):
"""渲染单条聊天消息 - 清新配色"""
if sender == "user":
# 用户消息 - 右侧,浅蓝色
st.markdown(
f"""
<div style="display: flex; justify-content: flex-end; margin: 12px 0;">
<div style="
background-color: #e3f2fd;
color: #1976d2;
padding: 14px 18px;
border-radius: 18px 18px 4px 18px;
max-width: 70%;
text-align: left;
font-size: 15px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
border: 1px solid #bbdefb;
">
{content}
</div>
</div>
""",
unsafe_allow_html=True
)
else:
# 助手消息 - 左侧,浅灰色
if message_type == "sql":
content_html = f"""<div style="font-family: 'Consolas', 'Courier New', monospace;
background-color: #f0f7f4;
color: #2e7d32;
padding: 12px;
border-radius: 8px;
margin: 8px 0;
border-left: 3px solid #4caf50;">
<b>生成的 SQL:</b><br><code>{content}</code></div>"""
st.markdown(content_html, unsafe_allow_html=True)
elif message_type == "data":
# --- 关键修复:直接渲染 HTML ---
st.markdown(content, unsafe_allow_html=True)
elif message_type == "error":
content_html = f"""<div style="color: #d32f2f; padding: 8px 0; background-color: #ffebee; border-radius: 4px; padding: 10px;">
❌ {content}</div>"""
st.markdown(content_html, unsafe_allow_html=True)
elif message_type == "info":
content_html = f"""<div style="color: #1976d2; padding: 8px 0; background-color: #e3f2fd; border-radius: 4px; padding: 10px;">
ℹ️ {content}</div>"""
st.markdown(content_html, unsafe_allow_html=True)
else:
st.markdown(
f"""
<div style="display: flex; justify-content: flex-start; margin: 12px 0;">
<div style="
background-color: #f5f5f5;
color: #333;
padding: 14px 18px;
border-radius: 18px 18px 18px 4px;
max-width: 80%;
text-align: left;
font-size: 15px;
box-shadow: 0 1px 3px rgba(0,0,0,0.08);
border: 1px solid #e0e0e0;
">
{content}
</div>
</div>
""",
unsafe_allow_html=True
)
def render_chart_for_dataframe(df):
"""根据 DataFrame 的内容生成并渲染一个美观的图表(支持多列情况)"""
if df.empty or len(df) == 0:
return None
# 情况1: 单值
if df.shape == (1, 1):
value = df.iloc[0, 0]
if pd.api.types.is_numeric_dtype(df.iloc[:, 0]):
html = f"""
<div style="text-align: center; margin: 16px 0;">
<div style="font-size: 32px; font-weight: bold; color: #1976d2;">{value:,.2f}</div>
<div style="font-size: 14px; color: #666;">结果</div>
</div>
"""
return html
return None
# 提取数值列和类别列
numeric_cols = df.select_dtypes(include='number').columns.tolist()
object_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
# 情况2: 两列(一类别一数值)——原有逻辑增强
if df.shape[1] == 2 and len(numeric_cols) == 1 and len(object_cols) == 1:
cat_col = object_cols[0]
val_col = numeric_cols[0]
plot_df = df.copy()
if plot_df[cat_col].nunique() > 15:
plot_df = plot_df.nlargest(15, val_col)
fig, ax = plt.subplots(figsize=(8, 4))
colors = sns.color_palette("Set2", len(plot_df))
bars = ax.bar(plot_df[cat_col].astype(str), plot_df[val_col], color=colors)
ax.set_xlabel(cat_col, fontsize=12)
ax.set_ylabel(val_col, fontsize=12)
ax.set_title(f'{val_col} by {cat_col}', fontsize=14, pad=12)
plt.xticks(rotation=30, ha='right')
for bar in bars:
height = bar.get_height()
ax.annotate(f'{height:,.0f}', xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=9)
plt.tight_layout()
# 情况3: 单列数值 → 直方图
elif len(numeric_cols) == 1 and df.shape[1] == 1:
col = numeric_cols[0]
fig, ax = plt.subplots(figsize=(8, 3))
sns.histplot(df[col].dropna(), kde=True, ax=ax, color='#4CAF50')
ax.set_title(f'Distribution of {col}', fontsize=14)
plt.tight_layout()
# 情况4: 多列数值(≥2),无类别列 → 热力图(相关性)
elif len(numeric_cols) >= 2 and len(object_cols) == 0:
if len(numeric_cols) <= 6: # 避免太多列导致热力图拥挤
corr = df[numeric_cols].corr()
fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(corr, annot=True, fmt=".2f", cmap="Blues", ax=ax, cbar_kws={'shrink': 0.8})
ax.set_title("Correlation Heatmap", fontsize=14, pad=12)
plt.tight_layout()
else:
# 降维为前5列展示折线图(按行索引)
plot_cols = numeric_cols[:5]
fig, ax = plt.subplots(figsize=(8, 4))
for col in plot_cols:
ax.plot(df.index, df[col], label=col, marker='o', markersize=3)
ax.set_title("Trend of Numerical Columns", fontsize=14)
ax.legend()
plt.tight_layout()
# 情况5: 多列,包含类别列 + 多个数值列 → 分组柱状图(取前2数值列)
elif len(object_cols) >= 1 and len(numeric_cols) >= 2:
cat_col = object_cols[0]
num_cols = numeric_cols[:2] # 取前两个数值列
if df[cat_col].nunique() <= 10:
plot_df = df[[cat_col] + num_cols].set_index(cat_col)
fig, ax = plt.subplots(figsize=(8, 4))
plot_df.plot(kind='bar', ax=ax, color=['#2196F3', '#FF9800'])
ax.set_title(f'Comparison of {", ".join(num_cols)} by {cat_col}', fontsize=14)
ax.set_xlabel(cat_col)
plt.xticks(rotation=30, ha='right')
plt.tight_layout()
else:
return None # 类别太多,不绘图
# 情况6: 全是类别列(无数值)→ 不绘图
elif len(numeric_cols) == 0:
return None
# 其他情况:尝试折线图(以索引为x轴)
else:
if len(numeric_cols) > 0:
plot_cols = numeric_cols[:4]
fig, ax = plt.subplots(figsize=(8, 4))
for col in plot_cols:
ax.plot(df.index, df[col], label=col, marker='.', linewidth=1.5)
ax.set_title("Numerical Trends", fontsize=14)
ax.legend()
plt.tight_layout()
else:
return None
# 统一渲染为 base64 图像
import io, base64
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
plt.close(fig)
return f'<div style="text-align: center; margin-top: 16px;"><img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto; border-radius: 8px;"></div>'
def display_dataframe_in_chat(df):
"""在聊天区域显示数据表格和图表"""
if df.empty:
return "查询未返回任何数据。"
# --- 1. 生成表格 ---
html_table = df.to_html(classes='chat-dataframe', index=False, escape=False)
table_html = f"""
{html_table}
<div style="margin-top: 8px; font-size: 13px; color: #666;">
共 {len(df)} 行数据
</div>
"""
# --- 2. 生成图表 ---
chart_html = render_chart_for_dataframe(df)
if chart_html:
table_html += chart_html
# --- 3. 合并返回 ---
return table_html
def main():
# 在 main() 开头添加
st.markdown(
"""
<style>
.chat-dataframe {
border-collapse: collapse;
margin: 10px 0;
font-size: 14px;
width: 100%;
}
.chat-dataframe th {
background-color: #e3f2fd;
color: #1976d2;
padding: 8px;
text-align: left;
border: 1px solid #bbdefb;
}
.chat-dataframe td {
padding: 8px;
border: 1px solid #e0e0e0;
background-color: white;
}
.chat-dataframe tr:nth-child(even) {
background-color: #fafafa;
}
</style>
""",
unsafe_allow_html=True
)
st.set_page_config(page_title="NL2SQL 智能问数", layout="wide")
st.title("🤖 数据库智能问答")
# Initialize database
if not st.session_state.db_initialized:
with st.spinner("正在初始化数据库..."):
init_database()
st.success("数据库初始化成功!")
# Sidebar for navigation
st.sidebar.title("导航")
page = st.sidebar.radio("", ["数据库概览", "智能问数"])
if page == "数据库概览":
st.header("数据库概览")
st.subheader("表结构")
table_info = get_table_info()
for table_name, columns in table_info.items():
st.write(f"**{table_name}**")
cols_df = pd.DataFrame(columns, columns=["列名", "数据类型"])
st.table(cols_df)
st.markdown("---")
st.subheader("示例数据")
tab1, tab2 = st.tabs(["员工", "销售"])
with tab1:
df_employees = execute_sql("SELECT * FROM employees LIMIT 10")
st.dataframe(df_employees)
with tab2:
df_sales = execute_sql("SELECT * FROM sales LIMIT 10")
st.dataframe(df_sales)
elif page == "智能问数":
st.header("💬 智能问数")
if not LANGCHAIN_AVAILABLE:
st.error("LangChain 不可用。请安装以使用此功能:")
st.code("pip install langchain langchain-openai langchain-community")
return
st.info("💡 在底部输入框中提出您的问题,系统将自动生成 SQL 并返回结果")
# 聊天历史区域
chat_container = st.container()
with chat_container:
if st.session_state.chat_history:
for msg in st.session_state.chat_history:
if msg["role"] == "user":
render_chat_message("user", msg["content"])
else:
render_chat_message(
"assistant",
msg["content"],
msg.get("type", "text")
)
else:
st.info("还没有对话记录。在底部输入您的问题开始对话!")
# 检查 LLM 连接
try:
llm = ChatOpenAI(
base_url="http://10.7.11.21:8000/v1",
api_key="not-needed",
model="Qwen3-0.6B",
temperature=0,
max_tokens=512,
)
db = SQLDatabase.from_uri("sqlite:///company.db")
sql_chain = create_sql_chain(llm, db)
llm_available = True
except Exception as e:
st.error(f"连接本地 LLM 失败: {e}")
llm_available = False
return
# --- 快速开始 ---
st.markdown("### 💡 快速开始")
example_questions = [
"每个部门有多少员工?",
"每个员工的总销售额是多少?",
"列出所有员工及其薪资,按薪资降序排列",
"按部门的平均工资是多少?",
"销售额前三的员工是谁?"
]
quick_cols = st.columns(len(example_questions))
for i, question in enumerate(example_questions):
if quick_cols[i].button(f"「{question}」", key=f"quick_{i}", use_container_width=True):
st.session_state.user_query_input = question
st.session_state.trigger_submit = True
st.rerun()
# --- 主输入区:单行布局(输入框 + 按钮)---
# st.markdown("<br><br>", unsafe_allow_html=True) # 确保底部有足够空间
# 使用一个空容器来控制布局
input_container = st.empty()
with input_container.container():
col_input, col_btn = st.columns([9, 1], gap="small") # 添加 gap 参数
# 确保 session_state 中有这个 key,初始化为空字符串
if "user_query_input" not in st.session_state:
st.session_state.user_query_input = ""
with col_input:
user_input = st.text_input(
label="您的问题",
# 移除了 value 参数!
placeholder="例如:每个部门有多少员工?",
label_visibility="collapsed",
key="user_query_input", # Streamlit 会自动将输入内容同步到 st.session_state["user_query_input"]
)
with col_btn:
# 关键:移除所有手动的 margin 调整
send_clicked = st.button("🚀", use_container_width=True)
# --- 处理提交逻辑(来自按钮点击或示例点击)---
# --- 处理提交逻辑(来自按钮点击或示例点击)---
should_submit = send_clicked or st.session_state.get("trigger_submit", False)
if should_submit and user_input.strip():
# 清除触发标记
if "trigger_submit" in st.session_state:
del st.session_state.trigger_submit
# 添加用户消息
st.session_state.chat_history.append({
"role": "user",
"content": user_input.strip()
})
# LLM 生成 SQL 并执行...
try:
with st.spinner("🧠 思考中..."):
sql_query = sql_chain(user_input.strip())
st.session_state.chat_history.append({
"role": "assistant",
"content": sql_query,
"type": "sql"
})
with st.spinner("⏱️ 执行查询..."):
df_result = execute_sql(sql_query)
if not df_result.empty:
# --- 渲染表格 ---
table_html = display_dataframe_in_chat(df_result)
st.session_state.chat_history.append({
"role": "assistant",
"content": table_html,
"type": "data"
})
st.session_state.last_csv = df_result.to_csv(index=False)
# --- 渲染图表 ---
# 图表已经在 display_dataframe_in_chat 中处理,不需要单独调用
else:
st.session_state.chat_history.append({
"role": "assistant",
"content": "查询成功,但无数据返回。",
"type": "info"
})
st.session_state.pop("last_csv", None)
except Exception as e:
st.session_state.chat_history.append({
"role": "assistant",
"content": f"❌ 错误:{str(e)}",
"type": "error"
})
st.rerun()
if __name__ == "__main__":
main()