Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llvm: Implement compiled Angle and SoftMax derivative Functions #2528

Merged
merged 7 commits into from
Nov 9, 2022
Prev Previous commit
llvm, functions/SoftMax: Implement compiled 'derivative' variant
Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Nov 9, 2022
commit cae1465b319539f17bb651dd7961ede88a7c4022
Original file line number Diff line number Diff line change
Expand Up @@ -2697,8 +2697,47 @@ def __gen_llvm_apply(self, ctx, builder, params, state, arg_in, arg_out, output_

return builder

def _gen_llvm_function_derivative_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
assert "derivative" in tags
forward_tags = tags.difference({"derivative"})
all_out = builder.alloca(arg_out.type.pointee)
builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, all_out, output_type=ALL, tags=forward_tags)

max_pos_ptr = builder.alloca(ctx.int32_ty)
builder.store(max_pos_ptr.type.pointee(-1), max_pos_ptr)
max_val_ptr = builder.alloca(arg_out.type.pointee.element)
builder.store(max_val_ptr.type.pointee(float("NaN")), max_val_ptr)

with pnlvm.helpers.array_ptr_loop(builder, all_out, id="max") as (b, idx):
val_ptr = b.gep(all_out, [ctx.int32_ty(0), idx])
val = b.load(val_ptr)
max_val = b.load(max_val_ptr)
new_max = b.fcmp_unordered(">", val, max_val)
with b.if_then(new_max):
b.store(val, max_val_ptr)
b.store(idx, max_pos_ptr)

max_val = builder.load(max_val_ptr)
max_pos = builder.load(max_pos_ptr)

with pnlvm.helpers.array_ptr_loop(builder, all_out, id="derivative") as (b, idx):
val_ptr = b.gep(all_out, [ctx.int32_ty(0), idx])
val = b.load(val_ptr)
is_max_pos = b.icmp_unsigned("==", idx, max_pos)

d = b.select(is_max_pos, val.type(1), val.type(0))
dv = b.fsub(d, max_val)
val = b.fmul(val, dv)

out_ptr = b.gep(arg_out, [ctx.int32_ty(0), idx])
b.store(val, out_ptr)

return builder

def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, output_type=None, *, tags:frozenset):
output_type = self.output if output_type is None else output_type
if "derivative" in tags:
return self._gen_llvm_function_derivative_body(ctx, builder, params, state, arg_in, arg_out, tags=tags)

if self.parameters.per_item.get():
assert isinstance(arg_in.type.pointee.element, pnlvm.ir.ArrayType)
Expand Down Expand Up @@ -2781,8 +2820,8 @@ def derivative(self, input=None, output=None, context=None):
"""

output_type = self._get_current_parameter_value(OUTPUT_TYPE, context)
size = len(output)
sm = self.function(output, params={OUTPUT_TYPE: ALL}, context=context)
size = len(input)
sm = self.function(input, params={OUTPUT_TYPE: ALL}, context=context)
sm = np.squeeze(sm)

if output_type == ALL:
Expand All @@ -2800,7 +2839,7 @@ def derivative(self, input=None, output=None, context=None):
# Return 1d array of derivatives for max element (i.e., the one chosen by SoftMax)
derivative = np.empty(size)
# Get the element of output returned as non-zero when output_type is not ALL
index_of_max = int(np.where(output == np.max(output))[0][0])
index_of_max = int(np.where(sm == np.max(sm))[0])
max_val = sm[index_of_max]
for i in range(size):
if i == index_of_max:
Expand Down
6 changes: 6 additions & 0 deletions tests/functions/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def test_execute(func, variable, params, expected, benchmark, func_mode):
(Functions.Logistic, test_var, {'gain':RAND1, 'x_0':RAND2, 'offset':RAND3, 'scale':RAND4}, RAND1 * RAND4 * logistic_helper * (1 - logistic_helper)),
(Functions.ReLU, test_var, {'gain':RAND1, 'bias':RAND2, 'leak':RAND3}, np.where((test_var - RAND2) > 0, RAND1, RAND1 * RAND3)),
(Functions.Tanh, test_var, {'gain':RAND1, 'bias':RAND2, 'offset':RAND3, 'scale':RAND4}, tanh_derivative_helper),
(Functions.SoftMax, test_var, {'gain':RAND1, 'params':{kw.OUTPUT_TYPE:kw.MAX_VAL}, 'per_item': False},
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]),
(Functions.SoftMax, test_var, {'gain':RAND1, 'params':{kw.OUTPUT_TYPE:kw.MAX_INDICATOR}, 'per_item': False},
[-0.010680386821751537, -0.011118109698906909, -0.01082040340318878, -0.010670257514724047, -0.010362498859374309,
-0.010933660158663306, -0.010397412260182806, -0.011602329078808718, 0.09684744183944892, -0.010262384043848513]),
]

@pytest.mark.function
Expand Down