110 lines
2.7 KiB
Python
110 lines
2.7 KiB
Python
from gradio_client import Client
|
|
from deep_translator import GoogleTranslator
|
|
|
|
# Обработка запроса генерации
|
|
def predict(prompt, client, model = "0.1"):
|
|
global iddb
|
|
if model == "0.1":
|
|
result = client.predict(
|
|
prompt,
|
|
"Default",
|
|
api_name="/chat"
|
|
)
|
|
elif model == "0.2":
|
|
result = client.predict(
|
|
prompt,
|
|
0.5, # 'Температура'
|
|
128, # 'Длина'
|
|
0.8, # 'Top-p (nucleus sampling)'
|
|
1.8, # 'Repetition penalty'
|
|
api_name="/chat"
|
|
)
|
|
elif model == "RWKV":
|
|
result = client.predict(
|
|
prompt,
|
|
333,
|
|
0.6,
|
|
0, # int | float representing numeric value between 0.0 and 1.0 in 'Top P' Slider component
|
|
0, # int | float representing numeric value between 0.0 and 1.0 in 'Presence Penalty' Slider component
|
|
0, # int | float representing numeric value between 0.0 and 1.0 in 'Count Penalty' Slider component
|
|
fn_index=0
|
|
)
|
|
else:
|
|
print("INCORRECT MODEL: ", model)
|
|
print(type(model))
|
|
|
|
return result
|
|
|
|
# Определяем что это код
|
|
def iscode(text):
|
|
# Теги
|
|
langs = ['sql','php','js','java','c','cpp','python','go']
|
|
is_code = False
|
|
for i in langs:
|
|
if i + r'\n' in text:
|
|
is_code = True
|
|
break
|
|
# Тег ассемблера
|
|
spec = ['section .']
|
|
if not is_code:
|
|
for i in spec:
|
|
if i in text:
|
|
is_code = True
|
|
break
|
|
return is_code
|
|
|
|
# text IN language IN
|
|
def translate(text, source):
|
|
if source == "ru":
|
|
target = "en"
|
|
elif source == "en":
|
|
target = "ru"
|
|
|
|
# Исправление перевода кода
|
|
if '```' in text:
|
|
out = ''
|
|
for i in text.split('```'):
|
|
if iscode(i):
|
|
out += '```' + i + '```'
|
|
else:
|
|
out += GoogleTranslator(source = source, target = target).translate(i)
|
|
else:
|
|
out = GoogleTranslator(source = source, target = target).translate(text)
|
|
return out
|
|
|
|
|
|
# Словарь пользователь - сессия
|
|
iddb = {}
|
|
|
|
def gen(prompt, id, model):
|
|
global iddb
|
|
|
|
# Если нету сессии
|
|
if str(id) not in iddb:
|
|
if model == "0.1":
|
|
client = Client("https://afischer1985-ai-interface.hf.space/")
|
|
elif model == "0.2":
|
|
client = Client("https://nonamed33-minigpt-api.hf.space/")
|
|
elif model == "RWKV":
|
|
client = Client("https://blinkdl-rwkv-gradio-1.hf.space/")
|
|
iddb[str(id)] = client
|
|
else:
|
|
client = iddb[str(id)]
|
|
|
|
|
|
try:
|
|
if model == "0.1" or model == "0.2":
|
|
prompt = translate(prompt, "ru")
|
|
predicted = predict(prompt, client, model).replace("</s>", "")
|
|
predicted = translate(predicted, "en")
|
|
elif model == "RWKV":
|
|
predicted = predict(prompt, client, model)
|
|
else:
|
|
print("INCORRECT MODEL: ", model)
|
|
print(type(model))
|
|
except:
|
|
pass
|
|
|
|
fixed = predicted.replace(r'\n', '\n').replace('\\ n', '\n')
|
|
return fixed
|