Skip to content

Commit

Permalink
service: http: Models via CLI
Browse files Browse the repository at this point in the history
Signed-off-by: John Andersen <[email protected]>
  • Loading branch information
pdxjohnny committed Apr 14, 2020
1 parent 2947288 commit bb95d48
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Example CLI commands and Python code for `SLRModel`
- `save` function in high level API to quickly save all given records to a
source
- Ability to configure models for HTTP API from command line when starting
server
### Changed
- Renamed `"arg"` to `"plugin"`.
- CSV source sorts feature names within headers when saving
Expand Down
11 changes: 11 additions & 0 deletions service/http/dffml_service_http/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

from aiohttp import web

from dffml import Model
from dffml.util.cli.arg import Arg
from dffml.util.cli.cmd import CMD
from dffml.util.cli.parser import list_action
from dffml.util.entrypoint import entrypoint
from dffml.util.asynchelper import AsyncContextManagerList

from .routes import Routes

Expand Down Expand Up @@ -181,6 +184,14 @@ class Server(TLSCMD, MultiCommCMD, Routes):
nargs="+",
default=[],
)
arg_models = Arg(
"-models",
help="Models configured on start",
nargs="+",
default=AsyncContextManagerList(),
type=Model.load_labeled,
action=list_action(AsyncContextManagerList),
)

def __init__(self, *args, **kwargs):
self.site = None
Expand Down
19 changes: 17 additions & 2 deletions service/http/dffml_service_http/routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import secrets
import inspect
import pathlib
import traceback
import pkg_resources
Expand Down Expand Up @@ -646,8 +647,22 @@ async def setup(self, **kwargs):
self.app["sources"] = {}
self.app["source_contexts"] = {}
self.app["source_records_iterkeys"] = {}
self.app["models"] = {}
self.app["model_contexts"] = {}

# Instantiate models if they aren't instantiated yet
for i, model in enumerate(self.models):
if inspect.isclass(model):
self.models[i] = model.withconfig(self.extra_config)

await self.app["exit_stack"].enter_async_context(self.models)
self.app["models"] = {
model.ENTRY_POINT_LABEL: model for model in self.models
}

mctx = await self.app["exit_stack"].enter_async_context(self.models())
self.app["model_contexts"] = {
model_ctx.parent.ENTRY_POINT_LABEL: model_ctx for model_ctx in mctx
}

self.app.update(kwargs)
# Allow no routes other than pre-registered if in atomic mode
self.routes = (
Expand Down
57 changes: 57 additions & 0 deletions service/http/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import aiohttp

from dffml.model.slr import SLRModel
from dffml import Features, DefFeature, train, accuracy
from dffml.util.asynctestcase import AsyncTestCase

from dffml_service_http.cli import HTTPService
Expand Down Expand Up @@ -221,3 +223,58 @@ async def test_mc_config(self):
{"Feedface": {"response": message}},
await response.json(),
)

async def test_models(self):
with tempfile.TemporaryDirectory() as tempdir:
# Model the HTTP API will pre-load
model = SLRModel(
features=Features(DefFeature("f1", float, 1)),
predict=DefFeature("ans", int, 1),
directory=tempdir,
)

# y = m * x + b for equation SLR is solving for
m = 5
b = 3

# Train the model
await train(
model, *[{"f1": x, "ans": m * x + b} for x in range(0, 10)]
)

await accuracy(
model, *[{"f1": x, "ans": m * x + b} for x in range(10, 20)]
)

async with ServerRunner.patch(HTTPService.server) as tserver:
cli = await tserver.start(
HTTPService.server.cli(
"-insecure",
"-port",
"0",
"-models",
"mymodel=slr",
"-model-mymodel-directory",
tempdir,
"-model-mymodel-features",
"f1:float:1",
"-model-mymodel-predict",
"ans:int:1",
)
)
async with self.post(
cli,
f"/model/mymodel/predict/0",
json={
f"record_{x}": {"features": {"f1": x}}
for x in range(20, 30)
},
) as response:
response = await response.json()
records = response["records"]
self.assertEqual(len(records), 10)
for record in records.values():
should_be = m * record["features"]["f1"] + b
prediction = record["prediction"]["ans"]["value"]
percent_error = abs(should_be - prediction) / should_be
self.assertLess(percent_error, 0.2)

0 comments on commit bb95d48

Please sign in to comment.