23 lines
785 B
Python
23 lines
785 B
Python
#!/users/dsc/bin/Python3
|
|
#-*-coding: UTF-8 -*-
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn.datasets import load_digits#数据集 手写字
|
|
from sklearn.model_selection import train_test_split#训练集/测试机划分函数
|
|
import ml_utiliy as mlty
|
|
digits = load_digits()
|
|
X = digits.data#特征数据
|
|
y = digits.target#标签
|
|
#训练集和测试集的划分
|
|
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2,stratify=y)
|
|
#模型载入
|
|
model_names=["MLP","LR","RFC","DTC"]
|
|
def mnTrain():
|
|
return mlty.train("mn_test",model_names[0],x_train,y_train)
|
|
def mnTest():
|
|
return mlty.predict_score("mn_test",model_names[0],x_test,y_test)
|
|
|
|
if __name__=="__main__":
|
|
print("X.shape:{}".format(X.shape))
|
|
print("y.shape:{}".format(y.shape))
|