from gradio_client import Client from deep_translator import GoogleTranslator # Обработка запроса генерации def predict(prompt, client, model = "0.1"): global iddb if model == "0.1": result = client.predict( prompt, "Default", api_name="/chat" ) elif model == "0.2": result = client.predict( prompt, 0.5, # 'Температура' 128, # 'Длина' 0.8, # 'Top-p (nucleus sampling)' 1.8, # 'Repetition penalty' api_name="/chat" ) elif model == "RWKV": result = client.predict( prompt, 333, 0.6, 0, # int | float representing numeric value between 0.0 and 1.0 in 'Top P' Slider component 0, # int | float representing numeric value between 0.0 and 1.0 in 'Presence Penalty' Slider component 0, # int | float representing numeric value between 0.0 and 1.0 in 'Count Penalty' Slider component fn_index=0 ) else: print("INCORRECT MODEL: ", model) print(type(model)) return result # Определяем что это код def iscode(text): # Теги langs = ['sql','php','js','java','c','cpp','python','go'] is_code = False for i in langs: if i + r'\n' in text: is_code = True break # Тег ассемблера spec = ['section .'] if not is_code: for i in spec: if i in text: is_code = True break return is_code # text IN language IN def translate(text, source): if source == "ru": target = "en" elif source == "en": target = "ru" # Исправление перевода кода if '```' in text: out = '' for i in text.split('```'): if iscode(i): out += '```' + i + '```' else: out += GoogleTranslator(source = source, target = target).translate(i) else: out = GoogleTranslator(source = source, target = target).translate(text) return out # Словарь пользователь - сессия iddb = {} def gen(prompt, id, model): global iddb # Если нету сессии if str(id) not in iddb: if model == "0.1": client = Client("https://afischer1985-ai-interface.hf.space/") elif model == "0.2": client = Client("https://nonamed33-minigpt-api.hf.space/") elif model == "RWKV": client = Client("https://blinkdl-rwkv-gradio-1.hf.space/") iddb[str(id)] = client else: client = iddb[str(id)] try: if model == "0.1" or model == "0.2": prompt = translate(prompt, "ru") predicted = predict(prompt, client, model).replace("", "") predicted = translate(predicted, "en") elif model == "RWKV": predicted = predict(prompt, client, model) else: print("INCORRECT MODEL: ", model) print(type(model)) except: pass fixed = predicted.replace(r'\n', '\n').replace('\\ n', '\n') return fixed