斯坦福算法专项课程Course1 week2算法实现
这次回顾第二周的算法。
课程地址:https://www.coursera.org/specializations/algorithms
Closest Pair
import numpy as np
import functools
import time
#生成测试数据
def generate(n):
return np.random.randn(n, 2)
#比较函数
def cmp1(a, b):
return a[0] - b[0]
def cmp2(a, b):
return a[1] - b[1]
#距离函数
def dist(a, b):
return (a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2
#直接解法
def Closest_Pair_bf(data):
n = len(data)
d_min = 1e9
p = []
q = []
for i in range(n):
for j in range(i + 1, n):
d1 = dist(data[i], data[j])
if d1 < d_min:
d_min = d1
p = data[i]
q = data[j]
return p, q, np.sqrt(d_min)
#将Px划分为前半部分以及后半部分
def Split(Px, Py):
n = len(Px)
m = n // 2
#中间位置的x坐标
x_mid = Px[m][0]
#x部分直接划分即可
Lx = Px[:m]
Rx = Px[m:]
#y部分遍历即可
Ly = []
Ry = []
for i in range(n):
if (Py[i][0] < x_mid):
Ly.append(Py[i])
else:
Ry.append(Py[i])
return Lx, Ly, Rx, Ry
def Closest_Split_Pair(Px, Py, d):
n = len(Px)
m = n // 2
#左半部分最大坐标
x_mid = Px[m - 1][0]
#Sy为x坐标在[x_mid - d, x_mid + d]范围内的点,按y坐标排序
Sy = []
for i in range(n):
if (Py[i][0] >= x_mid - d and Py[i][0] <= x_mid + d):
Sy.append(Py[i])
#Sy的长度
l = len(Sy)
#初始化
d_min = 1e9
p = []
q = []
if (l == 0):
return p, q, np.sqrt(d_min)
for i in range(l - 1):
k = min(7, l - 1 - i)
for j in range(1, k + 1):
d1 = dist(Sy[i], Sy[i + j])
if (d1 < d_min):
d_min = d1
p = Sy[i]
q = Sy[i + j]
return p, q, np.sqrt(d_min)
def Closest_Pair(Px, Py):
n = len(Px)
if (n <= 3):
return Closest_Pair_bf(Px)
else:
#划分
Lx, Ly, Rx, Ry = Split(Px, Py)
p1, q1, d1 = Closest_Pair(Lx, Ly)
p2, q2, d2 = Closest_Pair(Rx, Ry)
d = min(d1, d2)
p3, q3, d3 = Closest_Split_Pair(Px, Py, d)
d_min = min(d3, d)
#输出
if (d_min == d1):
return p1, q1, d1
elif (d_min == d2):
return p2, q2, d2
else:
return p3, q3, d3
####测试算法
#数据数量
N = 2000
#生成数据
data = generate(N)
#直接解法
t1 = time.time()
p1, q1, d_min1 = Closest_Pair_bf(data)
t2 = time.time()
print(p1, q1, d_min1)
print(t2 - t1)
#快速算法
t1 = time.time()
#按x坐标排序
Px = sorted(data, key=functools.cmp_to_key(cmp1))
#按y坐标排序
Py = sorted(data, key=functools.cmp_to_key(cmp2))
p2, q2, d_min2 = Closest_Pair(Px, Py)
t2 = time.time()
print(p2, q2, d_min2)
print(t2 - t1)
Counting Inversions
import numpy as np
def Count_split_inv(A, l, m, r):
"""
l, ..., m为第一部分
m + 1, ..., r为第二部分
"""
#存储结果
C = []
#索引
i = l
j = m + 1
#逆序数
cnt = 0
while ((i <= m) and (j <= r)):
if (A[i] <= A[j]):
C.append(A[i])
i += 1
else:
C.append(A[j])
j += 1
cnt += m - i + 1
#处理剩余部分
while (i <= m):
C.append(A[i])
i += 1
while (j <= r):
C.append(A[j])
j += 1
#更新
A[l: (r + 1)] = C
return cnt
def Sort_and_Count(A, l, r):
'''
对l, ... , r排序并计算逆序数
'''
if (l >= r):
return 0
else:
m = (l + r) // 2
x = Sort_and_Count(A, l, m)
y = Sort_and_Count(A, m + 1, r)
z = Count_split_inv(A, l, m, r)
return x + y + z
def Count_inv(A):
n = len(A)
return Sort_and_Count(A, 0, n - 1)
def Count_inv_bf(A):
cnt = 0
n = len(A)
for i in range(n):
for j in range(i + 1, n):
if (A[i] > A[j]):
cnt += 1
return cnt
#测试
N = 1000
data = np.random.randn(N)
cnt1 = Count_inv_bf(data)
cnt2 = Count_inv(data)
print(cnt1, cnt2)
print(cnt1 == cnt2)
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Doraemonzzz!
评论
ValineLivere