在本节中,我们将专注于构建一个与转换器模型通信的包装器,以会话格式从用户向 API 发送提示,以及接收和转换我们的聊天应用程序的响应。

如何开始使用 Huggingface

我们不会在 Hugginface 上构建或部署任何语言模型。相反,我们将专注于使用 Huggingface 的加速推理 API 连接到预训练模型。

我们将使用的模型是 EleutherAI](https://huggingface.co/EleutherAI/gpt-j-6B)提供的[GPT-J-6B 模型。这是一个生成语言模型,经过 60 亿个参数训练。

Huggingface 为我们提供了一个按需受限的 API,几乎可以免费连接到这个模型。

要开始使用 Huggingface,请创建一个免费帐户。在您的设置中,生成一个新的访问令牌。对于多达 30k 个令牌,Huggingface 提供对推理 API 的免费访问。

你可以在这里监控你的 API 使用情况。确保您妥善保管此令牌,并且不要将其公开。

注意:我们将使用 HTTP 连接与 API 进行通信,因为我们使用的是免费帐户。但是 PRO Huggingface 帐户支持使用 WebSockets进行流式传输,请参阅并行和批处理作业。

这有助于显着改善模型和我们的聊天应用程序之间的响应时间,我希望在后续文章中介绍这种方法。

如何与语言模型交互

首先,我们将 Huggingface 连接凭据添加到工作目录中的 .env 文件中。

export HUGGINFACE_INFERENCE_TOKEN=<HUGGINGFACE ACCESS TOKEN>
export MODEL_URL=https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B
worker.srcmodelgptj.py
import os
from dotenv import load_dotenv
import requests
import json

load_dotenv()

class GPT:
    def __init__(self):
        self.url = os.environ.get('MODEL_URL')
        self.headers = {
            "Authorization": f"Bearer {os.environ.get('HUGGINFACE_INFERENCE_TOKEN')}"}
        self.payload = {
            "inputs": "",
            "parameters": {
                "return_full_text": False,
                "use_cache": True,
                "max_new_tokens": 25
            }

        }

    def query(self, input: str) -> list:
        self.payload["inputs"] = input
        data = json.dumps(self.payload)
        response = requests.request(
            "POST", self.url, headers=self.headers, data=data)
        print(json.loads(response.content.decode("utf-8")))
        return json.loads(response.content.decode("utf-8"))

if __name__ == "__main__":
    GPT().query("Will artificial intelligence help humanity conquer the universe?")
GPTurlheaderpayloadquery
python src/model/gptj.py
[{'generated_text': ' (AI) could solve all the problems on this planet? I am of the opinion that in the short term artificial intelligence is much better than human beings, but in the long and distant future human beings will surpass artificial intelligence.\n\nIn the distant'}]

接下来,我们对输入进行一些调整,通过更改输入的格式使与模型的交互更具会话性。

GPT

class GPT:
    def __init__(self):
        self.url = os.environ.get('MODEL_URL')
        self.headers = {
            "Authorization": f"Bearer {os.environ.get('HUGGINFACE_INFERENCE_TOKEN')}"}
        self.payload = {
            "inputs": "",
            "parameters": {
                "return_full_text": False,
                "use_cache": False,
                "max_new_tokens": 25
            }

        }

    def query(self, input: str) -> list:
        self.payload["inputs"] = f"Human: {input} Bot:"
        data = json.dumps(self.payload)
        response = requests.request(
            "POST", self.url, headers=self.headers, data=data)
        data = json.loads(response.content.decode("utf-8"))
        text = data[0]['generated_text']
        res = str(text.split("Human:")[0]).strip("\n").strip()
        return res


if __name__ == "__main__":
    GPT().query("Will artificial intelligence help humanity conquer the universe?")
f"Human: {input} Bot:"
  • use_cache: 如果你希望模型在输入相同的情况下创建一个新的响应,你可以把这个设为 False。我建议在生产中将此保留为 True 以防止在用户不断向机器人发送相同消息的垃圾邮件时耗尽您的免费令牌。使用缓存实际上不会从模型加载新的响应。

  • return_full_text: 是 False,因为我们不需要返回输入——我们已经有了它。当我们收到响应时,我们会从响应中去除“Bot:”和前导/尾随空格,并仅返回响应文本。

如何模拟 AI 模型的短期记忆

对于我们发送给模型的每个新输入,模型都无法记住对话历史记录。如果我们想在对话中保持上下文,这一点很重要。

但请记住,随着我们发送给模型的令牌数量的增加,处理变得更加昂贵,响应时间也更长。

所以我们需要找到一种方法来检索短期历史并将其发送到模型。我们还需要找出一个最佳点——我们要检索多少历史数据并将其发送到模型?

token
worker.src.redis.config.pycreate_rejson_connection
worker.src.redis.config.py

import os
from dotenv import load_dotenv
import aioredis
from rejson import Client


load_dotenv()


class Redis():
    def __init__(self):
        """initialize  connection """
        self.REDIS_URL = os.environ['REDIS_URL']
        self.REDIS_PASSWORD = os.environ['REDIS_PASSWORD']
        self.REDIS_USER = os.environ['REDIS_USER']
        self.connection_url = f"redis://{self.REDIS_USER}:{self.REDIS_PASSWORD}@{self.REDIS_URL}"
        self.REDIS_HOST = os.environ['REDIS_HOST']
        self.REDIS_PORT = os.environ['REDIS_PORT']

    async def create_connection(self):
        self.connection = aioredis.from_url(
            self.connection_url, db=0)

        return self.connection

    def create_rejson_connection(self):
        self.redisJson = Client(host=self.REDIS_HOST,
                                port=self.REDIS_PORT, decode_responses=True, username=self.REDIS_USER, password=self.REDIS_PASSWORD)

        return self.redisJson

虽然您的 .env 文件应如下所示:

export REDIS_URL=<REDIS URL PROVIDED IN REDIS CLOUD>
export REDIS_USER=<REDIS USER IN REDIS CLOUD>
export REDIS_PASSWORD=<DATABASE PASSWORD IN REDIS CLOUD>
export REDIS_HOST=<REDIS HOST IN REDIS CLOUD>
export REDIS_PORT=<REDIS PORT IN REDIS CLOUD>
export HUGGINFACE_INFERENCE_TOKEN=<HUGGINGFACE ACCESS TOKEN>
export MODEL_URL=https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B
worker.src.rediscache.py
from .config import Redis
from rejson import Path

class Cache:
    def __init__(self, json_client):
        self.json_client = json_client

    async def get_chat_history(self, token: str):
        data = self.json_client.jsonget(
            str(token), Path.rootPath())

        return data
get_chat_history
worker.main.py
from src.redis.config import Redis
import asyncio
from src.model.gptj import GPT
from src.redis.cache import Cache

redis = Redis()

async def main():
    json_client = redis.create_rejson_connection()
    data = await Cache(json_client).get_chat_history(token="18196e23-763b-4808-ae84-064348a0daff")
    print(data)

if __name__ == "__main__":
    asyncio.run(main())
/tokenpython main.py
构建全栈 AI 聊天机器人第 4 部分
{'token': '18196e23-763b-4808-ae84-064348a0daff', 'messages': [], 'name': 'Stephen', 'session_start': '2022-07-16 13:20:01.092109'}
Cacheadd_message_to_cache

  async def add_message_to_cache(self, token: str, message_data: dict):
      self.json_client.jsonarrappend(
          str(token), Path('.messages'), message_data)
jsonarrappend
.messages

为了测试这个方法,更新main.py文件中的main函数,代码如下:

async def main():
    json_client = redis.create_rejson_connection()

    await Cache(json_client).add_message_to_cache(token="18196e23-763b-4808-ae84-064348a0daff", message_data={
        "id": "1",
        "msg": "Hello",
        "timestamp": "2022-07-16 13:20:01.092109"
    })

    data = await Cache(json_client).get_chat_history(token="18196e23-763b-4808-ae84-064348a0daff")
    print(data)
python main.py
{'token': '18196e23-763b-4808-ae84-064348a0daff', 'messages': [{'id': '1', 'msg': 'Hello', 'timestamp': '2022-07-16 13:20:01.092109'}], 'name': 'Stephen', 'session_start': '2022-07-16 13:20:01.092109'}

最后,我们需要更新主函数以将消息数据发送到 GPT 模型,并使用客户端和模型之间发送的last 4 消息更新输入。

add_message_to_cache
add_message_to_cache
  async def add_message_to_cache(self, token: str, source: str, message_data: dict):
      if source == "human":
          message_data['msg'] = "Human: " + (message_data['msg'])
      elif source == "bot":
          message_data['msg'] = "Bot: " + (message_data['msg'])

      self.json_client.jsonarrappend(
          str(token), Path('.messages'), message_data)
python main.py
async def main():
    json_client = redis.create_rejson_connection()

    await Cache(json_client).add_message_to_cache(token="18196e23-763b-4808-ae84-064348a0daff", source="human", message_data={
        "id": "1",
        "msg": "Hello",
        "timestamp": "2022-07-16 13:20:01.092109"
    })

    data = await Cache(json_client).get_chat_history(token="18196e23-763b-4808-ae84-064348a0daff")
    print(data)

接下来,我们需要更新 main 函数,将新消息添加到缓存中,从缓存中读取前 4 条消息,然后使用查询方法对模型进行 API 调用。它将有一个由最后 4 条消息的复合字符串组成的有效负载。

您可以随时调整要提取的历史记录中的消息数量,但我认为 4 条消息对于演示来说是一个相当不错的数字。

worker.srcchat.py
from datetime import datetime
from pydantic import BaseModel
from typing import List, Optional
import uuid


class Message(BaseModel):
    id = str(uuid.uuid4())
    msg: str
    timestamp = str(datetime.now())

接下来,更新main.py文件,如下所示:

async def main():

    json_client = redis.create_rejson_connection()

    await Cache(json_client).add_message_to_cache(token="18196e23-763b-4808-ae84-064348a0daff", source="human", message_data={
        "id": "3",
        "msg": "I would like to go to the moon to, would you take me?",
        "timestamp": "2022-07-16 13:20:01.092109"
    })

    data = await Cache(json_client).get_chat_history(token="18196e23-763b-4808-ae84-064348a0daff")

    print(data)

    message_data = data['messages'][-4:]

    input = ["" + i['msg'] for i in message_data]
    input = " ".join(input)

    res = GPT().query(input=input)

    msg = Message(
        msg=res
    )

    print(msg)
    await Cache(json_client).add_message_to_cache(token="18196e23-763b-4808-ae84-064348a0daff", source="bot", message_data=msg.dict())

在上面的代码中,我们将新的消息数据添加到缓存中。该消息最终将来自消息队列。接下来,我们从缓存中获取聊天历史记录,其中现在将包含我们添加的最新数据。

请注意,我们使用相同的硬编码令牌添加到缓存并从缓存中获取,暂时只是为了测试一下。

接下来,我们修剪缓存数据并仅提取最后 4 项。然后我们通过在列表中提取 msg 来合并输入数据并将其连接到一个空字符串。

最后,我们为机器人响应创建一个新的 Message 实例,并将响应添加到缓存中,指定源为“bot”

python main.py

打开 Redis Insight,你应该有类似下面的内容:

Stream Consumer 和从消息队列中拉取实时数据

worker.main.py
worker.src.redisstream.pyStreamConsumer
class StreamConsumer:
    def __init__(self, redis_client):
        self.redis_client = redis_client

    async def consume_stream(self, count: int, block: int,  stream_channel):

        response = await self.redis_client.xread(
            streams={stream_channel:  '0-0'}, count=count, block=block)

        return response

    async def delete_message(self, stream_channel, message_id):
        await self.redis_client.xdel(stream_channel, message_id)
StreamConsumerconsume_streamxread
worker.main.py

from src.redis.config import Redis
import asyncio
from src.model.gptj import GPT
from src.redis.cache import Cache
from src.redis.config import Redis
from src.redis.stream import StreamConsumer
import os
from src.schema.chat import Message


redis = Redis()


async def main():
    json_client = redis.create_rejson_connection()
    redis_client = await redis.create_connection()
    consumer = StreamConsumer(redis_client)
    cache = Cache(json_client)

    print("Stream consumer started")
    print("Stream waiting for new messages")

    while True:
        response = await consumer.consume_stream(stream_channel="message_channel", count=1, block=0)

        if response:
            for stream, messages in response:
                # Get message from stream, and extract token, message data and message id
                for message in messages:
                    message_id = message[0]
                    token = [k.decode('utf-8')
                             for k, v in message[1].items()][0]
                    message = [v.decode('utf-8')
                               for k, v in message[1].items()][0]
                    print(token)

                    # Create a new message instance and add to cache, specifying the source as human
                    msg = Message(msg=message)

                    await cache.add_message_to_cache(token=token, source="human", message_data=msg.dict())

                    # Get chat history from cache
                    data = await cache.get_chat_history(token=token)

                    # Clean message input and send to query
                    message_data = data['messages'][-4:]

                    input = ["" + i['msg'] for i in message_data]
                    input = " ".join(input)

                    res = GPT().query(input=input)

                    msg = Message(
                        msg=res
                    )

                    print(msg)

                    await cache.add_message_to_cache(token=token, source="bot", message_data=msg.dict())

                # Delete messaage from queue after it has been processed

                await consumer.delete_message(stream_channel="message_channel", message_id=message_id)


if __name__ == "__main__":
    asyncio.run(main())

这是相当大的更新,所以让我们一步一步来:

while True
consume_streamquery
add_message_to_cache

如何使用 AI 响应更新聊天客户端

到目前为止,我们正在从客户端向 message_channel 发送一条聊天消息(由查询 AI 模型的工作人员接收)以获取响应。

接下来,我们需要将此响应发送给客户端。只要套接字连接仍然打开,客户端应该能够接收到响应。

refresh_token
worker.src.redisproducer.pyProducer

class Producer:
    def __init__(self, redis_client):
        self.redis_client = redis_client

    async def add_to_stream(self,  data: dict, stream_channel) -> bool:
        msg_id = await self.redis_client.xadd(name=stream_channel, id="*", fields=data)
        print(f"Message id {msg_id} added to {stream_channel} stream")
        return msg_id
main.pyadd_to_streamresponse_channel
from src.redis.config import Redis
import asyncio
from src.model.gptj import GPT
from src.redis.cache import Cache
from src.redis.config import Redis
from src.redis.stream import StreamConsumer
import os
from src.schema.chat import Message
from src.redis.producer import Producer


redis = Redis()


async def main():
    json_client = redis.create_rejson_connection()
    redis_client = await redis.create_connection()
    consumer = StreamConsumer(redis_client)
    cache = Cache(json_client)
    producer = Producer(redis_client)

    print("Stream consumer started")
    print("Stream waiting for new messages")

    while True:
        response = await consumer.consume_stream(stream_channel="message_channel", count=1, block=0)

        if response:
            for stream, messages in response:
                # Get message from stream, and extract token, message data and message id
                for message in messages:
                    message_id = message[0]
                    token = [k.decode('utf-8')
                             for k, v in message[1].items()][0]
                    message = [v.decode('utf-8')
                               for k, v in message[1].items()][0]

                    # Create a new message instance and add to cache, specifying the source as human
                    msg = Message(msg=message)

                    await cache.add_message_to_cache(token=token, source="human", message_data=msg.dict())

                    # Get chat history from cache
                    data = await cache.get_chat_history(token=token)

                    # Clean message input and send to query
                    message_data = data['messages'][-4:]

                    input = ["" + i['msg'] for i in message_data]
                    input = " ".join(input)

                    res = GPT().query(input=input)

                    msg = Message(
                        msg=res
                    )

                    stream_data = {}
                    stream_data[str(token)] = str(msg.dict())

                    await producer.add_to_stream(stream_data, "response_channel")

                    await cache.add_message_to_cache(token=token, source="bot", message_data=msg.dict())

                # Delete messaage from queue after it has been processed
                await consumer.delete_message(stream_channel="message_channel", message_id=message_id)


if __name__ == "__main__":
    asyncio.run(main())
/chat

请注意,我们还需要通过添加逻辑来检查响应是针对哪个客户端的,以检查连接的令牌是否等于响应中的令牌。然后我们删除响应队列中的消息,一旦它被读取。

server.src.redisStreamConsumer
from .config import Redis

class StreamConsumer:
    def __init__(self, redis_client):
        self.redis_client = redis_client

    async def consume_stream(self, count: int, block: int,  stream_channel):
        response = await self.redis_client.xread(
            streams={stream_channel:  '0-0'}, count=count, block=block)

        return response

    async def delete_message(self, stream_channel, message_id):
        await self.redis_client.xdel(stream_channel, message_id)
/chat
from ..redis.stream import StreamConsumer

@chat.websocket("/chat")
async def websocket_endpoint(websocket: WebSocket, token: str = Depends(get_token)):
    await manager.connect(websocket)
    redis_client = await redis.create_connection()
    producer = Producer(redis_client)
    json_client = redis.create_rejson_connection()
    consumer = StreamConsumer(redis_client)

    try:
        while True:
            data = await websocket.receive_text()
            stream_data = {}
            stream_data[str(token)] = str(data)
            await producer.add_to_stream(stream_data, "message_channel")
            response = await consumer.consume_stream(stream_channel="response_channel", block=0)

            print(response)
            for stream, messages in response:
                for message in messages:
                    response_token = [k.decode('utf-8')
                                      for k, v in message[1].items()][0]

                    if token == response_token:
                        response_message = [v.decode('utf-8')
                                            for k, v in message[1].items()][0]

                        print(message[0].decode('utf-8'))
                        print(token)
                        print(response_token)

                        await manager.send_personal_message(response_message, websocket)

                    await consumer.delete_message(stream_channel="response_channel", message_id=message[0].decode('utf-8'))

    except WebSocketDisconnect:
        manager.disconnect(websocket)

刷新令牌

/refresh_tokenCache
server.src.rediscache.py

from rejson import Path

class Cache:
    def __init__(self, json_client):
        self.json_client = json_client

    async def get_chat_history(self, token: str):
        data = self.json_client.jsonget(
            str(token), Path.rootPath())

        return data
server.src.routes.chat.pyCache/token

from ..redis.cache import Cache

@chat.get("/refresh_token")
async def refresh_token(request: Request, token: str):
    json_client = redis.create_rejson_connection()
    cache = Cache(json_client)
    data = await cache.get_chat_history(token)

    if data == None:
        raise HTTPException(
            status_code=400, detail="Session expired or does not exist")
    else:
        return data
/refresh_token

如果令牌没有超时,数据将被发送给用户。或者如果找不到令牌,它将发送 400 响应。

在 Postman 中测试与多个客户端的聊天

最后,我们将通过在 Postman 中创建多个聊天会话、在 Postman 中连接多个客户端以及在客户端上与机器人聊天来测试聊天系统。最后,我们将尝试获取客户的聊天记录,并希望得到适当的回应。

回顾

让我们快速回顾一下我们通过聊天系统取得的成就。聊天客户端为与客户端的每个聊天会话创建一个令牌。此令牌用于标识每个客户端,连接到或 Web 服务器的客户端发送的每条消息都在 Redis 通道 (message_chanel) 中排队,由该令牌标识。

我们的工作环境从这个频道读取。它不知道客户端是谁(除了它是一个唯一令牌)并使用队列中的消息将请求发送到 Huggingface 推理 API。

当它得到响应时,响应被添加到响应通道并更新聊天记录。监听 response_channel 的客户端一旦收到带有其令牌的响应,就会立即将响应发送给客户端。

如果套接字仍然打开,则发送此响应。如果套接字关闭,我们可以确定响应会被保留,因为响应会添加到聊天历史记录中。即使发生页面刷新或连接丢失,客户端也可以获取历史记录。

恭喜你走到这一步!您已经能够构建一个有效的聊天系统。

在后续文章中,我将专注于为客户端构建聊天用户界面、创建单元和功能测试、微调我们的工作环境以加快 WebSockets 和异步请求的响应时间,并最终在 AWS 上部署聊天应用程序。

本文是使用 Python、React、Huggingface、Redis 等工具构建全栈智能聊天机器人系列的一部分。你可以在我的博客上关注完整系列:blog.stephensanwo.dev - AI ChatBot 系列

**您可以在My Github Repository上下载完整的存储库 **

我与 Redis 合作编写了本教程。需要帮助开始使用 Redis?尝试以下资源:

  • 免费试用 Redis Cloud

  • 观看此视频,了解 Redis Cloud 相对于其他 Redis 提供商的优势

  • Redis Developer Hub - 关于 Redis 的工具、指南和教程

  • RedisInsight 桌面 GUI