outposts/ldap: copy boundUsers map when running refresh instead of using blank map
closes #1651 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
		| @ -29,6 +29,7 @@ type ProviderInstance struct { | |||||||
| 	tlsServerName       *string | 	tlsServerName       *string | ||||||
| 	cert                *tls.Certificate | 	cert                *tls.Certificate | ||||||
| 	outpostName         string | 	outpostName         string | ||||||
|  | 	outpostPk           int32 | ||||||
| 	searchAllowedGroups []*strfmt.UUID | 	searchAllowedGroups []*strfmt.UUID | ||||||
| 	boundUsersMutex     sync.RWMutex | 	boundUsersMutex     sync.RWMutex | ||||||
| 	boundUsers          map[string]flags.UserFlags | 	boundUsers          map[string]flags.UserFlags | ||||||
|  | |||||||
| @ -17,6 +17,15 @@ import ( | |||||||
| 	memorysearch "goauthentik.io/internal/outpost/ldap/search/memory" | 	memorysearch "goauthentik.io/internal/outpost/ldap/search/memory" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance { | ||||||
|  | 	for _, p := range ls.providers { | ||||||
|  | 		if p.outpostPk == pk { | ||||||
|  | 			return p | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func (ls *LDAPServer) Refresh() error { | func (ls *LDAPServer) Refresh() error { | ||||||
| 	outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute() | 	outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @ -31,6 +40,15 @@ func (ls *LDAPServer) Refresh() error { | |||||||
| 		groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn)) | 		groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn)) | ||||||
| 		virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn)) | 		virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn)) | ||||||
| 		logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name) | 		logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name) | ||||||
|  |  | ||||||
|  | 		// Get existing instance so we can transfer boundUsers | ||||||
|  | 		existing := ls.getCurrentProvider(provider.Pk) | ||||||
|  | 		users := make(map[string]flags.UserFlags) | ||||||
|  | 		if existing != nil { | ||||||
|  | 			existing.boundUsersMutex.Unlock() | ||||||
|  | 			users = existing.boundUsers | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		providers[idx] = &ProviderInstance{ | 		providers[idx] = &ProviderInstance{ | ||||||
| 			BaseDN:              *provider.BaseDn, | 			BaseDN:              *provider.BaseDn, | ||||||
| 			VirtualGroupDN:      virtualGroupDN, | 			VirtualGroupDN:      virtualGroupDN, | ||||||
| @ -40,13 +58,14 @@ func (ls *LDAPServer) Refresh() error { | |||||||
| 			flowSlug:            provider.BindFlowSlug, | 			flowSlug:            provider.BindFlowSlug, | ||||||
| 			searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())}, | 			searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())}, | ||||||
| 			boundUsersMutex:     sync.RWMutex{}, | 			boundUsersMutex:     sync.RWMutex{}, | ||||||
| 			boundUsers:          make(map[string]flags.UserFlags), | 			boundUsers:          users, | ||||||
| 			s:                   ls, | 			s:                   ls, | ||||||
| 			log:                 logger, | 			log:                 logger, | ||||||
| 			tlsServerName:       provider.TlsServerName, | 			tlsServerName:       provider.TlsServerName, | ||||||
| 			uidStartNumber:      *provider.UidStartNumber, | 			uidStartNumber:      *provider.UidStartNumber, | ||||||
| 			gidStartNumber:      *provider.GidStartNumber, | 			gidStartNumber:      *provider.GidStartNumber, | ||||||
| 			outpostName:         ls.ac.Outpost.Name, | 			outpostName:         ls.ac.Outpost.Name, | ||||||
|  | 			outpostPk:           provider.Pk, | ||||||
| 		} | 		} | ||||||
| 		if provider.Certificate.Get() != nil { | 		if provider.Certificate.Get() != nil { | ||||||
| 			kp := provider.Certificate.Get() | 			kp := provider.Certificate.Get() | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer