Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving the plugins SDK to org.pytorch. Added plugins example. #261

Merged
merged 11 commits into from
May 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ci/buildspec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ phases:
install:
commands:
- apt-get update
- apt-get install -y curl
- apt-get install -y curl gnupg2
- pip install pip -U
- pip install future
- pip install Pillow
Expand All @@ -20,7 +20,7 @@ phases:
build:
commands:
- ./torchserve_sanity.sh
- cd serving-sdk/ && mvn clean deploy && cd ../
- cd serving-sdk/ && mvn clean install && cd ../

artifacts:
files:
Expand Down
2 changes: 1 addition & 1 deletion frontend/gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ slf4j_log4j12_version=1.7.25
gson_version=2.8.5
commons_cli_version=1.3.1
testng_version=6.8.1
mms_server_sdk_version=1.0.1
torchserve_sdk_version=0.0.3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious about this version ID, where does this come from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the SDK version. We publish this version to JCenter. Customers who want to define their custom endpoints would have to pull in 0.0.3 version of this SDK and define their endpoint. Similar to PyPi version, we will have to bump the package version for every release.

2 changes: 1 addition & 1 deletion frontend/server/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ dependencies {
compile "io.netty:netty-all:${netty_version}"
compile project(":modelarchive")
compile "commons-cli:commons-cli:${commons_cli_version}"
compile "software.amazon.ai:mms-plugins-sdk:${mms_server_sdk_version}"
compile "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}"
testCompile "org.testng:testng:${testng_version}"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import org.pytorch.serve.archive.ModelArchive;
import org.pytorch.serve.archive.ModelException;
import org.pytorch.serve.metrics.MetricManager;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.servingsdk.annotations.Endpoint;
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
import org.pytorch.serve.servingsdk.impl.PluginsManager;
import org.pytorch.serve.snapshot.InvalidSnapshotException;
import org.pytorch.serve.snapshot.SnapshotManager;
Expand All @@ -40,9 +43,6 @@
import org.pytorch.serve.wlm.WorkLoadManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
import software.amazon.ai.mms.servingsdk.annotations.Endpoint;
import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes;

public class ModelServer {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
import java.util.Map;
import org.pytorch.serve.archive.ModelException;
import org.pytorch.serve.archive.ModelNotFoundException;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.servingsdk.ModelServerEndpointException;
import org.pytorch.serve.servingsdk.impl.ModelServerContext;
import org.pytorch.serve.servingsdk.impl.ModelServerRequest;
import org.pytorch.serve.servingsdk.impl.ModelServerResponse;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.wlm.ModelManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
import software.amazon.ai.mms.servingsdk.ModelServerEndpointException;

public abstract class HttpRequestHandlerChain {
private static final Logger logger = LoggerFactory.getLogger(HttpRequestHandler.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.pytorch.serve.archive.ModelException;
import org.pytorch.serve.archive.ModelNotFoundException;
import org.pytorch.serve.openapi.OpenApiUtils;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.util.messages.InputParameter;
import org.pytorch.serve.util.messages.RequestInput;
Expand All @@ -23,7 +24,6 @@
import org.pytorch.serve.wlm.ModelManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;

/**
* A class handling inbound HTTP requests to the management API.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import org.pytorch.serve.archive.ModelNotFoundException;
import org.pytorch.serve.archive.ModelVersionNotFoundException;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.snapshot.SnapshotManager;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.WorkerThread;
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;

/**
* A class handling inbound HTTP requests to the management API.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import org.pytorch.serve.servingsdk.Context;
import org.pytorch.serve.servingsdk.Model;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.wlm.ModelManager;
import software.amazon.ai.mms.servingsdk.Context;
import software.amazon.ai.mms.servingsdk.Model;

public class ModelServerContext implements Context {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import java.util.ArrayList;
import java.util.List;
import org.pytorch.serve.servingsdk.Model;
import org.pytorch.serve.servingsdk.Worker;
import org.pytorch.serve.wlm.ModelManager;
import software.amazon.ai.mms.servingsdk.Model;
import software.amazon.ai.mms.servingsdk.Worker;

public class ModelServerModel implements Model {
private final org.pytorch.serve.wlm.Model model;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import software.amazon.ai.mms.servingsdk.http.Request;
import org.pytorch.serve.servingsdk.http.Request;

public class ModelServerRequest implements Request {
private FullHttpRequest req;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.io.OutputStream;
import software.amazon.ai.mms.servingsdk.http.Response;
import org.pytorch.serve.servingsdk.http.Response;

public class ModelServerResponse implements Response {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package org.pytorch.serve.servingsdk.impl;

import org.pytorch.serve.servingsdk.Worker;
import org.pytorch.serve.wlm.WorkerState;
import org.pytorch.serve.wlm.WorkerThread;
import software.amazon.ai.mms.servingsdk.Worker;

public class ModelWorker implements Worker {
private boolean running;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import java.util.Map;
import java.util.ServiceLoader;
import org.pytorch.serve.http.InvalidPluginException;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.servingsdk.annotations.Endpoint;
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
import software.amazon.ai.mms.servingsdk.annotations.Endpoint;
import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes;

public final class PluginsManager {

Expand Down
88 changes: 88 additions & 0 deletions plugins/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* This file was generated by the Gradle 'init' task.
*
* This generated file contains a sample Java Library project to get you started.
* For more details take a look at the Java Libraries chapter in the Gradle
* User Manual available at https://docs.gradle.org/5.4.1/userguide/java_library_plugin.html
*/

allprojects {
apply plugin: 'idea'
apply plugin: 'java'

version = '1.0'

repositories {
jcenter()
}

idea {
module {
outputDir = file('build/classes/java/main')
testOutputDir = file('build/classes/java/test')
}
}

task buildSagemaker("type": Jar) {

doFirst{ task ->
println "building $task.project.name"
}

with project.jar

doLast {
copy {
def fromDir = project.jar
def intoDir = "${rootProject.projectDir}/build/plugins"
from fromDir
into intoDir
println "Copying files from" + fromDir + " into " + intoDir
}
}
}

buildSagemaker.onlyIf {project.hasProperty("sagemaker")}
}

def javaProjects() {
return subprojects.findAll()
}

configure(javaProjects()) {
sourceCompatibility = 1.8
targetCompatibility = 1.8

defaultTasks 'jar'

apply from: file("${rootProject.projectDir}/tools/gradle/formatter.gradle")
apply from: file("${rootProject.projectDir}/tools/gradle/check.gradle")

test {
useTestNG() {
// suiteXmlFiles << new File(rootDir, "testng.xml") //This is how to add custom testng.xml
}

testLogging {
showStandardStreams = true
events "passed", "skipped", "failed", "standardOut", "standardError"
}
}

test.finalizedBy(project.tasks.jacocoTestReport)

compileJava {
options.compilerArgs << "-Xlint:all,-options,-static" << "-Werror"
}

jacocoTestCoverageVerification {
violationRules {
rule {
limit {
minimum = 0.75
}
}
}
}
}

19 changes: 19 additions & 0 deletions plugins/endpoints/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
dependencies {
compile "com.google.code.gson:gson:${gson_version}"
compile "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}"
}

project.ext{
sagemaker = true
}

jar {
includeEmptyDirs = false

exclude "META-INF/maven/**"
exclude "META-INF/INDEX.LIST"
exclude "META-INF/MANIFEST*"
exclude "META-INF//LICENSE*"
exclude "META-INF//NOTICE*"
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package org.pytorch.serve.plugins.endpoint;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear if this plugin is needed to move the plugins SDK?

Copy link
Collaborator Author

@vdantu vdantu May 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an example of how to define an endpoint. This example shows how to define endpoints which can be readily used/supported by SM services.


import com.google.gson.GsonBuilder;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Properties;
import org.pytorch.serve.servingsdk.Context;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.servingsdk.annotations.Endpoint;
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
import org.pytorch.serve.servingsdk.http.Request;
import org.pytorch.serve.servingsdk.http.Response;

@Endpoint(
urlPattern = "execution-parameters",
endpointType = EndpointTypes.INFERENCE,
description = "Execution parameters endpoint")
public class ExecutionParameters extends ModelServerEndpoint {

@Override
public void doGet(Request req, Response rsp, Context ctx) throws IOException {
Properties prop = ctx.getConfig();
// 6 * 1024 * 1024
int maxRequestSize = Integer.parseInt(prop.getProperty("max_request_size", "6291456"));
ExecutionParametersResponse r = new ExecutionParametersResponse();
r.setMaxConcurrentTransforms(Integer.parseInt(prop.getProperty("NUM_WORKERS", "1")));
r.setBatchStrategy("MULTI_RECORD");
r.setMaxPayloadInMB(maxRequestSize / (1024 * 1024));
rsp.getOutputStream()
.write(
new GsonBuilder()
.setPrettyPrinting()
.create()
.toJson(r)
.getBytes(StandardCharsets.UTF_8));
}

/** Response for Model server endpoint */
public static class ExecutionParametersResponse {
@SerializedName("MaxConcurrentTransforms")
private int maxConcurrentTransforms;

@SerializedName("BatchStrategy")
private String batchStrategy;

@SerializedName("MaxPayloadInMB")
private int maxPayloadInMB;

public ExecutionParametersResponse() {
maxConcurrentTransforms = 4;
batchStrategy = "MULTI_RECORD";
maxPayloadInMB = 6;
}

public int getMaxConcurrentTransforms() {
return maxConcurrentTransforms;
}

public String getBatchStrategy() {
return batchStrategy;
}

public int getMaxPayloadInMB() {
return maxPayloadInMB;
}

public void setMaxConcurrentTransforms(int newMaxConcurrentTransforms) {
maxConcurrentTransforms = newMaxConcurrentTransforms;
}

public void setBatchStrategy(String newBatchStrategy) {
batchStrategy = newBatchStrategy;
}

public void setMaxPayloadInMB(int newMaxPayloadInMB) {
maxPayloadInMB = newMaxPayloadInMB;
}
}
}
Loading