TensorFlow2 tf.dataset的使用

选中文字可对指定文章内容进行评论啦,→和←可快速切换按钮,绿色背景文字可以点击查看评论额。

这个笔记主要是TensorFlow 2.0的tf.dataset接口的使用。下面的示例会把numpy array的数据写入到TFRecord文件中,以及从TFRecord文件中读取数据到numpy array。

安装

可以参考官网的https://www.tensorflow.org/install教程来安装。

安装完成后,检查安装的TensorFlow的版本:

import tensorflow as tf
print(tf.__version__)

TensorFlow Dataset的使用

在TensorFlow 2.0中,向网络灌输数据的最好方法是使用tf.dataset类,dataset本身就是一个迭代器,所以可以使用for循环的方法来迭代dataset里的数据。

1、使用numpy array来创建一个dataset

import numpy as np
np.random.seed(0)
data = np.random.randn(256, 8, 8, 3)
dataset = tf.data.Dataset.from_tensor_slices(data)
print(dataset)
...
<TensorSliceDataset shapes: (8, 8, 3), types: tf.float64>

可以通过print()方法输出dataset,可以看到dataset的shap。

通常,第一维度的数据表示训练样本的数量。DataSet可以生产任何大小的batch size,但默认情况下,batch size的值为1, 也即是生成各自独立的训练样本。

2、迭代dataset

使用for循环可以dataset做迭代,如果想获取每个批次的数据,可以使用Python的enumerate,或者使用Dataset自身的方法enumerate(),迭代示例如下:

for i, batch in enumerate(dataset):
  if i == 255 or i == 256:
  print(i, batch.shape)
...
255 (8, 8, 3)
...
for i, batch in dataset.enumerate():
  if i == 255 or i == 256:
  print(i, batch.shape)
  print(i.numpy(), batch.shape)
...
tf.Tensor(255, shape=(), dtype=int64) (8, 8, 3) 255 (8, 8, 3)
...

可以看到,使用dataset.enumerate()内置方法,返回的的一个值是一个Tensor(张量)。

3、重复迭代dataset

如果需要重复多次对dataset进行迭代,可以使用dataset的内置方法repeat()。示例:

for i, batch in dataset.repeat(2).enumerate():
  if i == 255 or i == 256:
  print(i.numpy(), batch.shape)
...
255 (8, 8, 3)
256 (8, 8, 3)

4、使用take()获取指定数量大小的样本数

如果不想使用整个数据集,可以使用take()方法来获取指定数量的数据集:

for batch in dataset.take(3):
  print(batch.shape)
...
(8, 8, 3)
(8, 8, 3)
(8, 8, 3)

5、设置batch size

默认情况下,dataset是以batch size为1来迭代,可以使用batch()方法设置batch size的大小。

dataset = dataset.batch(16)
for batch in dataset.take(3):
  print(batch.shape)
...
(16, 8, 8, 3)
(16, 8, 8, 3)
(16, 8, 8, 3)

设置了batch size为16

6、打乱数据集

shuffle()方法可以用来打乱数据,其中shuffle()方法会接收一个buffer_size的参数,这个参数作为一个每一次打乱数据的缓存区,也即是每次去出buffer_size大小的数据进行打乱。如果想完全打乱整个数据集,buffer_size需要设置为整个数据集的大小。

示例:

dataset = tf.data.Dataset.from_tensor_slices(np.arange(19))
for batch in dataset.batch(5):
  print(batch)
...
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int64)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int64)
tf.Tensor([15 16 17 18], shape=(4,), dtype=int64)
...
for batch in dataset.shuffle(5).batch(5):
  print(batch)
...
tf.Tensor([2 5 0 4 1], shape=(5,), dtype=int64)
tf.Tensor([ 6  9  3 12 10], shape=(5,), dtype=int64)
tf.Tensor([13  8 15 17 11], shape=(5,), dtype=int64)
tf.Tensor([18 16 14  7], shape=(4,), dtype=int64)

可以看到shuffle()的buffer_size为5,batch size也是5,每次取出5个数据,并进行打乱。打乱后,每个批次的数据就不是原来按顺序的了。

需要注意的是,如果把shuffle()方法和batch()方法调转,会导致的结果是对批次打乱,而不是对数据集里的数据打乱。

for batch in dataset.batch(5).shuffle(5):
  print(batch)
...
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int64)
tf.Tensor([15 16 17 18], shape=(4,), dtype=int64)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int64)

