16

深度学习实践-使用CelebA_Spoof训练的权重测试NUAA

 2 years ago
source link: https://segmentfault.com/a/1190000040712441
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

1.直接使用保存的网络测试NUAA

测试代码:

def read_test_file():
    base_path = r'E:\ml\fas\data\NUAA'
    val_file_path = os.path.join(base_path, "test.txt")
    train_image_path_list = []
    train_labels_list = []
    with open(val_file_path) as f:
        lines = f.readlines()
        for line in lines:
            image_path = line.split(',')[0]
            label = line.split(',')[1]
            img = cv2.imread(image_path)
            resize_img = cv2.resize(img, (100, 100))
            train_image_path_list.append(resize_img)
            train_labels_list.append(int(label))
    return np.asarray(train_image_path_list), np.asarray(train_labels_list)

def fit2():
    X_test, y_test = read_test_file()
    model = load_model('model/live_model.h5')
    test_loss, test_acc = model.evaluate(X_test, y_test, verbose=2)
    print(test_loss, test_acc)

if __name__ == '__main__':
    fit2()
200/200 - 8s - loss: 1.3156 - accuracy: 0.5285
1.3155993223190308 0.5284998416900635

可以看到CelebA_Spoof 在自己的数据集测试可以达到99%以上
但是在NUAA数据集上测试准确率只能达到: 0.528
查看CelebA_Spoof 数据集后发现,CelebA_Spoof 的spoof数据集都是图像打印在纸上的加头像,但是NUAA里面是电脑屏幕或者手机屏幕录取的图片

2. 使用CelebA_Spoof 训练的网络再次训练NUAA数据

def read_train_file():
    base_path = r'E:\ml\fas\data\NUAA'
    train_file_path = os.path.join(base_path, "train.txt")

    train_image_path_list = []
    train_labels_list = []
    with open(train_file_path) as f:
        lines = f.readlines()
        for line in lines:
            image_path = line.split(',')[0]
            label = line.split(',')[1]
            img = cv2.imread(image_path)
            resize_img = cv2.resize(img, (100, 100))
            train_image_path_list.append(resize_img)
            train_labels_list.append(int(label))
    return np.asarray(train_image_path_list), np.asarray(train_labels_list)


def read_val_file():
    base_path = r'E:\ml\fas\data\NUAA'
    val_file_path = os.path.join(base_path, "val.txt")
    train_image_path_list = []
    train_labels_list = []
    with open(val_file_path) as f:
        lines = f.readlines()
        for line in lines:
            image_path = line.split(',')[0]
            label = line.split(',')[1]
            img = cv2.imread(image_path)
            resize_img = cv2.resize(img, (100, 100))
            train_image_path_list.append(resize_img)
            train_labels_list.append(int(label))
    return np.asarray(train_image_path_list), np.asarray(train_labels_list)


def read_test_file():
    base_path = r'E:\ml\fas\data\NUAA'
    val_file_path = os.path.join(base_path, "test.txt")
    train_image_path_list = []
    train_labels_list = []
    with open(val_file_path) as f:
        lines = f.readlines()
        for line in lines:
            image_path = line.split(',')[0]
            label = line.split(',')[1]
            img = cv2.imread(image_path)
            resize_img = cv2.resize(img, (100, 100))
            train_image_path_list.append(resize_img)
            train_labels_list.append(int(label))
    return np.asarray(train_image_path_list), np.asarray(train_labels_list)

def fit2():
    X_train, X_label = read_train_file()
    X_valid, y_valid = read_val_file()
    X_test, y_test = read_test_file()
    model = load_model('model/live_model.h5')
    history = model.fit(X_train, X_label, epochs=10, validation_data=(X_valid, y_valid))
    print(history)
    test_loss, test_acc = model.evaluate(X_test, y_test, verbose=2)
    print(test_loss, test_acc)


if __name__ == '__main__':
    fit2()

3.训练后测试结果

200/200 - 8s - loss: 5.3590 - accuracy: 0.6008
5.35904598236084 0.6008455753326416

虽然有些许提升,但是提升效果不怎么样


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK