0%

glm3源码学习

原文链接:chatglm源码学习

GLM3源码:https://github.com/THUDM/ChatGLM3

我们直接从openai_api_demo入手,因为api_demo一般是nlp模型后端核心功能实现的部分

openai_api_demo源码

api_server.py

api_server.py是提供web api接口的入口文件,是使用flask框架提供的一个异步接口支持

1
app = FastAPI(lifespan=lifespan)
1
2
3
4
class ModelCard(BaseModel):
...
class ChatCompletionResponse(BaseModel):

上面这一堆class是实现chat这个api功能的主要对象,如模型卡、请求体和响应体

1
2
3
4
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)

这个是测试api状态函数,可以看到这个测试功能还是很直接的,没有考虑部署应用下的问题,如负载情况和安全状况,这个demo也就是一个学习的小demo项目。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
if isinstance(request.input, str):# 判断输入是否是字符串,字符串直接编码,否则对字符串列表编码
embeddings = [embedding_model.encode(request.input)]
else:
embeddings = [embedding_model.encode(text) for text in request.input]
embeddings = [embedding.tolist() for embedding in embeddings]

def num_tokens_from_string(string: str) -> int:
"""
Returns the number of tokens in a text string.
use cl100k_base tokenizer
"""
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens

response = {
"data": [
{
"object": "embedding",
"embedding": embedding,
"index": index
}
for index, embedding in enumerate(embeddings)
],
"model": request.model,
"object": "list",
"usage": CompletionUsage(
prompt_tokens=sum(len(text.split()) for text in request.input),
completion_tokens=0,
total_tokens=sum(num_tokens_from_string(text) for text in request.input),
)
}
return response

这个函数是获取文本向量编码的,sentences_to_embeddings功能。
这里面有个函数num_tokens_from_string是统计文本的tokens数量,使用的tiktoken模块是openai开源的一个快速分词统计库,cl100k_base是和gpt4同款编码器,也就是说glm3的tokenizer实际上是使用的gpt4的tokenizer,在论文里面glm的baseline是最开始的gpt-1模型,那从理论上,glm3的性能提升肯定会受到分词的影响的(清华博士教大家的水论文小技巧hhh)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None

@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(
id="chatglm3-6b"
)
return ModelList(
data=[model_card]
)

这个list_models直接限定了就是chatglm3-6b模型了,里面没有包括实际的模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer

if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")#截这个的原因是gpt模型是允许任意角色的消息序列的包括assistant多次生成的功能。glm3则不允许
if request.stream:
#SSE流式响应
response = generate_chatglm3(model, tokenizer, gen_params) #直接响应
message = ChatMessage(role="assistant",content=response["text"],...)#创建消息体

#计算使用量然后返回响应体,choice_data里面只放了一个数据
return ChatCompletionResponse(
model=request.model,
id="", # for open_source model, id is empty
choices=[choice_data],
object="chat.completion",
usage=usage
)

chat最核心的响应函数了,由于函数较长就不全截了。

首先我们看到的是一个消息验证不允许assistant多次生成,原因主要是这个功能本身对助手是没有什么意义的,而且多次生成的训练效果比较差,之前我测试过gpt api的多次生成。因为他们用的训练数据基本上都是一个消息内全部回复了,上下文数据本身不存在多次生成的场景,因此这些模型多次生成并不是把问题分多次回复(和人类不同,一句话可以多方面讲,分段讲),只是把答案回答多次。

如果要实现更真实的问答AI,拥有更真实的对话体验,那对数据的要求是很高的,最好的数据集应该是QQ微信这种聊天软件的数据,但是企业是不可能拿这些隐私数据训练的。不过也有平替,如贴吧微博这些开放平台的数据也是很好的,但是这些数据看过后,上下文的逻辑性还是有问题的,并且多轮对话的人物被屏蔽了,也就是说明明是多个人的对话被训练成了二人的对话,这些模型后面肯定被高质量多轮对话微调过,不然单纯这些语料不会达到gpt的这种效果。

响应类型分为直接响应和SSE响应,其中直接响应简单,就是拿model直接推理得到message。
这里有个问题是这个chat函数是asyn异步的,但是model资源是global的单个模型,如果同时多个请求可能会报错。可以对模型封装个请求拥塞队列,比如大于3个请求就返回繁忙。

SSE响应部分和直接响应不同,SSE没有提供使用量这些信息,仅返回了响应文本,SSE还对前端的响应方法有要求,因此如果是仅学习开发和小规模应用没有必要追求SSE

1
2
3
4
predict_stream_generator = predict_stream(request.model, gen_params)
output = next(predict_stream_generator)
if not contains_custom_function(output):
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")

通过predict_stream创建一个生成器,next生成下一个字符然后返回。

utils.py

utils.py提供了响应的实现函数generate_stream_chatglm3和generate_chatglm3

1
2
3
4
def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
for response in generate_stream_chatglm3(model, tokenizer, params):
pass
return response

循环调用generate_stream_chatglm3后返回响应

