From bb824885b50ecca3de142de8d9514cf6396758fe Mon Sep 17 00:00:00 2001 From: t Date: Tue, 2 Jan 2024 01:27:06 +0300 Subject: [PATCH] Fixs and add more model. --- api.py | 41 ++++++++++++++++++++++++++--------------- minigpt.py | 30 +++++++++++++++++++++++++++--- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/api.py b/api.py index 50c2106..cd4f288 100644 --- a/api.py +++ b/api.py @@ -1,18 +1,25 @@ from gradio_client import Client -#client = Client("https://skier8402-mistral-super-fast.hf.space/") from deep_translator import GoogleTranslator -def predict(prompt, client): +def predict(prompt, client, model = "0.1"): global iddb - result = client.predict( - prompt, - 0.5, # 'Temperature' - 256, # 'Max new tokens' - 0.9, # 'Top-p (nucleus sampling)' - 1.2, # 'Repetition penalty' - api_name="/chat" - ) + if model == "0.1": + result = client.predict( + prompt, + "Default", + api_name="/chat" + ) + elif model == "0.2": + result = client.predict( + prompt, + 0.3, # 'Temperature' + 128, # 'Max new tokens' + 0.8, # 'Top-p (nucleus sampling)' + 1.5, # 'Repetition penalty' + api_name="/chat" + ) + return result # text IN language IN @@ -21,21 +28,25 @@ def translate(text, source): target = "en" elif source == "en": target = "ru" - out = GoogleTranslator(source = source, target = target).translate(text) return out iddb = {} -def gen(text, id): +def gen(text, id, model): global iddb + if str(id) not in iddb: - client = Client("https://skier8402-mistral-super-fast.hf.space/") + if model == "0.1": + client = Client("https://afischer1985-ai-interface.hf.space/--replicas/salfk/") + elif model == "0.2": + client = Client("https://skier8402-mistral-super-fast.hf.space/") iddb[str(id)] = client else: client = iddb[str(id)] prompt = translate(text, "ru") - predicted = translate( predict(prompt, client), "en" ).replace("", "") - return predicted + predicted = predict(prompt, client, model).replace("", "") + + return translate(predicted, "en") diff --git a/minigpt.py b/minigpt.py index d6eb89b..1cd8c33 100644 --- a/minigpt.py +++ b/minigpt.py @@ -18,20 +18,44 @@ API_TOKEN = db["token"] bot = telebot.TeleBot(API_TOKEN) ################## + @bot.message_handler(commands=['help', 'start']) def send_welcome(message): - bot.reply_to(message, "None") + bot.reply_to(message, "Скоро...") + ### MAIN ### from api import * +setted_models = {} + +@bot.message_handler(commands=['model']) +def set_model(message): + global setted_models, iddb + try: + iddb.pop(str(message.chat.id)) + except: + pass + + model = message.text.split()[1] + if model == "0.1" or model == "0.2": + setted_models[str(message.chat.id)] = model + bot.reply_to(message, "Установлена новая модель 🤖") + else: + bot.reply_to(message, "Неизвестная модель") + + @bot.message_handler(func=lambda message: True) def echo_message(message): - bot.send_chat_action(message.chat.id, "typing", 30) - bot.reply_to(message, gen(message.text, message.chat.id)) + global setted_models + prompt = 'Отвечай кратко не давая никакой лишней информации и не делая своих умозаключений. \n\n' + message.text + id = str(message.chat.id) + if id not in setted_models: + setted_models[id] = "0.1" + bot.reply_to(message, gen(prompt, message.chat.id, setted_models[id])) ############