问题描述
fit中使用回调函数,提示 TypeError: set_model() missing 1 required positional argument: 'model'
自己的代码:
class TimeHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.times = []
self.totaltime = time.time()
def on_train_end(self, logs={}):
self.totaltime = time.time() - self.totaltime
def on_epoch_begin(self, batch, logs={}):
self.epoch_time_start = time.time()
def on_epoch_end(self, batch, logs={}):
self.times.append(time.time() - self.epoch_time_start)
time_callback = TimeHistory
history = model.fit(train_img, train_labels, epochs=10, batch_size=128, callbacks=[time_callback])
print(time_callback.times)
print(time_callback.totaltime)
执行上述代码,提示 TypeError: set_model() missing 1 required positional argument: 'model'
, 详细错误如下所示:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-35-bf08d059ed19> in <module>
33
34 time_callback = TimeHistory
---> 35 history = model.fit(train_img, train_labels, epochs=10, batch_size=128, callbacks=[time_callback])
36 # history = model.fit(train_img, train_labels, epochs=10, batch_size=128)
37
~/.pyenv/versions/3.7.0/envs/mypython3.7/lib/python3.7/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1037 initial_epoch=initial_epoch,
1038 steps_per_epoch=steps_per_epoch,
-> 1039 validation_steps=validation_steps)
1040
1041 def evaluate(self, x=None, y=None,
~/.pyenv/versions/3.7.0/envs/mypython3.7/lib/python3.7/site-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
115 callback_model = model
116
--> 117 callbacks.set_model(callback_model)
118 callbacks.set_params({
119 'batch_size': batch_size,
~/.pyenv/versions/3.7.0/envs/mypython3.7/lib/python3.7/site-packages/keras/callbacks.py in set_model(self, model)
52 def set_model(self, model):
53 for callback in self.callbacks:
---> 54 callback.set_model(model)
55
56 def on_epoch_begin(self, epoch, logs=None):
TypeError: set_model() missing 1 required positional argument: 'model'
Solution
提示上述错误是因为:在fit回调函数中,传递的是一个类(class),而不是一个对象(object) – you are passing the class instead of an object of that class。
因为自己这里使用的是time_callback = TimeHistory
, time_callback是一个类,而非一个对象。
所以,传递一个实例化的类既可以解决该问题,于这里而言,即time_callback = TimeHistory()
,其他不用改变。
other
python中类的实例化的时候一定要带括号。
定义时,可以带括号也可以不带括号,如果要继承和其他类,一定要带括号。