Bootstrap

Python Sqlalchemy基础使用

Python Sqlalchemy基础使用

这里记录一下,在编写python代码过程中使用Sqlalchemy的封装和基本使用方法。

(持续完善ing)

基本使用

创建Session

这里采用读取yaml配置文件+调用函数的方式进行创建:

from util.database.sql_execute import create_session
from util.file.config_read import CONFIG

def get_db_config():
    """
    获取数据库配置
    Returns: 数据库配置(dict)

    """
    return CONFIG['database']['business-5']


# 创建数据库会话
DBSession = create_session(get_db_config())

session = DBSession()

yaml配置:

# 数据库配置
database:
  business-245:
    dbms: 'mysql'
    driver: 'pymysql'
    ip: '10.135.149.245'
    port: '3306'
    username: 'xxxx'
    password: 'xxxx'
    database: 'xxxx'
  business-119:
    dbms: 'mssql'
    driver: 'pymssql'
    ip: '10.135.30.119'
    port: '1433'
    username: 'xxxx'
    password: 'xxxx'
    database: 'xxxx'
  business-5:
    dbms: 'mssql'
    driver: 'pymssql'
    ip: '10.135.149.5'
    port: '1433'
    username: 'xxx'
    password: 'xxxx'
    database: 'xxx'

封装:

from urllib.parse import quote_plus as urlquote

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.dialects.mysql import insert

# 创建对象的基类:
Base = declarative_base()


def create_session(params: dict):
    """
    创建数据库会话
    Examples:
    params = {'dbms': 'mysql',
              'driver': 'pymysql',
              'ip': '192.168.30.193',
              'port': '3306',
              'username': 'root',
              'password': 'xxxx',
              'database': 'disaster_census'}

    Args:
        params: 请求参数说明 (dbms: 数据库类型, driver: 数据库驱动, username: 数据库用户名, password: 数据库密码, ip: IP地址, port: 端口号, database: 数据库)

    Returns: 数据库会话

    """
    # 构建url
    # engine = create_engine("数据库类型+数据库驱动://数据库用户名:数据库密码@IP地址:端口号/数据库?编码...", 其它参数)
    url = f"{params['dbms']}+{params['driver']}://{params['username']}:{urlquote(params['password'])}@{params['ip']}:{params['port']}/{params['database']}"

    # 创建DBSession类型:
    return create_session_url(url)


def create_session_url(url: str):
    """
    创建数据库会话

    Args:
        url: "数据库类型+数据库驱动://数据库用户名:数据库密码@IP地址:端口号/数据库?编码...", 例:mysql+pymysql://root:password@localhost:3306/test

    Returns:数据库会话

    """
    # 初始化数据库连接:
    # engine = create_engine("数据库类型+数据库驱动://
    # 数据库用户名:数据库密码@
    # IP地址:端口号/数据库?编码...", 其它参数)
    engine = create_engine(url)

    # 创建DBSession类型:
    db_session = sessionmaker(bind=engine)
    return db_session
import yaml
import os

from util.file.path import get_root_path


def get_config_path():
    """

    获取项目配置文件路径, 为项目根路径下的config目录

    Returns: 项目配置文件路径

    """

    return get_root_path() + '/config'


def get_profile():
    """

    读取profile配置,例: prod、dev等。(从项目根目录下./config/config.yaml文件中, 获取'profile: dev' 中的配置信息)

    Returns: profile信息,不存在为''

    """
    profile = ''

    config_path = get_config_path()
    filename = 'config.yaml'

    profile_path = f'{config_path}/{filename}'

    if os.path.exists(profile_path):
        print(f"读取配置文件 '{filename}' 成功!")

        with open(profile_path, 'r') as f:
            yaml_config = yaml.load(f.read(), Loader=yaml.FullLoader)
            profile = yaml_config['profile']
    else:
        print(f"读取profile配置文件失败, '{filename}' 不存在!")
    return profile


def get_config_yaml():
    """

    获取yaml配置文件 (约定目录为项目根目录,文件名为 ['config-{profile}.yml', 'config-{profile}.yaml', ])

    Returns: dict对象

    """

    # 1.获取项目路径(以当前文件为参考)和配置文件名称
    config_path = get_config_path()
    profile = get_profile()
    filenames = [f'config-{profile}.yml', f'config-{profile}.yaml', ]

    config = {}

    print(f'开始读取配置文件,path = {config_path}')

    # 2.依次读取yaml配置文件(相同key时,后面覆盖前面)
    for i in range(len(filenames)):
        # 2.1.获取配置文件相对路径
        filename = filenames[i]
        path = config_path + '/' + filename

        # 读取配置文件,并打印日志
        if os.path.exists(path):
            print(f"读取配置文件 '{filename}' 成功!")

            with open(path, 'r') as f:
                yaml_config = yaml.load(f.read(), Loader=yaml.FullLoader)
                config.update(yaml_config)  # 更新会导致第一层直接覆盖,后续需要进行深拷贝
        else:
            print(f"读取配置文件失败, '{filename}' 不存在!")

    return config


