Bootstrap

keras中计算precision和recall的一点思考

需要对模型的precision和recall进行衡量,希望使用metrics在训练的时候将这两个指标体现出来。    
keras2.0中已经删除了F1score、precision和recall的计算。按照

https://github.com/keras-team/keras/issues/5400) 的说法,可以自己编写batch_wise的precision和recall。自定义了如下的代码。

def precision(y_true,y_pred,n=0):#精准率
    threshold = K.constant(n)
    true_positives = K.sum(K.cast(K.greater(y_true,threshold)&K.greater(y_pred,threshold),tf.float32))
    possible_positives = K.sum(K.cast(K.greater(y_pred,threshold),tf.float32))
    precision = true_positives / (possible_positives + K.epsilon())
    return precision
def recall(y_true,y_pred,n=0):#召回率
    threshold = K.constant(n)
    true_positives = K.sum(K.cast(K.greater(y_true,threshold)&K.greater(y_pred,threshold),tf.float32))
    possible_positives = K.sum(K.cast(K.greater(y_true,threshold),tf.float32))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall
代码设置了个阈值n,精准率是预测值中大于n的数量中,真实值也大于n的比例。召回率是真实值大于n的数量中,预测值也大于n的比例。
然而,越测试越感觉不对,batch_wise的结果是我们想要的吗?问题出在一个batch中TP都为0的情况下,此时precision的计算结果为0;而分母并非全都非0的,为0的数字不该记到分母的总和里面去。当batch很小,或者准确率很低的情况下问题就会很突出。显示的结果不能反映真实的情况。
所以直接使用Batch_wise的指标是不行的,原因是里面的触发。但是如果换个思路呢?不直接使用precision、recall这类指标,而是取分解的tp、tn、fp、fn,这几个指标可以在batch里求平均,再在整个epoch上平均就是它本身的平均值,因为里面的除法的分母都是batchsize。而precision=tp/(tp+fp),recall =tp/(tp+tn),total_accuracy =(tp+fn)/(tp+fn+tn+fp)
重新定义了四个函数:
def tp(y_true,y_pred,n=0):#提供真实值大于n且预测值大于n的数量的平均值
    threshold = K.constant(n)
    true_positives = K.mean(K.cast(K.greater(y_true,threshold)&K.greater(y_pred,threshold),tf.float32))
    return true_positives
def tn(y_true,y_pred,n=0):#提供真实值大于n且预测值小于等于n的数量的平均值
    threshold = K.constant(n)
    true_negatives = K.mean(K.cast(K.greater(y_true,threshold)&K.less_equal(y_pred,threshold),tf.float32))
    return true_negatives
def fp(y_true,y_pred,n=0):#提供真实值小于等于n且预测值大于n的数量的平均值
    threshold = K.constant(n)
    false_positives = K.mean(K.cast(K.less_equal(y_true,threshold)&K.greater(y_pred,threshold),tf.float32))
    return false_positives
def fn(y_true,y_pred,n=0):#提供真实值小于等于n且预测值小于等于n的数量的平均值
    threshold = K.constant(n)
    false_negatives = K.mean(K.cast(K.less_equal(y_true,threshold)&K.less_equal(y_pred,threshold),tf.float32))
    return false_negatives

然后定义了一个callback类:

class myCallback(tf.keras.callbacks.Callback): 
    def on_epoch_end(self, epoch, logs={}): 
        logs = logs or {}
        precision = logs.get('tp')/(logs.get('tp')+logs.get('fp')+K.epsilon())
        recall = logs.get('tp')/(logs.get('tp')+logs.get('tn')+K.epsilon())
        total_accuracy = (logs.get('tp')+logs.get('fn'))/(logs.get('tp')+logs.get('fp')+logs.get('tn')+logs.get('fn')+K.epsilon())
        val_precision = logs.get('val_tp')/(logs.get('val_tp')+logs.get('val_fp')+K.epsilon())
        val_recall = logs.get('val_tp')/(logs.get('val_tp')+logs.get('val_tn')+K.epsilon())
        val_total_accuracy = (logs.get('val_tp')+logs.get('val_fn'))/(logs.get('val_tp')+\
                                                        logs.get('val_fp')+logs.get('val_tn')+logs.get('val_fn')+K.epsilon())
        logs['total_accuracy'] = total_accuracy
        logs['recall'] = recall
        logs['precision'] = precision
        logs['val_total_accuracy'] = val_total_accuracy
        logs['val_recall'] = val_recall
        logs['val_precision'] = val_precision
        print(" — precision: %0.4f — recall: %0.4f — total_accuracy: %0.4f — val_precision: %0.4f — val_recall: %0.4f — val_total_accuracy: %0.4f "\
              % (precision,recall,total_accuracy,val_precision, val_recall,val_total_accuracy))

