Skip to content

Commit

Permalink
fix: handle RateLimitError zhayujie#50 zhayujie#51 zhayujie#54
Browse files Browse the repository at this point in the history
  • Loading branch information
ubuntu committed Feb 6, 2023
1 parent d69fd2b commit c7d1e77
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 50 deletions.
77 changes: 29 additions & 48 deletions bot/openai/open_ai_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from config import conf
from common.log import logger
import openai
from datetime import date
import time

user_session = dict()

Expand All @@ -26,16 +26,16 @@ def reply(self, query, context=None):
new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query))

reply_content = self.reply_text(new_query, from_user_id)
logger.debug("[OPEN_AI] new_query={}, user={}".format(new_query, from_user_id))
reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query:
Session.save_session(query, reply_content, from_user_id)
return reply_content

elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query)
return self.create_img(query, 0)

def reply_text(self, query, user_id):
def reply_text(self, query, user_id, retry_count=0):
try:
response = openai.Completion.create(
model="text-davinci-003", # 对话模型的名称
Expand All @@ -48,14 +48,25 @@ def reply_text(self, query, user_id):
stop=["#"]
)
res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>")
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
except openai.error.RateLimitError as e:
# rate limit exception
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e:
# unknown exception
logger.exception(e)
Session.clear_session(user_id)
return None
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
return "请再问我一次吧"


def create_img(self, query):
def create_img(self, query, retry_count=0):
try:
logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create(
Expand All @@ -65,48 +76,18 @@ def create_img(self, query):
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
return image_url
except openai.error.RateLimitError as e:
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e:
logger.exception(e)
return None
return image_url

def edit_img(self, query, src_img):
try:
response = openai.Image.create_edit(
image=open(src_img, 'rb'),
mask=open('cat-mask.png', 'rb'),
prompt=query,
n=1,
size='512x512'
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
except Exception as e:
logger.exception(e)
return None
return image_url

def migration_img(self, query, src_img):

try:
response = openai.Image.create_variation(
image=open(src_img, 'rb'),
n=1,
size="512x512"
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
except Exception as e:
logger.exception(e)
return None
return image_url

def append_question_mark(self, query):
end_symbols = [".", "。", "?", "?", "!", "!"]
for symbol in end_symbols:
if query.endswith(symbol):
return query
return query + "?"


class Session(object):
Expand Down
4 changes: 2 additions & 2 deletions channel/wechat/wechat_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _do_send(self, query, reply_user_id):
return
context = dict()
context['from_user_id'] = reply_user_id
reply_text = super().build_reply_content(query, context).strip()
reply_text = super().build_reply_content(query, context)
if reply_text:
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
except Exception as e:
Expand Down Expand Up @@ -144,8 +144,8 @@ def _do_send_group(self, query, msg):
context = dict()
context['from_user_id'] = msg['ActualUserName']
reply_text = super().build_reply_content(query, context)
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
if reply_text:
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName'])


Expand Down

0 comments on commit c7d1e77

Please sign in to comment.