Bootstrap

fastapi 调用ollama之下的sqlcoder模式进行对话操作数据库

from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
import ollama
import mysql.connector
from mysql.connector.cursor import MySQLCursor
import json

app = FastAPI()

# 数据库连接配置
DB_CONFIG = {
    "database": "web",        # 您的数据库名,用于存储业务数据
    "user": "root",          # 数据库用户名,需要有读写权限
    "password": "XXXXXX",    # 数据库密码,建议使用强密码
    "host": "127.0.0.1",    # 数据库主机地址,本地开发环境使用localhost
    "port": "3306"          # MySQL 默认端口,可根据实际配置修改
}

# 数据库连接函数
def get_db_connection():
    try:
        conn = mysql.connector.connect(**DB_CONFIG)
        return conn
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"数据库连接失败: {str(e)}")

class SQLRequest(BaseModel):
    question: str

def get_table_relationships():
    """动态获取表之间的关联关系"""
    conn = get_db_connection()
    cur = conn.cursor()
    try:
        # 获取当前数据库名
        cur.execute("SELECT DATABASE()")
        db_name = cur.fetchone()[0]
        
        # 获取外键关系
        cur.execute("""
            SELECT 
                TABLE_NAME,
                COLUMN_NAME,
                REFERENCED_TABLE_NAME,
                REFERENCED_COLUMN_NAME
            FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
            WHERE TABLE_SCHEMA = %s
                AND REFERENCED_TABLE_NAME IS NOT NULL
            ORDER BY TABLE_NAME, COLUMN_NAME
        """, (db_name,))
        
        relationships = []
        for row in rows:
            table_name, column_name, ref_table, ref_column = row
            relationships.append(
                f"-- {table_name}.{column_name} can be joined with {ref_table}.{ref_column}"
            )
        
        return "\n".join(relationships) if relationships else "-- No foreign key relationships found"
        
    finally:
        cur.close()
        conn.close()

def get_database_schema():
    """获取MySQL数据库表结构,以CREATE TABLE格式返回"""
    conn = get_db_connection()
    cur = conn.cursor()
    try:
        # 获取当前数据库名
        cur.execute("SELECT DATABASE()")
        db_name = cur.fetchone()[0]
        
        # 获取所有表的结构信息
        cur.execute("""
            SELECT 
                t.TABLE_NAME,
                c.COLUMN_NAME,
                c.COLUMN_TYPE,
                c.IS_NULLABLE,
                c.COLUMN_KEY,
                c.COLUMN_COMMENT
            FROM INFORMATION_SCHEMA.TABLES t
            JOIN INFORMATION_SCHEMA.COLUMNS c 
                ON t.TABLE_NAME = c.TABLE_NAME
            WHERE t.TABLE_SCHEMA = %s
                AND t.TABLE_TYPE = 'BASE TABLE'
            ORDER BY t.TABLE_NAME, c.ORDINAL_POSITION
        """, (db_name,))
        
        rows = cur.fetchall()
        
        schema = []
        current_table = None
        table_columns = []
        
        for row in rows:
            table_name, column_name, column_type, nullable, key, comment = row
            
            if current_table != table_name:
                if current_table is not None:
                    schema.append(f"CREATE TABLE {current_table} (\n" + 
                                ",\n".join(table_columns) + 
                                "\n);\n")
                current_table = table_name
                table_columns = []
            
            # 构建列定义
            column_def = f"  {column_name} {column_type.upper()}"
            if key == "PRI":
                column_def += " PRIMARY KEY"
            elif nullable == "NO":
                column_def += " NOT NULL"
                
            if comment:
                column_def += f" -- {comment}"
                
            table_columns.append(column_def)
        
        # 添加最后一个表
        if current_table is not None:
            schema.append(f"CREATE TABLE {current_table} (\n" + 
                        ",\n".join(table_columns) + 
                        "\n);\n")
            
        return "\n".join(schema)
    finally:
        cur.close()
        conn.close()

def get_chinese_table_mapping():
    """动态生成表名的中文映射"""
    conn = get_db_connection()
    cur = conn.cursor()
    try:
        # 获取所有表的注释信息
        cur.execute("""
            SELECT 
                t.TABLE_NAME,
                t.TABLE_COMMENT
            FROM information_schema.TABLES t
            WHERE t.TABLE_SCHEMA = DATABASE()
            ORDER BY t.TABLE_NAME
        """)
        
        mappings = []
        for table_name, table_comment in cur.fetchall():
            # 生成表的中文名称
            chinese_name = table_name
            if table_name.startswith('web_'):
                chinese_name = table_name.replace('web_', '').replace('_', '')
            if table_comment:
                chinese_name = table_comment.split('--')[0].strip()
                # 如果中文名称以"表"结尾,则去掉"表"if chinese_name.endswith('表'):
                    chinese_name = chinese_name[:-1]
            
            mappings.append(f'           - "{chinese_name}" -> {table_name} table')
        
        return "\n".join(mappings)
    finally:
        cur.close()
        conn.close()

