首先是这个文件,PySODEvalToolkit/tools/check_path.py
, 需要替换成我修改后的这个版本
# -*- coding: utf-8 -*-
import argparse
import json
import os
from collections import OrderedDict
parser = argparse.ArgumentParser(description="A simple tool for checking your json config file.")
parser.add_argument(
"-m", "--method-jsons", nargs="+", required=True, help="The json file about all methods."
)
parser.add_argument(
"-d", "--dataset-jsons", nargs="+", required=True, help="The json file about all datasets."
)
args = parser.parse_args()
for method_json, dataset_json in zip(args.method_jsons, args.dataset_jsons):
with open(method_json, encoding="utf-8", mode="r") as f:
methods_info = json.load(f, object_hook=OrderedDict) # 有序载入
with open(dataset_json, encoding="utf-8", mode="r") as f:
datasets_info = json.load(f, object_hook=OrderedDict) # 有序载入
total_msgs = []
for method_name, method_info in methods_info.items():
print(f"Checking for {method_name} ...")
for dataset_name, results_info in method_info.items():
if results_info is None:
continue
dataset_mask_info = datasets_info[dataset_name]["mask"]
mask_path = dataset_mask_info["path"]
mask_suffix = dataset_mask_info["suffix"]
dir_path = results_info["path"]
file_prefix = results_info.get("prefix", "")
file_suffix = results_info["suffix"]
if not os.path.exists(dir_path):
total_msgs.append(f"{dir_path} 不存在")
continue
elif not os.path.isdir(dir_path):
total_msgs.append(f"{dir_path} 不是正常的文件夹路径")
continue
else:
pred_names = [
name[len(file_prefix):-len(file_suffix)]
for name in os.listdir(dir_path)
if name.startswith(file_prefix) and name.endswith(file_suffix)
]
if len(pred_names) == 0:
total_msgs.append(f"{dir_path} 中不包含前缀为 {file_prefix} 且后缀为 {file_suffix} 的文件")
continue
mask_names = [
name[len(file_prefix):-len(mask_suffix)]
for name in os.listdir(mask_path)
if name.endswith(mask_suffix)
]
# 调试输出
print(f"Prefix: {file_prefix}")
print(f"Suffix: {file_suffix}")
print(f"Prediction names in {dir_path}: {pred_names}")
print(f"Ground truth mask names in {mask_path}: {mask_names}")
pred_names_set = set(pred_names)
mask_names_set = set(mask_names)
intersection_names = pred_names_set.intersection(mask_names_set)
if len(intersection_names) == 0:
total_msgs.append(f"{dir_path} 中数据名字与真值 {mask_path} 不匹配")
elif len(intersection_names) != len(mask_names):
difference_names = mask_names_set.difference(pred_names_set)
total_msgs.append(
f"{dir_path} 中数据({len(pred_names_set)})与真值({len(mask_names_set)})不一致: {difference_names}"
)
if total_msgs:
print(*total_msgs, sep="\n")
else:
print(f"{method_json} & {dataset_json} 基本正常")
然后是 examples/config_method_json_example.json
这个json文件里,删除掉所有原先的方法,不然会因为 PASCAL-S
数据集报错
"Method1": {
"PASCAL-S": {
"path": "Path_Of_Method1/PASCAL-S/DGRL",
"prefix": "some_method_prefix",
"suffix": ".png"
},
"ECSSD": {
"path": "Path_Of_Method1/ECSSD/DGRL",
"prefix": "some_method_prefix",
"suffix": ".png"
},
"HKU-IS": {
"path": "Path_Of_Method1/HKU-IS/DGRL",
"prefix": "some_method_prefix",
"suffix": ".png"
},
"DUT-OMRON": {
"path": "Path_Of_Method1/DUT-OMRON/DGRL",
"prefix": "some_method_prefix",
"suffix": ".png"
},
"DUTS-TE": {
"path": "Path_Of_Method1/DUTS-TE/DGRL",
"suffix": ".png"
}
},
"Method2": {
"PASCAL-S": {
"path": "Path_Of_Method2/pascal",
"prefix": "pascal_",
"suffix": ".png"
},
"ECSSD": {
"path": "Path_Of_Method2/ecssd",
"prefix": "ecssd_",
"suffix": ".png"
},
"HKU-IS": {
"path": "Path_Of_Method2/hku",
"prefix": "hku_",
"suffix": ".png"
},
"DUT-OMRON": {
"path": "Path_Of_Method2/duto",
"prefix": "duto_",
"suffix": ".png"
},
"DUTS-TE": {
"path": "Path_Of_Method2/dut_te",
"prefix": "dut_te_",
"suffix": ".png"
}
},
"Method3": {
"PASCAL-S": {
"path": "Path_Of_Method3/pascal",
"prefix": "pascal_",
"suffix": "_fused_sod.png"
},
"ECSSD": {
"path": "Path_Of_Method3/ecssd",
"prefix": "ecssd_",
"suffix": "_fused_sod.png"
},
"HKU-IS": {
"path": "Path_Of_Method3/hku",
"prefix": "hku_",
"suffix": "_fused_sod.png"
},
"DUT-OMRON": {
"path": "Path_Of_Method3/duto",
"prefix": "duto_",
"suffix": "_fused_sod.png"
},
"DUTS-TE": {
"path": "Path_Of_Method3/dut_te",
"prefix": "dut_te_",
"suffix": "_fused_sod.png"
}
}
然后,最新的 PySODEvalToolkit github项目,作者删除了原先的 check_files.py
和 check_nyp.py
,我觉得有必要添加回来,因为这两个文件可以帮助调试来找出错原因
这是check_files.py的代码
import os
gt_path = "Data/COD10K/Test/GT"
mask_path = "Data/COD10K/Test/Mask"
gt_files = [f for f in os.listdir(gt_path) if f.startswith("COD10K-CAM-") and f.endswith(".png")]
mask_files = [f for f in os.listdir(mask_path) if f.startswith("COD10K-CAM-") and f.endswith(".png")]
gt_files_set = set(gt_files)
mask_files_set = set(mask_files)
missing_in_mask = gt_files_set - mask_files_set
missing_in_gt = mask_files_set - gt_files_set
if missing_in_mask:
print("The following ground truth files are missing in mask directory:")
for f in missing_in_mask:
print(f)
if missing_in_gt:
print("The following mask files are missing in ground truth directory:")
for f in missing_in_gt:
print(f)
if not missing_in_mask and not missing_in_gt:
print("All files match correctly!")
这是check_nyp.py的代码
import numpy as np
# 加载 npy 文件
curves = np.load('output/curves.npy', allow_pickle=True).item()
# 打印曲线数据
for dataset_name, method_infos in curves.items():
print(f"Dataset: {dataset_name}")
for method_name in method_infos.keys():
print(f" Method: {method_name}")
然后,我觉得有必要,单独设置两个bash脚本文件来分别执行eval 和 plot,因为这两个python命令参数太多,太长了
在 PySODEvalToolkit/
目录下创建一个 run_eval.sh
文件,
#!/bin/bash
# 定义参数
DATASET_JSON="examples/config_dataset_json_example.json"
METHOD_JSON="examples/config_method_json_example.json"
METRIC_NPY="output/metrics.npy"
CURVES_NPY="output/curves.npy"
RECORD_TXT="output/results.txt"
RECORD_XLSX="output/results.xlsx"
METRIC_NAMES="sm wfm mae fmeasure em precision recall"
INCLUDE_DATASETS="CAMO CHAMELEON COD10K NC4K"
INCLUDE_METHODS="MAMBA_UNET_2024"
# 运行 Python 脚本
python eval.py \
--dataset-json $DATASET_JSON \
--method-json $METHOD_JSON \
--metric-npy $METRIC_NPY \
--curves-npy $CURVES_NPY \
--record-txt $RECORD_TXT \
--record-xlsx $RECORD_XLSX \
--metric-names $METRIC_NAMES \
--include-datasets $INCLUDE_DATASETS \
--include-methods $INCLUDE_METHODS
然后执行如下命令为其添加权限
chmod +x run_plot.sh
然后就可以执行了
./run_plot.sh
同理,创建一个run_plot.sh
文件,其代码如下:
#!/bin/bash
# 定义参数
# STYLE_CFG="examples/single_row_style.yml"
# NUM_ROWS=1
STYLE_CFG="examples/two_row_style.yml"
NUM_ROWS=2
CURVES_NPYS="output/curves.npy"
OUR_METHODS="MAMBA_UNET_2024"
MODE="fm"
SAVE_NAME="output/simple_curve_fm_two_row"
ALIAS_YAML="examples/rgbd_aliases.yaml"
# 运行 Python 脚本
python plot.py \
--style-cfg $STYLE_CFG \
--num-rows $NUM_ROWS \
--curves-npys $CURVES_NPYS \
--our-methods $OUR_METHODS \
--mode $MODE \
--save-name $SAVE_NAME \
--alias-yaml $ALIAS_YAML
examples/single_row_style.yml
是单行效果,examples/two_row_style.yml
是双行效果,MODE
可以选择 em
、fm
、pr
这三种曲线。