142 lines
5.0 KiB
Python
142 lines
5.0 KiB
Python
#!/users/dsc/bin/Python3
|
||
#-*-coding: UTF-8 -*-
|
||
from typing_extensions import Self
|
||
from ml_model import models
|
||
from Serialization import Serialization,Deserialization
|
||
import ihyper_db
|
||
import numpy as np
|
||
import datetime
|
||
import copy
|
||
def train(ruleid,model_type,x_train,y_train):
|
||
"""_train_训练模型
|
||
|
||
Args:
|
||
ruleid (_str_): _保存模型的文件名称_
|
||
model_type (_str_): _模型类型_: 1.MLP-多层感知机;2.LR-逻辑斯蒂回归;3.RFC-随机森林;4.DTC-决策树
|
||
x_data (_<class 'numpy.ndarray'>_): _特征数据_
|
||
y_data (_<class 'numpy.ndarray'>_): _目标数据_
|
||
return:
|
||
训练的准确率
|
||
"""
|
||
now_model=models.model(model_type)
|
||
now_model.fit(x_train, y_train)
|
||
key=ruleid+"_"+model_type
|
||
Serialization(key,now_model)
|
||
return now_model.score(x_train, y_train)
|
||
|
||
def predict_score(ruleid,model_type,x_test,y_test):
|
||
"""_predict_score_运用模型
|
||
|
||
Args:
|
||
ruleid (_str_): _保存模型的文件名称_
|
||
model_type (_str_): _模型类型_: 1.MLP-多层感知机;2.LR-逻辑斯蒂回归;3.RFC-随机森林;4.DTC-决策树
|
||
x_test (_<class 'numpy.ndarray'>_): _特征数据_
|
||
y_test (_<class 'numpy.ndarray'>_): _目标数据_
|
||
return
|
||
测试数据的准确率
|
||
"""
|
||
key=ruleid+"_"+model_type
|
||
now_model=Deserialization(key)
|
||
if(now_model):
|
||
return now_model.score(x_test, y_test)
|
||
else:
|
||
return train(ruleid,model_type,x_test,y_test)
|
||
|
||
def read_mode(ruleid,model_type):
|
||
key=ruleid+"_"+model_type
|
||
now_model=Deserialization(key)
|
||
return now_model
|
||
|
||
class ihyperDB(object):
|
||
"""_ihyperDB_
|
||
查询ihd数据
|
||
"""
|
||
def __init__(self):
|
||
self.flag_=False
|
||
self.value_=[]
|
||
self.flags_=[]
|
||
self.values_=[]
|
||
def select_raw_data(self,tag_name,stime,etime,interval=50):
|
||
"""_select_raw_data_
|
||
查询原始记录
|
||
Args:
|
||
tag_name (_type_): _items_
|
||
stime (_type_): _开始时间_
|
||
etime (_type_): _结束时间_
|
||
interval (int, optional): _数据时间间隔_. Defaults to 50ms.
|
||
|
||
Returns:
|
||
_type_: _查询结果_ [是否正确查询,数据结果]
|
||
"""
|
||
time_points,data_nums=time_delta_split(stime,etime,interval=50)
|
||
self.flags_=[]
|
||
self.values_=[]
|
||
for t,n in zip(time_points,data_nums):
|
||
print("t:{},n:{}".format(t,n))
|
||
self.__select_raw_data(tag_name,t,n)
|
||
if(self.flag_):
|
||
(self.value_).reverse()
|
||
# print(self.value_)
|
||
# self.values_.append(copy.deepcopy(self.value_))
|
||
for k in self.value_:
|
||
self.values_.append(k)
|
||
self.flags_.append(self.flag_)
|
||
# print(self.values_)
|
||
return self.flags_,self.values_
|
||
|
||
def __select_raw_data(self,tag_name,etime,data_num,interval=50):
|
||
"""_select_raw_data_
|
||
查询原始记录 最多65535条记录,从etime往前查
|
||
Args:
|
||
tag_name (_type_): _items__
|
||
etime (_type_): _结束时间_
|
||
data_num (_type_): _查询数据量_最大为65535
|
||
interval (int, optional): _数据时间间隔_. Defaults to 50ms.
|
||
|
||
Returns:
|
||
_type_: _查询结果_ [是否正确查询,数据结果]
|
||
"""
|
||
if(isinstance(etime,int)):
|
||
if(data_num>65535):
|
||
data_num=65535
|
||
timepoint=datetime.datetime.fromtimestamp(etime/1000)
|
||
elif(isinstance(etime,datetime.datetime)):
|
||
if(data_num>65535):
|
||
data_num=65535
|
||
timepoint=etime
|
||
[flag,value]=ihyper_db.get_tag_small_time_data(tag_name,timepoint,int(data_num))
|
||
if(flag==0):
|
||
self.flag_=True
|
||
self.value_=value
|
||
else:
|
||
self.flag_=False
|
||
return self.flag_,self.value_
|
||
|
||
|
||
def time_delta_split(stime,etime,interval=50,start_mode=False):
|
||
timepoints=[]
|
||
data_nums=[]
|
||
if(isinstance(stime,int) and isinstance(etime,int)):
|
||
data_num=(etime-stime)/interval
|
||
nmax=int(data_num/65535)
|
||
nmin=int(data_num%65535)
|
||
|
||
if(nmax>0):
|
||
for k in range(nmax+1):
|
||
# print(k)
|
||
eetime=int((etime-(nmax-k)*interval*65535)/1000)
|
||
# sstime=int((stime+(nmax-k)*interval*65535)/1000)
|
||
# print("nmax:{},k:{},sstime:{}:".format(nmax,k,sstime))
|
||
timepoints.append(datetime.datetime.fromtimestamp(eetime))
|
||
data_nums.append(65530 if k>0 else nmin)
|
||
elif(isinstance(stime,datetime.datetime) and isinstance(etime,datetime.datetime)):
|
||
data_num=(etime-stime).total_seconds()*1000/interval
|
||
nmax=int(data_num/65535)
|
||
nmin=int(data_num%65535)
|
||
if(nmax>0):
|
||
for k in range(nmax+1):
|
||
timepoints.append(etime-datetime.timedelta(seconds=(nmax-k)*interval*65535)/1000)
|
||
data_nums.append(65530 if k>0 else nmin)
|
||
return timepoints,data_nums
|
||
|