Skip to content

Instantly share code, notes, and snippets.

@azkalot1
Created December 18, 2020 05:31
Show Gist options
  • Save azkalot1/dafa0f2b3ff5b306fe00c958e611139f to your computer and use it in GitHub Desktop.
Save azkalot1/dafa0f2b3ff5b306fe00c958e611139f to your computer and use it in GitHub Desktop.
callbacks = [
# Each criterion is calculated separately.
CriterionCallback(
input_key="mask",
prefix="loss_dice",
criterion_key="dice"
),
CriterionCallback(
input_key="mask",
prefix="loss_bce",
criterion_key="bce"
),
# And only then we aggregate everything into one loss.
MetricAggregationCallback(
prefix="loss",
mode="weighted_sum",
metrics={
"loss_dice": 1.0,
"loss_bce": 0.8
},
),
# metrics
IoUMetricsCallback(
mode='binary',
input_key='mask',
)
]
runner = dl.SupervisedRunner(input_key="features", input_target_key="mask")
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
callbacks=callbacks,
logdir='../logs/xray_test_log',
num_epochs=100,
main_metric="loss",
minimize_metric=True,
verbose=True,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment