Skip to content

Commit c14a3da

Browse files
Address comments
1 parent dd429ef commit c14a3da

File tree

8 files changed

+28
-32
lines changed

8 files changed

+28
-32
lines changed

docs/source/en/api/pipelines/cosmos.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ output.save("output.png")
7070
- all
7171
- __call__
7272

73+
## Cosmos2_5_PredictBasePipeline
74+
75+
[[autodoc]] Cosmos2_5_PredictBasePipeline
76+
- all
77+
- __call__
78+
7379
## CosmosPipelineOutput
7480

7581
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

scripts/convert_cosmos_to_diffusers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
FlowMatchEulerDiscreteScheduler,
6464
UniPCMultistepScheduler,
6565
)
66-
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBase
66+
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
6767

6868

6969
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -545,7 +545,7 @@ def save_pipeline_cosmos2_5(args, transformer, vae):
545545
sigma_min=0.01,
546546
)
547547

548-
pipe = Cosmos2_5_PredictBase(
548+
pipe = Cosmos2_5_PredictBasePipeline(
549549
text_encoder=text_encoder,
550550
tokenizer=tokenizer,
551551
transformer=transformer,

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@
463463
"CogView4ControlPipeline",
464464
"CogView4Pipeline",
465465
"ConsisIDPipeline",
466-
"Cosmos2_5_PredictBase",
466+
"Cosmos2_5_PredictBasePipeline",
467467
"Cosmos2_5_PredictImage2World",
468468
"Cosmos2_5_PredictText2World",
469469
"Cosmos2_5_PredictVideo2World",
@@ -1179,7 +1179,7 @@
11791179
CogView4ControlPipeline,
11801180
CogView4Pipeline,
11811181
ConsisIDPipeline,
1182-
Cosmos2_5_PredictBase,
1182+
Cosmos2_5_PredictBasePipeline,
11831183
Cosmos2_5_PredictImage2World,
11841184
Cosmos2_5_PredictText2World,
11851185
Cosmos2_5_PredictVideo2World,

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,7 @@ def __init__(
488488
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
489489
)
490490

491-
self.use_crossattn_projection = use_crossattn_projection
492-
if self.use_crossattn_projection:
491+
if self.config.use_crossattn_projection:
493492
self.crossattn_proj = nn.Sequential(
494493
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
495494
nn.GELU(),

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@
165165
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
166166
_import_structure["consisid"] = ["ConsisIDPipeline"]
167167
_import_structure["cosmos"] = [
168-
"Cosmos2_5_PredictBase",
168+
"Cosmos2_5_PredictBasePipeline",
169169
"Cosmos2_5_PredictImage2World",
170170
"Cosmos2_5_PredictText2World",
171171
"Cosmos2_5_PredictVideo2World",
@@ -626,7 +626,7 @@
626626
StableDiffusionXLControlNetXSPipeline,
627627
)
628628
from .cosmos import (
629-
Cosmos2_5_PredictBase,
629+
Cosmos2_5_PredictBasePipeline,
630630
Cosmos2_5_PredictImage2World,
631631
Cosmos2_5_PredictText2World,
632632
Cosmos2_5_PredictVideo2World,

src/diffusers/pipelines/cosmos/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
2525
_import_structure["pipeline_cosmos2_5_predict"] = [
26-
"Cosmos2_5_PredictBase",
26+
"Cosmos2_5_PredictBasePipeline",
2727
"Cosmos2_5_PredictImage2World",
2828
"Cosmos2_5_PredictText2World",
2929
"Cosmos2_5_PredictVideo2World",
@@ -42,7 +42,7 @@
4242
from ...utils.dummy_torch_and_transformers_objects import *
4343
else:
4444
from .pipeline_cosmos2_5_predict import (
45-
Cosmos2_5_PredictBase,
45+
Cosmos2_5_PredictBasePipeline,
4646
Cosmos2_5_PredictImage2World,
4747
Cosmos2_5_PredictText2World,
4848
Cosmos2_5_PredictVideo2World,

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ def retrieve_latents(
7171
Examples:
7272
```python
7373
>>> import torch
74-
>>> from diffusers import Cosmos2_5_PredictBase
74+
>>> from diffusers import Cosmos2_5_PredictBasePipeline
7575
>>> from diffusers.utils import export_to_video, load_image, load_video
7676
7777
>>> model_id = "nvidia/Cosmos-Predict2.5-Base-2B"
78-
>>> pipe = Cosmos2_5_PredictBase.from_pretrained(model_id, torch_dtype=torch.bfloat16)
78+
>>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
7979
>>> pipe = pipe.to("cuda")
8080
8181
>>> # Common negative prompt reused across modes.
@@ -163,7 +163,7 @@ def retrieve_latents(
163163
"""
164164

165165

166-
class Cosmos2_5_PredictBase(DiffusionPipeline):
166+
class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
167167
r"""
168168
Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model.
169169
@@ -233,20 +233,6 @@ def __init__(
233233
if self.latents_mean is None or self.latents_std is None:
234234
raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.")
235235

236-
237-
@property
238-
def _execution_device(self):
239-
device = super()._execution_device
240-
if isinstance(device, torch.device) and device.type == "cpu":
241-
for module_name in ("transformer", "text_encoder", "vae"):
242-
module = getattr(self, module_name, None)
243-
if module is None or not isinstance(module, torch.nn.Module):
244-
continue
245-
module_device = getattr(module, "device", None)
246-
if isinstance(module_device, torch.device) and module_device.type != "cpu":
247-
return module_device
248-
return device
249-
250236
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_prompt_embeds
251237
def _get_prompt_embeds(
252238
self,
@@ -398,6 +384,8 @@ def encode_prompt(
398384

399385
return prompt_embeds, negative_prompt_embeds
400386

387+
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and
388+
# diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents
401389
def prepare_latents(
402390
self,
403391
video: Optional[torch.Tensor],
@@ -458,8 +446,6 @@ def prepare_latents(
458446

459447
cond_latents = torch.cat(cond_latents, dim=0).to(dtype)
460448

461-
if self.latents_mean is None or self.latents_std is None:
462-
raise ValueError("VAE configuration must define `latents_mean` and `latents_std`.")
463449
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
464450
latents_std = self.latents_std.to(device=device, dtype=dtype)
465451
cond_latents = (cond_latents - latents_mean) / latents_std

tests/pipelines/cosmos/test_cosmos2_5_predict.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
import torch
2323
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
2424

25-
from diffusers import AutoencoderKLWan, Cosmos2_5_PredictBase, CosmosTransformer3DModel, UniPCMultistepScheduler
25+
from diffusers import (
26+
AutoencoderKLWan,
27+
Cosmos2_5_PredictBasePipeline,
28+
CosmosTransformer3DModel,
29+
UniPCMultistepScheduler,
30+
)
2631

2732
from ...testing_utils import enable_full_determinism, torch_device
2833
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -33,7 +38,7 @@
3338
enable_full_determinism()
3439

3540

36-
class Cosmos2_5_PredictBaseWrapper(Cosmos2_5_PredictBase):
41+
class Cosmos2_5_PredictBaseWrapper(Cosmos2_5_PredictBasePipeline):
3742
@staticmethod
3843
def from_pretrained(*args, **kwargs):
3944
if "safety_checker" not in kwargs or kwargs["safety_checker"] is None:
@@ -42,7 +47,7 @@ def from_pretrained(*args, **kwargs):
4247
if isinstance(torch_dtype, torch.dtype):
4348
safety_checker = safety_checker.to(dtype=torch_dtype)
4449
kwargs["safety_checker"] = safety_checker
45-
return Cosmos2_5_PredictBase.from_pretrained(*args, **kwargs)
50+
return Cosmos2_5_PredictBasePipeline.from_pretrained(*args, **kwargs)
4651

4752

4853
class Cosmos2_5_PredictPipelineFastTests(PipelineTesterMixin, unittest.TestCase):

0 commit comments

Comments
 (0)