67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
#!/users/dsc/bin/Python3
|
||
#-*-coding: UTF-8 -*-
|
||
from typing_extensions import Self
|
||
from ml_model import models
|
||
from Serialization import Serialization,Deserialization
|
||
import numpy as np
|
||
from pyml.pyml import logger
|
||
class ModelsAPI(object):
|
||
def __init__(self) -> None:
|
||
self.models_={}
|
||
def train(self,x_train,y_train,model_info):
|
||
ruleid=model_info.ruleid
|
||
model_type=model_info.type#int 1.分类;2.回归
|
||
model_name=model_info.name
|
||
#根据类型而定
|
||
# 分类器
|
||
# 1.多层感知机 MLP
|
||
# 2.逻辑斯蒂回归 LR
|
||
# 3.随机森林 RFC
|
||
# 4.决策树 DTC
|
||
# 回归模型
|
||
# 1.多层感知机 MLPR
|
||
# 2.逻辑斯蒂回归 LRR
|
||
# 3.随机森林 RFR
|
||
# 4.决策树 DTR
|
||
if(model_type==1):
|
||
try:
|
||
now_model=models.model(model_name)
|
||
now_model.fit(x_train, y_train)
|
||
key=ruleid+"_"+model_type+"_"+model_name
|
||
Serialization(key,now_model)
|
||
self.models_[key]=now_model
|
||
now_score=now_model.score(x_train, y_train)
|
||
logger.info(key+",训练完成!score:"+ str(now_score))
|
||
return now_score
|
||
except:
|
||
logger.error(key+",训练失败")
|
||
def predict(self,x_test,model_info):
|
||
ruleid=model_info.ruleid
|
||
model_type=model_info.type#int 1.分类;2.回归
|
||
model_name=model_info.name
|
||
#根据类型而定
|
||
# 分类器
|
||
# 1.多层感知机 MLP
|
||
# 2.逻辑斯蒂回归 LR
|
||
# 3.随机森林 RFC
|
||
# 4.决策树 DTC
|
||
# 回归模型
|
||
# 1.多层感知机 MLPR
|
||
# 2.逻辑斯蒂回归 LRR
|
||
# 3.随机森林 RFR
|
||
# 4.决策树 DTR
|
||
key=ruleid+"_"+model_type+"_"+model_name
|
||
if(key in self.models_):
|
||
pass
|
||
else:
|
||
try:
|
||
now_model=Deserialization(key)
|
||
self.models_[key]=now_model
|
||
except:
|
||
logger.error("predict ERROR! Deserialization ERROR!")
|
||
try:
|
||
return self.models_[key].predict(x_test)
|
||
except:
|
||
logger.error("predict ERROR! model.predict ERROR! ")
|
||
return None
|
||
|