diff --git a/backend/gce.go b/backend/gce.go index 4adac9d71..f9539a70d 100644 --- a/backend/gce.go +++ b/backend/gce.go @@ -14,6 +14,7 @@ import ( mathrand "math/rand" "net/http" "net/url" + "regexp" "sort" "strconv" "strings" @@ -48,7 +49,6 @@ import ( ) const ( - defaultGCEZone = "us-central1-a" defaultGCEMachineType = "n1-standard-2" defaultGCEPremiumMachineType = "n1-standard-4" defaultGCENetwork = "default" @@ -59,7 +59,6 @@ const ( defaultGCEStopPollSleep = 3 * time.Second defaultGCEStopPrePollSleep = 15 * time.Second defaultGCESubnet = "default" - defaultGCERegion = "us-central1" defaultGCEUploadRetries = uint64(120) defaultGCEUploadRetrySleep = 1 * time.Second defaultGCEImageSelectorType = "env" @@ -92,10 +91,11 @@ var ( "IMAGE_SELECTOR_URL": "URL for image selector API, used only when image selector is \"api\"", "IMAGE_[ALIAS_]{ALIAS}": "full name for a given alias given via IMAGE_ALIASES, where the alias form in the key is uppercased and normalized by replacing non-alphanumerics with _", "MACHINE_TYPE": fmt.Sprintf("machine name (default %q)", defaultGCEMachineType), + "MINIMUM_CPU_PLATFORM": "minimum cpu platform", "NETWORK": fmt.Sprintf("network name (default %q)", defaultGCENetwork), "PREEMPTIBLE": "boot job instances with preemptible flag enabled (default false)", "PREMIUM_MACHINE_TYPE": fmt.Sprintf("premium machine type (default %q)", defaultGCEPremiumMachineType), - "PROJECT_ID": "[REQUIRED] GCE project id", + "PROJECT_ID": "[REQUIRED] GCE project id (will try to auto detect it if not set)", "PUBLIC_IP": "boot job instances with a public ip, disable this for NAT (default true)", "PUBLIC_IP_CONNECT": "connect to the public ip of the instance instead of the internal, only takes effect if PUBLIC_IP is true (default true)", "IMAGE_PROJECT_ID": "GCE project id to use for images, will use PROJECT_ID if not specified", @@ -107,8 +107,8 @@ var ( "RATE_LIMIT_DYNAMIC_CONFIG": "get max-calls and duration dynamically through redis (default false)", "RATE_LIMIT_DYNAMIC_CONFIG_TTL": fmt.Sprintf("time to cache dynamic config for (default %v)", defaultGCERateLimitDynamicConfigTTL), - "BACKOFF_RETRY_MAX": "Maximum allowed duration of generic exponential backoff retries (default 1m)", - "REGION": fmt.Sprintf("only takes effect when SUBNETWORK is defined; region in which to deploy (default %v)", defaultGCERegion), + "BACKOFF_RETRY_MAX": "maximum allowed duration of generic exponential backoff retries (default 1m)", + "REGION": "[REQUIRED] region in which to deploy", "SKIP_STOP_POLL": "immediately return after issuing first instance deletion request (default false)", "SSH_DIAL_TIMEOUT": fmt.Sprintf("connection timeout for ssh connections (default %v)", defaultGCESSHDialTimeout), "STOP_POLL_SLEEP": fmt.Sprintf("sleep interval between polling server for instance stop status (default %v)", defaultGCEStopPollSleep), @@ -119,7 +119,7 @@ var ( "WARMER_URL": "URL for warmer service", "WARMER_TIMEOUT": fmt.Sprintf("timeout for requests to warmer service (default %v)", defaultGCEWarmerTimeout), "WARMER_SSH_PASSPHRASE": fmt.Sprintf("The passphrase used to decipher instace SSH keys"), - "ZONE": fmt.Sprintf("zone name (default %q)", defaultGCEZone), + "ZONE": "zone in which to deploy job instances into (default is to use all zones in the region)", } errGCEMissingIPAddressError = fmt.Errorf("no IP address found") @@ -182,7 +182,7 @@ type gceProvider struct { imageProjectID string ic *gceInstanceConfig cfg *config.ProviderConfig - alternateZones []string + allZonesForRegion []string machineTypeSelfLinks map[string]string backoffRetryMax time.Duration @@ -210,6 +210,7 @@ type gceProvider struct { } type gceInstanceConfig struct { + MinimumCpuPlatform string MachineType string PremiumMachineType string Zone *compute.Zone @@ -255,7 +256,6 @@ type gceStartContext struct { instanceWarmedIP string windowsPassword string zoneName string - zonePinned bool machineType string premiumMachineType string } @@ -310,6 +310,10 @@ func newGCEProvider(cfg *config.ProviderConfig) (Provider, error) { return nil, err } + if !cfg.IsSet("ACCOUNT_JSON") { + return nil, fmt.Errorf("missing ACCOUNT_JSON") + } + projectID := cfg.Get("PROJECT_ID") if metadata.OnGCE() { projectID, err = metadata.ProjectID() @@ -326,18 +330,32 @@ func newGCEProvider(cfg *config.ProviderConfig) (Provider, error) { imageProjectID = cfg.Get("IMAGE_PROJECT_ID") } - zoneName := defaultGCEZone - if metadata.OnGCE() { - zoneName, err = metadata.Zone() + region := "" + if cfg.IsSet("REGION") { + region = cfg.Get("REGION") + } else if metadata.OnGCE() { + zoneName, err := metadata.Zone() if err != nil { return nil, errors.Wrap(err, "could not get zone from metadata api") } + zone, zErr := client.Zones.Get(projectID, zoneName).Do() + if zErr != nil { + return nil, errors.Wrap(zErr, "could not get zone from compute api") + } + zoneURI := fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/regions/(.*)", projectID) + quotedZoneURI := strings.Replace(zoneURI, "/", "\\/", -1) + re := regexp.MustCompile(quotedZoneURI) + region = re.FindStringSubmatch(zone.Region)[1] + } else { + return nil, fmt.Errorf("missing REGION") } - if cfg.IsSet("ZONE") { - zoneName = cfg.Get("ZONE") - } - cfg.Set("ZONE", zoneName) + cfg.Set("REGION", region) + + minimumCpuPlatform := "" + if cfg.IsSet("MINIMUM_CPU_PLATFORM") { + minimumCpuPlatform = cfg.Get("MINIMUM_CPU_PLATFORM") + } mtName := defaultGCEMachineType if cfg.IsSet("MACHINE_TYPE") { @@ -597,7 +615,7 @@ func newGCEProvider(cfg *config.ProviderConfig) (Provider, error) { projectID: projectID, imageProjectID: imageProjectID, cfg: cfg, - alternateZones: []string{}, + allZonesForRegion: []string{}, machineTypeSelfLinks: map[string]string{}, sshDialer: sshDialer, sshDialTimeout: sshDialTimeout, @@ -614,6 +632,7 @@ func newGCEProvider(cfg *config.ProviderConfig) (Provider, error) { SkipStopPoll: skipStopPoll, Site: site, AcceleratorConfig: defaultAcceleratorConfig, + MinimumCpuPlatform: minimumCpuPlatform, MachineType: mtName, PremiumMachineType: premiumMTName, }, @@ -697,27 +716,29 @@ func (p *gceProvider) apiRateLimit(ctx gocontext.Context) error { func (p *gceProvider) Setup(ctx gocontext.Context) error { logger := context.LoggerFromContext(ctx).WithField("self", "backend/gce_provider") - logger.WithField("zone", p.cfg.Get("ZONE")).Debug("resolving configured zone") + if p.cfg.Get("ZONE") != "" { + logger.WithField("zone", p.cfg.Get("ZONE")).Debug("resolving configured zone") + + err := p.backoffRetry(ctx, func() error { + _ = p.apiRateLimit(ctx) + zone, zErr := p.client.Zones. + Get(p.projectID, p.cfg.Get("ZONE")). + Context(ctx). + Do() + if zErr == nil { + p.ic.Zone = zone + } + return zErr + }) - err := p.backoffRetry(ctx, func() error { - _ = p.apiRateLimit(ctx) - zone, zErr := p.client.Zones. - Get(p.projectID, p.cfg.Get("ZONE")). - Context(ctx). - Do() - if zErr == nil { - p.ic.Zone = zone + if err != nil { + return errors.Wrap(err, "failed to resolve configured zone") } - return zErr - }) - - if err != nil { - return errors.Wrap(err, "failed to resolve configured zone") } logger.WithField("network", p.cfg.Get("NETWORK")).Debug("resolving configured network") - err = p.backoffRetry(ctx, func() error { + err := p.backoffRetry(ctx, func() error { _ = p.apiRateLimit(ctx) nw, nwErr := p.client.Networks. Get(p.projectID, p.cfg.Get("NETWORK")). @@ -730,17 +751,18 @@ func (p *gceProvider) Setup(ctx gocontext.Context) error { }) if err != nil { - return errors.Wrap(err, "failed te resolve configured network") + return errors.Wrap(err, "failed to resolve configured network") } - region := defaultGCERegion - if metadata.OnGCE() { - logger.WithField("region", p.ic.Zone.Region).Debug("setting region from zone when on gce") - region = p.ic.Zone.Region - } + region := "" if p.cfg.IsSet("REGION") { - logger.WithField("region", p.ic.Zone.Region).Debug("setting region from config") + logger.WithField("region", p.cfg.Get("REGION")).Info("setting region from config") region = p.cfg.Get("REGION") + } else if metadata.OnGCE() { + logger.WithField("region", p.ic.Zone.Region).Info("setting region from zone when on gce") + region = p.ic.Zone.Region + } else { + return errors.Wrap(err, "failed to resolve configured region") } if p.cfg.IsSet("SUBNETWORK") { @@ -763,33 +785,40 @@ func (p *gceProvider) Setup(ctx gocontext.Context) error { } } - logger.Debug("finding alternate zones") + logger.Debug("finding all zones for region") err = p.backoffRetry(ctx, func() error { _ = p.apiRateLimit(ctx) + + regionURL := fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/regions/%s", p.projectID, p.cfg.Get("REGION")) + zl, zlErr := p.client.Zones.List(p.projectID). Context(ctx). Filter("status eq UP"). - Filter(fmt.Sprintf("region eq %s", p.ic.Zone.Region)).Do() + Filter(fmt.Sprintf("region eq %s", regionURL)).Do() if zlErr != nil { return zlErr } - p.alternateZones = []string{} + p.allZonesForRegion = []string{} for _, z := range zl.Items { - p.alternateZones = append(p.alternateZones, z.Name) + p.allZonesForRegion = append(p.allZonesForRegion, z.Name) + } + + if len(p.allZonesForRegion) == 0 { + return fmt.Errorf("no zones found for region %s", p.cfg.Get("REGION")) } return nil }) if err != nil { - return errors.Wrap(err, "failed to find alternate zones") + return errors.Wrap(err, "failed to find zones for region") } logger.Debug("building machine type self link map") - for _, zoneName := range append([]string{p.ic.Zone.Name}, p.alternateZones...) { + for _, zoneName := range p.allZonesForRegion { for _, machineType := range []string{p.ic.MachineType, p.ic.PremiumMachineType} { if zoneName == "" || machineType == "" { continue @@ -913,7 +942,6 @@ func (p *gceProvider) StartWithProgress(ctx gocontext.Context, startAttributes * c := &gceStartContext{ startAttributes: startAttributes, - zoneName: p.ic.Zone.Name, machineType: p.ic.MachineType, premiumMachineType: p.ic.PremiumMachineType, progresser: progresser, @@ -1045,22 +1073,7 @@ func (p *gceProvider) stepInsertInstance(c *gceStartContext) multistep.StepActio logger := context.LoggerFromContext(c.ctx).WithField("self", "backend/gce_provider") - if c.startAttributes.VMConfig.Zone != "" { - err := p.backoffRetry(ctx, func() error { - _ = p.apiRateLimit(ctx) - zone, zErr := p.client.Zones.Get(p.projectID, c.startAttributes.VMConfig.Zone).Context(ctx).Do() - if zErr != nil { - return zErr - } - c.zoneName = zone.Name - c.zonePinned = true - return nil - }) - - if err != nil { - return multistep.ActionHalt - } - } + c.zoneName = p.pickZone("") inst, err := p.buildInstance(ctx, c) if err != nil { @@ -1142,15 +1155,13 @@ func (p *gceProvider) stepInsertInstance(c *gceStartContext) multistep.StepActio op, insErr := p.client.Instances.Insert(p.projectID, c.zoneName, c.instance).Context(c.ctx).Do() if insErr != nil { - if !c.zonePinned { - altZone := p.pickAlternateZone(c.zoneName) - logger.WithFields(logrus.Fields{ - "err": insErr, - "prev_zone": c.zoneName, - "next_zone": altZone, - }).Warn("switching zones due to error") - p.setStartContextZone(c, altZone) - } + nextZone := p.pickZone(c.zoneName) + logger.WithFields(logrus.Fields{ + "err": insErr, + "prev_zone": c.zoneName, + "next_zone": nextZone, + }).Warn("switching zones due to error") + p.setStartContextZone(c, nextZone) return insErr } @@ -1451,6 +1462,10 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) ( }, } + if p.ic.MinimumCpuPlatform != "" { + inst.MinCpuPlatform = p.ic.MinimumCpuPlatform + } + machineType := p.ic.MachineType if c.startAttributes.VMType == "premium" { machineType = p.ic.PremiumMachineType @@ -1614,20 +1629,12 @@ func (p *gceProvider) warmerRequestInstance(ctx gocontext.Context, zone string, return warmerRes, nil } -func (p *gceProvider) pickAlternateZone(zoneName string) string { - if len(p.alternateZones) == 0 { - return zoneName +func (p *gceProvider) pickZone(zoneName string) string { + if p.cfg.Get("ZONE") != "" { + return p.cfg.Get("ZONE") } - for { - altZone := p.alternateZones[mathrand.Intn(len(p.alternateZones))] - if altZone != zoneName { - return altZone - } - if len(p.alternateZones) == 1 { - return zoneName - } - } + return p.allZonesForRegion[mathrand.Intn(len(p.allZonesForRegion))] } func (p *gceProvider) setStartContextZone(c *gceStartContext, zoneName string) { diff --git a/backend/gce_test.go b/backend/gce_test.go index e66510cf3..3e6e5a03f 100644 --- a/backend/gce_test.go +++ b/backend/gce_test.go @@ -74,6 +74,7 @@ func gceTestSetup(t *testing.T, cfg *config.ProviderConfig, resp *gceTestRespons "PROJECT_ID": "project_id", "IMAGE_ALIASES": "foo", "IMAGE_ALIAS_FOO": "default", + "REGION": "us-central1", }) }