Python Sqlalchemy基础使用
这里记录一下,在编写python代码过程中使用Sqlalchemy
的封装和基本使用方法。
(持续完善ing)
基本使用
创建Session
这里采用读取yaml
配置文件+调用函数的方式进行创建:
from util.database.sql_execute import create_session
from util.file.config_read import CONFIG
def get_db_config():
"""
获取数据库配置
Returns: 数据库配置(dict)
"""
return CONFIG['database']['business-5']
# 创建数据库会话
DBSession = create_session(get_db_config())
session = DBSession()
yaml配置:
# 数据库配置
database:
business-245:
dbms: 'mysql'
driver: 'pymysql'
ip: '10.135.149.245'
port: '3306'
username: 'xxxx'
password: 'xxxx'
database: 'xxxx'
business-119:
dbms: 'mssql'
driver: 'pymssql'
ip: '10.135.30.119'
port: '1433'
username: 'xxxx'
password: 'xxxx'
database: 'xxxx'
business-5:
dbms: 'mssql'
driver: 'pymssql'
ip: '10.135.149.5'
port: '1433'
username: 'xxx'
password: 'xxxx'
database: 'xxx'
封装:
from urllib.parse import quote_plus as urlquote
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.dialects.mysql import insert
# 创建对象的基类:
Base = declarative_base()
def create_session(params: dict):
"""
创建数据库会话
Examples:
params = {'dbms': 'mysql',
'driver': 'pymysql',
'ip': '192.168.30.193',
'port': '3306',
'username': 'root',
'password': 'xxxx',
'database': 'disaster_census'}
Args:
params: 请求参数说明 (dbms: 数据库类型, driver: 数据库驱动, username: 数据库用户名, password: 数据库密码, ip: IP地址, port: 端口号, database: 数据库)
Returns: 数据库会话
"""
# 构建url
# engine = create_engine("数据库类型+数据库驱动://数据库用户名:数据库密码@IP地址:端口号/数据库?编码...", 其它参数)
url = f"{params['dbms']}+{params['driver']}://{params['username']}:{urlquote(params['password'])}@{params['ip']}:{params['port']}/{params['database']}"
# 创建DBSession类型:
return create_session_url(url)
def create_session_url(url: str):
"""
创建数据库会话
Args:
url: "数据库类型+数据库驱动://数据库用户名:数据库密码@IP地址:端口号/数据库?编码...", 例:mysql+pymysql://root:password@localhost:3306/test
Returns:数据库会话
"""
# 初始化数据库连接:
# engine = create_engine("数据库类型+数据库驱动://
# 数据库用户名:数据库密码@
# IP地址:端口号/数据库?编码...", 其它参数)
engine = create_engine(url)
# 创建DBSession类型:
db_session = sessionmaker(bind=engine)
return db_session
import yaml
import os
from util.file.path import get_root_path
def get_config_path():
"""
获取项目配置文件路径, 为项目根路径下的config目录
Returns: 项目配置文件路径
"""
return get_root_path() + '/config'
def get_profile():
"""
读取profile配置,例: prod、dev等。(从项目根目录下./config/config.yaml文件中, 获取'profile: dev' 中的配置信息)
Returns: profile信息,不存在为''
"""
profile = ''
config_path = get_config_path()
filename = 'config.yaml'
profile_path = f'{config_path}/{filename}'
if os.path.exists(profile_path):
print(f"读取配置文件 '{filename}' 成功!")
with open(profile_path, 'r') as f:
yaml_config = yaml.load(f.read(), Loader=yaml.FullLoader)
profile = yaml_config['profile']
else:
print(f"读取profile配置文件失败, '{filename}' 不存在!")
return profile
def get_config_yaml():
"""
获取yaml配置文件 (约定目录为项目根目录,文件名为 ['config-{profile}.yml', 'config-{profile}.yaml', ])
Returns: dict对象
"""
# 1.获取项目路径(以当前文件为参考)和配置文件名称
config_path = get_config_path()
profile = get_profile()
filenames = [f'config-{profile}.yml', f'config-{profile}.yaml', ]
config = {}
print(f'开始读取配置文件,path = {config_path}')
# 2.依次读取yaml配置文件(相同key时,后面覆盖前面)
for i in range(len(filenames)):
# 2.1.获取配置文件相对路径
filename = filenames[i]
path = config_path + '/' + filename
# 读取配置文件,并打印日志
if os.path.exists(path):
print(f"读取配置文件 '{filename}' 成功!")
with open(path, 'r') as f:
yaml_config = yaml.load(f.read(), Loader=yaml.FullLoader)
config.update(yaml_config) # 更新会导致第一层直接覆盖,后续需要进行深拷贝
else:
print(f"读取配置文件失败, '{filename}' 不存在!")
return config
CONFIG = get_config_yaml()
创建ORM对象
import datetime
from sqlalchemy import Column, String, Integer, DateTime, DECIMAL, BigInteger, and_
from util.database.sql_execute import Base
# 定义 TyphoonInfo(台风轨迹)对象:
class TyphoonInfo(Base):
# 表的名字:
__tablename__ = 'typhoon_info'
# 表的结构:
id = Column(BigInteger)
sort = Column(Integer)
code_global = Column(Integer)
track_count = Column(Integer)
cyclone_no = Column(Integer)
code_cn = Column(Integer)
cyclone_end = Column(Integer)
typhoon_name_en = Column(String(255))
typhoon_name_cn = Column(String(255))
observe_time = Column(DateTime)
start_time = Column(DateTime)
end_time = Column(DateTime)
zj_across = Column(Integer, default=0)
create_time = Column(DateTime, default=datetime.datetime.now())
create_by = Column(String(32), default='-1')
modify_time = Column(DateTime)
modify_by = Column(String(32))
__mapper_args__ = {
"primary_key": [id]
}
def __repr__(self):
attributes = ', '.join(f"{k}={v!r}" for k, v in vars(self).items())
return f"{self.__class__.__name__}({attributes})"
查询
创建好session后,在需要的地方导入并执行:
from src.config.db_business_5 import session as session5
def select_typhoon_info_year(year):
start_time = datetime.datetime(year, 1, 1, 0, 0)
end_time = datetime.datetime(year, 12, 31, 23, 59)
typhoon_list = session5.query(TyphoonInfo) \
.filter(
or_(
and_(TyphoonInfo.start_time >= start_time, TyphoonInfo.start_time <= end_time, TyphoonInfo.zj_across == 1),
and_(TyphoonInfo.end_time >= start_time, TyphoonInfo.end_time <= end_time, TyphoonInfo.zj_across == 1)
)
)
return typhoon_list
插入
普通插入(也可以使用后面进阶操作的存在时更新):
def insert_db(track_list: list):
session.add_all(tracks)
session.commit()
return
进阶操作
插入存在时更新
封装:
def upsert(session, table_cls, records: list, chunk_size=5000, commit_on_chunk=True, except_cols_on_update=[]):
"""
Examples: upsert(session, SurfStationTyphoonHistory, insert_list, except_cols_on_update=["typhoon_id", "station_id"])
Args:
session:
table_cls: Bean class对象
records: 入库列表
chunk_size: 批量大小
commit_on_chunk:
except_cols_on_update: uk字段数组,如:[]
Returns:
"""
records_tmp = []
# 收集字段名
columns = []
for key in vars(records[0]):
if not key.startswith("_"):
columns.append(key)
# 转字典类型
for r in records:
obj = {}
for c in columns:
obj[c] = getattr(r, c)
records_tmp.append(obj)
update_keys = [key for key in records_tmp[0].keys() if
(key not in except_cols_on_update) and not key.startswith("_")]
for i in range(0, len(records_tmp), chunk_size):
chunk = records_tmp[i:i + chunk_size]
insert_stmt = insert(table_cls).values(chunk)
update_columns = {x.name: x for x in insert_stmt.inserted if x.name in update_keys}
upsert_stmt = insert_stmt.on_duplicate_key_update(**update_columns)
session.execute(upsert_stmt)
if commit_on_chunk:
session.commit()
插入:
from src.config.db_business_245 import session as session245
from util.database.sql_execute import upsert
def insert_stats_surf_station_typhoon_history(insert_list: list):
upsert(session245, SurfStationTyphoonHistory, insert_list, except_cols_on_update=["typhoon_id", "station_id"])
return
执行SQL
if __name__ == '__main__':
year = 2024
start_time = '2023-01-01 00:00:00'
end_time = '2023-01-01 23:59:00'
result_list = session5.execute(text(u"select 区站号,日期时间,一小时雨量,最大风速,极大风速,最大风速对应时间,极大风速对应时间 from zdz.CAWSData_2023 where 日期时间 between :start_time and :end_time"), params={'start_time': start_time, 'end_time': end_time})
for row in result_list:
station_ic_c = row[0]
observe_time = row[1]
pre_1h = row[2]
win_max = row[3]
win_max_inst = row[4]
win_max_ot = row[5]
win_max_inst_ow = row[6]
# print(row)
print()