Skip to content

Instantly share code, notes, and snippets.

@tori29umai0123
Last active August 10, 2024 08:05
Show Gist options
  • Save tori29umai0123/b6ce6d6450e5f87c00633af7d37915ea to your computer and use it in GitHub Desktop.
Save tori29umai0123/b6ce6d6450e5f87c00633af7d37915ea to your computer and use it in GitHub Desktop.
AI-NovelChat
import os
import sys
import time
import socket
import gradio as gr
from llama_cpp import Llama
import datetime
from jinja2 import Template
import configparser
import threading
import asyncio
import csv
DEFAULT_INI_FILE = 'settings.ini'
# ビルドしているかしていないかでパスを変更
if getattr(sys, 'frozen', False):
path = os.path.dirname(sys.executable)
model_dir = os.path.join(os.path.dirname(path), "AI-NovelAssistant", "models")
else:
path = os.path.dirname(os.path.abspath(__file__))
model_dir = os.path.join(path, "models")
def get_model_files():
return [f for f in os.listdir(model_dir) if f.endswith('.gguf')]
def load_settings_from_ini(filename):
config = configparser.ConfigParser()
if not os.path.exists(filename):
print(f"{filename} が見つかりません。デフォルト設定で作成します。")
create_default_ini(filename)
config.read(filename, encoding='utf-8')
settings = {}
if 'Character' in config:
settings['instructions'] = config['Character'].get('instructions', '')
settings['example_qa'] = config['Character'].get('example_qa', '').split('\n')
settings['initial_conversation'] = config['Character'].get('initial_conversation', '')
if 'Models' in config:
settings['DEFAULT_CHAT_MODEL'] = config['Models'].get('DEFAULT_CHAT_MODEL', '')
settings['DEFAULT_GEN_MODEL'] = config['Models'].get('DEFAULT_GEN_MODEL', '')
return settings
def save_settings_to_ini(settings, filename):
config = configparser.ConfigParser()
config['Character'] = {
'instructions': settings.get('instructions', ''),
'example_qa': '\n'.join(settings.get('example_qa', [])),
'initial_conversation': settings.get('initial_conversation', '')
}
config['Models'] = {
'DEFAULT_CHAT_MODEL': settings.get('DEFAULT_CHAT_MODEL', ''),
'DEFAULT_GEN_MODEL': settings.get('DEFAULT_GEN_MODEL', '')
}
with open(filename, 'w', encoding='utf-8') as configfile:
config.write(configfile)
def create_default_ini(filename):
default_settings = {
'instructions': "丁寧な敬語でアイディアのヒアリングしてください。物語をより面白くする提案、キャラクター造形の考察、世界観を膨らませる手伝いなどをお願いします。求められた時以外は基本、聞き役に徹してユーザー自身に言語化させるよう促してください。ユーザーのことは『ユーザー』と呼んでください。",
'example_qa': [
"user: キャラクターの設定について悩んでいます。",
"assistant: 承知いたしました。キャラクター設定は物語の核となる重要な要素ですね。ユーザー様が現在考えているキャラクターについて、簡単にご説明いただけますでしょうか?例えば、年齢、性別、職業、性格の特徴などから始めていただけると、より具体的なアドバイスができるかと思います。"
"user: プロットを書き出したいので、ヒアリングお願いします。",
"assistant: 承知しました。ではまず『起承転結』の起から考えていきましょう。",
"user: 読者を惹きこむ為のコツを提案してください",
"assistant: 諸説ありますが、『謎・ピンチ・意外性』を冒頭に持ってくることが重要だと言います。",
"user: プロットが面白いか自信がないので、考察のお手伝いをお願いします",
"assistant: 承知しました。まずコメントをする前にこの物語の『売り』について簡単に言語化してください",
],
'DEFAULT_CHAT_MODEL': 'Ninja-v1-RP-expressive-v2_Q4_K_M.gguf',
'DEFAULT_GEN_MODEL': 'Mistral-Nemo-Instruct-2407-Q8_0.gguf'
}
save_settings_to_ini(default_settings, filename)
def list_log_files():
logs_dir = os.path.join(path, "logs")
if not os.path.exists(logs_dir):
return []
return [f for f in os.listdir(logs_dir) if f.endswith('.csv')]
def load_chat_log(file_name):
file_path = os.path.join(path, "logs", file_name)
chat_history = []
with open(file_path, 'r', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile)
next(reader) # Skip header
for row in reader:
if len(row) == 2:
role, message = row
if role == "user":
chat_history.append([message, None])
elif role == "assistant":
if chat_history and chat_history[-1][1] is None:
chat_history[-1][1] = message
else:
chat_history.append([None, message])
return chat_history
class GentextParams:
def __init__(self):
self.gen_temperature = 0.35
self.gen_top_p = 1.0
self.gen_top_k = 40
self.gen_rep_pen = 1.0
self.chat_temperature = 0.5
self.chat_top_p = 0.7
self.chat_top_k = 80
self.chat_rep_pen = 1.2
def update_generate_parameters(self, temperature, top_p, top_k, rep_pen):
self.gen_temperature = temperature
self.gen_top_p = top_p
self.gen_top_k = top_k
self.gen_rep_pen = rep_pen
def update_chat_parameters(self, temperature, top_p, top_k, rep_pen):
self.chat_temperature = temperature
self.chat_top_p = top_p
self.chat_top_k = top_k
self.chat_rep_pen = rep_pen
params = GentextParams()
class LlamaAdapter:
def __init__(self, model_path, params):
self.llm = Llama(model_path=model_path, n_ctx=10000)
self.params = params
def generate_text(self, text, author_description, token_multiplier, instruction):
input_tokens = self.llm.tokenize(text.encode())
max_tokens = int(len(input_tokens) * token_multiplier)
response = self.llm.create_chat_completion(
messages=[
{"role": "system", "content": author_description},
{"role": "user", "content": f"{instruction}\n\n{text}"},
],
max_tokens=max_tokens, temperature=self.params.gen_temperature, top_p=self.params.gen_top_p, top_k=self.params.gen_top_k, repeat_penalty=self.params.gen_rep_pen,
)
return response["choices"][0]["message"]["content"].strip()
def generate(self, prompt, max_new_tokens=10000):
return self.llm(prompt, temperature=self.params.chat_temperature, max_tokens=max_new_tokens, top_p=self.params.chat_top_p, top_k=self.params.chat_top_k, repeat_penalty=self.params.chat_rep_pen, stop=["user:", "・会話履歴", "<END>"])
class CharacterMaker:
def __init__(self):
self.llama = None
self.history = []
self.settings = None
self.model_loaded = threading.Event()
def set_model(self, model_name):
def load_model():
try:
model_path = os.path.join(model_dir, model_name)
self.llama = LlamaAdapter(model_path, params)
self.model_loaded.set()
print(f"モデル {model_name} のロードが完了しました。")
except Exception as e:
print(f"モデルのロード中にエラーが発生しました: {str(e)}")
self.model_loaded.set() # エラーの場合でもイベントをセット
threading.Thread(target=load_model).start()
def make(self, input_str: str):
if not self.model_loaded.is_set():
return "モデルをロード中です。しばらくお待ちください。"
if not self.llama:
return "モデルのロードに失敗しました。設定を確認してください。"
prompt = self._generate_prompt(input_str)
res = self.llama.generate(prompt, max_new_tokens=1000, stop=["<END>", "\n"])
res_text = res["choices"][0]["text"]
self.history.append({"user": input_str, "assistant": res_text})
return res_text
def make_prompt(self, input_str: str):
prompt_template = """{{instructions}}
・キャラクターの回答例
{% for qa in example_qa %}
{{qa}}
{% endfor %}
・会話履歴
{% for history in histories %}
user: {{history.user}}
assistant: {{history.assistant}}
{% endfor %}
user: {{input_str}}
assistant:"""
template = Template(prompt_template)
return template.render(
instructions=self.settings.get('instructions', ''),
example_qa=self.settings.get('example_qa', []),
histories=self.history,
input_str=input_str
)
def _generate_prompt(self, input_str: str):
return self.make_prompt(input_str)
def update_settings(self, new_settings, filename):
self.settings.update(new_settings)
save_settings_to_ini(self.settings, filename)
self.set_model(self.settings['DEFAULT_CHAT_MODEL'])
def load_character(self, filename):
if isinstance(filename, list):
filename = filename[0] if filename else ""
self.settings = load_settings_from_ini(filename)
if self.settings:
self.set_model(self.settings['DEFAULT_CHAT_MODEL'])
return f"{filename}から設定を読み込み、モデル {self.settings['DEFAULT_CHAT_MODEL']} を設定しました。"
return f"{filename}の読み込みに失敗しました。"
def reset(self):
self.history = []
if self.llama:
self.set_model(self.settings['DEFAULT_CHAT_MODEL'])
character_maker = CharacterMaker()
async def chat_with_character(message, history):
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
prompt = character_maker._generate_prompt(message)
response = character_maker.llama.generate(prompt, max_new_tokens=1000)["choices"][0]["text"]
for i in range(len(response)):
time.sleep(0.01)
yield response[: i+1]
def generate_text_with_token_multiplier(text, author_type, genre, writing_style, target_audience, token_multiplier, model_name, instruction):
author_description = f"あなたは{author_type}で、{genre}{writing_style}の文体で{target_audience}に人気があります。"
model_path = os.path.join(model_dir, model_name)
llama = LlamaAdapter(model_path, params)
return llama.generate_text(text, author_description, token_multiplier, instruction)
def clear_chat():
character_maker.reset()
return []
def save_chat_log(chat_history):
current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
filename = f"{current_time}.csv"
logs_dir = os.path.join(path, "logs")
if not os.path.exists(logs_dir):
os.makedirs(logs_dir)
file_path = os.path.join(logs_dir, filename)
with open(file_path, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["Role", "Message"])
for user_message, assistant_message in chat_history:
if user_message:
writer.writerow(["user", user_message])
if assistant_message:
writer.writerow(["assistant", assistant_message])
return f"チャットログが {file_path} に保存されました。"
def get_ip_address():
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(('10.255.255.255', 1))
IP = s.getsockname()[0]
except Exception:
IP = '127.0.0.1'
finally:
s.close()
return IP
def is_port_in_use(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
def find_available_port(starting_port):
port = starting_port
while is_port_in_use(port):
print(f"Port {port} is in use, trying next one.")
port += 1
return port
model_files = get_model_files()
def build_gradio_interface():
global demo
# カスタムCSS
custom_css = """
#chatbot, #chatbot_read {
height: 50vh;
overflow-y: auto;
resize: vertical;
border: 1px solid #ccc;
}
/* サイズ変更用のグリップをより直感的に操作できるようにスタイリング */
.resizer-grip {
height: 10px;
background: #ccc;
cursor: ns-resize;
}
"""
with gr.Blocks(css=custom_css) as demo:
# HTMLブロックでカスタムJavaScriptと追加のCSSを注入
gr.HTML("""
<style>
#chatbot, #chatbot_read {
resize: both;
overflow: auto;
min-height: 100px;
max-height: 80vh;
}
</style>
<script>
// リサイズを処理するためのJavaScript、必要であれば
document.addEventListener('DOMContentLoaded', function() {
const chatboxes = document.querySelectorAll('#chatbot, #chatbot_read');
chatboxes.forEach(chatbox => {
chatbox.addEventListener('mousedown', function(e) {
console.log('Resizing started');
});
});
});
</script>
""")
with gr.Tab("チャット"):
chatbot = gr.Chatbot(elem_id="chatbot")
chat_interface = gr.ChatInterface(
chat_with_character,
chatbot=chatbot,
textbox=gr.Textbox(placeholder="メッセージを入力してください...", container=False, scale=7),
theme="soft",
submit_btn="送信",
stop_btn="停止",
retry_btn="もう一度生成",
undo_btn="前のメッセージを取り消す",
clear_btn="チャットをクリア",
)
with gr.Row():
model_dropdown = gr.Dropdown(choices=model_files, label="モデル選択", value=character_maker.settings.get('DEFAULT_CHAT_MODEL', ''))
save_log_button = gr.Button("チャットログを保存")
save_log_output = gr.Textbox(label="保存状態")
with gr.Accordion("詳細設定", open=False):
chat_temperature = gr.Slider(label="Temperature", value=0.5, minimum=0.0, maximum=1.0, step=0.05, interactive=True)
chat_top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.7, minimum=0.0, maximum=1, step=0.05, interactive=True)
chat_top_k = gr.Slider(label="Top-k", value=80, minimum=1, maximum=200, step=1, interactive=True)
chat_rep_pen = gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True)
apply_settings = gr.Button("設定を適用")
def apply_chat_settings(temp, top_p, top_k, rep_pen):
params.update_chat_parameters(temp, top_p, top_k, rep_pen)
return f"設定を適用しました: Temperature={temp}, Top-p={top_p}, Top-k={top_k}, Repetition Penalty={rep_pen}"
apply_settings.click(
apply_chat_settings,
inputs=[chat_temperature, chat_top_p, chat_top_k, chat_rep_pen],
outputs=[save_log_output]
)
save_log_button.click(
save_chat_log,
inputs=[chatbot],
outputs=[save_log_output]
)
model_dropdown.change(
lambda x: character_maker.set_model(x),
inputs=[model_dropdown],
outputs=[]
)
with gr.Tab("文章生成"):
with gr.Row():
with gr.Column(scale=2):
instruction_type = gr.Dropdown(
choices=["自由入力", "推敲", "プロット作成", "あらすじ作成"],
label="指示タイプ",
value="自由入力"
)
gen_instruction = gr.Textbox(
label="指示",
value="",
lines=3
)
gen_input_text = gr.Textbox(lines=5, label="処理されるテキストを入力してください")
with gr.Column(scale=1):
gen_author_type = gr.Textbox(label="作家のタイプ", value="新進気鋭のSF小説家")
gen_genre = gr.Textbox(label="ジャンル", value="斬新なアイデア")
gen_writing_style = gr.Textbox(label="文体", value="切れ味のある文体、流麗な文章")
gen_target_audience = gr.Textbox(label="ターゲット読者", value="若い世代")
token_multiplier = gr.Slider(minimum=0.1, maximum=20, value=1.5, step=0.1, label="トークン倍率", info="入力トークン数に対する生成トークン数の倍率(0.1〜20)")
gen_model = gr.Dropdown(choices=model_files, label="モデル選択", value=character_maker.settings.get('DEFAULT_GEN_MODEL', ''))
generate_button = gr.Button("文章生成開始")
generated_output = gr.Textbox(label="生成された文章")
with gr.Accordion("詳細設定", open=False):
gen_temperature = gr.Slider(label="Temperature", value=0.35, minimum=0.0, maximum=1.0, step=0.05, interactive=True)
gen_top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.9, minimum=0.0, maximum=1, step=0.05, interactive=True)
gen_top_k = gr.Slider(label="Top-k", value=40, minimum=1, maximum=200, step=1, interactive=True)
gen_rep_pen = gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05)
apply_settings = gr.Button("設定を適用")
def apply_gen_settings(temp, top_p, top_k, rep_pen):
params.update_generate_parameters(temp, top_p, top_k, rep_pen)
return f"設定を適用しました: Temperature={temp}, Top-p={top_p}, Top-k={top_k}, Repetition Penalty={rep_pen}"
apply_settings.click(
apply_gen_settings,
inputs=[gen_temperature, gen_top_p, gen_top_k, gen_rep_pen],
outputs=[save_log_output]
)
generate_button.click(
generate_text_with_token_multiplier,
inputs=[gen_input_text, gen_author_type, gen_genre, gen_writing_style, gen_target_audience, token_multiplier, gen_model, gen_instruction],
outputs=[generated_output]
)
def update_instruction(choice):
instructions = {
"自由入力": "",
"推敲": "以下のテキストを推敲してください。原文の文体や特徴的な表現は保持しつつ、必要に応じて微調整を加えてください。文章の流れを自然にし、表現を洗練させることが目標ですが、元の雰囲気や個性を損なわないよう注意してください",
"プロット作成": "以下のテキストをプロットにしてください。起承転結に分割すること。",
"あらすじ作成": "以下のテキストをあらすじにして、簡潔にまとめて下さい。"
}
return instructions.get(choice, "")
instruction_type.change(
update_instruction,
inputs=[instruction_type],
outputs=[gen_instruction]
)
generate_button.click(
generate_text_with_token_multiplier,
inputs=[gen_input_text, gen_author_type, gen_genre, gen_writing_style, gen_target_audience, token_multiplier, gen_model, gen_instruction],
outputs=[generated_output]
)
# Gradioインターフェースの "チャットログ閲覧" タブを更新
with gr.Tab("ログ閲覧"):
gr.Markdown("## チャットログ閲覧")
chatbot_read = gr.Chatbot(elem_id="chatbot_read")
log_file_dropdown = gr.Dropdown(label="ログファイル選択", choices=list_log_files())
refresh_log_list_button = gr.Button("ログファイルリストを更新")
def update_log_dropdown():
return gr.update(choices=list_log_files())
def load_and_display_chat_log(file_name):
chat_history = load_chat_log(file_name)
return gr.update(value=chat_history)
refresh_log_list_button.click(
update_log_dropdown,
outputs=[log_file_dropdown]
)
log_file_dropdown.change(
load_and_display_chat_log,
inputs=[log_file_dropdown],
outputs=[chatbot_read]
)
async def load_model_and_start_gradio():
# INIファイルが存在しない場合、デフォルトのINIファイルを作成
if not os.path.exists(DEFAULT_INI_FILE):
print(f"{DEFAULT_INI_FILE} が見つかりません。デフォルト設定で作成します。")
create_default_ini(DEFAULT_INI_FILE)
# デフォルト設定の読み込み
result = character_maker.load_character(DEFAULT_INI_FILE)
print(result)
# モデルのロード完了を待つ
while not character_maker.model_loaded.is_set():
await asyncio.sleep(1)
if not character_maker.llama:
print("モデルのロードに失敗しました。アプリケーションを終了します。")
return
# Gradio インターフェースの構築
build_gradio_interface()
ip_address = get_ip_address()
starting_port = 7860
port = find_available_port(starting_port)
print(f"サーバーのアドレス: http://{ip_address}:{port}")
demo.queue()
demo.launch(
server_name='0.0.0.0',
server_port=port,
share=False,
favicon_path=os.path.join(path, "custom.html")
)
if __name__ == "__main__":
asyncio.run(load_model_and_start_gradio())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment