Program Listing for File optscaling_oracle.hpp¶
↰ Return to documentation for file (ellcpp/oracles/optscaling_oracle.hpp)
// -*- coding: utf-8 -*-
#pragma once
#include "network_oracle.hpp"
#include <cassert>
#include <xtensor/xarray.hpp>
template <typename Graph, typename Container, typename Fn> //
class optscaling_oracle
{
using Arr = xt::xarray<double, xt::layout_type::row_major>;
using edge_t = typename Graph::edge_t;
using Cut = std::tuple<Arr, double>;
class Ratio
{
private:
const Graph& _G;
Fn _get_cost;
public:
Ratio(const Graph& G, Fn get_cost)
: _G {G}
, _get_cost {std::move(get_cost)}
{
}
explicit Ratio(const Ratio&) = default;
auto eval(const edge_t& e, const Arr& x) const -> double
{
const auto [u, v] = this->_G.end_points(e);
const auto cost = this->_get_cost(e);
assert(u != v);
return (u < v) ? x(0) - cost : cost - x(1);
}
auto grad(const edge_t& e, const Arr& x) const -> Arr
{
const auto [u, v] = this->_G.end_points(e);
assert(u != v);
return (u < v) ? Arr {1., 0.} : Arr {0., -1.};
}
};
network_oracle<Graph, Container, Ratio> _network;
public:
optscaling_oracle(const Graph& G, Container& u, Fn get_cost)
: _network(G, u, Ratio {G, get_cost})
{
}
explicit optscaling_oracle(const optscaling_oracle&) = default;
// optscaling_oracle& operator=(const optscaling_oracle&) = delete;
// optscaling_oracle(optscaling_oracle&&) = default;
auto operator()(const Arr& x, double& t) -> std::tuple<Cut, bool>
{
const auto cut = this->_network(x);
if (cut)
{
return {*cut, false};
}
auto s = x(0) - x(1);
auto fj = s - t;
if (fj < 0)
{
t = s;
return {{Arr {1., -1.}, 0.}, true};
}
return {{Arr {1., -1.}, fj}, false};
}
};