001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.examples.group.broadcast;
020
021import org.apache.reef.annotations.audience.DriverSide;
022import org.apache.reef.driver.context.ActiveContext;
023import org.apache.reef.driver.context.ClosedContext;
024import org.apache.reef.driver.context.ContextConfiguration;
025import org.apache.reef.driver.evaluator.AllocatedEvaluator;
026import org.apache.reef.driver.evaluator.EvaluatorRequest;
027import org.apache.reef.driver.evaluator.EvaluatorRequestor;
028import org.apache.reef.driver.task.FailedTask;
029import org.apache.reef.driver.task.TaskConfiguration;
030import org.apache.reef.evaluator.context.parameters.ContextIdentifier;
031import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster;
032import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup;
033import org.apache.reef.examples.group.bgd.parameters.ModelDimensions;
034import org.apache.reef.examples.group.broadcast.parameters.ModelBroadcaster;
035import org.apache.reef.examples.group.broadcast.parameters.ModelReceiveAckReducer;
036import org.apache.reef.examples.group.broadcast.parameters.NumberOfReceivers;
037import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver;
038import org.apache.reef.io.network.group.api.driver.GroupCommDriver;
039import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec;
040import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec;
041import org.apache.reef.io.serialization.SerializableCodec;
042import org.apache.reef.poison.PoisonedConfiguration;
043import org.apache.reef.tang.Configuration;
044import org.apache.reef.tang.Injector;
045import org.apache.reef.tang.Tang;
046import org.apache.reef.tang.annotations.Parameter;
047import org.apache.reef.tang.annotations.Unit;
048import org.apache.reef.tang.exceptions.InjectionException;
049import org.apache.reef.tang.formats.ConfigurationSerializer;
050import org.apache.reef.wake.EventHandler;
051import org.apache.reef.wake.time.event.StartTime;
052
053import javax.inject.Inject;
054import java.util.concurrent.atomic.AtomicBoolean;
055import java.util.concurrent.atomic.AtomicInteger;
056import java.util.logging.Level;
057import java.util.logging.Logger;
058
059@DriverSide
060@Unit
061public class BroadcastDriver {
062
063  private static final Logger LOG = Logger.getLogger(BroadcastDriver.class.getName());
064
065  private final AtomicBoolean masterSubmitted = new AtomicBoolean(false);
066  private final AtomicInteger slaveIds = new AtomicInteger(0);
067  private final AtomicInteger failureSet = new AtomicInteger(0);
068
069  private final GroupCommDriver groupCommDriver;
070  private final CommunicationGroupDriver allCommGroup;
071  private final ConfigurationSerializer confSerializer;
072  private final int dimensions;
073  private final EvaluatorRequestor requestor;
074  private final int numberOfReceivers;
075  private final AtomicInteger numberOfAllocatedEvaluators;
076
077  private String groupCommConfiguredMasterId;
078
079  @Inject
080  public BroadcastDriver(
081      final EvaluatorRequestor requestor,
082      final GroupCommDriver groupCommDriver,
083      final ConfigurationSerializer confSerializer,
084      @Parameter(ModelDimensions.class) final int dimensions,
085      @Parameter(NumberOfReceivers.class) final int numberOfReceivers) {
086
087    this.requestor = requestor;
088    this.groupCommDriver = groupCommDriver;
089    this.confSerializer = confSerializer;
090    this.dimensions = dimensions;
091    this.numberOfReceivers = numberOfReceivers;
092    this.numberOfAllocatedEvaluators = new AtomicInteger(numberOfReceivers + 1);
093
094    this.allCommGroup = this.groupCommDriver.newCommunicationGroup(
095        AllCommunicationGroup.class, numberOfReceivers + 1);
096
097    LOG.info("Obtained all communication group");
098
099    this.allCommGroup
100        .addBroadcast(ControlMessageBroadcaster.class,
101            BroadcastOperatorSpec.newBuilder()
102                .setSenderId(MasterTask.TASK_ID)
103                .setDataCodecClass(SerializableCodec.class)
104                .build())
105        .addBroadcast(ModelBroadcaster.class,
106            BroadcastOperatorSpec.newBuilder()
107                .setSenderId(MasterTask.TASK_ID)
108                .setDataCodecClass(SerializableCodec.class)
109                .build())
110        .addReduce(ModelReceiveAckReducer.class,
111            ReduceOperatorSpec.newBuilder()
112                .setReceiverId(MasterTask.TASK_ID)
113                .setDataCodecClass(SerializableCodec.class)
114                .setReduceFunctionClass(ModelReceiveAckReduceFunction.class)
115                .build())
116        .finalise();
117
118    LOG.info("Added operators to allCommGroup");
119  }
120
121  /**
122   * Handles the StartTime event: Request numOfReceivers Evaluators.
123   */
124  final class StartHandler implements EventHandler<StartTime> {
125    @Override
126    public void onNext(final StartTime startTime) {
127      final int numEvals = BroadcastDriver.this.numberOfReceivers + 1;
128      LOG.log(Level.FINE, "Requesting {0} evaluators", numEvals);
129      BroadcastDriver.this.requestor.submit(EvaluatorRequest.newBuilder()
130          .setNumber(numEvals)
131          .setMemory(2048)
132          .build());
133    }
134  }
135
136  /**
137   * Handles AllocatedEvaluator: Submits a context with an id.
138   */
139  final class EvaluatorAllocatedHandler implements EventHandler<AllocatedEvaluator> {
140    @Override
141    public void onNext(final AllocatedEvaluator allocatedEvaluator) {
142      LOG.log(Level.INFO, "Submitting an id context to AllocatedEvaluator: {0}", allocatedEvaluator);
143      final Configuration contextConfiguration = ContextConfiguration.CONF
144          .set(ContextConfiguration.IDENTIFIER, "BroadcastContext-" +
145              BroadcastDriver.this.numberOfAllocatedEvaluators.getAndDecrement())
146          .build();
147      allocatedEvaluator.submitContext(contextConfiguration);
148    }
149  }
150
151  public class FailedTaskHandler implements EventHandler<FailedTask> {
152
153    @Override
154    public void onNext(final FailedTask failedTask) {
155
156      LOG.log(Level.FINE, "Got failed Task: {0}", failedTask.getId());
157
158      final ActiveContext activeContext = failedTask.getActiveContext().get();
159      final Configuration partialTaskConf = Tang.Factory.getTang()
160          .newConfigurationBuilder(
161              TaskConfiguration.CONF
162                  .set(TaskConfiguration.IDENTIFIER, failedTask.getId())
163                  .set(TaskConfiguration.TASK, SlaveTask.class)
164                  .build(),
165              PoisonedConfiguration.TASK_CONF
166                  .set(PoisonedConfiguration.CRASH_PROBABILITY, "0")
167                  .set(PoisonedConfiguration.CRASH_TIMEOUT, "1")
168                  .build())
169          .bindNamedParameter(ModelDimensions.class, "" + dimensions)
170          .build();
171
172      // Do not add the task back:
173      // allCommGroup.addTask(partialTaskConf);
174
175      final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
176      LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf));
177
178      activeContext.submitTask(taskConf);
179    }
180  }
181
182  public class ContextActiveHandler implements EventHandler<ActiveContext> {
183
184    private final AtomicBoolean storeMasterId = new AtomicBoolean(false);
185
186    @Override
187    public void onNext(final ActiveContext activeContext) {
188
189      LOG.log(Level.FINE, "Got active context: {0}", activeContext.getId());
190
191      /**
192       * The active context can be either from data loading service or after network
193       * service has loaded contexts. So check if the GroupCommDriver knows if it was
194       * configured by one of the communication groups.
195       */
196      if (groupCommDriver.isConfigured(activeContext)) {
197
198        if (activeContext.getId().equals(groupCommConfiguredMasterId) && !masterTaskSubmitted()) {
199
200          final Configuration partialTaskConf = Tang.Factory.getTang()
201              .newConfigurationBuilder(
202                  TaskConfiguration.CONF
203                      .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID)
204                      .set(TaskConfiguration.TASK, MasterTask.class)
205                      .build())
206              .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions))
207              .build();
208
209          allCommGroup.addTask(partialTaskConf);
210
211          final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
212          LOG.log(Level.FINER, "Submit MasterTask conf: {0}", confSerializer.toString(taskConf));
213
214          activeContext.submitTask(taskConf);
215
216        } else {
217
218          final Configuration partialTaskConf = Tang.Factory.getTang()
219              .newConfigurationBuilder(
220                  TaskConfiguration.CONF
221                      .set(TaskConfiguration.IDENTIFIER, getSlaveId(activeContext))
222                      .set(TaskConfiguration.TASK, SlaveTask.class)
223                      .build(),
224                  PoisonedConfiguration.TASK_CONF
225                      .set(PoisonedConfiguration.CRASH_PROBABILITY, "0.4")
226                      .set(PoisonedConfiguration.CRASH_TIMEOUT, "1")
227                      .build())
228              .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions))
229              .build();
230
231          allCommGroup.addTask(partialTaskConf);
232
233          final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
234          LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf));
235
236          activeContext.submitTask(taskConf);
237        }
238      } else {
239
240        final Configuration contextConf = groupCommDriver.getContextConfiguration();
241        final String contextId = contextId(contextConf);
242
243        if (storeMasterId.compareAndSet(false, true)) {
244          groupCommConfiguredMasterId = contextId;
245        }
246
247        final Configuration serviceConf = groupCommDriver.getServiceConfiguration();
248        LOG.log(Level.FINER, "Submit GCContext conf: {0}", confSerializer.toString(contextConf));
249        LOG.log(Level.FINER, "Submit Service conf: {0}", confSerializer.toString(serviceConf));
250
251        activeContext.submitContextAndService(contextConf, serviceConf);
252      }
253    }
254
255    private String contextId(final Configuration contextConf) {
256      try {
257        final Injector injector = Tang.Factory.getTang().newInjector(contextConf);
258        return injector.getNamedInstance(ContextIdentifier.class);
259      } catch (final InjectionException e) {
260        throw new RuntimeException("Unable to inject context identifier from context conf", e);
261      }
262    }
263
264    private String getSlaveId(final ActiveContext activeContext) {
265      return "SlaveTask-" + slaveIds.getAndIncrement();
266    }
267
268    private boolean masterTaskSubmitted() {
269      return !masterSubmitted.compareAndSet(false, true);
270    }
271  }
272
273  public class ContextCloseHandler implements EventHandler<ClosedContext> {
274
275    @Override
276    public void onNext(final ClosedContext closedContext) {
277      LOG.log(Level.FINE, "Got closed context: {0}", closedContext.getId());
278      final ActiveContext parentContext = closedContext.getParentContext();
279      if (parentContext != null) {
280        LOG.log(Level.FINE, "Closing parent context: {0}", parentContext.getId());
281        parentContext.close();
282      }
283    }
284  }
285}