windows编译tensorflow tensorflow单机多卡程序的框架 tensorflow的操作 tensorflow的变量初始化和scope 人体姿态检测 segmentation标注工具 tensorflow模型恢复与inference的模型简化 利用多线程读取数据加快网络训练 tensorflow使用LSTM pytorch examples 利用tensorboard调参 深度学习中的loss函数汇总 纯C++代码实现的faster rcnn tensorflow使用记录 windows下配置caffe_ssd use ubuntu caffe as libs use windows caffe like opencv windows caffe implement caffe model convert to keras model flappyBird DQN Joint Face Detection and Alignment using Multi-task Cascaded Convolutional Neural Networks Fast-style-transfer tensorflow安装 tensorflow DQN Fully Convolutional Models for Semantic Segmentation Transposed Convolution, Fractionally Strided Convolution or Deconvolution 基于tensorflow的分布式部署 用python实现mlp bp算法 用tensorflow和tflearn搭建经典网络结构 Data Augmentation Tensorflow examples Training Faster RCNN with Online Hard Example Mining 使用Tensorflow做Prisma图像风格迁移 RNN(循环神经网络)推导 深度学习中的稀疏编码思想 利用caffe与lmdb读写图像数据 分析voc2007检测数据 用python写caffe网络配置 ssd开发 将KITTI的数据格式转换为VOC Pascal的xml格式 Faster RCNN 源码分析 在Caffe中建立Python layer 在Caffe中建立C++ layer 为什么CNN反向传播计算梯度时需要将权重旋转180度 Caffe使用教程(下) Caffe使用教程(上) CNN反向传播 Softmax回归 Caffe Ubuntu下环境配置

利用多线程读取数据加快网络训练

2017年11月27日

利用多线程生成数据

当CPU读取数据跟不上GPU处理数据速度时候可以考虑这种方式,这种方法的好处是数据接口简单而且可以大幅加快网络训练时间。特别是针对服务器端超多核CPU配置。实验证实,该方法可以大大提升GPU的利用率。

定义一个队列的类

"""
this file is modified from keras implemention of data process multi-threading,
see https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
"""
import time
import numpy as np
import threading
import multiprocessing
try:
    import queue
except ImportError:
    import Queue as queue


class GeneratorEnqueuer():
    """
    Builds a queue out of a data generator.
    Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
    # Arguments
        generator: a generator function which endlessly yields data
        use_multiprocessing: use multiprocessing if True, otherwise threading
        wait_time: time to sleep in-between calls to `put()`
        random_seed: Initial seed for workers,
            will be incremented by one for each workers.
    """

    def __init__(self, generator,
                 use_multiprocessing=False,
                 wait_time=0.05,
                 random_seed=None):
        self.wait_time = wait_time
        self._generator = generator
        self._use_multiprocessing = use_multiprocessing
        self._threads = []
        self._stop_event = None
        self.queue = None
        self.random_seed = random_seed

    def start(self, workers=1, max_queue_size=10):
        """Kicks off threads which add data from the generator into the queue.
        # Arguments
            workers: number of worker threads
            max_queue_size: queue size
                (when full, threads could block on `put()`)
        """

        def data_generator_task():
            while not self._stop_event.is_set():
                try:
                    if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
                        generator_output = next(self._generator)
                        self.queue.put(generator_output)
                    else:
                        time.sleep(self.wait_time)
                except Exception:
                    self._stop_event.set()
                    raise

        try:
            if self._use_multiprocessing:
                self.queue = multiprocessing.Queue(maxsize=max_queue_size)
                self._stop_event = multiprocessing.Event()
            else:
                self.queue = queue.Queue()
                self._stop_event = threading.Event()

            for _ in range(workers):
                if self._use_multiprocessing:
                    # Reset random seed else all children processes
                    # share the same seed
                    np.random.seed(self.random_seed)
                    thread = multiprocessing.Process(target=data_generator_task)
                    thread.daemon = True
                    if self.random_seed is not None:
                        self.random_seed += 1
                else:
                    thread = threading.Thread(target=data_generator_task)
                self._threads.append(thread)
                thread.start()
        except:
            self.stop()
            raise

    def is_running(self):
        return self._stop_event is not None and not self._stop_event.is_set()

    def stop(self, timeout=None):
        """Stops running threads and wait for them to exit, if necessary.
        Should be called by the same thread which called `start()`.
        # Arguments
            timeout: maximum time to wait on `thread.join()`.
        """
        if self.is_running():
            self._stop_event.set()

        for thread in self._threads:
            if thread.is_alive():
                if self._use_multiprocessing:
                    thread.terminate()
                else:
                    thread.join(timeout)

        if self._use_multiprocessing:
            if self.queue is not None:
                self.queue.close()

        self._threads = []
        self._stop_event = None
        self.queue = None

    def get(self):
        """Creates a generator to extract data from the queue.
        Skip the data if it is `None`.
        # Returns
            A generator
        """
        while self.is_running():
            if not self.queue.empty():
                inputs = self.queue.get()
                if inputs is not None:
                    yield inputs
            else:
                time.sleep(self.wait_time)

