我有一个简单的memoizer,我用来节省昂贵的网络电话的时间.粗略地说,我的代码看起来像这样:
# mem.py import functools import time def memoize(fn): """ Decorate a function so that it results are cached in memory. >>> import random >>> random.seed(0) >>> f = lambda x: random.randint(0,10) >>> [f(1) for _ in range(10)] [9,8,4,2,5,3,6] >>> [f(2) for _ in range(10)] [9,6,10,9] >>> g = memoize(f) >>> [g(1) for _ in range(10)] [3,3] >>> [g(2) for _ in range(10)] [8,8] """ cache = {} @functools.wraps(fn) def wrapped(*args,**kwargs): key = args,tuple(sorted(kwargs)) try: return cache[key] except KeyError: cache[key] = fn(*args,**kwargs) return cache[key] return wrapped def network_call(user_id): time.sleep(1) return 1 @memoize def search(user_id): response = network_call(user_id) # do stuff to response return response
我对这段代码进行了测试,在这里我模拟了network_call()的不同返回值,以确保我在search()中做的一些修改按预期工作.
import mock import mem @mock.patch('mem.network_call') def test_search(mock_network_call): mock_network_call.return_value = 2 assert mem.search(1) == 2 @mock.patch('mem.network_call') def test_search_2(mock_network_call): mock_network_call.return_value = 3 assert mem.search(1) == 3
但是,当我运行这些测试时,我得到了一个失败,因为search()返回一个缓存的结果.
CAESAR-BAUTISTA:~ caesarbautista$py.test test_mem.py ============================= test session starts ============================== platform darwin -- Python 2.7.8 -- py-1.4.26 -- pytest-2.6.4 collected 2 items test_mem.py .F =================================== FAILURES =================================== ________________________________ test_search_2 _________________________________ args = (<MagicMock name='network_call' id='4438999312'>,),keywargs = {} extra_args = [<MagicMock name='network_call' id='4438999312'>] entered_patchers = [<mock._patch object at 0x108913dd0>] exc_info = (<class '_pytest.assertion.reinterpret.AssertionError'>,AssertionError(u'assert 2 == 3\n + where 2 = <function search at 0x10893f848>(1)\n + where <function search at 0x10893f848> = mem.search',<traceback object at 0x1089502d8>) patching = <mock._patch object at 0x108913dd0> arg = <MagicMock name='network_call' id='4438999312'> @wraps(func) def patched(*args,**keywargs): # don't use a with here (backwards compatability with Python 2.4) extra_args = [] entered_patchers = [] # can't use try...except...finally because of Python 2.4 # compatibility exc_info = tuple() try: try: for patching in patched.patchings: arg = patching.__enter__() entered_patchers.append(patching) if patching.attribute_name is not None: keywargs.update(arg) elif patching.new is DEFAULT: extra_args.append(arg) args += tuple(extra_args) > return func(*args,**keywargs) /opt/Boxen/homebrew/lib/python2.7/site-packages/mock.py:1201: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ mock_network_call = <MagicMock name='network_call' id='4438999312'> @mock.patch('mem.network_call') def test_search_2(mock_network_call): mock_network_call.return_value = 3 > assert mem.search(1) == 3 E assert 2 == 3 E + where 2 = <function search at 0x10893f848>(1) E + where <function search at 0x10893f848> = mem.search test_mem.py:15: AssertionError ====================== 1 Failed,1 passed in 0.03 seconds ======================
有没有办法测试记忆功能?我考虑了一些替代方案,但它们都有缺点.
一种解决方案是模拟memoize().我不愿意这样做,因为它泄漏了测试的实现细节.从理论上讲,我应该能够在没有系统其他部分的情况下记忆和取消默认功能,包括测试,从功能角度注意.
另一种解决方案是重写代码以公开修饰函数.也就是说,我可以这样做:
def _search(user_id): return network_call(user_id) search = memoize(_search)
然而,这遇到了与上面相同的问题,尽管它可能更糟,因为它不适用于递归函数.
解决方法
是否真的需要在功能级别定义您的memoization?
这有效地使得memoized数据成为一个全局变量(就像函数一样,它的共享范围).
顺便说一下,这就是你在测试时遇到困难的原因!
那么,如何将它包装成一个对象呢?
import functools import time def memoize(meth): @functools.wraps(meth) def wrapped(self,*args,**kwargs): # Prepare and get reference to cache attr = "_memo_{0}".format(meth.__name__) if not hasattr(self,attr): setattr(self,attr,{}) cache = getattr(self,attr) # Actual caching key = args,tuple(sorted(kwargs)) try: return cache[key] except KeyError: cache[key] = meth(self,**kwargs) return cache[key] return wrapped def network_call(user_id): print "Was called with: %s" % user_id return 1 class NetworkEngine(object): @memoize def search(self,user_id): return network_call(user_id) if __name__ == "__main__": e = NetworkEngine() for v in [1,1,2]: e.search(v) NetworkEngine().search(1)
产量:
Was called with: 1 Was called with: 2 Was called with: 1
换句话说,NetworkEngine的每个实例都有自己的缓存.只需重用相同的一个来共享一个缓存,或者实例化一个新缓存以获得一个新的缓存.
在您的测试代码中,您将使用:
@mock.patch('mem.network_call') def test_search(mock_network_call): mock_network_call.return_value = 2 assert mem.NetworkEngine().search(1) == 2