Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, generalized to the multi-class case.

License

Notifications You must be signed in to change notification settings

AdeelH/pytorch-multi-class-focal-loss

Open more actions menu

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DOI

Multi-class Focal Loss

An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, https://arxiv.org/abs/1708.02002, generalized to the multi-class case.

It is essentially an enhancement to cross-entropy loss and is useful for classification tasks when there is a large class imbalance. It has the effect of underweighting easy examples.

Usage

  • FocalLoss is an nn.Module and behaves very much like nn.CrossEntropyLoss() i.e.

    • supports the reduction and ignore_index params, and
    • is able to work with 2D inputs of shape (N, C) as well as K-dimensional inputs of shape (N, C, d1, d2, ..., dK).
  • Example usage

    focal_loss = FocalLoss(alpha, gamma)
    ...
    inp, targets = batch
    out = model(inp)
    loss = focal_loss(out, targets)

Loading through torch.hub

This repo supports importing modules through torch.hub. FocalLoss can be easily imported into your code via, for example:

focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='FocalLoss',
	alpha=torch.tensor([.75, .25]),
	gamma=2,
	reduction='mean',
	force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)

Or:

focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='focal_loss',
	alpha=[.75, .25],
	gamma=2,
	reduction='mean',
	device='cpu',
	dtype=torch.float32,
	force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.