@app.post("/query")
async def query_database(request: Request):
    try:
        # 获取请求体数据并确保正确处理中文
        body = await request.body()
        try:
            request_data = json.loads(body.decode('utf-8'))
        except UnicodeDecodeError:
            request_data = json.loads(body.decode('gbk'))
        
        question = request_data.get('question')
        print(f"收到问题: {question}")  # 调试日志
        
        if not question:
            raise HTTPException(status_code=400, detail="缺少 question 参数")
            
        # 获取数据库结构
        db_schema = get_database_schema()
        #print(f"数据库结构: {db_schema}")  # 调试日志
        
        # 获取中文映射并打印
        chinese_mapping = get_chinese_table_mapping()
        #print(f"表映射关系:\n{chinese_mapping}")  # 添加这行来打印映射
        
        # 修改 prompt 使用更严格的指导
        prompt = f"""
        ### Instructions:
        Convert Chinese question to MySQL query. Follow these rules strictly:
        1. ONLY return a valid SELECT SQL query
        2. Use EXACT table names from the mapping below
        3. DO NOT use any table that's not in the mapping
        4. For Chinese terms, use these exact mappings:
{chinese_mapping}

        ### Examples:
        Question: 所有装修记录
        SQL: SELECT * FROM web_decoration ORDER BY id;

        Question: 查询装修
        SQL: SELECT * FROM web_decoration ORDER BY id;

        ### Database Schema:
        {db_schema}

        ### Question:
        {question}

        ### SQL Query:
        """
        
        # 获取 SQL 查询
        response = ollama.chat(model='sqlcoder:latest', 
                             messages=[{'role': 'user', 'content': prompt}])
        sql_query = response['message']['content'].strip()
        
        print(f"生成的SQL: {sql_query}")  # 调试日志
        
        # 验证 SQL 查询
        if "装修" in question and "web_decoration" not in sql_query:
            sql_query = "SELECT * FROM web_decoration ORDER BY id"
            
        if not sql_query.upper().startswith('SELECT'):
            raise HTTPException(status_code=400, detail="无效的SQL查询格式")
            
        conn = get_db_connection()
        cur = conn.cursor(dictionary=True)
        
        try:
            cur.execute(sql_query)
            results = cur.fetchall()
            return {
                "sql_query": sql_query,
                "results": results
            }
        except mysql.connector.Error as e:
            raise HTTPException(status_code=400, detail=f"SQL 执行错误: {str(e)}")
        finally:
            cur.close()
            conn.close()
            
    except json.JSONDecodeError as e:
        raise HTTPException(status_code=400, detail=f"JSON 解析错误: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"处理请求时生错误: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    port = 8666
    print(f"Starting server on port {port}")
    uvicorn.run(
        "main:app",
        host="0.0.0.0", 
        port=port,
        reload=True
    )

我来为您解释这段代码的主要功能和结构:

主要功能

这是一个使用 FastAPI 框架开发的 API 服务,主要用于将中文问题转换为 SQL 查询语句并执行。

核心组件说明

  1. 数据库配置与连接
DB_CONFIG = {
    "database": "web",
    "user": "root",
    "password": "XXXXXX",
    "host": "127.0.0.1",
    "port": "3306"
}

提供了 MySQL 数据库的基本连接配置。

  1. 主要工具函数
  • get_table_relationships(): 获取数据库表之间的外键关系
  • get_database_schema(): 获取数据库表结构
  • get_chinese_table_mapping(): 生成表名的中文映射关系
  1. 核心 API 端点
@app.post("/query")

这个端点接收中文问题,主要处理流程:

  • 接收并解析用户的中文问题
  • 获取数据库结构和表映射
  • 使用 ollama 模型将中文转换为 SQL 查询
  • 执行 SQL 查询并返回结果
  1. 智能转换功能
    使用 ollamasqlcoder 模型将中文问题转换为 SQL 查询,包含:
  • 严格的表名映射
  • SQL 查询验证
  • 错误处理机制

特点

  1. 支持中文输入处理
  2. 自动获取数据库结构
  3. 动态生成中文表名映射
  4. 完善的错误处理机制
  5. 支持热重载的开发模式

使用示例

可以通过 POST 请求访问 /query 端点:

{
    "question": "查询所有装修记录"
}

服务会返回:

{
    "sql_query": "SELECT * FROM web_decoration ORDER BY id",
    "results": [...]
}

安全特性

  1. 数据库连接错误处理
  2. SQL 注入防护
  3. 请求体编码自适应(支持 UTF-8 和 GBK)
  4. 查询结果的安全封装

查看效果:
在这里插入图片描述

;