この記事について
差分更新型ビームサーチライブラリの実装例について説明します。
差分更新型のビームサーチについては 高速なビームサーチが欲しい!!! などで既に解説されているため、被る部分については詳しく説明しません。
この記事に書いたソースコード をプログラミングコンテスト で自由に使っていただいて構いませんが、当ブログは損害などに対する責任を負いかねますのでご了承ください。
前提知識
差分更新型のビームサーチとは
木上のビームサーチ、Euler Tour ビームサーチとも呼ばれています1 。
ビームサーチは、幅優先探索 に枝刈りを取り入れた手法です。探索木における深さ の頂点集合から深さ の頂点集合を生成し、その中から評価値が高い上位 個を選択します。 のことをビーム幅と呼びます。
ビームサーチの例( 、赤色が探索中のノード、橙色が採用されたノード、灰色が不採用のノード)
愚直なビームサーチでは、頂点ごとに独立した状態を作成します。実装は比較的楽だと思いますが、状態や履歴をコピーする必要が出てくるため、実行速度が遅くなりやすいです2 。
そこで、探索木を明示的に作成し、Euler Tour の順序で1つの状態を更新するようにしたものが差分更新型のビームサーチです。探索木の葉を訪れたときに新しい葉の候補を生成します。
Euler Tour の例
状態遷移の履歴は探索木から復元できるため、状態のコピーだけでなく履歴のコピーも省略できます。
深さが の1つの葉を探索するのに最悪で 回の遷移を必要としますが、実際には多くの葉で近い先祖を共有するため、遷移回数は平均して 回よりもかなり少なくなります。
差分更新型ビームサーチは愚直なビームサーチよりも高速に動作する場合が多く3 、実装がやや複雑なため、ライブラリを作成しました。
実装
ライブラリ全体を折りたたみに記載しました。コードを読む場合は下の方にある beam_search
関数を最初に読むと分かりやすいかもしれません。
ビームサーチライブラリ
#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std ;
namespace beam_search {
template <class T>
class ObjectPool {
public :
T& operator [](int i) {
return data_[i];
}
void reserve (size_t capacity ) {
data_.reserve (capacity );
}
int push (const T& x) {
if (garbage_.empty ()) {
data_.push_back (x);
return data_.size () - 1 ;
} else {
int i = garbage_.top ();
garbage_.pop ();
data_[i] = x;
return i;
}
}
void pop (int i) {
garbage_.push (i);
}
size_t size () {
return data_.size ();
}
private :
vector <T> data_;
stack <int > garbage_;
};
template <class Key, class T>
struct HashMap {
public :
explicit HashMap (uint32_t n) {
n_ = n;
valid_.resize (n_, false );
data_.resize (n_);
}
pair <bool ,int > get_index (Key key) const {
Key i = key % n_;
while (valid_[i]) {
if (data_[i].first == key) {
return {true , i};
}
if (++i == n_) {
i = 0 ;
}
}
return {false , i};
}
void set (int i, Key key, T value) {
valid_[i] = true ;
data_[i] = {key, value};
}
T get (int i) const {
assert (valid_[i]);
return data_[i].second;
}
void clear () {
fill (valid_.begin (), valid_.end (), false );
}
private :
uint32_t n_;
vector <bool > valid_;
vector <pair <Key,T>> data_;
};
using Hash = uint32_t ; TODO
struct Action {
TODO
Action () {
TODO
}
};
using Cost = int ; TODO
struct Evaluator {
TODO
Evaluator () {
TODO
}
Cost evaluate () const {
TODO
}
};
struct Candidate {
Action action;
Evaluator evaluator;
Hash hash ;
int parent;
Cost cost;
Candidate (Action action, Evaluator evaluator, Hash hash , int parent, Cost cost) :
action (action),
evaluator (evaluator),
hash (hash ),
parent (parent),
cost (cost) {}
};
struct Config {
int max_turn;
size_t beam_width;
size_t nodes_capacity;
uint32_t hash_map_capacity;
};
using MaxSegtree = atcoder::segtree<
pair <Cost,int >,
[](pair <Cost,int > a, pair <Cost,int > b){
if (a.first >= b.first) {
return a;
} else {
return b;
}
},
[]() { return make_pair (numeric_limits <Cost>::min (), -1 ); }
>;
class Selector {
public :
explicit Selector (const Config& config) :
hash_to_index_ (config.hash_map_capacity)
{
beam_width = config.beam_width;
candidates_.reserve (beam_width);
full_ = false ;
st_original_.resize (beam_width);
}
void push (Action action, const Evaluator& evaluator, Hash hash , int parent, bool finished) {
Cost cost = evaluator.evaluate ();
if (finished) {
finished_candidates_.emplace_back (Candidate (action, evaluator, hash , parent, cost));
return ;
}
if (full_ && cost >= st_.all_prod ().first) {
return ;
}
auto [valid, i] = hash_to_index_.get_index (hash );
if (valid) {
int j = hash_to_index_.get (i);
if (hash == candidates_[j].hash ) {
if (cost < candidates_[j].cost) {
candidates_[j] = Candidate (action, evaluator, hash , parent, cost);
if (full_) {
st_.set (j, {cost, j});
}
}
return ;
}
}
if (full_) {
int j = st_.all_prod ().second;
hash_to_index_.set (i, hash , j);
candidates_[j] = Candidate (action, evaluator, hash , parent, cost);
st_.set (j, {cost, j});
} else {
hash_to_index_.set (i, hash , candidates_.size ());
candidates_.emplace_back (Candidate (action, evaluator, hash , parent, cost));
if (candidates_.size () == beam_width) {
construct_segment_tree ();
}
}
}
const vector <Candidate>& select () const {
return candidates_;
}
bool have_finished () const {
return !finished_candidates_.empty ();
}
vector <Candidate> get_finished_candidates () const {
return finished_candidates_;
}
void clear () {
candidates_.clear ();
hash_to_index_.clear ();
full_ = false ;
}
private :
size_t beam_width;
vector <Candidate> candidates_;
HashMap<Hash,int > hash_to_index_;
bool full_;
vector <pair <Cost,int >> st_original_;
MaxSegtree st_;
vector <Candidate> finished_candidates_;
void construct_segment_tree () {
full_ = true ;
for (size_t i = 0 ; i < beam_width; ++i) {
st_original_[i] = {candidates_[i].cost, i};
}
st_ = MaxSegtree (st_original_);
}
};
class State {
public :
explicit State () {
TODO
}
void expand (const Evaluator& evaluator, Hash hash , int parent, Selector& selector) {
TODO
}
void move_forward (Action action) {
TODO
}
void move_backward (Action action) {
TODO
}
private :
TODO
};
struct Node {
Action action;
Evaluator evaluator;
Hash hash ;
int parent, child, left , right ;
Node (Action action, const Evaluator& evaluator, Hash hash ) :
action (action),
evaluator (evaluator),
hash (hash ),
parent (-1 ),
child (-1 ),
left (-1 ),
right (-1 ) {}
Node (const Candidate& candidate, int right ) :
action (candidate.action),
evaluator (candidate.evaluator),
hash (candidate.hash ),
parent (candidate.parent),
child (-1 ),
left (-1 ),
right (right ) {}
};
class Tree {
public :
explicit Tree (const State& state, size_t nodes_capacity, const Node& root) :
state_ (state)
{
nodes_.reserve (nodes_capacity);
root_ = nodes_.push (root);
}
void dfs (Selector& selector) {
update_root ();
int v = root_;
while (true ) {
v = move_to_leaf (v);
state_.expand (nodes_[v].evaluator, nodes_[v].hash , v, selector);
v = move_to_ancestor (v);
if (v == root_) {
break ;
}
v = move_to_right (v);
}
}
vector <Action> get_path (int v) {
vector <Action> path ;
while (nodes_[v].parent != -1 ) {
path .push_back (nodes_[v].action);
v = nodes_[v].parent;
}
reverse (path .begin (), path .end ());
return path ;
}
int add_leaf (const Candidate& candidate) {
int parent = candidate.parent;
int sibling = nodes_[parent].child;
int v = nodes_.push (Node (candidate, sibling));
nodes_[parent].child = v;
if (sibling != -1 ) {
nodes_[sibling].left = v;
}
return v;
}
void remove_if_leaf (int v) {
if (nodes_[v].child == -1 ) {
remove_leaf (v);
}
}
int get_best_leaf (const vector <int >& last_nodes) {
assert (!last_nodes.empty ());
int ret = last_nodes[0 ];
for (int v : last_nodes) {
if (nodes_[v].evaluator.evaluate () < nodes_[ret].evaluator.evaluate ()) {
ret = v;
}
}
return ret;
}
private :
State state_;
ObjectPool<Node> nodes_;
int root_;
void update_root () {
int child = nodes_[root_].child;
while (child != -1 && nodes_[child].right == -1 ) {
root_ = child;
state_.move_forward (nodes_[child].action);
child = nodes_[child].child;
}
}
int move_to_leaf (int v) {
int child = nodes_[v].child;
while (child != -1 ) {
v = child;
state_.move_forward (nodes_[child].action);
child = nodes_[child].child;
}
return v;
}
int move_to_ancestor (int v) {
while (v != root_ && nodes_[v].right == -1 ) {
state_.move_backward (nodes_[v].action);
v = nodes_[v].parent;
}
return v;
}
int move_to_right (int v) {
state_.move_backward (nodes_[v].action);
v = nodes_[v].right ;
state_.move_forward (nodes_[v].action);
return v;
}
void remove_leaf (int v) {
while (true ) {
int left = nodes_[v].left ;
int right = nodes_[v].right ;
if (left == -1 ) {
int parent = nodes_[v].parent;
if (parent == -1 ) {
cerr << "ERROR: root is removed" << endl ;
exit (-1 );
}
nodes_.pop (v);
nodes_[parent].child = right ;
if (right != -1 ) {
nodes_[right ].left = -1 ;
return ;
}
v = parent;
} else {
nodes_.pop (v);
nodes_[left ].right = right ;
if (right != -1 ) {
nodes_[right ].left = left ;
}
return ;
}
}
}
};
vector <Action> beam_search (const Config& config, State state, Node root) {
Tree tree (state, config.nodes_capacity, root);
vector <int > curr_nodes;
curr_nodes.reserve (config.beam_width);
vector <int > next_nodes;
next_nodes.reserve (config.beam_width);
Selector selector (config);
for (int turn = 0 ; turn < config.max_turn; ++turn) {
tree.dfs (selector);
if (selector.have_finished ()) {
Candidate candidate = selector.get_finished_candidates ()[0 ];
vector <Action> ret = tree.get_path (candidate.parent);
ret.push_back (candidate.action);
return ret;
}
for (const Candidate& candidate : selector.select ()) {
next_nodes.push_back (tree.add_leaf (candidate));
}
if (next_nodes.empty ()) {
cerr << "ERROR: Failed to find any valid solution" << endl ;
return {};
}
for (int v : curr_nodes) {
tree.remove_if_leaf (v);
}
swap (curr_nodes, next_nodes);
next_nodes.clear ();
selector.clear ();
}
int best_leaf = tree.get_best_leaf (curr_nodes);
return tree.get_path (best_leaf);
}
}
それぞれの構造体やクラスについて見ていきます。
Object Pool
ビームサーチと直接の関係がないクラスです。
配列にオブジェクトを保存し、削除したオブジェクトの場所を再利用します。
また、std::vector
などと同様に reserve
でメモリを確保できるようにしました4 。
例を用いた説明
最初は長さ4の空の配列とします。
3, 1, 4 を順に追加します。追加した場所である 0, 1, 2 を順に返します。
a[0]
a[1]
a[2]
a[3]
3
1
4
a[1] を削除します。
5 を追加します。a[1] と a[3] のどちらに追加してもよいのですが、最後に使用した場所を優先的に使用することにします。今回の例だと a[1] です。追加した場所である 1 を返します。
a[0]
a[1]
a[2]
a[3]
3
5
4
削除した場所のインデックスを保持することでこのような挙動を実装することができます。
差分更新型のビームサーチでは、ノードの追加と削除を頻繁に繰り返します。Object Pool を使うことで次のようなメリットを享受できます。
メモリの再利用により空間計算量を削減できる。
メモリ上の連続した領域を使用するので、データがキャッシュに乗りやすくなる。
十分な大きさのメモリを最初に確保することで、メモリの再割り当てをなくすことができる。
HashMap
Selector でハッシュ値 が重複した候補を除去するところで使います。
open addressing を使用し、インデックスが衝突したときは linear probing を行いました。(ハッシュ値 に対する)ハッシュ関数 がなく、挿入と一括削除しか行わないので、std::unordered_set
よりも単純で高速に動作します。
Hash
が整数型でないときは整数型に変換するか、HashMap を std::unordered_set
に変更する必要があります。
ちなみに、元々私は std::unordered_set
を使っていて、saharan さんのツイート を読んで参考にしたら速度が少し上がりました5 。
Action
状態遷移に必要な情報をまとめます。ここでいう状態遷移は、親からの移動と親への移動の両方を指しています。
Evaluator
評価値を計算するための構造体です。
基本的な実装は次のようなものです。
struct Evaluator {
Cost cost;
Evaluator (Cost cost) : cost (cost) {}
Cost evaluate () const {
return cost;
}
};
Evaluator::evaluate
でコストを返します。コストが低いほど採用されやすくなります。
補足: Action や Evaluator で何を保持すべきか
既に述べたように、状態遷移に必要な情報を Action で保持し、評価結果の比較に必要な情報を Evaluator で保持することを想定しています。
一方で、Action と Evaluator に、より多くの情報を持たせることもできます。
状態が変数 x
をメンバとして保持し、状態を更新するときに毎度 x
を更新するものとします。このとき、x
を状態ではなく Action や Evaluator が保持するようにすれば、Euler Tour における x
の更新を省略できます。状態の関数内で x
を使用したいときには、Action や Evaluator のメンバにアクセスすればよいです。x
の後退処理(Euler Tour における子から親への遷移)を実装する必要がなくなるというメリットもあります。
例えば、複数の評価項目があるときに、Evaluator で各評価項目の値を保持するということが考えられます。また、後退処理がなくなるため、浮動小数点数 も扱いやすくなります6 。
一方で、Action や Evaluator の使用メモリが小さいほどよいという側面もあるため、全ての変数を Action や Evaluator に保持すればいいわけではありません。更新が面倒で使用メモリが少ない変数だけを Action や Evaluator に保持するとよいと思います。
例えば、状態内で探索木の深さを管理する場合、深さを更新するときにインクリメントやデクリメントという非常に軽い処理しか行われないため、深さは状態に保持すればよいと思います。
ちなみに、Action と State を空にして Evaluator で全ての情報を保持するようにすると、愚直なビームサーチらしくなります。
Candidate
新しいノードの候補を表現します。
Selector
Candidate の中から新しいノードとして採用するものを選びます。採用された Candidate を元に新しいノードが追加されます。
新しいノードを生成してから不要なものを削除するという実装も考えられますが、木に対する操作は定数倍が重いため、Candidate 構造体を経由する実装になっています。
コードのコメントに書かれているように、ハッシュ値 が一致した候補については評価が最もよいものに限定し、その中から評価がよい候補をビーム幅の個数だけ選びます。ハッシュ値 の重複を検出するために自作の HashMap を使ったり、ソートをなくすために atcoder::segtree
を削除可能な優先度付きキューとして使ったり、高速化を意識して実装しました。
ハッシュ値 以外で多様性を確保したい場合、Selector を実装し直す必要があります。例えばスライドパズルなどで「現在のマスの座標が一致するものの中から上位 個を選ぶ」という場合には書き直す必要があります。
State
Euler Tour に沿って更新する情報をまとめたクラスです。問題ごとに各メソッドを実装する必要があります。
ビームサーチの最中に State がコピーされることはないため、空間計算量は大きくても構いません。一方で、各メソッドは速いほどよいです。
Tree
探索木を二重連 鎖木で表現し、木に対する操作をまとめたクラスです。
二重連 鎖木のノードは次のノードへのポインタ7 を持ちます。
状態の更新順序は Euler Tour と一緒です。
Euler Tour の例
一方で、二重連 鎖木上では兄弟間を直接移動するようにします。
二重連 鎖木の遷移
二重連 鎖木のノードは配列を使用しないため、Euler Tour もノードの追加や削除も簡潔に実装できます。
不要なノードは全て削除します。不要なノードというのは、子ができなかった、あるいは子が全て削除されたノードのことです。
高速なビームサーチが欲しい!!! で紹介されているように、根から一本道の部分は反復しないようにします。
枠で囲った範囲で状態遷移を行う
上図では探索したノードが全て描かれていますが、実際には灰色のノードは作成されず、さらに赤色のノードを子孫として持たない6つの橙色のノードは削除されていることに注意してください。
beam_search
ビームサーチを実行する関数です。ライブラリの外からこの関数を呼び出します。
使用例
TOYOTA Programming Contest 2023 Summer(AtCoder Heuristic Contest 021) の実装例を紹介します。
大まかな方針
番号が小さいボールから揃えます。
番号が最小のボール、あるいはその左上または右上のボールを、左上か右上に移動させます。
紫色のスワップ を遷移の候補とする
評価関数は次のように設定しました。小さいほうがよいです。
ハッシュ値 は、揃えているボールの位置と、既に揃えたボールの位置の集合の2つから生成しました。
ライブラリの使い方を紹介することが目的なので、考察などは省略します。
inline int get_pyramid_index (int x, int y) {
return x * (x - 1 ) / 2 + y;
}
using Hash = uint32_t ;
constexpr Hash hash_mask = ((1U << 23 ) - 1U ) << 9 ;
inline Hash update_target_position (Hash hash , int x, int y) {
return (hash & hash_mask) | get_pyramid_index (x, y);
}
inline Hash update_sorted_position (Hash hash , int x, int y) {
Hash zobrist_hash = get_pyramid_index (x, y);
zobrist_hash |= 512U ;
zobrist_hash *= zobrist_hash * zobrist_hash;
return hash ^ (zobrist_hash & hash_mask);
}
下位9ビットで揃えているボールの位置を保持し、残りのビットで既に揃えたボールの位置の集合の Zobrist hashing を保持しました。
Action
struct Action {
int xyxy;
Action (int x1, int y1, int x2, int y2) {
xyxy = x1 | (y1 << 8 ) | (x2 << 16 ) | (y2 << 24 );
}
tuple <int ,int ,int ,int > decode () const {
return {xyxy & 255 , (xyxy >> 8 ) & 255 , (xyxy >> 16 ) & 255 , xyxy >> 24 };
}
};
AHC021の場合、スワップ 位置が分かれば盤面の更新が行えるため、スワップ する2つの位置を保持しました。
4つの整数を1つの int
にエンコード し、メモリ使用量を減らしています8 。
Evaluator
using Cost = int ;
constexpr int target_coefficient = 600 ;
struct Evaluator {
int target_ball;
int potential;
Evaluator (int target_ball, int potential) :
target_ball (target_ball),
potential (potential) {}
Cost evaluate () const {
return potential - target_coefficient * target_ball;
}
};
2つの評価項目を保持しました。
State
更新部分だけ説明します。他の部分を読みたい場合は提出コードをご覧ください。
class State {
public :
void move_forward (Action action) {
auto [x1, y1, x2, y2] = action.decode ();
swap_balls (x1, y1, x2, y2);
}
void move_backward (Action action) {
auto [x1, y1, x2, y2] = action.decode ();
swap_balls (x1, y1, x2, y2);
}
private :
vector <vector <int >> b_;
array <pair <int ,int >,m> positions_;
void swap_balls (int x1, int y1, int x2, int y2) {
int b1 = b_[x1][y1];
int b2 = b_[x2][y2];
b_[x1][y1] = b2;
b_[x2][y2] = b1;
positions_[b2] = {x1, y1};
positions_[b1] = {x2, y2};
}
};
b_
は盤面を表しています。positions_[i]
はボール i
の位置を表します。 positions_[b_[x][y]] = {x, y}
が成り立ちます。positions_
は、あるボールを揃えた後に次のボールの位置を得るときなどに使われます。
Evaluator
が揃えたボールの数を保持するため、現在揃えているボールの番号を保持・更新する必要がないことに注意してください。
今回は2つのボールを入れ替えるだけなので move_forward
と move_backward
が等しくなっていますが、一般的には異なります。
提出コード
ビーム幅を3500に設定しました。
提出言語は C++ 20 (Clang 16.0.6) です。私のビームサーチライブラリ(の主に Selector)は GCC よりも Clang のほうが高速に動作するようでした。
atcoder.jp
他の使用例
ゲーム実況者Xの挑戦 の提出コードです。解説などは省略します。
atcoder.jp
Euler Tour の辺を保持する実装 (2024/02/07 追記)
二重連 鎖木や Object Pool を使うのではなく、探索木の有向辺を Euler Tour の順序で保持するほうが速いという話があり、実装してみました。
状態を差分計算するときは Euler Tour の配列に前から順にアクセスします。
探索木を更新するときは、辺の追加や削除をしながら、別の配列に Euler Tour をコピーしています。
ビームサーチライブラリ
#include <bits/stdc++.h>
using namespace std ;
namespace beam_search {
struct Config {
int max_turn;
size_t beam_width;
size_t tour_capacity;
uint32_t hash_map_capacity;
};
template <class Key, class T>
struct HashMap {
public :
explicit HashMap (uint32_t n) {
if (n % 2 == 0 ) {
++n;
}
n_ = n;
valid_.resize (n_, false );
data_.resize (n_);
}
pair <bool ,int > get_index (Key key) const {
Key i = key % n_;
while (valid_[i]) {
if (data_[i].first == key) {
return {true , i};
}
if (++i == n_) {
i = 0 ;
}
}
return {false , i};
}
void set (int i, Key key, T value) {
valid_[i] = true ;
data_[i] = {key, value};
}
T get (int i) const {
assert (valid_[i]);
return data_[i].second;
}
void clear () {
fill (valid_.begin (), valid_.end (), false );
}
private :
uint32_t n_;
vector <bool > valid_;
vector <pair <Key,T>> data_;
};
using Hash = uint32_t ; TODO
struct Action {
TODO
Action () {
TODO
}
bool operator ==(const Action& other) const {
TODO
}
};
using Cost = int ;
struct Evaluator {
TODO
Evaluator () {
TODO
}
Cost evaluate () const {
TODO
}
};
struct Candidate {
Action action;
Evaluator evaluator;
Hash hash ;
int parent;
Candidate (Action action, Evaluator evaluator, Hash hash , int parent) :
action (action),
evaluator (evaluator),
hash (hash ),
parent (parent) {}
};
class Selector {
public :
explicit Selector (const Config& config) :
hash_to_index_ (config.hash_map_capacity)
{
beam_width = config.beam_width;
candidates_.reserve (beam_width);
full_ = false ;
costs_.resize (beam_width);
for (size_t i = 0 ; i < beam_width; ++i) {
costs_[i] = {0 , i};
}
}
void push (const Candidate& candidate, bool finished) {
if (finished) {
finished_candidates_.emplace_back (candidate);
return ;
}
Cost cost = candidate.evaluator.evaluate ();
if (full_ && cost >= st_.all_prod ().first) {
return ;
}
auto [valid, i] = hash_to_index_.get_index (candidate.hash );
if (valid) {
int j = hash_to_index_.get (i);
if (candidate.hash == candidates_[j].hash ) {
if (full_) {
if (cost < st_.get (j).first) {
candidates_[j] = candidate;
st_.set (j, {cost, j});
}
} else {
if (cost < costs_[j].first) {
candidates_[j] = candidate;
costs_[j].first = cost;
}
}
return ;
}
}
if (full_) {
int j = st_.all_prod ().second;
hash_to_index_.set (i, candidate.hash , j);
candidates_[j] = candidate;
st_.set (j, {cost, j});
} else {
int j = candidates_.size ();
hash_to_index_.set (i, candidate.hash , j);
candidates_.emplace_back (candidate);
costs_[j].first = cost;
if (candidates_.size () == beam_width) {
full_ = true ;
st_ = MaxSegtree (costs_);
}
}
}
const vector <Candidate>& select () const {
return candidates_;
}
bool have_finished () const {
return !finished_candidates_.empty ();
}
vector <Candidate> get_finished_candidates () const {
return finished_candidates_;
}
Candidate calculate_best_candidate () const {
if (full_) {
size_t best = 0 ;
for (size_t i = 0 ; i < beam_width; ++i) {
if (st_.get (i).first < st_.get (best).first) {
best = i;
}
}
return candidates_[best];
} else {
size_t best = 0 ;
for (size_t i = 0 ; i < candidates_.size (); ++i) {
if (costs_[i].first < costs_[best].first) {
best = i;
}
}
return candidates_[best];
}
}
void clear () {
candidates_.clear ();
hash_to_index_.clear ();
full_ = false ;
}
private :
using MaxSegtree = atcoder::segtree<
pair <Cost,int >,
[](pair <Cost,int > a, pair <Cost,int > b){
if (a.first >= b.first) {
return a;
} else {
return b;
}
},
[]() { return make_pair (numeric_limits <Cost>::min (), -1 ); }
>;
size_t beam_width;
vector <Candidate> candidates_;
HashMap<Hash,int > hash_to_index_;
bool full_;
vector <pair <Cost,int >> costs_;
MaxSegtree st_;
vector <Candidate> finished_candidates_;
};
class State {
public :
explicit State () {
TODO
}
pair <Evaluator,Hash> make_initial_node () {
TODO
}
void expand (const Evaluator& evaluator, Hash hash , int parent, Selector& selector) {
TODO
}
void move_forward (Action action) {
TODO
}
void move_backward (Action action) {
TODO
}
private :
TODO
};
class Tree {
public :
explicit Tree (const State& state, const Config& config) :
state_ (state)
{
curr_tour_.reserve (config.tour_capacity);
next_tour_.reserve (config.tour_capacity);
leaves_.reserve (config.beam_width);
buckets_.assign (config.beam_width, {});
}
void dfs (Selector& selector) {
if (curr_tour_.empty ()) {
auto [evaluator, hash ] = state_.make_initial_node ();
state_.expand (evaluator, hash , 0 , selector);
return ;
}
for (auto [leaf_index, action] : curr_tour_) {
if (leaf_index >= 0 ) {
state_.move_forward (action);
auto & [evaluator, hash ] = leaves_[leaf_index];
state_.expand (evaluator, hash , leaf_index, selector);
state_.move_backward (action);
} else if (leaf_index == -1 ) {
state_.move_forward (action);
} else {
state_.move_backward (action);
}
}
}
void update (const vector <Candidate>& candidates) {
leaves_.clear ();
if (curr_tour_.empty ()) {
for (const Candidate& candidate : candidates) {
curr_tour_.push_back ({(int )leaves_.size (), candidate.action});
leaves_.push_back ({candidate.evaluator, candidate.hash });
}
return ;
}
for (const Candidate& candidate : candidates) {
buckets_[candidate.parent].push_back ({candidate.action, candidate.evaluator, candidate.hash });
}
auto it = curr_tour_.begin ();
while (it->first == -1 && it->second == curr_tour_.back ().second) {
Action action = (it++)->second;
state_.move_forward (action);
direct_road_.push_back (action);
curr_tour_.pop_back ();
}
while (it != curr_tour_.end ()) {
auto [leaf_index, action] = *(it++);
if (leaf_index >= 0 ) {
if (buckets_[leaf_index].empty ()) {
continue ;
}
next_tour_.push_back ({-1 , action});
for (auto [new_action, evaluator, hash ] : buckets_[leaf_index]) {
int new_leaf_index = leaves_.size ();
next_tour_.push_back ({new_leaf_index, new_action});
leaves_.push_back ({evaluator, hash });
}
buckets_[leaf_index].clear ();
next_tour_.push_back ({-2 , action});
} else if (leaf_index == -1 ) {
next_tour_.push_back ({-1 , action});
} else {
auto [old_leaf_index, old_action] = next_tour_.back ();
if (old_leaf_index == -1 ) {
next_tour_.pop_back ();
} else {
next_tour_.push_back ({-2 , action});
}
}
}
swap (curr_tour_, next_tour_);
next_tour_.clear ();
}
vector <Action> calculate_path (int parent, int turn) const {
vector <Action> ret = direct_road_;
ret.reserve (turn);
for (auto [leaf_index, action] : curr_tour_) {
if (leaf_index >= 0 ) {
if (leaf_index == parent) {
ret.push_back (action);
return ret;
}
} else if (leaf_index == -1 ) {
ret.push_back (action);
} else {
ret.pop_back ();
}
}
unreachable ();
}
private :
State state_;
vector <pair <int ,Action>> curr_tour_;
vector <pair <int ,Action>> next_tour_;
vector <pair <Evaluator,Hash>> leaves_;
vector <vector <tuple <Action,Evaluator,Hash>>> buckets_;
vector <Action> direct_road_;
};
vector <Action> beam_search (const Config& config, const State& state) {
Tree tree (state, config);
Selector selector (config);
for (int turn = 0 ; turn < config.max_turn; ++turn) {
tree.dfs (selector);
if (selector.have_finished ()) {
Candidate candidate = selector.get_finished_candidates ()[0 ];
vector <Action> ret = tree.calculate_path (candidate.parent, turn + 1 );
ret.push_back (candidate.action);
return ret;
}
assert (!selector.select ().empty ());
if (turn == config.max_turn - 1 ) {
Candidate best_candidate = selector.calculate_best_candidate ();
vector <Action> ret = tree.calculate_path (best_candidate.parent, turn + 1 );
ret.push_back (best_candidate.action);
return ret;
}
tree.update (selector.select ());
selector.clear ();
}
unreachable ();
}
}
元の二重連 鎖木を使った実装よりも高速に動作するようでした。
atcoder.jp
atcoder.jp
最後に
この記事では私の実装のみを紹介しました。実装した人によって異なる部分があるので調べてみると面白いかもしれません。
最後まで読んでくださりありがとうございました。