Skip to content

Commit

Permalink
[SPARK-19843][SQL] UTF8String => (int / long) conversion expensive fo…
Browse files Browse the repository at this point in the history
…r invalid inputs

## What changes were proposed in this pull request?

Jira : https://issues.apache.org/jira/browse/SPARK-19843

Created wrapper classes (`IntWrapper`, `LongWrapper`) to wrap the result of parsing (which are primitive types). In case of problem in parsing, the method would return a boolean.

## How was this patch tested?

- Added new unit tests
- Ran a prod job which had conversion from string -> int and verified the outputs

## Performance

Tiny regression when all strings are valid integers

```
conversion to int:       Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
--------------------------------------------------------------------------------
trunk                         502 /  522         33.4          29.9       1.0X
SPARK-19843                   493 /  503         34.0          29.4       1.0X
```

Huge gain when all strings are invalid integers
```
conversion to int:      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
-------------------------------------------------------------------------------
trunk                     33913 / 34219          0.5        2021.4       1.0X
SPARK-19843                  154 /  162        108.8           9.2     220.0X
```

Author: Tejas Patil <[email protected]>

Closes apache#17184 from tejasapatil/SPARK-19843_is_numeric_maybe.
  • Loading branch information
tejasapatil authored and cloud-fan committed Mar 8, 2017
1 parent 47b2f68 commit c96d14a
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -850,26 +850,27 @@ public UTF8String translate(Map<Character, Character> dict) {
return fromString(sb.toString());
}

private int getDigit(byte b) {
if (b >= '0' && b <= '9') {
return b - '0';
}
throw new NumberFormatException(toString());
public static class LongWrapper {
public long value = 0;
}

/**
* Parses this UTF8String to long.
*
* Note that, in this method we accumulate the result in negative format, and convert it to
* positive format at the end, if this string is not started with '-'. This is because min value
* is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
* Integer.MIN_VALUE is '-2147483648'.
* is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and
* Long.MIN_VALUE is '-9223372036854775808'.
*
* This code is mostly copied from LazyLong.parseLong in Hive.
*
* @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would
* be set in `toLongResult`
* @return true if the parsing was successful else false
*/
public long toLong() {
public boolean toLong(LongWrapper toLongResult) {
if (numBytes == 0) {
throw new NumberFormatException("Empty string");
return false;
}

byte b = getByte(0);
Expand All @@ -878,7 +879,7 @@ public long toLong() {
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
throw new NumberFormatException(toString());
return false;
}
}

Expand All @@ -897,41 +898,52 @@ public long toLong() {
break;
}

int digit = getDigit(b);
int digit;
if (b >= '0' && b <= '9') {
digit = b - '0';
} else {
return false;
}

// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
// result * 10 will definitely be smaller than minValue, and we can stop.
if (result < stopValue) {
throw new NumberFormatException(toString());
return false;
}

result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
// can just use `result > 0` to check overflow. If result overflows, we should stop and throw
// exception.
// can just use `result > 0` to check overflow. If result overflows, we should stop.
if (result > 0) {
throw new NumberFormatException(toString());
return false;
}
}

// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
if (getDigit(getByte(offset)) == -1) {
throw new NumberFormatException(toString());
byte currentByte = getByte(offset);
if (currentByte < '0' || currentByte > '9') {
return false;
}
offset++;
}

if (!negative) {
result = -result;
if (result < 0) {
throw new NumberFormatException(toString());
return false;
}
}

return result;
toLongResult.value = result;
return true;
}

public static class IntWrapper {
public int value = 0;
}

/**
Expand All @@ -946,10 +958,14 @@ public long toLong() {
*
* Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
* reasons, like Hive does.
*
* @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would
* be set in `intWrapper`
* @return true if the parsing was successful else false
*/
public int toInt() {
public boolean toInt(IntWrapper intWrapper) {
if (numBytes == 0) {
throw new NumberFormatException("Empty string");
return false;
}

byte b = getByte(0);
Expand All @@ -958,7 +974,7 @@ public int toInt() {
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
throw new NumberFormatException(toString());
return false;
}
}

Expand All @@ -977,61 +993,69 @@ public int toInt() {
break;
}

int digit = getDigit(b);
int digit;
if (b >= '0' && b <= '9') {
digit = b - '0';
} else {
return false;
}

// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
// result * 10 will definitely be smaller than minValue, and we can stop
if (result < stopValue) {
throw new NumberFormatException(toString());
return false;
}

result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
// we can just use `result > 0` to check overflow. If result overflows, we should stop and
// throw exception.
// we can just use `result > 0` to check overflow. If result overflows, we should stop
if (result > 0) {
throw new NumberFormatException(toString());
return false;
}
}

// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
if (getDigit(getByte(offset)) == -1) {
throw new NumberFormatException(toString());
byte currentByte = getByte(offset);
if (currentByte < '0' || currentByte > '9') {
return false;
}
offset++;
}

if (!negative) {
result = -result;
if (result < 0) {
throw new NumberFormatException(toString());
return false;
}
}

return result;
intWrapper.value = result;
return true;
}

public short toShort() {
int intValue = toInt();
short result = (short) intValue;
if (result != intValue) {
throw new NumberFormatException(toString());
public boolean toShort(IntWrapper intWrapper) {
if (toInt(intWrapper)) {
int intValue = intWrapper.value;
short result = (short) intValue;
if (result == intValue) {
return true;
}
}

return result;
return false;
}

public byte toByte() {
int intValue = toInt();
byte result = (byte) intValue;
if (result != intValue) {
throw new NumberFormatException(toString());
public boolean toByte(IntWrapper intWrapper) {
if (toInt(intWrapper)) {
int intValue = intWrapper.value;
byte result = (byte) intValue;
if (result == intValue) {
return true;
}
}

return result;
return false;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.*;

import com.google.common.collect.ImmutableMap;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -608,4 +606,128 @@ public void writeToOutputStreamIntArray() throws IOException {
.writeTo(outputStream);
assertEquals("大千世界", outputStream.toString("UTF-8"));
}

@Test
public void testToShort() throws IOException {
Map<String, Short> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", (short) 1);
inputToExpectedOutput.put("+1", (short) 1);
inputToExpectedOutput.put("-1", (short) -1);
inputToExpectedOutput.put("0", (short) 0);
inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111);
inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE);

Random rand = new Random();
for (int i = 0; i < 10; i++) {
short value = (short) rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}

IntWrapper wrapper = new IntWrapper();
for (Map.Entry<String, Short> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper));
assertEquals((short) entry.getValue(), wrapper.value);
}

List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700");

for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper));
}
}

@Test
public void testToByte() throws IOException {
Map<String, Byte> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", (byte) 1);
inputToExpectedOutput.put("+1",(byte) 1);
inputToExpectedOutput.put("-1", (byte) -1);
inputToExpectedOutput.put("0", (byte) 0);
inputToExpectedOutput.put("111.12345678901234567890", (byte) 111);
inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE);

Random rand = new Random();
for (int i = 0; i < 10; i++) {
byte value = (byte) rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}

IntWrapper intWrapper = new IntWrapper();
for (Map.Entry<String, Byte> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper));
assertEquals((byte) entry.getValue(), intWrapper.value);
}

List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");

for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper));
}
}

@Test
public void testToInt() throws IOException {
Map<String, Integer> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", 1);
inputToExpectedOutput.put("+1", 1);
inputToExpectedOutput.put("-1", -1);
inputToExpectedOutput.put("0", 0);
inputToExpectedOutput.put("11111.1234567", 11111);
inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE);

Random rand = new Random();
for (int i = 0; i < 10; i++) {
int value = rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}

IntWrapper intWrapper = new IntWrapper();
for (Map.Entry<String, Integer> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper));
assertEquals((int) entry.getValue(), intWrapper.value);
}

List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");

for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper));
}
}

@Test
public void testToLong() throws IOException {
Map<String, Long> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", 1L);
inputToExpectedOutput.put("+1", 1L);
inputToExpectedOutput.put("-1", -1L);
inputToExpectedOutput.put("0", 0L);
inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L);
inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE);

Random rand = new Random();
for (int i = 0; i < 10; i++) {
long value = rand.nextLong();
inputToExpectedOutput.put(String.valueOf(value), value);
}

LongWrapper wrapper = new LongWrapper();
for (Map.Entry<String, Long> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper));
assertEquals((long) entry.getValue(), wrapper.value);
}

List<String> negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121",
"1234567890123456789012345678901234");

for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper));
}
}
}
Loading

0 comments on commit c96d14a

Please sign in to comment.