Created
April 18, 2024 15:48
-
-
Save Sam-Izdat/eb24e1f0f304d7794da8f5669545669c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import torch | |
import numpy as np | |
import typing | |
from typing import Union | |
from combomethod import combomethod | |
from tinycio import fsio | |
from .util import * | |
from .resampling import Resampling | |
class Atlas: | |
"""Texture atlas packing and sampling.""" | |
min_auto_size = 64 | |
max_auto_size = 8192 | |
auto_force_square = False | |
def __init__(self, min_auto_size:int=64, max_auto_size:int=8192, auto_force_square:bool=False): | |
self.min_auto_size = min_auto_size | |
self.max_auto_size = max_auto_size | |
self.auto_force_square = auto_force_square | |
class _TextureRect: | |
def __init__(self, tensor, idx, key): | |
self.tensor = tensor | |
self.idx = idx | |
self.key = key | |
self.x = 0 | |
self.y = 0 | |
self.was_packed = False | |
err_out_of_bounds = 'failed to to fit textures into atlas' | |
@combomethod | |
def pack(cls, | |
textures:dict, | |
max_height:int=0, | |
max_width:int=0, | |
auto_crop:bool=True, | |
row_pack:bool=False, | |
sort:str='height') -> (torch.Tensor, tuple): | |
H, W = max_height, max_width | |
auto_crop = auto_crop and not cls.auto_force_square | |
if W == 0 or H == 0: | |
i = 0 | |
auto_width, auto_height = W, H | |
while auto_height < cls.max_auto_size and auto_width < cls.max_auto_size: | |
if cls.auto_force_square: | |
auto_height = (H or int(next_pot(cls.min_auto_size)) << i) | |
auto_width = (W or int(next_pot(cls.min_auto_size)) << i) | |
else: | |
if H == 0 and (W != 0 or auto_width > auto_height): | |
auto_height = (H or int(next_pot(cls.min_auto_size)) << i) | |
elif W == 0 and (H != 0 or auto_height >= auto_width): | |
auto_width = (W or int(next_pot(cls.min_auto_size)) << i) | |
else: | |
raise Exception('undefined') # should not happen | |
atlas, index = cls.__row_pack(textures, (auto_height, auto_width), sort=sort, auto_crop=auto_crop) if row_pack \ | |
else cls.__rect_pack(textures, (auto_height, auto_width), sort=sort, auto_crop=auto_crop) | |
if not atlas is False: | |
return atlas, index | |
i += 1 | |
raise Exception(cls.err_out_of_bounds + f" at w {auto_width} h {auto_height}") | |
return cls.__row_pack(textures, (H, W), auto_crop=auto_crop, sort=sort, must_succeed=True) if row_pack \ | |
else cls.__rect_pack(textures, (H, W), auto_crop=auto_crop, sort=sort, must_succeed=True) | |
@combomethod | |
def pack_dir(cls, | |
dp:str, | |
ext:str='.png', | |
max_height:int=0, | |
max_width:int=0, | |
auto_crop:bool=True, | |
row_pack:bool=False, | |
sort:str='height', | |
channels:int=3, | |
allow_channel_mismatch=True) -> (torch.Tensor, list): | |
""" | |
Pack atlas with image files in target directory. Images base filenames will be used as keys. Non-recursive. | |
""" | |
textures = {} | |
for fn in os.listdir(dp): | |
fp = os.path.join(dp, fn) | |
fnext = os.path.splitext(fn)[1] | |
if os.path.isfile(fp) and (fnext == ext): | |
im = fsio.load_image(fp) | |
if allow_channel_mismatch: | |
if im.size(0) < channels: im = im.repeat(channels, 1, 1) | |
if im.size(0) > channels: im = im[0:channels, ...] | |
else: | |
assert channels == im.size(0), f"channel size mismatch for: {fn} - expected {channels}, got {im.size(0)}" | |
fnwe = os.path.splitext(os.path.basename(fp))[0] | |
textures[fnwe] = im | |
assert len(textures) > 0, "did not find any textures" | |
return cls.pack(textures=textures, max_height=max_height, max_width=max_width, auto_crop=auto_crop, row_pack=row_pack, sort=sort) | |
@combomethod | |
def sample(cls, atlas:torch.Tensor, index:dict, key:Union[str, int]): | |
assert len(index) > 0, "index is empty" | |
if isinstance(key, str) and key in index: | |
x0, y0, x1, y1 = index[key] | |
return atlas[:, int(y0):int(y1), int(x0):int(x1)] | |
elif isinstance(key, int): | |
x0, y0, x1, y1 = list(index.values())[key] | |
return atlas[:, int(y0):int(y1), int(x0):int(x1)] | |
else: | |
raise KeyError("key not found in index") | |
@combomethod | |
def sample_random(cls, atlas:torch.Tensor, index:dict): | |
assert len(index) > 0, "index is empty" | |
x0, y0, x1, y1 = random.choice(list(index.values())) | |
return atlas[:, int(y0):int(y1), int(x0):int(x1)] | |
@classmethod | |
def __sp_push_back(cls, spaces, space): | |
return space if spaces == None else torch.cat([spaces, space], dim=0) | |
@classmethod | |
def __sp_rem(cls, spaces, idx): | |
return torch.cat((spaces[:idx], spaces[idx+1:])) | |
# https://github.com/TeamHypersomnia/rectpack2D?tab=readme-ov-file#algorithm | |
# A bit slower. Suitable for high variance. | |
@classmethod | |
def __rect_pack(cls, | |
textures:dict, | |
shape:tuple, | |
auto_crop:bool=True, | |
sort:str='height', | |
must_succeed:bool=False) -> (torch.Tensor, tuple): | |
texture_rects = [] | |
max_w, max_h = 0, 0 | |
for k, v in enumerate(textures): | |
texture_rects.append(cls._TextureRect(textures[v], idx=k, key=v)) | |
atlas_height = shape[0] | |
atlas_width = shape[1] | |
atlas = torch.zeros(texture_rects[0].tensor.size(0), atlas_height, atlas_width) | |
# x0, y0, x1, y1 | |
empty_spaces = None | |
empty_spaces = cls.__sp_push_back(empty_spaces, torch.Tensor([[0, 0, atlas_width, atlas_height]])) | |
# Sort textures in descending order | |
if sort == 'height': | |
texture_rects.sort(key=lambda tex: tex.tensor.size(1), reverse=True) | |
elif sort == 'width': | |
texture_rects.sort(key=lambda tex: tex.tensor.size(2), reverse=True) | |
elif sort == 'area': | |
texture_rects.sort(key=lambda tex: (tex.tensor.size(1) * tex.tensor.size(2)), reverse=True) | |
else: | |
raise Exception(f'unrecognized sort order: {sort}') | |
for i, tex in enumerate(texture_rects): | |
tex_h, tex_w = tex.tensor.shape[1:] | |
best_fit_area = None | |
best_fit_idx = None | |
for space_idx in range(empty_spaces.size(0)): | |
space_idx = empty_spaces.size(0) - 1 - space_idx | |
space = empty_spaces[space_idx:space_idx+1,...] | |
sp_w, sp_h = space[0,2].item(), space[0,3].item() | |
if sp_w >= tex_w and sp_h >= tex_h: | |
if best_fit_area == None or best_fit_area > sp_w * sp_h: | |
best_fit_area = sp_w * sp_h | |
best_fit_idx = space_idx | |
if best_fit_idx == None: | |
if must_succeed: | |
raise Exception(cls.err_out_of_bounds + f" at w {atlas_width} h {atlas_height}") | |
else: | |
return False, False | |
space = empty_spaces[best_fit_idx:best_fit_idx+1,...] | |
sp_x, sp_y = space[0,0].item(), space[0,1].item() | |
sp_w, sp_h = space[0,2].item(), space[0,3].item() | |
atlas[..., | |
int(sp_y):int(sp_y+tex_h), | |
int(sp_x):int(sp_x+tex_w)] = tex.tensor | |
tex.x = sp_x | |
tex.y = sp_y | |
tex.was_packed = True | |
if sp_w > tex_w and sp_h > tex_h: | |
split1 = torch.Tensor([[ | |
sp_x, | |
sp_y+tex_h, | |
sp_w, | |
sp_h-tex_h]]) | |
split2 = torch.Tensor([[ | |
sp_x+tex_w, | |
sp_y, | |
sp_w-tex_w, | |
tex_h]]) | |
empty_spaces = cls.__sp_rem(empty_spaces, best_fit_idx) | |
if split2[0,2].item()*split2[0,3].item() > split1[0,2].item()*split1[0,3].item(): | |
empty_spaces = cls.__sp_push_back(empty_spaces, split2) | |
empty_spaces = cls.__sp_push_back(empty_spaces, split1) | |
else: | |
empty_spaces = cls.__sp_push_back(empty_spaces, split1) | |
empty_spaces = cls.__sp_push_back(empty_spaces, split2) | |
elif sp_w > tex_w: | |
split = torch.Tensor([[ | |
sp_x+tex_w, | |
sp_y, | |
sp_w-tex_w, | |
tex_h]]) | |
empty_spaces = cls.__sp_rem(empty_spaces, best_fit_idx) | |
empty_spaces = cls.__sp_push_back(empty_spaces, split) | |
elif sp_h > tex_h: | |
split = torch.Tensor([[ | |
sp_x, | |
sp_y+tex_h, | |
sp_w, | |
sp_h-tex_h]]) | |
empty_spaces = cls.__sp_rem(empty_spaces, best_fit_idx) | |
empty_spaces = cls.__sp_push_back(empty_spaces, split) | |
elif sp_h == tex_h and sp_w == tex_w: | |
empty_spaces = cls.__sp_rem(empty_spaces, best_fit_idx) | |
else: | |
raise Exception(cls.err_out_of_bounds, + f" at w {atlas_width} h {atlas_height}") | |
# Sort textures by input order | |
texture_rects.sort(key=lambda tex: tex.idx) | |
index = {} | |
for tex in texture_rects: | |
index[tex.key] = (tex.x, tex.y, tex.x+tex.tensor.size(2), tex.y+tex.tensor.size(1)) | |
if (tex.x+tex.tensor.size(2)) > max_w: max_w = tex.x+tex.tensor.size(2) | |
if (tex.y+tex.tensor.size(1)) > max_h: max_h = tex.y+tex.tensor.size(1) | |
if auto_crop: atlas = Resampling.crop(atlas, (int(max_h), int(max_w))) | |
return atlas, index | |
# https://www.david-colson.com/2020/03/10/exploring-rect-packing.html | |
# Faster. Suitable for low variance. | |
@classmethod | |
def __row_pack(cls, | |
textures:dict, | |
shape:tuple, | |
auto_crop:bool=True, | |
sort:str='height', | |
must_succeed:bool=False) -> (torch.Tensor, tuple): | |
texture_rects = [] | |
max_w, max_h = 0, 0 | |
for k, v in enumerate(textures): | |
texture_rects.append(cls._TextureRect(textures[v], idx=k, key=v)) | |
# Sort textures in descending order | |
if sort == 'height': | |
texture_rects.sort(key=lambda tex: tex.tensor.size(1), reverse=True) | |
elif sort == 'width': | |
texture_rects.sort(key=lambda tex: tex.tensor.size(2), reverse=True) | |
elif sort == 'area': | |
texture_rects.sort(key=lambda tex: (tex.tensor.size(1) * tex.tensor.size(2)), reverse=True) | |
else: | |
raise Exception(f'unrecognized sort order: {sort}') | |
atlas_height = shape[0] | |
atlas_width = shape[1] | |
atlas = torch.zeros(texture_rects[0].tensor.size(0), atlas_height, atlas_width) | |
x_pos = 0 | |
y_pos = 0 | |
largest_height_this_row = 0 | |
# Loop over all the textures | |
for tex in texture_rects: | |
tex_h, tex_w = tex.tensor.shape[1:] | |
# If this texture will go past the width of the atlas, | |
# loop around to the next row, using the largest height from the previous row | |
if (x_pos + tex.tensor.size(2)) > atlas_width: | |
y_pos = y_pos + largest_height_this_row | |
x_pos = 0 | |
largest_height_this_row = 0 | |
if (y_pos + tex_h) > atlas_height or (x_pos + tex_w) > atlas_width: | |
if must_succeed: | |
raise Exception(cls.err_out_of_bounds, + f" at w {atlas_width} h {atlas_height}") | |
else: | |
return False, False | |
tex.x = x_pos | |
tex.y = y_pos | |
atlas[:, y_pos:y_pos + tex_h, x_pos:x_pos + tex_w] = tex.tensor | |
x_pos += tex_w | |
# Save largest height in the new row | |
if tex_h > largest_height_this_row: | |
largest_height_this_row = tex_h | |
tex.was_packed = True | |
# Sort textures by input order | |
index = {} | |
for tex in texture_rects: | |
index[tex.key] = (tex.x, tex.y, tex.x+tex.tensor.size(2), tex.y+tex.tensor.size(1)) | |
if (tex.x+tex.tensor.size(2)) > max_w: max_w = tex.x+tex.tensor.size(2) | |
if (tex.y+tex.tensor.size(1)) > max_h: max_h = tex.y+tex.tensor.size(1) | |
if auto_crop: atlas = Resampling.crop(atlas, (max_h, max_w)) | |
return atlas, index |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment