字符串相乘 - LeetCode
导入依赖
主要依赖的库有:
-
math
:用来进行幂运算。 -
random
:用来生成随机测试用例
import math
import random
拆分、填充
- 默认输入的格式为不固定长度的字符串,如
"123456"
。 - 需要对输入的字符串拆分成长度为 的数字类型列表,如
[1,2,3,4,5,6]
。 - 并对其进行填充,找到 的指数 ,满足 ,如:
时,有 ,
满足 。
- 使用
0
对列表进行填充后的长度满足: ,如[1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0]
。
def to_list(num1:str, num2:str) -> tuple:
# 拆分为 list
a = [int(i) for i in num1]
b = [int(i) for i in num2]
# 反转列表,将低阶项系数放在列表前面
a.reverse()
b.reverse()
max_len = max(len(a),len(b))
# 对齐使长度相等
l = len(a)-len(b)
zeros = [0] * abs(l)
if l < 0:
a = a + zeros
elif l > 0:
b = b + zeros
# 补充前导 0,使得长度为 2^n
fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len)
fill = [0] * fill_count
return a+fill,b+fill
多项式表示
对于输入的两个数 、 ,将其处理成两个多项式:
最终的目标是对多项式 进行求解。
傅里叶变换求解
-
将处理后的两个列表进行快速傅里叶变换(fft),得到 个点值对的取值
-
得到两个新的列表并将其按元素相乘,得到待求解的多项式 的值
-
再进行逆离散傅里叶变换(idft),将点值表示转换为 的系数;
-
傅里叶变换后的结果是虚数,其实部四舍五入后取整,便是结果多项式对应项的系数,将以 为底的多项式计算求和,得到乘法的结果。
def multiply(num1: str, num2: str) -> str:
l = len(num1)
a,b = to_list(num1, num2)
# 傅里叶变换
a_fft, b_fft = fft(a), fft(b)
t = []
# 对应项相乘
for i in range(len(a_fft)):
t.append(a_fft[i] * b_fft[i])
# 逆傅里叶变换
ans = idft(t)
sum = 0
# 计算多项式
for i,r in enumerate(ans):
# 实部四舍五入取整
sum += int(r.real+0.5) * (10 ** i)
return str(sum)
傅里叶变换实现
傅里叶变换与逆傅里叶变换的主要区别在于:逆傅里叶变换需要对计算的结果除以 (并不是在递归中进行),并且在计算的过程中 。
def _ft(l:list, idft = False):
"""
基础的变换方法,通过变量控制进行dft还是idft
:param bool idft: 控制进行傅里叶变换还是逆傅里叶变换
"""
n = len(l)
if n == 1:
return l
# dft 与 idft 分别处理 $\omega$
o_n_e = -2j if idft else 2j
o = 1
o_n = math.e ** (o_n_e * math.pi / n)
# 拆分奇偶项
even_index = l[::2]
odd_index = l[1::2]
y_even = _ft(even_index, idft)
y_odd = _ft(odd_index, idft)
y = [0]*n
for i in range(n//2):
y[i] = y_even[i] + o * y_odd[i]
y[i+n//2] = y_even[i] - o * y_odd[i]
o *= o_n
return y
def fft(l:list):
"""
傅里叶变换
"""
output = _ft(l)
return output
def idft(l:list):
"""
逆傅里叶变换
"""
n = len(l)
output = _ft(l,True)
# 将计算的结果除以 $N$
output = [i/n for i in output]
return output
测试
将multiply()
方法输出的结果与自带的乘法计算结果进行比较,并输出测试结果。
def test(num1:str, num2:str):
r = int(multiply(num1,num2))
s = int(num1)*int(num2)
t = 30
print(f"{'-'*t} Test {'-'*t}")
print(f"Test case: \n\t{num1} \n\t{num2}")
print(f"Program output: \n\t{r}")
print(f"Expected output: \n\t{s}")
print(f"❌ FAILED" if r != s else "✔ OK")
return r == s
编写测试用例
# 测试用例数
test_cases = 10
# 数据长度
INT_MAX = 1e100
for i in range(test_caese):
num1 = str(random.randint(0, INT_MAX))
num2 = str(random.randint(0, INT_MAX))
test(num1, num2)
LeetCode
AC 代码
class Solution:
def to_list(self, num1:str, num2:str) -> tuple:
a = [int(i) for i in num1]
b = [int(i) for i in num2]
a.reverse()
b.reverse()
l = len(a)-len(b)
max_len = max(len(a),len(b))
# 对齐使长度相等
zeros = [0] * abs(l)
if l < 0:
a = a + zeros
elif l > 0:
b = b + zeros
# 补充前导 0,使得长度为 2^n
fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len)
fill = [0] * fill_count
return a+fill,b+fill
def multiply(self, num1: str, num2: str) -> str:
l = len(num1)
a,b = self.to_list(num1, num2)
a_fft, b_fft = self.fft(a), self.fft(b)
t = []
_3 = []
for i in range(len(a_fft)):
t.append(a_fft[i] * b_fft[i])
ans = self.idft(t)
sum = 0
for i,r in enumerate(ans):
sum += int(r.real+0.5) * (10 ** i)
return str(sum)
def _ft(self, l:list, idft = False):
n = len(l)
if n == 1:
return l
o_n_e = -2j if idft else 2j
even_index = l[::2]
odd_index = l[1::2]
o = 1
o_n = math.e ** (o_n_e * math.pi / n)
y_even = self._ft(even_index, idft)
y_odd = self._ft(odd_index, idft)
y = [0]*n
for i in range(n//2):
y[i] = y_even[i] + o * y_odd[i]
y[i+n//2] = y_even[i] - o * y_odd[i]
o *= o_n
return y
def fft(self, l:list):
output = self._ft(l)
return output
def idft(self, l:list):
n = len(l)
output = self._ft(l,True)
output = [i/n for i in output]
return output