forked from denosaurs/netsaur
-
Notifications
You must be signed in to change notification settings - Fork 0
/
backend.ts
107 lines (95 loc) · 2.73 KB
/
backend.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import { Rank, Shape } from "../../core/api/shape.ts";
import { Tensor } from "../../core/tensor/tensor.ts";
import { length } from "../../core/tensor/util.ts";
import { Backend, DataSet, NetworkConfig } from "../../core/types.ts";
import { Library } from "./mod.ts";
import {
Buffer,
encodeDatasets,
encodeJSON,
PredictOptions,
TrainOptions,
} from "./util.ts";
/**
* GPU Backend.
*/
export class GPUBackend implements Backend {
library: Library;
outputShape: Shape[Rank];
#id: bigint;
constructor(
library: Library,
outputShape: Shape[Rank],
id: bigint,
) {
this.library = library;
this.outputShape = outputShape;
this.#id = id;
}
static create(config: NetworkConfig, library: Library) {
const buffer = encodeJSON(config);
const shape = new Buffer();
const id = library.symbols.ffi_backend_create(
buffer,
buffer.length,
shape.allocBuffer,
) as bigint;
const outputShape = Array.from(shape.buffer.slice(1)) as Shape[Rank];
return new GPUBackend(library, outputShape, id);
}
train(datasets: DataSet[], epochs: number, batches: number, rate: number) {
const buffer = encodeDatasets(datasets);
const options = encodeJSON({
datasets: datasets.length,
inputShape: datasets[0].inputs.shape,
outputShape: datasets[0].outputs.shape,
epochs,
batches,
rate,
} as TrainOptions);
this.library.symbols.ffi_backend_train(
this.#id,
buffer,
buffer.byteLength,
options,
options.byteLength,
);
}
//deno-lint-ignore require-await
async predict(input: Tensor<Rank>): Promise<Tensor<Rank>> {
const options = encodeJSON({
inputShape: [1, ...input.shape],
outputShape: this.outputShape,
} as PredictOptions);
const output = new Float32Array(length(this.outputShape));
this.library.symbols.ffi_backend_predict(
this.#id,
input.data as Float32Array,
options,
options.length,
output,
);
return new Tensor(output, this.outputShape);
}
save(): Uint8Array {
const shape = new Buffer();
this.library.symbols.ffi_backend_save(this.#id, shape.allocBuffer);
return shape.buffer;
}
saveFile(path: string): void {
Deno.writeFileSync(path, this.save());
}
static load(buffer: Uint8Array, library: Library): GPUBackend {
const shape = new Buffer();
const id = library.symbols.ffi_backend_load(
buffer,
buffer.length,
shape.allocBuffer,
) as bigint;
const outputShape = Array.from(shape.buffer.slice(1)) as Shape[Rank];
return new GPUBackend(library, outputShape, id);
}
static loadFile(path: string, library: Library): GPUBackend {
return this.load(Deno.readFileSync(path), library);
}
}