001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.wake.examples.accumulate;
020
021
022import org.apache.reef.wake.Stage;
023import org.apache.reef.wake.rx.Observer;
024
025import java.util.Map;
026import java.util.concurrent.ConcurrentSkipListMap;
027
028public class CombinerStage<K extends Comparable<K>, V> implements Stage {
029
030  private final Combiner<K, V> c;
031  private final Observer<Map.Entry<K, V>> o;
032  private final OutputThread worker = new OutputThread();
033  private final ConcurrentSkipListMap<K, V> register = new ConcurrentSkipListMap<>();
034  private volatile boolean done = false;
035
036  public CombinerStage(final Combiner<K, V> c, final Observer<Map.Entry<K, V>> o) {
037    this.c = c;
038    this.o = o;
039    worker.start();
040  }
041
042  public Observer<Map.Entry<K, V>> wireIn() {
043    return new Observer<Map.Entry<K, V>>() {
044      @Override
045      public void onNext(final Map.Entry<K, V> pair) {
046        V old;
047        V newVal;
048        final boolean wasEmpty = register.isEmpty();
049        boolean succ = false;
050
051        while (!succ) {
052          old = register.get(pair.getKey());
053          newVal = c.combine(pair.getKey(), old, pair.getValue());
054          if (old == null) {
055            succ = (null == register.putIfAbsent(pair.getKey(), newVal));
056          } else {
057            succ = register.replace(pair.getKey(), old, newVal);
058          }
059        }
060
061        if (wasEmpty) {
062          synchronized (register) {
063            register.notify();
064          }
065        }
066      }
067
068      @Override
069      public void onError(final Exception error) {
070        o.onError(error);
071      }
072
073      @Override
074      public void onCompleted() {
075        synchronized (register) {
076          done = true;
077          if (register.isEmpty()) {
078            register.notify();
079          }
080        }
081      }
082    };
083  }
084
085  @Override
086  public void close() throws Exception {
087    worker.join();
088  }
089
090  public interface Combiner<K extends Comparable<K>, V> {
091    V combine(K key, V old, V cur);
092  }
093
094  public static class Pair<K extends Comparable<K>, V> implements Map.Entry<K, V>, Comparable<Map.Entry<K, V>> {
095    private final K k;
096    private final V v;
097
098    public Pair(final K k, final V v) {
099      this.k = k;
100      this.v = v;
101    }
102
103    @Override
104    public int compareTo(final Map.Entry<K, V> arg0) {
105      return k.compareTo(arg0.getKey());
106    }
107
108    @Override
109    public K getKey() {
110      return k;
111    }
112
113    @Override
114    public V getValue() {
115      return v;
116    }
117
118    @Override
119    public V setValue(final V value) {
120      throw new UnsupportedOperationException();
121    }
122  }
123
124  private class OutputThread extends Thread {
125    public OutputThread() {
126      super("grouper-output-thread");
127    }
128
129    @Override
130    public void run() {
131      while (true) {
132        if (register.isEmpty()) {
133          synchronized (register) {
134            while (register.isEmpty() && !done) {
135              try {
136                register.wait();
137              } catch (final InterruptedException e) {
138                throw new IllegalStateException(e);
139              }
140            }
141            if (done) {
142              break;
143            }
144          }
145        }
146        Map.Entry<K, V> cursor = register.pollFirstEntry();
147        while (cursor != null) {
148          o.onNext(cursor);
149          final K nextKey = register.higherKey(cursor.getKey());
150
151          /* If there is more than one OutputThread worker then the remove() -> null case
152           * must be handled
153           */
154          cursor = (nextKey == null) ? null : new Pair<>(nextKey, register.remove(nextKey));
155        }
156      }
157      o.onCompleted();
158    }
159  }
160
161}