1 | |
|
2 | |
|
3 | |
|
4 | |
|
5 | |
|
6 | |
|
7 | |
|
8 | |
|
9 | |
|
10 | |
|
11 | |
|
12 | |
|
13 | |
|
14 | |
|
15 | |
|
16 | |
|
17 | |
|
18 | |
|
19 | |
package org.apache.giraph.examples; |
20 | |
|
21 | |
import org.apache.giraph.aggregators.LongSumAggregator; |
22 | |
import org.apache.giraph.bsp.BspInputSplit; |
23 | |
import org.apache.giraph.edge.Edge; |
24 | |
import org.apache.giraph.edge.EdgeFactory; |
25 | |
import org.apache.giraph.graph.BasicComputation; |
26 | |
import org.apache.giraph.master.DefaultMasterCompute; |
27 | |
import org.apache.giraph.graph.Vertex; |
28 | |
import org.apache.giraph.io.EdgeInputFormat; |
29 | |
import org.apache.giraph.io.EdgeReader; |
30 | |
import org.apache.giraph.io.VertexReader; |
31 | |
import org.apache.giraph.io.formats.GeneratedVertexInputFormat; |
32 | |
import org.apache.hadoop.conf.Configuration; |
33 | |
import org.apache.hadoop.io.DoubleWritable; |
34 | |
import org.apache.hadoop.io.FloatWritable; |
35 | |
import org.apache.hadoop.io.LongWritable; |
36 | |
import org.apache.hadoop.mapreduce.InputSplit; |
37 | |
import org.apache.hadoop.mapreduce.JobContext; |
38 | |
import org.apache.hadoop.mapreduce.TaskAttemptContext; |
39 | |
import org.apache.log4j.Logger; |
40 | |
|
41 | |
import com.google.common.collect.Lists; |
42 | |
|
43 | |
import java.io.IOException; |
44 | |
import java.util.ArrayList; |
45 | |
import java.util.List; |
46 | |
|
47 | |
|
48 | 0 | public class AggregatorsTestComputation extends |
49 | |
BasicComputation<LongWritable, DoubleWritable, FloatWritable, |
50 | |
DoubleWritable> { |
51 | |
|
52 | |
|
53 | |
private static final String REGULAR_AGG = "regular"; |
54 | |
|
55 | |
private static final String PERSISTENT_AGG = "persistent"; |
56 | |
|
57 | |
private static final String INPUT_VERTEX_PERSISTENT_AGG |
58 | |
= "input_super_step_vertex_agg"; |
59 | |
|
60 | |
private static final String INPUT_EDGE_PERSISTENT_AGG |
61 | |
= "input_super_step_edge_agg"; |
62 | |
|
63 | |
private static final String MASTER_WRITE_AGG = "master"; |
64 | |
|
65 | |
private static final long MASTER_VALUE = 12345; |
66 | |
|
67 | |
private static final String ARRAY_PREFIX_AGG = "array"; |
68 | |
|
69 | |
private static final int NUM_OF_AGGREGATORS_IN_ARRAY = 100; |
70 | |
|
71 | |
@Override |
72 | |
public void compute( |
73 | |
Vertex<LongWritable, DoubleWritable, FloatWritable> vertex, |
74 | |
Iterable<DoubleWritable> messages) throws IOException { |
75 | 0 | long superstep = getSuperstep(); |
76 | |
|
77 | 0 | LongWritable myValue = new LongWritable(1L << superstep); |
78 | 0 | aggregate(REGULAR_AGG, myValue); |
79 | 0 | aggregate(PERSISTENT_AGG, myValue); |
80 | |
|
81 | 0 | long nv = getTotalNumVertices(); |
82 | 0 | if (superstep > 0) { |
83 | 0 | assertEquals(nv * (1L << (superstep - 1)), |
84 | 0 | ((LongWritable) getAggregatedValue(REGULAR_AGG)).get()); |
85 | |
} else { |
86 | 0 | assertEquals(0, |
87 | 0 | ((LongWritable) getAggregatedValue(REGULAR_AGG)).get()); |
88 | |
} |
89 | 0 | assertEquals(nv * ((1L << superstep) - 1), |
90 | 0 | ((LongWritable) getAggregatedValue(PERSISTENT_AGG)).get()); |
91 | 0 | assertEquals(MASTER_VALUE * (1L << superstep), |
92 | 0 | ((LongWritable) getAggregatedValue(MASTER_WRITE_AGG)).get()); |
93 | |
|
94 | 0 | for (int i = 0; i < NUM_OF_AGGREGATORS_IN_ARRAY; i++) { |
95 | 0 | aggregate(ARRAY_PREFIX_AGG + i, new LongWritable((superstep + 1) * i)); |
96 | 0 | assertEquals(superstep * getTotalNumVertices() * i, |
97 | 0 | ((LongWritable) getAggregatedValue(ARRAY_PREFIX_AGG + i)).get()); |
98 | |
} |
99 | |
|
100 | 0 | if (getSuperstep() == 10) { |
101 | 0 | vertex.voteToHalt(); |
102 | |
} |
103 | 0 | } |
104 | |
|
105 | |
|
106 | 0 | public static class AggregatorsTestMasterCompute extends |
107 | |
DefaultMasterCompute { |
108 | |
@Override |
109 | |
public void compute() { |
110 | 0 | long superstep = getSuperstep(); |
111 | |
|
112 | 0 | LongWritable myValue = |
113 | |
new LongWritable(MASTER_VALUE * (1L << superstep)); |
114 | 0 | setAggregatedValue(MASTER_WRITE_AGG, myValue); |
115 | |
|
116 | 0 | long nv = getTotalNumVertices(); |
117 | 0 | if (superstep >= 0) { |
118 | 0 | assertEquals(100, ((LongWritable) |
119 | 0 | getAggregatedValue(INPUT_VERTEX_PERSISTENT_AGG)).get()); |
120 | |
} |
121 | 0 | if (superstep >= 0) { |
122 | 0 | assertEquals(4500, ((LongWritable) |
123 | 0 | getAggregatedValue(INPUT_EDGE_PERSISTENT_AGG)).get()); |
124 | |
} |
125 | 0 | if (superstep > 0) { |
126 | 0 | assertEquals(nv * (1L << (superstep - 1)), |
127 | 0 | ((LongWritable) getAggregatedValue(REGULAR_AGG)).get()); |
128 | |
} else { |
129 | 0 | assertEquals(0, |
130 | 0 | ((LongWritable) getAggregatedValue(REGULAR_AGG)).get()); |
131 | |
} |
132 | 0 | assertEquals(nv * ((1L << superstep) - 1), |
133 | 0 | ((LongWritable) getAggregatedValue(PERSISTENT_AGG)).get()); |
134 | |
|
135 | 0 | for (int i = 0; i < NUM_OF_AGGREGATORS_IN_ARRAY; i++) { |
136 | 0 | assertEquals(superstep * getTotalNumVertices() * i, |
137 | 0 | ((LongWritable) getAggregatedValue(ARRAY_PREFIX_AGG + i)).get()); |
138 | |
} |
139 | 0 | } |
140 | |
|
141 | |
@Override |
142 | |
public void initialize() throws InstantiationException, |
143 | |
IllegalAccessException { |
144 | 0 | registerPersistentAggregator( |
145 | |
INPUT_VERTEX_PERSISTENT_AGG, LongSumAggregator.class); |
146 | 0 | registerPersistentAggregator( |
147 | |
INPUT_EDGE_PERSISTENT_AGG, LongSumAggregator.class); |
148 | 0 | registerAggregator(REGULAR_AGG, LongSumAggregator.class); |
149 | 0 | registerPersistentAggregator(PERSISTENT_AGG, |
150 | |
LongSumAggregator.class); |
151 | 0 | registerAggregator(MASTER_WRITE_AGG, LongSumAggregator.class); |
152 | |
|
153 | 0 | for (int i = 0; i < NUM_OF_AGGREGATORS_IN_ARRAY; i++) { |
154 | 0 | registerAggregator(ARRAY_PREFIX_AGG + i, LongSumAggregator.class); |
155 | |
} |
156 | 0 | } |
157 | |
} |
158 | |
|
159 | |
|
160 | |
|
161 | |
|
162 | |
|
163 | |
|
164 | |
|
165 | |
private static void assertEquals(long expected, long actual) { |
166 | 0 | if (expected != actual) { |
167 | 0 | throw new RuntimeException("expected: " + expected + |
168 | |
", actual: " + actual); |
169 | |
} |
170 | 0 | } |
171 | |
|
172 | |
|
173 | |
|
174 | |
|
175 | 0 | public static class SimpleVertexReader extends |
176 | |
GeneratedVertexReader<LongWritable, DoubleWritable, FloatWritable> { |
177 | |
|
178 | 0 | private static final Logger LOG = |
179 | 0 | Logger.getLogger(SimpleVertexReader.class); |
180 | |
|
181 | |
@Override |
182 | |
public boolean nextVertex() { |
183 | 0 | return totalRecords > recordsRead; |
184 | |
} |
185 | |
|
186 | |
@Override |
187 | |
public Vertex<LongWritable, DoubleWritable, |
188 | |
FloatWritable> getCurrentVertex() throws IOException { |
189 | 0 | Vertex<LongWritable, DoubleWritable, FloatWritable> vertex = |
190 | 0 | getConf().createVertex(); |
191 | 0 | LongWritable vertexId = new LongWritable( |
192 | 0 | (inputSplit.getSplitIndex() * totalRecords) + recordsRead); |
193 | 0 | DoubleWritable vertexValue = new DoubleWritable(vertexId.get() * 10d); |
194 | 0 | long targetVertexId = |
195 | 0 | (vertexId.get() + 1) % |
196 | 0 | (inputSplit.getNumSplits() * totalRecords); |
197 | 0 | float edgeValue = vertexId.get() * 100f; |
198 | 0 | List<Edge<LongWritable, FloatWritable>> edges = Lists.newLinkedList(); |
199 | 0 | edges.add(EdgeFactory.create(new LongWritable(targetVertexId), |
200 | |
new FloatWritable(edgeValue))); |
201 | 0 | vertex.initialize(vertexId, vertexValue, edges); |
202 | 0 | ++recordsRead; |
203 | 0 | if (LOG.isInfoEnabled()) { |
204 | 0 | LOG.info("next vertex: Return vertexId=" + vertex.getId().get() + |
205 | 0 | ", vertexValue=" + vertex.getValue() + |
206 | |
", targetVertexId=" + targetVertexId + ", edgeValue=" + edgeValue); |
207 | |
} |
208 | 0 | aggregate(INPUT_VERTEX_PERSISTENT_AGG, |
209 | 0 | new LongWritable((long) vertex.getValue().get())); |
210 | 0 | return vertex; |
211 | |
} |
212 | |
} |
213 | |
|
214 | |
|
215 | |
|
216 | |
|
217 | 0 | public static class SimpleVertexInputFormat extends |
218 | |
GeneratedVertexInputFormat<LongWritable, DoubleWritable, FloatWritable> { |
219 | |
@Override |
220 | |
public VertexReader<LongWritable, DoubleWritable, |
221 | |
FloatWritable> createVertexReader(InputSplit split, |
222 | |
TaskAttemptContext context) |
223 | |
throws IOException { |
224 | 0 | return new SimpleVertexReader(); |
225 | |
} |
226 | |
} |
227 | |
|
228 | |
|
229 | |
|
230 | |
|
231 | 0 | public static class SimpleEdgeReader extends |
232 | |
GeneratedEdgeReader<LongWritable, FloatWritable> { |
233 | |
|
234 | 0 | private static final Logger LOG = Logger.getLogger(SimpleEdgeReader.class); |
235 | |
|
236 | |
@Override |
237 | |
public boolean nextEdge() { |
238 | 0 | return totalRecords > recordsRead; |
239 | |
} |
240 | |
|
241 | |
@Override |
242 | |
public Edge<LongWritable, FloatWritable> getCurrentEdge() |
243 | |
throws IOException { |
244 | 0 | LongWritable vertexId = new LongWritable( |
245 | 0 | (inputSplit.getSplitIndex() * totalRecords) + recordsRead); |
246 | 0 | long targetVertexId = (vertexId.get() + 1) % |
247 | 0 | (inputSplit.getNumSplits() * totalRecords); |
248 | 0 | float edgeValue = vertexId.get() * 100f; |
249 | 0 | Edge<LongWritable, FloatWritable> edge = EdgeFactory.create( |
250 | |
new LongWritable(targetVertexId), new FloatWritable(edgeValue)); |
251 | 0 | ++recordsRead; |
252 | 0 | if (LOG.isInfoEnabled()) { |
253 | 0 | LOG.info("next edge: Return targetVertexId=" + targetVertexId + |
254 | |
", edgeValue=" + edgeValue); |
255 | |
} |
256 | 0 | aggregate(INPUT_EDGE_PERSISTENT_AGG, new LongWritable((long) edge |
257 | 0 | .getValue().get())); |
258 | 0 | return edge; |
259 | |
} |
260 | |
|
261 | |
@Override |
262 | |
public LongWritable getCurrentSourceId() throws IOException, |
263 | |
InterruptedException { |
264 | 0 | LongWritable vertexId = new LongWritable( |
265 | 0 | (inputSplit.getSplitIndex() * totalRecords) + recordsRead); |
266 | 0 | return vertexId; |
267 | |
} |
268 | |
} |
269 | |
|
270 | |
|
271 | |
|
272 | |
|
273 | 0 | public static class SimpleEdgeInputFormat extends |
274 | |
EdgeInputFormat<LongWritable, FloatWritable> { |
275 | 0 | @Override public void checkInputSpecs(Configuration conf) { } |
276 | |
|
277 | |
@Override |
278 | |
public EdgeReader<LongWritable, FloatWritable> createEdgeReader( |
279 | |
InputSplit split, TaskAttemptContext context) throws IOException { |
280 | 0 | return new SimpleEdgeReader(); |
281 | |
} |
282 | |
|
283 | |
@Override |
284 | |
public List<InputSplit> getSplits(JobContext context, int minSplitCountHint) |
285 | |
throws IOException, InterruptedException { |
286 | 0 | List<InputSplit> inputSplitList = new ArrayList<InputSplit>(); |
287 | 0 | for (int i = 0; i < minSplitCountHint; ++i) { |
288 | 0 | inputSplitList.add(new BspInputSplit(i, minSplitCountHint)); |
289 | |
} |
290 | 0 | return inputSplitList; |
291 | |
} |
292 | |
} |
293 | |
} |