Skip to content

Commit

Permalink
HIVE-9487: Make Remote Spark Context secure [Spark Branch] (Marcelo v…
Browse files Browse the repository at this point in the history
…ia Xuefu)

git-svn-id: https://svn.apache.org/repos/asf/hive/branches/spark@1655926 13f79535-47bb-0310-9956-ffa450edef68
  • Loading branch information
Xuefu Zhang committed Jan 30, 2015
1 parent c0d1e54 commit 4f3187c
Show file tree
Hide file tree
Showing 12 changed files with 672 additions and 165 deletions.
4 changes: 3 additions & 1 deletion common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
Original file line number Diff line number Diff line change
Expand Up @@ -2018,7 +2018,9 @@ public static enum ConfVars {
SPARK_RPC_MAX_MESSAGE_SIZE("hive.spark.client.rpc.max.size", 50 * 1024 * 1024,
"Maximum message size in bytes for communication between Hive client and remote Spark driver. Default is 50MB."),
SPARK_RPC_CHANNEL_LOG_LEVEL("hive.spark.client.channel.log.level", null,
"Channel logging level for remote Spark driver. One of {DEBUG, ERROR, INFO, TRACE, WARN}.");
"Channel logging level for remote Spark driver. One of {DEBUG, ERROR, INFO, TRACE, WARN}."),
SPARK_RPC_SASL_MECHANISM("hive.spark.client.rpc.sasl.mechanisms", "DIGEST-MD5",
"Name of the SASL mechanism to use for authentication.");

