Skip to content

Commit

Permalink
Linting, normalization of some variables, and cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
wbrown committed Aug 25, 2022
1 parent 375da42 commit d0b0b30
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions src/stability_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def process_artifacts_from_answers(
contents = artifact.binary
elif artifact.type == generation.ARTIFACT_CLASSIFICATIONS:
ext = ".pb.json"
contents = MessageToJson(artifact.classifier).encode('utf-8')
contents = MessageToJson(artifact.classifier).encode("utf-8")
elif artifact.type == generation.ARTIFACT_TEXT:
ext = ".pb.json"
contents = MessageToJson(artifact).encode('utf-8')
contents = MessageToJson(artifact).encode("utf-8")
else:
ext = ".pb"
contents = artifact.SerializeToString()
out_p = f"{artifact_p}{ext}"
out_p = f"{artifact_p}{ext}"
if write:
with open(out_p, "wb") as f:
f.write(bytes(contents))
Expand Down Expand Up @@ -130,7 +130,7 @@ def __init__(
host: str = "grpc.stability.ai:443",
key: str = "",
engine: str = "stable-diffusion-v1",
verbose=False,
verbose: bool = False,
wait_for_ready: bool = True,
):
"""
Expand Down Expand Up @@ -247,13 +247,19 @@ def generate(
duration = time.time() - start
if self.verbose:
if len(answer.artifacts) > 0:
artifact_ts = [generation.ArtifactType.Name(artifact.type)
for artifact in answer.artifacts]
logger.info(f"Got {answer.answer_id} with {artifact_ts} in "
f"{duration:0.2f}s")
artifact_ts = [
generation.ArtifactType.Name(artifact.type)
for artifact in answer.artifacts
]
logger.info(
f"Got {answer.answer_id} with {artifact_ts} in "
f"{duration:0.2f}s"
)
else:
logger.info(f"Got keepalive {answer.answer_id} in "
f"{duration:0.2f}s")
logger.info(
f"Got keepalive {answer.answer_id} in "
f"{duration:0.2f}s"
)

yield answer
start = time.time()
Expand Down Expand Up @@ -283,17 +289,17 @@ def build_request_dict(cli_args: Namespace) -> Dict[str, Any]:
fh.setFormatter(fh_formatter)
logger.addHandler(fh)

INFERENCE_HOST = os.getenv("STABILITY_HOST", "grpc.stability.ai:443")
STABILITY_HOST = os.getenv("STABILITY_HOST", "grpc.stability.ai:443")
STABILITY_KEY = os.getenv("STABILITY_KEY", "")

if not INFERENCE_HOST:
if not STABILITY_HOST:
logger.warning("STABILITY_HOST environment variable needs to be set.")
sys.exit(1)

if not STABILITY_KEY:
logger.warning(
"STABILITY_KEY environment variable needs to be set. You may"
" need to login to the Stability website to obtain your"
" need to login to the Stability website to obtain the"
" API key."
)
sys.exit(1)
Expand Down Expand Up @@ -324,8 +330,8 @@ def build_request_dict(cli_args: Namespace) -> Dict[str, Any]:
"--prefix",
"-p",
type=str,
help="output prefixes for artifacts",
default="generation",
help="output prefixes for artifacts",
)
parser.add_argument(
"--no-store", action="store_true", help="do not write out artifacts"
Expand Down Expand Up @@ -354,7 +360,7 @@ def build_request_dict(cli_args: Namespace) -> Dict[str, Any]:
request = build_request_dict(args)

stability_api = StabilityInference(
INFERENCE_HOST, STABILITY_KEY, engine=args.engine, verbose=True
STABILITY_HOST, STABILITY_KEY, engine=args.engine, verbose=True
)

answers = stability_api.generate(args.prompt, **request)
Expand Down

0 comments on commit d0b0b30

Please sign in to comment.