Skip to content

Commit

Permalink
added ArraySubtract
Browse files Browse the repository at this point in the history
  • Loading branch information
sivakumar-mahalingam committed Oct 24, 2023
1 parent 14c84d0 commit 7a54175
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 2 deletions.
60 changes: 60 additions & 0 deletions src/main/java/mercury/udf/collect/ArraySubtract.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package mercury.udf.collect;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;

/**
* UDF for subtracting one list from another
*
* @author: Sivakumar Mahalingam
*/

@Description(name = "array_subtract", value = "_FUNC_(a,b) - Returns a list of items from one list and not in second")
public class ArraySubtract extends GenericUDF {
private StandardListObjectInspector standardListObjectInspector;

@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
if (arguments.length < 2) {
throw new UDFArgumentException("ArraySubtract requires at least 2 arguments");
}

for (int i = 0; i < arguments.length; i++) {
if (!arguments[i].getCategory().equals(ObjectInspector.Category.LIST)) {
throw new UDFArgumentException("Argument " + i + " must be a list");
}
}

standardListObjectInspector = (StandardListObjectInspector) arguments[0];
return ObjectInspectorFactory.getStandardListObjectInspector(standardListObjectInspector.getListElementObjectInspector());
}

@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
List<Object> subtractList = new ArrayList<>();
List<Object> list1 = (List<Object>) arguments[0].get();
List<Object> list2 = (List<Object>) arguments[1].get();

for (Object element : list1) {
if (!list2.contains(element)) {
subtractList.add(element);
}
}
return subtractList;
}

@Override
public String getDisplayString(String[] strings) {
return "ArraySubtract";
}
}
5 changes: 3 additions & 2 deletions src/main/resources/mercury.hql
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
--collect
CREATE FUNCTION array_intersection AS 'mercury.udf.collect.ArrayIntersection';
CREATE FUNCTION array_union AS 'mercury.udf.collect.ArrayUnion';
CREATE FUNCTION array_sort AS 'mercury.udf.collect.ArraySort';

CREATE FUNCTION array_subtract AS 'mercury.udf.collect.ArraySubtract';
CREATE FUNCTION array_union AS 'mercury.udf.collect.ArrayUnion';
--{A}=={B}

--statistics
CREATE FUNCTION jaccard_similarity AS 'mercury.udf.statistics.JaccardSimilarity';
Expand Down
91 changes: 91 additions & 0 deletions src/test/java/mercury/udf/collect/ArraySubtractTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package mercury.udf.collect;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.junit.Before;
import org.junit.Test;

import java.util.ArrayList;
import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class ArraySubtractTest {
private ArraySubtract udf;

@Before
public void before() {
udf = new ArraySubtract();
}

@Test(expected = UDFArgumentException.class)
public void testInitializeWithNoArguments() throws UDFArgumentException {
udf.initialize(new ObjectInspector[0]);
}

@Test(expected = UDFArgumentException.class)
public void testInitializeWithOneArgument() throws UDFArgumentException {
ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(stringOi);
udf.initialize(new ObjectInspector[]{listOi});
}

@Test
public void testInitializeWithTwoArguments() throws UDFArgumentException {
ObjectInspector stringOi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(stringOi);
udf.initialize(new ObjectInspector[]{listOi, listOi});
}

/*
* {1, 2} - {1, 2} = {}
*/
@Test
public void testEvaluateWithSameElements() throws HiveException {
ObjectInspector intOi = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(intOi);
StandardListObjectInspector resultOi = (StandardListObjectInspector) udf.initialize(new ObjectInspector[]{listOi, listOi});

List<Integer> one = new ArrayList<Integer>();
one.add(1);
one.add(2);

List<Integer> two = new ArrayList<Integer>();
two.add(1);
two.add(2);

Object result = udf.evaluate(new GenericUDF.DeferredObject[]{new DeferredJavaObject(one), new DeferredJavaObject(two)});
assertEquals(0, resultOi.getListLength(result));
}

/*
* {1, 2, 3} - {1, 2, 4} = {3}
*/
@Test
public void testEvaluateWithTwoArrays() throws HiveException {
ObjectInspector intOi = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector listOi = ObjectInspectorFactory.getStandardListObjectInspector(intOi);
StandardListObjectInspector resultOi = (StandardListObjectInspector) udf.initialize(new ObjectInspector[]{listOi, listOi});

List<Integer> one = new ArrayList<Integer>();
one.add(1);
one.add(2);
one.add(3);

List<Integer> two = new ArrayList<Integer>();
two.add(1);
two.add(2);
two.add(4);

Object result = udf.evaluate(new GenericUDF.DeferredObject[]{new DeferredJavaObject(one), new DeferredJavaObject(two)});
assertEquals(1, resultOi.getListLength(result));
assertTrue(resultOi.getList(result).contains(3));
}
}

0 comments on commit 7a54175

Please sign in to comment.