public final String varname;
private final String defaultExpr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ private RemoteDriver(String[] args) throws Exception {
serverAddress = getArg(args, idx);
} else if (key.equals("--remote-port")) {
serverPort = Integer.parseInt(getArg(args, idx));
} else if (key.equals("--client-id")) {
conf.set(SparkClientFactory.CONF_CLIENT_ID, getArg(args, idx));
} else if (key.equals("--secret")) {
conf.set(SparkClientFactory.CONF_KEY_SECRET, getArg(args, idx));
} else if (key.equals("--conf")) {
Expand All @@ -127,6 +129,8 @@ private RemoteDriver(String[] args) throws Exception {
LOG.debug("Remote Driver configured with: " + e._1() + "=" + e._2());
}

String clientId = mapConf.get(SparkClientFactory.CONF_CLIENT_ID);
Preconditions.checkArgument(clientId != null, "No client ID provided.");
String secret = mapConf.get(SparkClientFactory.CONF_KEY_SECRET);
Preconditions.checkArgument(secret != null, "No secret provided.");

Expand All @@ -140,8 +144,8 @@ private RemoteDriver(String[] args) throws Exception {
this.protocol = new DriverProtocol();

// The RPC library takes care of timing out this.
this.clientRpc = Rpc.createClient(mapConf, egroup, serverAddress, serverPort, secret, protocol)
.get();
this.clientRpc = Rpc.createClient(mapConf, egroup, serverAddress, serverPort,
clientId, secret, protocol).get();
this.running = true;

this.clientRpc.addListener(new Rpc.Listener() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public final class SparkClientFactory {
/** Used to run the driver in-process, mostly for testing. */
static final String CONF_KEY_IN_PROCESS = "spark.client.do_not_use.run_driver_in_process";

/** Used by client and driver to share a client ID for establishing an RPC session. */
static final String CONF_CLIENT_ID = "spark.client.authentication.client_id";

/** Used by client and driver to share a secret for establishing an RPC session. */
static final String CONF_KEY_SECRET = "spark.client.authentication.secret";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ class SparkClientImpl implements SparkClient {
this.childIdGenerator = new AtomicInteger();
this.jobs = Maps.newConcurrentMap();

String clientId = UUID.randomUUID().toString();
String secret = rpcServer.createSecret();
this.driverThread = startDriver(rpcServer, secret);
this.driverThread = startDriver(rpcServer, clientId, secret);
this.protocol = new ClientProtocol();

try {
// The RPC server will take care of timeouts here.
this.driverRpc = rpcServer.registerClient(secret, protocol).get();
this.driverRpc = rpcServer.registerClient(clientId, secret, protocol).get();
} catch (Exception e) {
LOG.warn("Error while waiting for client to connect.", e);
driverThread.interrupt();
Expand Down Expand Up @@ -174,7 +175,8 @@ void cancel(String jobId) {
protocol.cancel(jobId);
}

private Thread startDriver(RpcServer rpcServer, final String secret) throws IOException {
private Thread startDriver(RpcServer rpcServer, final String clientId, final String secret)
throws IOException {
Runnable runnable;
final String serverAddress = rpcServer.getAddress();
final String serverPort = String.valueOf(rpcServer.getPort());
Expand All @@ -190,6 +192,8 @@ public void run() {
args.add(serverAddress);
args.add("--remote-port");
args.add(serverPort);
args.add("--client-id");
args.add(clientId);
args.add("--secret");
args.add(secret);

Expand Down Expand Up @@ -243,6 +247,7 @@ public void run() {
for (Map.Entry<String, String> e : conf.entrySet()) {
allProps.put(e.getKey(), conf.get(e.getKey()));
}
allProps.put(SparkClientFactory.CONF_CLIENT_ID, clientId);
allProps.put(SparkClientFactory.CONF_KEY_SECRET, secret);
allProps.put(DRIVER_OPTS_KEY, driverJavaOpts);
allProps.put(EXECUTOR_OPTS_KEY, executorJavaOpts);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.hive.spark.client.rpc;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -63,9 +64,12 @@ protected Kryo initialValue() {
}
};

private volatile EncryptionHandler encryptionHandler;

public KryoMessageCodec(int maxMessageSize, Class<?>... messages) {
this.maxMessageSize = maxMessageSize;
this.messages = Arrays.asList(messages);
this.encryptionHandler = null;
}

@Override
Expand All @@ -86,7 +90,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
}

try {
ByteBuffer nioBuffer = in.nioBuffer(in.readerIndex(), msgSize);
ByteBuffer nioBuffer = maybeDecrypt(in.nioBuffer(in.readerIndex(), msgSize));
Input kryoIn = new Input(new ByteBufferInputStream(nioBuffer));

Object msg = kryos.get().readClassAndObject(kryoIn);
Expand All @@ -106,7 +110,7 @@ protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf buf)
kryos.get().writeClassAndObject(kryoOut, msg);
kryoOut.flush();

byte[] msgData = bytes.toByteArray();
byte[] msgData = maybeEncrypt(bytes.toByteArray());
LOG.debug("Encoded message of type {} ({} bytes)", msg.getClass().getName(), msgData.length);
checkSize(msgData.length);

Expand All @@ -115,10 +119,56 @@ protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf buf)
buf.writeBytes(msgData);
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (encryptionHandler != null) {
encryptionHandler.dispose();
}
super.channelInactive(ctx);
}

private void checkSize(int msgSize) {
Preconditions.checkArgument(msgSize > 0, "Message size (%s bytes) must be positive.", msgSize);
Preconditions.checkArgument(maxMessageSize <= 0 || msgSize <= maxMessageSize,
"Message (%s bytes) exceeds maximum allowed size (%s bytes).", msgSize, maxMessageSize);
}

private byte[] maybeEncrypt(byte[] data) throws Exception {
return (encryptionHandler != null) ? encryptionHandler.wrap(data, 0, data.length) : data;
}

private ByteBuffer maybeDecrypt(ByteBuffer data) throws Exception {
if (encryptionHandler != null) {
byte[] encrypted;
int len = data.limit() - data.position();
int offset;
if (data.hasArray()) {
encrypted = data.array();
offset = data.position() + data.arrayOffset();
data.position(data.limit());
} else {
encrypted = new byte[len];
offset = 0;
data.get(encrypted);
}
return ByteBuffer.wrap(encryptionHandler.unwrap(encrypted, offset, len));
} else {
return data;
}
}

void setEncryptionHandler(EncryptionHandler handler) {
this.encryptionHandler = handler;
}

interface EncryptionHandler {

byte[] wrap(byte[] data, int offset, int len) throws IOException;

byte[] unwrap(byte[] data, int offset, int len) throws IOException;

void dispose() throws IOException;

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Basic flow of events:
- Client side creates an RPC server
- Client side spawns RemoteDriver, which manages the SparkContext, and provides a secret
- Client side sets up a timer to wait for RemoteDriver to connect back
- RemoteDriver connects back to client, sends Hello message with secret
- RemoteDriver connects back to client, SASL handshake ensues
- Connection is established and now there's a session between the client and the driver.

Features of the RPC layer:
Expand Down
Loading

0 comments on commit 4f3187c

Please sign in to comment.