eis/TestProject/pyml/pymlAPI.py

67 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/users/dsc/bin/Python3
#-*-coding: UTF-8 -*-
from typing_extensions import Self
from ml_model import models
from Serialization import Serialization,Deserialization
import numpy as np
from pyml.pyml import logger
class ModelsAPI(object):
def __init__(self) -> None:
self.models_={}
def train(self,x_train,y_train,model_info):
ruleid=model_info.ruleid
model_type=model_info.type#int 1.分类2.回归
model_name=model_info.name
#根据类型而定
# 分类器
# 1.多层感知机 MLP
# 2.逻辑斯蒂回归 LR
# 3.随机森林 RFC
# 4.决策树 DTC
# 回归模型
# 1.多层感知机 MLPR
# 2.逻辑斯蒂回归 LRR
# 3.随机森林 RFR
# 4.决策树 DTR
if(model_type==1):
try:
now_model=models.model(model_name)
now_model.fit(x_train, y_train)
key=ruleid+"_"+model_type+"_"+model_name
Serialization(key,now_model)
self.models_[key]=now_model
now_score=now_model.score(x_train, y_train)
logger.info(key+",训练完成score:"+ str(now_score))
return now_score
except:
logger.error(key+",训练失败")
def predict(self,x_test,model_info):
ruleid=model_info.ruleid
model_type=model_info.type#int 1.分类2.回归
model_name=model_info.name
#根据类型而定
# 分类器
# 1.多层感知机 MLP
# 2.逻辑斯蒂回归 LR
# 3.随机森林 RFC
# 4.决策树 DTC
# 回归模型
# 1.多层感知机 MLPR
# 2.逻辑斯蒂回归 LRR
# 3.随机森林 RFR
# 4.决策树 DTR
key=ruleid+"_"+model_type+"_"+model_name
if(key in self.models_):
pass
else:
try:
now_model=Deserialization(key)
self.models_[key]=now_model
except:
logger.error("predict ERROR! Deserialization ERROR!")
try:
return self.models_[key].predict(x_test)
except:
logger.error("predict ERROR! model.predict ERROR! ")
return None