42 lines
919 B
Python
42 lines
919 B
Python
|
from gradio_client import Client
|
||
|
#client = Client("https://skier8402-mistral-super-fast.hf.space/")
|
||
|
|
||
|
from deep_translator import GoogleTranslator
|
||
|
|
||
|
def predict(prompt, client):
|
||
|
global iddb
|
||
|
result = client.predict(
|
||
|
prompt,
|
||
|
0.5, # 'Temperature'
|
||
|
256, # 'Max new tokens'
|
||
|
0.9, # 'Top-p (nucleus sampling)'
|
||
|
1.2, # 'Repetition penalty'
|
||
|
api_name="/chat"
|
||
|
)
|
||
|
return result
|
||
|
|
||
|
# text IN language IN
|
||
|
def translate(text, source):
|
||
|
if source == "ru":
|
||
|
target = "en"
|
||
|
elif source == "en":
|
||
|
target = "ru"
|
||
|
|
||
|
out = GoogleTranslator(source = source, target = target).translate(text)
|
||
|
return out
|
||
|
|
||
|
|
||
|
iddb = {}
|
||
|
|
||
|
def gen(text, id):
|
||
|
global iddb
|
||
|
if str(id) not in iddb:
|
||
|
client = Client("https://skier8402-mistral-super-fast.hf.space/")
|
||
|
iddb[str(id)] = client
|
||
|
else:
|
||
|
client = iddb[str(id)]
|
||
|
|
||
|
prompt = translate(text, "ru")
|
||
|
predicted = translate( predict(prompt, client), "en" ).replace("</s>", "")
|
||
|
return predicted
|