diff --git a/controllers/nvidiadriver_controller.go b/controllers/nvidiadriver_controller.go index 3dea19e24..19762a633 100644 --- a/controllers/nvidiadriver_controller.go +++ b/controllers/nvidiadriver_controller.go @@ -248,6 +248,33 @@ func (r *NVIDIADriverReconciler) updateCrStatus( return nil } +// enqueueAllNVIDIADrivers lists all NVIDIADriver instances in the cluster and enqueues a reconcile +// request for each instance. This is used to trigger reconciliation for all NVIDIADriver instances +// when a relevant event occurs (e.g. ClusterPolicy/NVIDIADriver update, node label change, etc). +func (r *NVIDIADriverReconciler) enqueueAllNVIDIADrivers(ctx context.Context) []reconcile.Request { + logger := log.FromContext(ctx) + list := &nvidiav1alpha1.NVIDIADriverList{} + + err := r.List(ctx, list) + if err != nil { + logger.Error(err, "Unable to list NVIDIADriver resources") + return []reconcile.Request{} + } + + reconcileRequests := make([]reconcile.Request, 0, len(list.Items)) + for _, nvidiaDriver := range list.Items { + reconcileRequests = append(reconcileRequests, + reconcile.Request{ + NamespacedName: types.NamespacedName{ + Name: nvidiaDriver.GetName(), + Namespace: nvidiaDriver.GetNamespace(), + }, + }) + } + + return reconcileRequests +} + // SetupWithManager sets up the controller with the Manager. func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error { // Create state manager @@ -277,11 +304,17 @@ func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl. return err } - // Watch for changes to the primary resource NVIDIaDriver + // Watch for changes to NVIDIADriver CRs. Whenever an event is generated for a NVIDIADriver CR, + // enqueue a reconcile request for all NVIDIADriver instances. + nvidiaDriverMapFn := func(ctx context.Context, _ *nvidiav1alpha1.NVIDIADriver) []reconcile.Request { + return r.enqueueAllNVIDIADrivers(ctx) + } + + // Watch for changes to the primary resource NVIDIADriver err = c.Watch(source.Kind( mgr.GetCache(), &nvidiav1alpha1.NVIDIADriver{}, - &handler.TypedEnqueueRequestForObject[*nvidiav1alpha1.NVIDIADriver]{}, + handler.TypedEnqueueRequestsFromMapFunc(nvidiaDriverMapFn), predicate.TypedGenerationChangedPredicate[*nvidiav1alpha1.NVIDIADriver]{}, ), ) @@ -291,63 +324,21 @@ func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl. // Watch for changes to ClusterPolicy. Whenever an event is generated for ClusterPolicy, enqueue // a reconcile request for all NVIDIADriver instances. - mapFn := func(ctx context.Context, cp *gpuv1.ClusterPolicy) []reconcile.Request { - logger := log.FromContext(ctx) - opts := []client.ListOption{} - list := &nvidiav1alpha1.NVIDIADriverList{} - - err := mgr.GetClient().List(ctx, list, opts...) - if err != nil { - logger.Error(err, "Unable to list NVIDIADriver resources") - return []reconcile.Request{} - } - - reconcileRequests := []reconcile.Request{} - for _, nvidiaDriver := range list.Items { - reconcileRequests = append(reconcileRequests, - reconcile.Request{ - NamespacedName: types.NamespacedName{ - Name: nvidiaDriver.GetName(), - Namespace: nvidiaDriver.GetNamespace(), - }, - }) - } - - return reconcileRequests + mapFn := func(ctx context.Context, _ *gpuv1.ClusterPolicy) []reconcile.Request { + return r.enqueueAllNVIDIADrivers(ctx) } - // Watch for changes to the Nodes. Whenever an event is generated for ClusterPolicy, enqueue + // Watch for changes to the Nodes. Whenever an event is generated for a Node, enqueue // a reconcile request for all NVIDIADriver instances. - nodeMapFn := func(ctx context.Context, cp *corev1.Node) []reconcile.Request { - logger := log.FromContext(ctx) - opts := []client.ListOption{} - list := &nvidiav1alpha1.NVIDIADriverList{} - - err := mgr.GetClient().List(ctx, list, opts...) - if err != nil { - logger.Error(err, "Unable to list NVIDIADriver resources") - return []reconcile.Request{} - } - - reconcileRequests := []reconcile.Request{} - for _, nvidiaDriver := range list.Items { - reconcileRequests = append(reconcileRequests, - reconcile.Request{ - NamespacedName: types.NamespacedName{ - Name: nvidiaDriver.GetName(), - Namespace: nvidiaDriver.GetNamespace(), - }, - }) - } - - return reconcileRequests + nodeMapFn := func(ctx context.Context, _ *corev1.Node) []reconcile.Request { + return r.enqueueAllNVIDIADrivers(ctx) } err = c.Watch( source.Kind( mgr.GetCache(), &gpuv1.ClusterPolicy{}, - handler.TypedEnqueueRequestsFromMapFunc[*gpuv1.ClusterPolicy](mapFn), + handler.TypedEnqueueRequestsFromMapFunc(mapFn), predicate.TypedGenerationChangedPredicate[*gpuv1.ClusterPolicy]{}, ), ) @@ -385,7 +376,7 @@ func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl. err = c.Watch( source.Kind(mgr.GetCache(), &corev1.Node{}, - handler.TypedEnqueueRequestsFromMapFunc[*corev1.Node](nodeMapFn), + handler.TypedEnqueueRequestsFromMapFunc(nodeMapFn), nodePredicate, ), ) diff --git a/controllers/nvidiadriver_controller_test.go b/controllers/nvidiadriver_controller_test.go index 922c6ed67..f89dfc16e 100644 --- a/controllers/nvidiadriver_controller_test.go +++ b/controllers/nvidiadriver_controller_test.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "errors" + "sort" "testing" "github.com/go-logr/logr" @@ -193,3 +194,24 @@ func TestReconcile(t *testing.T) { }) } } + +func TestEnqueueAllNVIDIADrivers(t *testing.T) { + scheme := runtime.NewScheme() + require.NoError(t, nvidiav1alpha1.AddToScheme(scheme)) + + client := fake.NewClientBuilder().WithScheme(scheme).WithObjects( + &nvidiav1alpha1.NVIDIADriver{ObjectMeta: metav1.ObjectMeta{Name: "driver-a", Namespace: "default"}}, + &nvidiav1alpha1.NVIDIADriver{ObjectMeta: metav1.ObjectMeta{Name: "driver-b", Namespace: "default"}}, + ).Build() + + reconciler := &NVIDIADriverReconciler{Client: client} + requests := reconciler.enqueueAllNVIDIADrivers(context.Background()) + + require.Len(t, requests, 2) + got := []string{ + requests[0].String(), + requests[1].String(), + } + sort.Strings(got) + require.Equal(t, []string{"default/driver-a", "default/driver-b"}, got) +}