CONFIG = get_config_yaml()

创建ORM对象

import datetime

from sqlalchemy import Column, String, Integer, DateTime, DECIMAL, BigInteger, and_
from util.database.sql_execute import Base


# 定义 TyphoonInfo(台风轨迹)对象:
class TyphoonInfo(Base):
    # 表的名字:
    __tablename__ = 'typhoon_info'

    # 表的结构:
    id = Column(BigInteger)
    sort = Column(Integer)
    code_global = Column(Integer)
    track_count = Column(Integer)
    cyclone_no = Column(Integer)
    code_cn = Column(Integer)
    cyclone_end = Column(Integer)
    typhoon_name_en = Column(String(255))
    typhoon_name_cn = Column(String(255))
    observe_time = Column(DateTime)
    start_time = Column(DateTime)
    end_time = Column(DateTime)
    zj_across = Column(Integer, default=0)

    create_time = Column(DateTime, default=datetime.datetime.now())
    create_by = Column(String(32), default='-1')
    modify_time = Column(DateTime)
    modify_by = Column(String(32))
    

    __mapper_args__ = {
        "primary_key": [id]
    }

    def __repr__(self):
        attributes = ', '.join(f"{k}={v!r}" for k, v in vars(self).items())
        return f"{self.__class__.__name__}({attributes})"

查询

创建好session后,在需要的地方导入并执行:

from src.config.db_business_5 import session as session5

def select_typhoon_info_year(year):
    start_time = datetime.datetime(year, 1, 1, 0, 0)
    end_time = datetime.datetime(year, 12, 31, 23, 59)

    typhoon_list = session5.query(TyphoonInfo) \
        .filter(
        or_(
            and_(TyphoonInfo.start_time >= start_time, TyphoonInfo.start_time <= end_time, TyphoonInfo.zj_across == 1),
            and_(TyphoonInfo.end_time >= start_time, TyphoonInfo.end_time <= end_time, TyphoonInfo.zj_across == 1)
        )
    )

    return typhoon_list

插入

普通插入(也可以使用后面进阶操作的存在时更新):

def insert_db(track_list: list):
    session.add_all(tracks)
    session.commit()
    return

进阶操作

插入存在时更新

封装:

def upsert(session, table_cls, records: list, chunk_size=5000, commit_on_chunk=True, except_cols_on_update=[]):
    """
    Examples: upsert(session, SurfStationTyphoonHistory, insert_list, except_cols_on_update=["typhoon_id", "station_id"])

    Args:
        session:
        table_cls: Bean class对象
        records: 入库列表
        chunk_size: 批量大小
        commit_on_chunk:
        except_cols_on_update: uk字段数组,如:[]

    Returns:

    """
    records_tmp = []

    # 收集字段名
    columns = []
    for key in vars(records[0]):
        if not key.startswith("_"):
            columns.append(key)

    # 转字典类型
    for r in records:
        obj = {}
        for c in columns:
            obj[c] = getattr(r, c)

        records_tmp.append(obj)

    update_keys = [key for key in records_tmp[0].keys() if
                   (key not in except_cols_on_update) and not key.startswith("_")]
    for i in range(0, len(records_tmp), chunk_size):
        chunk = records_tmp[i:i + chunk_size]
        insert_stmt = insert(table_cls).values(chunk)
        update_columns = {x.name: x for x in insert_stmt.inserted if x.name in update_keys}
        upsert_stmt = insert_stmt.on_duplicate_key_update(**update_columns)
        session.execute(upsert_stmt)
        if commit_on_chunk:
            session.commit()

插入:

from src.config.db_business_245 import session as session245
from util.database.sql_execute import upsert

def insert_stats_surf_station_typhoon_history(insert_list: list):
    upsert(session245, SurfStationTyphoonHistory, insert_list, except_cols_on_update=["typhoon_id", "station_id"])
    return

执行SQL

if __name__ == '__main__':
    year = 2024
    start_time = '2023-01-01 00:00:00'
    end_time = '2023-01-01 23:59:00'
    result_list = session5.execute(text(u"select 区站号,日期时间,一小时雨量,最大风速,极大风速,最大风速对应时间,极大风速对应时间 from zdz.CAWSData_2023 where 日期时间 between :start_time and :end_time"), params={'start_time': start_time, 'end_time': end_time})
    for row in result_list:
        station_ic_c = row[0]
        observe_time = row[1]
        pre_1h = row[2]
        win_max = row[3]
        win_max_inst = row[4]
        win_max_ot = row[5]
        win_max_inst_ow = row[6]
        # print(row)
    print()
;