-
Notifications
You must be signed in to change notification settings - Fork 845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Moving the plugins SDK to org.pytorch. Added plugins example. #261
Changes from all commits
2b00bd1
d17aaf6
c66292d
097516e
7894834
cf118b7
45e31c2
9db8926
44f5508
8855afa
1bee126
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
/* | ||
* This file was generated by the Gradle 'init' task. | ||
* | ||
* This generated file contains a sample Java Library project to get you started. | ||
* For more details take a look at the Java Libraries chapter in the Gradle | ||
* User Manual available at https://docs.gradle.org/5.4.1/userguide/java_library_plugin.html | ||
*/ | ||
|
||
allprojects { | ||
apply plugin: 'idea' | ||
apply plugin: 'java' | ||
|
||
version = '1.0' | ||
|
||
repositories { | ||
jcenter() | ||
} | ||
|
||
idea { | ||
module { | ||
outputDir = file('build/classes/java/main') | ||
testOutputDir = file('build/classes/java/test') | ||
} | ||
} | ||
|
||
task buildSagemaker("type": Jar) { | ||
|
||
doFirst{ task -> | ||
println "building $task.project.name" | ||
} | ||
|
||
with project.jar | ||
|
||
doLast { | ||
copy { | ||
def fromDir = project.jar | ||
def intoDir = "${rootProject.projectDir}/build/plugins" | ||
from fromDir | ||
into intoDir | ||
println "Copying files from" + fromDir + " into " + intoDir | ||
} | ||
} | ||
} | ||
|
||
buildSagemaker.onlyIf {project.hasProperty("sagemaker")} | ||
} | ||
|
||
def javaProjects() { | ||
return subprojects.findAll() | ||
} | ||
|
||
configure(javaProjects()) { | ||
sourceCompatibility = 1.8 | ||
targetCompatibility = 1.8 | ||
|
||
defaultTasks 'jar' | ||
|
||
apply from: file("${rootProject.projectDir}/tools/gradle/formatter.gradle") | ||
apply from: file("${rootProject.projectDir}/tools/gradle/check.gradle") | ||
|
||
test { | ||
useTestNG() { | ||
// suiteXmlFiles << new File(rootDir, "testng.xml") //This is how to add custom testng.xml | ||
} | ||
|
||
testLogging { | ||
showStandardStreams = true | ||
events "passed", "skipped", "failed", "standardOut", "standardError" | ||
} | ||
} | ||
|
||
test.finalizedBy(project.tasks.jacocoTestReport) | ||
|
||
compileJava { | ||
options.compilerArgs << "-Xlint:all,-options,-static" << "-Werror" | ||
} | ||
|
||
jacocoTestCoverageVerification { | ||
violationRules { | ||
rule { | ||
limit { | ||
minimum = 0.75 | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
dependencies { | ||
compile "com.google.code.gson:gson:${gson_version}" | ||
compile "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}" | ||
} | ||
|
||
project.ext{ | ||
sagemaker = true | ||
} | ||
|
||
jar { | ||
includeEmptyDirs = false | ||
|
||
exclude "META-INF/maven/**" | ||
exclude "META-INF/INDEX.LIST" | ||
exclude "META-INF/MANIFEST*" | ||
exclude "META-INF//LICENSE*" | ||
exclude "META-INF//NOTICE*" | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package org.pytorch.serve.plugins.endpoint; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear if this plugin is needed to move the plugins SDK? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an example of how to define an endpoint. This example shows how to define endpoints which can be readily used/supported by SM services. |
||
|
||
import com.google.gson.GsonBuilder; | ||
import com.google.gson.annotations.SerializedName; | ||
import java.io.IOException; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.Properties; | ||
import org.pytorch.serve.servingsdk.Context; | ||
import org.pytorch.serve.servingsdk.ModelServerEndpoint; | ||
import org.pytorch.serve.servingsdk.annotations.Endpoint; | ||
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes; | ||
import org.pytorch.serve.servingsdk.http.Request; | ||
import org.pytorch.serve.servingsdk.http.Response; | ||
|
||
@Endpoint( | ||
urlPattern = "execution-parameters", | ||
endpointType = EndpointTypes.INFERENCE, | ||
description = "Execution parameters endpoint") | ||
public class ExecutionParameters extends ModelServerEndpoint { | ||
|
||
@Override | ||
public void doGet(Request req, Response rsp, Context ctx) throws IOException { | ||
Properties prop = ctx.getConfig(); | ||
// 6 * 1024 * 1024 | ||
int maxRequestSize = Integer.parseInt(prop.getProperty("max_request_size", "6291456")); | ||
ExecutionParametersResponse r = new ExecutionParametersResponse(); | ||
r.setMaxConcurrentTransforms(Integer.parseInt(prop.getProperty("NUM_WORKERS", "1"))); | ||
r.setBatchStrategy("MULTI_RECORD"); | ||
r.setMaxPayloadInMB(maxRequestSize / (1024 * 1024)); | ||
rsp.getOutputStream() | ||
.write( | ||
new GsonBuilder() | ||
.setPrettyPrinting() | ||
.create() | ||
.toJson(r) | ||
.getBytes(StandardCharsets.UTF_8)); | ||
} | ||
|
||
/** Response for Model server endpoint */ | ||
public static class ExecutionParametersResponse { | ||
@SerializedName("MaxConcurrentTransforms") | ||
private int maxConcurrentTransforms; | ||
|
||
@SerializedName("BatchStrategy") | ||
private String batchStrategy; | ||
|
||
@SerializedName("MaxPayloadInMB") | ||
private int maxPayloadInMB; | ||
|
||
public ExecutionParametersResponse() { | ||
maxConcurrentTransforms = 4; | ||
batchStrategy = "MULTI_RECORD"; | ||
maxPayloadInMB = 6; | ||
} | ||
|
||
public int getMaxConcurrentTransforms() { | ||
return maxConcurrentTransforms; | ||
} | ||
|
||
public String getBatchStrategy() { | ||
return batchStrategy; | ||
} | ||
|
||
public int getMaxPayloadInMB() { | ||
return maxPayloadInMB; | ||
} | ||
|
||
public void setMaxConcurrentTransforms(int newMaxConcurrentTransforms) { | ||
maxConcurrentTransforms = newMaxConcurrentTransforms; | ||
} | ||
|
||
public void setBatchStrategy(String newBatchStrategy) { | ||
batchStrategy = newBatchStrategy; | ||
} | ||
|
||
public void setMaxPayloadInMB(int newMaxPayloadInMB) { | ||
maxPayloadInMB = newMaxPayloadInMB; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious about this version ID, where does this come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the SDK version. We publish this version to JCenter. Customers who want to define their custom endpoints would have to pull in 0.0.3 version of this SDK and define their endpoint. Similar to PyPi version, we will have to bump the package version for every release.