eis/TestProject/pyml/pymlAPI.py

67 lines
2.2 KiB
Python
Raw Permalink Normal View History

#!/users/dsc/bin/Python3
#-*-coding: UTF-8 -*-
from typing_extensions import Self
from ml_model import models
from Serialization import Serialization,Deserialization
import numpy as np
from pyml.pyml import logger
class ModelsAPI(object):
def __init__(self) -> None:
self.models_={}
def train(self,x_train,y_train,model_info):
ruleid=model_info.ruleid
model_type=model_info.type#int 1.分类2.回归
model_name=model_info.name
#根据类型而定
# 分类器
# 1.多层感知机 MLP
# 2.逻辑斯蒂回归 LR
# 3.随机森林 RFC
# 4.决策树 DTC
# 回归模型
# 1.多层感知机 MLPR
# 2.逻辑斯蒂回归 LRR
# 3.随机森林 RFR
# 4.决策树 DTR
if(model_type==1):
try:
now_model=models.model(model_name)
now_model.fit(x_train, y_train)
key=ruleid+"_"+model_type+"_"+model_name
Serialization(key,now_model)
self.models_[key]=now_model
now_score=now_model.score(x_train, y_train)
logger.info(key+",训练完成score:"+ str(now_score))
return now_score
except:
logger.error(key+",训练失败")
def predict(self,x_test,model_info):
ruleid=model_info.ruleid
model_type=model_info.type#int 1.分类2.回归
model_name=model_info.name
#根据类型而定
# 分类器
# 1.多层感知机 MLP
# 2.逻辑斯蒂回归 LR
# 3.随机森林 RFC
# 4.决策树 DTC
# 回归模型
# 1.多层感知机 MLPR
# 2.逻辑斯蒂回归 LRR
# 3.随机森林 RFR
# 4.决策树 DTR
key=ruleid+"_"+model_type+"_"+model_name
if(key in self.models_):
pass
else:
try:
now_model=Deserialization(key)
self.models_[key]=now_model
except:
logger.error("predict ERROR! Deserialization ERROR!")
try:
return self.models_[key].predict(x_test)
except:
logger.error("predict ERROR! model.predict ERROR! ")
return None