Created
August 25, 2023 15:44
-
-
Save catboxanon/c00b16a8afea333b71870f9a17987c36 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from modules import script_callbacks, shared | |
def on_model_loaded(sd_model): | |
if hasattr(shared.opts, 'clip_tensor_fix_enabled') and shared.opts.data.get('clip_tensor_fix_enabled', False): # type: ignore | |
try: | |
t1 = None | |
t2 = torch.arange(0, 77, dtype=torch.int64).unsqueeze(0) | |
if 'cond_stage_model.wrapped.transformer.text_model.embeddings.position_ids' in sd_model.state_dict(): | |
t1 = sd_model.cond_stage_model.hijack.clip.wrapped.transformer.text_model.embeddings.position_ids | |
elif 'cond_stage_model.wrapped.transformer.embeddings.position_ids' in sd_model.state_dict(): | |
t1 = sd_model.cond_stage_model.hijack.clip.wrapped.transformer.embeddings.position_ids | |
if t1 is not None and not torch.all(torch.eq(t1.to(torch.int64), t2.to(t1.device))).item(): | |
t1 = t2 | |
print('CLIP IDs tensor repaired!') | |
elif t1 is not None: | |
print('CLIP IDs tensor OK!') | |
else: | |
print('CLIP IDs tensor not found.') | |
except Exception as e: | |
print('Exception thrown when trying to verify/fix CLIP tensor: ', e) | |
def on_ui_settings(): | |
section = ('clip_tensor_fix', 'CLIP Tensor Fix') | |
shared.opts.add_option('clip_tensor_fix_enabled', shared.OptionInfo( # type: ignore | |
False, 'Enable CLIP tensor fix on model load', section=section | |
)) | |
script_callbacks.on_model_loaded(on_model_loaded) | |
script_callbacks.on_ui_settings(on_ui_settings) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment