暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

是不是很好奇如何动手写一个简单的redis客户端?

游在鱼里的水 2022-12-12
417

redis对于各位程序员大佬肯定是耳熟能详了,一个单线程高性能的内存数据库,常用于存放缓存,不用再去读数据库,减少数据库的压力。

为什么这么快呢?除了单线程,除了数据结构,除了全部在内存,除了IO多路复用等,还有就是redis底层通信协议设计的简单高效也是重要的一方面。

redis协议的设计哲学:

  1. 1. 实现要简单

  2. 2. 对计算机来说,解析速度快

  3. 3. 对人类来说,可读性强

redis协议的简单介绍

其实从官方文档的介绍就能知道这协议真的是设计的简单精妙,不拖泥带水。

个人认为“\r\n”的分隔符可以避免网络编程中常说的粘包问题吧。

我们这里不去讨论redis的高级功能,我们这里只关注redis基本的数据结构操作。

官方文档的介绍可以查阅: https://redis.io/docs/reference/protocol-spec/

其实就是以下五种情况:

  • • For Simple Strings, the first byte of the reply is "+"

  • • For Errors, the first byte of the reply is "-"

  • • For Integers, the first byte of the reply is ":"

  • • For Bulk Strings, the first byte of the reply is "$"

  • • For Arrays, the first byte of the reply is "*"

第一种就是简单字符串,以字符"+"号开头,后面就接上实际的字符串即可,但是别忘了最后“\r\n”代表字符串结束了,以下几种同理。

第二种就是错误提示,以字符"-"开头。

第三种就是整形类型,以字符":"开头。

第四种就是多条字符串返回的时候会用到,先用字符"$"后面的数字代表后面有多少个字符,这种方式不仅是在发送命令报文的时候会用到,而且在接收响应报文的时候也会用到。

第五种就是代表数字,同样的在发送和接收时都会用到。

如果你现在不太懂没关系,继续看下去,通过手敲代码肯定能够帮助自己理解更深刻。

redis协议简单实现

这里我们选择python语言来实现一下,python语言简洁易懂,非常适合来写一个简单入门的redis client.

事先声明下,这里不会考虑客户端维护连接池什么的,发多个请求去和redis服务端去通信,我们的代码就是线性地一步步地运行,我们关注的重点是在协议的解析上面。

选择的框架是: asyncio

初始化连接

连接的代码比较简单,返回一个reader和writer用于从建立的连接中读取和写入数据。

class RedisClient:
    def __init__(self):
        self.writer = None
        self.reader = None
    # 连接redis服务器的ip地址和端口
    async def connect(self,host,port):
        self.reader, self.writer = await asyncio.open_connection(host, port)

set key value

# 设置一个key value
    async def set(self, key, value):
        # 发送命令
        self.writer.write(f"set \"{key}\" \"{value}\"\r\n".encode())
        await self.writer.drain()
        return await self.get_reply()

这里也没啥难度吧

get key

# 获取一个key
    async def get(self, key):
        # 发送命令
        self.writer.write(f"get \"{key}\"\r\n".encode())
        await self.writer.drain()
        return await self.get_reply()

一样很好理解

解析响应(RESP)

这里就是重头戏了,先放一下代码:

# 接收响应
    async def get_reply(self):
        # 识别返回响应的第一个字符
        tag = (await self.reader.read(1))
        # 如果是简单字符串响应
        if tag == b'+':
            result = b''
            next_char = b''
            while next_char != b'\n':
                next_char = (await self.reader.read(1))
                result += next_char
            return result[:-1].decode()
        # 如果是错误
        elif tag == b'-':
            result = b''
            next_char = b''
            while next_char != b'\n':
                next_char = (await self.reader.read(1))
                result += next_char
            raise  Exception(result[:-1].decode())
        # 如果是bulk字符串
        elif tag == b'$':
            # 先获取总长度
            result = b''
            next_char = b''
            while next_char != b'\n':
                next_char = (await self.reader.read(1))
                result += next_char
            # 加上最后\r\n的2个字符长度
            total_length = int(result[:-1].decode()) + 2
            # 如果是-1的话,其实是代表字符串为一个NULL value
            if total_length == 1:
                await self.reader.read(2)
                return ''
            result = b''
            while len(result) < total_length:
                result += (await self.reader.read(total_length))
            return result[:-2].decode()
        # 说明返回的是一个数字
        elif tag == b':':
            result = b''
            next_char = b''
            while next_char != b'\n':
                next_char = (await self.reader.read(1))
                result += next_char
            return int(result[:-1].decode())
        # 说明返回的是一个数组
        elif tag == b'*':
            # 先获取数组总长度
            result = b''
            next_char = b''
            while next_char != b'\n':
                next_char = (await self.reader.read(1))
                result += next_char
            # 加上最后\r\n的2个字符长度
            arr_length = int(result[:-1].decode())
            if arr_length == 0:
                return []
            else:
                res = []
                for _ in range(arr_length):
                    res.append(await self.get_reply())
                return res
        else:
            msg = (await self.reader.read(100))
            # msg[:-2]代表去掉最后的\r\n
            raise Exception(f"unsupported tag: {tag}, msg: {msg[:-2].decode()}")

正常客户端在建立的连接中发送命令报文,服务端收到的时候肯定会返回客户端响应报文,因此客户端需要通过reader去读取服务端返回的数据到底是什么意思。

所以才有“协议”这种双方约定好的东西。

