Matrix Chain Multiplication

Given an array arr[] which represents the dimensions of a sequence of matrices where the ith matrix has the dimensions (arr[i-1] x arr[i]) for i>=1, find the most efficient way to multiply these matrices together. The efficient way is the one that involves the least number of multiplications.

Examples:

Input: arr[] = [2, 1, 3, 4]
Output: 20
Explanation: There are 3 matrices of dimensions 2 × 1, 1 × 3, and 3 × 4, Let this 3 input matrices be M1, M2, and M3. There are two ways to multiply: ((M1 x M2) x M3) and (M1 x (M2 x M3)), note that the result of (M1 x M2) is a 2 x 3 matrix and result of (M2 x M3) is a 1 x 4 matrix. 
((M1 x M2) x M3) requires (2 x 1 x 3) + (2 x 3 x 4) = 30 (M1 x (M2 x M3)) requires (1 x 3 x 4) + (2 x 1 x 4) = 20.
The minimum of these two is 20.
Input: arr[] = [1, 2, 3, 4, 3]
Output: 30
Explanation: There are 4 matrices of dimensions 1 × 2, 2 × 3, 3 × 4, 4 × 3. Let this 4 input matrices be M1, M2, M3 and M4. The minimum number of multiplications are obtained by ((M1 x M2) x M3) x M4). The minimum number is (1 x 2 x 3) + (1 x 3 x 4) + (1 x 4 x 3) = 30.
Input: arr[] = [3, 4]
Output: 0
Explanation: As there is only one matrix so, there is no cost of multiplication.

Constraints: 
2 ≤ arr.size() ≤ 100
1 ≤ arr[i] ≤ 200


Approach 01:

#include <vector>
#include <climits>

using namespace std;

class Solution {
public:
    int matrixChainMultiplication(vector<int>& dimensions, int left, int right, vector<vector<int>>& memo) {
        if (left == right) {
            return 0;
        }

        if (memo[left][right] != -1) {
            return memo[left][right];
        }

        int minOperations = INT_MAX;

        for (int partition = left; partition <= right - 1; partition++) {
            int currentOperations = matrixChainMultiplication(dimensions, left, partition, memo) +
                                    matrixChainMultiplication(dimensions, partition + 1, right, memo) +
                                    dimensions[left - 1] * dimensions[partition] * dimensions[right];

            minOperations = min(minOperations, currentOperations);
        }

        return memo[left][right] = minOperations;
    }

    int matrixMultiplication(vector<int>& dimensions) {
        int numMatrices = dimensions.size();
        if (numMatrices <= 2) return 0; // Edge case: no multiplication needed

        vector<vector<int>> memo(numMatrices, vector<int>(numMatrices, -1));
        return matrixChainMultiplication(dimensions, 1, numMatrices - 1, memo);
    }
};
from typing import List

class Solution:
    def matrixChainMultiplication(self, dimensions: List[int], left: int, right: int, memo: List[List[int]]) -> int:
        if left == right:
            return 0

        if memo[left][right] != -1:
            return memo[left][right]

        minOperations = float('inf')

        for partition in range(left, right):
            currentOperations = (self.matrixChainMultiplication(dimensions, left, partition, memo) +
                                 self.matrixChainMultiplication(dimensions, partition + 1, right, memo) +
                                 dimensions[left - 1] * dimensions[partition] * dimensions[right])

            minOperations = min(minOperations, currentOperations)

        memo[left][right] = minOperations
        return minOperations

    def matrixMultiplication(self, dimensions: List[int]) -> int:
        numMatrices = len(dimensions)
        if numMatrices <= 2:
            return 0  # Edge case: no multiplication needed

        memo = [[-1] * numMatrices for _ in range(numMatrices)]
        return self.matrixChainMultiplication(dimensions, 1, numMatrices - 1, memo)

Time Complexity:

  • Recursive Calls:

    The function recursively explores all possible partitions of the matrix chain. Since there are \( O(n^2) \) unique subproblems and each takes \( O(n) \) time to compute, the total complexity is \( O(n^3) \).

  • Total Time Complexity:

    The overall time complexity is \( O(n^3) \).

Space Complexity:

  • Memoization Table:

    The algorithm uses a \( n \times n \) memoization table, requiring \( O(n^2) \) space.

  • Recursive Call Stack:

    In the worst case, the recursive depth reaches \( O(n) \), leading to an additional \( O(n) \) space usage.

  • Total Space Complexity:

    The overall space complexity is \( O(n^2) \).

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top