Bootstrap

MindSpore易点通·精讲系列--数据集加载之MindDataset

Dive Into MindSpore – MindDataset For Dataset Load

MindSpore易点通·精讲系列–数据集加载之MindDataset

本文开发环境

  • Ubuntu 20.04
  • Python 3.8
  • MindSpore 1.7.0

本文内容摘要

  • 背景介绍
  • 先看文档
  • 数据生成
  • 数据加载
  • 问题解答
  • 本文总结
  • 本文参考

1. 背景介绍

在前面的文章中,我们介绍了ImageFolderDatasetCSVDatasetTFRecordDataset三个数据集加载API。本文为数据集加载部分的最后一篇文章(当然,如果后续读者有需要,再考虑补充其他API精讲),我们将介绍MindSpore中官方数据格式MindRecord加载所涉及的APIMindDataset

一个完整的机器学习工作流包括数据集读取(可能包含数据处理)、模型定义、模型训练、模型评估。如何在工作流中更好的读取数据,是各个深度学习框架需要解决的一个重要问题。为此,TensorFlow推出了TFRecord数据格式,而MindSpore给出的解决方案就是MindRecord。在正式开始本文的讲解之前,先来看看MindRecord数据格式的特点:

  1. 实现数据统一存储、访问,使得训练时数据读取更加简便。
  2. 数据聚合存储、高效读取,使得训练时数据方便管理和移动。
  3. 高效的数据编解码操作,使得用户可以对数据操作无感知。
  4. 可以灵活控制数据切分的分区大小,实现分布式数据处理。

2. 先看文档

老传统,先看官方文档。

下面对官方文档中的参数,做简单解读:

  • dataset_files – 类型为字符串或者列表。如果为字符串则按照匹配规则自动寻找并加载相应前缀的MindRecord文件;如果为列表,则读取列表内的MindRecord文件,即列表内要为具体的文件名。
  • columns_list – 指定从MindRecord数据文件中读取的数据字段,或者说数据列。默认值为None,即读取全部字段或数据列。
  • 其他参数参见之前文章中的相关解读。

3. 数据生成

本文使用的是THUCNews数据集,如果需要将该数据集用于商业用途,请联系数据集作者

数据集启智社区下载地址

在上面API解读中,我们讲到MindDasetset读取的是MindRecord文件,下面就来介绍一下如何生成MindRecord数据文件。

MindRecord数据文件生成可以简单包含以下几个部分(非顺序):

  • 读取及处理原始数据
  • 声明MindRecord文件格式
  • 定义MindRecord数据字段
  • 添加MindRecord索引字段
  • 写入MindRecord数据内容

3.1 生成代码

下面我们基于THUCNews数据集,来生成MindRecord数据。

3.1.1 代码部分

import codecs
import os
import re

import numpy as np

from collections import Counter
from mindspore.mindrecord import FileWriter


def get_txt_files(data_dir):
    cls_txt_dict = {}
    txt_file_list = []

    # get files list and class files list.
    sub_data_name_list = next(os.walk(data_dir))[1]
    sub_data_name_list = sorted(sub_data_name_list)
    for sub_data_name in sub_data_name_list:
        sub_data_dir = os.path.join(data_dir, sub_data_name)
        data_name_list = next(os.walk(sub_data_dir))[2]
        data_file_list = [os.path.join(sub_data_dir, data_name) for data_name in data_name_list]
        cls_txt_dict[sub_data_name] = data_file_list
        txt_file_list.extend(data_file_list)
        num_data_files = len(data_file_list)
        print("{}: {}".format(sub_data_name, num_data_files), flush=True)
    num_txt_files = len(txt_file_list)
    print("total: {}".format(num_txt_files), flush=True)

    return cls_txt_dict, txt_file_list


def get_txt_data(txt_file):
    with codecs.open(txt_file, "r", "UTF8") as fp:
        txt_content = fp.read()
    txt_data = re.sub("\s+", " ", txt_content)

    return txt_data


def build_vocab(txt_file_list, vocab_size=7000):
    counter = Counter()
    for txt_file in txt_file_list:
        txt_data = get_txt_data(txt_file)
        counter.update(txt_data)

    num_vocab = len(counter)
    if num_vocab < vocab_size - 1:
        real_vocab_size = num_vocab + 2
    else:
        real_vocab_size = vocab_size

    # pad_id is 0, unk_id is 1
    vocab_dict = {word_freq[0]: ix + 1 for ix, word_freq in enumerate(counter.most_common(real_vocab_size - 2))}

    print("real vocab size: {}".format(real_vocab_size), flush=True)
    print("vocab dict:\n{}".forma
;