Python划分数据集
- 由于模型需要使用训练集、验证集和测试集,而我只有一个总的数据集,因此用Python实现了数据集的划分,特此小记一下。
- 同时也是为了记录这个过程中用到的Python的一些知识。
1划分要求
- 原始数据集的格式为:第一行为标签,之后每一行为一条数据。为了处理方便,写代码的时候使用的数据集就从简了,如下图所示。
- 训练集、验证集、测试集的格式要求和原始数据集一样
2函数名称解释
-
数据集划分的代码我写在一个函数里面。函数为:
split_data(before_dataset_filepath, output_dir,split_prop)
-
参数解释:
before_dataset_filepath
:待划分的数据集文件的路径,是相对路径。output_dir
:训练集、测试集、验证集文件所在的目录,是相对路径。split_prop
:划分比例的列表。这里参照了文献的划分比例,即[训练集:验证集:测试集]=[3:1:2]
。该列表共三个元素,例如[3,1,2]
,表示[训练集:验证集:测试集]=[3:1:2]
。
3划分过程
3.1读取待划分数据
- 这里使用了
numpy.loadtxt
函数,以字符串的形式将数据读取进来,用numpy.ndarray
进行存储。 - 分隔符为换行符
- 以字符串的形式读取每一条记录
data_list = np.loadtxt(before_dataset_filepath, dtype="str", comments=None, delimiter="\n", encoding="utf-8-sig")
3.2确定各数据集的数量
- 首先确定读取进来的数据的记录数(第一行为标签行,不能算在内)
- 根据设置好的划分比例,计算每个数据集的数量(这里没有进行四舍五入)。
####确定样本总数,以及训练、验证、测试集的数量
total_num = len(data_list) - 1
print("total_num:", total_num)
train_num=int(split_prop[0]/sum(split_prop)*total_num)#使用int()相当于直接舍弃掉小数位
dev_num = int(split_prop[1] / sum(split_prop)*total_num)
test_num = total_num-train_num-dev_num
print("train_num:", train_num)
print("dev_num:", dev_num)
print("test_num:", test_num)
####我这里的示例结果:
#total_num: 10
#train_num: 5
#dev_num: 1
#test_num: 4
3.3删除标签行
确定了各数据集的数量之后,就需要从待划分数据集中抽取数据了。但在这之前需要把第一行的标签行去掉。这里使用了numpy.delete()
函数。
#删除第一行的标签
data_list=numpy.delete(data_list,[0])#这里的data_list就是一个一维的数组,因此第二个参数指定为了该数组的第一个元素
3.4打乱数据集
这里采用了random.shuffle()
来随机打乱待划分数据集数组的索引,比较方便。
- 生成一个代表原数据集的索引列表,列表元素就是原数据集的每一个索引
- 然后打乱这个生成的列表
####随机打乱划分前的数据
index=list(range(0,total_num))
print("打乱前索引:",index)
random.shuffle(index)
print("打乱后索引:", index)
####我这里的示例结果:
#打乱前索引: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
#打乱后索引: [2, 3, 8, 0, 6, 9, 1, 5, 4, 7]
3.5读取并写入对应数据集文件
这里有两点需要注意:
- 为了每次运行该程序的方便,需要在写入之前先将对应的文件清空。我这里采用的方式是打开对应的文件,然后以
'w'
的方式写入一个空字符串。官方文档也说了,f.truncate()
函数不常用,而且我用了之后好像也没成功~ - 注意读取
index
索引列表时候的下标范围。
####读取并写入训练数据集文件
train_txt=output_dir+"train.txt"#训练数据集文件的路径
#先清空文件内容
with open(train_txt, 'w', encoding="utf-8-sig") as f1:
f1.write("")
#再写入
with open(train_txt,'a',encoding="utf-8-sig") as f1:
#先写标签
f1.write("text_a\tlabel")
f1.write("\n")
for i in range(0,train_num):
f1.write(data_list[index[i]])
f1.write("\n")
####读取并写入验证数据集文件
dev_txt = output_dir + "dev.txt"
# 先清空文件内容
with open(dev_txt, 'w', encoding="utf-8-sig") as f2:
f2.write("")
with open(dev_txt,'a',encoding="utf-8-sig") as f2:
f2.write("text_a\tlabel")
f2.write("\n")
for i in range(train_num,train_num+dev_num):
f2.write(data_list[index[i]])
f2.write("\n")
####读取并写入测试数据集文件
test_txt = output_dir + "test.txt"
# 先清空文件内容
with open(test_txt, 'w', encoding="utf-8-sig") as f3:
f3.write("")
with open(test_txt,'a',encoding="utf-8-sig") as f3:
f3.write("text_a\tlabel")
f3.write("\n")
for i in range(train_num+dev_num,total_num):
f3.write(data_list[index[i]])
f3.write("\n")
4完整的代码
- 首先是封装好的
split_data()
函数
def split_data(before_dataset_filepath, output_dir,split_prop):
data_list = np.loadtxt(before_dataset_filepath, dtype="str", comments=None, delimiter="\n", encoding="utf-8-sig")
print(type(data_list))
#确定样本总数,以及训练、验证、测试集的数量
total_num = len(data_list) - 1
print("total_num:", total_num)
train_num=int(split_prop[0]/sum(split_prop)*total_num)
dev_num = int(split_prop[1] / sum(split_prop)*total_num)
test_num = total_num-train_num-dev_num
print("train_num:", train_num)
print("dev_num:", dev_num)
print("test_num:", test_num)
#删除第一行的标签
data_list=numpy.delete(data_list,[0])
#随机打乱划分前的数据
index=list(range(0,total_num))
print("打乱前索引:",index)
random.shuffle(index)
print("打乱后索引:", index)
#读取并写入训练数据集文件
train_txt=output_dir+"train.txt"
#先清空文件内容
with open(train_txt, 'w', encoding="utf-8-sig") as f1:
f1.write("")
with open(train_txt,'a',encoding="utf-8-sig") as f1:
f1.write("text_a\tlabel")
f1.write("\n")
for i in range(0,train_num):
f1.write(data_list[index[i]])
f1.write("\n")
# 读取并写入验证数据集文件
dev_txt = output_dir + "dev.txt"
# 先清空文件内容
with open(dev_txt, 'w', encoding="utf-8-sig") as f2:
f2.write("")
with open(dev_txt,'a',encoding="utf-8-sig") as f2:
f2.write("text_a\tlabel")
f2.write("\n")
for i in range(train_num,train_num+dev_num):
f2.write(data_list[index[i]])
f2.write("\n")
# 读取并写入测试数据集文件
test_txt = output_dir + "test.txt"
# 先清空文件内容
with open(test_txt, 'w', encoding="utf-8-sig") as f3:
f3.write("")
with open(test_txt,'a',encoding="utf-8-sig") as f3:
f3.write("text_a\tlabel")
f3.write("\n")
for i in range(train_num+dev_num,total_num):
f3.write(data_list[index[i]])
f3.write("\n")
- 其次是调用的代码
####相关参数
PROC_DATA_DIR = "data/processed_data/"
TRAIN_DATA_FILENAME = "train_data.txt"
TRAIN_DATA_FILEPATH = os.path.join(PROC_DATA_DIR, TRAIN_DATA_FILENAME)
out_put_dir="data/split_data/"
split_prop=[3,1,2]#[训练集,验证集,测试集]
####调用
split_data(TRAIN_DATA_FILEPATH,out_put_dir,split_prop)
TRAIN_DATA_FILEPATH = os.path.join(PROC_DATA_DIR, TRAIN_DATA_FILENAME)
out_put_dir="data/split_data/"
split_prop=[3,1,2]#[训练集,验证集,测试集]
####调用
split_data(TRAIN_DATA_FILEPATH,out_put_dir,split_prop)