You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
2.7 KiB

10 months ago
from gradio_client import Client
from deep_translator import GoogleTranslator
# Обработка запроса генерации
def predict(prompt, client, model = "0.1"):
10 months ago
global iddb
if model == "0.1":
result = client.predict(
prompt,
"Default",
api_name="/chat"
)
elif model == "0.2":
result = client.predict(
prompt,
0.5, # 'Температура'
128, # 'Длина'
0.8, # 'Top-p (nucleus sampling)'
1.8, # 'Repetition penalty'
api_name="/chat"
)
elif model == "RWKV":
result = client.predict(
prompt,
333,
0.6,
0, # int | float representing numeric value between 0.0 and 1.0 in 'Top P' Slider component
0, # int | float representing numeric value between 0.0 and 1.0 in 'Presence Penalty' Slider component
0, # int | float representing numeric value between 0.0 and 1.0 in 'Count Penalty' Slider component
fn_index=0
)
else:
print("INCORRECT MODEL: ", model)
print(type(model))
10 months ago
return result
# Определяем что это код
def iscode(text):
# Теги
langs = ['sql','php','js','java','c','cpp','python','go']
is_code = False
for i in langs:
if i + r'\n' in text:
is_code = True
break
# Тег ассемблера
spec = ['section .']
if not is_code:
for i in spec:
if i in text:
is_code = True
break
return is_code
10 months ago
# text IN language IN
def translate(text, source):
if source == "ru":
target = "en"
elif source == "en":
target = "ru"
# Исправление перевода кода
if '```' in text:
out = ''
for i in text.split('```'):
if iscode(i):
out += '```' + i + '```'
else:
out += GoogleTranslator(source = source, target = target).translate(i)
else:
out = GoogleTranslator(source = source, target = target).translate(text)
10 months ago
return out
7 months ago
# Словарь пользователь - сессия
10 months ago
iddb = {}
def gen(prompt, id, model):
10 months ago
global iddb
# Если нету сессии
10 months ago
if str(id) not in iddb:
if model == "0.1":
client = Client("https://afischer1985-ai-interface.hf.space/")
elif model == "0.2":
7 months ago
client = Client("https://nonamed33-minigpt-api.hf.space/")
elif model == "RWKV":
client = Client("https://blinkdl-rwkv-gradio-1.hf.space/")
iddb[str(id)] = client
10 months ago
else:
client = iddb[str(id)]
try:
if model == "0.1" or model == "0.2":
prompt = translate(prompt, "ru")
predicted = predict(prompt, client, model).replace("</s>", "")
predicted = translate(predicted, "en")
elif model == "RWKV":
predicted = predict(prompt, client, model)
else:
print("INCORRECT MODEL: ", model)
print(type(model))
except:
pass
fixed = predicted.replace(r'\n', '\n').replace('\\ n', '\n')
return fixed