Program Listing for File optscaling_oracle.hpp

Return to documentation for file (ellcpp/oracles/optscaling_oracle.hpp)

// -*- coding: utf-8 -*-
#pragma once

#include "network_oracle.hpp"
#include <cassert>
#include <xtensor/xarray.hpp>

template <typename Graph, typename Container, typename Fn> //
class optscaling_oracle
{
    using Arr = xt::xarray<double, xt::layout_type::row_major>;
    using edge_t = typename Graph::edge_t;
    using Cut = std::tuple<Arr, double>;

    class Ratio
    {
      private:
        const Graph& _G;
        Fn _get_cost;

      public:
        Ratio(const Graph& G, Fn get_cost)
            : _G {G}
            , _get_cost {std::move(get_cost)}
        {
        }

        explicit Ratio(const Ratio&) = default;

        auto eval(const edge_t& e, const Arr& x) const -> double
        {
            const auto [u, v] = this->_G.end_points(e);
            const auto cost = this->_get_cost(e);
            assert(u != v);
            return (u < v) ? x(0) - cost : cost - x(1);
        }

        auto grad(const edge_t& e, const Arr& x) const -> Arr
        {
            const auto [u, v] = this->_G.end_points(e);
            assert(u != v);
            return (u < v) ? Arr {1., 0.} : Arr {0., -1.};
        }
    };

    network_oracle<Graph, Container, Ratio> _network;

  public:
    optscaling_oracle(const Graph& G, Container& u, Fn get_cost)
        : _network(G, u, Ratio {G, get_cost})
    {
    }

    explicit optscaling_oracle(const optscaling_oracle&) = default;

    // optscaling_oracle& operator=(const optscaling_oracle&) = delete;
    // optscaling_oracle(optscaling_oracle&&) = default;

    auto operator()(const Arr& x, double& t) -> std::tuple<Cut, bool>
    {
        const auto cut = this->_network(x);
        if (cut)
        {
            return {*cut, false};
        }
        auto s = x(0) - x(1);
        auto fj = s - t;
        if (fj < 0)
        {
            t = s;
            return {{Arr {1., -1.}, 0.}, true};
        }
        return {{Arr {1., -1.}, fj}, false};
    }
};