You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
2.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from gradio_client import Client
from deep_translator import GoogleTranslator
# Обработка запроса генерации
def predict(prompt, client, model = "0.1"):
global iddb
if model == "0.1":
result = client.predict(
prompt,
"Default",
api_name="/chat"
)
elif model == "0.2":
result = client.predict(
prompt,
0.5, # 'Температура'
128, # 'Длина'
0.8, # 'Top-p (nucleus sampling)'
1.8, # 'Repetition penalty'
api_name="/chat"
)
elif model == "RWKV":
result = client.predict(
prompt,
333,
0.6,
0, # int | float representing numeric value between 0.0 and 1.0 in 'Top P' Slider component
0, # int | float representing numeric value between 0.0 and 1.0 in 'Presence Penalty' Slider component
0, # int | float representing numeric value between 0.0 and 1.0 in 'Count Penalty' Slider component
fn_index=0
)
else:
print("INCORRECT MODEL: ", model)
print(type(model))
return result
# Определяем что это код
def iscode(text):
# Теги
langs = ['sql','php','js','java','c','cpp','python','go']
is_code = False
for i in langs:
if i + r'\n' in text:
is_code = True
break
# Тег ассемблера
spec = ['section .']
if not is_code:
for i in spec:
if i in text:
is_code = True
break
return is_code
# text IN language IN
def translate(text, source):
if source == "ru":
target = "en"
elif source == "en":
target = "ru"
# Исправление перевода кода
if '```' in text:
out = ''
for i in text.split('```'):
if iscode(i):
out += '```' + i + '```'
else:
out += GoogleTranslator(source = source, target = target).translate(i)
else:
out = GoogleTranslator(source = source, target = target).translate(text)
return out
# Словарь пользователь - сессия
iddb = {}
def gen(prompt, id, model):
global iddb
# Если нету сессии
if str(id) not in iddb:
if model == "0.1":
client = Client("https://afischer1985-ai-interface.hf.space/")
elif model == "0.2":
client = Client("https://nonamed33-minigpt-api.hf.space/")
elif model == "RWKV":
client = Client("https://blinkdl-rwkv-gradio-1.hf.space/")
iddb[str(id)] = client
else:
client = iddb[str(id)]
try:
if model == "0.1" or model == "0.2":
prompt = translate(prompt, "ru")
predicted = predict(prompt, client, model).replace("</s>", "")
predicted = translate(predicted, "en")
elif model == "RWKV":
predicted = predict(prompt, client, model)
else:
print("INCORRECT MODEL: ", model)
print(type(model))
except:
pass
fixed = predicted.replace(r'\n', '\n').replace('\\ n', '\n')
return fixed