Fu_L's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub Fu-L/cp-library

:heavy_check_mark: inverse
(src/matrix/inverse.hpp)

inverse

Matrix<T> inverse(Matrix<T> A)

$N \times N$ 行列 $A$ の逆行列を返します.

逆行列が存在しない場合は $0 \times 0$ の行列を返します.

制約

計算量

Depends on

Verified with

Code

#pragma once
#include "../template/template.hpp"
#include "./matrix.hpp"
#include "./gauss_elimination.hpp"
template <typename T>
Matrix<T> inverse(const Matrix<T>& a) {
    assert(a.H() > 0);
    assert(a.H() == a.W());
    const int n = a.H();
    Matrix<T> m(n, 2 * n);
    for(int i = 0; i < n; ++i) {
        for(int j = 0; j < n; ++j) {
            m[i][j] = a[i][j];
        }
        m[i][n + i] = 1;
    }
    if(gauss_elimination(m, n).first != n) {
        Matrix<T> res(0, 0);
        return res;
    }
    Matrix<T> b(n, n);
    for(int i = 0; i < n; ++i) {
        for(int j = 0; j < n; ++j) {
            b[i][j] = m[i][j + n];
        }
    }
    return b;
}
#line 2 "src/template/template.hpp"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using P = pair<long long, long long>;
#define rep(i, a, b) for(long long i = (a); i < (b); ++i)
#define rrep(i, a, b) for(long long i = (a); i >= (b); --i)
constexpr long long inf = 4e18;
struct SetupIO {
    SetupIO() {
        ios::sync_with_stdio(0);
        cin.tie(0);
        cout << fixed << setprecision(30);
    }
} setup_io;
#line 3 "src/matrix/matrix.hpp"
template <typename T>
struct Matrix {
    Matrix(const int h, const int w, const T& val = 0)
        : h(h), w(w), A(h, vector<T>(w, val)) {}
    int H() const {
        return h;
    }
    int W() const {
        return w;
    }
    const vector<T>& operator[](const int i) const {
        assert(0 <= i and i < h);
        return A[i];
    }
    vector<T>& operator[](const int i) {
        assert(0 <= i and i < h);
        return A[i];
    }
    static Matrix I(const int n) {
        Matrix mat(n, n);
        for(int i = 0; i < n; ++i) mat[i][i] = 1;
        return mat;
    }
    Matrix& operator+=(const Matrix& B) {
        assert(h == B.h and w == B.w);
        for(int i = 0; i < h; ++i) {
            for(int j = 0; j < w; ++j) {
                (*this)[i][j] += B[i][j];
            }
        }
        return (*this);
    }
    Matrix& operator-=(const Matrix& B) {
        assert(h == B.h and w == B.w);
        for(int i = 0; i < h; ++i) {
            for(int j = 0; j < w; ++j) {
                (*this)[i][j] -= B[i][j];
            }
        }
        return (*this);
    }
    Matrix& operator*=(const Matrix& B) {
        assert(w == B.h);
        vector<vector<T>> C(h, vector<T>(B.w, 0));
        for(int i = 0; i < h; ++i) {
            for(int k = 0; k < w; ++k) {
                for(int j = 0; j < B.w; ++j) {
                    C[i][j] += (*this)[i][k] * B[k][j];
                }
            }
        }
        A.swap(C);
        return (*this);
    }
    Matrix& pow(long long t) {
        assert(h == w);
        assert(t >= 0);
        Matrix B = Matrix::I(h);
        while(t > 0) {
            if(t & 1ll) B *= (*this);
            (*this) *= (*this);
            t >>= 1ll;
        }
        A.swap(B.A);
        return (*this);
    }
    Matrix operator+(const Matrix& B) const {
        return (Matrix(*this) += B);
    }
    Matrix operator-(const Matrix& B) const {
        return (Matrix(*this) -= B);
    }
    Matrix operator*(const Matrix& B) const {
        return (Matrix(*this) *= B);
    }
    bool operator==(const Matrix& B) const {
        assert(h == B.H() and w == B.W());
        for(int i = 0; i < h; ++i) {
            for(int j = 0; j < w; ++j) {
                if(A[i][j] != B[i][j]) return false;
            }
        }
        return true;
    }
    bool operator!=(const Matrix& B) const {
        assert(h == B.H() and w == B.W());
        for(int i = 0; i < h; ++i) {
            for(int j = 0; j < w; ++j) {
                if(A[i][j] != B[i][j]) return true;
            }
        }
        return false;
    }

   private:
    int h, w;
    vector<vector<T>> A;
};
#line 4 "src/matrix/gauss_elimination.hpp"
template <typename T>
pair<int, T> gauss_elimination(Matrix<T>& a, int pivot_end = -1) {
    const int h = a.H(), w = a.W();
    int rank = 0;
    assert(-1 <= pivot_end and pivot_end <= w);
    if(pivot_end == -1) pivot_end = w;
    T det = 1;
    for(int j = 0; j < pivot_end; ++j) {
        int idx = -1;
        for(int i = rank; i < h; ++i) {
            if(a[i][j] != T(0)) {
                idx = i;
                break;
            }
        }
        if(idx == -1) {
            det = 0;
            continue;
        }
        if(rank != idx) det = -det, swap(a[rank], a[idx]);
        det *= a[rank][j];
        if(a[rank][j] != T(1)) {
            const T coeff = T(1) / a[rank][j];
            for(int k = j; k < w; ++k) a[rank][k] *= coeff;
        }
        for(int i = 0; i < h; ++i) {
            if(i == rank) continue;
            if(a[i][j] != T(0)) {
                const T coeff = a[i][j] / a[rank][j];
                for(int k = j; k < w; ++k) a[i][k] -= a[rank][k] * coeff;
            }
        }
        ++rank;
    }
    return {rank, det};
}
#line 5 "src/matrix/inverse.hpp"
template <typename T>
Matrix<T> inverse(const Matrix<T>& a) {
    assert(a.H() > 0);
    assert(a.H() == a.W());
    const int n = a.H();
    Matrix<T> m(n, 2 * n);
    for(int i = 0; i < n; ++i) {
        for(int j = 0; j < n; ++j) {
            m[i][j] = a[i][j];
        }
        m[i][n + i] = 1;
    }
    if(gauss_elimination(m, n).first != n) {
        Matrix<T> res(0, 0);
        return res;
    }
    Matrix<T> b(n, n);
    for(int i = 0; i < n; ++i) {
        for(int j = 0; j < n; ++j) {
            b[i][j] = m[i][j + n];
        }
    }
    return b;
}
Back to top page