3

TFRecord的Shuffle、划分和读取 - 多事鬼间人

 1 year ago
source link: https://www.cnblogs.com/yc0806/p/16526114.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

对数据集的shuffle处理需要设置相应的buffer_size参数,相当于需要将相应数目的样本读入内存,且这部分内存会在训练过程中一直保持占用。完全的shuffle需要将整个数据集读入内存,这在大规模数据集的情况下是不现实的,故需要结合设备内存以及Batch大小将TFRecord文件随机划分为多个子文件,再对数据集做local shuffle(即设置相对较小的buffer_size,不小于单个子文件的样本数)。

Shuffle和划分

下文以一个异常检测数据集(正负样本不平衡)为例,在生成第一批TFRecord时,我将正负样本分别写入单独的TFrecord文件以备后续在对正负样本有不同处理策略的情况下无需再解析example_proto。比如在以下代码中,我对正负样本有不同的验证集比例,并将他们写入不同的验证集文件。

import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm as tqdm

# TFRecord划分
raw_normal_dataset = tf.data.TFRecordDataset("normal_16_256.tfrecords","GZIP")
raw_anomaly_dataset = tf.data.TFRecordDataset("anomaly_16_256.tfrecords","GZIP")
normal_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
anomaly_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
train_writer_list = [tf.io.TFRecordWriter(r'ex_1/'+'train_16_256_{}.tfrecords'.format(i),"GZIP") for i in range(SUBFILE_NUM+1)]
with tqdm(total=LEN_NORMAL_DATASET+LEN_ANOMALY_DATASET) as pbar:
    for example_proto in raw_normal_dataset:
        # 划分训练集和测试集
        if np.random.random() > 0.99: # 正样本测试集的比例
            normal_val_writer.write(example_proto.numpy())
        else:
            train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
        pbar.update(1)

    for example_proto in raw_anomaly_dataset:
        # 划分训练集和测试集
        if np.random.random() > 0.7: # 负样本测试集的比例
            anomaly_val_writer.write(example_proto.numpy())
        else:
            train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
        pbar.update(1)
normal_val_writer.close()
anomaly_val_writer.close()
for train_writer in train_writer_list:
    train_writer.close()
raw_train_dataset = tf.data.TFRecordDataset([r'ex_1/'+'train_16_256_{}.tfrecords'.format(i) for i in range(SUBFILE_NUM+1)],"GZIP")
raw_train_dataset = raw_train_dataset.shuffle(buffer_size=100000).batch(BATCH_SIZE)
parsed_train_dataset = raw_train_dataset.map(map_func=map_func)

raw_normal_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
raw_anomaly_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
parsed_nomarl_val_dataset = raw_normal_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)
parsed_anomaly_val_dateset = raw_anomaly_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK