Skip to content

Commit

Permalink
implement get set functinality
Browse files Browse the repository at this point in the history
  • Loading branch information
rebeccacxy committed Mar 9, 2024
1 parent 609fe26 commit 8d42623
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 10 deletions.
72 changes: 62 additions & 10 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,74 @@
import asyncio
import io
import app.redis_db as redis_db

def parse_wire_protocol(message):
return _parse_wire_protocol(io.BytesIO(message))

def _parse_wire_protocol(msg_buffer):
current_line = msg_buffer.readline()
msg_type, remaining = chr(current_line[0]), current_line[1:]
if msg_type == '+':
return remaining.rstrip(b'\r\n').decode()
elif msg_type == ':':
return int(remaining)
elif msg_type == '$':
msg_length = int(remaining)
if msg_length == -1:
return None
result = msg_buffer.read(msg_length)
msg_buffer.readline() # move past \r\n
return result
elif msg_type == '*':
array_length = int(remaining)
return [_parse_wire_protocol(msg_buffer) for _ in range(array_length)]

def serialize_to_wire(value):
if isinstance(value, str):
return ('+%s' % value).encode() + b'\r\n'
elif isinstance(value, bool) and value:
return b"+OK\r\n"
elif isinstance(value, int):
return (':%s' % value).encode() + b'\r\n'
elif isinstance(value, bytes):
return (b'$' + str(len(value)).encode() +
b'\r\n' + value + b'\r\n')
elif value is None:
return b'$-1\r\n'
elif isinstance(value, list):
base = b'*' + str(len(value)).encode() + b'\r\n'
for item in value:
base += serialize_to_wire(item)
return base

class RedisServerProtocol(asyncio.Protocol):
def __init__(self, db):
self._db = db
self.transport = None

def connection_made(self, transport):
self.transport = transport

def data_received(self, data):
message = data.decode()
if 'GET' in message:
self.transport.write(b"$3\r\n")
self.transport.write(b"BAZ\r\n")
else:
self.transport.write(b"-ERR unknown command\r\n")

def main():
parsed = parse_wire_protocol(data)
command = parsed[0].decode().lower()
if command == 'get':
response = self._db.get(parsed[1])
elif command == 'set':
response = self._db.set(parsed[1], parsed[2])

wire_response = serialize_to_wire(response)
self.transport.write(wire_response)

def main(hostname='localhost', port=6379):
loop = asyncio.get_event_loop()
coro = loop.create_server(RedisServerProtocol, '127.0.0.1', 6379)
server = loop.run_until_complete(coro)
db = redis_db.DB()
protocol_factory = lambda: RedisServerProtocol(db)

coro = loop.create_server(protocol_factory,
hostname, port)
server = loop.run_until_complete(coro)
print("Listening on port {}".format(port))
try:
loop.run_forever()
except KeyboardInterrupt:
Expand Down
82 changes: 82 additions & 0 deletions app/my_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import asyncio
import io

def parse_wire_protocol(message):
return _parse_wire_protocol(io.BytesIO(message))

def _parse_wire_protocol(msg_buffer):
current_line = msg_buffer.readline()
msg_type, remaining = chr(current_line[0]), current_line[1:]
if msg_type == '+':
return remaining.rstrip(b'\r\n').decode()
elif msg_type == ':':
return int(remaining)
elif msg_type == '$':
msg_length = int(remaining)
if msg_length == -1:
return None
result = msg_buffer.read(msg_length)
# There's a '\r\n' that comes after a bulk string
# so we .readline() to move passed that crlf.
msg_buffer.readline()
return result
elif msg_type == '*':
array_length = int(remaining)
return [_parse_wire_protocol(msg_buffer) for _ in range(array_length)]

def serialize_to_wire(value):
if isinstance(value, str):
return ('+%s' % value).encode() + b'\r\n'
elif isinstance(value, bool) and value:
return b"+OK\r\n"
elif isinstance(value, int):
return (':%s' % value).encode() + b'\r\n'
elif isinstance(value, bytes):
return (b'$' + str(len(value)).encode() +
b'\r\n' + value + b'\r\n')
elif value is None:
return b'$-1\r\n'
elif isinstance(value, list):
base = b'*' + str(len(value)).encode() + b'\r\n'
for item in value:
base += serialize_to_wire(item)
return base

class RedisServerProtocol(asyncio.Protocol):
def __init__(self, redis):
self._redis = redis
self.transport = None

def connection_made(self, transport):
self.transport = transport

def data_received(self, data):
parsed = parse_wire_protocol(data)
# parsed is an array of [command, *args]
command = parsed[0].decode().lower()
try:
method = getattr(self._redis, command)
except AttributeError:
self.transport.write(
b"-ERR unknown command " + parsed[0] + b"\r\n")
return
result = method(*parsed[1:])
serialized = serialize_to_wire(result)
self.transport.write(serialized)

class WireRedisConverter(object):
def __init__(self, redis):
self._redis = redis

def lrange(self, name, start, end):
return self._redis.lrange(name, int(start), int(end))

def hmset(self, name, *args):
converted = {}
iter_args = iter(list(args))
for key, val in zip(iter_args, iter_args):
converted[key] = val
return self._redis.hmset(name, converted)

def __getattr__(self, name):
return getattr(self._redis, name)
12 changes: 12 additions & 0 deletions app/redis_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class DB:
def __init__(self, db=None):
if db is None:
db = {}
self._db = db

def get(self, item):
return self._db.get(item)

def set(self, item, value):
self._db[item] = value
return True

0 comments on commit 8d42623

Please sign in to comment.