在fit的时候加上callback对象,注意myCallback的后面一定要有括号,否则会出莫名其妙的错误,其实就是要把实例传过去,而不是类名:

with tf.device('/GPU:0'):
    H = model.fit_generator(train_gen,steps_per_epoch=5,
                            validation_data=test_gen,
                            validation_steps=5,
                            epochs=2,callbacks=[myCallback()],initial_epoch=0)

结果就可以在每个epoch显示出指标的情况:

Epoch 1/2
5/5 [==============================] - ETA: 0s - loss: 141.5960 - mae: 9.8554 - tp: 0.3000 - tn: 0.2000 - fp: 0.3000 - fn: 0.2000 - y_true_value: 0.0188 - y_pred_value: 2.8814     — precision: 0.5000 — recall: 0.6000 — total_accuracy: 0.5000 — val_precision: 0.8000 — val_recall: 1.0000 — val_total_accuracy: 0.8000 
5/5 [==============================] - 11s 2s/step - loss: 141.5960 - mae: 9.8554 - tp: 0.3000 - tn: 0.2000 - fp: 0.3000 - fn: 0.2000 - y_true_value: 0.0188 - y_pred_value: 2.8814 - val_loss: 262.5614 - val_mae: 9.4246 - val_tp: 0.8000 - val_tn: 0.0000e+00 - val_fp: 0.2000 - val_fn: 0.0000e+00 - val_y_true_value: 2.0131 - val_y_pred_value: 10.7343 - total_accuracy: 0.5000 - recall: 0.6000 - precision: 0.5000 - val_total_accuracy: 0.8000 - val_recall: 1.0000 - val_precision: 0.8000
Epoch 2/2
5/5 [==============================] - ETA: 0s - loss: 39.7712 - mae: 4.7150 - tp: 0.4000 - tn: 0.1000 - fp: 0.4000 - fn: 0.1000 - y_true_value: 0.0188 - y_pred_value: 4.2211       — precision: 0.5000 — recall: 0.8000 — total_accuracy: 0.5000 — val_precision: 0.0000 — val_recall: 0.0000 — val_total_accuracy: 0.3000 
5/5 [==============================] - 11s 2s/step - loss: 39.7712 - mae: 4.7150 - tp: 0.4000 - tn: 0.1000 - fp: 0.4000 - fn: 0.1000 - y_true_value: 0.0188 - y_pred_value: 4.2211 - val_loss: 29.4384 - val_mae: 4.1652 - val_tp: 0.0000e+00 - val_tn: 0.7000 - val_fp: 0.0000e+00 - val_fn: 0.3000 - val_y_true_value: 2.1435 - val_y_pred_value: -1.7780 - total_accuracy: 0.5000 - recall: 0.8000 - precision: 0.5000 - val_total_accuracy: 0.3000 - val_recall: 0.0000e+00 - val_precision: 0.0000e+00

把代码放到on_batch_end里,也可以跟随每一个batch显示指标进度。不过keras的显示机制好像挺复杂,怎么试也实现不了每一个batch重新显示,\r什么的结果都不对。所以就放弃了。
试着读keras的代码,但是这些代码怎么读不懂了,这还是python吗?有大神给解释一下吗?

@overload
    def update(self, **kwargs: _VT) -> None: ...
    if sys.version_info >= (3,):
        def keys(self) -> KeysView[_KT]: ...
        def values(self) -> ValuesView[_VT]: ...
        def items(self) -> ItemsView[_KT, _VT]: ...
    else:
        def iterkeys(self) -> Iterator[_KT]: ...
        def itervalues(self) -> Iterator[_VT]: ...
        def iteritems(self) -> Iterator[Tuple[_KT, _VT]]: ...
        def viewkeys(self) -> KeysView[_KT]: ...
        def viewvalues(self) -> ValuesView[_VT]: ...
        def viewitems(self) -> ItemsView[_KT, _VT]: ...
;