所以一般来说,客户端都是先读取响应报文的第一个字符来判断接下来的操作,对于四种简单的来说,我们这里只讲解下简单字符串的,比如服务端给客户端回复的是:

+OK\r\n

客户端收到的时候发现第一个字符是+号,那么后面只需要读取一个简单的字符串就行,遇到\r\n肯定代表字符串读取结束了,那么直接根据OK去正常处理之后的逻辑即可。

特别要注意的是数组的报文形式,数组的这种格式也可以用在发送端,比如上面的send方法,例如:

*3\r\n$3\r\nset\r\n$5\r\nhello\r\n$5\r\nworld\r\n

其实这个报文翻译下来就是发送一个set命令:

set hello world

对照一下,第一个3代表该报文有三个字符串,后面每个字符串都要先指定字符串的长度再跟上实际的字符串即可。

是不是也不是那么复杂?

为什么要有这个数组的形式呢?大家可以自己思考下,我的理解是:正常字符串中如果有\r\n字符串的话,会影响分隔符的判断。

例如下面的这段代码,如果直接用set命令会报错:

# 会报错:ERR Protocol error: unbalanced quotes in request
    #print(await client.set("banana", "hello \n world"))
    print(await client.send("set""banana""hello \n world"))

下面放一下完整的代码,大家有空的时候一定要动手敲一遍才会有更深刻的认识。

import asyncio

class RedisClient:
    def __init__(self):
        self.writer = None
        self.reader = None

    async def connect(self,host,port):
        self.reader, self.writer = await asyncio.open_connection(host, port)

    # 直接发送命令
    async def send_cmd(self, cmd):
        self.writer.write(f"{cmd}\r\n".encode())
        await self.writer.drain()
        return await self.get_reply()
    # 设置一个key value
    async def set(self, key, value):
        # 发送命令
        self.writer.write(f"set \"{key}\" \"{value}\"\r\n".encode())
        await self.writer.drain()
        return await self.get_reply()

    # 获取一个key
    async def get(self, key):
        # 发送命令
        self.writer.write(f"get \"{key}\"\r\n".encode())
        await self.writer.drain()
        return await self.get_reply()

    # 增长一个key
    async def incr(self, key):
        # 发送命令
        self.writer.write(f"incr \"{key}\"\r\n".encode())
        await self.writer.drain()
        return await self.get_reply()
    # 接收响应
    async def get_reply(self):
        # 识别返回响应的第一个字符
        tag = (await self.reader.read(1))
        # 如果是简单字符串响应
        if tag == b'+':
            result = b''
            next_char = b''
            while next_char != b'\n':
                next_char = (await self.reader.read(1))
                result += next_char
            return result[:-1].decode()
        # 如果是错误
        elif tag == b'-':
            result = b''
            next_char = b''
            while next_char != b'\n':
                next_char = (await self.reader.read(1))
                result += next_char
            raise  Exception(result[:-1].decode())
        # 如果是bulk字符串
        elif tag == b'$':
            # 先获取总长度
            result = b''
            next_char = b''
            while next_char != b'\n':
                next_char = (await self.reader.read(1))
                result += next_char
            # 加上最后\r\n的2个字符长度
            total_length = int(result[:-1].decode()) + 2
            # 如果是-1的话,其实是代表字符串为一个NULL value
            if total_length == 1:
                await self.reader.read(2)
                return ''
            result = b''
            while len(result) < total_length:
                result += (await self.reader.read(total_length))
            return result[:-2].decode()
        # 说明返回的是一个数字
        elif tag == b':':
            result = b''
            next_char = b''
            while next_char != b'\n':
                next_char = (await self.reader.read(1))
                result += next_char
            return int(result[:-1].decode())
        # 说明返回的是一个数组
        elif tag == b'*':
            # 先获取数组总长度
            result = b''
            next_char = b''
            while next_char != b'\n':
                next_char = (await self.reader.read(1))
                result += next_char
            # 加上最后\r\n的2个字符长度
            arr_length = int(result[:-1].decode())
            if arr_length == 0:
                return []
            else:
                res = []
                for _ in range(arr_length):
                    res.append(await self.get_reply())
                return res
        else:
            msg = (await self.reader.read(100))
            # msg[:-2]代表去掉最后的\r\n
            raise Exception(f"unsupported tag: {tag}, msg: {msg[:-2].decode()}")
    # 以数组形式发送
    async def send(self,*args):
        cmd_args = "".join([f"${len(x)}\r\n{x}\r\n" for x in args])
        self.writer.write(f"*{len(args)}\r\n{cmd_args}".encode())
        await self.writer.drain()
        return await self.get_reply()

async def main():
    print("=====redis client for py3 started!=====")
    client = RedisClient()
    # 连接redis
    await client.connect("localhost",6379)
    print(await client.set("banana","hello world"))

    print(await client.get("banana"))

    # 删除,不然会报错
    print(await client.send_cmd("del banana"))

    print(await client.incr("banana"))
    print(await client.incr("banana"))
    print(await client.incr("banana"))

    print(await client.send_cmd("del banana"))

    # 会报错:ERR Protocol error: unbalanced quotes in request
    #print(await client.set("banana", "hello \n world"))
    print(await client.send("set""banana""hello \n world"))

    print(await client.send("hset","myhash","k1""v1""k2""v2"))

    print(await client.send("hget","myhash","k1"))

    print(await client.send("hgetall","myhash"))


if __name__ == '__main__':
    asyncio.run(main())

python代码水平有限,如果哪里有所不足,欢迎指正,就先写到这!


文章转载自游在鱼里的水,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论