eis/py/comlib/db/db2_tool.py

426 lines
15 KiB
Python
Raw Normal View History

#import os
#os.environ['IBM_DB_HOME'] = r"D:\win\IBM\SQLLIB\dsdriver"
#os.add_dll_directory(r"D:\win\IBM\SQLLIB\BIN")
from typing import Optional, List, Dict, Any, TypeAlias, Union
import pandas as pd
import ibm_db
from log.LogUtil import LogUtil
import logging as d
# 定义DB2连接类型别名
DB2Connection: TypeAlias = Any
"""
DB2数据库操作工具类
功能:
- 连接池管理
- 自动重连机制
- 安全参数处理
- 基础查询执行
支持的数据库:
- IBM DB2 (Windows/Linux)
依赖库:
- ibm_db (IBM DB2 Python驱动)
- pandas (数据处理)
- logging (日志记录)
作者: zoufuzhou
版本: 1.0
"""
class DB2Tool:
"""DB2数据库操作工具类"""
def __init__(self,
host: str,
port: int,
database: str,
username: str,
password: str,
conn_pool_size: int = 5):
"""
初始化DB2连接参数
:param host: 主机地址
:param port: 端口号
:param database: 数据库名
:param username: 用户名
:param password: 密码
:param conn_pool_size: 连接池大小
"""
# 再初始化连接参数
self.host = host
self.port = port
self.database = database
self.username = username
self.password = password
self.conn_pool_size = conn_pool_size
self.connection_pool = []
try:
self._init_connection_pool()
except Exception as e:
d.error(f"初始化连接池失败: {e}")
raise
def _init_connection_pool(self):
"""初始化连接池"""
for _ in range(self.conn_pool_size):
conn = self._create_connection()
if conn:
self.connection_pool.append(conn)
def _create_connection(self, retries: int = 3) -> Optional[DB2Connection]:
"""创建单个连接
:param retries: 连接重试次数默认3次
"""
attempt = 0
last_error = None
while attempt < retries:
attempt += 1
try:
# 安全处理连接参数
def safe_str(s: str) -> str:
if not isinstance(s, str):
s = str(s)
try:
# 先尝试UTF-8编码
s.encode('utf-8')
# 替换可能引起问题的特殊字符
return ''.join(c if c.isalnum() or c in ('_', '-', '.') else '_' for c in s)
except UnicodeError:
# 对于无法UTF-8编码的字符串使用更安全的处理方式
return ''.join(c if ord(c) < 128 and c.isalnum() else '_' for c in s)
# 构建安全连接字符串(添加编码参数)
conn_params = [
f"DATABASE={safe_str(self.database)}",
f"HOSTNAME={safe_str(self.host)}",
f"PORT={self.port}",
"PROTOCOL=TCPIP",
f"UID={safe_str(self.username)}",
f"PWD={safe_str(self.password)}",
"CONNECTTIMEOUT=10",
"CLIENT_LOCALE=en_US.UTF-8", # 强制客户端使用UTF-8
"DB_LOCALE=en_US.UTF-8" # 强制数据库使用UTF-8
]
conn_str = ";".join(conn_params) + ";"
d.debug(f"完整连接字符串: {conn_str}")
d.debug(f"安全处理后的连接字符串: {conn_str}")
d.debug(f"尝试连接(第{attempt}次): {conn_str}")
conn = ibm_db.connect(conn_str, "", "")
# 验证连接是否有效
if ibm_db.active(conn):
d.info(f"DB2连接成功(第{attempt}次尝试)")
return conn
ibm_db.close(conn)
except Exception as e: # ibm_db通用异常捕获
if hasattr(ibm_db, 'stmt_errormsg'):
last_error = ibm_db.stmt_errormsg()
else:
last_error = str(e)
d.warning(f"DB2连接失败(第{attempt}次): {last_error}")
if attempt < retries:
import time
time.sleep(1) # 等待1秒后重试
except Exception as e:
last_error = str(e)
d.error(f"连接过程中发生意外错误(第{attempt}次): {last_error}")
break
d.error(f"所有连接尝试失败,最后错误: {last_error}")
return None
def get_connection(self) -> Optional[DB2Connection]:
"""从连接池获取连接"""
if not self.connection_pool:
d.warning("连接池为空,创建新连接")
return self._create_connection()
return self.connection_pool.pop()
def release_connection(self, conn: DB2Connection):
"""释放连接回连接池"""
if conn and len(self.connection_pool) < self.conn_pool_size:
self.connection_pool.append(conn)
def execute_query(self,
sql: str,
params: Optional[tuple] = None,
as_dataframe: bool = False) -> Union[List[Dict[str, Any]], pd.DataFrame]:
"""
执行查询语句
:param sql: SQL语句
:param params: 参数
:param as_dataframe: 是否返回DataFrame
:return: 结果列表或DataFrame
"""
conn = self.get_connection()
if not conn:
return []
try:
stmt = ibm_db.prepare(conn, sql)
if params:
ibm_db.execute(stmt, params)
else:
ibm_db.execute(stmt)
result = []
row = ibm_db.fetch_assoc(stmt)
while row:
result.append(row)
row = ibm_db.fetch_assoc(stmt)
if as_dataframe:
return pd.DataFrame(result) if result else pd.DataFrame()
return result
except Exception as e:
d.error(f"查询执行失败: {e}")
return pd.DataFrame() if as_dataframe else []
finally:
ibm_db.free_stmt(stmt)
self.release_connection(conn)
def execute_update(self, sql: str, params: Optional[tuple] = None) -> int:
"""
执行更新语句
:param sql: SQL语句
:param params: 参数
:return: 影响的行数
"""
conn = self.get_connection()
if not conn:
return 0
try:
stmt = ibm_db.prepare(conn, sql)
if params:
ibm_db.execute(stmt, params)
else:
ibm_db.execute(stmt)
return ibm_db.num_rows(stmt)
except Exception as e:
d.error(f"更新执行失败: {e}")
return 0
finally:
ibm_db.free_stmt(stmt)
self.release_connection(conn)
def execute_transaction(self, sql_list: List[str], params_list: Optional[List[tuple]] = None) -> bool:
"""
执行事务
:param sql_list: SQL语句列表
:param params_list: 参数列表
:return: 是否成功
"""
conn = self.get_connection()
if not conn:
return False
try:
ibm_db.autocommit(conn, ibm_db.SQL_AUTOCOMMIT_OFF)
def safe_encode(value):
"""安全编码参数值"""
if isinstance(value, str):
try:
return value.encode('utf-8').decode('utf-8')
except UnicodeError:
return value.encode('ascii', 'replace').decode('ascii')
return value
for i, sql in enumerate(sql_list):
try:
d.debug(f"准备执行SQL: {sql}")
stmt = ibm_db.prepare(conn, sql)
if params_list and i < len(params_list):
params = params_list[i]
d.debug(f"参数列表: {params}")
# 使用原生参数绑定方式
for param_idx, param_val in enumerate(params, start=1):
if isinstance(param_val, str):
# 字符串参数特殊处理
try:
ibm_db.bind_param(stmt, param_idx, param_val.encode('utf-8'), ibm_db.SQL_PARAM_INPUT)
except Exception as e:
d.error(f"参数绑定失败(位置{param_idx}): {param_val}")
raise
else:
# 非字符串参数直接绑定
ibm_db.bind_param(stmt, param_idx, param_val, ibm_db.SQL_PARAM_INPUT)
d.debug("执行带绑定参数的SQL")
ibm_db.execute(stmt)
else:
d.debug("执行无参数SQL")
ibm_db.execute(stmt)
ibm_db.free_stmt(stmt)
except Exception as e:
d.error(f"执行SQL语句失败: {sql}\n错误: {e}")
raise
ibm_db.commit(conn)
return True
except Exception as e:
d.error(f"事务执行失败: {e}", exc_info=True)
try:
ibm_db.rollback(conn)
except Exception as rollback_err:
d.error(f"回滚失败: {rollback_err}")
return False
finally:
try:
ibm_db.autocommit(conn, ibm_db.SQL_AUTOCOMMIT_ON)
self.release_connection(conn)
except Exception as e:
d.error(f"释放连接失败: {e}")
def close_all_connections(self):
"""关闭所有连接"""
for conn in self.connection_pool:
ibm_db.close(conn)
self.connection_pool.clear()
d.info("所有DB2连接已关闭")
def write_dataframe(self,
df: pd.DataFrame,
table_name: str,
batch_size: int = 1000,
use_transaction: bool = True,
upsert_keys: Optional[List[str]] = None,
update_columns: Optional[List[str]] = None) -> int:
"""
将DataFrame写入数据库表(支持存在更新不存在插入)
:param df: 要写入的DataFrame
:param table_name: 目标表名
:param batch_size: 批量大小
:param use_transaction: 是否使用事务
:param upsert_keys: 用于判断记录是否存在的键列名列表
:param update_columns: 需要更新的列名列表(为None时更新所有非键列)
:return: 处理的行数
"""
if df.empty:
d.warning("空DataFrame跳过写入")
return 0
# 获取列名
columns = df.columns.tolist()
total_rows = len(df)
processed_rows = 0
# 判断操作模式
if upsert_keys is None:
# 纯插入模式
placeholders = ",".join(["?"] * len(columns))
sql = f"INSERT INTO {table_name} ({','.join(columns)}) VALUES ({placeholders})"
else:
# Upsert模式
if update_columns is None:
update_columns = [col for col in columns if col not in upsert_keys]
# 构建MERGE语句(DB2语法)
merge_set = ",".join([f"{col}=EXCLUDED.{col}" for col in update_columns])
merge_condition = " AND ".join([f"TARGET.{key}=EXCLUDED.{key}" for key in upsert_keys])
sql = f"""
MERGE INTO {table_name} AS TARGET
USING (VALUES ({','.join(['?']*len(columns))})) AS EXCLUDED ({','.join(columns)})
ON ({merge_condition})
WHEN MATCHED THEN UPDATE SET {merge_set}
WHEN NOT MATCHED THEN INSERT ({','.join(columns)}) VALUES ({','.join([f'EXCLUDED.{col}' for col in columns])})
"""
# 分批处理
for i in range(0, total_rows, batch_size):
batch = df.iloc[i:i + batch_size]
params = [tuple(row) for row in batch.itertuples(index=False)]
if use_transaction:
# 使用事务写入
success = self.execute_transaction([sql] * len(batch), params)
if not success:
raise RuntimeError(f"写入表{table_name}失败")
else:
# 不使用事务,逐条写入
for param in params:
affected = self.execute_update(sql, param)
if affected <= 0 and upsert_keys is None: # 纯插入模式才检查影响行数
raise RuntimeError(f"写入表{table_name}失败")
processed_rows += len(batch)
d.info(f"已处理{processed_rows}/{total_rows}行到{table_name}")
return processed_rows
def __del__(self):
"""析构函数"""
self.close_all_connections()
# 使用示例
if __name__ == "__main__":
# 配置数据库连接参数
db2 = DB2Tool(
host="192.168.137.100",
port=50000,
database="appdb",
username="som",
password="dscdsc1"
)
LogUtil.init("app")
# 查询示例
results = db2.execute_query("SELECT * FROM T_SOM_STAT WHERE STEEL_GRADE = ? ", ("steel",), as_dataframe=True)
print("查询结果:", results)
# 更新示例
affected_rows = db2.execute_update("UPDATE T_TCM_DEC SET USE_TEN_DEVIATE = ? WHERE entid = ?", (1, "test1"))
print("影响行数:", affected_rows)
# 事务示例
transaction_sqls = [
"INSERT INTO T_TCM_DEC (entid, USE_TEN_DEVIATE) VALUES ('test1', 1)",
"UPDATE T_TCM_DEC SET USE_TEN_DEVIATE = 2 WHERE entid = 'test1'"
]
success = db2.execute_transaction(transaction_sqls)
print("事务执行结果:", success)
# DataFrame写入数据库示例
try:
# 创建测试DataFrame
test_data = {
'entid': ['test4', 'test5', 'test6'],
'USE_TEN_DEVIATE': [1, 0, 1]
}
test_df = pd.DataFrame(test_data)
# 测试Upsert模式
test_data_update = {
'entid': ['test4', 'test5', 'test9'], # test7是新记录
'USE_TEN_DEVIATE': [2, 9, 1]
}
test_df_update = pd.DataFrame(test_data_update)
print(test_df_update)
rows_processed = db2.write_dataframe(
df=test_df_update,
table_name="T_TCM_DEC",
batch_size=2,
upsert_keys=['entid'], # 使用entid作为唯一键
update_columns=['USE_TEN_DEVIATE'] # 更新这些列
)
print(f"成功处理{rows_processed}行(更新2行,插入1行)")
except Exception as e:
print(f"DataFrame写入失败: {e}")