-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathllmcord.py
284 lines (205 loc) · 12.6 KB
/
llmcord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import asyncio
from base64 import b64encode
from dataclasses import dataclass, field
from datetime import datetime as dt
import logging
from typing import Literal, Optional
import discord
import httpx
from openai import AsyncOpenAI
import yaml
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
)
VISION_MODEL_TAGS = ("gpt-4o", "claude-3", "gemini", "pixtral", "llava", "vision", "vl")
PROVIDERS_SUPPORTING_USERNAMES = ("openai", "x-ai")
ALLOWED_FILE_TYPES = ("image", "text")
EMBED_COLOR_COMPLETE = discord.Color.dark_green()
EMBED_COLOR_INCOMPLETE = discord.Color.orange()
STREAMING_INDICATOR = " ⚪"
EDIT_DELAY_SECONDS = 1
MAX_MESSAGE_NODES = 100
def get_config(filename="config.yaml"):
with open(filename, "r") as file:
return yaml.safe_load(file)
cfg = get_config()
if client_id := cfg["client_id"]:
logging.info(f"\n\nBOT INVITE URL:\nhttps://discord.com/api/oauth2/authorize?client_id={client_id}&permissions=412317273088&scope=bot\n")
intents = discord.Intents.default()
intents.message_content = True
activity = discord.CustomActivity(name=(cfg["status_message"] or "github.com/jakobdylanc/llmcord")[:128])
discord_client = discord.Client(intents=intents, activity=activity)
httpx_client = httpx.AsyncClient()
msg_nodes = {}
last_task_time = 0
@dataclass
class MsgNode:
text: Optional[str] = None
images: list = field(default_factory=list)
role: Literal["user", "assistant"] = "assistant"
user_id: Optional[int] = None
has_bad_attachments: bool = False
fetch_parent_failed: bool = False
parent_msg: Optional[discord.Message] = None
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
@discord_client.event
async def on_message(new_msg):
global msg_nodes, last_task_time
is_dm = new_msg.channel.type == discord.ChannelType.private
if (not is_dm and discord_client.user not in new_msg.mentions) or new_msg.author.bot:
return
role_ids = tuple(role.id for role in getattr(new_msg.author, "roles", ()))
channel_ids = tuple(id for id in (new_msg.channel.id, getattr(new_msg.channel, "parent_id", None), getattr(new_msg.channel, "category_id", None)) if id)
cfg = get_config()
allow_dms = cfg["allow_dms"]
permissions = cfg["permissions"]
(allowed_user_ids, blocked_user_ids), (allowed_role_ids, blocked_role_ids), (allowed_channel_ids, blocked_channel_ids) = (
(perm["allowed_ids"], perm["blocked_ids"]) for perm in (permissions["users"], permissions["roles"], permissions["channels"])
)
allow_all_users = not allowed_user_ids if is_dm else not allowed_user_ids and not allowed_role_ids
is_good_user = allow_all_users or new_msg.author.id in allowed_user_ids or any(id in allowed_role_ids for id in role_ids)
is_bad_user = not is_good_user or new_msg.author.id in blocked_user_ids or any(id in blocked_role_ids for id in role_ids)
allow_all_channels = not allowed_channel_ids
is_good_channel = allow_dms if is_dm else allow_all_channels or any(id in allowed_channel_ids for id in channel_ids)
is_bad_channel = not is_good_channel or any(id in blocked_channel_ids for id in channel_ids)
if is_bad_user or is_bad_channel:
return
provider, model = cfg["model"].split("/", 1)
base_url = cfg["providers"][provider]["base_url"]
api_key = cfg["providers"][provider].get("api_key", "sk-no-key-required")
openai_client = AsyncOpenAI(base_url=base_url, api_key=api_key)
accept_images = any(x in model.lower() for x in VISION_MODEL_TAGS)
accept_usernames = any(x in provider.lower() for x in PROVIDERS_SUPPORTING_USERNAMES)
max_text = cfg["max_text"]
max_images = cfg["max_images"] if accept_images else 0
max_messages = cfg["max_messages"]
use_plain_responses = cfg["use_plain_responses"]
max_message_length = 2000 if use_plain_responses else (4096 - len(STREAMING_INDICATOR))
# Build message chain and set user warnings
messages = []
user_warnings = set()
curr_msg = new_msg
while curr_msg != None and len(messages) < max_messages:
curr_node = msg_nodes.setdefault(curr_msg.id, MsgNode())
async with curr_node.lock:
if curr_node.text == None:
cleaned_content = curr_msg.content.removeprefix(discord_client.user.mention).lstrip()
good_attachments = {type: [att for att in curr_msg.attachments if att.content_type and type in att.content_type] for type in ALLOWED_FILE_TYPES}
curr_node.text = "\n".join(
([cleaned_content] if cleaned_content else [])
+ [embed.description for embed in curr_msg.embeds if embed.description]
+ [(await httpx_client.get(att.url)).text for att in good_attachments["text"]]
)
curr_node.images = [
dict(type="image_url", image_url=dict(url=f"data:{att.content_type};base64,{b64encode((await httpx_client.get(att.url)).content).decode('utf-8')}"))
for att in good_attachments["image"]
]
curr_node.role = "assistant" if curr_msg.author == discord_client.user else "user"
curr_node.user_id = curr_msg.author.id if curr_node.role == "user" else None
curr_node.has_bad_attachments = len(curr_msg.attachments) > sum(len(att_list) for att_list in good_attachments.values())
try:
if (
not curr_msg.reference
and discord_client.user.mention not in curr_msg.content
and (prev_msg_in_channel := ([m async for m in curr_msg.channel.history(before=curr_msg, limit=1)] or [None])[0])
and prev_msg_in_channel.type in (discord.MessageType.default, discord.MessageType.reply)
and prev_msg_in_channel.author == (discord_client.user if curr_msg.channel.type == discord.ChannelType.private else curr_msg.author)
):
curr_node.parent_msg = prev_msg_in_channel
else:
is_public_thread = curr_msg.channel.type == discord.ChannelType.public_thread
parent_is_thread_start = is_public_thread and not curr_msg.reference and curr_msg.channel.parent.type == discord.ChannelType.text
if parent_msg_id := curr_msg.channel.id if parent_is_thread_start else getattr(curr_msg.reference, "message_id", None):
if parent_is_thread_start:
curr_node.parent_msg = curr_msg.channel.starter_message or await curr_msg.channel.parent.fetch_message(parent_msg_id)
else:
curr_node.parent_msg = curr_msg.reference.cached_message or await curr_msg.channel.fetch_message(parent_msg_id)
except (discord.NotFound, discord.HTTPException):
logging.exception("Error fetching next message in the chain")
curr_node.fetch_parent_failed = True
if curr_node.images[:max_images]:
content = ([dict(type="text", text=curr_node.text[:max_text])] if curr_node.text[:max_text] else []) + curr_node.images[:max_images]
else:
content = curr_node.text[:max_text]
if content != "":
message = dict(content=content, role=curr_node.role)
if accept_usernames and curr_node.user_id != None:
message["name"] = str(curr_node.user_id)
messages.append(message)
if len(curr_node.text) > max_text:
user_warnings.add(f"⚠️ Max {max_text:,} characters per message")
if len(curr_node.images) > max_images:
user_warnings.add(f"⚠️ Max {max_images} image{'' if max_images == 1 else 's'} per message" if max_images > 0 else "⚠️ Can't see images")
if curr_node.has_bad_attachments:
user_warnings.add("⚠️ Unsupported attachments")
if curr_node.fetch_parent_failed or (curr_node.parent_msg != None and len(messages) == max_messages):
user_warnings.add(f"⚠️ Only using last {len(messages)} message{'' if len(messages) == 1 else 's'}")
curr_msg = curr_node.parent_msg
logging.info(f"Message received (user ID: {new_msg.author.id}, attachments: {len(new_msg.attachments)}, conversation length: {len(messages)}):\n{new_msg.content}")
if system_prompt := cfg["system_prompt"]:
system_prompt_extras = [f"Today's date: {dt.now().strftime('%B %d %Y')}."]
if accept_usernames:
system_prompt_extras.append("User's names are their Discord IDs and should be typed as '<@ID>'.")
full_system_prompt = "\n".join([system_prompt] + system_prompt_extras)
messages.append(dict(role="system", content=full_system_prompt))
# Generate and send response message(s) (can be multiple if response is long)
response_msgs = []
response_contents = []
prev_chunk = None
edit_task = None
embed = discord.Embed()
for warning in sorted(user_warnings):
embed.add_field(name=warning, value="", inline=False)
kwargs = dict(model=model, messages=messages[::-1], stream=True, extra_body=cfg["extra_api_parameters"])
try:
async with new_msg.channel.typing():
async for curr_chunk in await openai_client.chat.completions.create(**kwargs):
prev_content = prev_chunk.choices[0].delta.content if prev_chunk != None and prev_chunk.choices[0].delta.content else ""
curr_content = curr_chunk.choices[0].delta.content or ""
prev_chunk = curr_chunk
if not (response_contents or prev_content):
continue
if start_next_msg := response_contents == [] or len(response_contents[-1] + prev_content) > max_message_length:
response_contents.append("")
response_contents[-1] += prev_content
if not use_plain_responses:
finish_reason = curr_chunk.choices[0].finish_reason
ready_to_edit = (edit_task == None or edit_task.done()) and dt.now().timestamp() - last_task_time >= EDIT_DELAY_SECONDS
msg_split_incoming = len(response_contents[-1] + curr_content) > max_message_length
is_final_edit = finish_reason != None or msg_split_incoming
is_good_finish = finish_reason != None and finish_reason.lower() in ("stop", "end_turn")
if start_next_msg or ready_to_edit or is_final_edit:
if edit_task != None:
await edit_task
embed.description = response_contents[-1] if is_final_edit else (response_contents[-1] + STREAMING_INDICATOR)
embed.color = EMBED_COLOR_COMPLETE if msg_split_incoming or is_good_finish else EMBED_COLOR_INCOMPLETE
if start_next_msg:
reply_to_msg = new_msg if response_msgs == [] else response_msgs[-1]
response_msg = await reply_to_msg.reply(embed=embed, silent=True)
response_msgs.append(response_msg)
msg_nodes[response_msg.id] = MsgNode(parent_msg=new_msg)
await msg_nodes[response_msg.id].lock.acquire()
else:
edit_task = asyncio.create_task(response_msgs[-1].edit(embed=embed))
last_task_time = dt.now().timestamp()
if use_plain_responses:
for content in response_contents:
reply_to_msg = new_msg if response_msgs == [] else response_msgs[-1]
response_msg = await reply_to_msg.reply(content=content, suppress_embeds=True)
response_msgs.append(response_msg)
msg_nodes[response_msg.id] = MsgNode(parent_msg=new_msg)
await msg_nodes[response_msg.id].lock.acquire()
except Exception:
logging.exception("Error while generating response")
for response_msg in response_msgs:
msg_nodes[response_msg.id].text = "".join(response_contents)
msg_nodes[response_msg.id].lock.release()
# Delete oldest MsgNodes (lowest message IDs) from the cache
if (num_nodes := len(msg_nodes)) > MAX_MESSAGE_NODES:
for msg_id in sorted(msg_nodes.keys())[: num_nodes - MAX_MESSAGE_NODES]:
async with msg_nodes.setdefault(msg_id, MsgNode()).lock:
msg_nodes.pop(msg_id, None)
async def main():
await discord_client.start(cfg["bot_token"])
asyncio.run(main())