7、转换数据

如果我们想对导入的数据做预处理,可以使用map方法。

def tranform(data):
  mean = tf.reduce_mean(data)
  return data - mean
for batch in dataset.shuffle(5).batch(5).map(tranform):
  print(batch)
...
tf.Tensor([ 2  3 -1  0 -2], shape=(5,), dtype=int64)
tf.Tensor([-2 -5  2  3  4], shape=(5,), dtype=int64)
tf.Tensor([-1  1  2 -5  3], shape=(5,), dtype=int64)
tf.Tensor([ 3 -3  7 -4], shape=(4,), dtype=int64)

8、预取指定大小的batch来做训练

通常,读取和处理dataset的数据会很耗时,即耗CPU时间,为了让GPU不出现太多空闲,可以使用prefetch()方法预取一定数据的batch来做训练。

dataset.shuffle(5).batch(5).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

其中,把buffer_size设置为tf.data.experimental.AUTOTUNE,意思是让TensorFlow自己找到一个合适的最优的buffer_size。

 

 

版权声明:著作权归作者所有。

相关推荐

java.util.Objects的使用

java.utils.Objects针对Object对象提供了几个静态的工具方法,这些方法可以归类为:null安全检查对象比较计算对象hash code对象转换为Stringnull安全检查null安全检查有5个方法:isNull(Object obj):检查对象是否为null,null返回true,否则返回falsenonNull(Object obj):与isNu

Android使用Fresco加载图片的用法

在Android的App开发中,延时加载图片是硬需求。有好几个开源的项目也提供了延时加载图片的功能,常用的有:Fresco,Glide和Universal Image Loader。这里主要简单介绍下Fresco。添加依赖在build.gradle添加依赖如下:dependencies {   ...    compi

Java Predicate接口的使用

Java 8新增了Predicate接口,它是一个函数接口,提供的test函数会接收一个参数,并返回一个bool值,我们可以用它来做过滤,检测类等功能。源码说明@FunctionalInterfacepublic interface Predicate<T> { /** * 具体过滤操作 需要被子类实现. * 用来处理参数T是否满足要求,可以理解为 条件A

SwiftUI Alerts的使用示例

SwiftUI里的Alerts可以分为三类:警告对话框(Alert Dialogs)操作列表(Action Sheets)弹窗(Popovers)警告对话框(Alert Dialogs)示例使用SwiftUI,我们可以很容易地使用声明的方式来创建警告框以及定义操作,示例如下:struct AlertView: View { @State private var showingAlert =

.NET Core 3 System.Text.JSON的使用

ASP.NET Core 3.0引入了原生支持处理json的Sytem.Text.Json,替换了之前的Newtonsoft.Json。如果项目中使用的是.NET Standard或者.NET framework(v4.6.1+),如果想使用System.Text.Json,则需要安装System.Text.Jsonde NuGet包。添加Namespaces使用System.Text.Json首

shell脚本变量的使用

问var=value 在export前后的差在哪? 这次让我们暂时丢开command line,先了解一下bash变量(variable)吧…所谓的变量,就是利用一个固定的”名称”(name),来存取一段可以变化的”值”(value)。1. 变量设定(set)在bash中, 你可以用”=”来设定或者重新定义变量的内容: name=value 在设定变量的时候,得遵守如

Linux下gcc的使用

在Linux系统中,可执行文件没有统一的后缀,系统从文件的属性来区分可执行文件和不可执行文件.而gcc则通过后缀来区别输入文件的类别,下面介绍gcc所遵循的部分约定规则. .c为后缀的文件,是C语言源代码文件; .a为后缀的文件,是由目标文件构成的库文件; .C,.cc,.cxx为后缀的文件,是C++源代码文件; .h为后缀的文件,是程序所包含的头文件; .i为后缀的文件

ES6中async的使用案例

在项目中有时会遇到异步操作的问题,async就是解决异步操作的终极操作。我会以终极三问(what,why,when)的形式来说明什么是async。由于这是第一篇文章不知道怎么写,有很大部分是借鉴阮一峰老的原文,事例将会从我的项目中摘取。 async是什么? 官方例子 官方文档 async相当于对Generator 函数的一个语法糖const fs = require

LINQ group by的使用示例

下面通过一个示例来展示linq中group by的使用。类Person如下:class Person { internal int PersonId; internal string car ; }Person列表List<Person>:persons[0] = new Person { PersonID = 1, car = "Ferrari" }; person