Skip to content

Segment Tree

Code

Node must have a sum operator, as well as a default constructor which must be the monoid's neutral element.

Segment Tree
//{{{ Segment Tree
template<typename T>
class SegmentTree {
  int N;
  vector<T> data;
public:
  explicit SegmentTree(int N) : N(N), data(2*N) {}
  explicit SegmentTree(vector<T> const& A) : N(size(A)) {
    for (int i = 0; i < N; i++) set(i, A[i]);
  }

  void set(int p, T const& val) {
    for (data[p+=N]=val; p /= 2;)
      data[p] = data[2*p]+data[2*p+1];
  }

  T get(int p) const {
    return data[p + N];
  }

  void add(int p, T const& val) {
    set(p, get(p)+val);
  }

  T sum(int l, int r)  const {
    T rl = T(), rr = T();
    for (l+=N, r+=N+1; l<r; l/=2, r/=2) {
      if (l&1) rl = rl+data[l++];
      if (r&1) rr = data[--r]+rr;
    }
    return rl+rr;
  }
};
//}}}

Tested on Library Checker - Point Add Range Sum [Submission]

Code (Simple)

This segment tree has a simpler interface in some cases, but I heard that std::function has a little bit of overhead. It runs about 15% slower on Yosupo, probably still faster than a recursive segment tree.

If you're doing simple operations on integers it will probably result in shorter code. On the implementation above, you would need to do this for a max segment tree:

struct MaxNode {
    int x;
    MaxNode() : x(-INF) {}
    MaxNode(int x) : x(x) {}
    friend MaxNode operator+(MaxNode a, MaxNode b) {
        return MaxNode(max(a.x, b.x));
    }
};
SegmentTree<Node> S(N);

On this one you can do this:

SegmentTreeSimple<int> S(N, -INF, [](int x, int y){return max(x, y);});
Segment Tree Simple
//{{{ Segment Tree Simple
template<typename T>
class SegmentTreeSimple {
  int N;
  T neutral;
  vector<T> data;
  function<T(T,T)> merge;
public:
  SegmentTreeSimple(int N, T neutral, function<T(T,T)> merge) {
    this->N = N;
    this->neutral = neutral;
    this->merge = merge;
    data.assign(2*N, neutral);
  }

  SegmentTreeSimple(vector<T> const& A, T neutral, function<T(T,T)> merge) {
    this->N = A.size();
    this->neutral = neutral;
    this->merge = merge;
    data.resize(2*N);
    for (int i = 0; i < N; i++) data[i+N] = A[i];
    for (int i=N-1; i>0; i--)
      data[i]=merge(data[2*i],data[2*i+1]);
  }

  void set(int p, T const& val) {
    for (data[p+=N]=val; p /= 2;)
      data[p] = merge(data[2*p], data[2*p+1]);
  }

  T get(int p) const {
    return data[p + N];
  }

  void add(int p, T const& val) {
    set(p, get(p)+val);
  }

  T sum(int l, int r) const {
    T rl = neutral, rr = neutral;
    for (l+=N, r+=N+1; l<r; l/=2, r/=2) {
      if (l&1) rl = merge(rl, data[l++]);
      if (r&1) rr = merge(data[--r], rr);
    }
    return merge(rl, rr);
  }
};
//}}}

Tested on Library Checker - Point Add Range Sum [Submission]