forked from intel/dffml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_quickstart.py
122 lines (109 loc) · 4.81 KB
/
test_quickstart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import ast
import sys
import json
import shlex
import asyncio
import pathlib
import tempfile
import contextlib
import subprocess
from dffml import chdir, AsyncTestCase
from dffml_service_http.cli import HTTPService
from dffml_service_http.util.testing import ServerRunner
def sh_filepath(filename):
return os.path.join(os.path.dirname(__file__), "quickstart", filename)
@contextlib.contextmanager
def directory_with_csv_files():
with tempfile.TemporaryDirectory() as tempdir:
with chdir(tempdir):
subprocess.check_output(["sh", sh_filepath("train_data.sh")])
subprocess.check_output(["sh", sh_filepath("test_data.sh")])
subprocess.check_output(["sh", sh_filepath("predict_data.sh")])
yield tempdir
class TestQuickstart(AsyncTestCase):
def python_test(self, filename):
# Path to target file
filepath = os.path.join(os.path.dirname(__file__), filename)
# Capture output
stdout = subprocess.check_output([sys.executable, filepath])
lines = stdout.decode().split("\n")
# Check the Accuracy
self.assertIn("Accuracy: 0.0", lines[0])
# Check the salary
self.assertEqual(round(ast.literal_eval(lines[1])["Salary"]), 70)
self.assertEqual(round(ast.literal_eval(lines[2])["Salary"]), 80)
def test_python(self):
self.python_test("quickstart.py")
def test_python_async(self):
self.python_test("quickstart_async.py")
def test_python_filenames(self):
with directory_with_csv_files() as tempdir:
self.python_test("quickstart_filenames.py")
def test_shell(self):
with directory_with_csv_files() as tempdir:
# Run training
subprocess.check_output(["sh", sh_filepath("train.sh")])
# Check the Accuracy
stdout = subprocess.check_output(
["sh", sh_filepath("accuracy.sh")]
)
self.assertAlmostEqual(float(stdout.decode().strip()), 0.0)
# Make the prediction
stdout = subprocess.check_output(["sh", sh_filepath("predict.sh")])
records = json.loads(stdout.decode())
# Check the salary
self.assertEqual(
round(records[0]["prediction"]["Salary"]["value"]), 70
)
self.assertEqual(
round(records[1]["prediction"]["Salary"]["value"]), 80
)
async def test_http(self):
# Read in command to start HTTP server
server_cmd = pathlib.Path(sh_filepath("model_start_http.sh"))
server_cmd = server_cmd.read_text()
server_cmd = server_cmd.replace("\n", "")
server_cmd = server_cmd.replace("\\", "")
# Remove `dffml service http server`
server_cmd = server_cmd.replace("dffml service http server", "")
# Replace port
server_cmd = server_cmd.replace("8080", "0")
server_cmd = shlex.split(server_cmd)
# Read in the curl command
curl_cmd = pathlib.Path(sh_filepath("model_curl_http.sh"))
curl_cmd = curl_cmd.read_text()
# Modify the curl command to use the correct version of python
curl_cmd = curl_cmd.replace("python", sys.executable)
# Create a temporary directory for new curl command
with directory_with_csv_files() as tempdir:
# Run training
subprocess.check_output(["sh", sh_filepath("train.sh")])
async with ServerRunner.patch(HTTPService.server) as tserver:
# Start the HTTP server
cli = await tserver.start(HTTPService.server.cli(*server_cmd))
# Modify the curl command to use the correct port
curl_cmd = curl_cmd.replace("8080", str(cli.port))
# Write out the modified curl command
pathlib.Path("curl.sh").write_text(curl_cmd)
# Make the prediction
proc = await asyncio.create_subprocess_exec(
"sh",
"curl.sh",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise Exception(stderr.decode())
response = json.loads(stdout)
# Check the result
records = response["records"]
self.assertEqual(len(records), 1)
for record in records.values():
# Correct value should be 90
should_be = 90
prediction = record["prediction"]["Salary"]["value"]
# Check prediction within 20% of correct value
percent_error = abs(should_be - prediction) / should_be
self.assertLess(percent_error, 0.2)