Diffusers
sayakpaul HF Staff commited on
Commit
6d3538a
·
verified ·
1 Parent(s): b3deb52

Create canny_block.py

Browse files
Files changed (1) hide show
  1. canny_block.py +87 -0
canny_block.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ from PIL import Image
3
+ import torch
4
+ import numpy as np
5
+
6
+ from diffusers.modular_pipelines import (
7
+ PipelineState,
8
+ ModularPipelineBlocks,
9
+ InputParam,
10
+ ComponentSpec,
11
+ OutputParam,
12
+ )
13
+ from controlnet_aux import CannyDetector
14
+ import numpy as np
15
+
16
+
17
+ class CannyBlock(ModularPipelineBlocks):
18
+ @property
19
+ def expected_components(self):
20
+ return [
21
+ ComponentSpec(name="canny_annotator", type_hint=CannyDetector),
22
+ ]
23
+
24
+ @property
25
+ def inputs(self) -> List[InputParam]:
26
+ return [
27
+ InputParam(
28
+ "image",
29
+ type_hint=Union[Image.Image, np.ndarray],
30
+ required=True,
31
+ description="Image to compute canny filter on",
32
+ ),
33
+ InputParam(
34
+ "low_threshold",
35
+ type_hint=int,
36
+ default=50,
37
+ ),
38
+ InputParam("high_threshold", type_hint=int, default=200),
39
+ InputParam(
40
+ "detect_resolution",
41
+ type_hint=int,
42
+ default=1024,
43
+ description="Resolution to resize to when running the Canny filtering process.",
44
+ ),
45
+ InputParam(
46
+ "image_resolution",
47
+ type_hint=int,
48
+ default=1024,
49
+ description="Resolution to resize the detected Canny edge map to.",
50
+ ),
51
+ ]
52
+
53
+ @property
54
+ def intermediate_outputs(self) -> List[OutputParam]:
55
+ return [
56
+ OutputParam(
57
+ "canny_map",
58
+ type_hint=Image,
59
+ description="Canny map for input image",
60
+ )
61
+ ]
62
+
63
+ def compute_canny(self, components, image, low_threshold, high_threshold, detect_resolution, image_resolution):
64
+ canny_map = components.canny_annotator(
65
+ input_image=image,
66
+ low_threshold=low_threshold,
67
+ high_threshold=high_threshold,
68
+ detect_resolution=detect_resolution,
69
+ image_resolution=image_resolution,
70
+ )
71
+ return canny_map
72
+
73
+ @torch.no_grad()
74
+ def __call__(self, components, state: PipelineState) -> PipelineState:
75
+ block_state = self.get_block_state(state)
76
+
77
+ block_state.canny_map = self.compute_canny(
78
+ components,
79
+ block_state.image,
80
+ block_state.low_threshold,
81
+ block_state.high_threshold,
82
+ block_state.detect_resolution,
83
+ block_state.image_resolution,
84
+ )
85
+ self.set_block_state(state, block_state)
86
+
87
+ return components, state