jupyter notebook
def build_model(num_classes: int, arch: str = "xception", dropout: float = 0.2, freeze_backbone: bool = True) -> nn.Module:
# Try timm/xception
backbone = None
in_feats = None
if arch.lower() == "xception":
try:
import timm
backbone = timm.create_model("xception", pretrained=True, num_classes=0, global_pool="avg")
in_feats = backbone.num_features
except Exception as e:
print(f"[warn] timm/xception not available ({e}). Falling back to torchvision resnet50.")
if backbone is None:
m = tvm.resnet50(weights=tvm.ResNet50_Weights.IMAGENET1K_V2)
in_feats = m.fc.in_features
backbone = nn.Sequential(*(list(m.children())[:-1])) # (B, 2048, 1, 1)
if freeze_backbone:
for p in backbone.parameters():
p.requires_grad = False
head = nn.Sequential(
nn.Flatten(),
nn.Linear(in_feats, 1024),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(1024, num_classes),
)
class Net(nn.Module):
def __init__(self, backbone, head):
super().__init__()
self.backbone = backbone
self.head = head
def forward(self, x):
if hasattr(self.backbone, "forward_features"): # timm models
feats = self.backbone.forward_features(x)
if feats.ndim == 4:
feats = feats.mean(dim=(-2, -1)) # GAP
else:
feats = self.backbone(x)
return self.head(feats)
model = Net(backbone, head).to(DEVICE)
return model