254 lines
9.3 KiB
Python
254 lines
9.3 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
数据清洗工具模块
|
||
|
||
该模块提供用于清洗和预处理Pandas DataFrame的功能,包括:
|
||
- 缺失值处理(删除、填充、插值)
|
||
- 重复行删除
|
||
- 数据编码管理(通过_encoding_maps属性)
|
||
|
||
主要类:
|
||
DataFrameCleaner: 提供链式调用的数据清洗接口,支持多种清洗策略
|
||
|
||
Author:
|
||
- Author : zoufuzhou
|
||
- Date : 2025-05-21 16:34:37
|
||
- LastEditTime : 2025-05-21 16:34:37
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from typing import Union, List, Dict, Optional, Callable
|
||
|
||
|
||
class DataFrameCleaner:
|
||
"""DataFrame数据清洗工具类"""
|
||
|
||
def __init__(self, df: pd.DataFrame):
|
||
"""
|
||
初始化清洗器
|
||
:param df: 要清洗的DataFrame
|
||
"""
|
||
self.df = df.copy()
|
||
self._original_df = df.copy() # 保留原始数据
|
||
self._encoding_maps = {} # 保存编码映射关系
|
||
|
||
def get_cleaned_data(self) -> pd.DataFrame:
|
||
"""获取清洗后的数据"""
|
||
return self.df
|
||
|
||
def reset(self) -> 'DataFrameCleaner':
|
||
"""重置为原始数据"""
|
||
self.df = self._original_df.copy()
|
||
return self
|
||
|
||
def handle_missing_values(self,
|
||
strategy: str = 'drop',
|
||
fill_value: Union[int, float, str] = None,
|
||
columns: Optional[List[str]] = None) -> 'DataFrameCleaner':
|
||
"""
|
||
处理缺失值
|
||
:param strategy: 处理策略('drop','fill','interpolate')
|
||
:param fill_value: 填充值(strategy='fill'时使用)
|
||
:param columns: 指定列(默认处理所有列)
|
||
"""
|
||
cols = columns if columns else self.df.columns
|
||
|
||
if strategy == 'drop':
|
||
self.df = self.df.dropna(subset=cols)
|
||
elif strategy == 'fill':
|
||
if fill_value is None:
|
||
raise ValueError("fill策略需要指定fill_value")
|
||
self.df[cols] = self.df[cols].fillna(fill_value)
|
||
elif strategy == 'interpolate':
|
||
self.df[cols] = self.df[cols].interpolate()
|
||
|
||
return self
|
||
|
||
def remove_duplicates(self,
|
||
subset: Optional[List[str]] = None,
|
||
keep: str = 'first') -> 'DataFrameCleaner':
|
||
"""
|
||
删除重复行
|
||
:param subset: 查重列(默认所有列)
|
||
:param keep: 保留策略('first','last',False)
|
||
"""
|
||
self.df = self.df.drop_duplicates(subset=subset, keep=keep)
|
||
return self
|
||
|
||
def convert_types(self,
|
||
type_map: Dict[str, str]) -> 'DataFrameCleaner':
|
||
"""
|
||
转换列数据类型
|
||
:param type_map: 类型映射{列名:类型}
|
||
支持类型: 'int','float','str','bool','datetime','category'
|
||
"""
|
||
for col, dtype in type_map.items():
|
||
if col in self.df.columns:
|
||
if dtype == 'datetime':
|
||
self.df[col] = pd.to_datetime(self.df[col])
|
||
elif dtype == 'category':
|
||
self.df[col] = self.df[col].astype('category')
|
||
else:
|
||
self.df[col] = self.df[col].astype(dtype)
|
||
return self
|
||
|
||
def handle_outliers(self,
|
||
column: str,
|
||
method: str = 'iqr',
|
||
threshold: float = 1.5) -> 'DataFrameCleaner':
|
||
"""
|
||
处理异常值
|
||
:param column: 列名
|
||
:param method: 检测方法('iqr','zscore')
|
||
:param threshold: 阈值
|
||
"""
|
||
if column not in self.df.columns or not pd.api.types.is_numeric_dtype(self.df[column]):
|
||
return self
|
||
|
||
if method == 'iqr':
|
||
q1 = self.df[column].quantile(0.25)
|
||
q3 = self.df[column].quantile(0.75)
|
||
iqr = q3 - q1
|
||
lower = q1 - threshold * iqr
|
||
upper = q3 + threshold * iqr
|
||
self.df = self.df[(self.df[column] >= lower) & (self.df[column] <= upper)]
|
||
elif method == 'zscore':
|
||
zscore = (self.df[column] - self.df[column].mean()) / self.df[column].std()
|
||
self.df = self.df[abs(zscore) <= threshold]
|
||
|
||
return self
|
||
|
||
def normalize_strings(self,
|
||
columns: Union[str, List[str]],
|
||
case: str = 'lower',
|
||
strip: bool = True) -> 'DataFrameCleaner':
|
||
"""
|
||
字符串规范化
|
||
:param columns: 列名或列名列表
|
||
:param case: 大小写('lower','upper','title')
|
||
:param strip: 是否去除两端空格
|
||
"""
|
||
if isinstance(columns, str):
|
||
columns = [columns]
|
||
|
||
for col in columns:
|
||
if col in self.df.columns and pd.api.types.is_string_dtype(self.df[col]):
|
||
if strip:
|
||
self.df[col] = self.df[col].str.strip()
|
||
if case == 'lower':
|
||
self.df[col] = self.df[col].str.lower()
|
||
elif case == 'upper':
|
||
self.df[col] = self.df[col].str.upper()
|
||
elif case == 'title':
|
||
self.df[col] = self.df[col].str.title()
|
||
|
||
return self
|
||
|
||
def normalize_headers(self, case: str = 'lower') -> 'DataFrameCleaner':
|
||
"""
|
||
统一列名大小写
|
||
:param case: 大小写格式('lower','upper','title')
|
||
"""
|
||
if case == 'lower':
|
||
self.df.columns = [col.lower() for col in self.df.columns]
|
||
elif case == 'upper':
|
||
self.df.columns = [col.upper() for col in self.df.columns]
|
||
elif case == 'title':
|
||
self.df.columns = [col.title() for col in self.df.columns]
|
||
return self
|
||
|
||
def apply_custom(self,
|
||
columns: Union[str, List[str]],
|
||
func: Callable) -> 'DataFrameCleaner':
|
||
"""
|
||
应用自定义清洗函数
|
||
:param columns: 列名或列名列表
|
||
:param func: 自定义函数(接受一个值并返回处理后的值)
|
||
"""
|
||
if isinstance(columns, str):
|
||
columns = [columns]
|
||
|
||
for col in columns:
|
||
if col in self.df.columns:
|
||
self.df[col] = self.df[col].apply(func)
|
||
|
||
return self
|
||
|
||
def encode_categorical(self,
|
||
columns: Union[str, List[str]],
|
||
method: str = 'label',
|
||
drop: bool = True) -> 'DataFrameCleaner':
|
||
"""
|
||
分类数据数值化
|
||
:param columns: 列名或列名列表
|
||
:param method: 编码方法('label','onehot')
|
||
:param drop: 是否删除原始列(onehot时有效)
|
||
"""
|
||
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
|
||
|
||
if isinstance(columns, str):
|
||
columns = [columns]
|
||
|
||
for col in columns:
|
||
if col in self.df.columns and isinstance(self.df[col].dtype, pd.CategoricalDtype):
|
||
if method == 'label':
|
||
encoder = LabelEncoder()
|
||
self.df[col] = encoder.fit_transform(self.df[col])
|
||
# 保存编码映射
|
||
self._encoding_maps[col] = dict(zip(encoder.classes_, range(len(encoder.classes_))))
|
||
elif method == 'onehot':
|
||
encoder = OneHotEncoder()
|
||
encoded = encoder.fit_transform(self.df[[col]].to_numpy().reshape(-1, 1)).toarray()
|
||
# 添加新列
|
||
for i, cls in enumerate(encoder.categories_[0]):
|
||
self.df[f"{col}_{cls}"] = encoded[:, i]
|
||
if drop:
|
||
self.df.drop(col, axis=1, inplace=True)
|
||
return self
|
||
|
||
def convert_strings_to_numeric(self,
|
||
columns: Union[str, List[str]],
|
||
pattern: Optional[str] = None,
|
||
func: Optional[Callable] = None) -> 'DataFrameCleaner':
|
||
"""
|
||
字符串转数值
|
||
:param columns: 列名或列名列表
|
||
:param pattern: 正则模式(如提取数字)
|
||
:param func: 自定义转换函数
|
||
"""
|
||
import re
|
||
|
||
if isinstance(columns, str):
|
||
columns = [columns]
|
||
|
||
for col in columns:
|
||
if col in self.df.columns and pd.api.types.is_string_dtype(self.df[col]):
|
||
if func:
|
||
self.df[col] = self.df[col].apply(func)
|
||
elif pattern:
|
||
self.df[col] = self.df[col].str.extract(pattern, expand=False).astype(float)
|
||
else:
|
||
# 尝试自动转换
|
||
self.df[col] = pd.to_numeric(self.df[col], errors='coerce')
|
||
return self
|
||
|
||
|
||
|
||
# 使用示例
|
||
#from PandasDataIO import PandasDataIO
|
||
#if __name__ == "__main__":
|
||
#
|
||
# data = PandasDataIO().read_csv('t_mode_pdo.csv')
|
||
# print("\n原始数据:")
|
||
# print(data)
|
||
# cleaned_df = (
|
||
# DataFrameCleaner(data)
|
||
# # .normalize_strings('name', case='title', strip=True)
|
||
# .normalize_headers(case='lower')
|
||
# .convert_types({'steelgrade': 'category'}) # 转换数据类型
|
||
# .encode_categorical(['steelgrade'], method='label', drop=True)).get_cleaned_data()
|
||
# print("\n清洗后数据:")
|
||
# print(cleaned_df)
|
||
# print(cleaned_df.dtypes)
|