Skip to content

Commit

Permalink
HIVE-9178: Create a separate API for remote Spark Context RPC other t…
Browse files Browse the repository at this point in the history
…han job submission [Spark Branch] (Marcelo via Xuefu)

git-svn-id: https://svn.apache.org/repos/asf/hive/branches/spark@1652105 13f79535-47bb-0310-9956-ffa450edef68
  • Loading branch information
Xuefu Zhang committed Jan 15, 2015
1 parent 6ff82ba commit f5fdb96
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

/**
Expand Down Expand Up @@ -145,7 +146,7 @@ private SparkJobInfo getSparkJobInfo() {
return getDefaultJobInfo(sparkJobId, JobExecutionStatus.FAILED);
}
}
JobHandle<SparkJobInfo> getJobInfo = sparkClient.submit(
Future<SparkJobInfo> getJobInfo = sparkClient.run(
new GetJobInfoJob(jobHandle.getClientJobId(), sparkJobId));
try {
return getJobInfo.get(sparkClientTimeoutInSeconds, TimeUnit.SECONDS);
Expand All @@ -156,7 +157,7 @@ private SparkJobInfo getSparkJobInfo() {
}

private SparkStageInfo getSparkStageInfo(int stageId) {
JobHandle<SparkStageInfo> getStageInfo = sparkClient.submit(new GetStageInfoJob(stageId));
Future<SparkStageInfo> getStageInfo = sparkClient.run(new GetStageInfoJob(stageId));
try {
return getStageInfo.get(sparkClientTimeoutInSeconds, TimeUnit.SECONDS);
} catch (Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,18 @@ protected static class JobSubmitted implements Serializable {
}
}

protected static class SyncJobRequest<T extends Serializable> implements Serializable {

final Job<T> job;

SyncJobRequest(Job<T> job) {
this.job = job;
}

SyncJobRequest() {
this(null);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,24 @@ public class RemoteDriver {
private static final Logger LOG = LoggerFactory.getLogger(RemoteDriver.class);

private final Map<String, JobWrapper<?>> activeJobs;
private final Object jcLock;
private final Object shutdownLock;
private final ExecutorService executor;
private final JobContextImpl jc;
private final NioEventLoopGroup egroup;
private final Rpc clientRpc;
private final DriverProtocol protocol;

// Used to queue up requests while the SparkContext is being created.
private final List<JobWrapper<?>> jobQueue = Lists.newLinkedList();

private boolean running;
// jc is effectively final, but it has to be volatile since it's accessed by different
// threads while the constructor is running.
private volatile JobContextImpl jc;
private volatile boolean running;

private RemoteDriver(String[] args) throws Exception {
this.activeJobs = Maps.newConcurrentMap();
this.jcLock = new Object();
this.shutdownLock = new Object();

SparkConf conf = new SparkConf();
Expand Down Expand Up @@ -150,14 +154,20 @@ public void rpcClosed(Rpc rpc) {
try {
JavaSparkContext sc = new JavaSparkContext(conf);
sc.sc().addSparkListener(new ClientListener());
jc = new JobContextImpl(sc);
synchronized (jcLock) {
jc = new JobContextImpl(sc);
jcLock.notifyAll();
}
} catch (Exception e) {
LOG.error("Failed to start SparkContext.", e);
shutdown(e);
synchronized (jcLock) {
jcLock.notifyAll();
}
throw e;
}

synchronized (jobQueue) {
synchronized (jcLock) {
for (Iterator<JobWrapper<?>> it = jobQueue.iterator(); it.hasNext();) {
it.next().submit();
}
Expand All @@ -174,7 +184,7 @@ private void run() throws InterruptedException {
}

private void submit(JobWrapper<?> job) {
synchronized (jobQueue) {
synchronized (jcLock) {
if (jc != null) {
job.submit();
} else {
Expand Down Expand Up @@ -264,6 +274,35 @@ private void handle(ChannelHandlerContext ctx, JobRequest msg) {
submit(wrapper);
}

private Object handle(ChannelHandlerContext ctx, SyncJobRequest msg) throws Exception {
// In case the job context is not up yet, let's wait, since this is supposed to be a
// "synchronous" RPC.
if (jc == null) {
synchronized (jcLock) {
while (jc == null) {
jcLock.wait();
if (!running) {
throw new IllegalStateException("Remote context is shutting down.");
}
}
}
}

jc.setMonitorCb(new MonitorCallback() {
@Override
public void call(JavaFutureAction<?> future,
SparkCounters sparkCounters, Set<Integer> cachedRDDIds) {
throw new IllegalStateException(
"JobContext.monitor() is not available for synchronous jobs.");
}
});
try {
return msg.job.call(jc);
} finally {
jc.setMonitorCb(null);
}
}

}

private class JobWrapper<T extends Serializable> implements Callable<Void> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ public interface SparkClient extends Serializable {
*/
<T extends Serializable> JobHandle<T> submit(Job<T> job);

