-
Notifications
You must be signed in to change notification settings - Fork 216
Description
🐞 Bug Description
When using TorchSharp's Module system with inheritance, if a derived class overrides a submodule (e.g., replaces a Sequential defined in the base class), calling RegisterComponents() again in the derived class does not correctly reflect the new structure in the state_dict(). The state dictionary still returns the parameters of the original module registered in the base class.
✅ Expected Behavior
When creating two C2 objects with different configurations (one using a regular Conv, the other using a custom DSConv), the state_dict() should reflect the actual registered submodules.
📎 Reproduction Code (Minimal Example)
public class DSConv : Module<Tensor, Tensor>
{
private readonly Conv2d dw;
private readonly Conv2d pw;
private readonly BatchNorm2d bn;
public DSConv(int inChannels, int outChannels, int c):
base(nameof(DSConv))
{
RegisterComponents();
}
public override Tensor forward(Tensor input)
{
return input;
}
}
public class Conv : Module<Tensor, Tensor>
{
private readonly Conv2d conv;
private readonly BatchNorm2d bn;
public Conv(int inChannels, int outChannels, int c) :
base(nameof(Conv))
{
RegisterComponents();
}
public override Tensor forward(Tensor input)
{
return input;
}
}
public class C1 : Module<Tensor, Tensor>
{
public Sequential m;
public C1(int inChannels, int outChannels) : base(nameof(C1))
{
m = Sequential();
m = m.append(new Conv(inChannels, outChannels, 3));
RegisterComponents();
}
public override Tensor forward(Tensor input)
{
return input;
}
}
public class C2 : C1
{
public C2(int inChannels, int outChannels, bool dsc3k)
: base(inChannels, outChannels)
{
// new a Sequential to replace the existing m
m = Sequential();
if (dsc3k)
{
m.append(new Conv(inChannels, outChannels, 3));
}
else
{
m.append(new DSConv(inChannels, outChannels, 3));
}
RegisterComponents();
}
}
class Program
{
static void Main(string[] args)
{
var c2_1 = new C2(3, 3, true);
var c2_2 = new C2(3, 3, false);
c2_1.state_dict().Keys.ToList().ForEach(k => Console.WriteLine(k));
Console.WriteLine("======");
c2_2.state_dict().Keys.ToList().ForEach(k => Console.WriteLine(k));
}
}📌 Actual Output
m.0.conv.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked
======
m.0.conv.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked
✅ Expected Output
m.0.conv.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked
======
m.0.dw.weight
m.0.pw.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked
🔍 Root Cause
In the base class C1, RegisterComponents() registers the initial version of m. In the derived class C2, we override the m field with a new Sequential, but even after calling RegisterComponents() again, the old components from C1 remain in the module registry.
This leads to an incorrect state_dict() that still reflects the structure of the base class, not the new submodules created in the derived class.