eis/TestProject/pyml/ml_utiliy.py

142 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/users/dsc/bin/Python3
#-*-coding: UTF-8 -*-
from typing_extensions import Self
from ml_model import models
from Serialization import Serialization,Deserialization
import ihyper_db
import numpy as np
import datetime
import copy
def train(ruleid,model_type,x_train,y_train):
"""_train_训练模型
Args:
ruleid (_str_): _保存模型的文件名称_
model_type (_str_): _模型类型_: 1.MLP-多层感知机;2.LR-逻辑斯蒂回归;3.RFC-随机森林;4.DTC-决策树
x_data (_<class 'numpy.ndarray'>_): _特征数据_
y_data (_<class 'numpy.ndarray'>_): _目标数据_
return:
训练的准确率
"""
now_model=models.model(model_type)
now_model.fit(x_train, y_train)
key=ruleid+"_"+model_type
Serialization(key,now_model)
return now_model.score(x_train, y_train)
def predict_score(ruleid,model_type,x_test,y_test):
"""_predict_score_运用模型
Args:
ruleid (_str_): _保存模型的文件名称_
model_type (_str_): _模型类型_: 1.MLP-多层感知机;2.LR-逻辑斯蒂回归;3.RFC-随机森林;4.DTC-决策树
x_test (_<class 'numpy.ndarray'>_): _特征数据_
y_test (_<class 'numpy.ndarray'>_): _目标数据_
return
测试数据的准确率
"""
key=ruleid+"_"+model_type
now_model=Deserialization(key)
if(now_model):
return now_model.score(x_test, y_test)
else:
return train(ruleid,model_type,x_test,y_test)
def read_mode(ruleid,model_type):
key=ruleid+"_"+model_type
now_model=Deserialization(key)
return now_model
class ihyperDB(object):
"""_ihyperDB_
查询ihd数据
"""
def __init__(self):
self.flag_=False
self.value_=[]
self.flags_=[]
self.values_=[]
def select_raw_data(self,tag_name,stime,etime,interval=50):
"""_select_raw_data_
查询原始记录
Args:
tag_name (_type_): _items_
stime (_type_): _开始时间_
etime (_type_): _结束时间_
interval (int, optional): _数据时间间隔_. Defaults to 50ms.
Returns:
_type_: _查询结果_ [是否正确查询,数据结果]
"""
time_points,data_nums=time_delta_split(stime,etime,interval=50)
self.flags_=[]
self.values_=[]
for t,n in zip(time_points,data_nums):
print("t:{},n:{}".format(t,n))
self.__select_raw_data(tag_name,t,n)
if(self.flag_):
(self.value_).reverse()
# print(self.value_)
# self.values_.append(copy.deepcopy(self.value_))
for k in self.value_:
self.values_.append(k)
self.flags_.append(self.flag_)
# print(self.values_)
return self.flags_,self.values_
def __select_raw_data(self,tag_name,etime,data_num,interval=50):
"""_select_raw_data_
查询原始记录 最多65535条记录从etime往前查
Args:
tag_name (_type_): _items__
etime (_type_): _结束时间_
data_num (_type_): _查询数据量_最大为65535
interval (int, optional): _数据时间间隔_. Defaults to 50ms.
Returns:
_type_: _查询结果_ [是否正确查询,数据结果]
"""
if(isinstance(etime,int)):
if(data_num>65535):
data_num=65535
timepoint=datetime.datetime.fromtimestamp(etime/1000)
elif(isinstance(etime,datetime.datetime)):
if(data_num>65535):
data_num=65535
timepoint=etime
[flag,value]=ihyper_db.get_tag_small_time_data(tag_name,timepoint,int(data_num))
if(flag==0):
self.flag_=True
self.value_=value
else:
self.flag_=False
return self.flag_,self.value_
def time_delta_split(stime,etime,interval=50,start_mode=False):
timepoints=[]
data_nums=[]
if(isinstance(stime,int) and isinstance(etime,int)):
data_num=(etime-stime)/interval
nmax=int(data_num/65535)
nmin=int(data_num%65535)
if(nmax>0):
for k in range(nmax+1):
# print(k)
eetime=int((etime-(nmax-k)*interval*65535)/1000)
# sstime=int((stime+(nmax-k)*interval*65535)/1000)
# print("nmax:{},k:{},sstime:{}:".format(nmax,k,sstime))
timepoints.append(datetime.datetime.fromtimestamp(eetime))
data_nums.append(65530 if k>0 else nmin)
elif(isinstance(stime,datetime.datetime) and isinstance(etime,datetime.datetime)):
data_num=(etime-stime).total_seconds()*1000/interval
nmax=int(data_num/65535)
nmin=int(data_num%65535)
if(nmax>0):
for k in range(nmax+1):
timepoints.append(etime-datetime.timedelta(seconds=(nmax-k)*interval*65535)/1000)
data_nums.append(65530 if k>0 else nmin)
return timepoints,data_nums