Segment Tree
Code
Node must have a sum operator, as well as a default constructor which must be the monoid's neutral element.
Segment Tree
//{{{ Segment Tree
template<typename T>
class SegmentTree {
int N;
vector<T> data;
public:
explicit SegmentTree(int N) : N(N), data(2*N) {}
explicit SegmentTree(vector<T> const& A) : N(size(A)) {
for (int i = 0; i < N; i++) set(i, A[i]);
}
void set(int p, T const& val) {
for (data[p+=N]=val; p /= 2;)
data[p] = data[2*p]+data[2*p+1];
}
T get(int p) const {
return data[p + N];
}
void add(int p, T const& val) {
set(p, get(p)+val);
}
T sum(int l, int r) const {
T rl = T(), rr = T();
for (l+=N, r+=N+1; l<r; l/=2, r/=2) {
if (l&1) rl = rl+data[l++];
if (r&1) rr = data[--r]+rr;
}
return rl+rr;
}
};
//}}}
Tested on Library Checker - Point Add Range Sum [Submission]
Code (Simple)
This segment tree has a simpler interface in some cases, but I heard that std::function
has a little bit of overhead.
It runs about 15% slower on Yosupo, probably still faster than a recursive segment tree.
If you're doing simple operations on integers it will probably result in shorter code. On the implementation above, you would need to do this for a max segment tree:
struct MaxNode {
int x;
MaxNode() : x(-INF) {}
MaxNode(int x) : x(x) {}
friend MaxNode operator+(MaxNode a, MaxNode b) {
return MaxNode(max(a.x, b.x));
}
};
SegmentTree<Node> S(N);
On this one you can do this:
Segment Tree Simple
//{{{ Segment Tree Simple
template<typename T>
class SegmentTreeSimple {
int N;
T neutral;
vector<T> data;
function<T(T,T)> merge;
public:
SegmentTreeSimple(int N, T neutral, function<T(T,T)> merge) {
this->N = N;
this->neutral = neutral;
this->merge = merge;
data.assign(2*N, neutral);
}
SegmentTreeSimple(vector<T> const& A, T neutral, function<T(T,T)> merge) {
this->N = A.size();
this->neutral = neutral;
this->merge = merge;
data.resize(2*N);
for (int i = 0; i < N; i++) data[i+N] = A[i];
for (int i=N-1; i>0; i--)
data[i]=merge(data[2*i],data[2*i+1]);
}
void set(int p, T const& val) {
for (data[p+=N]=val; p /= 2;)
data[p] = merge(data[2*p], data[2*p+1]);
}
T get(int p) const {
return data[p + N];
}
void add(int p, T const& val) {
set(p, get(p)+val);
}
T sum(int l, int r) const {
T rl = neutral, rr = neutral;
for (l+=N, r+=N+1; l<r; l/=2, r/=2) {
if (l&1) rl = merge(rl, data[l++]);
if (r&1) rr = merge(data[--r], rr);
}
return merge(rl, rr);
}
};
//}}}
Tested on Library Checker - Point Add Range Sum [Submission]