代码实现
套用公式:
def gini(a: int, b: int) -> float:
return 1-pow(a/(a+b), 2)-pow(b/(a+b), 2)
def gini_total(a: int, b: int, c: int, d: int) -> float:
return (a+b)/(a+b+c+d) * gini(a, b) + (c+d)/(a+b+c+d) * gini(c, d)
print(gini(13, 98))
print(gini(24, 29))
print(gini_total(13, 98, 24, 29))
numpy实现
import numpy as np
def gini(data_list):
data_length = len(data_list)
total_sum = np.sum(data_list)
total_gini = 0
for i in range(data_length):
temp_denominator = data_list[i][0] + data_list[i][1]
temp_son = np.sum(data_list[i])
temp_gini = 1 - np.power(data_list[i][0]/temp_denominator, 2) - np.power(data_list[i][1]/temp_denominator, 2)
total_gini = temp_gini * temp_son / total_sum + total_gini
return total_gini
测试结果