426 lines
15 KiB
Python
426 lines
15 KiB
Python
#import os
|
||
#os.environ['IBM_DB_HOME'] = r"D:\win\IBM\SQLLIB\dsdriver"
|
||
#os.add_dll_directory(r"D:\win\IBM\SQLLIB\BIN")
|
||
from typing import Optional, List, Dict, Any, TypeAlias, Union
|
||
import pandas as pd
|
||
import ibm_db
|
||
from log.LogUtil import LogUtil
|
||
import logging as d
|
||
|
||
# 定义DB2连接类型别名
|
||
DB2Connection: TypeAlias = Any
|
||
"""
|
||
DB2数据库操作工具类
|
||
|
||
功能:
|
||
- 连接池管理
|
||
- 自动重连机制
|
||
- 安全参数处理
|
||
- 基础查询执行
|
||
|
||
支持的数据库:
|
||
- IBM DB2 (Windows/Linux)
|
||
|
||
依赖库:
|
||
- ibm_db (IBM DB2 Python驱动)
|
||
- pandas (数据处理)
|
||
- logging (日志记录)
|
||
|
||
作者: zoufuzhou
|
||
版本: 1.0
|
||
"""
|
||
|
||
|
||
class DB2Tool:
|
||
"""DB2数据库操作工具类"""
|
||
|
||
def __init__(self,
|
||
host: str,
|
||
port: int,
|
||
database: str,
|
||
username: str,
|
||
password: str,
|
||
conn_pool_size: int = 5):
|
||
"""
|
||
初始化DB2连接参数
|
||
:param host: 主机地址
|
||
:param port: 端口号
|
||
:param database: 数据库名
|
||
:param username: 用户名
|
||
:param password: 密码
|
||
:param conn_pool_size: 连接池大小
|
||
"""
|
||
|
||
# 再初始化连接参数
|
||
self.host = host
|
||
self.port = port
|
||
self.database = database
|
||
self.username = username
|
||
self.password = password
|
||
self.conn_pool_size = conn_pool_size
|
||
self.connection_pool = []
|
||
|
||
try:
|
||
self._init_connection_pool()
|
||
except Exception as e:
|
||
d.error(f"初始化连接池失败: {e}")
|
||
raise
|
||
|
||
def _init_connection_pool(self):
|
||
"""初始化连接池"""
|
||
for _ in range(self.conn_pool_size):
|
||
conn = self._create_connection()
|
||
if conn:
|
||
self.connection_pool.append(conn)
|
||
|
||
def _create_connection(self, retries: int = 3) -> Optional[DB2Connection]:
|
||
"""创建单个连接
|
||
:param retries: 连接重试次数,默认3次
|
||
"""
|
||
attempt = 0
|
||
last_error = None
|
||
|
||
while attempt < retries:
|
||
attempt += 1
|
||
try:
|
||
# 安全处理连接参数
|
||
def safe_str(s: str) -> str:
|
||
if not isinstance(s, str):
|
||
s = str(s)
|
||
try:
|
||
# 先尝试UTF-8编码
|
||
s.encode('utf-8')
|
||
# 替换可能引起问题的特殊字符
|
||
return ''.join(c if c.isalnum() or c in ('_', '-', '.') else '_' for c in s)
|
||
except UnicodeError:
|
||
# 对于无法UTF-8编码的字符串,使用更安全的处理方式
|
||
return ''.join(c if ord(c) < 128 and c.isalnum() else '_' for c in s)
|
||
|
||
# 构建安全连接字符串(添加编码参数)
|
||
conn_params = [
|
||
f"DATABASE={safe_str(self.database)}",
|
||
f"HOSTNAME={safe_str(self.host)}",
|
||
f"PORT={self.port}",
|
||
"PROTOCOL=TCPIP",
|
||
f"UID={safe_str(self.username)}",
|
||
f"PWD={safe_str(self.password)}",
|
||
"CONNECTTIMEOUT=10",
|
||
"CLIENT_LOCALE=en_US.UTF-8", # 强制客户端使用UTF-8
|
||
"DB_LOCALE=en_US.UTF-8" # 强制数据库使用UTF-8
|
||
]
|
||
conn_str = ";".join(conn_params) + ";"
|
||
d.debug(f"完整连接字符串: {conn_str}")
|
||
|
||
d.debug(f"安全处理后的连接字符串: {conn_str}")
|
||
|
||
d.debug(f"尝试连接(第{attempt}次): {conn_str}")
|
||
conn = ibm_db.connect(conn_str, "", "")
|
||
|
||
# 验证连接是否有效
|
||
if ibm_db.active(conn):
|
||
d.info(f"DB2连接成功(第{attempt}次尝试)")
|
||
return conn
|
||
ibm_db.close(conn)
|
||
|
||
except Exception as e: # ibm_db通用异常捕获
|
||
if hasattr(ibm_db, 'stmt_errormsg'):
|
||
last_error = ibm_db.stmt_errormsg()
|
||
else:
|
||
last_error = str(e)
|
||
d.warning(f"DB2连接失败(第{attempt}次): {last_error}")
|
||
if attempt < retries:
|
||
import time
|
||
time.sleep(1) # 等待1秒后重试
|
||
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
d.error(f"连接过程中发生意外错误(第{attempt}次): {last_error}")
|
||
break
|
||
|
||
d.error(f"所有连接尝试失败,最后错误: {last_error}")
|
||
return None
|
||
|
||
def get_connection(self) -> Optional[DB2Connection]:
|
||
"""从连接池获取连接"""
|
||
if not self.connection_pool:
|
||
d.warning("连接池为空,创建新连接")
|
||
return self._create_connection()
|
||
return self.connection_pool.pop()
|
||
|
||
def release_connection(self, conn: DB2Connection):
|
||
"""释放连接回连接池"""
|
||
if conn and len(self.connection_pool) < self.conn_pool_size:
|
||
self.connection_pool.append(conn)
|
||
|
||
def execute_query(self,
|
||
sql: str,
|
||
params: Optional[tuple] = None,
|
||
as_dataframe: bool = False) -> Union[List[Dict[str, Any]], pd.DataFrame]:
|
||
"""
|
||
执行查询语句
|
||
:param sql: SQL语句
|
||
:param params: 参数
|
||
:param as_dataframe: 是否返回DataFrame
|
||
:return: 结果列表或DataFrame
|
||
"""
|
||
conn = self.get_connection()
|
||
if not conn:
|
||
return []
|
||
|
||
try:
|
||
stmt = ibm_db.prepare(conn, sql)
|
||
if params:
|
||
ibm_db.execute(stmt, params)
|
||
else:
|
||
ibm_db.execute(stmt)
|
||
|
||
result = []
|
||
row = ibm_db.fetch_assoc(stmt)
|
||
while row:
|
||
result.append(row)
|
||
row = ibm_db.fetch_assoc(stmt)
|
||
|
||
if as_dataframe:
|
||
return pd.DataFrame(result) if result else pd.DataFrame()
|
||
return result
|
||
except Exception as e:
|
||
d.error(f"查询执行失败: {e}")
|
||
return pd.DataFrame() if as_dataframe else []
|
||
finally:
|
||
ibm_db.free_stmt(stmt)
|
||
self.release_connection(conn)
|
||
|
||
def execute_update(self, sql: str, params: Optional[tuple] = None) -> int:
|
||
"""
|
||
执行更新语句
|
||
:param sql: SQL语句
|
||
:param params: 参数
|
||
:return: 影响的行数
|
||
"""
|
||
conn = self.get_connection()
|
||
if not conn:
|
||
return 0
|
||
|
||
try:
|
||
stmt = ibm_db.prepare(conn, sql)
|
||
if params:
|
||
ibm_db.execute(stmt, params)
|
||
else:
|
||
ibm_db.execute(stmt)
|
||
return ibm_db.num_rows(stmt)
|
||
except Exception as e:
|
||
d.error(f"更新执行失败: {e}")
|
||
return 0
|
||
finally:
|
||
ibm_db.free_stmt(stmt)
|
||
self.release_connection(conn)
|
||
|
||
def execute_transaction(self, sql_list: List[str], params_list: Optional[List[tuple]] = None) -> bool:
|
||
"""
|
||
执行事务
|
||
:param sql_list: SQL语句列表
|
||
:param params_list: 参数列表
|
||
:return: 是否成功
|
||
"""
|
||
conn = self.get_connection()
|
||
if not conn:
|
||
return False
|
||
|
||
try:
|
||
ibm_db.autocommit(conn, ibm_db.SQL_AUTOCOMMIT_OFF)
|
||
|
||
def safe_encode(value):
|
||
"""安全编码参数值"""
|
||
if isinstance(value, str):
|
||
try:
|
||
return value.encode('utf-8').decode('utf-8')
|
||
except UnicodeError:
|
||
return value.encode('ascii', 'replace').decode('ascii')
|
||
return value
|
||
|
||
for i, sql in enumerate(sql_list):
|
||
try:
|
||
d.debug(f"准备执行SQL: {sql}")
|
||
stmt = ibm_db.prepare(conn, sql)
|
||
|
||
if params_list and i < len(params_list):
|
||
params = params_list[i]
|
||
d.debug(f"参数列表: {params}")
|
||
|
||
# 使用原生参数绑定方式
|
||
for param_idx, param_val in enumerate(params, start=1):
|
||
if isinstance(param_val, str):
|
||
# 字符串参数特殊处理
|
||
try:
|
||
ibm_db.bind_param(stmt, param_idx, param_val.encode('utf-8'), ibm_db.SQL_PARAM_INPUT)
|
||
except Exception as e:
|
||
d.error(f"参数绑定失败(位置{param_idx}): {param_val}")
|
||
raise
|
||
else:
|
||
# 非字符串参数直接绑定
|
||
ibm_db.bind_param(stmt, param_idx, param_val, ibm_db.SQL_PARAM_INPUT)
|
||
|
||
d.debug("执行带绑定参数的SQL")
|
||
ibm_db.execute(stmt)
|
||
else:
|
||
d.debug("执行无参数SQL")
|
||
ibm_db.execute(stmt)
|
||
|
||
ibm_db.free_stmt(stmt)
|
||
except Exception as e:
|
||
d.error(f"执行SQL语句失败: {sql}\n错误: {e}")
|
||
raise
|
||
|
||
ibm_db.commit(conn)
|
||
return True
|
||
except Exception as e:
|
||
d.error(f"事务执行失败: {e}", exc_info=True)
|
||
try:
|
||
ibm_db.rollback(conn)
|
||
except Exception as rollback_err:
|
||
d.error(f"回滚失败: {rollback_err}")
|
||
return False
|
||
finally:
|
||
try:
|
||
ibm_db.autocommit(conn, ibm_db.SQL_AUTOCOMMIT_ON)
|
||
self.release_connection(conn)
|
||
except Exception as e:
|
||
d.error(f"释放连接失败: {e}")
|
||
|
||
def close_all_connections(self):
|
||
"""关闭所有连接"""
|
||
for conn in self.connection_pool:
|
||
ibm_db.close(conn)
|
||
self.connection_pool.clear()
|
||
d.info("所有DB2连接已关闭")
|
||
|
||
def write_dataframe(self,
|
||
df: pd.DataFrame,
|
||
table_name: str,
|
||
batch_size: int = 1000,
|
||
use_transaction: bool = True,
|
||
upsert_keys: Optional[List[str]] = None,
|
||
update_columns: Optional[List[str]] = None) -> int:
|
||
"""
|
||
将DataFrame写入数据库表(支持存在更新,不存在插入)
|
||
:param df: 要写入的DataFrame
|
||
:param table_name: 目标表名
|
||
:param batch_size: 批量大小
|
||
:param use_transaction: 是否使用事务
|
||
:param upsert_keys: 用于判断记录是否存在的键列名列表
|
||
:param update_columns: 需要更新的列名列表(为None时更新所有非键列)
|
||
:return: 处理的行数
|
||
"""
|
||
if df.empty:
|
||
d.warning("空DataFrame,跳过写入")
|
||
return 0
|
||
|
||
# 获取列名
|
||
columns = df.columns.tolist()
|
||
total_rows = len(df)
|
||
processed_rows = 0
|
||
|
||
# 判断操作模式
|
||
if upsert_keys is None:
|
||
# 纯插入模式
|
||
placeholders = ",".join(["?"] * len(columns))
|
||
sql = f"INSERT INTO {table_name} ({','.join(columns)}) VALUES ({placeholders})"
|
||
else:
|
||
# Upsert模式
|
||
if update_columns is None:
|
||
update_columns = [col for col in columns if col not in upsert_keys]
|
||
|
||
# 构建MERGE语句(DB2语法)
|
||
merge_set = ",".join([f"{col}=EXCLUDED.{col}" for col in update_columns])
|
||
merge_condition = " AND ".join([f"TARGET.{key}=EXCLUDED.{key}" for key in upsert_keys])
|
||
sql = f"""
|
||
MERGE INTO {table_name} AS TARGET
|
||
USING (VALUES ({','.join(['?']*len(columns))})) AS EXCLUDED ({','.join(columns)})
|
||
ON ({merge_condition})
|
||
WHEN MATCHED THEN UPDATE SET {merge_set}
|
||
WHEN NOT MATCHED THEN INSERT ({','.join(columns)}) VALUES ({','.join([f'EXCLUDED.{col}' for col in columns])})
|
||
"""
|
||
|
||
# 分批处理
|
||
for i in range(0, total_rows, batch_size):
|
||
batch = df.iloc[i:i + batch_size]
|
||
params = [tuple(row) for row in batch.itertuples(index=False)]
|
||
|
||
if use_transaction:
|
||
# 使用事务写入
|
||
success = self.execute_transaction([sql] * len(batch), params)
|
||
if not success:
|
||
raise RuntimeError(f"写入表{table_name}失败")
|
||
else:
|
||
# 不使用事务,逐条写入
|
||
for param in params:
|
||
affected = self.execute_update(sql, param)
|
||
if affected <= 0 and upsert_keys is None: # 纯插入模式才检查影响行数
|
||
raise RuntimeError(f"写入表{table_name}失败")
|
||
|
||
processed_rows += len(batch)
|
||
d.info(f"已处理{processed_rows}/{total_rows}行到{table_name}")
|
||
|
||
return processed_rows
|
||
|
||
def __del__(self):
|
||
"""析构函数"""
|
||
self.close_all_connections()
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 配置数据库连接参数
|
||
db2 = DB2Tool(
|
||
host="192.168.137.100",
|
||
port=50000,
|
||
database="appdb",
|
||
username="som",
|
||
password="dscdsc1"
|
||
)
|
||
LogUtil.init("app")
|
||
|
||
# 查询示例
|
||
results = db2.execute_query("SELECT * FROM T_SOM_STAT WHERE STEEL_GRADE = ? ", ("steel",), as_dataframe=True)
|
||
print("查询结果:", results)
|
||
|
||
# 更新示例
|
||
affected_rows = db2.execute_update("UPDATE T_TCM_DEC SET USE_TEN_DEVIATE = ? WHERE entid = ?", (1, "test1"))
|
||
print("影响行数:", affected_rows)
|
||
|
||
# 事务示例
|
||
transaction_sqls = [
|
||
"INSERT INTO T_TCM_DEC (entid, USE_TEN_DEVIATE) VALUES ('test1', 1)",
|
||
"UPDATE T_TCM_DEC SET USE_TEN_DEVIATE = 2 WHERE entid = 'test1'"
|
||
]
|
||
success = db2.execute_transaction(transaction_sqls)
|
||
print("事务执行结果:", success)
|
||
|
||
# DataFrame写入数据库示例
|
||
try:
|
||
# 创建测试DataFrame
|
||
test_data = {
|
||
'entid': ['test4', 'test5', 'test6'],
|
||
'USE_TEN_DEVIATE': [1, 0, 1]
|
||
}
|
||
test_df = pd.DataFrame(test_data)
|
||
|
||
# 测试Upsert模式
|
||
test_data_update = {
|
||
'entid': ['test4', 'test5', 'test9'], # test7是新记录
|
||
'USE_TEN_DEVIATE': [2, 9, 1]
|
||
}
|
||
test_df_update = pd.DataFrame(test_data_update)
|
||
print(test_df_update)
|
||
|
||
rows_processed = db2.write_dataframe(
|
||
df=test_df_update,
|
||
table_name="T_TCM_DEC",
|
||
batch_size=2,
|
||
upsert_keys=['entid'], # 使用entid作为唯一键
|
||
update_columns=['USE_TEN_DEVIATE'] # 更新这些列
|
||
)
|
||
print(f"成功处理{rows_processed}行(更新2行,插入1行)")
|
||
|
||
except Exception as e:
|
||
print(f"DataFrame写入失败: {e}")
|