Skip to content

Commit

Permalink
added ArraySort
Browse files Browse the repository at this point in the history
  • Loading branch information
sivakumar-mahalingam committed Jul 23, 2023
1 parent 62cbb0e commit 8ee6823
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/main/java/mercury/udf/collect/ArrayIntersection.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class ArrayIntersection extends GenericUDF {
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
if (arguments.length < 2) {
throw new UDFArgumentException("IntersectionArraysUDF requires at least 2 arguments");
throw new UDFArgumentException("ArrayIntersection requires at least 2 arguments");
}

for (int i = 0; i < arguments.length; i++) {
Expand Down
77 changes: 77 additions & 0 deletions src/main/java/mercury/udf/collect/ArraySort.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package mercury.udf.collect;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
* UDF for combining multiple lists together
*
* @author: Sivakumar Mahalingam
*/

@Description(name = "array_sort", value = "_FUNC_(a,b) - Returns a sorted list in ascending orders")
public class ArraySort extends GenericUDF {
private StandardListObjectInspector standardListObjectInspector;

@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
if (arguments.length < 1) {
throw new UDFArgumentException("ArraySort requires at least 1 argument");
} else if (arguments.length > 2) {
throw new UDFArgumentException("ArraySort requires at most 2 arguments");
} else if (arguments.length == 2) {
if (!arguments[0].getCategory().equals(ObjectInspector.Category.LIST)) {
throw new UDFArgumentException("Argument 1 must be a list");
}
if (!arguments[1].getCategory().equals(ObjectInspector.Category.PRIMITIVE)) {
throw new UDFArgumentException("Argument 2 must be a string");
}
}

standardListObjectInspector = (StandardListObjectInspector) arguments[0];
return ObjectInspectorFactory.getStandardListObjectInspector(standardListObjectInspector.getListElementObjectInspector());
}

@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
List<Object> sortedList = new ArrayList<>();

List<?> list = (List<?>) ObjectInspectorUtils.copyToStandardObject(arguments[0].get(), standardListObjectInspector);
Object[] arrayObject = list.toArray();

String sortOrder = "asc";
if (arguments.length == 2) {
sortOrder = (String) arguments[1].get();
sortOrder = sortOrder.toLowerCase();
if (!(sortOrder.equals("asc") || sortOrder.equals("desc"))) {
throw new HiveException("Array sort order should be either 'asc' or 'desc'");
}
}

if (sortOrder.equals("desc")) {
Arrays.sort(arrayObject, Collections.reverseOrder());
} else {
Arrays.sort(arrayObject);
}

sortedList.addAll(List.of(arrayObject));

return sortedList;
}

@Override
public String getDisplayString(String[] strings) {
return "ArraySort";
}
}
2 changes: 1 addition & 1 deletion src/main/java/mercury/udf/collect/ArrayUnion.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class ArrayUnion extends GenericUDF {
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
if (arguments.length < 2) {
throw new UDFArgumentException("MergeArraysUDF requires at least 2 arguments");
throw new UDFArgumentException("ArrayUnion requires at least 2 arguments");
}

for (int i = 0; i < arguments.length; i++) {
Expand Down
8 changes: 4 additions & 4 deletions src/test/java/mercury/udf/collect/ArrayIntersectionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void testInitializeWithTwoArguments() throws UDFArgumentException {
}

/*
* {1, 2} {1, 2} = {1, 2}
* {1, 2} {1, 2} = {1, 2}
*/
@Test
public void testEvaluateWithSameElements() throws HiveException {
Expand All @@ -68,7 +68,7 @@ public void testEvaluateWithSameElements() throws HiveException {
}

/*
* {1, 2} {"A", "B"} = Error
* {1, 2} {"A", "B"} = {}
*/
@Test
public void testEvaluateWithDifferentTypes() throws HiveException {
Expand All @@ -89,7 +89,7 @@ public void testEvaluateWithDifferentTypes() throws HiveException {
}

/*
* {1, 2} {2, 3} = {1, 2, 3}
* {1, 2} {2, 3} = {2}
*/
@Test
public void testEvaluateWithTwoArrays() throws HiveException {
Expand All @@ -111,7 +111,7 @@ public void testEvaluateWithTwoArrays() throws HiveException {
}

/*
* {1, 2} {2, 5} u {3, 4} = {1, 2, 3, 4, 5}
* {1, 2} {2, 5} {3, 4} = {2, 4}
*/
@Test
public void testEvaluateWithThreeArrays() throws HiveException {
Expand Down
201 changes: 201 additions & 0 deletions src/test/java/mercury/udf/collect/ArraySortTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package mercury.udf.collect;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.junit.Before;
import org.junit.Test;

import java.util.ArrayList;
import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class ArraySortTest {
private ArraySort udf;

@Before
public void before() {
udf = new ArraySort();
}

@Test(expected = UDFArgumentException.class)
public void testInitializeWithNoArguments() throws UDFArgumentException {
udf.initialize(new ObjectInspector[0]);
}

@Test
public void testInitializeWithOneArgument() throws UDFArgumentException {
ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(stringOi);
udf.initialize(new ObjectInspector[]{listOi});
}

@Test
public void testInitializeWithTwoArguments() throws UDFArgumentException {
ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(stringOi);
udf.initialize(new ObjectInspector[]{listOi, stringOi});
}

/*
* array_sort([1, 3, 2], asc) = [1, 2, 3]
*/
@Test
public void testEvaluateIntegersAscending() throws HiveException {
ObjectInspector intOi = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(intOi);
ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
StandardListObjectInspector resultOi = (StandardListObjectInspector) udf.initialize(new ObjectInspector[]{listOi, stringOi});

List<Integer> one = new ArrayList<Integer>();
one.add(1);
one.add(3);
one.add(2);

List<Integer> two = new ArrayList<Integer>();
two.add(1);
two.add(2);
two.add(3);

String sortOrder = "asc";

Object result = udf.evaluate(new GenericUDF.DeferredObject[]{new DeferredJavaObject(one), new DeferredJavaObject(sortOrder)});
assertEquals(3, resultOi.getListLength(result));
assertEquals(two, resultOi.getList(result));
}

/*
* array_sort([1, 3, 2], desc) = [3, 2, 1]
*/
@Test
public void testEvaluateIntegersDescending() throws HiveException {
ObjectInspector intOi = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(intOi);
ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
StandardListObjectInspector resultOi = (StandardListObjectInspector) udf.initialize(new ObjectInspector[]{listOi, stringOi});

List<Integer> one = new ArrayList<Integer>();
one.add(1);
one.add(3);
one.add(2);

List<Integer> two = new ArrayList<Integer>();
two.add(3);
two.add(2);
two.add(1);

String sortOrder = "desc";

Object result = udf.evaluate(new GenericUDF.DeferredObject[]{new DeferredJavaObject(one), new DeferredJavaObject(sortOrder)});
assertEquals(3, resultOi.getListLength(result));
assertEquals(two, resultOi.getList(result));
}

/*
* array_sort([1, 3, 2]) = [1, 2, 3]
*/
@Test
public void testEvaluateIntegersNoArgument() throws HiveException {
ObjectInspector intOi = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(intOi);
ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
StandardListObjectInspector resultOi = (StandardListObjectInspector) udf.initialize(new ObjectInspector[]{listOi, stringOi});

List<Integer> one = new ArrayList<Integer>();
one.add(1);
one.add(3);
one.add(2);

List<Integer> two = new ArrayList<Integer>();
two.add(1);
two.add(2);
two.add(3);

Object result = udf.evaluate(new GenericUDF.DeferredObject[]{new DeferredJavaObject(one)});
assertEquals(3, resultOi.getListLength(result));
assertEquals(two, resultOi.getList(result));
}

/*
* array_sort([a, c, b], asc) = [a, b, c]
*/
@Test
public void testEvaluateStringsAscending() throws HiveException {
ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(stringOi);
StandardListObjectInspector resultOi = (StandardListObjectInspector) udf.initialize(new ObjectInspector[]{listOi, stringOi});

List<String> one = new ArrayList<String>();
one.add("a");
one.add("c");
one.add("b");

List<String> two = new ArrayList<String>();
two.add("a");
two.add("b");
two.add("c");

String sortOrder = "asc";

Object result = udf.evaluate(new GenericUDF.DeferredObject[]{new DeferredJavaObject(one), new DeferredJavaObject(sortOrder)});
assertEquals(3, resultOi.getListLength(result));
assertEquals(two, resultOi.getList(result));
}

/*
* array_sort([a, c, b], desc) = [c, b, a]
*/
@Test
public void testEvaluateStringsDescending() throws HiveException {
ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(stringOi);
StandardListObjectInspector resultOi = (StandardListObjectInspector) udf.initialize(new ObjectInspector[]{listOi, stringOi});

List<String> one = new ArrayList<String>();
one.add("a");
one.add("c");
one.add("b");

List<String> two = new ArrayList<String>();
two.add("c");
two.add("b");
two.add("a");

String sortOrder = "desc";

Object result = udf.evaluate(new GenericUDF.DeferredObject[]{new DeferredJavaObject(one), new DeferredJavaObject(sortOrder)});
assertEquals(3, resultOi.getListLength(result));
assertEquals(two, resultOi.getList(result));
}

/*
* array_sort([a, c, b]) = [a, b, c]
*/
@Test
public void testEvaluateStringsNoArgument() throws HiveException {
ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(stringOi);
StandardListObjectInspector resultOi = (StandardListObjectInspector) udf.initialize(new ObjectInspector[]{listOi, stringOi});

List<String> one = new ArrayList<String>();
one.add("a");
one.add("c");
one.add("b");

List<String> two = new ArrayList<String>();
two.add("a");
two.add("b");
two.add("c");

Object result = udf.evaluate(new GenericUDF.DeferredObject[]{new DeferredJavaObject(one)});
assertEquals(3, resultOi.getListLength(result));
assertEquals(two, resultOi.getList(result));
}
}

0 comments on commit 8ee6823

Please sign in to comment.