Skip to content

Commit

Permalink
gpt2 inference working
Browse files Browse the repository at this point in the history
  • Loading branch information
xl0 committed Dec 7, 2023
1 parent a65c794 commit 40e376b
Show file tree
Hide file tree
Showing 7 changed files with 598 additions and 65 deletions.
138 changes: 120 additions & 18 deletions nbs/01_tensor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"import numpy as np\n",
"from lovely_numpy import lovely\n",
"\n",
"\n",
"class Tensor:\n",
" ..."
]
Expand All @@ -57,14 +58,14 @@
" \"\"\"Calculate the target shape for broadcasting two tensors\"\"\"\n",
"\n",
" # expand shaped to be the same length. Note (1,) * <negative> is empty\n",
" s2 = (1, ) * (len(s1) - len(s2)) + s2\n",
" s1 = (1, ) * (len(s2) - len(s1)) + s1\n",
" s2 = (1,) * (len(s1) - len(s2)) + s2\n",
" s1 = (1,) * (len(s2) - len(s1)) + s1\n",
"\n",
" out_shape = ()\n",
" for dims in list(zip(reversed(s1), reversed(s2))):\n",
" if dims[0] != 1 and dims[1] != 1 and dims[0] != dims[1]:\n",
" raise ValueError(f\"Cannot broadcast {s1} and {s2}\")\n",
" out_shape = (max(dims), ) + out_shape\n",
" out_shape = (max(dims),) + out_shape\n",
"\n",
" return out_shape"
]
Expand All @@ -88,6 +89,7 @@
"\n",
" return a, b\n",
"\n",
"\n",
"def maybe_broadcast_matmul(a: Tensor, b: Tensor):\n",
" \"\"\"Broadcast two tensors if they have different shapes, except for the last two dimensions\"\"\"\n",
"\n",
Expand All @@ -99,8 +101,16 @@
" # print(\n",
" # f\"Matmul broadcasted {a.data.shape} and {b.data.shape} to {target_shape + a.data.shape[-2:]} and {target_shape + b.data.shape[-2:]}\"\n",
" # )\n",
" a = (a.broadcast(target_shape + a.data.shape[-2:]) if a_short_shape != target_shape else a)\n",
" b = (b.broadcast(target_shape + b.data.shape[-2:]) if b_short_shape != target_shape else b)\n",
" a = (\n",
" a.broadcast(target_shape + a.data.shape[-2:])\n",
" if a_short_shape != target_shape\n",
" else a\n",
" )\n",
" b = (\n",
" b.broadcast(target_shape + b.data.shape[-2:])\n",
" if b_short_shape != target_shape\n",
" else b\n",
" )\n",
"\n",
" return a, b"
]
Expand All @@ -112,37 +122,59 @@
"outputs": [],
"source": [
"# | export\n",
"_num_ops = 0\n",
"\n",
"\n",
"class BaseOp:\n",
" \"\"\"Base class for all operations\"\"\"\n",
"\n",
" name_template = \"??\"\n",
"\n",
" def __init__(self, *args, name: str = None):\n",
" assert isinstance(name, (str, type(None))), f\"name= should be str, got {type(name)}. You probably meant something else.\"\n",
"\n",
" self.args = [arg if isinstance(arg, Tensor) else Tensor(data=np.asarray(arg, dtype=np.float32)) for arg in args]\n",
" self.name = (self.name_template.format(*[arg.name for arg in self.args]) if name is None else name)\n",
" global _num_ops\n",
" _num_ops += 1\n",
" assert isinstance(\n",
" name, (str, type(None))\n",
" ), f\"name= should be str, got {type(name)}. You probably meant something else.\"\n",
"\n",
" self.args = [\n",
" arg\n",
" if isinstance(arg, Tensor)\n",
" else Tensor(data=np.asarray(arg, dtype=np.float32))\n",
" for arg in args\n",
" ]\n",
" self.name = \"\" # (self.name_template.format(*[arg.name for arg in self.args]) if name is None else name)\n",
" self.requires_grad = any(arg.requires_grad for arg in self.args) and _grad\n",
"\n",
" def set_out(self, data):\n",
" self.out = Tensor(data=data, requires_grad=self.requires_grad, name=self.name, op=self)\n",
" self.out = Tensor(\n",
" data=data, requires_grad=self.requires_grad, name=self.name, op=self\n",
" )\n",
"\n",
" def check_backward(self):\n",
" # Add more checks here?\n",
" assert (self.out.requires_grad), f\"You are trying to backpropagate through a non-differentiable operation:\\n{self}\"\n",
" assert (\n",
" self.out.requires_grad\n",
" ), f\"You are trying to backpropagate through a non-differentiable operation:\\n{self}\"\n",
"\n",
" def __repr__(self):\n",
" return (f\"{self.__class__.__name__}({', '.join([str(arg) for arg in self.args])})\")\n",
" return (\n",
" f\"{self.__class__.__name__}({', '.join([str(arg) for arg in self.args])})\"\n",
" )\n",
"\n",
"\n",
"class BinaryElementwiseOp(BaseOp):\n",
" \"\"\"Base class for binary elementwise operations\"\"\"\n",
"\n",
" def __init__(self, a, b, name=None):\n",
" super().__init__(a, b, name=name)\n",
" self.args = maybe_broadcast_elementwise(*self.args)\n",
" self.parents = self.args if self.requires_grad else []\n",
"\n",
"\n",
"class UnaryElementwiseOp(BaseOp):\n",
" \"\"\"Base class for unary elementwise operations\"\"\"\n",
"\n",
" def __init__(self, a, name=None):\n",
" super().__init__(a, name=name)\n",
" self.parents = self.args if self.requires_grad else []"
Expand Down Expand Up @@ -391,6 +423,41 @@
" self.parents[0].accum_grad(summed)\n",
"\n",
"\n",
"class Slice(UnaryElementwiseOp):\n",
" name_template = \"slice({})\"\n",
"\n",
" def __init__(self, a, key, name=None):\n",
" super().__init__(a, name=name)\n",
" self.key = key\n",
" self.set_out(self.args[0].data[key])\n",
"\n",
" def backward(self):\n",
" self.check_backward()\n",
" p = self.parents[0]\n",
"\n",
" if not p.requires_grad:\n",
" return\n",
"\n",
" if p.grad is None:\n",
" p.grad = np.zeros_like(p.data)\n",
"\n",
" p.grad[self.key] += self.out.grad\n",
"\n",
"# class SetSlice(BaseOp):\n",
"# name_template = \"setslice({})\"\n",
" \n",
"# def __init__(self, a: Tensor, key, value: Tensor, name=None):\n",
"# # super().__init__(a, value, name=name)\n",
"# self.key = key\n",
" \n",
"# a.out[key] = value.out\n",
"# a.parents.append(value)\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"# class LessThan(BinaryElementwiseOp):\n",
"# name_template = \"({}<{})\"\n",
"\n",
Expand Down Expand Up @@ -459,11 +526,17 @@
"outputs": [],
"source": [
"# | export\n",
"\n",
"_num_tensors = 0\n",
"\n",
"\n",
"class Tensor:\n",
" # op = \"L\"\n",
" name: str = \"\"\n",
"\n",
" def __init__(self, data, name=None, op=None, eps=1e-8, requires_grad=False):\n",
" global _num_tensors\n",
" _num_tensors += 1\n",
" self.data = np.asarray(data)\n",
"\n",
" self.grad = (\n",
Expand All @@ -482,7 +555,8 @@
" if self.op.parents\n",
" else \"\"\n",
" )\n",
" return f'Tensor{list(self.data.shape)}(name=\"{self.name}\" op={type(self.op).__name__}{parents}):\\n {value_str}\\n {grad_str}'\n",
" # name=\"{self.name}\n",
" return f'Tensor{list(self.data.shape)}(\" op={type(self.op).__name__}{parents}):\\n {value_str}\\n {grad_str}'\n",
"\n",
" def accum_grad(self, grad):\n",
" if not self.requires_grad:\n",
Expand Down Expand Up @@ -552,9 +626,7 @@
" else np.prod([self.data.shape[i] for i in axis])\n",
" )\n",
"\n",
" corrected = var.sum(axis=axis, keepdims=keepdims) / (\n",
" numel - correction\n",
" )\n",
" corrected = var.sum(axis=axis, keepdims=keepdims) / (numel - correction)\n",
"\n",
" return corrected**0.5\n",
"\n",
Expand Down Expand Up @@ -595,6 +667,12 @@
" other = other if isinstance(other, Tensor) else Tensor(other)\n",
" return self.data == other.data\n",
"\n",
" def __getitem__(self, key):\n",
" return Slice(self, key).out\n",
" \n",
" def __setitem__(self, key, value):\n",
" return SetSlice(self, key, value)\n",
"\n",
" @property\n",
" def shape(self):\n",
" return self.data.shape\n",
Expand Down Expand Up @@ -639,8 +717,32 @@
{
"data": {
"text/plain": [
"Tensor[2, 3](name=\"x\" op=Load):\n",
" v=array[2, 3] n=6 x∈[-1.591, 2.410] μ=0.177 σ=1.231 [[-0.272, 0.850, 2.410], [0.039, -1.591, -0.376]]\n",
"Tensor[5, 5](\" op=Slice):\n",
" v=array[5, 5] n=25 x∈[-2.029, 2.323] μ=0.259 σ=1.027\n",
" "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = Tensor(np.random.randn(10, 10), name=\"x\")\n",
"\n",
"x[:5, 5:]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor[2, 3](\" op=Load):\n",
" v=array[2, 3] n=6 x∈[-0.472, 1.793] μ=0.604 σ=0.758 [[1.215, 1.793, 0.202], [0.824, 0.061, -0.472]]\n",
" "
]
},
Expand Down
Loading

0 comments on commit 40e376b

Please sign in to comment.