Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() {
&hires_upscaler},
{"",
"--extra-sample-args",
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;; logit_normal supports mu, std, logsnr_min, logsnr_max, resolution_aware",
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; flux supports base_shift, max_shift; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;; logit_normal supports mu, std, logsnr_min, logsnr_max, resolution_aware",
(int)',',
&extra_sample_args},
{"",
Expand Down Expand Up @@ -1475,7 +1475,7 @@ ArgOptions SDGenerationParams::get_options() {
on_high_noise_sample_method_arg},
{"",
"--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2, logit_normal, flux2], default: model-specific",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2, logit_normal, flux2, flux], default: model-specific",
on_scheduler_arg},
{"",
"--sigmas",
Expand Down
1 change: 1 addition & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ enum scheduler_t {
LTX2_SCHEDULER,
LOGIT_NORMAL_SCHEDULER,
FLUX2_SCHEDULER,
FLUX_SCHEDULER,
SCHEDULER_COUNT
};

Expand Down
64 changes: 64 additions & 0 deletions src/runtime/denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,65 @@ inline float flux_time_shift(float mu, float sigma, float t) {
return ::expf(mu) / (::expf(mu) + ::powf((1.0f / t - 1.0f), sigma));
}

// https://github.com/black-forest-labs/flux/blob/main/src/flux/sampling.py#L289
struct FluxScheduler : SigmaScheduler {
int image_seq_len = 0;
float base_shift = 0.5f;
float max_shift = 1.15f;

explicit FluxScheduler(int image_seq_len, const char* extra_sample_args = nullptr)
: image_seq_len(image_seq_len) {
parse_extra_sample_args(extra_sample_args);
}

void parse_extra_sample_args(const char* extra_sample_args) {
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "flux scheduler arg")) {
if (key == "base_shift") {
if (!parse_strict_float(value, base_shift)) {
LOG_WARN("ignoring invalid flux scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else if (key == "max_shift") {
if (!parse_strict_float(value, max_shift)) {
LOG_WARN("ignoring invalid flux scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
}
}
}

float compute_mu() const {
constexpr float base_shift_anchor = 256.0f;
constexpr float max_shift_anchor = 4096.0f;
float m = (max_shift - base_shift) / (max_shift_anchor - base_shift_anchor);
float b = base_shift - m * base_shift_anchor;
return static_cast<float>(image_seq_len) * m + b;
}

std::vector<float> get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override {
std::vector<float> sigmas;
sigmas.reserve(n + 1);

float mu = compute_mu();
LOG_DEBUG("Flux scheduler: image_seq_len=%d, steps=%u, mu=%.3f", image_seq_len, n, mu);

if (n == 0) {
sigmas.push_back(1.0f);
return sigmas;
}

for (uint32_t i = 0; i <= n; ++i) {
float t = 1.0f - static_cast<float>(i) / static_cast<float>(n);
if (t <= 0.0f) {
sigmas.push_back(0.0f);
} else {
sigmas.push_back(flux_time_shift(mu, 1.0f, t));
}
}

sigmas[n] = 0.0f;
return sigmas;
}
};

// https://github.com/black-forest-labs/flux2/blob/main/src/flux2/sampling.py#L244
struct Flux2Scheduler : SigmaScheduler {
int image_seq_len = 0;
Expand Down Expand Up @@ -886,6 +945,11 @@ struct Denoiser {
scheduler = std::make_shared<Flux2Scheduler>(image_seq_len);
break;
}
case FLUX_SCHEDULER: {
LOG_INFO("get_sigmas with Flux scheduler");
scheduler = std::make_shared<FluxScheduler>(image_seq_len, extra_sample_args);
break;
}
default:
LOG_INFO("get_sigmas with discrete scheduler (default)");
scheduler = std::make_shared<DiscreteScheduler>();
Expand Down
3 changes: 3 additions & 0 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2561,6 +2561,7 @@ const char* scheduler_to_str[] = {
"ltx2",
"logit_normal",
"flux2",
"flux",
};

const char* sd_scheduler_name(enum scheduler_t scheduler) {
Expand Down Expand Up @@ -3161,6 +3162,8 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
return LCM_SCHEDULER;
} else if (sample_method == DDIM_TRAILING_SAMPLE_METHOD) {
return SIMPLE_SCHEDULER;
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_flux(sd_ctx->sd->version)) {
return FLUX_SCHEDULER;
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_flux2(sd_ctx->sd->version)) {
return FLUX2_SCHEDULER;
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) {
Expand Down
Loading