然后定义一个data_provide函数用于生成data和label

def data_provide():
    ...
    return data, label

batch生成器

def generator(batch_size=32):
    images = []
    labels = []
    while True:
        try:
            im, label = data_provide()
            images.append(im)
            labels.append(label)

            if len(images) == batch_size:
                yield images, labels
                images = []
                labels = []
        except Exception as e:
            print(e)
            import traceback
            traceback.print_exc()
            continue

def get_batch(num_workers, **kwargs):
    try:
        enqueuer = GeneratorEnqueuer(generator(**kwargs), use_multiprocessing=True)
        enqueuer.start(max_queue_size=64, workers=num_workers)
        generator_output = None
        while True:
            while enqueuer.is_running():
                if not enqueuer.queue.empty():
                    generator_output = enqueuer.queue.get()
                    break
                else:
                    time.sleep(0.01)
            yield generator_output
            generator_output = None
    finally:
        if enqueuer is not None:
            enqueuer.stop()

if __name__ == '__main__':
    gen = get_batch(num_workers=64,batch_size=128)
    while 1:
        start = time.time()
        images, labels =  next(gen)
        end = time.time()
        print end-start
        print(len(images)," ",images[0].shape)
        print(len(labels)," ",labels[0].shape)

tensorflow1.3+自带的dataset API

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import random

import tensorflow as tf

class DataReader(object):

    def __init__(self, num_class = 2, batch_size = 128, epoch = 10, file_list = None):
        self.num_class = num_class
        self.train_data_list = file_list
        self.batch_size = batch_size
        self.epoch = epoch
        self.dataset_iterator = None
        self._init_dataset()

    def _get_filename_list(self):
        lines = open(self.train_data_list,'r').read().splitlines()
        random.shuffle(lines)
        filename_list = []
        label_list = []
        for i in range(len(lines)):
            filename_list.append(lines[i].split(' ')[0])
            label_list.append(int(lines[i].split(' ')[1]))
        return tf.constant(filename_list), tf.constant(label_list) #must be tensor type

    def _parse_function(self, filename, label):
        image_string = tf.read_file(filename)
        image_decoded = tf.image.decode_image(image_string)
        channel_mean = tf.constant(np.array([123.68,116.779,103.938], dtype=np.float32))
        image = tf.subtract(tf.cast(image_decoded, dtype=tf.float32),channel_mean)
        image_label = tf.one_hot(indices = label, depth = self.num_class)
        return tf.reshape(image,[256,256,3]), tf.reshape(image_label,[self.num_class])

    def _init_dataset(self):
        # Read sample files
        sample_files, sample_labels = self._get_filename_list()
        # make dataset
        dataset = tf.data.Dataset.from_tensor_slices((sample_files, sample_labels))
        dataset = dataset.map(self._parse_function)
        dataset = dataset.shuffle(1000).batch(self.batch_size).repeat(self.epoch)
        self.dataset_iterator = dataset.make_one_shot_iterator()

    def inputs(self):
        return self.dataset_iterator.get_next()

if __name__ == '__main__':
    data_generator = DataReader(file_list = "train_list.txt")
	"""
	train_list.txt
	
	path/image1.jpg label_1
	path/image2.jpg label_2
	path/image3.jpg label_3
	path/image4.jpg label_4	
	"""
    sess = tf.Session()
    x,y = sess.run(data_generator.inputs())

如果需要用到python处理数据或有特殊需求可以使用如下接口,完美兼容python numpy

    def _read_py_function(self, filename, label):
        image_decoded = py_function(filename)
		label = py_function(label)
        return image_decoded, label


    def _init_dataset(self):

        # Read sample files
        sample_files, sample_labels = self._get_filename_list()

        # make dataset
        dataset = tf.data.Dataset.from_tensor_slices((sample_files, sample_labels))

        dataset = dataset.map(
            lambda filename, label: tuple(tf.py_func(
                self._read_py_function, [filename, label], [tf.int32, label.dtype])))

        self.dataset_iterator = dataset.make_one_shot_iterator()

利用python的multiprocess包多线程处理生成数据

import numpy as np
import time
from multiprocessing import Pool, cpu_count

def batch_works(k):
    #将任务集合拆分到每个cpu核中
    if k == n_processes - 1:
        nums = jobs[k * int(len(jobs) / n_processes) : ]
    else:
        nums = jobs[k * int(len(jobs) / n_processes) : (k + 1) * int(len(jobs) / n_processes)]
    for j in nums:
        py_function()

if __name__ == '__main__':
    jobs = range(100) #所需执行的任务序列
    n_processes = cpu_count()

    pool = Pool(processes=n_processes)
    start = time.time()
    pool.map(batch_works, range(n_processes))
    pool.close()
    pool.join()
    end = time.time()
    print end -start

看我写的辛苦求打赏啊!!!有学术讨论和指点请加微信manutdzou,注明

20


blog comments powered by Disqus