Bootstrap

Keras fit函数中传入回调函数,提示『TypeError: set_model() missing 1 required positional argument: 'model'』错误

问题描述

fit中使用回调函数,提示 TypeError: set_model() missing 1 required positional argument: 'model'

自己的代码:

class TimeHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []
        self.totaltime = time.time()
        
    def on_train_end(self, logs={}):
        self.totaltime = time.time() - self.totaltime
        
    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time.time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)

    
time_callback = TimeHistory        
history = model.fit(train_img, train_labels, epochs=10, batch_size=128, callbacks=[time_callback])

print(time_callback.times)
print(time_callback.totaltime)

执行上述代码,提示 TypeError: set_model() missing 1 required positional argument: 'model', 详细错误如下所示:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-35-bf08d059ed19> in <module>
     33 
     34 time_callback = TimeHistory
---> 35 history = model.fit(train_img, train_labels, epochs=10, batch_size=128, callbacks=[time_callback])
     36 # history = model.fit(train_img, train_labels, epochs=10, batch_size=128)
     37 

~/.pyenv/versions/3.7.0/envs/mypython3.7/lib/python3.7/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1037                                         initial_epoch=initial_epoch,
   1038                                         steps_per_epoch=steps_per_epoch,
-> 1039                                         validation_steps=validation_steps)
   1040 
   1041     def evaluate(self, x=None, y=None,

~/.pyenv/versions/3.7.0/envs/mypython3.7/lib/python3.7/site-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
    115         callback_model = model
    116 
--> 117     callbacks.set_model(callback_model)
    118     callbacks.set_params({
    119         'batch_size': batch_size,

~/.pyenv/versions/3.7.0/envs/mypython3.7/lib/python3.7/site-packages/keras/callbacks.py in set_model(self, model)
     52     def set_model(self, model):
     53         for callback in self.callbacks:
---> 54             callback.set_model(model)
     55 
     56     def on_epoch_begin(self, epoch, logs=None):

TypeError: set_model() missing 1 required positional argument: 'model'

Solution

提示上述错误是因为:在fit回调函数中,传递的是一个类(class),而不是一个对象(object) – you are passing the class instead of an object of that class。

因为自己这里使用的是time_callback = TimeHistory, time_callback是一个类,而非一个对象。

所以,传递一个实例化的类既可以解决该问题,于这里而言,即time_callback = TimeHistory(),其他不用改变。

other

python中类的实例化的时候一定要带括号。

定义时,可以带括号也可以不带括号,如果要继承和其他类,一定要带括号。

参考: http://bit.ly/2SsWoaz

;