67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
|
|
#!/users/dsc/bin/Python3
|
|||
|
|
#-*-coding: UTF-8 -*-
|
|||
|
|
from typing_extensions import Self
|
|||
|
|
from ml_model import models
|
|||
|
|
from Serialization import Serialization,Deserialization
|
|||
|
|
import numpy as np
|
|||
|
|
from pyml.pyml import logger
|
|||
|
|
class ModelsAPI(object):
|
|||
|
|
def __init__(self) -> None:
|
|||
|
|
self.models_={}
|
|||
|
|
def train(self,x_train,y_train,model_info):
|
|||
|
|
ruleid=model_info.ruleid
|
|||
|
|
model_type=model_info.type#int 1.分类;2.回归
|
|||
|
|
model_name=model_info.name
|
|||
|
|
#根据类型而定
|
|||
|
|
# 分类器
|
|||
|
|
# 1.多层感知机 MLP
|
|||
|
|
# 2.逻辑斯蒂回归 LR
|
|||
|
|
# 3.随机森林 RFC
|
|||
|
|
# 4.决策树 DTC
|
|||
|
|
# 回归模型
|
|||
|
|
# 1.多层感知机 MLPR
|
|||
|
|
# 2.逻辑斯蒂回归 LRR
|
|||
|
|
# 3.随机森林 RFR
|
|||
|
|
# 4.决策树 DTR
|
|||
|
|
if(model_type==1):
|
|||
|
|
try:
|
|||
|
|
now_model=models.model(model_name)
|
|||
|
|
now_model.fit(x_train, y_train)
|
|||
|
|
key=ruleid+"_"+model_type+"_"+model_name
|
|||
|
|
Serialization(key,now_model)
|
|||
|
|
self.models_[key]=now_model
|
|||
|
|
now_score=now_model.score(x_train, y_train)
|
|||
|
|
logger.info(key+",训练完成!score:"+ str(now_score))
|
|||
|
|
return now_score
|
|||
|
|
except:
|
|||
|
|
logger.error(key+",训练失败")
|
|||
|
|
def predict(self,x_test,model_info):
|
|||
|
|
ruleid=model_info.ruleid
|
|||
|
|
model_type=model_info.type#int 1.分类;2.回归
|
|||
|
|
model_name=model_info.name
|
|||
|
|
#根据类型而定
|
|||
|
|
# 分类器
|
|||
|
|
# 1.多层感知机 MLP
|
|||
|
|
# 2.逻辑斯蒂回归 LR
|
|||
|
|
# 3.随机森林 RFC
|
|||
|
|
# 4.决策树 DTC
|
|||
|
|
# 回归模型
|
|||
|
|
# 1.多层感知机 MLPR
|
|||
|
|
# 2.逻辑斯蒂回归 LRR
|
|||
|
|
# 3.随机森林 RFR
|
|||
|
|
# 4.决策树 DTR
|
|||
|
|
key=ruleid+"_"+model_type+"_"+model_name
|
|||
|
|
if(key in self.models_):
|
|||
|
|
pass
|
|||
|
|
else:
|
|||
|
|
try:
|
|||
|
|
now_model=Deserialization(key)
|
|||
|
|
self.models_[key]=now_model
|
|||
|
|
except:
|
|||
|
|
logger.error("predict ERROR! Deserialization ERROR!")
|
|||
|
|
try:
|
|||
|
|
return self.models_[key].predict(x_test)
|
|||
|
|
except:
|
|||
|
|
logger.error("predict ERROR! model.predict ERROR! ")
|
|||
|
|
return None
|
|||
|
|
|