001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.examples.group.broadcast; 020 021import org.apache.reef.annotations.audience.DriverSide; 022import org.apache.reef.driver.context.ActiveContext; 023import org.apache.reef.driver.context.ClosedContext; 024import org.apache.reef.driver.context.ContextConfiguration; 025import org.apache.reef.driver.evaluator.AllocatedEvaluator; 026import org.apache.reef.driver.evaluator.EvaluatorRequest; 027import org.apache.reef.driver.evaluator.EvaluatorRequestor; 028import org.apache.reef.driver.task.FailedTask; 029import org.apache.reef.driver.task.TaskConfiguration; 030import org.apache.reef.evaluator.context.parameters.ContextIdentifier; 031import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster; 032import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup; 033import org.apache.reef.examples.group.bgd.parameters.ModelDimensions; 034import org.apache.reef.examples.group.broadcast.parameters.ModelBroadcaster; 035import org.apache.reef.examples.group.broadcast.parameters.ModelReceiveAckReducer; 036import org.apache.reef.examples.group.broadcast.parameters.NumberOfReceivers; 037import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver; 038import org.apache.reef.io.network.group.api.driver.GroupCommDriver; 039import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec; 040import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec; 041import org.apache.reef.io.serialization.SerializableCodec; 042import org.apache.reef.poison.PoisonedConfiguration; 043import org.apache.reef.tang.Configuration; 044import org.apache.reef.tang.Injector; 045import org.apache.reef.tang.Tang; 046import org.apache.reef.tang.annotations.Parameter; 047import org.apache.reef.tang.annotations.Unit; 048import org.apache.reef.tang.exceptions.InjectionException; 049import org.apache.reef.tang.formats.ConfigurationSerializer; 050import org.apache.reef.wake.EventHandler; 051import org.apache.reef.wake.time.event.StartTime; 052 053import javax.inject.Inject; 054import java.util.concurrent.atomic.AtomicBoolean; 055import java.util.concurrent.atomic.AtomicInteger; 056import java.util.logging.Level; 057import java.util.logging.Logger; 058 059@DriverSide 060@Unit 061public class BroadcastDriver { 062 063 private static final Logger LOG = Logger.getLogger(BroadcastDriver.class.getName()); 064 065 private final AtomicBoolean masterSubmitted = new AtomicBoolean(false); 066 private final AtomicInteger slaveIds = new AtomicInteger(0); 067 private final AtomicInteger failureSet = new AtomicInteger(0); 068 069 private final GroupCommDriver groupCommDriver; 070 private final CommunicationGroupDriver allCommGroup; 071 private final ConfigurationSerializer confSerializer; 072 private final int dimensions; 073 private final EvaluatorRequestor requestor; 074 private final int numberOfReceivers; 075 private final AtomicInteger numberOfAllocatedEvaluators; 076 077 private String groupCommConfiguredMasterId; 078 079 @Inject 080 public BroadcastDriver( 081 final EvaluatorRequestor requestor, 082 final GroupCommDriver groupCommDriver, 083 final ConfigurationSerializer confSerializer, 084 @Parameter(ModelDimensions.class) final int dimensions, 085 @Parameter(NumberOfReceivers.class) final int numberOfReceivers) { 086 087 this.requestor = requestor; 088 this.groupCommDriver = groupCommDriver; 089 this.confSerializer = confSerializer; 090 this.dimensions = dimensions; 091 this.numberOfReceivers = numberOfReceivers; 092 this.numberOfAllocatedEvaluators = new AtomicInteger(numberOfReceivers + 1); 093 094 this.allCommGroup = this.groupCommDriver.newCommunicationGroup( 095 AllCommunicationGroup.class, numberOfReceivers + 1); 096 097 LOG.info("Obtained all communication group"); 098 099 this.allCommGroup 100 .addBroadcast(ControlMessageBroadcaster.class, 101 BroadcastOperatorSpec.newBuilder() 102 .setSenderId(MasterTask.TASK_ID) 103 .setDataCodecClass(SerializableCodec.class) 104 .build()) 105 .addBroadcast(ModelBroadcaster.class, 106 BroadcastOperatorSpec.newBuilder() 107 .setSenderId(MasterTask.TASK_ID) 108 .setDataCodecClass(SerializableCodec.class) 109 .build()) 110 .addReduce(ModelReceiveAckReducer.class, 111 ReduceOperatorSpec.newBuilder() 112 .setReceiverId(MasterTask.TASK_ID) 113 .setDataCodecClass(SerializableCodec.class) 114 .setReduceFunctionClass(ModelReceiveAckReduceFunction.class) 115 .build()) 116 .finalise(); 117 118 LOG.info("Added operators to allCommGroup"); 119 } 120 121 /** 122 * Handles the StartTime event: Request numOfReceivers Evaluators. 123 */ 124 final class StartHandler implements EventHandler<StartTime> { 125 @Override 126 public void onNext(final StartTime startTime) { 127 final int numEvals = BroadcastDriver.this.numberOfReceivers + 1; 128 LOG.log(Level.FINE, "Requesting {0} evaluators", numEvals); 129 BroadcastDriver.this.requestor.submit(EvaluatorRequest.newBuilder() 130 .setNumber(numEvals) 131 .setMemory(2048) 132 .build()); 133 } 134 } 135 136 /** 137 * Handles AllocatedEvaluator: Submits a context with an id. 138 */ 139 final class EvaluatorAllocatedHandler implements EventHandler<AllocatedEvaluator> { 140 @Override 141 public void onNext(final AllocatedEvaluator allocatedEvaluator) { 142 LOG.log(Level.INFO, "Submitting an id context to AllocatedEvaluator: {0}", allocatedEvaluator); 143 final Configuration contextConfiguration = ContextConfiguration.CONF 144 .set(ContextConfiguration.IDENTIFIER, "BroadcastContext-" + 145 BroadcastDriver.this.numberOfAllocatedEvaluators.getAndDecrement()) 146 .build(); 147 allocatedEvaluator.submitContext(contextConfiguration); 148 } 149 } 150 151 public class FailedTaskHandler implements EventHandler<FailedTask> { 152 153 @Override 154 public void onNext(final FailedTask failedTask) { 155 156 LOG.log(Level.FINE, "Got failed Task: {0}", failedTask.getId()); 157 158 final ActiveContext activeContext = failedTask.getActiveContext().get(); 159 final Configuration partialTaskConf = Tang.Factory.getTang() 160 .newConfigurationBuilder( 161 TaskConfiguration.CONF 162 .set(TaskConfiguration.IDENTIFIER, failedTask.getId()) 163 .set(TaskConfiguration.TASK, SlaveTask.class) 164 .build(), 165 PoisonedConfiguration.TASK_CONF 166 .set(PoisonedConfiguration.CRASH_PROBABILITY, "0") 167 .set(PoisonedConfiguration.CRASH_TIMEOUT, "1") 168 .build()) 169 .bindNamedParameter(ModelDimensions.class, "" + dimensions) 170 .build(); 171 172 // Do not add the task back: 173 // allCommGroup.addTask(partialTaskConf); 174 175 final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); 176 LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf)); 177 178 activeContext.submitTask(taskConf); 179 } 180 } 181 182 public class ContextActiveHandler implements EventHandler<ActiveContext> { 183 184 private final AtomicBoolean storeMasterId = new AtomicBoolean(false); 185 186 @Override 187 public void onNext(final ActiveContext activeContext) { 188 189 LOG.log(Level.FINE, "Got active context: {0}", activeContext.getId()); 190 191 /** 192 * The active context can be either from data loading service or after network 193 * service has loaded contexts. So check if the GroupCommDriver knows if it was 194 * configured by one of the communication groups. 195 */ 196 if (groupCommDriver.isConfigured(activeContext)) { 197 198 if (activeContext.getId().equals(groupCommConfiguredMasterId) && !masterTaskSubmitted()) { 199 200 final Configuration partialTaskConf = Tang.Factory.getTang() 201 .newConfigurationBuilder( 202 TaskConfiguration.CONF 203 .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID) 204 .set(TaskConfiguration.TASK, MasterTask.class) 205 .build()) 206 .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions)) 207 .build(); 208 209 allCommGroup.addTask(partialTaskConf); 210 211 final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); 212 LOG.log(Level.FINER, "Submit MasterTask conf: {0}", confSerializer.toString(taskConf)); 213 214 activeContext.submitTask(taskConf); 215 216 } else { 217 218 final Configuration partialTaskConf = Tang.Factory.getTang() 219 .newConfigurationBuilder( 220 TaskConfiguration.CONF 221 .set(TaskConfiguration.IDENTIFIER, getSlaveId(activeContext)) 222 .set(TaskConfiguration.TASK, SlaveTask.class) 223 .build(), 224 PoisonedConfiguration.TASK_CONF 225 .set(PoisonedConfiguration.CRASH_PROBABILITY, "0.4") 226 .set(PoisonedConfiguration.CRASH_TIMEOUT, "1") 227 .build()) 228 .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions)) 229 .build(); 230 231 allCommGroup.addTask(partialTaskConf); 232 233 final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); 234 LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf)); 235 236 activeContext.submitTask(taskConf); 237 } 238 } else { 239 240 final Configuration contextConf = groupCommDriver.getContextConfiguration(); 241 final String contextId = contextId(contextConf); 242 243 if (storeMasterId.compareAndSet(false, true)) { 244 groupCommConfiguredMasterId = contextId; 245 } 246 247 final Configuration serviceConf = groupCommDriver.getServiceConfiguration(); 248 LOG.log(Level.FINER, "Submit GCContext conf: {0}", confSerializer.toString(contextConf)); 249 LOG.log(Level.FINER, "Submit Service conf: {0}", confSerializer.toString(serviceConf)); 250 251 activeContext.submitContextAndService(contextConf, serviceConf); 252 } 253 } 254 255 private String contextId(final Configuration contextConf) { 256 try { 257 final Injector injector = Tang.Factory.getTang().newInjector(contextConf); 258 return injector.getNamedInstance(ContextIdentifier.class); 259 } catch (final InjectionException e) { 260 throw new RuntimeException("Unable to inject context identifier from context conf", e); 261 } 262 } 263 264 private String getSlaveId(final ActiveContext activeContext) { 265 return "SlaveTask-" + slaveIds.getAndIncrement(); 266 } 267 268 private boolean masterTaskSubmitted() { 269 return !masterSubmitted.compareAndSet(false, true); 270 } 271 } 272 273 public class ContextCloseHandler implements EventHandler<ClosedContext> { 274 275 @Override 276 public void onNext(final ClosedContext closedContext) { 277 LOG.log(Level.FINE, "Got closed context: {0}", closedContext.getId()); 278 final ActiveContext parentContext = closedContext.getParentContext(); 279 if (parentContext != null) { 280 LOG.log(Level.FINE, "Closing parent context: {0}", parentContext.getId()); 281 parentContext.close(); 282 } 283 } 284 } 285}