/**
* Asks the remote context to run a job immediately.
* <p/>
* Normally, the remote context will queue jobs and execute them based on how many worker
* threads have been configured. This method will run the submitted job in the same thread
* processing the RPC message, so that queueing does not apply.
* <p/>
* It's recommended that this method only be used to run code that finishes quickly. This
* avoids interfering with the normal operation of the context.
* <p/>
* Note: the {@link JobContext#monitor()} functionality is not available when using this method.
*
* @param job The job to execute.
* @return A future to monitor the result of the job.
*/
<T extends Serializable> Future<T> run(Job<T> job);

/**
* Stops the remote context.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ public <T extends Serializable> JobHandle<T> submit(Job<T> job) {
return protocol.submit(job);
}

@Override
public <T extends Serializable> Future<T> run(Job<T> job) {
return protocol.run(job);
}

@Override
public void stop() {
if (isAlive) {
Expand Down Expand Up @@ -144,22 +149,22 @@ public void stop() {

@Override
public Future<?> addJar(URL url) {
return submit(new AddJarJob(url.toString()));
return run(new AddJarJob(url.toString()));
}

@Override
public Future<?> addFile(URL url) {
return submit(new AddFileJob(url.toString()));
return run(new AddFileJob(url.toString()));
}

@Override
public Future<Integer> getExecutorCount() {
return submit(new GetExecutorCountJob());
return run(new GetExecutorCountJob());
}

@Override
public Future<Integer> getDefaultParallelism() {
return submit(new GetDefaultParallelismJob());
return run(new GetDefaultParallelismJob());
}

void cancel(String jobId) {
Expand Down Expand Up @@ -379,16 +384,24 @@ public void operationComplete(io.netty.util.concurrent.Future<Void> f) {
promise.addListener(new GenericFutureListener<Promise<T>>() {
@Override
public void operationComplete(Promise<T> p) {
jobs.remove(jobId);
if (jobId != null) {
jobs.remove(jobId);
}
if (p.isCancelled() && !rpc.isDone()) {
rpc.cancel(true);
}
}
});

return handle;
}

<T extends Serializable> Future<T> run(Job<T> job) {
@SuppressWarnings("unchecked")
final io.netty.util.concurrent.Future<T> rpc = (io.netty.util.concurrent.Future<T>)
driverRpc.call(new SyncJobRequest(job), Serializable.class);
return rpc;
}

void cancel(String jobId) {
driverRpc.call(new CancelJob(jobId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.jar.JarOutputStream;
import java.util.zip.ZipEntry;
Expand Down Expand Up @@ -110,6 +111,17 @@ public void call(SparkClient client) throws Exception {
});
}

@Test
public void testSyncRpc() throws Exception {
runTest(true, new TestFunction() {
@Override
public void call(SparkClient client) throws Exception {
Future<String> result = client.run(new SyncRpc());
assertEquals("Hello", result.get(TIMEOUT, TimeUnit.SECONDS));
}
});
}

@Test
public void testRemoteClient() throws Exception {
runTest(false, new TestFunction() {
Expand Down Expand Up @@ -333,6 +345,15 @@ public void call(Integer l) throws Exception {

}

private static class SyncRpc implements Job<String> {

@Override
public String call(JobContext jc) {
return "Hello";
}

}

private abstract static class TestFunction {
abstract void call(SparkClient client) throws Exception;
void config(Map<String, String> conf) { }
Expand Down

0 comments on commit f5fdb96

Please sign in to comment.