generate_stream_chatglm3函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
messages = params["messages"] #消息
tools = params["tools"] #工具
temperature = float(params.get("temperature", 1.0)) #温度参数
repetition_penalty = float(params.get("repetition_penalty", 1.0))#惩罚参数,transformer有个问题就是高概率文本会重复生成,在有的论文中提出了惩罚参数,即对已经生成的token的概率乘上惩罚参数让这个token的概率变小,减小重复概率。
top_p = float(params.get("top_p", 1.0)) #top_p top_k是采样的一个过滤方法,p是按概率阈值过滤,k是按排序过滤
max_new_tokens = int(params.get("max_tokens", 256)) #最大允许新生成的tokens
echo = params.get("echo", True)
messages = process_chatglm_messages(messages, tools=tools)#消息处理
query, role = messages[-1]["content"], messages[-1]["role"]#最后一个消息内容

inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role) #把历史和问题构建输入
inputs = inputs.to(model.device)
input_echo_len = len(inputs["input_ids"][0])#输入编码序列长度

if input_echo_len >= model.config.seq_length: #输入序列长度限制
print(f"Input length larger than {model.config.seq_length}")

eos_token_id = [ #结束token
tokenizer.eos_token_id,
tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")
]

gen_kwargs = { #控制参数
"max_new_tokens": max_new_tokens,
"do_sample": True if temperature > 1e-5 else False,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"logits_processor": [InvalidScoreLogitsProcessor()],
}
if temperature > 1e-5:
gen_kwargs["temperature"] = temperature

total_len = 0
for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):
total_ids = total_ids.tolist()[0]
total_len = len(total_ids)
if echo: #没看懂echo什么意思 input_echo_len应该是生成的total_ids中echo控制是否对问题重复一遍,重复了就减掉
output_ids = total_ids[:-1]
else:
output_ids = total_ids[input_echo_len:-1]
#反正output_ids是stream_generate的ids

response = tokenizer.decode(output_ids)
if response and response[-1] != "�": #乱码了就跳出
response, stop_found = apply_stopping_strings(response, ["<|observation|>"]) #判断是否结束

yield { #yield作为一个生成器每次调用生成output_ids然后返回
"text": response,
"usage": {
"prompt_tokens": input_echo_len,#输入tokens
"completion_tokens": total_len - input_echo_len,#总tokens-重复的输入tokens
"total_tokens": total_len,#总tokens
},
"finish_reason": "function_call" if stop_found else None,
}

if stop_found:
break
#最后一个字符跳出返回结束
# Only last stream result contains finish_reason, we set finish_reason as stop
ret = {
"text": response,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": total_len - input_echo_len,
"total_tokens": total_len,
},
"finish_reason": "stop",
}
yield ret
#内存显存收下垃圾
gc.collect()
torch.cuda.empty_cache()

其中里面有个函数很关键process_chatglm_messages:消息处理函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def process_chatglm_messages(messages, tools=None):
_messages = messages
messages = []
msg_has_sys = False
if tools:
messages.append(
{
"role": "system",
"content": "Answer the following questions as best as you can. You have access to the following tools:",
"tools": tools
}
)
msg_has_sys = True

for m in _messages:
role, content, func_call = m.role, m.content, m.function_call
if role == "function":
messages.append(
{
"role": "observation",
"content": content
}
)

elif role == "assistant" and func_call is not None:
for response in content.split("<|assistant|>"):
metadata, sub_content = response.split("\n", maxsplit=1)
messages.append(
{
"role": role,
"metadata": metadata,
"content": sub_content.strip()
}
)
else:
if role == "system" and msg_has_sys:
msg_has_sys = False
continue
messages.append({"role": role, "content": content})
return messages

这个函数就是把message对象转化为dict对象,我们可以看到这里面有system、observation、assistant、user。

在generate_stream_chatglm3:inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)
这个函数中把dict对象对应的history转换成了文本格式,例如:

1
2
3
4
5
6
<|system|>
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
<|user|>
Hello
<|assistant|>
Hello, I'm ChatGLM3. What can I assist you today?

也就是说我们以为的多轮对话实际上就是把历史记录拼起来的。

这部分想到了个idea,这种拼起来的实际上有历史限制,如果让模型生成每个对话的重要性,然后按照重要性+时间权重排序选择性记忆能不能增强长期记忆能力?感觉这部分应该有人在做或者做出来了。

main

最后回来看下api_server.py的main函数

1
2
3
4
5
6
7
8
if __name__ == "__main__":
# Load LLM
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()

# load Embedding
embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

main函数从transformers加载模型然后作为global对象推理。
transformers模型就和传统的bert、t5类似了

chatglm的改进主要包括:

  • 二维位置编码+GELU+残差、层归一化重排序
  • 文档级+句子 NLG预训练
  • NLG+NLU两种任务都进行训练,同时微调的时候还使用了slot填空的NLU方法

总结

之前没怎么看过这种有上下文模型响应的完整流程,这趟下来解决了我之前好几个疑惑:

  1. transformer的重复问题我遇到了好几次,可以通过惩罚参数控制
  2. 上下文实现方法-实际上还是把历史对话融在一起
  3. 模型推理资源占用问题,请求队列感觉是一定要有的,web框架本身是异步请求响应的,不对临界资源管理感觉没啥可靠性

加上这个,目前已经把带上下文的文本生成+知识库扩展永久记忆解决了,后面再对模型结构魔改下,然后集成一些动作指令,就可以实现本地部署家用AI女仆了hhh。