Skip to content

Commit f67639b

Browse files
add post init for safty checker (#12794)
* add post init for safty checker Signed-off-by: jiqing-feng <[email protected]> * check transformers version before post init Signed-off-by: jiqing-feng <[email protected]> * Apply style fixes --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 5a74319 commit f67639b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/diffusers/pipelines/stable_diffusion/safety_checker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
1919

20-
from ...utils import logging
20+
from ...utils import is_transformers_version, logging
2121

2222

2323
logger = logging.get_logger(__name__)
@@ -46,6 +46,9 @@ def __init__(self, config: CLIPConfig):
4646

4747
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
4848
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
49+
# Model requires post_init after transformers v4.57.3
50+
if is_transformers_version(">", "4.57.3"):
51+
self.post_init()
4952

5053
@torch.no_grad()
5154
def forward(self, clip_input, images):

0 commit comments

Comments
 (0)