本代码最初在colab中实现,以下为全部代码及运行输出结果 #
挂载谷歌云盘,解压数据集
'/content'
1 !unzip '/content/drive/MyDrive/AI_content/RCNN/Images.zip' -d '/content/drive/MyDrive/AI_content/RCNN'
1 !unzip '/content/drive/MyDrive/AI_content/RCNN/Airplanes_Annotations.zip' -d '/content/drive/MyDrive/AI_content/RCNN'
安装并导入依赖
1 !pip install tensorflow==2.8 .0
1 2 3 4 5 import os,cv2,kerasimport pandas as pdimport matplotlib.pyplot as pltimport numpy as npimport tensorflow as tf
'2.8.0'
更改工作目录
1 cd /content/drive/MyDrive/AI_content/RCNN
/content/drive/MyDrive/AI_content/RCNN
1 2 path = '/content/drive/MyDrive/AI_content/RCNN/Images' annot = '/content/drive/MyDrive/AI_content/RCNN/Airplanes_Annotations'
查看数据和标签
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 Index=148 filename = "airplane_" +str (Index)+".jpg" print (filename)img = cv2.imread(os.path.join(path,filename)) df = pd.read_csv(os.path.join(annot,filename.replace(".jpg" ,".csv" ))) plt.imshow(img) for row in df.iterrows(): x1 = int (row[1 ][0 ].split(" " )[0 ]) y1 = int (row[1 ][0 ].split(" " )[1 ]) x2 = int (row[1 ][0 ].split(" " )[2 ]) y2 = int (row[1 ][0 ].split(" " )[3 ]) cv2.rectangle(img,(x1,y1),(x2,y2),(255 ,0 ,0 ), 2 ) plt.figure() plt.imshow(img) plt.show()
airplane_148.jpg
OKYkfL.png
OKYzBi.png
Selective search
1 2 cv2.setUseOptimized(True ); ss = cv2.ximgproc.segmentation.createSelectiveSearchSegmentation()
1 2 3 4 5 6 7 8 9 10 11 12 13 im = cv2.imread(os.path.join(path,"42850.jpg" )) ss.setBaseImage(im) ss.switchToSelectiveSearchFast() rects = ss.process() imOut = im.copy() for i, rect in (enumerate (rects)): x, y, w, h = rect cv2.rectangle(imOut, (x, y), (x+w, y+h), (0 , 255 , 0 ), 1 , cv2.LINE_AA) plt.imshow(imOut) plt.show()
OKYbpC.png
IOU
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def get_iou (bb1, bb2 ): assert bb1['x1' ] < bb1['x2' ] assert bb1['y1' ] < bb1['y2' ] assert bb2['x1' ] < bb2['x2' ] assert bb2['y1' ] < bb2['y2' ] x_left = max (bb1['x1' ], bb2['x1' ]) y_top = max (bb1['y1' ], bb2['y1' ]) x_right = min (bb1['x2' ], bb2['x2' ]) y_bottom = min (bb1['y2' ], bb2['y2' ]) if x_right < x_left or y_bottom < y_top: return 0.0 intersection_area = (x_right - x_left) * (y_bottom - y_top) bb1_area = (bb1['x2' ] - bb1['x1' ]) * (bb1['y2' ] - bb1['y1' ]) bb2_area = (bb2['x2' ] - bb2['x1' ]) * (bb2['y2' ] - bb2['y1' ]) iou = intersection_area / float (bb1_area + bb2_area - intersection_area) assert iou >= 0.0 assert iou <= 1.0 return iou
准备训练数据
1 2 3 train_images=[] train_labels=[]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 for e,i in enumerate (os.listdir(annot)): try : if i.startswith("airplane" ): filename = i.split("." )[0 ]+".jpg" print (e,filename) image = cv2.imread(os.path.join(path,filename)) df = pd.read_csv(os.path.join(annot,i)) gtvalues=[] for row in df.iterrows(): x1 = int (row[1 ][0 ].split(" " )[0 ]) y1 = int (row[1 ][0 ].split(" " )[1 ]) x2 = int (row[1 ][0 ].split(" " )[2 ]) y2 = int (row[1 ][0 ].split(" " )[3 ]) gtvalues.append({"x1" :x1,"x2" :x2,"y1" :y1,"y2" :y2}) ss.setBaseImage(image) ss.switchToSelectiveSearchFast() ssresults = ss.process() imout = image.copy() counter = 0 falsecounter = 0 flag = 0 fflag = 0 bflag = 0 for e,result in enumerate (ssresults): if e < 2000 and flag == 0 : for gtval in gtvalues: x,y,w,h = result iou = get_iou(gtval,{"x1" :x,"x2" :x+w,"y1" :y,"y2" :y+h}) if counter < 30 : if iou > 0.70 : timage = imout[x:x+w,y:y+h] resized = cv2.resize(timage, (224 ,224 ), interpolation = cv2.INTER_AREA) train_images.append(resized) train_labels.append(1 ) counter += 1 else : fflag =1 if falsecounter <30 : if iou < 0.3 : timage = imout[x:x+w,y:y+h] resized = cv2.resize(timage, (224 ,224 ), interpolation = cv2.INTER_AREA) train_images.append(resized) train_labels.append(0 ) falsecounter += 1 else : bflag = 1 if fflag == 1 and bflag == 1 : print ("inside" ) flag = 1 except Exception as e: print (e) print ("error in " +filename) continue
1 2 3 X_new = np.array(train_images) Y_new = np.array(train_labels)
1 2 3 np.save('save_X_new' ,X_new) np.save('save_Y_new' ,Y_new)
1 2 3 X_new = np.load('save_X_new.npy' ) Y_new = np.load('save_Y_new.npy' )
1 2 3 total_nums = len (Y_new) print (total_nums)
30229
1 2 3 4 5 6 7 num_samples = 5000 random_indices = np.random.choice(len (X_new), num_samples, replace=False ) X_new_subset = X_new[random_indices] Y_new_subset = Y_new[random_indices]
预训练(使用VGG16模型创建一个迁移学习模型)
1 2 3 from keras.layers import Densefrom keras import Modelfrom keras import optimizers
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 vgg = tf.keras.applications.vgg16.VGG16(include_top=True , weights='imagenet' , input_tensor=None , input_shape=None , pooling=None , classes=1000 ) for layer in vgg.layers[:-2 ]: layer.trainable = False x = vgg.get_layer('fc2' ) last_output = x.output x = tf.keras.layers.Dense(1 ,activation = 'sigmoid' )(last_output) model = tf.keras.Model(vgg.input ,x) model.compile (optimizer = "adam" , loss = 'binary_crossentropy' , metrics = ['acc' ])
1 2 model.save('my_model_vgg16.h5' )
/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py:3103:
UserWarning: You are saving your model as an HDF5 file via
model.save()
. This file format is considered legacy. We
recommend using instead the native Keras format, e.g.
model.save('my_model.keras')
. saving_api.save_model(
1 2 3 model.summary() model.fit(X_new_subset,Y_new_subset,batch_size = 32 ,epochs = 3 , verbose = 1 ,validation_split=.05 ,shuffle = True )
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 224, 224, 3)] 0
block1_conv1 (Conv2D) (None, 224, 224, 64) 1792
block1_conv2 (Conv2D) (None, 224, 224, 64) 36928
block1_pool (MaxPooling2D) (None, 112, 112, 64) 0
block2_conv1 (Conv2D) (None, 112, 112, 128) 73856
block2_conv2 (Conv2D) (None, 112, 112, 128) 147584
block2_pool (MaxPooling2D) (None, 56, 56, 128) 0
block3_conv1 (Conv2D) (None, 56, 56, 256) 295168
block3_conv2 (Conv2D) (None, 56, 56, 256) 590080
block3_conv3 (Conv2D) (None, 56, 56, 256) 590080
block3_pool (MaxPooling2D) (None, 28, 28, 256) 0
block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160
block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808
block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808
block4_pool (MaxPooling2D) (None, 14, 14, 512) 0
block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808
block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808
block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808
block5_pool (MaxPooling2D) (None, 7, 7, 512) 0
flatten (Flatten) (None, 25088) 0
fc1 (Dense) (None, 4096) 102764544
fc2 (Dense) (None, 4096) 16781312
dense (Dense) (None, 1) 4097
=================================================================
Total params: 134264641 (512.18 MB)
Trainable params: 16785409 (64.03 MB)
Non-trainable params: 117479232 (448.15 MB)
_________________________________________________________________
Epoch 1/3
149/149 [==============================] - 43s 214ms/step - loss: 1.4586 - acc: 0.7680 - val_loss: 0.3230 - val_acc: 0.8880
Epoch 2/3
149/149 [==============================] - 19s 129ms/step - loss: 0.3880 - acc: 0.8215 - val_loss: 0.3339 - val_acc: 0.8560
Epoch 3/3
149/149 [==============================] - 20s 131ms/step - loss: 0.3464 - acc: 0.8495 - val_loss: 0.3104 - val_acc: 0.8760
<keras.src.callbacks.History at 0x7a67e60b2140>
创建带有SVM的新网络
创建供SVM使用的数据集
1 2 svm_image = []; svm_label = [];
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 for e,i in enumerate (os.listdir(annot)): try : if i.startswith("airplane" ): filename = i.split("." )[0 ]+".jpg" print (e,filename) image = cv2.imread(os.path.join(path,filename)) df = pd.read_csv(os.path.join(annot,i)) gtvalues=[] for row in df.iterrows(): x1 = int (row[1 ][0 ].split(" " )[0 ]) y1 = int (row[1 ][0 ].split(" " )[1 ]) x2 = int (row[1 ][0 ].split(" " )[2 ]) y2 = int (row[1 ][0 ].split(" " )[3 ]) gtvalues.append({"x1" :x1,"x2" :x2,"y1" :y1,"y2" :y2}) timage = image[x1:x2,y1:y2] resized = cv2.resize(timage, (224 ,224 ), interpolation = cv2.INTER_AREA) svm_image.append(resized) svm_label.append([0 ,1 ]) ss.setBaseImage(image) ss.switchToSelectiveSearchFast() ssresults = ss.process() imout = image.copy() counter = 0 falsecounter = 0 flag = 0 for e,result in enumerate (ssresults): if e < 2000 and flag == 0 : for gtval in gtvalues: x,y,w,h = result iou = get_iou(gtval,{"x1" :x,"x2" :x+w,"y1" :y,"y2" :y+h}) if falsecounter <5 : if iou < 0.3 : timage = imout[x:x+w,y:y+h] resized = cv2.resize(timage, (224 ,224 ), interpolation = cv2.INTER_AREA) svm_image.append(resized) svm_label.append([1 ,0 ]) falsecounter += 1 else : flag = 1 except Exception as e: print (e) print ("error in " +filename) continue
1 2 3 X_svm = np.array(svm_image) Y_svm = np.array(svm_label)
1 2 3 np.save('save_X_svm' ,X_svm) np.save('save_Y_svm' ,Y_svm)
1 2 total_nums_svm = len (Y_svm) print (total_nums_svm)
7750
1 2 X_svm = np.load('save_X_svm.npy' ) Y_svm = np.load('save_Y_svm.npy' )
1 2 3 4 5 6 7 num_samples = 2000 random_indices = np.random.choice(len (X_svm), num_samples, replace=False ) X_svm_subset = X_svm[random_indices] Y_svm_subset = Y_svm[random_indices]
SVM模型结构
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 x = model.get_layer('fc2' ).output Y = tf.keras.layers.Dense(2 )(x) final_model = tf.keras.Model(model.input , Y) final_model.compile (loss='hinge' , optimizer='adam' , metrics=['accuracy' ]) final_model.summary()
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 224, 224, 3)] 0
block1_conv1 (Conv2D) (None, 224, 224, 64) 1792
block1_conv2 (Conv2D) (None, 224, 224, 64) 36928
block1_pool (MaxPooling2D) (None, 112, 112, 64) 0
block2_conv1 (Conv2D) (None, 112, 112, 128) 73856
block2_conv2 (Conv2D) (None, 112, 112, 128) 147584
block2_pool (MaxPooling2D) (None, 56, 56, 128) 0
block3_conv1 (Conv2D) (None, 56, 56, 256) 295168
block3_conv2 (Conv2D) (None, 56, 56, 256) 590080
block3_conv3 (Conv2D) (None, 56, 56, 256) 590080
block3_pool (MaxPooling2D) (None, 28, 28, 256) 0
block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160
block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808
block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808
block4_pool (MaxPooling2D) (None, 14, 14, 512) 0
block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808
block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808
block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808
block5_pool (MaxPooling2D) (None, 7, 7, 512) 0
flatten (Flatten) (None, 25088) 0
fc1 (Dense) (None, 4096) 102764544
fc2 (Dense) (None, 4096) 16781312
dense_1 (Dense) (None, 2) 8194
=================================================================
Total params: 134268738 (512.19 MB)
Trainable params: 16789506 (64.05 MB)
Non-trainable params: 117479232 (448.15 MB)
_________________________________________________________________
模型训练
1 2 3 4 5 6 7 8 9 10 hist_final = final_model.fit( X_svm_subset, Y_svm_subset, batch_size=32 , epochs=20 , verbose=1 , shuffle=True , validation_split=0.05 )
Epoch 1/20
60/60 [==============================] - 17s 217ms/step - loss: 0.7543 - accuracy: 0.6689 - val_loss: 0.7870 - val_accuracy: 0.6100
Epoch 2/20
60/60 [==============================] - 8s 128ms/step - loss: 0.5756 - accuracy: 0.7463 - val_loss: 0.6612 - val_accuracy: 0.7200
Epoch 3/20
60/60 [==============================] - 8s 134ms/step - loss: 0.4762 - accuracy: 0.7905 - val_loss: 0.6471 - val_accuracy: 0.7300
Epoch 4/20
60/60 [==============================] - 8s 132ms/step - loss: 0.4303 - accuracy: 0.8232 - val_loss: 0.7102 - val_accuracy: 0.6900
Epoch 5/20
60/60 [==============================] - 8s 131ms/step - loss: 0.4178 - accuracy: 0.8200 - val_loss: 0.6434 - val_accuracy: 0.7000
Epoch 6/20
60/60 [==============================] - 8s 136ms/step - loss: 0.3378 - accuracy: 0.8558 - val_loss: 0.7106 - val_accuracy: 0.7000
Epoch 7/20
60/60 [==============================] - 8s 135ms/step - loss: 0.3321 - accuracy: 0.8647 - val_loss: 0.6975 - val_accuracy: 0.7400
Epoch 8/20
60/60 [==============================] - 8s 130ms/step - loss: 0.3007 - accuracy: 0.8737 - val_loss: 0.7403 - val_accuracy: 0.7500
Epoch 9/20
60/60 [==============================] - 8s 138ms/step - loss: 0.2735 - accuracy: 0.8858 - val_loss: 0.8128 - val_accuracy: 0.7100
Epoch 10/20
60/60 [==============================] - 8s 133ms/step - loss: 0.2241 - accuracy: 0.9153 - val_loss: 0.9183 - val_accuracy: 0.6900
Epoch 11/20
60/60 [==============================] - 8s 141ms/step - loss: 0.2732 - accuracy: 0.8916 - val_loss: 0.7814 - val_accuracy: 0.7100
Epoch 12/20
60/60 [==============================] - 8s 142ms/step - loss: 0.2009 - accuracy: 0.9163 - val_loss: 0.8467 - val_accuracy: 0.7100
Epoch 13/20
60/60 [==============================] - 8s 140ms/step - loss: 0.2214 - accuracy: 0.9126 - val_loss: 0.7389 - val_accuracy: 0.7500
Epoch 14/20
60/60 [==============================] - 9s 144ms/step - loss: 0.1761 - accuracy: 0.9268 - val_loss: 0.8726 - val_accuracy: 0.7300
Epoch 15/20
60/60 [==============================] - 9s 142ms/step - loss: 0.1645 - accuracy: 0.9295 - val_loss: 0.7956 - val_accuracy: 0.7400
Epoch 16/20
60/60 [==============================] - 9s 142ms/step - loss: 0.1385 - accuracy: 0.9453 - val_loss: 0.7979 - val_accuracy: 0.7200
Epoch 17/20
60/60 [==============================] - 8s 138ms/step - loss: 0.1303 - accuracy: 0.9495 - val_loss: 0.8370 - val_accuracy: 0.7700
Epoch 18/20
60/60 [==============================] - 8s 139ms/step - loss: 0.1469 - accuracy: 0.9426 - val_loss: 0.7578 - val_accuracy: 0.7900
Epoch 19/20
60/60 [==============================] - 9s 143ms/step - loss: 0.1020 - accuracy: 0.9611 - val_loss: 0.7946 - val_accuracy: 0.7600
Epoch 20/20
60/60 [==============================] - 9s 147ms/step - loss: 0.1020 - accuracy: 0.9632 - val_loss: 0.8619 - val_accuracy: 0.7500
绘制损失变化曲线
1 2 3 4 5 6 7 8 9 10 11 plt.plot(hist_final.history['loss' ]) plt.plot(hist_final.history['val_loss' ]) plt.title("model loss" ) plt.ylabel("Loss" ) plt.xlabel("Epoch" ) plt.legend(["Loss" , "Validation Loss" ]) plt.show() plt.savefig('chart_loss.png' )
OKYT8S.png
<Figure size 640x480 with 0 Axes>
测试
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 image = cv2.imread(os.path.join(path,'airplane_020.jpg' )) ss.setBaseImage(image) ss.switchToSelectiveSearchFast() ssresults = ss.process() imOut = image.copy() boxes = [] count = 0 for e,result in enumerate (ssresults): if e < 50 : x,y,w,h = result timage = imout[x:x+w,y:y+h] resized = cv2.resize(timage, (224 ,224 ), interpolation = cv2.INTER_AREA) resized = np.expand_dims(resized,axis = 0 ) out = final_model.predict(resized) print (e,out) if (out[0 ][0 ]<out[0 ][1 ]): boxes.append([x,y,w,h]) count+=1 for box in boxes: x, y, w, h = box print (x,y,w,h) cv2.rectangle(imOut, (x, y), (x+w, y+h), (0 , 255 , 0 ), 1 , cv2.LINE_AA) plt.imshow(imOut) plt.show()
1/1 [==============================] - 1s 1s/step
0 [[ 2.551831 -2.6361141]]
1/1 [==============================] - 0s 31ms/step
1 [[ 1.2116516 -1.1462703]]
1/1 [==============================] - 0s 43ms/step
2 [[ 3.1247723 -3.0577419]]
1/1 [==============================] - 0s 38ms/step
3 [[ 0.3094477 -0.35184118]]
1/1 [==============================] - 0s 33ms/step
4 [[ 16.334412 -16.097301]]
1/1 [==============================] - 0s 36ms/step
5 [[-1.9055386 1.8246138]]
1/1 [==============================] - 0s 46ms/step
6 [[ 3.849068 -3.596069]]
1/1 [==============================] - 0s 35ms/step
7 [[-4.3343387 4.457566 ]]
1/1 [==============================] - 0s 28ms/step
8 [[ 2.1157243 -2.0935826]]
1/1 [==============================] - 0s 37ms/step
9 [[ 1.1227907 -1.0547544]]
1/1 [==============================] - 0s 32ms/step
10 [[ 3.028215 -3.0655315]]
1/1 [==============================] - 0s 35ms/step
11 [[-3.4406524 3.4818974]]
1/1 [==============================] - 0s 36ms/step
12 [[-3.3148727 3.2502732]]
1/1 [==============================] - 0s 69ms/step
13 [[-1.7705667 1.8401496]]
1/1 [==============================] - 0s 119ms/step
14 [[ 17.1168 -17.020542]]
1/1 [==============================] - 0s 31ms/step
15 [[-0.54532474 0.49859324]]
1/1 [==============================] - 0s 39ms/step
16 [[ 1.3955598 -1.4487445]]
1/1 [==============================] - 0s 39ms/step
17 [[-0.9255678 0.7681236]]
1/1 [==============================] - 0s 34ms/step
18 [[-1.0967708 1.0601681]]
1/1 [==============================] - 0s 35ms/step
19 [[ 1.6157322 -1.5883387]]
1/1 [==============================] - 0s 19ms/step
20 [[ 6.222667 -6.078978]]
1/1 [==============================] - 0s 22ms/step
21 [[ 1.9781907 -1.9643315]]
1/1 [==============================] - 0s 21ms/step
22 [[ 2.6352754 -2.6751401]]
1/1 [==============================] - 0s 22ms/step
23 [[-0.6199166 0.6234232]]
1/1 [==============================] - 0s 26ms/step
24 [[ 0.56931984 -0.52301127]]
1/1 [==============================] - 0s 19ms/step
25 [[-4.092036 4.0529504]]
1/1 [==============================] - 0s 20ms/step
26 [[-1.1211745 1.1134607]]
1/1 [==============================] - 0s 20ms/step
27 [[ 1.5422791 -1.504165 ]]
1/1 [==============================] - 0s 20ms/step
28 [[ 0.9709082 -1.1293985]]
1/1 [==============================] - 0s 22ms/step
29 [[ 6.2005806 -6.223065 ]]
1/1 [==============================] - 0s 19ms/step
30 [[ 0.7283702 -0.67930716]]
1/1 [==============================] - 0s 20ms/step
31 [[ 3.7712991 -3.7369084]]
1/1 [==============================] - 0s 20ms/step
32 [[ 1.7139057 -1.7024881]]
1/1 [==============================] - 0s 21ms/step
33 [[ 12.521779 -12.554838]]
1/1 [==============================] - 0s 24ms/step
34 [[ 3.4832761 -3.3890066]]
1/1 [==============================] - 0s 23ms/step
35 [[ 1.2881904 -1.3030462]]
1/1 [==============================] - 0s 31ms/step
36 [[ 1.3349662 -1.2856408]]
1/1 [==============================] - 0s 20ms/step
37 [[ 0.29870683 -0.25320527]]
1/1 [==============================] - 0s 20ms/step
38 [[-1.2835077 1.3210849]]
1/1 [==============================] - 0s 19ms/step
39 [[ 1.3556112 -1.3576012]]
1/1 [==============================] - 0s 20ms/step
40 [[ 5.945995 -5.7617545]]
1/1 [==============================] - 0s 20ms/step
41 [[ 4.5127177 -4.54366 ]]
1/1 [==============================] - 0s 19ms/step
42 [[ 1.2226268 -1.2334687]]
1/1 [==============================] - 0s 19ms/step
43 [[ 2.2175348 -2.1676831]]
1/1 [==============================] - 0s 20ms/step
44 [[ 5.3103013 -5.1500263]]
1/1 [==============================] - 0s 22ms/step
45 [[ 1.7600315 -1.8218967]]
1/1 [==============================] - 0s 20ms/step
46 [[ 3.2599857 -2.9830909]]
1/1 [==============================] - 0s 19ms/step
47 [[-1.2734337 1.2411362]]
1/1 [==============================] - 0s 21ms/step
48 [[ 10.405064 -10.193722]]
1/1 [==============================] - 0s 20ms/step
49 [[-1.247491 1.2709295]]
145 129 35 31
0 71 98 70
176 148 64 43
49 91 49 22
0 71 77 70
199 148 38 29
174 130 21 27
19 95 58 45
111 142 27 23
120 127 32 33
0 74 48 53
120 143 33 19
29 149 36 40
111 117 34 46
OKYWxN.png
最后的测试效果没有很好,可能因为使用的训练数据过少,如果显存足够可以不用截取训练集的子集来进行训练,效果应该会提高。