Bootstrap

3-Python数据划分代码-小记

Python划分数据集

  1. 由于模型需要使用训练集、验证集和测试集,而我只有一个总的数据集,因此用Python实现了数据集的划分,特此小记一下。
  2. 同时也是为了记录这个过程中用到的Python的一些知识。

1划分要求

  • 原始数据集的格式为:第一行为标签,之后每一行为一条数据。为了处理方便,写代码的时候使用的数据集就从简了,如下图所示。

在这里插入图片描述

  • 训练集、验证集、测试集的格式要求和原始数据集一样

2函数名称解释

  • 数据集划分的代码我写在一个函数里面。函数为:split_data(before_dataset_filepath, output_dir,split_prop)

  • 参数解释:

    • before_dataset_filepath:待划分的数据集文件的路径,是相对路径。
    • output_dir:训练集、测试集、验证集文件所在的目录,是相对路径。
    • split_prop:划分比例的列表。这里参照了文献的划分比例,即[训练集:验证集:测试集]=[3:1:2]。该列表共三个元素,例如[3,1,2],表示[训练集:验证集:测试集]=[3:1:2]

3划分过程

3.1读取待划分数据

  1. 这里使用了numpy.loadtxt函数,以字符串的形式将数据读取进来,用numpy.ndarray进行存储。
  2. 分隔符为换行符
  3. 以字符串的形式读取每一条记录
data_list = np.loadtxt(before_dataset_filepath, dtype="str", comments=None, delimiter="\n", encoding="utf-8-sig")

3.2确定各数据集的数量

  1. 首先确定读取进来的数据的记录数(第一行为标签行,不能算在内)
  2. 根据设置好的划分比例,计算每个数据集的数量(这里没有进行四舍五入)。
####确定样本总数,以及训练、验证、测试集的数量
total_num = len(data_list) - 1
print("total_num:", total_num)
train_num=int(split_prop[0]/sum(split_prop)*total_num)#使用int()相当于直接舍弃掉小数位
dev_num = int(split_prop[1] / sum(split_prop)*total_num)
test_num = total_num-train_num-dev_num
print("train_num:", train_num)
print("dev_num:", dev_num)
print("test_num:", test_num)
####我这里的示例结果:
#total_num: 10
#train_num: 5
#dev_num: 1
#test_num: 4

3.3删除标签行

确定了各数据集的数量之后,就需要从待划分数据集中抽取数据了。但在这之前需要把第一行的标签行去掉。这里使用了numpy.delete()函数。

#删除第一行的标签
data_list=numpy.delete(data_list,[0])#这里的data_list就是一个一维的数组,因此第二个参数指定为了该数组的第一个元素

3.4打乱数据集

这里采用了random.shuffle()来随机打乱待划分数据集数组的索引,比较方便。

  • 生成一个代表原数据集的索引列表,列表元素就是原数据集的每一个索引
  • 然后打乱这个生成的列表
####随机打乱划分前的数据
index=list(range(0,total_num))
print("打乱前索引:",index)
random.shuffle(index)
print("打乱后索引:", index)
####我这里的示例结果:
#打乱前索引: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
#打乱后索引: [2, 3, 8, 0, 6, 9, 1, 5, 4, 7]

3.5读取并写入对应数据集文件

这里有两点需要注意:

  1. 为了每次运行该程序的方便,需要在写入之前先将对应的文件清空。我这里采用的方式是打开对应的文件,然后以'w'的方式写入一个空字符串。官方文档也说了,f.truncate()函数不常用,而且我用了之后好像也没成功~
  2. 注意读取index索引列表时候的下标范围。
####读取并写入训练数据集文件
train_txt=output_dir+"train.txt"#训练数据集文件的路径
#先清空文件内容
with open(train_txt, 'w', encoding="utf-8-sig") as f1:
    f1.write("")
#再写入
with open(train_txt,'a',encoding="utf-8-sig") as f1:
    #先写标签
    f1.write("text_a\tlabel")
    f1.write("\n")
    for i in range(0,train_num):
        f1.write(data_list[index[i]])
        f1.write("\n")
