diff --git a/.gitignore b/.gitignore index 499205fbd..0c4736d00 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea +.run *.iml dist/ bin/ diff --git a/management/server/account.go b/management/server/account.go index 459cabcd3..8d6e50b55 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -100,12 +100,6 @@ type UserInfo struct { Role string `json:"role"` } -// NewAccount creates a new Account with a generated ID and generated default setup keys -func NewAccount(userId, domain string) *Account { - accountId := xid.New().String() - return newAccountWithId(accountId, userId, domain) -} - func (a *Account) Copy() *Account { peers := map[string]*Peer{} for id, peer := range a.Peers { @@ -198,6 +192,27 @@ func BuildManager( } +// newAccount creates a new Account with a generated ID and generated default setup keys. +// If ID is already in use (due to collision) we try one more time before returning error +func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) { + for i := 0; i < 2; i++ { + accountId := xid.New().String() + + _, err := am.Store.GetAccount(accountId) + statusErr, _ := status.FromError(err) + if err == nil { + log.Warnf("an account with ID already exists, retrying...") + continue + } else if statusErr.Code() == codes.NotFound { + return newAccountWithId(accountId, userID, domain), nil + } else { + return nil, err + } + } + + return nil, status.Errorf(codes.Internal, "error while creating new account") +} + func (am *DefaultAccountManager) warmupIDPCache() error { userData, err := am.idpManager.GetAllAccounts() if err != nil { @@ -368,7 +383,7 @@ func mergeLocalAndQueryUser(queried idp.UserData, local User) *UserInfo { } } -func (am *DefaultAccountManager) loadFromCache(ctx context.Context, accountID interface{}) (interface{}, error) { +func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) { return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID)) } @@ -458,8 +473,17 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes( primaryDomain bool, ) error { account.IsDomainPrimaryAccount = primaryDomain - account.Domain = strings.ToLower(claims.Domain) - account.DomainCategory = claims.DomainCategory + + lowerDomain := strings.ToLower(claims.Domain) + userObj := account.Users[claims.UserId] + if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { + account.Domain = lowerDomain + } + // prevent updating category for different domain until admin logs in + if account.Domain == lowerDomain { + account.DomainCategory = claims.DomainCategory + } + err := am.Store.SaveAccount(account) if err != nil { return status.Errorf(codes.Internal, "failed saving updated account") @@ -523,7 +547,10 @@ func (am *DefaultAccountManager) handleNewUserAccount( return nil, status.Errorf(codes.Internal, "failed saving updated account") } } else { - account = NewAccount(claims.UserId, lowerDomain) + account, err = am.newAccount(claims.UserId, lowerDomain) + if err != nil { + return nil, err + } err = am.updateAccountDomainAttributes(account, claims, true) if err != nil { return nil, err diff --git a/management/server/account_test.go b/management/server/account_test.go index c7de2e838..1f05324cd 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -96,7 +96,8 @@ func TestNewAccount(t *testing.T) { domain := "netbird.io" userId := "account_creator" - account := NewAccount(userId, domain) + accountID := "account_id" + account := newAccountWithId(accountID, userId, domain) verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) } diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index 27b996949..6a2c9f7ea 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -33,7 +33,7 @@ func TestNewStore(t *testing.T) { func TestSaveAccount(t *testing.T) { store := newStore(t) - account := NewAccount("testuser", "") + account := newAccountWithId("account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &Peer{ @@ -72,7 +72,7 @@ func TestSaveAccount(t *testing.T) { func TestStore(t *testing.T) { store := newStore(t) - account := NewAccount("testuser", "") + account := newAccountWithId("account_id", "testuser", "") account.Peers["testpeer"] = &Peer{ Key: "peerkey", SetupKey: "peerkeysetupkey", diff --git a/management/server/user.go b/management/server/user.go index 05c42af2e..f69ec40e4 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -59,7 +59,10 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) account, err := am.Store.GetUserAccount(userId) if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - account = NewAccount(userId, lowerDomain) + account, err = am.newAccount(userId, lowerDomain) + if err != nil { + return nil, err + } err = am.Store.SaveAccount(account) if err != nil { return nil, status.Errorf(codes.Internal, "failed creating account") @@ -70,7 +73,9 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) } } - if account.Domain != lowerDomain { + userObj := account.Users[userId] + + if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { account.Domain = lowerDomain err = am.Store.SaveAccount(account) if err != nil {