update texturing pipeline

This commit is contained in:
JeffreyXiang
2025-12-23 12:57:08 +00:00
parent 1762f493fe
commit 903bfcf51a
10 changed files with 629 additions and 36 deletions

View File

@@ -49,7 +49,7 @@ Data processing is streamlined for instant conversions that are fully **renderin
- [x] Release image-to-3D inference code
- [x] Release pretrained checkpoints (4B)
- [x] Hugging Face Spaces demo
- [ ] Release shape-conditioned texture generation inference code (Current schdule: before 12/24/2025)
- [x] Release shape-conditioned texture generation inference code
- [ ] Release training code (Current schdule: before 12/31/2025)
@@ -184,7 +184,7 @@ Then, you can access the demo at the address shown in the terminal.
### 2. PBR Texture Generation
Will be released soon. Please stay tuned!
Please refer to the [example_texturing.py](example_texturing.py) for an example of how to generate PBR textures for a given 3D shape. Also, you can use the [app_texturing.py](app_texturing.py) to run a web demo for PBR texture generation.
## 🧩 Related Packages

151
app_texturing.py Normal file
View File

@@ -0,0 +1,151 @@
import gradio as gr
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from datetime import datetime
import shutil
from typing import *
import torch
import numpy as np
import trimesh
from PIL import Image
from trellis2.pipelines import Trellis2TexturingPipeline
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def preprocess_image(image: Image.Image) -> Image.Image:
"""
Preprocess the input image.
Args:
image (Image.Image): The input image.
Returns:
Image.Image: The preprocessed image.
"""
processed_image = pipeline.preprocess_image(image)
return processed_image
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def shapeimage_to_tex(
mesh_file: str,
image: Image.Image,
seed: int,
resolution: str,
texture_size: int,
tex_slat_guidance_strength: float,
tex_slat_guidance_rescale: float,
tex_slat_sampling_steps: int,
tex_slat_rescale_t: float,
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> str:
mesh = trimesh.load(mesh_file)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.to_mesh()
output = pipeline.run(
mesh,
image,
seed=seed,
preprocess_image=False,
tex_slat_sampler_params={
"steps": tex_slat_sampling_steps,
"guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale,
"rescale_t": tex_slat_rescale_t,
},
resolution=int(resolution),
texture_size=texture_size,
)
now = datetime.now()
timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
output.export(glb_path, extension_webp=True)
torch.cuda.empty_cache()
return glb_path, glb_path
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Texturing a mesh with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
* Upload a mesh and corresponding reference image (preferably with an alpha-masked foreground object) and click Generate to create a textured 3D asset.
""")
with gr.Row():
with gr.Column(scale=1, min_width=360):
mesh_file = gr.File(label="Upload Mesh", file_types=[".ply", ".obj", ".glb", ".gltf"], file_count="single")
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
generate_btn = gr.Button("Generate")
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row():
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
with gr.Column(scale=10):
glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
download_btn = gr.DownloadButton(label="Download GLB")
# Handlers
demo.load(start_session)
demo.unload(end_session)
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[image_prompt],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
shapeimage_to_tex,
inputs=[
mesh_file, image_prompt, seed, resolution, texture_size,
tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
],
outputs=[glb_output, download_btn],
)
# Launch the Gradio app
if __name__ == "__main__":
os.makedirs(TMP_DIR, exist_ok=True)
pipeline = Trellis2TexturingPipeline.from_pretrained('microsoft/TRELLIS.2-4B', config_file="texturing_pipeline.json")
pipeline.cuda()
demo.launch()

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

View File

@@ -0,0 +1,11 @@
## Asset Information
* Title: The Forgotten Knight
* Author: dark_igorek
* Source: https://sketchfab.com/3d-models/the-forgotten-knight-d14eb14d83bd4e7ba7cbe443d76a10fd
* License: Creative Commons Attribution (CC BY)
## Usage
The asset is used for research purposes only.
Please credit the original author and include the Sketchfab link when using or redistributing this model.

Binary file not shown.

17
example_texturing.py Normal file
View File

@@ -0,0 +1,17 @@
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Can save GPU memory
import trimesh
from PIL import Image
from trellis2.pipelines import Trellis2TexturingPipeline
# 1. Load Pipeline
pipeline = Trellis2TexturingPipeline.from_pretrained("microsoft/TRELLIS.2-4B", config_file="texturing_pipeline.json")
pipeline.cuda()
# 2. Load Mesh, image & Run
mesh = trimesh.load("assets/example_texturing/the_forgotten_knight.ply")
image = Image.open("assets/example_texturing/image.webp")
output = pipeline.run(mesh, image)
# 3. Render Mesh
output.export("textured.glb", extension_webp=True)

View File

@@ -2,8 +2,7 @@ import importlib
__attributes = {
"Trellis2ImageTo3DPipeline": "trellis2_image_to_3d",
"Trellis2ImageTo3DCascadePipeline": "trellis2_image_to_3d_cascade",
"Trellis2ImageToTexturePipeline": "trellis2_image_to_tex",
"Trellis2TexturingPipeline": "trellis2_texturing",
}
__submodules = ['samplers', 'rembg']
@@ -49,7 +48,5 @@ def from_pretrained(path: str):
# For PyLance
if __name__ == '__main__':
from . import samplers, rembg
from .trellis_image_to_3d import TrellisImageTo3DPipeline
from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline
from .trellis2_image_to_3d_cascade import Trellis2ImageTo3DCascadePipeline
from .trellis2_image_to_tex import Trellis2ImageToTexturePipeline
from .trellis2_texturing import Trellis2TexturingPipeline

View File

@@ -18,32 +18,34 @@ class Pipeline:
for model in self.models.values():
model.eval()
@staticmethod
def from_pretrained(path: str) -> "Pipeline":
@classmethod
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Pipeline":
"""
Load a pretrained model.
"""
import os
import json
is_local = os.path.exists(f"{path}/pipeline.json")
is_local = os.path.exists(f"{path}/{config_file}")
if is_local:
config_file = f"{path}/pipeline.json"
config_file = f"{path}/{config_file}"
else:
from huggingface_hub import hf_hub_download
config_file = hf_hub_download(path, "pipeline.json")
config_file = hf_hub_download(path, config_file)
with open(config_file, 'r') as f:
args = json.load(f)['args']
_models = {}
for k, v in args['models'].items():
if hasattr(cls, 'model_names_to_load') and k not in cls.model_names_to_load:
continue
try:
_models[k] = models.from_pretrained(f"{path}/{v}")
except Exception as e:
_models[k] = models.from_pretrained(v)
new_pipeline = Pipeline(_models)
new_pipeline = cls(_models)
new_pipeline._pretrained_args = args
return new_pipeline

View File

@@ -28,6 +28,17 @@ class Trellis2ImageTo3DPipeline(Pipeline):
rembg_model (Callable): The model for removing background.
low_vram (bool): Whether to use low-VRAM mode.
"""
model_names_to_load = [
'sparse_structure_flow_model',
'sparse_structure_decoder',
'shape_slat_flow_model_512',
'shape_slat_flow_model_1024',
'shape_slat_decoder',
'tex_slat_flow_model_512',
'tex_slat_flow_model_1024',
'tex_slat_decoder',
]
def __init__(
self,
models: dict[str, nn.Module] = None,
@@ -67,45 +78,43 @@ class Trellis2ImageTo3DPipeline(Pipeline):
}
self._device = 'cpu'
@staticmethod
def from_pretrained(path: str) -> "Trellis2ImageTo3DPipeline":
@classmethod
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2ImageTo3DPipeline":
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline = super(Trellis2ImageTo3DPipeline, Trellis2ImageTo3DPipeline).from_pretrained(path)
new_pipeline = Trellis2ImageTo3DPipeline()
new_pipeline.__dict__ = pipeline.__dict__
pipeline = super().from_pretrained(path, config_file)
args = pipeline._pretrained_args
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
new_pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
new_pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
new_pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
new_pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
new_pipeline.shape_slat_normalization = args['shape_slat_normalization']
new_pipeline.tex_slat_normalization = args['tex_slat_normalization']
pipeline.shape_slat_normalization = args['shape_slat_normalization']
pipeline.tex_slat_normalization = args['tex_slat_normalization']
new_pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
new_pipeline.low_vram = args.get('low_vram', True)
new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
new_pipeline.pbr_attr_layout = {
pipeline.low_vram = args.get('low_vram', True)
pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
pipeline.pbr_attr_layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
new_pipeline._device = 'cpu'
pipeline._device = 'cpu'
return new_pipeline
return pipeline
def to(self, device: torch.device) -> None:
self._device = device
@@ -364,7 +373,6 @@ class Trellis2ImageTo3DPipeline(Pipeline):
Args:
slat (SparseTensor): The structured latent.
formats (List[str]): The formats to decode the structured latent to.
Returns:
List[Mesh]: The decoded meshes.
@@ -433,10 +441,9 @@ class Trellis2ImageTo3DPipeline(Pipeline):
Args:
slat (SparseTensor): The structured latent.
formats (List[str]): The formats to decode the structured latent to.
Returns:
List[SparseTensor]: The decoded texture voxels
SparseTensor: The decoded texture voxels
"""
if self.low_vram:
self.models['tex_slat_decoder'].to(self.device)

View File

@@ -0,0 +1,408 @@
from typing import *
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import trimesh
from .base import Pipeline
from . import samplers, rembg
from ..modules.sparse import SparseTensor
from ..modules import image_feature_extractor
import o_voxel
import cumesh
import nvdiffrast.torch as dr
import cv2
import flex_gemm
class Trellis2TexturingPipeline(Pipeline):
"""
Pipeline for inferring Trellis2 image-to-3D models.
Args:
models (dict[str, nn.Module]): The models to use in the pipeline.
tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
shape_slat_normalization (dict): The normalization parameters for the structured latent.
tex_slat_normalization (dict): The normalization parameters for the texture latent.
image_cond_model (Callable): The image conditioning model.
rembg_model (Callable): The model for removing background.
low_vram (bool): Whether to use low-VRAM mode.
"""
model_names_to_load = [
'shape_slat_encoder',
'tex_slat_decoder',
'tex_slat_flow_model_512',
'tex_slat_flow_model_1024'
]
def __init__(
self,
models: dict[str, nn.Module] = None,
tex_slat_sampler: samplers.Sampler = None,
tex_slat_sampler_params: dict = None,
shape_slat_normalization: dict = None,
tex_slat_normalization: dict = None,
image_cond_model: Callable = None,
rembg_model: Callable = None,
low_vram: bool = True,
):
if models is None:
return
super().__init__(models)
self.tex_slat_sampler = tex_slat_sampler
self.tex_slat_sampler_params = tex_slat_sampler_params
self.shape_slat_normalization = shape_slat_normalization
self.tex_slat_normalization = tex_slat_normalization
self.image_cond_model = image_cond_model
self.rembg_model = rembg_model
self.low_vram = low_vram
self.pbr_attr_layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
self._device = 'cpu'
@classmethod
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2TexturingPipeline":
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline = super().from_pretrained(path, config_file)
args = pipeline._pretrained_args
pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
pipeline.shape_slat_normalization = args['shape_slat_normalization']
pipeline.tex_slat_normalization = args['tex_slat_normalization']
pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
pipeline.low_vram = args.get('low_vram', True)
pipeline.pbr_attr_layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
pipeline._device = 'cpu'
return pipeline
def to(self, device: torch.device) -> None:
self._device = device
if not self.low_vram:
super().to(device)
self.image_cond_model.to(device)
if self.rembg_model is not None:
self.rembg_model.to(device)
def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
"""
Preprocess the input mesh.
"""
vertices = mesh.vertices
vertices_min = vertices.min(axis=0)
vertices_max = vertices.max(axis=0)
center = (vertices_min + vertices_max) / 2
scale = 0.99999 / (vertices_max - vertices_min).max()
vertices = (vertices - center) * scale
tmp = vertices[:, 1].copy()
vertices[:, 1] = -vertices[:, 2]
vertices[:, 2] = tmp
assert np.all(vertices >= -0.5) and np.all(vertices <= 0.5), 'vertices out of range'
return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False)
def preprocess_image(self, input: Image.Image) -> Image.Image:
"""
Preprocess the input image.
"""
# if has alpha channel, use it directly; otherwise, remove background
has_alpha = False
if input.mode == 'RGBA':
alpha = np.array(input)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
max_size = max(input.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
if has_alpha:
output = input
else:
input = input.convert('RGB')
if self.low_vram:
self.rembg_model.to(self.device)
output = self.rembg_model(input)
if self.low_vram:
self.rembg_model.cpu()
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1)
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
output = output.crop(bbox) # type: ignore
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
return output
def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict:
"""
Get the conditioning information for the model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
Returns:
dict: The conditioning information
"""
self.image_cond_model.image_size = resolution
if self.low_vram:
self.image_cond_model.to(self.device)
cond = self.image_cond_model(image)
if self.low_vram:
self.image_cond_model.cpu()
if not include_neg_cond:
return {'cond': cond}
neg_cond = torch.zeros_like(cond)
return {
'cond': cond,
'neg_cond': neg_cond,
}
def encode_shape_slat(
self,
mesh: trimesh.Trimesh,
resolution: int = 1024,
) -> SparseTensor:
"""
Encode the meshes to structured latent.
Args:
mesh (trimesh.Trimesh): The mesh to encode.
resolution (int): The resolution of mesh
Returns:
SparseTensor: The encoded structured latent.
"""
vertices = torch.from_numpy(mesh.vertices).float()
faces = torch.from_numpy(mesh.faces).long()
voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
vertices.cpu(), faces.cpu(),
grid_size=resolution,
aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]],
face_weight=1.0,
boundary_weight=0.2,
regularization_weight=1e-2,
timing=True,
)
vertices = SparseTensor(
feats=dual_vertices * resolution - voxel_indices,
coords=torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1)
).to(self.device)
intersected = vertices.replace(intersected).to(self.device)
if self.low_vram:
self.models['shape_slat_encoder'].to(self.device)
shape_slat = self.models['shape_slat_encoder'](vertices, intersected)
if self.low_vram:
self.models['shape_slat_encoder'].cpu()
return shape_slat
def sample_tex_slat(
self,
cond: dict,
flow_model,
shape_slat: SparseTensor,
sampler_params: dict = {},
) -> SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
shape_slat (SparseTensor): The structured latent for shape
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample structured latent
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device)
shape_slat = (shape_slat - mean) / std
in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device))
sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
if self.low_vram:
flow_model.to(self.device)
slat = self.tex_slat_sampler.sample(
flow_model,
noise,
concat_cond=shape_slat,
**cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling texture SLat",
).samples
if self.low_vram:
flow_model.cpu()
std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
return slat
def decode_tex_slat(
self,
slat: SparseTensor,
) -> SparseTensor:
"""
Decode the structured latent.
Args:
slat (SparseTensor): The structured latent.
Returns:
SparseTensor: The decoded texture voxels
"""
if self.low_vram:
self.models['tex_slat_decoder'].to(self.device)
ret = self.models['tex_slat_decoder'](slat) * 0.5 + 0.5
if self.low_vram:
self.models['tex_slat_decoder'].cpu()
return ret
def postprocess_mesh(
self,
mesh: trimesh.Trimesh,
pbr_voxel: SparseTensor,
resolution: int = 1024,
texture_size: int = 1024,
) -> trimesh.Trimesh:
vertices = mesh.vertices
faces = mesh.faces
normals = mesh.vertex_normals
vertices_torch = torch.from_numpy(vertices).float().cuda()
faces_torch = torch.from_numpy(faces).int().cuda()
if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
uvs = mesh.visual.uv.copy()
uvs[:, 1] = 1 - uvs[:, 1]
uvs_torch = torch.from_numpy(uvs).float().cuda()
else:
_cumesh = cumesh.CuMesh()
_cumesh.init(vertices_torch, faces_torch)
vertices_torch, faces_torch, uvs_torch, vmap = _cumesh.uv_unwrap(return_vmaps=True)
vertices_torch = vertices_torch.cuda()
faces_torch = faces_torch.cuda()
uvs_torch = uvs_torch.cuda()
vertices = vertices_torch.cpu().numpy()
faces = faces_torch.cpu().numpy()
uvs = uvs_torch.cpu().numpy()
normals = normals[vmap.cpu().numpy()]
# rasterize
ctx = dr.RasterizeCudaContext()
uvs_torch = torch.cat([uvs_torch * 2 - 1, torch.zeros_like(uvs_torch[:, :1]), torch.ones_like(uvs_torch[:, :1])], dim=-1).unsqueeze(0)
rast, _ = dr.rasterize(
ctx, uvs_torch, faces_torch,
resolution=[texture_size, texture_size],
)
mask = rast[0, ..., 3] > 0
pos = dr.interpolate(vertices_torch.unsqueeze(0), rast, faces_torch)[0][0]
attrs = torch.zeros(texture_size, texture_size, pbr_voxel.shape[1], device=self.device)
attrs[mask] = flex_gemm.ops.grid_sample.grid_sample_3d(
pbr_voxel.feats,
pbr_voxel.coords,
shape=torch.Size([*pbr_voxel.shape, *pbr_voxel.spatial_shape]),
grid=((pos[mask] + 0.5) * resolution).reshape(1, -1, 3),
mode='trilinear',
)
# construct mesh
mask = mask.cpu().numpy()
base_color = np.clip(attrs[..., self.pbr_attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
metallic = np.clip(attrs[..., self.pbr_attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
roughness = np.clip(attrs[..., self.pbr_attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
alpha = np.clip(attrs[..., self.pbr_attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
# extend
mask = (~mask).astype(np.uint8)
base_color = cv2.inpaint(base_color, mask, 3, cv2.INPAINT_TELEA)
metallic = cv2.inpaint(metallic, mask, 1, cv2.INPAINT_TELEA)[..., None]
roughness = cv2.inpaint(roughness, mask, 1, cv2.INPAINT_TELEA)[..., None]
alpha = cv2.inpaint(alpha, mask, 1, cv2.INPAINT_TELEA)[..., None]
material = trimesh.visual.material.PBRMaterial(
baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)),
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8),
metallicRoughnessTexture=Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)),
metallicFactor=1.0,
roughnessFactor=1.0,
alphaMode='OPAQUE',
doubleSided=True,
)
# Swap Y and Z axes, invert Y (common conversion for GLB compatibility)
vertices[:, 1], vertices[:, 2] = vertices[:, 2], -vertices[:, 1]
normals[:, 1], normals[:, 2] = normals[:, 2], -normals[:, 1]
uvs[:, 1] = 1 - uvs[:, 1] # Flip UV V-coordinate
textured_mesh = trimesh.Trimesh(
vertices=vertices,
faces=faces,
vertex_normals=normals,
process=False,
visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)
)
return textured_mesh
@torch.no_grad()
def run(
self,
mesh: trimesh.Trimesh,
image: Image.Image,
seed: int = 42,
tex_slat_sampler_params: dict = {},
preprocess_image: bool = True,
resolution: int = 1024,
texture_size: int = 2048,
) -> trimesh.Trimesh:
"""
Run the pipeline.
Args:
mesh (trimesh.Trimesh): The mesh to texture.
image (Image.Image): The image prompt.
seed (int): The random seed.
tex_slat_sampler_params (dict): Additional parameters for the texture latent sampler.
preprocess_image (bool): Whether to preprocess the image.
"""
if preprocess_image:
image = self.preprocess_image(image)
mesh = self.preprocess_mesh(mesh)
torch.manual_seed(seed)
cond = self.get_cond([image], 512) if resolution == 512 else self.get_cond([image], 1024)
shape_slat = self.encode_shape_slat(mesh, resolution)
tex_model = self.models['tex_slat_flow_model_512'] if resolution == 512 else self.models['tex_slat_flow_model_1024']
tex_slat = self.sample_tex_slat(
cond, tex_model,
shape_slat, tex_slat_sampler_params
)
pbr_voxel = self.decode_tex_slat(tex_slat)
out_mesh = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size)
return out_mesh