Skip to content

Commit d6bc8b0

Browse files
authored
Create trainingData
1 parent 1ae4731 commit d6bc8b0

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from cv2 import cv2
2+
from PIL import Image
3+
import numpy as np
4+
import os
5+
6+
# 创建一个垃圾分类的字典
7+
rubbishDict = {'dry':1, 'harmful':2, 'recycle':3, 'residual':4}
8+
# 创建一个存放源数据的数组
9+
data = []
10+
# 存放对应标签的数组
11+
label = []
12+
# 图片总文件夹
13+
imageFile = 'E:/python/Project/rubbish/image'
14+
# 对各个子文件夹进行读取
15+
for i in rubbishDict.keys():
16+
# 图片路径
17+
imageListDir = imageFile + '/' + i
18+
# 获取该路径下的文件名目录
19+
imageList = os.listdir(imageListDir)
20+
# 对每一张图片进行操作
21+
for j in imageList:
22+
# 图片路径
23+
imageDir = imageListDir + '/' + j
24+
# 读取图片
25+
src = cv2.imread(imageDir)
26+
# 将array类型的输入转换成RGB格式的图片
27+
arrayToImage = Image.fromarray(src, 'RGB')
28+
# 将输入数组统一尺寸
29+
sizeArray = arrayToImage.resize((50, 50))
30+
data.append(np.array(sizeArray))
31+
label.append(rubbishDict[i])
32+
# 生成训练数据集
33+
x_train = np.array(data)
34+
y_train = np.array(label)
35+
print(x_train.shape)
36+
print(y_train.shape)

0 commit comments

Comments
 (0)