Skip to content

RegisterComponents() does not register overridden submodules correctly in derived class #1485

@dogvane

Description

@dogvane

🐞 Bug Description
When using TorchSharp's Module system with inheritance, if a derived class overrides a submodule (e.g., replaces a Sequential defined in the base class), calling RegisterComponents() again in the derived class does not correctly reflect the new structure in the state_dict(). The state dictionary still returns the parameters of the original module registered in the base class.

✅ Expected Behavior
When creating two C2 objects with different configurations (one using a regular Conv, the other using a custom DSConv), the state_dict() should reflect the actual registered submodules.

📎 Reproduction Code (Minimal Example)

 
    public class DSConv : Module<Tensor, Tensor>
    {
        private readonly Conv2d dw;
        private readonly Conv2d pw;
        private readonly BatchNorm2d bn;

        public DSConv(int inChannels, int outChannels, int c):
            base(nameof(DSConv))
        {
            RegisterComponents();
        }

        public override Tensor forward(Tensor input)
        {
            return input;
        }
    }

    public class Conv : Module<Tensor, Tensor>
    {
        private readonly Conv2d conv;
        private readonly BatchNorm2d bn;

        public Conv(int inChannels, int outChannels, int c) :
            base(nameof(Conv))
        {
            RegisterComponents();
        }

        public override Tensor forward(Tensor input)
        {
            return input;
        }
    }

    public class C1 : Module<Tensor, Tensor>
    {
        public Sequential m;
        public C1(int inChannels, int outChannels) : base(nameof(C1))
        {
            m = Sequential();

            m = m.append(new Conv(inChannels, outChannels, 3));

            RegisterComponents();
        }

        public override Tensor forward(Tensor input)
        {
            return input;
        }
    }

    public class C2 : C1
    {
        public C2(int inChannels, int outChannels, bool dsc3k)
            : base(inChannels, outChannels)
        {
            // new a Sequential to replace the existing m
            m = Sequential();

            if (dsc3k)
            {
                m.append(new Conv(inChannels, outChannels, 3));
            }
            else
            {
                m.append(new DSConv(inChannels, outChannels, 3));
            }
            RegisterComponents();
        }
    }
    class Program
    {
        static void Main(string[] args)
        {

            var c2_1 = new C2(3, 3, true);

            var c2_2 = new C2(3, 3, false);

            c2_1.state_dict().Keys.ToList().ForEach(k => Console.WriteLine(k));

            Console.WriteLine("======");

            c2_2.state_dict().Keys.ToList().ForEach(k => Console.WriteLine(k));

        }    
    }

📌 Actual Output

m.0.conv.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked
======
m.0.conv.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked

✅ Expected Output

m.0.conv.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked
======
m.0.dw.weight
m.0.pw.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked

🔍 Root Cause
In the base class C1, RegisterComponents() registers the initial version of m. In the derived class C2, we override the m field with a new Sequential, but even after calling RegisterComponents() again, the old components from C1 remain in the module registry.

This leads to an incorrect state_dict() that still reflects the structure of the base class, not the new submodules created in the derived class.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions