Skip to content

Instantly share code, notes, and snippets.

@city96
Created August 2, 2024 13:33
Show Gist options
  • Save city96/30743dfdfe129b331b5676a79c3a8a39 to your computer and use it in GitHub Desktop.
Save city96/30743dfdfe129b331b5676a79c3a8a39 to your computer and use it in GitHub Desktop.
# Force model to always use specified device
# Place in `ComfyUI\custom_nodes` to use
# City96 [Apache2]
#
import types
import torch
import comfy.model_management
class OverrideDevice:
@classmethod
def INPUT_TYPES(s):
devices = ["cpu",]
for k in range(0, torch.cuda.device_count()):
devices.append(f"cuda:{k}")
return {
"required": {
"device": (devices, {"default":"cpu"}),
}
}
FUNCTION = "patch"
CATEGORY = "other"
def override(self, model, model_attr, device):
# set model/patcher attributes
model.device = device
patcher = getattr(model, "patcher", model) #.clone()
for name in ["device", "load_device", "offload_device", "current_device", "output_device"]:
setattr(patcher, name, device)
# move model to device
py_model = getattr(model, model_attr)
py_model.to = types.MethodType(torch.nn.Module.to, py_model)
py_model.to(device)
# remove ability to move model
def to(*args, **kwargs):
pass
py_model.to = types.MethodType(to, py_model)
return (model,)
def patch(self, *args, **kwargs):
raise NotImplementedError
class OverrideCLIPDevice(OverrideDevice):
@classmethod
def INPUT_TYPES(s):
k = super().INPUT_TYPES()
k["required"]["clip"] = ("CLIP",)
return k
RETURN_TYPES = ("CLIP",)
TITLE = "Force/Set CLIP Device"
def patch(self, clip, device):
return self.override(clip, "cond_stage_model", torch.device(device))
class OverrideVAEDevice(OverrideDevice):
@classmethod
def INPUT_TYPES(s):
k = super().INPUT_TYPES()
k["required"]["vae"] = ("VAE",)
return k
RETURN_TYPES = ("VAE",)
TITLE = "Force/Set VAE Device"
def patch(self, vae, device):
return self.override(vae, "first_stage_model", torch.device(device))
NODE_CLASS_MAPPINGS = {
"OverrideCLIPDevice": OverrideCLIPDevice,
"OverrideVAEDevice": OverrideVAEDevice,
}
NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()}
@jdc4429
Copy link

jdc4429 commented Aug 6, 2024

My menu is ComfyUI is very big because of over 300 added nodes.. I can't seem to find where these nodes are located to add them to a workflow. Finally found under ExtraModels/Other for anyone else looking for them. A new question, can you set the gpu to load the main checkpoint? I'm not sure how it decides which gpu to use for this when not specified.

@city96
Copy link
Author

city96 commented Aug 6, 2024

By default, the nodes are under the other category. You can also double click anywhere and search for "override" to find them.

ExtraModels/other is the category if you have the Extra models nodepack installed. I've also added this node to that pack, so there's no need for this script if you already have it installed, this just for people that prefer to have that function as a standalone.

Selecting the default device for the unet is a bit trickier since you can't use the launch arg or cuda_visible_devices (as those hides all other GPUs) and moving it with a node would probably break a bunch of stuff. Maybe raise an issue on the main comfy repo about the device select logic?

Also, here's a simple example workflow: offload_workflow_base.json

@kobechenyang
Copy link

Is it possible to offload loading controlnet model to a second gpu?

@jdc4429
Copy link

jdc4429 commented Aug 13, 2024 via email

@jdc4429
Copy link

jdc4429 commented Aug 13, 2024

unnamed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment