题目描述
给定两个稀疏矩阵A 和 B,返回AB相乘的结果。
您可以假设A的列数等于B的行数。
样例
Input:
[[1,0,0],[-1,0,3]]
[[7,0,0],[0,0,0],[0,0,1]]
Output:
[[7,0,0],[-7,0,3]]
Explanation:
A = [
[ 1, 0, 0],
[-1, 0, 3]
]
B = [
[ 7, 0, 0 ],
[ 0, 0, 0 ],
[ 0, 0, 1 ]
]
| 1 0 0 | | 7 0 0 | | 7 0 0 |
AB = | -1 0 3 | x | 0 0 0 | = | -7 0 3 |
| 0 0 1 |
Input:
[[1,0],[0,1]]
[[0,1],[1,0]]
Output:
[[0,1],[1,0]]
思路
首先要注意题目描述,是稀疏矩阵!也就是说矩阵中大部分元素值为0。因此,只要矩阵的任一元素出现为0,则不对其进行乘法运算,直接无视掉。
1、先分配好输出矩阵的空间大小,再往其中填入运算结果
我们知道,若矩阵A和B相乘,那么输出矩阵的行数将等于矩阵A的行数,列数将等于矩阵B的列数,于是根据输入矩阵A、B我们可以事先为输出矩阵分配好空间,然后再将对应位置的元素填上:
, where A[i][k] * B[k][j]
0
2、将矩阵相乘转换为行、列向量相乘
矩阵相乘的过程可看作是每个行向量与列向量对应位置的元素相乘然后求和,因此可以事先将矩阵A转换为一个个行向量,将矩阵B转换为一个个列向量。
注意在转换的时候,我们只取非0元素,同时记录下每个元素所在行/列的位置,只有当行、列向量中的元素位置匹配上时这些元素才能进行相乘。
代码
1、先分配好输出矩阵的空间大小,再往其中填入运算结果
class Solution:
"""
@param A: a sparse matrix
@param B: a sparse matrix
@return: the result of A * B
"""
def multiply(self, A, B):
C = [[0] * len(B[0]) for _ in range(len(A))]
for i in range(len(A)):
for k in range(len(B)):
if A[i][k]:
for j in range(len(B[0])):
if B[k][j]:
C[i][j] += A[i][k] * B[k][j]
return C
2、将矩阵相乘转换为行、列向量相乘
class Solution:
"""
@param A: a sparse matrix
@param B: a sparse matrix
@return: the result of A * B
"""
def multiply(self, A, B):
row_vec = self.convet_to_row_vector(A)
col_vec = self.convet_to_col_vector(B)
C = []
for rv in row_vec:
r = []
for cv in col_vec:
# 每个行向量与每个列向量对应相乘
r.append(self.multiply_vector(rv, cv))
C.append(r)
return C
def convet_to_row_vector(self, matrix):
"""将二维矩阵转换为行向量"""
row_vector = []
for row_val in matrix:
vector = []
for j, val in enumerate(row_val):
# 只加入非0元素
if val:
vector.append((j, val))
row_vector.append(vector)
return row_vector
def convet_to_col_vector(self, matrix):
"""将二维矩阵转换为列向量"""
col_vector = []
for j in range(len(matrix[0])):
vector = []
for i in range(len(matrix)):
# 只加入非0元素
if matrix[i][j]:
vector.append((i, matrix[i][j]))
col_vector.append(vector)
return col_vector
def multiply_vector(self, v1, v2):
"""向量相乘,即行向量的第k个元素只能与列向量的第k个元素相乘"""
i = j = 0
result = 0
while i < len(v1) and j < len(v2):
col_num = v1[i][0]
row_num = v2[j][0]
if col_num < row_num:
i += 1
elif col_num > row_num:
j += 1
else:
result += v1[i][1] * v2[j][1]
i += 1
j += 1
return result