package erp.graph; import java.util.*; import erp.util.*; /** An augmented binary tree representing a (non-strictly) monotonically increasing sequence. Notable operations: * Compute the average value of any contiguous interval in O(log N) time * Ensure that all elements with index after i are at least value v in O(log N) time * Limit all values to v in O(log N) time * Scale all values according to v/(1+v) in O(N) time Organization: Each node maintains three values: a minimum value, a maximum value, and a sum. Nominally: * the sum is the sum of all the children's values. * minimum value: no child is allowed to have a smaller value * maximum value of the children A parent that has minimum==maximum overrides any data of its children: all the children have that value. Indexing scheme: we use an in-place allocation that uses exactly 2N nodes to store N values. We use the term "address" to mean the node number. The address for any index is always 2*i. The parents has an address equal to the average of its children. 7 ROOT ADDRESS 3 11 PARENT ADDRESSES 1 5 9 13 PARENT ADDRESSES 0 2 4 6 8 10 12 14 ADDRESS OF LEAF (0) (1) (2) (3) (4) (5) (6) (7) INDEX Since nodes can override their children, most operations proceed from the root and traverse downwards. Most write operations crawl down the tree looking for the nodes that need to be modified in order to achieve the correct changes, then crawl back up the tree, fixing up the parent nodes of the changed children. The run-time performance derives from the fact that the tree contains monotonically increasing values, and that operations deal with contiguous blocks of indices. Any contiguous block of indices has at most log N ancestors that have only the affected indices and descendents. In the figure above, the indices (5-7) has a single ancestor: 11. The indices 1-7 have ancestors 2, 5, and 11. Write operations need only locate these nodes. **/ public final class LearningRatesTree implements LearningRates { DoubleVector vs; int size; static final int STORAGE_PER_NODE = 3; int ROOT_NODE; int NUM_NODES; public LearningRatesTree() { this(4); } public LearningRatesTree(int sz) { ensureSize(sz); } public LearningRatesTree copy() { LearningRatesTree lt = new LearningRatesTree(size); lt.size = size; lt.vs = vs.copy(); lt.ROOT_NODE = ROOT_NODE; lt.NUM_NODES = NUM_NODES; return lt; } public int size() { return size; } int nextPowerOfTwo(int n) { n |= (n>>1); n |= (n>>2); n |= (n>>4); n |= (n>>8); n |= (n>>16); return n + 1; } public void ensureSize(int newsize) { if (newsize <= size) return; // constraint size to be a power of two: this eliminates many // special cases when crawling the tree. if ((newsize&(newsize-1))!=0) newsize = nextPowerOfTwo(newsize); int vssize = 2*STORAGE_PER_NODE*newsize; if (vs == null) vs = new DoubleVector(vssize); vs.addZeros(vssize - vs.size()); int newRootNode = nextPowerOfTwo(newsize)/2 - 1; int newMaxNodes = 2*newsize; if (size > 0) { int node = ROOT_NODE; while (node != newRootNode) { node = parent(node); int rightchild = rightChild(node); int leftchild = leftChild(node); setMin(rightchild, getMax(leftchild)); setMax(rightchild, getMax(leftchild)); setSum(rightchild, numberOfChildren(rightchild) * getMax(leftchild)); fixup(node); } } ROOT_NODE = newRootNode; NUM_NODES = newMaxNodes; size = newsize; } public void debug() { debug(size); } public void debug(int maxsize) { int TAB = 10; ArrayList lines = new ArrayList(); ArrayList tabs = new ArrayList(); for (int step = 0; step < size && tabs.size()!=1; step++) { int tabidx = 0; String line = ""; ArrayList newtabs = new ArrayList(); for (int idx = (1<=0; i--) System.out.println(lines.get(i)); double values[] = dump(); for (int i = 0; i < Math.min(16, size); i++) { System.out.printf("%4d %6.5f %6.5f %10.5f\n", i, get(i), values[i], cumulativeSum(i)); } System.out.println("\n"); } String spaces(int c) { if (c>0) return String.format("%"+c+"s",""); return ""; } double getMin(int node) { return vs.get(node*STORAGE_PER_NODE); } void setMin(int node, double v) { vs.set(node*STORAGE_PER_NODE, v); } double getMax(int node) { return vs.get(node*STORAGE_PER_NODE + 1); } void setMax(int node, double v) { vs.set(node*STORAGE_PER_NODE+1, v); } double getSum(int node) { return vs.get(node*STORAGE_PER_NODE + 2); } void setSum(int node, double v) { vs.set(node*STORAGE_PER_NODE+2, v); } static final int idx2node(int idx) { return idx*2; } static final boolean isLeftChild(int node) { int v = node + 1; int w = v^(v&(v-1)); return (v&((2*w)))==0; } static final boolean isRightChild(int node) { return !isLeftChild(node); } static final int rightChild(int node) { int v = (node + 1)/2; v = v^(v&(v-1)); return node + v; } static final int leftChild(int node) { int v = (node + 1)/2; v = v^(v&(v-1)); return node - v; } static final int parent(int node) { int v = (node+1); int lo = v^(v&(v-1)); return (node&(~(lo*2)))+lo; } static final int rightSibling(int node) { int v = node + 1; int lo = v^(v&(v-1)); return v + lo*2 -1 ; } static final int leftSibling(int node) { int v = node + 1; int lo = v^(v&(v-1)); return v - lo*2 - 1 ; } static final boolean isLeaf(int node) { return numberOfChildren(node)==1; } static final int numberOfChildren(int node) { int v = node+1; return v^(v&(v-1)); } /** setLowerLimit is a bit more complicated than setUpperLimit * because of the limitation on which indices we can modify: it * means that we must sometimes dive more deeply into the graph, * past a place where min==max. Consequently, as we're descending, * we must track the override value and propagate it as we go. **/ public void setLowerLimit(int idx, double llimit) { setLowerLimitRecurse(idx2node(idx), ROOT_NODE, llimit, false, 0); } void setLowerLimitRecurse(int limitnode, int node, double llimit, boolean overriding, double override_value) { int rightmargin = (rightSibling(node) + node)/2 - 1; int leftmargin = (leftSibling(node) + node)/2 + 1; if (overriding) { setMin(node, override_value); setMax(node, override_value); setSum(node, numberOfChildren(node)*override_value); } if (rightmargin < limitnode) return; double min = getMin(node), max = getMax(node); if (getMin(node) >= llimit) // nothing to do return; if (min == max && leftmargin >= limitnode) { setMin(node, llimit); setMax(node, llimit); setSum(node, numberOfChildren(node)*llimit); return; } if (min == max && !overriding) { overriding = true; override_value = min; } setLowerLimitRecurse(limitnode, leftChild(node), llimit, overriding, override_value); setLowerLimitRecurse(limitnode, rightChild(node), llimit, overriding, override_value); fixup(node); } public void setUpperLimit(double ulimit) { setUpperLimitRecurse(ROOT_NODE, ulimit); } void setUpperLimitRecurse(int node, double ulimit) { double min = getMin(node), max = getMax(node); if (getMax(node) <= ulimit) // nothing to do return; if (min == max) // leaf { setMin(node, ulimit); setMax(node, ulimit); setSum(node, numberOfChildren(node)*ulimit); return; } setUpperLimitRecurse(leftChild(node), ulimit); setUpperLimitRecurse(rightChild(node), ulimit); fixup(node); } /** recompute this node's min, max, sum from its children. It is important that the children be initialized properly (via prepareChild) or some other method, or you'll get garbage here. **/ void fixup(int node) { int left = leftChild(node), right = rightChild(node); setMax(node, getMax(right)); setMin(node, getMin(left)); setSum(node, getSum(left) + getSum(right)); } public double cumulativeSum(int idx) { if (idx < 0) return 0; int node = ROOT_NODE; int goalnode = idx2node(idx); double sum = 0; int nodesToLeft = 0; // a count of the nodes that we've accounted for to our left while (true) { double min = getMin(node), max = getMax(node); if (min == max) { // how many nodes to our left in this sub tree? // System.out.println(idx +" "+nodesToLeft); sum += min*(idx - nodesToLeft + 1); return sum; } // we need to turn left or right. int left = leftChild(node), right = rightChild(node); if (node > goalnode) { node = left; } else { sum += getSum(left); nodesToLeft += numberOfChildren(left); node = right; } } } public double mean(int idx0, int idx1) { double sum0 = cumulativeSum(idx0-1); double sum1 = cumulativeSum(idx1); return (sum1-sum0)/(idx1-idx0+1); } public double get(int idx) { int goalnode = idx2node(idx); int node = ROOT_NODE; // traverse from the root to the leaf, ensuring that there are // no invalid nodes along the path. while (true) { double min = getMin(node), max = getMax(node); if (min == max || node == goalnode) return min; if (node > goalnode) // go right node = leftChild(node); else node = rightChild(node); } } void dump_recurse(int node, double[] values) { double min = getMin(node), max = getMax(node); if (isLeaf(node)) { values[node/2] = getMin(node); return; } if (min == max) { int low = (leftSibling(node) + node)/2 + 1; int high = (rightSibling(node) + node)/2 -1; for (int i = low; i <= high; i++) values[i/2] = min; return; } dump_recurse(leftChild(node), values); dump_recurse(rightChild(node), values); } double[] dump() { double values[] = new double[size]; dump_recurse(ROOT_NODE, values); return values; } /** Decrease each element in the (implicit) array according to n' = n/(n+1) **/ public void age() { double values[] = dump(); vs.setToZero(); // compute new leaf values. for (int i = 0; i < values.length; i++) { double newval = values[i]/(1+values[i]); setMin(i*2, newval); setMax(i*2, newval); setSum(i*2, newval); } int rowoffset = 1; int rowinc = 4; while (true) { for (int i = rowoffset; i < NUM_NODES; i+=rowinc) fixup(i); if (rowoffset == ROOT_NODE) break; rowoffset = rowoffset*2 + 1; rowinc *=2; } } public static void main(String args[]) { for (int i = 0; i < 15; i++) { // System.out.printf("%4d parent %4d\n", i, parent(i)); // System.out.printf("%4d isLeftChild %b\n", i, isLeftChild(i)); // System.out.printf("%4d children are %4d, %4d\n", i, leftChild(i), rightChild(i)); System.out.printf("%4d right sibling is %4d\n", i, rightSibling(i)); System.out.printf("%4d left sibling is %4d\n", i, leftSibling(i)); // System.out.printf("%4d has children: %4d\n", i, numberOfChildren(i)); } LearningRatesTree lrates = new LearningRatesTree(2); lrates.setLowerLimit(0, 1.0); lrates.debug(); lrates.ensureSize(8); lrates.debug(); lrates.setLowerLimit(4, 1.0); lrates.debug(); lrates.setLowerLimit(3, 4.0); lrates.debug(); lrates.setLowerLimit(5, 0); lrates.debug(); lrates.setLowerLimit(6,8); lrates.debug(); lrates.setUpperLimit(2); lrates.debug(); } }