diff --git a/Makefile b/Makefile index 23c088646..3979b80cf 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,7 @@ gen-proto: protoc api/v1/runtime/common.proto --go_out=paths=source_relative:. protoc api/v1/runtime/runtime_agent_api.proto --go_out=paths=source_relative:. --go-grpc_out=require_unimplemented_servers=false:. --go-grpc_opt=paths=source_relative protoc api/v1/kube/kube_api.proto --go_out=paths=source_relative:. --go-grpc_out=require_unimplemented_servers=false:. --go-grpc_opt=paths=source_relative + protoc api/v1/proxy/proxy.proto --go_out=paths=source_relative:. --go-grpc_out=require_unimplemented_servers=false:. --go-grpc_opt=paths=source_relative UNAME_M ?= $(shell uname -m) diff --git a/api/v1/proxy/proxy.pb.go b/api/v1/proxy/proxy.pb.go new file mode 100644 index 000000000..d85678def --- /dev/null +++ b/api/v1/proxy/proxy.pb.go @@ -0,0 +1,394 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.9 +// protoc v6.32.0 +// source: api/v1/proxy/proxy.proto + +package proxy + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type SubscribeRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SubscribeRequest) Reset() { + *x = SubscribeRequest{} + mi := &file_api_v1_proxy_proxy_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SubscribeRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SubscribeRequest) ProtoMessage() {} + +func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_proxy_proxy_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead. +func (*SubscribeRequest) Descriptor() ([]byte, []int) { + return file_api_v1_proxy_proxy_proto_rawDescGZIP(), []int{0} +} + +type SendResponseResult struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SendResponseResult) Reset() { + *x = SendResponseResult{} + mi := &file_api_v1_proxy_proxy_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SendResponseResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SendResponseResult) ProtoMessage() {} + +func (x *SendResponseResult) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_proxy_proxy_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SendResponseResult.ProtoReflect.Descriptor instead. +func (*SendResponseResult) Descriptor() ([]byte, []int) { + return file_api_v1_proxy_proxy_proto_rawDescGZIP(), []int{1} +} + +type HttpRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + Method string `protobuf:"bytes,2,opt,name=method,proto3" json:"method,omitempty"` + Path string `protobuf:"bytes,3,opt,name=path,proto3" json:"path,omitempty"` + Headers []*Header `protobuf:"bytes,4,rep,name=headers,proto3" json:"headers,omitempty"` + Body []byte `protobuf:"bytes,5,opt,name=body,proto3" json:"body,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HttpRequest) Reset() { + *x = HttpRequest{} + mi := &file_api_v1_proxy_proxy_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HttpRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HttpRequest) ProtoMessage() {} + +func (x *HttpRequest) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_proxy_proxy_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HttpRequest.ProtoReflect.Descriptor instead. +func (*HttpRequest) Descriptor() ([]byte, []int) { + return file_api_v1_proxy_proxy_proto_rawDescGZIP(), []int{2} +} + +func (x *HttpRequest) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *HttpRequest) GetMethod() string { + if x != nil { + return x.Method + } + return "" +} + +func (x *HttpRequest) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +func (x *HttpRequest) GetHeaders() []*Header { + if x != nil { + return x.Headers + } + return nil +} + +func (x *HttpRequest) GetBody() []byte { + if x != nil { + return x.Body + } + return nil +} + +type HttpResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + StatusCode int32 `protobuf:"varint,2,opt,name=status_code,json=statusCode,proto3" json:"status_code,omitempty"` + Headers []*Header `protobuf:"bytes,3,rep,name=headers,proto3" json:"headers,omitempty"` + Body []byte `protobuf:"bytes,4,opt,name=body,proto3" json:"body,omitempty"` + More bool `protobuf:"varint,5,opt,name=more,proto3" json:"more,omitempty"` + Error string `protobuf:"bytes,6,opt,name=error,proto3" json:"error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HttpResponse) Reset() { + *x = HttpResponse{} + mi := &file_api_v1_proxy_proxy_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HttpResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HttpResponse) ProtoMessage() {} + +func (x *HttpResponse) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_proxy_proxy_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HttpResponse.ProtoReflect.Descriptor instead. +func (*HttpResponse) Descriptor() ([]byte, []int) { + return file_api_v1_proxy_proxy_proto_rawDescGZIP(), []int{3} +} + +func (x *HttpResponse) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *HttpResponse) GetStatusCode() int32 { + if x != nil { + return x.StatusCode + } + return 0 +} + +func (x *HttpResponse) GetHeaders() []*Header { + if x != nil { + return x.Headers + } + return nil +} + +func (x *HttpResponse) GetBody() []byte { + if x != nil { + return x.Body + } + return nil +} + +func (x *HttpResponse) GetMore() bool { + if x != nil { + return x.More + } + return false +} + +func (x *HttpResponse) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +type Header struct { + state protoimpl.MessageState `protogen:"open.v1"` + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Values []string `protobuf:"bytes,2,rep,name=values,proto3" json:"values,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Header) Reset() { + *x = Header{} + mi := &file_api_v1_proxy_proxy_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Header) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Header) ProtoMessage() {} + +func (x *Header) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_proxy_proxy_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Header.ProtoReflect.Descriptor instead. +func (*Header) Descriptor() ([]byte, []int) { + return file_api_v1_proxy_proxy_proto_rawDescGZIP(), []int{4} +} + +func (x *Header) GetKey() string { + if x != nil { + return x.Key + } + return "" +} + +func (x *Header) GetValues() []string { + if x != nil { + return x.Values + } + return nil +} + +var File_api_v1_proxy_proxy_proto protoreflect.FileDescriptor + +const file_api_v1_proxy_proxy_proto_rawDesc = "" + + "\n" + + "\x18api/v1/proxy/proxy.proto\x12\bproxy.v1\"\x12\n" + + "\x10SubscribeRequest\"\x14\n" + + "\x12SendResponseResult\"\x98\x01\n" + + "\vHttpRequest\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x16\n" + + "\x06method\x18\x02 \x01(\tR\x06method\x12\x12\n" + + "\x04path\x18\x03 \x01(\tR\x04path\x12*\n" + + "\aheaders\x18\x04 \x03(\v2\x10.proxy.v1.HeaderR\aheaders\x12\x12\n" + + "\x04body\x18\x05 \x01(\fR\x04body\"\xb8\x01\n" + + "\fHttpResponse\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x1f\n" + + "\vstatus_code\x18\x02 \x01(\x05R\n" + + "statusCode\x12*\n" + + "\aheaders\x18\x03 \x03(\v2\x10.proxy.v1.HeaderR\aheaders\x12\x12\n" + + "\x04body\x18\x04 \x01(\fR\x04body\x12\x12\n" + + "\x04more\x18\x05 \x01(\bR\x04more\x12\x14\n" + + "\x05error\x18\x06 \x01(\tR\x05error\"2\n" + + "\x06Header\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x16\n" + + "\x06values\x18\x02 \x03(\tR\x06values2\x9f\x01\n" + + "\x0fKubernetesProxy\x12B\n" + + "\tSubscribe\x12\x1a.proxy.v1.SubscribeRequest\x1a\x15.proxy.v1.HttpRequest\"\x000\x01\x12H\n" + + "\fSendResponse\x12\x16.proxy.v1.HttpResponse\x1a\x1c.proxy.v1.SendResponseResult\"\x00(\x01B'Z%github.com/castai/kvisor/api/v1/proxyb\x06proto3" + +var ( + file_api_v1_proxy_proxy_proto_rawDescOnce sync.Once + file_api_v1_proxy_proxy_proto_rawDescData []byte +) + +func file_api_v1_proxy_proxy_proto_rawDescGZIP() []byte { + file_api_v1_proxy_proxy_proto_rawDescOnce.Do(func() { + file_api_v1_proxy_proxy_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_api_v1_proxy_proxy_proto_rawDesc), len(file_api_v1_proxy_proxy_proto_rawDesc))) + }) + return file_api_v1_proxy_proxy_proto_rawDescData +} + +var file_api_v1_proxy_proxy_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_api_v1_proxy_proxy_proto_goTypes = []any{ + (*SubscribeRequest)(nil), // 0: proxy.v1.SubscribeRequest + (*SendResponseResult)(nil), // 1: proxy.v1.SendResponseResult + (*HttpRequest)(nil), // 2: proxy.v1.HttpRequest + (*HttpResponse)(nil), // 3: proxy.v1.HttpResponse + (*Header)(nil), // 4: proxy.v1.Header +} +var file_api_v1_proxy_proxy_proto_depIdxs = []int32{ + 4, // 0: proxy.v1.HttpRequest.headers:type_name -> proxy.v1.Header + 4, // 1: proxy.v1.HttpResponse.headers:type_name -> proxy.v1.Header + 0, // 2: proxy.v1.KubernetesProxy.Subscribe:input_type -> proxy.v1.SubscribeRequest + 3, // 3: proxy.v1.KubernetesProxy.SendResponse:input_type -> proxy.v1.HttpResponse + 2, // 4: proxy.v1.KubernetesProxy.Subscribe:output_type -> proxy.v1.HttpRequest + 1, // 5: proxy.v1.KubernetesProxy.SendResponse:output_type -> proxy.v1.SendResponseResult + 4, // [4:6] is the sub-list for method output_type + 2, // [2:4] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_api_v1_proxy_proxy_proto_init() } +func file_api_v1_proxy_proxy_proto_init() { + if File_api_v1_proxy_proxy_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_api_v1_proxy_proxy_proto_rawDesc), len(file_api_v1_proxy_proxy_proto_rawDesc)), + NumEnums: 0, + NumMessages: 5, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_api_v1_proxy_proxy_proto_goTypes, + DependencyIndexes: file_api_v1_proxy_proxy_proto_depIdxs, + MessageInfos: file_api_v1_proxy_proxy_proto_msgTypes, + }.Build() + File_api_v1_proxy_proxy_proto = out.File + file_api_v1_proxy_proxy_proto_goTypes = nil + file_api_v1_proxy_proxy_proto_depIdxs = nil +} diff --git a/api/v1/proxy/proxy.proto b/api/v1/proxy/proxy.proto new file mode 100644 index 000000000..65590de74 --- /dev/null +++ b/api/v1/proxy/proxy.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; + +package proxy.v1; + +option go_package = "github.com/castai/kvisor/api/v1/proxy"; + +service KubernetesProxy { + rpc Subscribe(SubscribeRequest) returns (stream HttpRequest) {} + rpc SendResponse(stream HttpResponse) returns (SendResponseResult) {} +} + +message SubscribeRequest {} + +message SendResponseResult {} + +message HttpRequest { + string request_id = 1; + string method = 2; + string path = 3; + repeated Header headers = 4; + bytes body = 5; +} + +message HttpResponse { + string request_id = 1; + int32 status_code = 2; + repeated Header headers = 3; + bytes body = 4; + bool more = 5; + string error = 6; +} + +message Header { + string key = 1; + repeated string values = 2; +} diff --git a/api/v1/proxy/proxy_grpc.pb.go b/api/v1/proxy/proxy_grpc.pb.go new file mode 100644 index 000000000..49e6766eb --- /dev/null +++ b/api/v1/proxy/proxy_grpc.pb.go @@ -0,0 +1,153 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v6.32.0 +// source: api/v1/proxy/proxy.proto + +package proxy + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + KubernetesProxy_Subscribe_FullMethodName = "/proxy.v1.KubernetesProxy/Subscribe" + KubernetesProxy_SendResponse_FullMethodName = "/proxy.v1.KubernetesProxy/SendResponse" +) + +// KubernetesProxyClient is the client API for KubernetesProxy service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type KubernetesProxyClient interface { + Subscribe(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[HttpRequest], error) + SendResponse(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[HttpResponse, SendResponseResult], error) +} + +type kubernetesProxyClient struct { + cc grpc.ClientConnInterface +} + +func NewKubernetesProxyClient(cc grpc.ClientConnInterface) KubernetesProxyClient { + return &kubernetesProxyClient{cc} +} + +func (c *kubernetesProxyClient) Subscribe(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[HttpRequest], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &KubernetesProxy_ServiceDesc.Streams[0], KubernetesProxy_Subscribe_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[SubscribeRequest, HttpRequest]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type KubernetesProxy_SubscribeClient = grpc.ServerStreamingClient[HttpRequest] + +func (c *kubernetesProxyClient) SendResponse(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[HttpResponse, SendResponseResult], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &KubernetesProxy_ServiceDesc.Streams[1], KubernetesProxy_SendResponse_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[HttpResponse, SendResponseResult]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type KubernetesProxy_SendResponseClient = grpc.ClientStreamingClient[HttpResponse, SendResponseResult] + +// KubernetesProxyServer is the server API for KubernetesProxy service. +// All implementations should embed UnimplementedKubernetesProxyServer +// for forward compatibility. +type KubernetesProxyServer interface { + Subscribe(*SubscribeRequest, grpc.ServerStreamingServer[HttpRequest]) error + SendResponse(grpc.ClientStreamingServer[HttpResponse, SendResponseResult]) error +} + +// UnimplementedKubernetesProxyServer should be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedKubernetesProxyServer struct{} + +func (UnimplementedKubernetesProxyServer) Subscribe(*SubscribeRequest, grpc.ServerStreamingServer[HttpRequest]) error { + return status.Errorf(codes.Unimplemented, "method Subscribe not implemented") +} +func (UnimplementedKubernetesProxyServer) SendResponse(grpc.ClientStreamingServer[HttpResponse, SendResponseResult]) error { + return status.Errorf(codes.Unimplemented, "method SendResponse not implemented") +} +func (UnimplementedKubernetesProxyServer) testEmbeddedByValue() {} + +// UnsafeKubernetesProxyServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to KubernetesProxyServer will +// result in compilation errors. +type UnsafeKubernetesProxyServer interface { + mustEmbedUnimplementedKubernetesProxyServer() +} + +func RegisterKubernetesProxyServer(s grpc.ServiceRegistrar, srv KubernetesProxyServer) { + // If the following call pancis, it indicates UnimplementedKubernetesProxyServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&KubernetesProxy_ServiceDesc, srv) +} + +func _KubernetesProxy_Subscribe_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(SubscribeRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(KubernetesProxyServer).Subscribe(m, &grpc.GenericServerStream[SubscribeRequest, HttpRequest]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type KubernetesProxy_SubscribeServer = grpc.ServerStreamingServer[HttpRequest] + +func _KubernetesProxy_SendResponse_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(KubernetesProxyServer).SendResponse(&grpc.GenericServerStream[HttpResponse, SendResponseResult]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type KubernetesProxy_SendResponseServer = grpc.ClientStreamingServer[HttpResponse, SendResponseResult] + +// KubernetesProxy_ServiceDesc is the grpc.ServiceDesc for KubernetesProxy service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var KubernetesProxy_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "proxy.v1.KubernetesProxy", + HandlerType: (*KubernetesProxyServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Subscribe", + Handler: _KubernetesProxy_Subscribe_Handler, + ServerStreams: true, + }, + { + StreamName: "SendResponse", + Handler: _KubernetesProxy_SendResponse_Handler, + ClientStreams: true, + }, + }, + Metadata: "api/v1/proxy/proxy.proto", +} diff --git a/charts/kvisor/templates/controller.yaml b/charts/kvisor/templates/controller.yaml index e59b8a018..604a57692 100644 --- a/charts/kvisor/templates/controller.yaml +++ b/charts/kvisor/templates/controller.yaml @@ -86,6 +86,10 @@ spec: {{- if .Values.controller.netflow.staticCIDRs.mappings }} - "--cloud-provider-static-cidrs-file=/etc/kvisor/static-cidrs/static-cidrs.yaml" {{- end }} + {{- if .Values.controller.kubeProxy.enabled }} + - "--kube-proxy-enabled=true" + - "--kube-proxy-restricted-sa-name={{ .Values.controller.kubeProxy.restrictedServiceAccount.name }}" + {{- end }} {{- range $key, $value := .Values.controller.extraArgs }} - "--{{ $key }}={{ $value }}" {{- end }} @@ -335,6 +339,16 @@ rules: - get - list - watch + {{- if .Values.controller.kubeProxy.enabled }} + - apiGroups: + - "" + resources: + - serviceaccounts/token + resourceNames: + - {{ .Values.controller.kubeProxy.restrictedServiceAccount.name }} + verbs: + - create + {{- end }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding diff --git a/charts/kvisor/templates/proxy-rbac.yaml b/charts/kvisor/templates/proxy-rbac.yaml new file mode 100644 index 000000000..07f00f57e --- /dev/null +++ b/charts/kvisor/templates/proxy-rbac.yaml @@ -0,0 +1,93 @@ +{{- if and .Values.controller.enabled .Values.controller.kubeProxy.enabled }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ .Values.controller.kubeProxy.restrictedServiceAccount.name }} + namespace: {{ .Release.Namespace }} + labels: + {{- include "kvisor.labels" . | nindent 4 }} +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: {{ include "kvisor.fullname" . }}-proxy + labels: + {{- include "kvisor.labels" . | nindent 4 }} +rules: + - apiGroups: [""] + resources: + - pods + - pods/log + - pods/status + - services + - configmaps + - namespaces + - nodes + - endpoints + - events + - persistentvolumeclaims + - persistentvolumes + - resourcequotas + - serviceaccounts + - limitranges + - replicationcontrollers + verbs: ["get", "list", "watch"] + - apiGroups: ["apps"] + resources: + - deployments + - replicasets + - daemonsets + - statefulsets + verbs: ["get", "list", "watch"] + - apiGroups: ["batch"] + resources: + - jobs + - cronjobs + verbs: ["get", "list", "watch"] + - apiGroups: ["networking.k8s.io"] + resources: + - ingresses + - ingressclasses + - networkpolicies + verbs: ["get", "list", "watch"] + - apiGroups: ["rbac.authorization.k8s.io"] + resources: + - roles + - rolebindings + - clusterroles + - clusterrolebindings + verbs: ["get", "list", "watch"] + - apiGroups: ["storage.k8s.io"] + resources: + - storageclasses + - volumeattachments + verbs: ["get", "list", "watch"] + - apiGroups: ["autoscaling"] + resources: + - horizontalpodautoscalers + verbs: ["get", "list", "watch"] + - apiGroups: ["policy"] + resources: + - poddisruptionbudgets + verbs: ["get", "list", "watch"] + - apiGroups: ["metrics.k8s.io"] + resources: + - pods + - nodes + verbs: ["get", "list"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: {{ include "kvisor.fullname" . }}-proxy + labels: + {{- include "kvisor.labels" . | nindent 4 }} +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: {{ include "kvisor.fullname" . }}-proxy +subjects: + - kind: ServiceAccount + name: {{ .Values.controller.kubeProxy.restrictedServiceAccount.name }} + namespace: {{ .Release.Namespace }} +{{- end }} diff --git a/charts/kvisor/values.yaml b/charts/kvisor/values.yaml index 1bc33da04..4dd5280a4 100644 --- a/charts/kvisor/values.yaml +++ b/charts/kvisor/values.yaml @@ -282,6 +282,11 @@ controller: # Additional environment variables for the controller container via configMaps or secrets. envFrom: [] + kubeProxy: + enabled: false + restrictedServiceAccount: + name: "kvisor-proxy" + # Deprecated: use additionalEnv instead. extraEnv: {} diff --git a/cmd/controller/app/app.go b/cmd/controller/app/app.go index 9e2ed41ce..7d26405c9 100644 --- a/cmd/controller/app/app.go +++ b/cmd/controller/app/app.go @@ -20,10 +20,14 @@ import ( "github.com/samber/lo" "golang.org/x/sync/errgroup" "google.golang.org/grpc" + authv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" kubepb "github.com/castai/kvisor/api/v1/kube" + proxypb "github.com/castai/kvisor/api/v1/proxy" "github.com/castai/kvisor/cmd/controller/config" "github.com/castai/kvisor/cmd/controller/controllers" "github.com/castai/kvisor/cmd/controller/controllers/imagescan" @@ -33,21 +37,22 @@ import ( "github.com/castai/kvisor/pkg/blobscache" "github.com/castai/kvisor/pkg/castai" "github.com/castai/kvisor/pkg/cloudprovider" + "github.com/castai/kvisor/pkg/kubeproxy" "github.com/castai/logging" "github.com/castai/logging/components" ) -func New(cfg config.Config, clientset kubernetes.Interface) *App { +func New(cfg config.Config, clientset kubernetes.Interface, kubeConfig *rest.Config) *App { if err := validator.New().Struct(cfg); err != nil { panic(fmt.Errorf("invalid config: %w", err).Error()) } - return &App{cfg: cfg, kubeClient: clientset} + return &App{cfg: cfg, kubeClient: clientset, kubeConfig: kubeConfig} } type App struct { - cfg config.Config - + cfg config.Config kubeClient kubernetes.Interface + kubeConfig *rest.Config } func parseLogLevel(lvlStr string) (slog.Level, error) { @@ -230,6 +235,20 @@ func (a *App) Run(ctx context.Context) error { return kubeBenchCtrl.Run(ctx) }) } + + if cfg.KubeProxy.Enabled { + proxyClient, err := setupKubeProxy(log, cfg, castaiClient, clientset, a.kubeConfig) + if err != nil { + log.Errorf("failed to setup kube proxy: %v", err) + } else { + errg.Go(func() error { + if err := proxyClient.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.Errorf("kube proxy client stopped: %v", err) + } + return nil + }) + } + } } errg.Go(func() error { @@ -372,3 +391,38 @@ func (a *App) runMetricsHTTPServer(ctx context.Context, log *logging.Logger) err return nil } + +func setupKubeProxy(log *logging.Logger, cfg config.Config, castaiClient *castai.Client, clientset kubernetes.Interface, restCfg *rest.Config) (*kubeproxy.Client, error) { + expSeconds := cfg.KubeProxy.TokenExpirationSeconds + tokenProvider := kubeproxy.NewTokenProvider(kubeproxy.TokenProviderConfig{ + CreateToken: func(ctx context.Context) (string, time.Time, error) { + treq, err := clientset.CoreV1().ServiceAccounts(cfg.KubeProxy.RestrictedSANamespace).CreateToken( + ctx, + cfg.KubeProxy.RestrictedSAName, + &authv1.TokenRequest{ + Spec: authv1.TokenRequestSpec{ + ExpirationSeconds: &expSeconds, + }, + }, + metav1.CreateOptions{}, + ) + if err != nil { + return "", time.Time{}, fmt.Errorf("creating token for restricted SA: %w", err) + } + return treq.Status.Token, treq.Status.ExpirationTimestamp.Time, nil + }, + }) + + baseTransport, err := rest.TransportFor(restCfg) + if err != nil { + return nil, fmt.Errorf("building k8s transport: %w", err) + } + + httpClient := &http.Client{ + Transport: kubeproxy.NewTokenRoundTripper(tokenProvider, baseTransport), + Timeout: 30 * time.Second, + } + + proxyGRPC := proxypb.NewKubernetesProxyClient(castaiClient.GRPCConn()) + return kubeproxy.NewClient(log, proxyGRPC, httpClient, restCfg.Host) +} diff --git a/cmd/controller/config/config.go b/cmd/controller/config/config.go index a9c1b9f9e..2188d9491 100644 --- a/cmd/controller/config/config.go +++ b/cmd/controller/config/config.go @@ -41,6 +41,14 @@ type Config struct { JobsCleanup controllers.JobsCleanupConfig `json:"jobsCleanup"` AgentConfig AgentConfig `json:"agentConfig"` CloudProviderConfig CloudProviderConfig `json:"cloudProviderConfig"` + KubeProxy KubeProxyConfig `json:"kubeProxy"` +} + +type KubeProxyConfig struct { + Enabled bool `json:"enabled"` + RestrictedSAName string `json:"restrictedSAName"` + RestrictedSANamespace string `json:"restrictedSANamespace"` + TokenExpirationSeconds int64 `json:"tokenExpirationSeconds"` } type AgentConfig struct { diff --git a/cmd/controller/main.go b/cmd/controller/main.go index 94e0573b6..becfcbd64 100644 --- a/cmd/controller/main.go +++ b/cmd/controller/main.go @@ -96,6 +96,10 @@ var ( jobsCleanupJobAge = pflag.Duration("jobs-cleanup-job-age", 10*time.Minute, "Jobs cleanup job age") agentEnabled = pflag.Bool("agent-enabled", false, "Whether kvisor-agent is enabled (used for reporting; does not enable agent)") + + kubeProxyEnabled = pflag.Bool("kube-proxy-enabled", false, "Enable kube proxy for remote K8s API access") + kubeProxyRestrictedSAName = pflag.String("kube-proxy-restricted-sa-name", "kvisor-proxy", "Name of the restricted service account for proxy") + kubeProxyTokenExpiration = pflag.Int64("kube-proxy-token-expiration", 900, "Token expiration in seconds for the restricted SA") ) func main() { @@ -209,6 +213,12 @@ func main() { AgentConfig: config.AgentConfig{ Enabled: *agentEnabled, }, + KubeProxy: config.KubeProxyConfig{ + Enabled: *kubeProxyEnabled, + RestrictedSAName: *kubeProxyRestrictedSAName, + RestrictedSANamespace: podNs, + TokenExpirationSeconds: *kubeProxyTokenExpiration, + }, CloudProviderConfig: config.CloudProviderConfig{ CloudProvider: cloudtypes.ProviderConfig{ Type: cloudProviderType, @@ -230,6 +240,7 @@ func main() { }, }, clientset, + kubeConfig, ) if err := appInstance.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { diff --git a/pkg/castai/client.go b/pkg/castai/client.go index b4d062759..e100ec6e5 100644 --- a/pkg/castai/client.go +++ b/pkg/castai/client.go @@ -4,11 +4,13 @@ import ( "context" "fmt" "strings" + "time" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/encoding/gzip" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" castaipb "github.com/castai/kvisor/api/v1/runtime" @@ -25,6 +27,11 @@ func NewClient(userAgent string, cfg Config) (*Client, error) { grpc.WithUserAgent(userAgent), grpc.WithUnaryInterceptor(newCastaiGrpcUnaryMetadataInterceptor(cfg)), grpc.WithStreamInterceptor(newCastaiGrpcStreamMetadataInterceptor(cfg)), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 10 * time.Second, + Timeout: 5 * time.Second, + PermitWithoutStream: true, + }), ) if err != nil { return nil, fmt.Errorf("castai grpc server dial: %w", err) @@ -53,6 +60,10 @@ func (c *Client) Close() { } } +func (c *Client) GRPCConn() grpc.ClientConnInterface { + return c.grpcConn +} + func (c *Client) GetCompressionName() string { return c.compressionName } diff --git a/pkg/kubeproxy/client.go b/pkg/kubeproxy/client.go new file mode 100644 index 000000000..e56ae0018 --- /dev/null +++ b/pkg/kubeproxy/client.go @@ -0,0 +1,358 @@ +package kubeproxy + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "path" + "strings" + "sync" + "time" + + "github.com/cenkalti/backoff/v5" + "google.golang.org/grpc/codes" + + proxypb "github.com/castai/kvisor/api/v1/proxy" + "github.com/castai/kvisor/pkg/castai" + "github.com/castai/logging" +) + +const ( + maxConcurrentRequests = 50 + maxResponseChunkSize = 32 * 1024 + sendResponseTimeout = 30 * time.Second +) + +var allowedResponseHeaders = map[string]bool{ + "Audit-Id": true, + "Content-Type": true, + "Content-Length": true, + "Content-Encoding": true, + "Cache-Control": true, + "Date": true, + "X-Kubernetes-Pf-Flowschema-Uid": true, + "X-Kubernetes-Pf-Prioritylevel-Uid": true, +} + +var blockedRequestHeaders = map[string]bool{ + "authorization": true, + "impersonate-user": true, + "impersonate-group": true, + "impersonate-uid": true, +} + +func isBlockedRequestHeader(key string) bool { + lower := strings.ToLower(key) + return blockedRequestHeaders[lower] || strings.HasPrefix(lower, "impersonate-extra-") +} + +var blockedSubresources = map[string]bool{ + "exec": true, + "attach": true, + "portforward": true, + "proxy": true, +} + +type Client struct { + log *logging.Logger + proxyClient proxypb.KubernetesProxyClient + httpClient *http.Client + kubeHost *url.URL +} + +func NewClient(log *logging.Logger, proxyClient proxypb.KubernetesProxyClient, httpClient *http.Client, kubeHost string) (*Client, error) { + parsed, err := url.Parse(kubeHost) + if err != nil { + return nil, fmt.Errorf("parsing kube host URL: %w", err) + } + return &Client{ + log: log, + proxyClient: proxyClient, + httpClient: httpClient, + kubeHost: parsed, + }, nil +} + +func (c *Client) Run(ctx context.Context) error { + c.log.Info("starting kube proxy client") + defer c.log.Info("stopping kube proxy client") + + op := func() (struct{}, error) { + err := c.subscribe(ctx) + if ctx.Err() != nil { + return struct{}{}, backoff.Permanent(ctx.Err()) + } + if castai.IsGRPCError(err, codes.PermissionDenied, codes.Unauthenticated, codes.Unimplemented) { + c.log.Errorf("proxy subscription failed permanently: %v", err) + return struct{}{}, backoff.Permanent(err) + } + c.log.Warnf("proxy subscription closed, reconnecting: %v", err) + return struct{}{}, err + } + + eb := backoff.NewExponentialBackOff() + eb.InitialInterval = 1 * time.Second + eb.MaxInterval = 30 * time.Second + + // context.Background() prevents sibling errgroup cancellations from interrupting + // the backoff sleep; shutdown is handled inside op via ctx.Err(). + _, err := backoff.Retry(context.Background(), op, + backoff.WithBackOff(eb), + ) + return err +} + +func (c *Client) subscribe(ctx context.Context) error { + subCtx, subCancel := context.WithCancel(ctx) + defer subCancel() + + stream, err := c.proxyClient.Subscribe(subCtx, &proxypb.SubscribeRequest{}) + if err != nil { + return fmt.Errorf("subscribe: %w", err) + } + + c.log.Info("subscribed to proxy requests") + + sem := make(chan struct{}, maxConcurrentRequests) + var wg sync.WaitGroup + + for { + req, err := stream.Recv() + if err != nil { + subCancel() + wg.Wait() + return fmt.Errorf("recv: %w", err) + } + + select { + case sem <- struct{}{}: + case <-subCtx.Done(): + wg.Wait() + return subCtx.Err() + } + wg.Add(1) + go func() { + defer wg.Done() + defer func() { <-sem }() + c.handleRequest(subCtx, req) + }() + } +} + +func (c *Client) handleRequest(ctx context.Context, req *proxypb.HttpRequest) { + log := c.log.With("request_id", req.RequestId, "method", req.Method, "path", req.Path) + + if err := validateRequest(req); err != nil { + log.Warnf("invalid request: %v", err) + c.sendErrorResponse(ctx, req.RequestId, http.StatusBadRequest, err.Error()) + return + } + + sanitized := sanitizeRequestURL(req.Path) + pathPart, rawQuery, _ := strings.Cut(sanitized, "?") + reqURL := *c.kubeHost + reqURL.Path = pathPart + reqURL.RawQuery = rawQuery + url := reqURL.String() + var body io.Reader + if len(req.Body) > 0 { + body = bytes.NewReader(req.Body) + } + httpReq, err := http.NewRequestWithContext(ctx, req.Method, url, body) + if err != nil { + log.Warnf("creating http request: %v", err) + c.sendErrorResponse(ctx, req.RequestId, http.StatusInternalServerError, "failed to create request") + return + } + + for _, h := range req.Headers { + if isBlockedRequestHeader(h.Key) { + continue + } + for _, v := range h.Values { + httpReq.Header.Add(h.Key, v) + } + } + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + log.Warnf("executing k8s request: %v", err) + c.sendErrorResponse(ctx, req.RequestId, http.StatusBadGateway, "failed to reach kubernetes api") + return + } + defer resp.Body.Close() + + if err := c.streamResponse(ctx, req.RequestId, resp); err != nil { + log.Warnf("streaming response: %v", err) + } +} + +func (c *Client) streamResponse(ctx context.Context, requestID string, resp *http.Response) error { + sendCtx, sendCancel := context.WithTimeout(ctx, sendResponseTimeout) + defer sendCancel() + + sendStream, err := c.proxyClient.SendResponse(sendCtx) + if err != nil { + return fmt.Errorf("opening send stream: %w", err) + } + + headers := filterResponseHeaders(resp.Header) + + buf := make([]byte, maxResponseChunkSize) + first := true + for { + n, readErr := resp.Body.Read(buf) + isLast := readErr != nil + + if n > 0 { + msg := &proxypb.HttpResponse{ + RequestId: requestID, + Body: buf[:n], + More: !isLast, + } + if first { + msg.StatusCode = int32(resp.StatusCode) + msg.Headers = headers + first = false + } + if err := sendStream.Send(msg); err != nil { + return fmt.Errorf("send response chunk: %w", err) + } + } + + // EOF with 0 bytes means the previous chunk was already sent with More=true + // (we didn't know it was last until this read). Send an explicit terminator. + if readErr == io.EOF { + if n == 0 && !first { + if err := sendStream.Send(&proxypb.HttpResponse{ + RequestId: requestID, + More: false, + }); err != nil { + return fmt.Errorf("send response chunk: %w", err) + } + } + break + } + if readErr != nil { + _ = sendStream.Send(&proxypb.HttpResponse{ + RequestId: requestID, + Error: readErr.Error(), + }) + break + } + } + + if first { + if err := sendStream.Send(&proxypb.HttpResponse{ + RequestId: requestID, + StatusCode: int32(resp.StatusCode), + Headers: headers, + }); err != nil { + return fmt.Errorf("send empty response: %w", err) + } + } + + if _, err := sendStream.CloseAndRecv(); err != nil { + return fmt.Errorf("close send stream: %w", err) + } + return nil +} + +func (c *Client) sendErrorResponse(ctx context.Context, requestID string, statusCode int, errMsg string) { + sendCtx, sendCancel := context.WithTimeout(ctx, sendResponseTimeout) + defer sendCancel() + + sendStream, err := c.proxyClient.SendResponse(sendCtx) + if err != nil { + c.log.Warnf("opening error send stream: %v", err) + return + } + + _ = sendStream.Send(&proxypb.HttpResponse{ + RequestId: requestID, + StatusCode: int32(statusCode), + Error: errMsg, + }) + _, _ = sendStream.CloseAndRecv() +} + +func validateRequest(req *proxypb.HttpRequest) error { + if req.Method != http.MethodGet { + return fmt.Errorf("only GET requests are allowed, got %s", req.Method) + } + + pathPart, _, _ := strings.Cut(req.Path, "?") + cleaned := path.Clean(pathPart) + if !isAllowedPath(cleaned) { + return fmt.Errorf("path %q is not allowed, must start with /api/ or /apis/", req.Path) + } + + subresource := extractSubresource(cleaned) + if blockedSubresources[subresource] { + return fmt.Errorf("subresource %q is not allowed", subresource) + } + + return nil +} + +func sanitizeRequestURL(raw string) string { + pathPart, query, _ := strings.Cut(raw, "?") + cleaned := path.Clean(pathPart) + if query != "" { + return cleaned + "?" + query + } + return cleaned +} + +func isAllowedPath(p string) bool { + return p == "/api" || p == "/apis" || + strings.HasPrefix(p, "/api/") || strings.HasPrefix(p, "/apis/") +} + +func extractSubresource(path string) string { + parts := strings.Split(strings.TrimPrefix(path, "/"), "/") + + // /api/v1/namespaces/{ns}/pods/{name}/{subresource} + // /apis/{group}/{version}/namespaces/{ns}/{resource}/{name}/{subresource} + if len(parts) < 2 { + return "" + } + + for i, p := range parts { + if p == "namespaces" && i+1 < len(parts) { + rest := parts[i+2:] + if len(rest) >= 3 { + return rest[2] + } + return "" + } + } + + // /api/v1/{resource}/{name}/{subresource} + // /apis/{group}/{version}/{resource}/{name}/{subresource} + startIdx := 2 + if parts[0] == "apis" { + startIdx = 3 + } + if len(parts) > startIdx+2 { + return parts[startIdx+2] + } + + return "" +} + +func filterResponseHeaders(headers http.Header) []*proxypb.Header { + var result []*proxypb.Header + for k, v := range headers { + if allowedResponseHeaders[k] { + result = append(result, &proxypb.Header{ + Key: k, + Values: v, + }) + } + } + return result +} diff --git a/pkg/kubeproxy/client_test.go b/pkg/kubeproxy/client_test.go new file mode 100644 index 000000000..c09b5a49f --- /dev/null +++ b/pkg/kubeproxy/client_test.go @@ -0,0 +1,294 @@ +package kubeproxy + +import ( + "net/http" + "testing" + + proxypb "github.com/castai/kvisor/api/v1/proxy" + "github.com/stretchr/testify/require" +) + +func TestValidateRequest(t *testing.T) { + tests := []struct { + name string + req *proxypb.HttpRequest + wantErr string + }{ + { + name: "valid GET to core API", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/namespaces/default/pods", + }, + }, + { + name: "valid GET to extended API", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/apis/apps/v1/namespaces/kube-system/deployments/coredns", + }, + }, + { + name: "valid GET to list all pods", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/pods", + }, + }, + { + name: "POST rejected", + req: &proxypb.HttpRequest{ + Method: http.MethodPost, + Path: "/api/v1/namespaces/default/pods", + }, + wantErr: "only GET requests are allowed", + }, + { + name: "DELETE rejected", + req: &proxypb.HttpRequest{ + Method: http.MethodDelete, + Path: "/api/v1/namespaces/default/pods/my-pod", + }, + wantErr: "only GET requests are allowed", + }, + { + name: "/healthz path rejected", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/healthz", + }, + wantErr: "path \"/healthz\" is not allowed", + }, + { + name: "/metrics path rejected", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/metrics", + }, + wantErr: "is not allowed", + }, + { + name: "/debug/pprof path rejected", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/debug/pprof/", + }, + wantErr: "is not allowed", + }, + { + name: "root path rejected", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/", + }, + wantErr: "is not allowed", + }, + { + name: "exec subresource blocked", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/namespaces/default/pods/my-pod/exec", + }, + wantErr: `subresource "exec" is not allowed`, + }, + { + name: "attach subresource blocked", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/namespaces/default/pods/my-pod/attach", + }, + wantErr: `subresource "attach" is not allowed`, + }, + { + name: "portforward subresource blocked", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/namespaces/default/pods/my-pod/portforward", + }, + wantErr: `subresource "portforward" is not allowed`, + }, + { + name: "log subresource allowed", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/namespaces/default/pods/my-pod/log", + }, + }, + { + name: "logs subresource allowed", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/namespaces/default/pods/my-pod/logs", + }, + }, + { + name: "proxy subresource blocked", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/namespaces/default/pods/my-pod/proxy", + }, + wantErr: `subresource "proxy" is not allowed`, + }, + { + name: "status subresource allowed", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/namespaces/default/pods/my-pod/status", + }, + }, + { + name: "path traversal via .. rejected", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/../../debug/pprof/", + }, + wantErr: "is not allowed", + }, + { + name: "exec with query string still blocked", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/namespaces/default/pods/my-pod/exec?command=ls", + }, + wantErr: `subresource "exec" is not allowed`, + }, + { + name: "log with follow query string allowed", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/namespaces/default/pods/my-pod/log?follow=true", + }, + }, + { + name: "valid path with query string allowed", + req: &proxypb.HttpRequest{ + Method: http.MethodGet, + Path: "/api/v1/pods?labelSelector=app%3Dnginx", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateRequest(tt.req) + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestExtractSubresource(t *testing.T) { + tests := []struct { + name string + path string + want string + }{ + { + name: "no subresource - list pods", + path: "/api/v1/pods", + want: "", + }, + { + name: "no subresource - namespaced pods", + path: "/api/v1/namespaces/default/pods", + want: "", + }, + { + name: "no subresource - specific pod", + path: "/api/v1/namespaces/default/pods/my-pod", + want: "", + }, + { + name: "exec subresource", + path: "/api/v1/namespaces/default/pods/my-pod/exec", + want: "exec", + }, + { + name: "log subresource", + path: "/api/v1/namespaces/default/pods/my-pod/log", + want: "log", + }, + { + name: "status subresource on apps resource", + path: "/apis/apps/v1/namespaces/default/deployments/my-deploy/status", + want: "status", + }, + { + name: "cluster-scoped - nodes status", + path: "/api/v1/nodes/my-node/status", + want: "status", + }, + { + name: "cluster-scoped - no subresource", + path: "/api/v1/nodes/my-node", + want: "", + }, + { + name: "empty path", + path: "/", + want: "", + }, + { + name: "just api version", + path: "/api/v1", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractSubresource(tt.path) + require.Equal(t, tt.want, got) + }) + } +} + +func TestIsAllowedPath(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"/api/v1/pods", true}, + {"/api/v1/namespaces/default/pods", true}, + {"/apis/apps/v1/deployments", true}, + {"/healthz", false}, + {"/metrics", false}, + {"/debug/pprof/", false}, + {"/", false}, + {"/version", false}, + {"/openapi/v2", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + require.Equal(t, tt.want, isAllowedPath(tt.path)) + }) + } +} + +func TestFilterResponseHeaders(t *testing.T) { + headers := http.Header{ + "Content-Type": {"application/json"}, + "Content-Length": {"1234"}, + "X-Secret-Header": {"should-be-filtered"}, + "Cache-Control": {"no-cache"}, + "Set-Cookie": {"should-be-filtered"}, + } + + result := filterResponseHeaders(headers) + + resultMap := make(map[string][]string) + for _, h := range result { + resultMap[h.Key] = h.Values + } + + require.Contains(t, resultMap, "Content-Type") + require.Contains(t, resultMap, "Content-Length") + require.Contains(t, resultMap, "Cache-Control") + require.NotContains(t, resultMap, "X-Secret-Header") + require.NotContains(t, resultMap, "Set-Cookie") +} diff --git a/pkg/kubeproxy/token.go b/pkg/kubeproxy/token.go new file mode 100644 index 000000000..85c2e7d68 --- /dev/null +++ b/pkg/kubeproxy/token.go @@ -0,0 +1,79 @@ +package kubeproxy + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" +) + +type TokenCreatorFunc func(ctx context.Context) (token string, expiresAt time.Time, err error) + +type TokenProviderConfig struct { + CreateToken TokenCreatorFunc +} + +type TokenProvider struct { + createToken TokenCreatorFunc + mu sync.RWMutex + token string + expiresAt time.Time +} + +func NewTokenProvider(cfg TokenProviderConfig) *TokenProvider { + return &TokenProvider{createToken: cfg.CreateToken} +} + +func (tp *TokenProvider) isValid() bool { + return tp.token != "" && time.Now().Before(tp.expiresAt) +} + +func (tp *TokenProvider) GetToken(ctx context.Context) (string, error) { + tp.mu.RLock() + if tp.isValid() { + token := tp.token + tp.mu.RUnlock() + return token, nil + } + tp.mu.RUnlock() + return tp.refreshToken(ctx) +} + +func (tp *TokenProvider) refreshToken(ctx context.Context) (string, error) { + tp.mu.Lock() + defer tp.mu.Unlock() + if tp.isValid() { + return tp.token, nil + } + token, expiresAt, err := tp.createToken(ctx) + if err != nil { + return "", err + } + tp.token = token + ttl := time.Until(expiresAt) + tp.expiresAt = expiresAt.Add(-ttl / 2) + return token, nil +} + +type tokenRoundTripper struct { + tp *TokenProvider + base http.RoundTripper +} + +func NewTokenRoundTripper(tp *TokenProvider, base http.RoundTripper) http.RoundTripper { + if base == nil { + base = http.DefaultTransport + } + return &tokenRoundTripper{tp: tp, base: base} +} + +func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + token, err := t.tp.GetToken(req.Context()) + if err != nil { + return nil, fmt.Errorf("getting token: %w", err) + } + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+token) + return t.base.RoundTrip(req) +} diff --git a/pkg/kubeproxy/token_test.go b/pkg/kubeproxy/token_test.go new file mode 100644 index 000000000..3a59fa902 --- /dev/null +++ b/pkg/kubeproxy/token_test.go @@ -0,0 +1,97 @@ +package kubeproxy + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTokenProvider_GetToken(t *testing.T) { + t.Run("requests and caches token", func(t *testing.T) { + callCount := 0 + tp := NewTokenProvider(TokenProviderConfig{ + CreateToken: func(ctx context.Context) (string, time.Time, error) { + callCount++ + return "restricted-token-123", time.Now().Add(15 * time.Minute), nil + }, + }) + + ctx := context.Background() + + token, err := tp.GetToken(ctx) + require.NoError(t, err) + require.Equal(t, "restricted-token-123", token) + require.Equal(t, 1, callCount) + + token2, err := tp.GetToken(ctx) + require.NoError(t, err) + require.Equal(t, "restricted-token-123", token2) + require.Equal(t, 1, callCount, "second call should use cache") + }) + + t.Run("refreshes expired token", func(t *testing.T) { + callCount := 0 + tp := NewTokenProvider(TokenProviderConfig{ + CreateToken: func(ctx context.Context) (string, time.Time, error) { + callCount++ + return fmt.Sprintf("token-%d", callCount), time.Now().Add(15 * time.Minute), nil + }, + }) + + ctx := context.Background() + + _, err := tp.GetToken(ctx) + require.NoError(t, err) + require.Equal(t, 1, callCount) + + tp.mu.Lock() + tp.expiresAt = time.Now().Add(-1 * time.Minute) + tp.mu.Unlock() + + _, err = tp.GetToken(ctx) + require.NoError(t, err) + require.Equal(t, 2, callCount, "should refresh expired token") + }) + + t.Run("returns error on CreateToken failure", func(t *testing.T) { + tp := NewTokenProvider(TokenProviderConfig{ + CreateToken: func(ctx context.Context) (string, time.Time, error) { + return "", time.Time{}, fmt.Errorf("token request returned status 403: forbidden") + }, + }) + + _, err := tp.GetToken(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "403") + }) +} + +func TestTokenRoundTripper(t *testing.T) { + tp := NewTokenProvider(TokenProviderConfig{ + CreateToken: func(ctx context.Context) (string, time.Time, error) { + return "injected-token", time.Now().Add(15 * time.Minute), nil + }, + }) + + var capturedAuth string + backendSrv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer backendSrv.Close() + + rt := NewTokenRoundTripper(tp, backendSrv.Client().Transport) + client := &http.Client{Transport: rt} + + req, _ := http.NewRequest(http.MethodGet, backendSrv.URL+"/api/v1/pods", nil) + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, "Bearer injected-token", capturedAuth) +}