/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen; import java.util.ArrayList; import java.util.List; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.AggregationDesc; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; /** * . Vectorized implementation for VARIANCE aggregates. */ @Description(name = "", value = "") public class extends VectorAggregateExpression { private static final long serialVersionUID = 1L; /** /* class for storing the current aggregate value. */ private static final class Aggregation implements AggregationBuffer { private static final long serialVersionUID = 1L; transient private double sum; transient private long count; transient private double variance; /** * Value is explicitly (re)initialized in reset() (despite the init() bellow...) */ transient private boolean isNull = true; public void init() { isNull = false; sum = 0; count = 0; variance = 0; } @Override public int getVariableSize() { throw new UnsupportedOperationException(); } @Override public void reset () { isNull = true; sum = 0; count = 0; variance = 0; } } private VectorExpression inputExpression; transient private LongWritable resultCount; transient private DoubleWritable resultSum; transient private DoubleWritable resultVariance; transient private Object[] partialResult; transient private ObjectInspector soi; public (VectorExpression inputExpression) { this(); this.inputExpression = inputExpression; } public () { super(); partialResult = new Object[3]; resultCount = new LongWritable(); resultSum = new DoubleWritable(); resultVariance = new DoubleWritable(); partialResult[0] = resultCount; partialResult[1] = resultSum; partialResult[2] = resultVariance; initPartialResultInspector(); } private void initPartialResultInspector() { List foi = new ArrayList(); foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); List fname = new ArrayList(); fname.add("count"); fname.add("sum"); fname.add("variance"); soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); } private Aggregation getCurrentAggregationBuffer( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, int row) { VectorAggregationBufferRow mySet = aggregationBufferSets[row]; Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex); return myagg; } @Override public void aggregateInputSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, VectorizedRowBatch batch) throws HiveException { inputExpression.evaluate(batch); inputVector = ()batch. cols[this.inputExpression.getOutputColumn()]; int batchSize = batch.size; if (batchSize == 0) { return; } [] vector = inputVector.vector; if (inputVector.isRepeating) { if (inputVector.noNulls || !inputVector.isNull[0]) { iterateRepeatingNoNullsWithAggregationSelection( aggregationBufferSets, aggregateIndex, vector[0], batchSize); } } else if (!batch.selectedInUse && inputVector.noNulls) { iterateNoSelectionNoNullsWithAggregationSelection( aggregationBufferSets, aggregateIndex, vector, batchSize); } else if (!batch.selectedInUse) { iterateNoSelectionHasNullsWithAggregationSelection( aggregationBufferSets, aggregateIndex, vector, batchSize, inputVector.isNull); } else if (inputVector.noNulls){ iterateSelectionNoNullsWithAggregationSelection( aggregationBufferSets, aggregateIndex, vector, batchSize, batch.selected); } else { iterateSelectionHasNullsWithAggregationSelection( aggregationBufferSets, aggregateIndex, vector, batchSize, inputVector.isNull, batch.selected); } } private void iterateRepeatingNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, double value, int batchSize) { for (int i=0; i 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } } } private void iterateSelectionHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, [] vector, int batchSize, boolean[] isNull, int[] selected) { for (int j=0; j< batchSize; ++j) { Aggregation myagg = getCurrentAggregationBuffer( aggregationBufferSets, aggregateIndex, j); int i = selected[j]; if (!isNull[i]) { double value = vector[i]; if (myagg.isNull) { myagg.init (); } myagg.sum += value; myagg.count += 1; if(myagg.count > 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } } } } private void iterateSelectionNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, [] vector, int batchSize, int[] selected) { for (int i=0; i< batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( aggregationBufferSets, aggregateIndex, i); double value = vector[selected[i]]; if (myagg.isNull) { myagg.init (); } myagg.sum += value; myagg.count += 1; if(myagg.count > 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } } } private void iterateNoSelectionHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, [] vector, int batchSize, boolean[] isNull) { for(int i=0;i 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } } } } private void iterateNoSelectionNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, [] vector, int batchSize) { for (int i=0; i 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } } } @Override public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) throws HiveException { inputExpression.evaluate(batch); inputVector = ()batch. cols[this.inputExpression.getOutputColumn()]; int batchSize = batch.size; if (batchSize == 0) { return; } Aggregation myagg = (Aggregation)agg; [] vector = inputVector.vector; if (inputVector.isRepeating) { if (inputVector.noNulls) { iterateRepeatingNoNulls(myagg, vector[0], batchSize); } } else if (!batch.selectedInUse && inputVector.noNulls) { iterateNoSelectionNoNulls(myagg, vector, batchSize); } else if (!batch.selectedInUse) { iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull); } else if (inputVector.noNulls){ iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected); } else { iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected); } } private void iterateRepeatingNoNulls( Aggregation myagg, double value, int batchSize) { if (myagg.isNull) { myagg.init (); } // TODO: conjure a formula w/o iterating // myagg.sum += value; myagg.count += 1; if(myagg.count > 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } // We pulled out i=0 so we can remove the count > 1 check in the loop for (int i=1; i[] vector, int batchSize, boolean[] isNull, int[] selected) { for (int j=0; j< batchSize; ++j) { int i = selected[j]; if (!isNull[i]) { double value = vector[i]; if (myagg.isNull) { myagg.init (); } myagg.sum += value; myagg.count += 1; if(myagg.count > 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } } } } private void iterateSelectionNoNulls( Aggregation myagg, [] vector, int batchSize, int[] selected) { if (myagg.isNull) { myagg.init (); } double value = vector[selected[0]]; myagg.sum += value; myagg.count += 1; if(myagg.count > 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } // i=0 was pulled out to remove the count > 1 check in the loop // for (int i=1; i< batchSize; ++i) { value = vector[selected[i]]; myagg.sum += value; myagg.count += 1; double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } } private void iterateNoSelectionHasNulls( Aggregation myagg, [] vector, int batchSize, boolean[] isNull) { for(int i=0;i 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } } } } private void iterateNoSelectionNoNulls( Aggregation myagg, [] vector, int batchSize) { if (myagg.isNull) { myagg.init (); } double value = vector[0]; myagg.sum += value; myagg.count += 1; if(myagg.count > 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } // i=0 was pulled out to remove count > 1 check for (int i=1; i