diff --git a/src/main/java/mercury/udf/collect/ArraySubtract.java b/src/main/java/mercury/udf/collect/ArraySubtract.java new file mode 100644 index 0000000..91d0504 --- /dev/null +++ b/src/main/java/mercury/udf/collect/ArraySubtract.java @@ -0,0 +1,60 @@ +package mercury.udf.collect; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; + +/** + * UDF for subtracting one list from another + * + * @author: Sivakumar Mahalingam + */ + +@Description(name = "array_subtract", value = "_FUNC_(a,b) - Returns a list of items from one list and not in second") +public class ArraySubtract extends GenericUDF { + private StandardListObjectInspector standardListObjectInspector; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length < 2) { + throw new UDFArgumentException("ArraySubtract requires at least 2 arguments"); + } + + for (int i = 0; i < arguments.length; i++) { + if (!arguments[i].getCategory().equals(ObjectInspector.Category.LIST)) { + throw new UDFArgumentException("Argument " + i + " must be a list"); + } + } + + standardListObjectInspector = (StandardListObjectInspector) arguments[0]; + return ObjectInspectorFactory.getStandardListObjectInspector(standardListObjectInspector.getListElementObjectInspector()); + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + List subtractList = new ArrayList<>(); + List list1 = (List) arguments[0].get(); + List list2 = (List) arguments[1].get(); + + for (Object element : list1) { + if (!list2.contains(element)) { + subtractList.add(element); + } + } + return subtractList; + } + + @Override + public String getDisplayString(String[] strings) { + return "ArraySubtract"; + } +} diff --git a/src/main/resources/mercury.hql b/src/main/resources/mercury.hql index 291fc37..0225216 100644 --- a/src/main/resources/mercury.hql +++ b/src/main/resources/mercury.hql @@ -1,8 +1,9 @@ --collect CREATE FUNCTION array_intersection AS 'mercury.udf.collect.ArrayIntersection'; -CREATE FUNCTION array_union AS 'mercury.udf.collect.ArrayUnion'; CREATE FUNCTION array_sort AS 'mercury.udf.collect.ArraySort'; - +CREATE FUNCTION array_subtract AS 'mercury.udf.collect.ArraySubtract'; +CREATE FUNCTION array_union AS 'mercury.udf.collect.ArrayUnion'; +--{A}=={B} --statistics CREATE FUNCTION jaccard_similarity AS 'mercury.udf.statistics.JaccardSimilarity'; diff --git a/src/test/java/mercury/udf/collect/ArraySubtractTest.java b/src/test/java/mercury/udf/collect/ArraySubtractTest.java new file mode 100644 index 0000000..499bd57 --- /dev/null +++ b/src/test/java/mercury/udf/collect/ArraySubtractTest.java @@ -0,0 +1,91 @@ +package mercury.udf.collect; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class ArraySubtractTest { + private ArraySubtract udf; + + @Before + public void before() { + udf = new ArraySubtract(); + } + + @Test(expected = UDFArgumentException.class) + public void testInitializeWithNoArguments() throws UDFArgumentException { + udf.initialize(new ObjectInspector[0]); + } + + @Test(expected = UDFArgumentException.class) + public void testInitializeWithOneArgument() throws UDFArgumentException { + ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(stringOi); + udf.initialize(new ObjectInspector[]{listOi}); + } + + @Test + public void testInitializeWithTwoArguments() throws UDFArgumentException { + ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(stringOi); + udf.initialize(new ObjectInspector[]{listOi, listOi}); + } + + /* + * {1, 2} - {1, 2} = {} + */ + @Test + public void testEvaluateWithSameElements() throws HiveException { + ObjectInspector intOi = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(intOi); + StandardListObjectInspector resultOi = (StandardListObjectInspector) udf.initialize(new ObjectInspector[]{listOi, listOi}); + + List one = new ArrayList(); + one.add(1); + one.add(2); + + List two = new ArrayList(); + two.add(1); + two.add(2); + + Object result = udf.evaluate(new GenericUDF.DeferredObject[]{new DeferredJavaObject(one), new DeferredJavaObject(two)}); + assertEquals(0, resultOi.getListLength(result)); + } + + /* + * {1, 2, 3} - {1, 2, 4} = {3} + */ + @Test + public void testEvaluateWithTwoArrays() throws HiveException { + ObjectInspector intOi = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(intOi); + StandardListObjectInspector resultOi = (StandardListObjectInspector) udf.initialize(new ObjectInspector[]{listOi, listOi}); + + List one = new ArrayList(); + one.add(1); + one.add(2); + one.add(3); + + List two = new ArrayList(); + two.add(1); + two.add(2); + two.add(4); + + Object result = udf.evaluate(new GenericUDF.DeferredObject[]{new DeferredJavaObject(one), new DeferredJavaObject(two)}); + assertEquals(1, resultOi.getListLength(result)); + assertTrue(resultOi.getList(result).contains(3)); + } +}