eis/py/comlib/db/DBOperator.py

589 lines
26 KiB
Python
Raw Normal View History

from sqlalchemy import create_engine, MetaData, Table, select, update, delete, insert
from sqlalchemy.engine import URL
from sqlalchemy.exc import SQLAlchemyError
from typing import Optional, Dict, List, Any, Union
import pandas as pd
from log.LogUtil import LogUtil
import logging as d
#import os
#os.environ['IBM_DB_HOME'] = r"D:\win\IBM\SQLLIB\dsdriver"
#os.add_dll_directory(r"D:\win\IBM\SQLLIB\BIN")
#import ibm_db
"""
DBOperator - 通用数据库操作类
功能:
- 提供基础的数据库连接管理
- 支持SQL查询更新删除和插入操作
- 支持连接池配置
- 支持多种数据库类型
支持的数据库类型:
- MySQL
- PostgreSQL
- SQLite
- Oracle
- IBM DB2 (需额外配置)
依赖:
- sqlalchemy >= 1.4.0
- pandas
- ibm_db (仅用于IBM DB2)
- pymysql (用于MySQL)
- psycopg2 (用于PostgreSQL)
Author
- Author : zoufuzhou
- Date : 2025-05-21 16:34:37
- Description : db operator tool
- LastEditTime : 2025-05-21 16:34:37
"""
class DBOperator:
"""通用数据库操作类"""
def __init__(self,
db_type: str,
host: str,
port: int,
database: str,
username: str,
password: str,
pool_size: int = 5):
"""
初始化数据库连接
:param db_type: 数据库类型(mysql/postgresql/sqlite/oracle等)
:param host: 主机地址
:param port: 端口号
:param database: 数据库名
:param username: 用户名
:param password: 密码
:param pool_size: 连接池大小
"""
self.db_type = db_type.lower()
self.host = host
self.port = port
self.database = database
self.username = username
self.password = password
# 创建连接引擎
self.engine = self._create_engine(pool_size)
self.metadata = MetaData()
def _create_engine(self, pool_size: int):
"""创建SQLAlchemy引擎"""
try:
# 生成连接URL
url = self._generate_connection_url()
d.info(f"创建{self.db_type}数据库引擎")
return create_engine(
url,
pool_size=pool_size,
max_overflow=10,
pool_pre_ping=True,
pool_recycle=3600
)
except Exception as e:
d.error(f"创建数据库引擎失败: {e}")
raise
def _generate_connection_url(self) -> URL:
"""生成数据库连接URL"""
if self.db_type == "mysql":
return URL.create(
"mysql+pymysql",
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database
)
elif self.db_type == "postgresql":
return URL.create(
"postgresql+psycopg2",
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database
)
elif self.db_type == "sqlite":
return f"sqlite:///{self.database}"
elif self.db_type == "oracle":
return URL.create(
"oracle+cx_oracle",
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database
)
elif self.db_type == "db2":
return URL.create(
"db2+ibm_db",
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database
)
else:
raise ValueError(f"不支持的数据库类型: {self.db_type}")
def execute_query(self, table_name: str,
filters: Optional[Dict[str, Any]] = None,
columns: Optional[List[str]] = None,
as_dataframe: bool = False) -> Union[List[Dict[str, Any]], pd.DataFrame]:
"""
执行查询
:param table_name: 表名
:param filters: 过滤条件字典 {列名: }
:param columns: 要查询的列名列表
:param as_dataframe: 是否返回DataFrame
:return: 结果字典列表或DataFrame
"""
try:
table = Table(table_name, self.metadata, autoload_with=self.engine)
# 构建查询
query = select(table)
# 添加过滤条件(不区分大小写)
if filters:
# 查找匹配的列名(不区分大小写)
conditions = []
for k, v in filters.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
col_name = matching_cols[0].name
col_obj = getattr(table.c, col_name)
# 处理特殊操作符
if isinstance(v, dict) and len(v) == 1:
op, val = next(iter(v.items()))
if op == '$like':
conditions.append(col_obj.like(val))
elif op == '$gt':
conditions.append(col_obj > val)
elif op == '$gte':
conditions.append(col_obj >= val)
elif op == '$lt':
conditions.append(col_obj < val)
elif op == '$lte':
conditions.append(col_obj <= val)
elif op == '$between':
if isinstance(val, (list, tuple)) and len(val) == 2:
conditions.append(col_obj.between(val[0], val[1]))
else:
# 默认等于操作
conditions.append(col_obj == v)
if conditions:
query = query.where(*conditions)
# 指定列(不区分大小写)
if columns:
# 查找匹配的列名(不区分大小写)
selected_cols = []
for col in columns:
matching_cols = [table_col for table_col in table.columns
if table_col.name.lower() == col.lower()]
if matching_cols:
selected_cols.append(getattr(table.c, matching_cols[0].name))
if selected_cols:
query = query.with_only_columns(*selected_cols)
# 执行查询
with self.engine.connect() as conn:
result = conn.execute(query)
# 处理DB2结果集确保列名和值正确映射
columns = result.keys()
result_data = [dict(zip(columns, row)) for row in result]
if as_dataframe:
return pd.DataFrame(result_data)
return result_data
except SQLAlchemyError as e:
d.error(f"查询执行失败: {e}")
return pd.DataFrame() if as_dataframe else []
def execute_update(self, table_name: str,
data: Dict[str, Any],
filters: Optional[Dict[str, Any]] = None) -> int:
"""
执行更新
:param table_name: 表名
:param data: 更新数据字典 {列名: 新值}
:param filters: 过滤条件字典 {列名: }
:return: 影响的行数
"""
try:
table = Table(table_name, self.metadata, autoload_with=self.engine)
# 构建更新语句
stmt = update(table).values(**data)
# 添加过滤条件(不区分大小写)
if filters:
# 查找匹配的列名(不区分大小写)
conditions = []
for k, v in filters.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
conditions.append(getattr(table.c, matching_cols[0].name) == v)
if conditions:
stmt = stmt.where(*conditions)
# 执行更新
with self.engine.begin() as conn:
result = conn.execute(stmt)
return result.rowcount
except SQLAlchemyError as e:
d.error(f"更新执行失败: {e}")
return 0
def execute_insert(self, table_name: str, data: Dict[str, Any]) -> int:
"""
执行插入
:param table_name: 表名
:param data: 插入数据字典 {列名: }
:return: 插入的行数(通常为1)
"""
try:
table = Table(table_name, self.metadata, autoload_with=self.engine)
# 执行插入(不区分大小写)
with self.engine.begin() as conn:
# 转换数据键名为表列名(不区分大小写)
insert_data = {}
for k, v in data.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
insert_data[matching_cols[0].name] = v
if insert_data:
result = conn.execute(insert(table).values(**insert_data))
return result.rowcount
except SQLAlchemyError as e:
d.error(f"插入执行失败: {e}")
return 0
def execute_delete(self, table_name: str,
filters: Dict[str, Any]) -> int:
"""
执行删除
:param table_name: 表名
:param filters: 过滤条件字典 {列名: }
:return: 影响的行数
"""
try:
table = Table(table_name, self.metadata, autoload_with=self.engine)
# 执行删除(不区分大小写)
with self.engine.begin() as conn:
# 查找匹配的列名(不区分大小写)
conditions = []
for k, v in filters.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
conditions.append(getattr(table.c, matching_cols[0].name) == v)
if conditions:
result = conn.execute(delete(table).where(*conditions))
return result.rowcount
except SQLAlchemyError as e:
d.error(f"删除执行失败: {e}")
return 0
def write_dataframe(self,
df: pd.DataFrame,
table_name: str,
upsert_keys: Optional[List[str]] = None,
batch_size: int = 1000) -> int:
"""
将DataFrame写入数据库表(支持存在更新不存在插入)
:param df: 要写入的DataFrame
:param table_name: 目标表名
:param upsert_keys: 用于判断记录是否存在的键列名列表
:param batch_size: 批量处理大小
:return: 处理的行数
"""
if df.empty:
d.warning("空DataFrame跳过写入")
return 0
try:
# 反射表结构
table = Table(table_name, self.metadata, autoload_with=self.engine)
# 检查DataFrame列是否匹配表结构(不区分大小写)
df_columns_lower = [col.lower() for col in df.columns]
missing_cols = [col.name for col in table.columns
if not col.nullable and col.name.lower() not in df_columns_lower]
if missing_cols:
raise ValueError(f"DataFrame缺少必填列: {missing_cols}")
# 记录表结构详情
table_info = {col.name: str(col.type) for col in table.columns}
# d.info(f"表{table_name}结构: {table_info}")
# 转换数据类型以匹配表定义(不区分大小写)
for col in table.columns:
# 查找匹配的DataFrame列名(不区分大小写)
matching_cols = [df_col for df_col in df.columns
if df_col.lower() == col.name.lower()]
if matching_cols:
df_col = matching_cols[0] # 取第一个匹配的列名
if 'INT' in str(col.type):
df[df_col] = df[df_col].astype('int64')
elif 'FLOAT' in str(col.type) or 'DECIMAL' in str(col.type):
df[df_col] = df[df_col].astype('float64')
elif 'DATE' in str(col.type) or 'TIME' in str(col.type):
df[df_col] = pd.to_datetime(df[df_col])
total_rows = len(df)
processed_rows = 0
# 分批处理
for i in range(0, total_rows, batch_size):
batch = df.iloc[i:i + batch_size]
data_list = batch.to_dict('records')
with self.engine.begin() as conn:
if upsert_keys is None:
# 纯插入模式,失败则转为更新
for data in data_list:
try:
# 转换数据键名为表列名(不区分大小写)
insert_data = {}
for k, v in data.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
insert_data[matching_cols[0].name] = v
if insert_data:
conn.execute(insert(table).values(**insert_data))
except SQLAlchemyError as e:
# 如果是唯一约束冲突,则转为更新
if "unique constraint" in str(e).lower() or "duplicate key" in str(e).lower():
d.warning(f"插入冲突,转为更新: {e}")
# 尝试提取冲突键
conflict_keys = []
if self.db_type == 'postgresql':
# PostgreSQL错误信息通常包含约束名
conflict_keys = [k for k in data.keys()
if k in table.primary_key.columns]
elif self.db_type == 'mysql':
# MySQL错误信息通常包含键名
conflict_keys = [k for k in data.keys()
if k in table.primary_key.columns]
elif self.db_type == 'db2':
# DB2错误信息通常包含键名
conflict_keys = [k for k in data.keys()
if k in table.primary_key.columns]
if not conflict_keys:
conflict_keys = list(data.keys())
# 执行更新
update_stmt = update(table).values(**data)
for key in conflict_keys:
update_stmt = update_stmt.where(
getattr(table.c, key) == data[key]
)
conn.execute(update_stmt)
else:
# 其他错误仍然抛出
# d.error(f"插入失败 - 表结构: {table_info}")
d.error(f"尝试插入的数据: {data}")
if "NOT NULL" in str(e):
missing = [col.name for col in table.columns
if not col.nullable and col.name not in data]
raise ValueError(f"缺少必填字段: {missing}") from e
raise
else:
# Upsert模式
if self.db_type == 'postgresql':
# PostgreSQL使用ON CONFLICT语法
for data in data_list:
# 转换数据键名为表列名(不区分大小写)
insert_data = {}
for k, v in data.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
insert_data[matching_cols[0].name] = v
# 转换upsert_keys为表列名(不区分大小写)
actual_upsert_keys = []
for key in upsert_keys:
matching_cols = [col for col in table.columns if col.name.lower() == key.lower()]
if matching_cols:
actual_upsert_keys.append(matching_cols[0].name)
if insert_data and actual_upsert_keys:
update_dict = {k: v for k, v in insert_data.items()
if k not in actual_upsert_keys}
stmt = insert(table).values(**insert_data)
stmt = stmt.on_conflict_do_update(
index_elements=actual_upsert_keys,
set_=update_dict
)
conn.execute(stmt)
elif self.db_type == 'mysql':
# MySQL使用ON DUPLICATE KEY UPDATE语法
for data in data_list:
# 转换数据键名为表列名(不区分大小写)
insert_data = {}
for k, v in data.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
insert_data[matching_cols[0].name] = v
# 转换upsert_keys为表列名(不区分大小写)
actual_upsert_keys = []
for key in upsert_keys:
matching_cols = [col for col in table.columns if col.name.lower() == key.lower()]
if matching_cols:
actual_upsert_keys.append(matching_cols[0].name)
if insert_data and actual_upsert_keys:
update_dict = {k: v for k, v in insert_data.items()
if k not in actual_upsert_keys}
stmt = insert(table).values(**insert_data)
stmt = stmt.on_duplicate_key_update(**update_dict)
conn.execute(stmt)
elif self.db_type == 'db2':
# DB2使用MERGE语法
for data in data_list:
# 转换数据键名为表列名(不区分大小写)
insert_data = {}
for k, v in data.items():
matching_cols = [col for col in table.columns if col.name.lower() == k.lower()]
if matching_cols:
insert_data[matching_cols[0].name] = v
# 转换upsert_keys为表列名(不区分大小写)
actual_upsert_keys = []
for key in upsert_keys:
matching_cols = [col for col in table.columns if col.name.lower() == key.lower()]
if matching_cols:
actual_upsert_keys.append(matching_cols[0].name)
if insert_data and actual_upsert_keys:
# 构建MERGE语句
update_cols = [col for col in insert_data.keys()
if col not in actual_upsert_keys]
merge_set = {col: insert_data[col] for col in update_cols}
merge_condition = {key: insert_data[key] for key in actual_upsert_keys}
# 先尝试更新
update_stmt = update(table).values(**merge_set)
for key in actual_upsert_keys:
update_stmt = update_stmt.where(
getattr(table.c, key) == merge_condition[key]
)
result = conn.execute(update_stmt)
# 如果没有更新到行,则插入
if result.rowcount == 0:
conn.execute(insert(table).values(**insert_data))
else:
raise ValueError(f"不支持的数据库类型: {self.db_type}")
processed_rows += len(data_list)
d.info(f"已处理{processed_rows}/{total_rows}行到{table_name}")
return processed_rows
except SQLAlchemyError as e:
d.error(f"DataFrame写入失败: {e}")
raise
def close(self):
"""关闭所有连接"""
self.engine.dispose()
d.info("数据库连接已关闭")
# 使用示例
if __name__ == "__main__":
# DB2示例
db2_db = DBOperator(
db_type="db2",
host="192.168.137.100",
port=50000,
database="appdb",
username="som",
password="dscdsc1"
)
LogUtil.init("app")
# # 查询示例
# # db2_results = db2_db.execute_query("T_TCM_DEC", {"entId": 'test7'}, as_dataframe=True)
db2_results = db2_db.execute_query("T_TCM_DEC", {"entId": {"$like":'test%'}}, as_dataframe=True)
# # db2_results = db2_db.execute_query("T_TCM_DEC",as_dataframe=True)
print("DB2经理员工:", db2_results)
#
# # # 插入示例
# # db2_insert_count = db2_db.execute_insert("T_TCM_DEC", {
# # "entId": "test3",
# # "f1e_ten_deviate": 1
# # })
# # print(f"DB2插入了{db2_insert_count}条记录")
#
# # DataFrame写入示例
# try:
# # 创建测试DataFrame - 包含所有必填字段并确保类型匹配
# test_data = {
# 'entid': ['test7', 'test8', 'test9'],
# 'F1e_ten_deviate': [1, 0, 1]
# }
#
#
# test_df = pd.DataFrame(test_data)
#
# # # 打印测试数据详情
# # print("\n测试数据详情:")
# # print(test_df.dtypes)
#
# # print("\n测试DataFrame写入DB2:")
# # print(test_df)
#
# # 1. 纯插入模式
# # rows_inserted = db2_db.write_dataframe(
# # df=test_df,
# # table_name="T_TCM_DEC",
# # batch_size=2
# # )
# # print(f"成功插入{rows_inserted}行到T_TCM_DEC表")
#
# # # 2. Upsert模式
# test_data_update = {
# 'entid': ['test7', 'test8', 'test10'], # test10是新记录
# 'f1e_ten_deviate': [2, 8, 10]
# }
# # 确保包含所有必填字段
# test_df_update = pd.DataFrame(test_data_update)
#
# print("\n测试Upsert操作:")
# print(test_df_update)
#
# rows_processed = db2_db.write_dataframe(
# df=test_df_update,
# table_name="T_TCM_DEC",
# upsert_keys=['entid'], # 使用entid作为唯一键
# batch_size=2
# )
# print(f"成功处理{rows_processed}行(更新2行+插入1行)")
#
# except Exception as e:
# print(f"DataFrame写入失败: {e}")
#
# # 关闭连接
# db2_db.close()