diff --git a/api.py b/api.py index cd4f288..37fdaa9 100644 --- a/api.py +++ b/api.py @@ -13,10 +13,10 @@ def predict(prompt, client, model = "0.1"): elif model == "0.2": result = client.predict( prompt, - 0.3, # 'Temperature' + 0.05, # 'Temperature' 128, # 'Max new tokens' 0.8, # 'Top-p (nucleus sampling)' - 1.5, # 'Repetition penalty' + 1.8, # 'Repetition penalty' api_name="/chat" ) @@ -31,7 +31,6 @@ def translate(text, source): out = GoogleTranslator(source = source, target = target).translate(text) return out - iddb = {} def gen(text, id, model): @@ -39,7 +38,7 @@ def gen(text, id, model): if str(id) not in iddb: if model == "0.1": - client = Client("https://afischer1985-ai-interface.hf.space/--replicas/salfk/") + client = Client("https://afischer1985-ai-interface.hf.space/") elif model == "0.2": client = Client("https://skier8402-mistral-super-fast.hf.space/") iddb[str(id)] = client @@ -47,6 +46,13 @@ def gen(text, id, model): client = iddb[str(id)] prompt = translate(text, "ru") - predicted = predict(prompt, client, model).replace("", "") + + success = False + while not success: + try: + predicted = predict(prompt, client, model).replace("", "") + success = True + except: + pass return translate(predicted, "en") diff --git a/minigpt.py b/minigpt.py old mode 100644 new mode 100755 index 1cd8c33..9356115 --- a/minigpt.py +++ b/minigpt.py @@ -29,6 +29,23 @@ def send_welcome(message): from api import * setted_models = {} +system_prompts = {} + +@bot.message_handler(commands=['info']) +def info(message): + global setted_models, system_prompts + id = str(message.chat.id) + if id not in setted_models: + setted_models[id] = "0.1" + if id not in system_prompts: + prompt = "None" + else: + prompt = system_prompts[str(message.chat.id)] + + bot.send_message(message.chat.id, f"""____ Информация ____ +Версия: {setted_models[id]} +System-prompt: {telebot.formatting.hcode(prompt)} +""", parse_mode="HTML") @bot.message_handler(commands=['model']) def set_model(message): @@ -46,16 +63,35 @@ def set_model(message): bot.reply_to(message, "Неизвестная модель") +@bot.message_handler(commands=['prompt']) +def set_prompt(message): + global system_prompts + system_prompts[str(message.chat.id)] = message.text[8:] + bot.reply_to(message, "Установлен новый system-prompt") +@bot.message_handler(commands=['cprompt']) +def clear_prompt(message): + global system_prompts + system_prompts.pop(str(message.chat.id)) + bot.reply_to(message, "System-prompt очищен") @bot.message_handler(func=lambda message: True) def echo_message(message): - global setted_models - prompt = 'Отвечай кратко не давая никакой лишней информации и не делая своих умозаключений. \n\n' + message.text + global setted_models, system_prompts id = str(message.chat.id) if id not in setted_models: setted_models[id] = "0.1" - bot.reply_to(message, gen(prompt, message.chat.id, setted_models[id])) + + if id in system_prompts: + prompt = '[INST]' + system_prompts[id] + '[/INST]\n\n' + message.text + else: + prompt = message.text + + st = bot.send_message(message.chat.id, "Печатает...") + bot.reply_to(message, gen(prompt, message.chat.id, setted_models[id]).replace(r'\n', '\n'), parse_mode="HTML") + bot.delete_message(message.chat.id, st.id) + + ############