redis对于各位程序员大佬肯定是耳熟能详了,一个单线程高性能的内存数据库,常用于存放缓存,不用再去读数据库,减少数据库的压力。
为什么这么快呢?除了单线程,除了数据结构,除了全部在内存,除了IO多路复用等,还有就是redis底层通信协议设计的简单高效也是重要的一方面。
redis协议的设计哲学:
1. 实现要简单
2. 对计算机来说,解析速度快
3. 对人类来说,可读性强
redis协议的简单介绍
其实从官方文档的介绍就能知道这协议真的是设计的简单精妙,不拖泥带水。
个人认为“\r\n”的分隔符可以避免网络编程中常说的粘包问题吧。
我们这里不去讨论redis的高级功能,我们这里只关注redis基本的数据结构操作。
官方文档的介绍可以查阅: https://redis.io/docs/reference/protocol-spec/
其实就是以下五种情况:
• For Simple Strings, the first byte of the reply is "+"
• For Errors, the first byte of the reply is "-"
• For Integers, the first byte of the reply is ":"
• For Bulk Strings, the first byte of the reply is "$"
• For Arrays, the first byte of the reply is "*"
第一种就是简单字符串,以字符"+"号开头,后面就接上实际的字符串即可,但是别忘了最后“\r\n”代表字符串结束了,以下几种同理。
第二种就是错误提示,以字符"-"开头。
第三种就是整形类型,以字符":"开头。
第四种就是多条字符串返回的时候会用到,先用字符"$"后面的数字代表后面有多少个字符,这种方式不仅是在发送命令报文的时候会用到,而且在接收响应报文的时候也会用到。
第五种就是代表数字,同样的在发送和接收时都会用到。
如果你现在不太懂没关系,继续看下去,通过手敲代码肯定能够帮助自己理解更深刻。
redis协议简单实现
这里我们选择python语言来实现一下,python语言简洁易懂,非常适合来写一个简单入门的redis client.
事先声明下,这里不会考虑客户端维护连接池什么的,发多个请求去和redis服务端去通信,我们的代码就是线性地一步步地运行,我们关注的重点是在协议的解析上面。
选择的框架是: asyncio
初始化连接
连接的代码比较简单,返回一个reader和writer用于从建立的连接中读取和写入数据。
class RedisClient:
def __init__(self):
self.writer = None
self.reader = None
# 连接redis服务器的ip地址和端口
async def connect(self,host,port):
self.reader, self.writer = await asyncio.open_connection(host, port)
set key value
# 设置一个key value
async def set(self, key, value):
# 发送命令
self.writer.write(f"set \"{key}\" \"{value}\"\r\n".encode())
await self.writer.drain()
return await self.get_reply()
这里也没啥难度吧
get key
# 获取一个key
async def get(self, key):
# 发送命令
self.writer.write(f"get \"{key}\"\r\n".encode())
await self.writer.drain()
return await self.get_reply()
一样很好理解
解析响应(RESP)
这里就是重头戏了,先放一下代码:
# 接收响应
async def get_reply(self):
# 识别返回响应的第一个字符
tag = (await self.reader.read(1))
# 如果是简单字符串响应
if tag == b'+':
result = b''
next_char = b''
while next_char != b'\n':
next_char = (await self.reader.read(1))
result += next_char
return result[:-1].decode()
# 如果是错误
elif tag == b'-':
result = b''
next_char = b''
while next_char != b'\n':
next_char = (await self.reader.read(1))
result += next_char
raise Exception(result[:-1].decode())
# 如果是bulk字符串
elif tag == b'$':
# 先获取总长度
result = b''
next_char = b''
while next_char != b'\n':
next_char = (await self.reader.read(1))
result += next_char
# 加上最后\r\n的2个字符长度
total_length = int(result[:-1].decode()) + 2
# 如果是-1的话,其实是代表字符串为一个NULL value
if total_length == 1:
await self.reader.read(2)
return ''
result = b''
while len(result) < total_length:
result += (await self.reader.read(total_length))
return result[:-2].decode()
# 说明返回的是一个数字
elif tag == b':':
result = b''
next_char = b''
while next_char != b'\n':
next_char = (await self.reader.read(1))
result += next_char
return int(result[:-1].decode())
# 说明返回的是一个数组
elif tag == b'*':
# 先获取数组总长度
result = b''
next_char = b''
while next_char != b'\n':
next_char = (await self.reader.read(1))
result += next_char
# 加上最后\r\n的2个字符长度
arr_length = int(result[:-1].decode())
if arr_length == 0:
return []
else:
res = []
for _ in range(arr_length):
res.append(await self.get_reply())
return res
else:
msg = (await self.reader.read(100))
# msg[:-2]代表去掉最后的\r\n
raise Exception(f"unsupported tag: {tag}, msg: {msg[:-2].decode()}")
正常客户端在建立的连接中发送命令报文,服务端收到的时候肯定会返回客户端响应报文,因此客户端需要通过reader去读取服务端返回的数据到底是什么意思。
所以才有“协议”这种双方约定好的东西。
所以一般来说,客户端都是先读取响应报文的第一个字符来判断接下来的操作,对于四种简单的来说,我们这里只讲解下简单字符串的,比如服务端给客户端回复的是:
+OK\r\n
客户端收到的时候发现第一个字符是+号,那么后面只需要读取一个简单的字符串就行,遇到\r\n肯定代表字符串读取结束了,那么直接根据OK去正常处理之后的逻辑即可。
特别要注意的是数组的报文形式,数组的这种格式也可以用在发送端,比如上面的send方法,例如:
*3\r\n$3\r\nset\r\n$5\r\nhello\r\n$5\r\nworld\r\n
其实这个报文翻译下来就是发送一个set命令:
set hello world
对照一下,第一个3代表该报文有三个字符串,后面每个字符串都要先指定字符串的长度再跟上实际的字符串即可。
是不是也不是那么复杂?
为什么要有这个数组的形式呢?大家可以自己思考下,我的理解是:正常字符串中如果有\r\n字符串的话,会影响分隔符的判断。
例如下面的这段代码,如果直接用set命令会报错:
# 会报错:ERR Protocol error: unbalanced quotes in request
#print(await client.set("banana", "hello \n world"))
print(await client.send("set", "banana", "hello \n world"))
下面放一下完整的代码,大家有空的时候一定要动手敲一遍才会有更深刻的认识。
import asyncio
class RedisClient:
def __init__(self):
self.writer = None
self.reader = None
async def connect(self,host,port):
self.reader, self.writer = await asyncio.open_connection(host, port)
# 直接发送命令
async def send_cmd(self, cmd):
self.writer.write(f"{cmd}\r\n".encode())
await self.writer.drain()
return await self.get_reply()
# 设置一个key value
async def set(self, key, value):
# 发送命令
self.writer.write(f"set \"{key}\" \"{value}\"\r\n".encode())
await self.writer.drain()
return await self.get_reply()
# 获取一个key
async def get(self, key):
# 发送命令
self.writer.write(f"get \"{key}\"\r\n".encode())
await self.writer.drain()
return await self.get_reply()
# 增长一个key
async def incr(self, key):
# 发送命令
self.writer.write(f"incr \"{key}\"\r\n".encode())
await self.writer.drain()
return await self.get_reply()
# 接收响应
async def get_reply(self):
# 识别返回响应的第一个字符
tag = (await self.reader.read(1))
# 如果是简单字符串响应
if tag == b'+':
result = b''
next_char = b''
while next_char != b'\n':
next_char = (await self.reader.read(1))
result += next_char
return result[:-1].decode()
# 如果是错误
elif tag == b'-':
result = b''
next_char = b''
while next_char != b'\n':
next_char = (await self.reader.read(1))
result += next_char
raise Exception(result[:-1].decode())
# 如果是bulk字符串
elif tag == b'$':
# 先获取总长度
result = b''
next_char = b''
while next_char != b'\n':
next_char = (await self.reader.read(1))
result += next_char
# 加上最后\r\n的2个字符长度
total_length = int(result[:-1].decode()) + 2
# 如果是-1的话,其实是代表字符串为一个NULL value
if total_length == 1:
await self.reader.read(2)
return ''
result = b''
while len(result) < total_length:
result += (await self.reader.read(total_length))
return result[:-2].decode()
# 说明返回的是一个数字
elif tag == b':':
result = b''
next_char = b''
while next_char != b'\n':
next_char = (await self.reader.read(1))
result += next_char
return int(result[:-1].decode())
# 说明返回的是一个数组
elif tag == b'*':
# 先获取数组总长度
result = b''
next_char = b''
while next_char != b'\n':
next_char = (await self.reader.read(1))
result += next_char
# 加上最后\r\n的2个字符长度
arr_length = int(result[:-1].decode())
if arr_length == 0:
return []
else:
res = []
for _ in range(arr_length):
res.append(await self.get_reply())
return res
else:
msg = (await self.reader.read(100))
# msg[:-2]代表去掉最后的\r\n
raise Exception(f"unsupported tag: {tag}, msg: {msg[:-2].decode()}")
# 以数组形式发送
async def send(self,*args):
cmd_args = "".join([f"${len(x)}\r\n{x}\r\n" for x in args])
self.writer.write(f"*{len(args)}\r\n{cmd_args}".encode())
await self.writer.drain()
return await self.get_reply()
async def main():
print("=====redis client for py3 started!=====")
client = RedisClient()
# 连接redis
await client.connect("localhost",6379)
print(await client.set("banana","hello world"))
print(await client.get("banana"))
# 删除,不然会报错
print(await client.send_cmd("del banana"))
print(await client.incr("banana"))
print(await client.incr("banana"))
print(await client.incr("banana"))
print(await client.send_cmd("del banana"))
# 会报错:ERR Protocol error: unbalanced quotes in request
#print(await client.set("banana", "hello \n world"))
print(await client.send("set", "banana", "hello \n world"))
print(await client.send("hset","myhash","k1", "v1", "k2", "v2"))
print(await client.send("hget","myhash","k1"))
print(await client.send("hgetall","myhash"))
if __name__ == '__main__':
asyncio.run(main())
python代码水平有限,如果哪里有所不足,欢迎指正,就先写到这!




