Program Listing for File ldlt_ext.hpp

Return to documentation for file (ellcpp/oracles/ldlt_ext.hpp)

// -*- coding: utf-8 -*-
#pragma once

#include <ellcpp/ell_assert.hpp> // ELL_UNLIKELY
#include <ellcpp/utility.hpp>
#include <xtensor/xarray.hpp>

class ldlt_ext
{
    using Arr = xt::xarray<double, xt::layout_type::row_major>;
    using Vec = Arr;
    using Mat = Arr;
    using Rng = std::pair<size_t, size_t>;

  public:
    Rng p {0U, 0U};
    Vec v;

  private:
    const size_t n;
    Mat T;

  public:
    explicit ldlt_ext(size_t N)
        : v {zeros({N})}
        , n {N}
        , T {zeros({N, N})}
    {
    }

    ldlt_ext(const ldlt_ext&) = delete;
    ldlt_ext& operator=(const ldlt_ext&) = delete;
    ldlt_ext(ldlt_ext&&) = default;

    auto factorize(const Mat& A) -> bool
    {
        return this->factor([&](size_t i, size_t j) { return A(i, j); });
    }

    template <typename Callable, bool Allow_semidefinite = false>
    auto factor(Callable&& getA) -> bool
    {
        this->p = {0U, 0U};
        auto& [start, stop] = this->p;

        for (auto i = 0U; i != this->n; ++i)
        {
            // auto j = start;
            auto d = getA(i, start);
            for (auto j = start; j != i; ++j)
            {
                this->T(j, i) = d;
                this->T(i, j) = d / this->T(j, j); // note: T(j, i) here!
                auto s = j + 1;
                d = getA(i, s);
                for (auto k = start; k != s; ++k)
                {
                    d -= this->T(i, k) * this->T(k, s);
                }
            }
            this->T(i, i) = d;

            if constexpr (Allow_semidefinite)
            {
                if (d < 0.)
                {
                    // this->stop = i + 1;
                    stop = i + 1;
                    break;
                }
                if (ELL_UNLIKELY(d == 0.))
                {
                    start = i + 1;
                    // restart at i + 1, special as an LMI oracle
                }
            }
            else // not Allow_semidefinite
            {
                if (d <= 0.)
                {
                    stop = i + 1;
                    break;
                }
            }
        }

        return this->is_spd();
    }


    auto is_spd() const noexcept -> bool
    {
        return this->p.second == 0;
    }

    auto witness() -> double;

    auto sym_quad(const Vec& A) const -> double;

    auto sqrt() -> Mat;
};