####读取并写入验证数据集文件
dev_txt = output_dir + "dev.txt"
# 先清空文件内容
with open(dev_txt, 'w', encoding="utf-8-sig") as f2:
    f2.write("")
with open(dev_txt,'a',encoding="utf-8-sig") as f2:
    f2.write("text_a\tlabel")
    f2.write("\n")
    for i in range(train_num,train_num+dev_num):
        f2.write(data_list[index[i]])
        f2.write("\n")
####读取并写入测试数据集文件
test_txt = output_dir + "test.txt"
# 先清空文件内容
with open(test_txt, 'w', encoding="utf-8-sig") as f3:
    f3.write("")
with open(test_txt,'a',encoding="utf-8-sig") as f3:
    f3.write("text_a\tlabel")
    f3.write("\n")
    for i in range(train_num+dev_num,total_num):
        f3.write(data_list[index[i]])
        f3.write("\n")

4完整的代码

  1. 首先是封装好的split_data()函数
def split_data(before_dataset_filepath, output_dir,split_prop):
    data_list = np.loadtxt(before_dataset_filepath, dtype="str", comments=None, delimiter="\n", encoding="utf-8-sig")
    print(type(data_list))

    #确定样本总数,以及训练、验证、测试集的数量
    total_num = len(data_list) - 1
    print("total_num:", total_num)
    train_num=int(split_prop[0]/sum(split_prop)*total_num)
    dev_num = int(split_prop[1] / sum(split_prop)*total_num)
    test_num = total_num-train_num-dev_num
    print("train_num:", train_num)
    print("dev_num:", dev_num)
    print("test_num:", test_num)
    #删除第一行的标签
    data_list=numpy.delete(data_list,[0])
    #随机打乱划分前的数据
    index=list(range(0,total_num))
    print("打乱前索引:",index)
    random.shuffle(index)
    print("打乱后索引:", index)

    #读取并写入训练数据集文件
    train_txt=output_dir+"train.txt"
    #先清空文件内容
    with open(train_txt, 'w', encoding="utf-8-sig") as f1:
        f1.write("")
    with open(train_txt,'a',encoding="utf-8-sig") as f1:

        f1.write("text_a\tlabel")
        f1.write("\n")
        for i in range(0,train_num):
            f1.write(data_list[index[i]])
            f1.write("\n")
    # 读取并写入验证数据集文件
    dev_txt = output_dir + "dev.txt"
    # 先清空文件内容
    with open(dev_txt, 'w', encoding="utf-8-sig") as f2:
        f2.write("")
    with open(dev_txt,'a',encoding="utf-8-sig") as f2:

        f2.write("text_a\tlabel")
        f2.write("\n")
        for i in range(train_num,train_num+dev_num):
            f2.write(data_list[index[i]])
            f2.write("\n")
    # 读取并写入测试数据集文件
    test_txt = output_dir + "test.txt"
    # 先清空文件内容
    with open(test_txt, 'w', encoding="utf-8-sig") as f3:
        f3.write("")
    with open(test_txt,'a',encoding="utf-8-sig") as f3:

        f3.write("text_a\tlabel")
        f3.write("\n")
        for i in range(train_num+dev_num,total_num):
            f3.write(data_list[index[i]])
            f3.write("\n")
  1. 其次是调用的代码
####相关参数
PROC_DATA_DIR = "data/processed_data/"
TRAIN_DATA_FILENAME = "train_data.txt"
TRAIN_DATA_FILEPATH = os.path.join(PROC_DATA_DIR, TRAIN_DATA_FILENAME)
out_put_dir="data/split_data/"
split_prop=[3,1,2]#[训练集,验证集,测试集]
####调用
split_data(TRAIN_DATA_FILEPATH,out_put_dir,split_prop)

TRAIN_DATA_FILEPATH = os.path.join(PROC_DATA_DIR, TRAIN_DATA_FILENAME)
out_put_dir="data/split_data/"
split_prop=[3,1,2]#[训练集,验证集,测试集]
####调用
split_data(TRAIN_DATA_FILEPATH,out_put_dir,split_prop)
;