mirror of
https://github.com/microsoft/TRELLIS.2.git
synced 2026-04-02 02:27:08 -04:00
update texturing pipeline
This commit is contained in:
@@ -49,7 +49,7 @@ Data processing is streamlined for instant conversions that are fully **renderin
|
||||
- [x] Release image-to-3D inference code
|
||||
- [x] Release pretrained checkpoints (4B)
|
||||
- [x] Hugging Face Spaces demo
|
||||
- [ ] Release shape-conditioned texture generation inference code (Current schdule: before 12/24/2025)
|
||||
- [x] Release shape-conditioned texture generation inference code
|
||||
- [ ] Release training code (Current schdule: before 12/31/2025)
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ Then, you can access the demo at the address shown in the terminal.
|
||||
|
||||
### 2. PBR Texture Generation
|
||||
|
||||
Will be released soon. Please stay tuned!
|
||||
Please refer to the [example_texturing.py](example_texturing.py) for an example of how to generate PBR textures for a given 3D shape. Also, you can use the [app_texturing.py](app_texturing.py) to run a web demo for PBR texture generation.
|
||||
|
||||
## 🧩 Related Packages
|
||||
|
||||
|
||||
151
app_texturing.py
Normal file
151
app_texturing.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import gradio as gr
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
from datetime import datetime
|
||||
import shutil
|
||||
from typing import *
|
||||
import torch
|
||||
import numpy as np
|
||||
import trimesh
|
||||
from PIL import Image
|
||||
from trellis2.pipelines import Trellis2TexturingPipeline
|
||||
|
||||
|
||||
MAX_SEED = np.iinfo(np.int32).max
|
||||
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
||||
|
||||
|
||||
def start_session(req: gr.Request):
|
||||
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||
os.makedirs(user_dir, exist_ok=True)
|
||||
|
||||
|
||||
def end_session(req: gr.Request):
|
||||
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||
shutil.rmtree(user_dir)
|
||||
|
||||
|
||||
def preprocess_image(image: Image.Image) -> Image.Image:
|
||||
"""
|
||||
Preprocess the input image.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The input image.
|
||||
|
||||
Returns:
|
||||
Image.Image: The preprocessed image.
|
||||
"""
|
||||
processed_image = pipeline.preprocess_image(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
def get_seed(randomize_seed: bool, seed: int) -> int:
|
||||
"""
|
||||
Get the random seed.
|
||||
"""
|
||||
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
|
||||
|
||||
|
||||
def shapeimage_to_tex(
|
||||
mesh_file: str,
|
||||
image: Image.Image,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
texture_size: int,
|
||||
tex_slat_guidance_strength: float,
|
||||
tex_slat_guidance_rescale: float,
|
||||
tex_slat_sampling_steps: int,
|
||||
tex_slat_rescale_t: float,
|
||||
req: gr.Request,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
) -> str:
|
||||
mesh = trimesh.load(mesh_file)
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
mesh = mesh.to_mesh()
|
||||
output = pipeline.run(
|
||||
mesh,
|
||||
image,
|
||||
seed=seed,
|
||||
preprocess_image=False,
|
||||
tex_slat_sampler_params={
|
||||
"steps": tex_slat_sampling_steps,
|
||||
"guidance_strength": tex_slat_guidance_strength,
|
||||
"guidance_rescale": tex_slat_guidance_rescale,
|
||||
"rescale_t": tex_slat_rescale_t,
|
||||
},
|
||||
resolution=int(resolution),
|
||||
texture_size=texture_size,
|
||||
)
|
||||
now = datetime.now()
|
||||
timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
|
||||
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||
os.makedirs(user_dir, exist_ok=True)
|
||||
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
|
||||
output.export(glb_path, extension_webp=True)
|
||||
torch.cuda.empty_cache()
|
||||
return glb_path, glb_path
|
||||
|
||||
|
||||
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
||||
gr.Markdown("""
|
||||
## Texturing a mesh with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
|
||||
* Upload a mesh and corresponding reference image (preferably with an alpha-masked foreground object) and click Generate to create a textured 3D asset.
|
||||
""")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=360):
|
||||
mesh_file = gr.File(label="Upload Mesh", file_types=[".ply", ".obj", ".glb", ".gltf"], file_count="single")
|
||||
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
|
||||
|
||||
resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
|
||||
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
|
||||
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
||||
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
|
||||
|
||||
generate_btn = gr.Button("Generate")
|
||||
|
||||
with gr.Accordion(label="Advanced Settings", open=False):
|
||||
with gr.Row():
|
||||
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
|
||||
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
|
||||
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
||||
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
|
||||
|
||||
with gr.Column(scale=10):
|
||||
glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
|
||||
download_btn = gr.DownloadButton(label="Download GLB")
|
||||
|
||||
|
||||
# Handlers
|
||||
demo.load(start_session)
|
||||
demo.unload(end_session)
|
||||
|
||||
image_prompt.upload(
|
||||
preprocess_image,
|
||||
inputs=[image_prompt],
|
||||
outputs=[image_prompt],
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
get_seed,
|
||||
inputs=[randomize_seed, seed],
|
||||
outputs=[seed],
|
||||
).then(
|
||||
shapeimage_to_tex,
|
||||
inputs=[
|
||||
mesh_file, image_prompt, seed, resolution, texture_size,
|
||||
tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
|
||||
],
|
||||
outputs=[glb_output, download_btn],
|
||||
)
|
||||
|
||||
|
||||
# Launch the Gradio app
|
||||
if __name__ == "__main__":
|
||||
os.makedirs(TMP_DIR, exist_ok=True)
|
||||
|
||||
pipeline = Trellis2TexturingPipeline.from_pretrained('microsoft/TRELLIS.2-4B', config_file="texturing_pipeline.json")
|
||||
pipeline.cuda()
|
||||
|
||||
demo.launch()
|
||||
BIN
assets/example_texturing/image.webp
Normal file
BIN
assets/example_texturing/image.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
11
assets/example_texturing/readme
Normal file
11
assets/example_texturing/readme
Normal file
@@ -0,0 +1,11 @@
|
||||
## Asset Information
|
||||
|
||||
* Title: The Forgotten Knight
|
||||
* Author: dark_igorek
|
||||
* Source: https://sketchfab.com/3d-models/the-forgotten-knight-d14eb14d83bd4e7ba7cbe443d76a10fd
|
||||
* License: Creative Commons Attribution (CC BY)
|
||||
|
||||
## Usage
|
||||
|
||||
The asset is used for research purposes only.
|
||||
Please credit the original author and include the Sketchfab link when using or redistributing this model.
|
||||
BIN
assets/example_texturing/the_forgotten_knight.ply
Normal file
BIN
assets/example_texturing/the_forgotten_knight.ply
Normal file
Binary file not shown.
17
example_texturing.py
Normal file
17
example_texturing.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Can save GPU memory
|
||||
import trimesh
|
||||
from PIL import Image
|
||||
from trellis2.pipelines import Trellis2TexturingPipeline
|
||||
|
||||
# 1. Load Pipeline
|
||||
pipeline = Trellis2TexturingPipeline.from_pretrained("microsoft/TRELLIS.2-4B", config_file="texturing_pipeline.json")
|
||||
pipeline.cuda()
|
||||
|
||||
# 2. Load Mesh, image & Run
|
||||
mesh = trimesh.load("assets/example_texturing/the_forgotten_knight.ply")
|
||||
image = Image.open("assets/example_texturing/image.webp")
|
||||
output = pipeline.run(mesh, image)
|
||||
|
||||
# 3. Render Mesh
|
||||
output.export("textured.glb", extension_webp=True)
|
||||
@@ -2,8 +2,7 @@ import importlib
|
||||
|
||||
__attributes = {
|
||||
"Trellis2ImageTo3DPipeline": "trellis2_image_to_3d",
|
||||
"Trellis2ImageTo3DCascadePipeline": "trellis2_image_to_3d_cascade",
|
||||
"Trellis2ImageToTexturePipeline": "trellis2_image_to_tex",
|
||||
"Trellis2TexturingPipeline": "trellis2_texturing",
|
||||
}
|
||||
|
||||
__submodules = ['samplers', 'rembg']
|
||||
@@ -49,7 +48,5 @@ def from_pretrained(path: str):
|
||||
# For PyLance
|
||||
if __name__ == '__main__':
|
||||
from . import samplers, rembg
|
||||
from .trellis_image_to_3d import TrellisImageTo3DPipeline
|
||||
from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline
|
||||
from .trellis2_image_to_3d_cascade import Trellis2ImageTo3DCascadePipeline
|
||||
from .trellis2_image_to_tex import Trellis2ImageToTexturePipeline
|
||||
from .trellis2_texturing import Trellis2TexturingPipeline
|
||||
|
||||
@@ -18,32 +18,34 @@ class Pipeline:
|
||||
for model in self.models.values():
|
||||
model.eval()
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path: str) -> "Pipeline":
|
||||
@classmethod
|
||||
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Pipeline":
|
||||
"""
|
||||
Load a pretrained model.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
is_local = os.path.exists(f"{path}/pipeline.json")
|
||||
is_local = os.path.exists(f"{path}/{config_file}")
|
||||
|
||||
if is_local:
|
||||
config_file = f"{path}/pipeline.json"
|
||||
config_file = f"{path}/{config_file}"
|
||||
else:
|
||||
from huggingface_hub import hf_hub_download
|
||||
config_file = hf_hub_download(path, "pipeline.json")
|
||||
config_file = hf_hub_download(path, config_file)
|
||||
|
||||
with open(config_file, 'r') as f:
|
||||
args = json.load(f)['args']
|
||||
|
||||
_models = {}
|
||||
for k, v in args['models'].items():
|
||||
if hasattr(cls, 'model_names_to_load') and k not in cls.model_names_to_load:
|
||||
continue
|
||||
try:
|
||||
_models[k] = models.from_pretrained(f"{path}/{v}")
|
||||
except Exception as e:
|
||||
_models[k] = models.from_pretrained(v)
|
||||
|
||||
new_pipeline = Pipeline(_models)
|
||||
new_pipeline = cls(_models)
|
||||
new_pipeline._pretrained_args = args
|
||||
return new_pipeline
|
||||
|
||||
|
||||
@@ -28,6 +28,17 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
||||
rembg_model (Callable): The model for removing background.
|
||||
low_vram (bool): Whether to use low-VRAM mode.
|
||||
"""
|
||||
model_names_to_load = [
|
||||
'sparse_structure_flow_model',
|
||||
'sparse_structure_decoder',
|
||||
'shape_slat_flow_model_512',
|
||||
'shape_slat_flow_model_1024',
|
||||
'shape_slat_decoder',
|
||||
'tex_slat_flow_model_512',
|
||||
'tex_slat_flow_model_1024',
|
||||
'tex_slat_decoder',
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models: dict[str, nn.Module] = None,
|
||||
@@ -67,45 +78,43 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
||||
}
|
||||
self._device = 'cpu'
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path: str) -> "Trellis2ImageTo3DPipeline":
|
||||
@classmethod
|
||||
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2ImageTo3DPipeline":
|
||||
"""
|
||||
Load a pretrained model.
|
||||
|
||||
Args:
|
||||
path (str): The path to the model. Can be either local path or a Hugging Face repository.
|
||||
"""
|
||||
pipeline = super(Trellis2ImageTo3DPipeline, Trellis2ImageTo3DPipeline).from_pretrained(path)
|
||||
new_pipeline = Trellis2ImageTo3DPipeline()
|
||||
new_pipeline.__dict__ = pipeline.__dict__
|
||||
pipeline = super().from_pretrained(path, config_file)
|
||||
args = pipeline._pretrained_args
|
||||
|
||||
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
|
||||
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
|
||||
pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
|
||||
pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
|
||||
|
||||
new_pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
|
||||
new_pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
|
||||
pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
|
||||
pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
|
||||
|
||||
new_pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
|
||||
new_pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
|
||||
pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
|
||||
pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
|
||||
|
||||
new_pipeline.shape_slat_normalization = args['shape_slat_normalization']
|
||||
new_pipeline.tex_slat_normalization = args['tex_slat_normalization']
|
||||
pipeline.shape_slat_normalization = args['shape_slat_normalization']
|
||||
pipeline.tex_slat_normalization = args['tex_slat_normalization']
|
||||
|
||||
new_pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
|
||||
new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
|
||||
pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
|
||||
pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
|
||||
|
||||
new_pipeline.low_vram = args.get('low_vram', True)
|
||||
new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
|
||||
new_pipeline.pbr_attr_layout = {
|
||||
pipeline.low_vram = args.get('low_vram', True)
|
||||
pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
|
||||
pipeline.pbr_attr_layout = {
|
||||
'base_color': slice(0, 3),
|
||||
'metallic': slice(3, 4),
|
||||
'roughness': slice(4, 5),
|
||||
'alpha': slice(5, 6),
|
||||
}
|
||||
new_pipeline._device = 'cpu'
|
||||
pipeline._device = 'cpu'
|
||||
|
||||
return new_pipeline
|
||||
return pipeline
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self._device = device
|
||||
@@ -364,7 +373,6 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
||||
|
||||
Args:
|
||||
slat (SparseTensor): The structured latent.
|
||||
formats (List[str]): The formats to decode the structured latent to.
|
||||
|
||||
Returns:
|
||||
List[Mesh]: The decoded meshes.
|
||||
@@ -433,10 +441,9 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
||||
|
||||
Args:
|
||||
slat (SparseTensor): The structured latent.
|
||||
formats (List[str]): The formats to decode the structured latent to.
|
||||
|
||||
Returns:
|
||||
List[SparseTensor]: The decoded texture voxels
|
||||
SparseTensor: The decoded texture voxels
|
||||
"""
|
||||
if self.low_vram:
|
||||
self.models['tex_slat_decoder'].to(self.device)
|
||||
|
||||
408
trellis2/pipelines/trellis2_texturing.py
Executable file
408
trellis2/pipelines/trellis2_texturing.py
Executable file
@@ -0,0 +1,408 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import trimesh
|
||||
from .base import Pipeline
|
||||
from . import samplers, rembg
|
||||
from ..modules.sparse import SparseTensor
|
||||
from ..modules import image_feature_extractor
|
||||
import o_voxel
|
||||
import cumesh
|
||||
import nvdiffrast.torch as dr
|
||||
import cv2
|
||||
import flex_gemm
|
||||
|
||||
|
||||
class Trellis2TexturingPipeline(Pipeline):
|
||||
"""
|
||||
Pipeline for inferring Trellis2 image-to-3D models.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): The models to use in the pipeline.
|
||||
tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
|
||||
tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
|
||||
shape_slat_normalization (dict): The normalization parameters for the structured latent.
|
||||
tex_slat_normalization (dict): The normalization parameters for the texture latent.
|
||||
image_cond_model (Callable): The image conditioning model.
|
||||
rembg_model (Callable): The model for removing background.
|
||||
low_vram (bool): Whether to use low-VRAM mode.
|
||||
"""
|
||||
model_names_to_load = [
|
||||
'shape_slat_encoder',
|
||||
'tex_slat_decoder',
|
||||
'tex_slat_flow_model_512',
|
||||
'tex_slat_flow_model_1024'
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models: dict[str, nn.Module] = None,
|
||||
tex_slat_sampler: samplers.Sampler = None,
|
||||
tex_slat_sampler_params: dict = None,
|
||||
shape_slat_normalization: dict = None,
|
||||
tex_slat_normalization: dict = None,
|
||||
image_cond_model: Callable = None,
|
||||
rembg_model: Callable = None,
|
||||
low_vram: bool = True,
|
||||
):
|
||||
if models is None:
|
||||
return
|
||||
super().__init__(models)
|
||||
self.tex_slat_sampler = tex_slat_sampler
|
||||
self.tex_slat_sampler_params = tex_slat_sampler_params
|
||||
self.shape_slat_normalization = shape_slat_normalization
|
||||
self.tex_slat_normalization = tex_slat_normalization
|
||||
self.image_cond_model = image_cond_model
|
||||
self.rembg_model = rembg_model
|
||||
self.low_vram = low_vram
|
||||
self.pbr_attr_layout = {
|
||||
'base_color': slice(0, 3),
|
||||
'metallic': slice(3, 4),
|
||||
'roughness': slice(4, 5),
|
||||
'alpha': slice(5, 6),
|
||||
}
|
||||
self._device = 'cpu'
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2TexturingPipeline":
|
||||
"""
|
||||
Load a pretrained model.
|
||||
|
||||
Args:
|
||||
path (str): The path to the model. Can be either local path or a Hugging Face repository.
|
||||
"""
|
||||
pipeline = super().from_pretrained(path, config_file)
|
||||
args = pipeline._pretrained_args
|
||||
|
||||
pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
|
||||
pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
|
||||
|
||||
pipeline.shape_slat_normalization = args['shape_slat_normalization']
|
||||
pipeline.tex_slat_normalization = args['tex_slat_normalization']
|
||||
|
||||
pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
|
||||
pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
|
||||
|
||||
pipeline.low_vram = args.get('low_vram', True)
|
||||
pipeline.pbr_attr_layout = {
|
||||
'base_color': slice(0, 3),
|
||||
'metallic': slice(3, 4),
|
||||
'roughness': slice(4, 5),
|
||||
'alpha': slice(5, 6),
|
||||
}
|
||||
pipeline._device = 'cpu'
|
||||
return pipeline
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self._device = device
|
||||
if not self.low_vram:
|
||||
super().to(device)
|
||||
self.image_cond_model.to(device)
|
||||
if self.rembg_model is not None:
|
||||
self.rembg_model.to(device)
|
||||
|
||||
def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
||||
"""
|
||||
Preprocess the input mesh.
|
||||
"""
|
||||
vertices = mesh.vertices
|
||||
vertices_min = vertices.min(axis=0)
|
||||
vertices_max = vertices.max(axis=0)
|
||||
center = (vertices_min + vertices_max) / 2
|
||||
scale = 0.99999 / (vertices_max - vertices_min).max()
|
||||
vertices = (vertices - center) * scale
|
||||
tmp = vertices[:, 1].copy()
|
||||
vertices[:, 1] = -vertices[:, 2]
|
||||
vertices[:, 2] = tmp
|
||||
assert np.all(vertices >= -0.5) and np.all(vertices <= 0.5), 'vertices out of range'
|
||||
return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False)
|
||||
|
||||
def preprocess_image(self, input: Image.Image) -> Image.Image:
|
||||
"""
|
||||
Preprocess the input image.
|
||||
"""
|
||||
# if has alpha channel, use it directly; otherwise, remove background
|
||||
has_alpha = False
|
||||
if input.mode == 'RGBA':
|
||||
alpha = np.array(input)[:, :, 3]
|
||||
if not np.all(alpha == 255):
|
||||
has_alpha = True
|
||||
max_size = max(input.size)
|
||||
scale = min(1, 1024 / max_size)
|
||||
if scale < 1:
|
||||
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
||||
if has_alpha:
|
||||
output = input
|
||||
else:
|
||||
input = input.convert('RGB')
|
||||
if self.low_vram:
|
||||
self.rembg_model.to(self.device)
|
||||
output = self.rembg_model(input)
|
||||
if self.low_vram:
|
||||
self.rembg_model.cpu()
|
||||
output_np = np.array(output)
|
||||
alpha = output_np[:, :, 3]
|
||||
bbox = np.argwhere(alpha > 0.8 * 255)
|
||||
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
|
||||
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
|
||||
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
|
||||
size = int(size * 1)
|
||||
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
|
||||
output = output.crop(bbox) # type: ignore
|
||||
output = np.array(output).astype(np.float32) / 255
|
||||
output = output[:, :, :3] * output[:, :, 3:4]
|
||||
output = Image.fromarray((output * 255).astype(np.uint8))
|
||||
return output
|
||||
|
||||
def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict:
|
||||
"""
|
||||
Get the conditioning information for the model.
|
||||
|
||||
Args:
|
||||
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
|
||||
|
||||
Returns:
|
||||
dict: The conditioning information
|
||||
"""
|
||||
self.image_cond_model.image_size = resolution
|
||||
if self.low_vram:
|
||||
self.image_cond_model.to(self.device)
|
||||
cond = self.image_cond_model(image)
|
||||
if self.low_vram:
|
||||
self.image_cond_model.cpu()
|
||||
if not include_neg_cond:
|
||||
return {'cond': cond}
|
||||
neg_cond = torch.zeros_like(cond)
|
||||
return {
|
||||
'cond': cond,
|
||||
'neg_cond': neg_cond,
|
||||
}
|
||||
|
||||
def encode_shape_slat(
|
||||
self,
|
||||
mesh: trimesh.Trimesh,
|
||||
resolution: int = 1024,
|
||||
) -> SparseTensor:
|
||||
"""
|
||||
Encode the meshes to structured latent.
|
||||
|
||||
Args:
|
||||
mesh (trimesh.Trimesh): The mesh to encode.
|
||||
resolution (int): The resolution of mesh
|
||||
|
||||
Returns:
|
||||
SparseTensor: The encoded structured latent.
|
||||
"""
|
||||
vertices = torch.from_numpy(mesh.vertices).float()
|
||||
faces = torch.from_numpy(mesh.faces).long()
|
||||
|
||||
voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
|
||||
vertices.cpu(), faces.cpu(),
|
||||
grid_size=resolution,
|
||||
aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]],
|
||||
face_weight=1.0,
|
||||
boundary_weight=0.2,
|
||||
regularization_weight=1e-2,
|
||||
timing=True,
|
||||
)
|
||||
|
||||
vertices = SparseTensor(
|
||||
feats=dual_vertices * resolution - voxel_indices,
|
||||
coords=torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1)
|
||||
).to(self.device)
|
||||
intersected = vertices.replace(intersected).to(self.device)
|
||||
|
||||
if self.low_vram:
|
||||
self.models['shape_slat_encoder'].to(self.device)
|
||||
shape_slat = self.models['shape_slat_encoder'](vertices, intersected)
|
||||
if self.low_vram:
|
||||
self.models['shape_slat_encoder'].cpu()
|
||||
return shape_slat
|
||||
|
||||
def sample_tex_slat(
|
||||
self,
|
||||
cond: dict,
|
||||
flow_model,
|
||||
shape_slat: SparseTensor,
|
||||
sampler_params: dict = {},
|
||||
) -> SparseTensor:
|
||||
"""
|
||||
Sample structured latent with the given conditioning.
|
||||
|
||||
Args:
|
||||
cond (dict): The conditioning information.
|
||||
shape_slat (SparseTensor): The structured latent for shape
|
||||
sampler_params (dict): Additional parameters for the sampler.
|
||||
"""
|
||||
# Sample structured latent
|
||||
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
|
||||
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device)
|
||||
shape_slat = (shape_slat - mean) / std
|
||||
|
||||
in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
|
||||
noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device))
|
||||
sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
|
||||
if self.low_vram:
|
||||
flow_model.to(self.device)
|
||||
slat = self.tex_slat_sampler.sample(
|
||||
flow_model,
|
||||
noise,
|
||||
concat_cond=shape_slat,
|
||||
**cond,
|
||||
**sampler_params,
|
||||
verbose=True,
|
||||
tqdm_desc="Sampling texture SLat",
|
||||
).samples
|
||||
if self.low_vram:
|
||||
flow_model.cpu()
|
||||
|
||||
std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
|
||||
mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
|
||||
slat = slat * std + mean
|
||||
|
||||
return slat
|
||||
|
||||
def decode_tex_slat(
|
||||
self,
|
||||
slat: SparseTensor,
|
||||
) -> SparseTensor:
|
||||
"""
|
||||
Decode the structured latent.
|
||||
|
||||
Args:
|
||||
slat (SparseTensor): The structured latent.
|
||||
|
||||
Returns:
|
||||
SparseTensor: The decoded texture voxels
|
||||
"""
|
||||
if self.low_vram:
|
||||
self.models['tex_slat_decoder'].to(self.device)
|
||||
ret = self.models['tex_slat_decoder'](slat) * 0.5 + 0.5
|
||||
if self.low_vram:
|
||||
self.models['tex_slat_decoder'].cpu()
|
||||
return ret
|
||||
|
||||
def postprocess_mesh(
|
||||
self,
|
||||
mesh: trimesh.Trimesh,
|
||||
pbr_voxel: SparseTensor,
|
||||
resolution: int = 1024,
|
||||
texture_size: int = 1024,
|
||||
) -> trimesh.Trimesh:
|
||||
vertices = mesh.vertices
|
||||
faces = mesh.faces
|
||||
normals = mesh.vertex_normals
|
||||
vertices_torch = torch.from_numpy(vertices).float().cuda()
|
||||
faces_torch = torch.from_numpy(faces).int().cuda()
|
||||
if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
|
||||
uvs = mesh.visual.uv.copy()
|
||||
uvs[:, 1] = 1 - uvs[:, 1]
|
||||
uvs_torch = torch.from_numpy(uvs).float().cuda()
|
||||
else:
|
||||
_cumesh = cumesh.CuMesh()
|
||||
_cumesh.init(vertices_torch, faces_torch)
|
||||
vertices_torch, faces_torch, uvs_torch, vmap = _cumesh.uv_unwrap(return_vmaps=True)
|
||||
vertices_torch = vertices_torch.cuda()
|
||||
faces_torch = faces_torch.cuda()
|
||||
uvs_torch = uvs_torch.cuda()
|
||||
vertices = vertices_torch.cpu().numpy()
|
||||
faces = faces_torch.cpu().numpy()
|
||||
uvs = uvs_torch.cpu().numpy()
|
||||
normals = normals[vmap.cpu().numpy()]
|
||||
|
||||
# rasterize
|
||||
ctx = dr.RasterizeCudaContext()
|
||||
uvs_torch = torch.cat([uvs_torch * 2 - 1, torch.zeros_like(uvs_torch[:, :1]), torch.ones_like(uvs_torch[:, :1])], dim=-1).unsqueeze(0)
|
||||
rast, _ = dr.rasterize(
|
||||
ctx, uvs_torch, faces_torch,
|
||||
resolution=[texture_size, texture_size],
|
||||
)
|
||||
mask = rast[0, ..., 3] > 0
|
||||
pos = dr.interpolate(vertices_torch.unsqueeze(0), rast, faces_torch)[0][0]
|
||||
|
||||
attrs = torch.zeros(texture_size, texture_size, pbr_voxel.shape[1], device=self.device)
|
||||
attrs[mask] = flex_gemm.ops.grid_sample.grid_sample_3d(
|
||||
pbr_voxel.feats,
|
||||
pbr_voxel.coords,
|
||||
shape=torch.Size([*pbr_voxel.shape, *pbr_voxel.spatial_shape]),
|
||||
grid=((pos[mask] + 0.5) * resolution).reshape(1, -1, 3),
|
||||
mode='trilinear',
|
||||
)
|
||||
|
||||
# construct mesh
|
||||
mask = mask.cpu().numpy()
|
||||
base_color = np.clip(attrs[..., self.pbr_attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
metallic = np.clip(attrs[..., self.pbr_attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
roughness = np.clip(attrs[..., self.pbr_attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
alpha = np.clip(attrs[..., self.pbr_attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
# extend
|
||||
mask = (~mask).astype(np.uint8)
|
||||
base_color = cv2.inpaint(base_color, mask, 3, cv2.INPAINT_TELEA)
|
||||
metallic = cv2.inpaint(metallic, mask, 1, cv2.INPAINT_TELEA)[..., None]
|
||||
roughness = cv2.inpaint(roughness, mask, 1, cv2.INPAINT_TELEA)[..., None]
|
||||
alpha = cv2.inpaint(alpha, mask, 1, cv2.INPAINT_TELEA)[..., None]
|
||||
|
||||
material = trimesh.visual.material.PBRMaterial(
|
||||
baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)),
|
||||
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8),
|
||||
metallicRoughnessTexture=Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)),
|
||||
metallicFactor=1.0,
|
||||
roughnessFactor=1.0,
|
||||
alphaMode='OPAQUE',
|
||||
doubleSided=True,
|
||||
)
|
||||
|
||||
# Swap Y and Z axes, invert Y (common conversion for GLB compatibility)
|
||||
vertices[:, 1], vertices[:, 2] = vertices[:, 2], -vertices[:, 1]
|
||||
normals[:, 1], normals[:, 2] = normals[:, 2], -normals[:, 1]
|
||||
uvs[:, 1] = 1 - uvs[:, 1] # Flip UV V-coordinate
|
||||
|
||||
textured_mesh = trimesh.Trimesh(
|
||||
vertices=vertices,
|
||||
faces=faces,
|
||||
vertex_normals=normals,
|
||||
process=False,
|
||||
visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)
|
||||
)
|
||||
|
||||
return textured_mesh
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run(
|
||||
self,
|
||||
mesh: trimesh.Trimesh,
|
||||
image: Image.Image,
|
||||
seed: int = 42,
|
||||
tex_slat_sampler_params: dict = {},
|
||||
preprocess_image: bool = True,
|
||||
resolution: int = 1024,
|
||||
texture_size: int = 2048,
|
||||
) -> trimesh.Trimesh:
|
||||
"""
|
||||
Run the pipeline.
|
||||
|
||||
Args:
|
||||
mesh (trimesh.Trimesh): The mesh to texture.
|
||||
image (Image.Image): The image prompt.
|
||||
seed (int): The random seed.
|
||||
tex_slat_sampler_params (dict): Additional parameters for the texture latent sampler.
|
||||
preprocess_image (bool): Whether to preprocess the image.
|
||||
"""
|
||||
if preprocess_image:
|
||||
image = self.preprocess_image(image)
|
||||
mesh = self.preprocess_mesh(mesh)
|
||||
torch.manual_seed(seed)
|
||||
cond = self.get_cond([image], 512) if resolution == 512 else self.get_cond([image], 1024)
|
||||
shape_slat = self.encode_shape_slat(mesh, resolution)
|
||||
tex_model = self.models['tex_slat_flow_model_512'] if resolution == 512 else self.models['tex_slat_flow_model_1024']
|
||||
tex_slat = self.sample_tex_slat(
|
||||
cond, tex_model,
|
||||
shape_slat, tex_slat_sampler_params
|
||||
)
|
||||
pbr_voxel = self.decode_tex_slat(tex_slat)
|
||||
out_mesh = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size)
|
||||
return out_mesh
|
||||
Reference in New Issue
Block a user