eis/py/comlib/db/DBOperator.py

589 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sqlalchemy import create_engine, MetaData, Table, select, update, delete, insert
from sqlalchemy.engine import URL
from sqlalchemy.exc import SQLAlchemyError
from typing import Optional, Dict, List, Any, Union
import pandas as pd
from log.LogUtil import LogUtil
import logging as d
#import os
#os.environ['IBM_DB_HOME'] = r"D:\win\IBM\SQLLIB\dsdriver"
#os.add_dll_directory(r"D:\win\IBM\SQLLIB\BIN")
#import ibm_db
"""
DBOperator - 通用数据库操作类
功能:
- 提供基础的数据库连接管理
- 支持SQL查询、更新、删除和插入操作
- 支持连接池配置
- 支持多种数据库类型
支持的数据库类型:
- MySQL
- PostgreSQL
- SQLite
- Oracle
- IBM DB2 (需额外配置)
依赖:
- sqlalchemy >= 1.4.0
- pandas
- ibm_db (仅用于IBM DB2)
- pymysql (用于MySQL)
- psycopg2 (用于PostgreSQL)
Author
- Author : zoufuzhou
- Date : 2025-05-21 16:34:37
- Description : db operator tool
- LastEditTime : 2025-05-21 16:34:37
"""
class DBOperator:
"""通用数据库操作类"""
def __init__(self,
db_type: str,
host: str,
port: int,
database: str,
username: str,
password: str,
pool_size: int = 5):
"""
初始化数据库连接
:param db_type: 数据库类型(mysql/postgresql/sqlite/oracle等)
:param host: 主机地址
:param port: 端口号
:param database: 数据库名
:param username: 用户名
:param password: 密码
:param pool_size: 连接池大小
"""
self.db_type = db_type.lower()
self.host = host
self.port = port
self.database = database
self.username = username
self.password = password
# 创建连接引擎
self.engine = self._create_engine(pool_size)
self.metadata = MetaData()
def _create_engine(self, pool_size: int):
"""创建SQLAlchemy引擎"""
try:
# 生成连接URL
url = self._generate_connection_url()
d.info(f"创建{self.db_type}数据库引擎")
return create_engine(
url,
pool_size=pool_size,
max_overflow=10,
pool_pre_ping=True,
pool_recycle=3600
)
except Exception as e:
d.error(f"创建数据库引擎失败: {e}")
raise
def _generate_connection_url(self) -> URL:
"""生成数据库连接URL"""
if self.db_type == "mysql":
return URL.create(
"mysql+pymysql",
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database
)
elif self.db_type == "postgresql":
return URL.create(
"postgresql+psycopg2",
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database
)
elif self.db_type == "sqlite":
return f"sqlite:///{self.database}"
elif self.db_type == "oracle":
return URL.create(
"oracle+cx_oracle",
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database
)
elif self.db_type == "db2":
return URL.create(
"db2+ibm_db",
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database
)
else:
raise ValueError(f"不支持的数据库类型: {self.db_type}")
def execute_query(self, table_name: str,
filters: Optional[Dict[str, Any]] = None,
columns: Optional[List[str]] = None,
as_dataframe: bool = False) -> Union[List[Dict[str, Any]], pd.DataFrame]:
"""
执行查询
:param table_name: 表名
:param filters: 过滤条件字典 {列名: 值}
:param columns: 要查询的列名列表
:param as_dataframe: 是否返回DataFrame
:return: 结果字典列表或DataFrame
"""
try:
table = Table(table_name, self.metadata, autoload_with=self.engine)
# 构建查询
query = select(table)
# 添加过滤条件(不区分大小写)
if filters:
# 查找匹配的列名(不区分大小写)
conditions = []
for k, v in filters.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
col_name = matching_cols[0].name
col_obj = getattr(table.c, col_name)
# 处理特殊操作符
if isinstance(v, dict) and len(v) == 1:
op, val = next(iter(v.items()))
if op == '$like':
conditions.append(col_obj.like(val))
elif op == '$gt':
conditions.append(col_obj > val)
elif op == '$gte':
conditions.append(col_obj >= val)
elif op == '$lt':
conditions.append(col_obj < val)
elif op == '$lte':
conditions.append(col_obj <= val)
elif op == '$between':
if isinstance(val, (list, tuple)) and len(val) == 2:
conditions.append(col_obj.between(val[0], val[1]))
else:
# 默认等于操作
conditions.append(col_obj == v)
if conditions:
query = query.where(*conditions)
# 指定列(不区分大小写)
if columns:
# 查找匹配的列名(不区分大小写)
selected_cols = []
for col in columns:
matching_cols = [table_col for table_col in table.columns
if table_col.name.lower() == col.lower()]
if matching_cols:
selected_cols.append(getattr(table.c, matching_cols[0].name))
if selected_cols:
query = query.with_only_columns(*selected_cols)
# 执行查询
with self.engine.connect() as conn:
result = conn.execute(query)
# 处理DB2结果集确保列名和值正确映射
columns = result.keys()
result_data = [dict(zip(columns, row)) for row in result]
if as_dataframe:
return pd.DataFrame(result_data)
return result_data
except SQLAlchemyError as e:
d.error(f"查询执行失败: {e}")
return pd.DataFrame() if as_dataframe else []
def execute_update(self, table_name: str,
data: Dict[str, Any],
filters: Optional[Dict[str, Any]] = None) -> int:
"""
执行更新
:param table_name: 表名
:param data: 更新数据字典 {列名: 新值}
:param filters: 过滤条件字典 {列名: 值}
:return: 影响的行数
"""
try:
table = Table(table_name, self.metadata, autoload_with=self.engine)
# 构建更新语句
stmt = update(table).values(**data)
# 添加过滤条件(不区分大小写)
if filters:
# 查找匹配的列名(不区分大小写)
conditions = []
for k, v in filters.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
conditions.append(getattr(table.c, matching_cols[0].name) == v)
if conditions:
stmt = stmt.where(*conditions)
# 执行更新
with self.engine.begin() as conn:
result = conn.execute(stmt)
return result.rowcount
except SQLAlchemyError as e:
d.error(f"更新执行失败: {e}")
return 0
def execute_insert(self, table_name: str, data: Dict[str, Any]) -> int:
"""
执行插入
:param table_name: 表名
:param data: 插入数据字典 {列名: 值}
:return: 插入的行数(通常为1)
"""
try:
table = Table(table_name, self.metadata, autoload_with=self.engine)
# 执行插入(不区分大小写)
with self.engine.begin() as conn:
# 转换数据键名为表列名(不区分大小写)
insert_data = {}
for k, v in data.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
insert_data[matching_cols[0].name] = v
if insert_data:
result = conn.execute(insert(table).values(**insert_data))
return result.rowcount
except SQLAlchemyError as e:
d.error(f"插入执行失败: {e}")
return 0
def execute_delete(self, table_name: str,
filters: Dict[str, Any]) -> int:
"""
执行删除
:param table_name: 表名
:param filters: 过滤条件字典 {列名: 值}
:return: 影响的行数
"""
try:
table = Table(table_name, self.metadata, autoload_with=self.engine)
# 执行删除(不区分大小写)
with self.engine.begin() as conn:
# 查找匹配的列名(不区分大小写)
conditions = []
for k, v in filters.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
conditions.append(getattr(table.c, matching_cols[0].name) == v)
if conditions:
result = conn.execute(delete(table).where(*conditions))
return result.rowcount
except SQLAlchemyError as e:
d.error(f"删除执行失败: {e}")
return 0
def write_dataframe(self,
df: pd.DataFrame,
table_name: str,
upsert_keys: Optional[List[str]] = None,
batch_size: int = 1000) -> int:
"""
将DataFrame写入数据库表(支持存在更新,不存在插入)
:param df: 要写入的DataFrame
:param table_name: 目标表名
:param upsert_keys: 用于判断记录是否存在的键列名列表
:param batch_size: 批量处理大小
:return: 处理的行数
"""
if df.empty:
d.warning("空DataFrame跳过写入")
return 0
try:
# 反射表结构
table = Table(table_name, self.metadata, autoload_with=self.engine)
# 检查DataFrame列是否匹配表结构(不区分大小写)
df_columns_lower = [col.lower() for col in df.columns]
missing_cols = [col.name for col in table.columns
if not col.nullable and col.name.lower() not in df_columns_lower]
if missing_cols:
raise ValueError(f"DataFrame缺少必填列: {missing_cols}")
# 记录表结构详情
table_info = {col.name: str(col.type) for col in table.columns}
# d.info(f"表{table_name}结构: {table_info}")
# 转换数据类型以匹配表定义(不区分大小写)
for col in table.columns:
# 查找匹配的DataFrame列名(不区分大小写)
matching_cols = [df_col for df_col in df.columns
if df_col.lower() == col.name.lower()]
if matching_cols:
df_col = matching_cols[0] # 取第一个匹配的列名
if 'INT' in str(col.type):
df[df_col] = df[df_col].astype('int64')
elif 'FLOAT' in str(col.type) or 'DECIMAL' in str(col.type):
df[df_col] = df[df_col].astype('float64')
elif 'DATE' in str(col.type) or 'TIME' in str(col.type):
df[df_col] = pd.to_datetime(df[df_col])
total_rows = len(df)
processed_rows = 0
# 分批处理
for i in range(0, total_rows, batch_size):
batch = df.iloc[i:i + batch_size]
data_list = batch.to_dict('records')
with self.engine.begin() as conn:
if upsert_keys is None:
# 纯插入模式,失败则转为更新
for data in data_list:
try:
# 转换数据键名为表列名(不区分大小写)
insert_data = {}
for k, v in data.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
insert_data[matching_cols[0].name] = v
if insert_data:
conn.execute(insert(table).values(**insert_data))
except SQLAlchemyError as e:
# 如果是唯一约束冲突,则转为更新
if "unique constraint" in str(e).lower() or "duplicate key" in str(e).lower():
d.warning(f"插入冲突,转为更新: {e}")
# 尝试提取冲突键
conflict_keys = []
if self.db_type == 'postgresql':
# PostgreSQL错误信息通常包含约束名
conflict_keys = [k for k in data.keys()
if k in table.primary_key.columns]
elif self.db_type == 'mysql':
# MySQL错误信息通常包含键名
conflict_keys = [k for k in data.keys()
if k in table.primary_key.columns]
elif self.db_type == 'db2':
# DB2错误信息通常包含键名
conflict_keys = [k for k in data.keys()
if k in table.primary_key.columns]
if not conflict_keys:
conflict_keys = list(data.keys())
# 执行更新
update_stmt = update(table).values(**data)
for key in conflict_keys:
update_stmt = update_stmt.where(
getattr(table.c, key) == data[key]
)
conn.execute(update_stmt)
else:
# 其他错误仍然抛出
# d.error(f"插入失败 - 表结构: {table_info}")
d.error(f"尝试插入的数据: {data}")
if "NOT NULL" in str(e):
missing = [col.name for col in table.columns
if not col.nullable and col.name not in data]
raise ValueError(f"缺少必填字段: {missing}") from e
raise
else:
# Upsert模式
if self.db_type == 'postgresql':
# PostgreSQL使用ON CONFLICT语法
for data in data_list:
# 转换数据键名为表列名(不区分大小写)
insert_data = {}
for k, v in data.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
insert_data[matching_cols[0].name] = v
# 转换upsert_keys为表列名(不区分大小写)
actual_upsert_keys = []
for key in upsert_keys:
matching_cols = [col for col in table.columns if col.name.lower() == key.lower()]
if matching_cols:
actual_upsert_keys.append(matching_cols[0].name)
if insert_data and actual_upsert_keys:
update_dict = {k: v for k, v in insert_data.items()
if k not in actual_upsert_keys}
stmt = insert(table).values(**insert_data)
stmt = stmt.on_conflict_do_update(
index_elements=actual_upsert_keys,
set_=update_dict
)
conn.execute(stmt)
elif self.db_type == 'mysql':
# MySQL使用ON DUPLICATE KEY UPDATE语法
for data in data_list:
# 转换数据键名为表列名(不区分大小写)
insert_data = {}
for k, v in data.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
insert_data[matching_cols[0].name] = v
# 转换upsert_keys为表列名(不区分大小写)
actual_upsert_keys = []
for key in upsert_keys:
matching_cols = [col for col in table.columns if col.name.lower() == key.lower()]
if matching_cols:
actual_upsert_keys.append(matching_cols[0].name)
if insert_data and actual_upsert_keys:
update_dict = {k: v for k, v in insert_data.items()
if k not in actual_upsert_keys}
stmt = insert(table).values(**insert_data)
stmt = stmt.on_duplicate_key_update(**update_dict)
conn.execute(stmt)
elif self.db_type == 'db2':
# DB2使用MERGE语法
for data in data_list:
# 转换数据键名为表列名(不区分大小写)
insert_data = {}
for k, v in data.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
insert_data[matching_cols[0].name] = v
# 转换upsert_keys为表列名(不区分大小写)
actual_upsert_keys = []
for key in upsert_keys:
matching_cols = [col for col in table.columns if col.name.lower() == key.lower()]
if matching_cols:
actual_upsert_keys.append(matching_cols[0].name)
if insert_data and actual_upsert_keys:
# 构建MERGE语句
update_cols = [col for col in insert_data.keys()
if col not in actual_upsert_keys]
merge_set = {col: insert_data[col] for col in update_cols}
merge_condition = {key: insert_data[key] for key in actual_upsert_keys}
# 先尝试更新
update_stmt = update(table).values(**merge_set)
for key in actual_upsert_keys:
update_stmt = update_stmt.where(
getattr(table.c, key) == merge_condition[key]
)
result = conn.execute(update_stmt)
# 如果没有更新到行,则插入
if result.rowcount == 0:
conn.execute(insert(table).values(**insert_data))
else:
raise ValueError(f"不支持的数据库类型: {self.db_type}")
processed_rows += len(data_list)
d.info(f"已处理{processed_rows}/{total_rows}行到{table_name}")
return processed_rows
except SQLAlchemyError as e:
d.error(f"DataFrame写入失败: {e}")
raise
def close(self):
"""关闭所有连接"""
self.engine.dispose()
d.info("数据库连接已关闭")
# 使用示例
if __name__ == "__main__":
# DB2示例
db2_db = DBOperator(
db_type="db2",
host="192.168.137.100",
port=50000,
database="appdb",
username="som",
password="dscdsc1"
)
LogUtil.init("app")
# # 查询示例
# # db2_results = db2_db.execute_query("T_TCM_DEC", {"entId": 'test7'}, as_dataframe=True)
db2_results = db2_db.execute_query("T_TCM_DEC", {"entId": {"$like":'test%'}}, as_dataframe=True)
# # db2_results = db2_db.execute_query("T_TCM_DEC",as_dataframe=True)
print("DB2经理员工:", db2_results)
#
# # # 插入示例
# # db2_insert_count = db2_db.execute_insert("T_TCM_DEC", {
# # "entId": "test3",
# # "f1e_ten_deviate": 1
# # })
# # print(f"DB2插入了{db2_insert_count}条记录")
#
# # DataFrame写入示例
# try:
# # 创建测试DataFrame - 包含所有必填字段并确保类型匹配
# test_data = {
# 'entid': ['test7', 'test8', 'test9'],
# 'F1e_ten_deviate': [1, 0, 1]
# }
#
#
# test_df = pd.DataFrame(test_data)
#
# # # 打印测试数据详情
# # print("\n测试数据详情:")
# # print(test_df.dtypes)
#
# # print("\n测试DataFrame写入DB2:")
# # print(test_df)
#
# # 1. 纯插入模式
# # rows_inserted = db2_db.write_dataframe(
# # df=test_df,
# # table_name="T_TCM_DEC",
# # batch_size=2
# # )
# # print(f"成功插入{rows_inserted}行到T_TCM_DEC表")
#
# # # 2. Upsert模式
# test_data_update = {
# 'entid': ['test7', 'test8', 'test10'], # test10是新记录
# 'f1e_ten_deviate': [2, 8, 10]
# }
# # 确保包含所有必填字段
# test_df_update = pd.DataFrame(test_data_update)
#
# print("\n测试Upsert操作:")
# print(test_df_update)
#
# rows_processed = db2_db.write_dataframe(
# df=test_df_update,
# table_name="T_TCM_DEC",
# upsert_keys=['entid'], # 使用entid作为唯一键
# batch_size=2
# )
# print(f"成功处理{rows_processed}行(更新2行+插入1行)")
#
# except Exception as e:
# print(f"DataFrame写入失败: {e}")
#
# # 关闭连接
# db2_db.close()