暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

Kaggle知识点:TFRecord使用教程

Coggle数据科学 2022-04-21
487

为了高效地读取数据,比较有帮助的一种做法是对数据进行序列化并将其存储在一组可线性读取的文件(每个文件 100-200MB)中。这尤其适用于通过网络进行流式传输的数据。这种做法对缓冲任何数据预处理也十分有用。

TFRecord
格式是一种用于存储二进制记录序列的简单格式。协议缓冲区是一个跨平台、跨语言的库,用于高效地序列化结构化数据。协议消息由.proto
文件定义,这通常是了解消息类型最简单的方法。

tf.Example
消息(或 protobuf)是一种灵活的消息类型,表示{"string": value}
映射。它专为 TensorFlow 而设计,也可以搭配Pytroch使用。

TFRecords 格式

TFRecord
文件包含一系列记录。该文件只能按顺序读取。每条记录包含一个字节字符串(用于数据有效负载),外加数据长度,以及用于完整性检查的哈希值。

每条记录会存储为以下格式:

uint64 length uint32 masked_crc32_of_length byte   data[length] uint32 masked_crc32_of_data

TFRecords 写入案例

tf.data.TFRecordWriter
用于写入一条序列化后的信息。

import numpy as np

# Write the records to a file.
with tf.io.TFRecordWriter("tmp.tfrecords"as file_writer:
    for _ in range(4):
        x, y = np.random.random(), np.random.random()

        record_bytes = tf.train.Example(
            features=tf.train.Features(
                feature={
                    "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
                    "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
                }
            )
        ).SerializeToString()
        file_writer.write(record_bytes)

TFRecords 读取案例

tf.data.TFRecordDataset
可以将多个TFRecord
格式合并为Dataset
,这里需要编写解码函数。

# Read the data back out.
def decode_fn(record_bytes):
    return tf.io.parse_single_example(
        # Data
        record_bytes,
        
        # Schema
        {
            "x": tf.io.FixedLenFeature([], dtype=tf.float32),
            "y": tf.io.FixedLenFeature([], dtype=tf.float32),
        },
    )


for batch in tf.data.TFRecordDataset(["tmp.tfrecords"]).map(decode_fn):
    print("x = {x:.4f},  y = {y:.4f}".format(**batch))

案例:写入和读取图片

写入图片

# 创建一个写tfrecord的变量
with tf.io.TFRecordWriter("train.tfrecord"as writer:
    
    for path in train_path:
        
        # 读取单个文件的过程
        # image_string 图片内容
        # image_label 图片标签
        image_string, image_label = encode_sigle_example(path)
        # 跳过空图片
        if image_string is None:
            continue
        
        feature = {
            "label": _int64_feature(image_label),
            "image": _bytes_feature(image_string),
        }

        tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(tf_example.SerializeToString())

读取图片

image_feature_description = {
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image': tf.io.FixedLenFeature([], tf.string),
}

from functools import partial
def _parse_image_function(example_proto):
    data = tf.io.parse_single_example(example_proto, image_feature_description)
    image = tf.image.decode_image(data['image'])
    print(data['label'])
    return image# , # data['label']
    
image_dataset = tf.data.TFRecordDataset('train.tfrecord')
train_image_dataset = image_dataset.map(_parse_image_function)

注意事项

  • TFRecord
    格式的写入格式和解码格式要一致,否则会报错。
  • TFRecord
    格式适合存储为定长的文件,可以考虑将图片进行缩放然后写入。


加好友  领取完整代码 #

△长按添加竞赛小助手
添加Coggle小助手微信(ID : coggle666)


每天Kaggle算法竞赛、干货资讯汇总

与 22000+来自竞赛爱好者一起交流~




文章转载自Coggle数据科学,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论