暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

大模型推理性能优化之KV Cache解读

motianlun20240602 2024-06-11
1275

0. 引言

做大模型性能优化的一定对KV Cache不陌生,那么我们对这个技术了解到什么程度呢?请尝试回答如下问题:

  1. KV Cache节省了Self-Attention层中哪部分的计算?
  2. KV Cache对MLP层的计算量有影响吗?
  3. KV Cache对block间的数据传输量有影响吗?

本文打算剖析该技术并给出上面问题的答案。

1. KV Cache是啥?

大模型推理性能优化的一个常用技术是KV Cache,该技术可以在不影响任何计算精度的前提下,通过空间换时间思想,提高推理性能。网上有一些关于该技术的分析博客,但读过后仍然会很迷糊,甚至可能会被带偏,认为这个Cache过程和数据库读取或CPU Cache加速类似的荒谬结论。刚开始我也有类似误解,直到逐行查阅并运行源码,才清楚了解到其Cache了啥,以及如何节省计算的。

2. 背景

生成式generative模型的推理过程很有特点,我们给一个输入文本,模型会输出一个回答(长度为N),其实该过程中执行了N次推理过程。即GPT类模型一次推理只输出一个token,输出token会与输入tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。

如上描述是我们通常认知的GPT推理过程。

3. 原理

在上面的推理过程中,每 step 内,输入一个 token序列,经过Embedding层将输入token序列变为一个三维张量[b, s, h],经过一通计算,最后经logits层将计算结果映射至词表空间,输出张量维度为[b, s, vocab_size]。

当前轮输出token与输入tokens拼接,并作为下一轮的输入tokens,反复多次。可以看出第𝑖+1 轮输入数据只比第𝑖轮输入数据新增了一个token,其他全部相同!因此第𝑖+1轮推理时必然包含了第 𝑖 轮的部分计算。KV Cache的出发点就在这里,缓存当前轮可重复利用的计算结果,下一轮计算时直接读取缓存结果,就是这么简单,不存在什么Cache miss问题。

4. 实现细节

目前各大模型推理都实现了KV Cache,下面就看如何使用了。我们可以在上面代码基础上修改,主要改动:

  • 在推理时新增了 past_key_values 参数,该参数就会以追加方式保存每一轮的K V值。kvcache变量内容为((k,v), (k,v), ..., (k,v)),即有 𝑛𝑙𝑎𝑦𝑒𝑟𝑠 个 k,v 组成的一个元组,其中 k 和 v 的维度均为 [b, n_head, s, head_dims]。这里可以顺带计算出每轮推理对应的 cache 数据量为 2𝑏𝑠𝑛𝑙𝑎𝑦𝑒𝑟𝑠 ,这里 𝑠 值等于当前轮次值。以GPT3-175B为例,假设以 float16 来保存 KV cache,senquence长度为100,batchsize=1,则 KV cache占用显存为 2×100×12288×96×2 Byte= 472MB。
  • 推理输出的token直接作为下一轮的输入,不再拼接,因为上文信息已经在 kvcache 中。
  • 其实,KV Cache 配置开启后,推理过程可以分为2个阶段:

    1. 预填充阶段:发生在计算第一个输出token过程中,这时Cache是空的,计算时需要为每个 transformer layer 计算并保存key cache和value cache,在输出token时Cache完成填充;FLOPs同KV Cache关闭一致,存在大量gemm操作,推理速度慢。
    2. 使用KV Cache阶段:发生在计算第二个输出token至最后一个token过程中,这时Cache是有值的,每轮推理只需读取Cache,同时将当前轮计算出的新的Key、Value追加写入至Cache;FLOPs降低,gemm变为gemv操作,推理速度相对第一阶段变快,这时属于Memory-bound类型计算。

    这里用图可能更有助理解,下图是一个Decoder Block(仅以MHA为例),含有Self-Attention和MLP,标红部分为KV Cache影响到的内容,即KV Cache开启后,标红的序列长度 𝑠 变为 1,当batch_size=1时,Self-Attention中的2个dense全都变为gemv操作,MLP中的dense也全都变为gemv操作。看懂这个图就可以答对上面的3个问题啦。图中数据维度相关字母的含义:

5. 总结

KV Cache是Transformer推理性能优化的一项重要工程化技术,各大推理框架都已实现并将其进行了封装(例如 transformers库 generate 函数已经将其封装,用户不需要手动传入past_key_values)并默认开启(config.json文件中use_cache=True)。

「喜欢这篇文章,您的关注和赞赏是给作者最好的鼓励」
关注作者
【版权声明】本文为墨天轮用户原创内容,转载时必须标注文章的来源(墨天轮),文章链接,文章作者等基本信息,否则作者和墨天轮有权追究责任。如果您发现墨天轮中有涉嫌抄袭或者侵权的内容,欢迎发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论