#!/users/dsc/bin/Python3 #-*-coding: UTF-8 -*- from typing_extensions import Self from ml_model import models from Serialization import Serialization,Deserialization import numpy as np from pyml.pyml import logger class ModelsAPI(object): def __init__(self) -> None: self.models_={} def train(self,x_train,y_train,model_info): ruleid=model_info.ruleid model_type=model_info.type#int 1.分类;2.回归 model_name=model_info.name #根据类型而定 # 分类器 # 1.多层感知机 MLP # 2.逻辑斯蒂回归 LR # 3.随机森林 RFC # 4.决策树 DTC # 回归模型 # 1.多层感知机 MLPR # 2.逻辑斯蒂回归 LRR # 3.随机森林 RFR # 4.决策树 DTR if(model_type==1): try: now_model=models.model(model_name) now_model.fit(x_train, y_train) key=ruleid+"_"+model_type+"_"+model_name Serialization(key,now_model) self.models_[key]=now_model now_score=now_model.score(x_train, y_train) logger.info(key+",训练完成!score:"+ str(now_score)) return now_score except: logger.error(key+",训练失败") def predict(self,x_test,model_info): ruleid=model_info.ruleid model_type=model_info.type#int 1.分类;2.回归 model_name=model_info.name #根据类型而定 # 分类器 # 1.多层感知机 MLP # 2.逻辑斯蒂回归 LR # 3.随机森林 RFC # 4.决策树 DTC # 回归模型 # 1.多层感知机 MLPR # 2.逻辑斯蒂回归 LRR # 3.随机森林 RFR # 4.决策树 DTR key=ruleid+"_"+model_type+"_"+model_name if(key in self.models_): pass else: try: now_model=Deserialization(key) self.models_[key]=now_model except: logger.error("predict ERROR! Deserialization ERROR!") try: return self.models_[key].predict(x_test) except: logger.error("predict ERROR! model.predict ERROR! ") return None