just_minigpt/api.py

59 lines
1.2 KiB
Python
Raw Normal View History

2024-01-01 21:06:13 +00:00
from gradio_client import Client
from deep_translator import GoogleTranslator
2024-01-01 22:27:06 +00:00
def predict(prompt, client, model = "0.1"):
2024-01-01 21:06:13 +00:00
global iddb
2024-01-01 22:27:06 +00:00
if model == "0.1":
result = client.predict(
prompt,
"Default",
api_name="/chat"
)
elif model == "0.2":
result = client.predict(
prompt,
2024-01-03 01:01:08 +00:00
0.05, # 'Temperature'
2024-01-01 22:27:06 +00:00
128, # 'Max new tokens'
0.8, # 'Top-p (nucleus sampling)'
2024-01-03 01:01:08 +00:00
1.8, # 'Repetition penalty'
2024-01-01 22:27:06 +00:00
api_name="/chat"
)
2024-01-01 21:06:13 +00:00
return result
# text IN language IN
def translate(text, source):
if source == "ru":
target = "en"
elif source == "en":
target = "ru"
out = GoogleTranslator(source = source, target = target).translate(text)
return out
iddb = {}
2024-01-01 22:27:06 +00:00
def gen(text, id, model):
2024-01-01 21:06:13 +00:00
global iddb
2024-01-01 22:27:06 +00:00
2024-01-01 21:06:13 +00:00
if str(id) not in iddb:
2024-01-01 22:27:06 +00:00
if model == "0.1":
2024-01-03 01:01:08 +00:00
client = Client("https://afischer1985-ai-interface.hf.space/")
2024-01-01 22:27:06 +00:00
elif model == "0.2":
client = Client("https://skier8402-mistral-super-fast.hf.space/")
2024-01-01 21:06:13 +00:00
iddb[str(id)] = client
else:
client = iddb[str(id)]
prompt = translate(text, "ru")
2024-01-03 01:01:08 +00:00
success = False
while not success:
try:
predicted = predict(prompt, client, model).replace("</s>", "")
success = True
except:
pass
2024-01-01 22:27:06 +00:00
return translate(predicted, "en")