defbinary_search_e1(arr, t): l, r = 0, len(arr) - 1 while l <= r: m = (l + r) // 2# 取整 print(" current l: %d, m: %d, r: %d" % (l, m, r)) if arr[m] < t: l = m # 将搜索范围调整为 m 右侧 elif arr[m] > t: r = m # 将搜索范围调整为 m 左侧 else: return m # 找到 return -1# 没有找到
# 一个脚手架做简单测试 whileTrue: # 被查数组 arr = [] n, t = map(int, input().split()) # no assertion for inputs if -1 == n: break for i inrange(n): arr.append(10 * i) print("binary_search_e1: %d\n" % (binary_search_e1(arr, t)))
5 30 current l: 0, m: 2, r: 4 current l: 2, m: 3, r: 4 binary_search_e1: 3
5 40 current l: 0, m: 2, r: 4 current l: 2, m: 3, r: 4 current l: 3, m: 3, r: 4 current l: 3, m: 3, r: 4 ... ...
发现在输入 5 40 时,程序进入了死循环。在本例中,由于区间边界的计算逻辑错误导致 l <= r 在上述情况下无法达到。经过对边界条件设置的摸索后,改成了第二段代码的形式。
第二段代码
1 2 3 4 5 6 7 8 9 10 11 12
defbinary_search_e2(arr, t): l, r = 0, len(arr) - 1 while l <= r: m = round((l + r) / 2) # 取整 # print(" current l: %d, m: %d, r: %d" % (l, m, r)) if arr[m] < t: l = m + 1# 将搜索范围调整为 m 右侧 elif arr[m] > t: r = m - 1# 将搜索范围调整为 m 左侧 else: return m # 找到 return -1# 没有找到
>>> assert(0 > 1) Traceback (most recent call last): File "<stdin>", line 1, in <module> AssertionError
又如本文,应确保待查询的数组是严格有序的,可以使用如下 assert,
1 2 3 4 5 6 7
defsorted(arr): for i inrange(len(arr)-1): if arr[i+1] < arr[i]: returnFalse returnTrue
assert(sorted())
然后一个自动测试的脚手架被设计成下面这样。
自动测试的脚手架
1 2 3 4 5 6 7 8 9
arr = [] MAX_N = int(input("Input maximum length for the array: ")) for i inrange(MAX_N): arr.append(10 * i) for i inrange(MAX_N): assert(binary_search_right(arr, i * 10) == i) assert(binary_search_right(arr, i * 10 - 5) == -1) assert(binary_search_right(arr, 10 * MAX_N - 5) == -1) assert(binary_search_right(arr, 10 * MAX_N) == -1)
运行这个脚手架,
1 2 3 4 5 6
$ python binary_search_autotest.py Input maximum length for the array: 100 $ python binary_search_autotest.py Input maximum length for the array: 1000 $ python binary_search_autotest.py Input maximum length for the array: 100000
通过细心的分析,最终,发现错误出在 m = (l + r) / 2 导致了溢出。当输入大于 INT_MAX 的一半时,求和的结果不再是正确的 int 值。而原来由于 Python 的特殊数据类型机制,导致 Python 中的计算,不会出现溢出。
好了,写了这么多,就是想说问题就出在 m = (l + r) / 2。于是将其改成 m = l + (r - l) / 2, 或者 m = l + ((r - l) >> 1),确保不溢出。
再编译、运行,成功通过上述不成功的用例。
第三段代码
终于,一个看起来没有错误的二分查找算法完成了。
1 2 3 4 5 6 7 8 9 10 11 12
defbinary_search_right(arr, t): l, r = 0, len(arr) - 1 while l <= r: m = l + (r - l) // 2# 取整 # print(" current l: %d, m: %d, r: %d" % (l, m, r)) if arr[m] < t: l = m + 1# 将搜索范围调整为 m 右侧 elif arr[m] > t: r = m - 1# 将搜索范围调整为 m 左侧 else: return m # 找到 return -1# 没有找到