Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ ArgOptions SDGenerationParams::get_options() {
on_high_noise_sample_method_arg},
{"",
"--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2, logit_normal, flux2, flux], alias: normal=discrete, default: model-specific",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2, logit_normal, flux2, flux, beta], alias: normal=discrete, default: model-specific",
on_scheduler_arg},
{"",
"--sigmas",
Expand Down
1 change: 1 addition & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum scheduler_t {
LOGIT_NORMAL_SCHEDULER,
FLUX2_SCHEDULER,
FLUX_SCHEDULER,
BETA_SCHEDULER,
SCHEDULER_COUNT
};

Expand Down
135 changes: 135 additions & 0 deletions src/runtime/denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,137 @@ struct KarrasScheduler : SigmaScheduler {
}
};

struct BetaScheduler : SigmaScheduler {
static constexpr double alpha = 0.6;
static constexpr double beta = 0.6;

static double log_beta(double a, double b) {
return std::lgamma(a) + std::lgamma(b) - std::lgamma(a + b);
}

static double incbeta(double x, double a, double b) {
if (x <= 0.0) {
return 0.0;
}
if (x >= 1.0) {
return 1.0;
}

// Continued fraction approximation using Lentz's method.
const int max_iter = 200;
const double epsilon = 3.0e-7;
const double tiny = 1e-30;

const double qab = a + b;
const double qap = a + 1.0;
const double qam = a - 1.0;

double c = 1.0;
double d = 1.0 - qab * x / qap;
if (std::abs(d) < tiny) {
d = tiny;
}
d = 1.0 / d;
double h = d;

for (int m = 1; m <= max_iter; m++) {
const int m2 = 2 * m;

double aa = m * (b - m) * x / ((qam + m2) * (a + m2));
d = 1.0 + aa * d;
if (std::abs(d) < tiny) {
d = tiny;
}
c = 1.0 + aa / c;
if (std::abs(c) < tiny) {
c = tiny;
}
d = 1.0 / d;
h *= d * c;

aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2));
d = 1.0 + aa * d;
if (std::abs(d) < tiny) {
d = tiny;
}
c = 1.0 + aa / c;
if (std::abs(c) < tiny) {
c = tiny;
}
d = 1.0 / d;
const double del = d * c;
h *= del;

if (std::abs(del - 1.0) < epsilon) {
break;
}
}

return std::exp(a * std::log(x) + b * std::log(1.0 - x) - log_beta(a, b)) / a * h;
}

static double beta_cdf(double x, double a, double b) {
if (x == 0.0) {
return 0.0;
}
if (x == 1.0) {
return 1.0;
}
if (x < (a + 1.0) / (a + b + 2.0)) {
return incbeta(x, a, b);
}
return 1.0 - incbeta(1.0 - x, b, a);
}

static double beta_ppf(double u, double a, double b, int max_iter = 30) {
double x = 0.5;
for (int i = 0; i < max_iter; i++) {
const double f = beta_cdf(x, a, b) - u;
if (std::abs(f) < 1e-10) {
break;
}
const double df = std::exp((a - 1.0) * std::log(x) + (b - 1.0) * std::log(1.0 - x) - log_beta(a, b));
x -= f / df;
if (x <= 0.0) {
x = 1e-10;
}
if (x >= 1.0) {
x = 1.0 - 1e-10;
}
}
return x;
}

std::vector<float> get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t t_to_sigma) override {
std::vector<float> result;
result.reserve(n + 1);

const int t_max = TIMESTEPS - 1;
if (n == 0) {
return result;
} else if (n == 1) {
result.push_back(t_to_sigma(static_cast<float>(t_max)));
result.push_back(0.f);
return result;
}

int last_t = -1;
for (uint32_t i = 0; i < n; i++) {
const double u = 1.0 - static_cast<double>(i) / static_cast<double>(n);
const double t_cont = beta_ppf(u, alpha, beta) * t_max;
const int t = static_cast<int>(std::lround(t_cont));

if (t != last_t) {
result.push_back(t_to_sigma(static_cast<float>(t)));
last_t = t;
}
}

result.push_back(0.f);
return result;
}
};

struct SimpleScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result_sigmas;
Expand Down Expand Up @@ -895,6 +1026,10 @@ struct Denoiser {
LOG_INFO("get_sigmas with Karras scheduler");
scheduler = std::make_shared<KarrasScheduler>();
break;
case BETA_SCHEDULER:
LOG_INFO("get_sigmas with Beta scheduler");
scheduler = std::make_shared<BetaScheduler>();
break;
case EXPONENTIAL_SCHEDULER:
LOG_INFO("get_sigmas exponential scheduler");
scheduler = std::make_shared<ExponentialScheduler>();
Expand Down
1 change: 1 addition & 0 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2562,6 +2562,7 @@ const char* scheduler_to_str[] = {
"logit_normal",
"flux2",
"flux",
"beta",
};

const char* sd_scheduler_name(enum scheduler_t scheduler) {
Expand Down
Loading